Fix for #475124. DowngradeFromWriterLock only resets to a reader lock state
[mono.git] / mcs / class / corlib / Mono.Math / BigInteger.cs
index 20c6d0eb11fe43a553b39cd82d3ba897da3d6b45..0356d0a166487af7ffae6befcd7c59b362d180d0 100644 (file)
@@ -13,7 +13,7 @@
 // Copyright (c) 2002 Chew Keong TAN
 // All rights reserved.
 //
-// Copyright (C) 2004 Novell, Inc (http://www.novell.com)
+// Copyright (C) 2004, 2007 Novell, Inc (http://www.novell.com)
 //
 // Permission is hereby granted, free of charge, to any person obtaining
 // a copy of this software and associated documentation files (the
@@ -208,6 +208,8 @@ namespace Mono.Math {
                
                public BigInteger (byte [] inData)
                {
+                       if (inData.Length == 0)
+                               inData = new byte [1];
                        length = (uint)inData.Length >> 2;
                        int leftOver = inData.Length & 0x3;
 
@@ -239,6 +241,8 @@ namespace Mono.Math {
 #endif 
                public BigInteger (uint [] inData)
                {
+                       if (inData.Length == 0)
+                               inData = new uint [1];
                        length = (uint)inData.Length;
 
                        data = new uint [length];
@@ -862,10 +866,16 @@ namespace Mono.Math {
 
                public override bool Equals (object o)
                {
-                       if (o == null) return false;
-                       if (o is int) return (int)o >= 0 && this == (uint)o;
+                       if (o == null)
+                               return false;
+                       if (o is int)
+                               return (int)o >= 0 && this == (uint)o;
 
-                       return Kernel.Compare (this, (BigInteger)o) == 0;
+                       BigInteger bi = o as BigInteger;
+                       if (bi == null)
+                               return false;
+                       
+                       return Kernel.Compare (this, bi) == 0;
                }
 
                #endregion
@@ -894,19 +904,23 @@ namespace Mono.Math {
 
                public bool IsProbablePrime ()
                {
-                       if (this < smallPrimes [smallPrimes.Length - 1]) {
+                       // can we use our small-prime table ?
+                       if (this <= smallPrimes[smallPrimes.Length - 1]) {
                                for (int p = 0; p < smallPrimes.Length; p++) {
-                                       if (this == smallPrimes [p])
+                                       if (this == smallPrimes[p])
                                                return true;
                                }
+                               // the list is complete, so it's not a prime
+                               return false;
                        }
-                       else {
-                               for (int p = 0; p < smallPrimes.Length; p++) {
-                                       if (this % smallPrimes [p] == 0)
-                                               return false;
-                               }
+
+                       // otherwise check if we can divide by one of the small primes
+                       for (int p = 0; p < smallPrimes.Length; p++) {
+                               if (this % smallPrimes[p] == 0)
+                                       return false;
                        }
-                       return PrimalityTests.RabinMillerTest (this, Prime.ConfidenceFactor.Medium);
+                       // the last step is to confirm the "large" prime with the SPP or Miller-Rabin test
+                       return PrimalityTests.Test (this, Prime.ConfidenceFactor.Medium);
                }
 
                #endregion
@@ -1038,19 +1052,13 @@ namespace Mono.Math {
                        {
                                if (a == 0 || b == 0) return 0;
 
-                               if (a.length >= mod.length << 1)
+                               if (a > mod)
                                        a %= mod;
 
-                               if (b.length >= mod.length << 1)
+                               if (b > mod)
                                        b %= mod;
 
-                               if (a.length >= mod.length)
-                                       BarrettReduction (a);
-
-                               if (b.length >= mod.length)
-                                       BarrettReduction (b);
-
-                               BigInteger ret = new BigInteger (a * b);
+                               BigInteger ret = a * b;
                                BarrettReduction (ret);
 
                                return ret;
@@ -1082,7 +1090,25 @@ namespace Mono.Math {
                                        diff = mod - diff;
                                return diff;
                        }
-
+#if true
+                       public BigInteger Pow (BigInteger a, BigInteger k)
+                       {
+                               BigInteger b = new BigInteger (1);
+                               if (k == 0)
+                                       return b;
+
+                               BigInteger A = a;
+                               if (k.TestBit (0))
+                                       b = a;
+
+                               for (int i = 1; i < k.BitCount (); i++) {
+                                       A = Multiply (A, A);
+                                       if (k.TestBit (i))
+                                               b = Multiply (A, b);
+                               }
+                               return b;
+                       }
+#else
                        public BigInteger Pow (BigInteger b, BigInteger exp)
                        {
                                if ((mod.data [0] & 1) == 1) return OddPow (b, exp);
@@ -1146,14 +1172,17 @@ namespace Mono.Math {
                                                Montgomery.Reduce (resultNum, mod, mPrime);
                                        }
 
-                                       Kernel.SquarePositive (tempNum, ref wkspace);
-                                       Montgomery.Reduce (tempNum, mod, mPrime);
+                                       // the value of tempNum is required in the last loop
+                                       if (pos < totalBits - 1) {
+                                               Kernel.SquarePositive (tempNum, ref wkspace);
+                                               Montgomery.Reduce (tempNum, mod, mPrime);
+                                       }
                                }
 
                                Montgomery.Reduce (resultNum, mod, mPrime);
                                return resultNum;
                        }
-
+#endif
                        #region Pow Small Base
 
                        // TODO: Make tests for this, not really needed b/c prime stuff
@@ -1161,6 +1190,12 @@ namespace Mono.Math {
 #if !INSIDE_CORLIB
                         [CLSCompliant (false)]
 #endif 
+#if true
+                       public BigInteger Pow (uint b, BigInteger exp)
+                       {
+                               return Pow (new BigInteger (b), exp);
+                       }
+#else
                        public BigInteger Pow (uint b, BigInteger exp)
                        {
 //                             if (b != 2) {
@@ -1168,7 +1203,7 @@ namespace Mono.Math {
                                                return OddPow (b, exp);
                                        else
                                                return EvenPow (b, exp);
-/* buggy in some cases (like the well tested primes)
+/* buggy in some cases (like the well tested primes) 
                                } else {
                                        if ((mod.data [0] & 1) == 1)
                                                return OddModTwoPow (exp);
@@ -1187,7 +1222,8 @@ namespace Mono.Math {
 
                                uint mPrime = Montgomery.Inverse (mod.data [0]);
 
-                               uint pos = (uint)exp.BitCount () - 2;
+                               int bc = exp.BitCount () - 2;
+                               uint pos = (bc > 1 ? (uint) bc : 1);
 
                                //
                                // We know that the first itr will make the val b
@@ -1387,8 +1423,9 @@ namespace Mono.Math {
 
                                return resultNum;
                        }
-
-/* known to be buggy in some cases
+#endif
+/* known to be buggy in some cases */
+#if false
                        private unsafe BigInteger EvenModTwoPow (BigInteger exp)
                        {
                                exp.Normalize ();
@@ -1521,94 +1558,8 @@ namespace Mono.Math {
                                resultNum = Montgomery.Reduce (resultNum, mod, mPrime);
                                return resultNum;
                        }
-*/                     
-                       #endregion
-               }
-
-               internal sealed class Montgomery {
-
-                       private Montgomery () 
-                       {
-                       }
-
-                       public static uint Inverse (uint n)
-                       {
-                               uint y = n, z;
-
-                               while ((z = n * y) != 1)
-                                       y *= 2 - z;
-
-                               return (uint)-y;
-                       }
-
-                       public static BigInteger ToMont (BigInteger n, BigInteger m)
-                       {
-                               n.Normalize (); m.Normalize ();
-
-                               n <<= (int)m.length * 32;
-                               n %= m;
-                               return n;
-                       }
-
-                       public static unsafe BigInteger Reduce (BigInteger n, BigInteger m, uint mPrime)
-                       {
-                               BigInteger A = n;
-                               fixed (uint* a = A.data, mm = m.data) {
-                                       for (uint i = 0; i < m.length; i++) {
-                                               // The mod here is taken care of by the CPU,
-                                               // since the multiply will overflow.
-                                               uint u_i = a [0] * mPrime /* % 2^32 */;
-
-                                               //
-                                               // A += u_i * m;
-                                               // A >>= 32
-                                               //
-
-                                               // mP = Position in mod
-                                               // aSP = the source of bits from a
-                                               // aDP = destination for bits
-                                               uint* mP = mm, aSP = a, aDP = a;
-
-                                               ulong c = (ulong)u_i * ((ulong)*(mP++)) + *(aSP++);
-                                               c >>= 32;
-                                               uint j = 1;
-
-                                               // Multiply and add
-                                               for (; j < m.length; j++) {
-                                                       c += (ulong)u_i * (ulong)*(mP++) + *(aSP++);
-                                                       *(aDP++) = (uint)c;
-                                                       c >>= 32;
-                                               }
-
-                                               // Account for carry
-                                               // TODO: use a better loop here, we dont need the ulong stuff
-                                               for (; j < A.length; j++) {
-                                                       c += *(aSP++);
-                                                       *(aDP++) = (uint)c;
-                                                       c >>= 32;
-                                                       if (c == 0) {j++; break;}
-                                               }
-                                               // Copy the rest
-                                               for (; j < A.length; j++) {
-                                                       *(aDP++) = *(aSP++);
-                                               }
-
-                                               *(aDP++) = (uint)c;
-                                       }
-
-                                       while (A.length > 1 && a [A.length-1] == 0) A.length--;
-
-                               }
-                               if (A >= m) Kernel.MinusEq (A, m);
-
-                               return A;
-                       }
-#if _NOT_USED_
-                       public static BigInteger Reduce (BigInteger n, BigInteger m)
-                       {
-                               return Reduce (n, m, Inverse (m.data [0]));
-                       }
 #endif
+                       #endregion
                }
 
                /// <summary>