Merge pull request #961 from ermshiperete/bug-xamarin-18118
[mono.git] / mcs / class / System.Core / System.Collections.Generic / HashSet.cs
index 96b407fb5a189bcfa938893d88e212ebd8a9f561..a146fd9002379272a51617de921f214143b5714f 100644 (file)
@@ -34,17 +34,24 @@ using System.Runtime.Serialization;
 using System.Runtime.InteropServices;
 using System.Security;
 using System.Security.Permissions;
+using System.Diagnostics;
 
 // HashSet is basically implemented as a reduction of Dictionary<K, V>
 
 namespace System.Collections.Generic {
 
-       [Serializable, HostProtection (SecurityAction.LinkDemand, MayLeakOnAbort = true)]
-       public class HashSet<T> : ICollection<T>, ISerializable, IDeserializationCallback {
-
+       [Serializable]
+       [DebuggerDisplay ("Count={Count}")]
+       [DebuggerTypeProxy (typeof (CollectionDebuggerView<,>))]
+       public class HashSet<T> : ICollection<T>, ISerializable, IDeserializationCallback
+#if NET_4_0
+                                                       , ISet<T>
+#endif
+       {
                const int INITIAL_SIZE = 10;
                const float DEFAULT_LOAD_FACTOR = (90f / 100);
                const int NO_SLOT = -1;
+               const int HASH_FLAG = -2147483648;
 
                struct Link {
                        public int HashCode;
@@ -105,7 +112,11 @@ namespace System.Collections.Generic {
                        if (collection == null)
                                throw new ArgumentNullException ("collection");
 
-                       int capacity = collection.Count ();
+                       int capacity = 0;
+                       var col = collection as ICollection<T>;
+                       if (col != null)
+                               capacity = col.Count;
+
                        Init (capacity, comparer);
                        foreach (var item in collection)
                                Add (item);
@@ -132,7 +143,8 @@ namespace System.Collections.Generic {
                        generation = 0;
                }
 
-               void InitArrays (int size) {
+               void InitArrays (int size)
+               {
                        table = new int [size];
 
                        links = new Link [size];
@@ -151,7 +163,7 @@ namespace System.Collections.Generic {
                        int current = table [index] - 1;
                        while (current != NO_SLOT) {
                                Link link = links [current];
-                               if (link.HashCode == hash && comparer.Equals (item, slots [current]))
+                               if (link.HashCode == hash && ((hash == HASH_FLAG && (item == null || null == slots [current])) ? (item == null && null == slots [current]) : comparer.Equals (item, slots [current])))
                                        return true;
 
                                current = link.Next;
@@ -160,29 +172,36 @@ namespace System.Collections.Generic {
                        return false;
                }
 
-               public void CopyTo (T [] array, int index)
+               public void CopyTo (T [] array)
+               {
+                       CopyTo (array, 0, count);
+               }
+               
+               public void CopyTo (T [] array, int arrayIndex)
+               {
+                       CopyTo (array, arrayIndex, count);
+               }
+
+               public void CopyTo (T [] array, int arrayIndex, int count)
                {
                        if (array == null)
                                throw new ArgumentNullException ("array");
-                       if (index < 0)
-                               throw new ArgumentOutOfRangeException ("index");
-                       if (index > array.Length)
+                       if (arrayIndex < 0)
+                               throw new ArgumentOutOfRangeException ("arrayIndex");
+                       if (arrayIndex > array.Length)
                                throw new ArgumentException ("index larger than largest valid index of array");
-                       if (array.Length - index < count)
+                       if (array.Length - arrayIndex < count)
                                throw new ArgumentException ("Destination array cannot hold the requested elements!");
 
-                       for (int i = 0; i < table.Length; i++) {
-                               int current = table [i] - 1;
-                               while (current != NO_SLOT) {
-                                       array [index++] = slots [current];
-                                       current = links [current].Next;
-                               }
+                       for (int i = 0, items = 0; i < touched && items < count; i++) {
+                               if (GetLinkHashCode (i) != 0)
+                                       array [arrayIndex++] = slots [i];
                        }
                }
 
-               void Resize ()
+               void Resize (int size)
                {
-                       int newSize = PrimeHelper.ToPrime ((table.Length << 1) | 1);
+                       int newSize = HashPrimeNumbers.ToPrime (size);
 
                        // allocate new hash table and link slots array
                        var newTable = new int [newSize];
@@ -191,7 +210,7 @@ namespace System.Collections.Generic {
                        for (int i = 0; i < table.Length; i++) {
                                int current = table [i] - 1;
                                while (current != NO_SLOT) {
-                                       int hashCode = newLinks [current].HashCode = comparer.GetHashCode (slots [current]);
+                                       int hashCode = newLinks [current].HashCode = GetItemHashCode (slots [current]);
                                        int index = (hashCode & int.MaxValue) % newSize;
                                        newLinks [current].Next = newTable [index] - 1;
                                        newTable [index] = current + 1;
@@ -210,16 +229,28 @@ namespace System.Collections.Generic {
                        threshold = (int) (newSize * DEFAULT_LOAD_FACTOR);
                }
 
+               int GetLinkHashCode (int index)
+               {
+                       return links [index].HashCode & HASH_FLAG;
+               }
+
+               int GetItemHashCode (T item)
+               {
+                       if (item == null)
+                               return HASH_FLAG;
+                       return comparer.GetHashCode (item) | HASH_FLAG;
+               }
+
                public bool Add (T item)
                {
-                       int hashCode = comparer.GetHashCode (item);
+                       int hashCode = GetItemHashCode (item);
                        int index = (hashCode & int.MaxValue) % table.Length;
 
                        if (SlotsContainsAt (index, hashCode, item))
                                return false;
 
                        if (++count > threshold) {
-                               Resize ();
+                               Resize ((table.Length << 1) | 1);
                                index = (hashCode & int.MaxValue) % table.Length;
                        }
 
@@ -252,8 +283,10 @@ namespace System.Collections.Generic {
                public void Clear ()
                {
                        count = 0;
-                       // clear the hash table
+
                        Array.Clear (table, 0, table.Length);
+                       Array.Clear (slots, 0, slots.Length);
+                       Array.Clear (links, 0, links.Length);
 
                        // empty the "empty slots chain"
                        empty_slot = NO_SLOT;
@@ -264,7 +297,7 @@ namespace System.Collections.Generic {
 
                public bool Contains (T item)
                {
-                       int hashCode = comparer.GetHashCode (item);
+                       int hashCode = GetItemHashCode (item);
                        int index = (hashCode & int.MaxValue) % table.Length;
 
                        return SlotsContainsAt (index, hashCode, item);
@@ -273,7 +306,7 @@ namespace System.Collections.Generic {
                public bool Remove (T item)
                {
                        // get first item of linked list corresponding to given key
-                       int hashCode = comparer.GetHashCode (item);
+                       int hashCode = GetItemHashCode (item);
                        int index = (hashCode & int.MaxValue) % table.Length;
                        int current = table [index] - 1;
 
@@ -286,7 +319,7 @@ namespace System.Collections.Generic {
                        int prev = NO_SLOT;
                        do {
                                Link link = links [current];
-                               if (link.HashCode == hashCode && comparer.Equals (slots [current], item))
+                               if (link.HashCode == hashCode && ((hashCode == HASH_FLAG && (item == null || null == slots [current])) ? (item == null && null == slots [current]) : comparer.Equals (slots [current], item)))
                                        break;
 
                                prev = current;
@@ -310,34 +343,35 @@ namespace System.Collections.Generic {
                        links [current].Next = empty_slot;
                        empty_slot = current;
 
+                       // clear slot
+                       links [current].HashCode = 0;
+                       slots [current] = default (T);
+
                        generation++;
 
                        return true;
                }
 
-               public int RemoveWhere (Predicate<T> predicate)
+               public int RemoveWhere (Predicate<T> match)
                {
-                       if (predicate == null)
-                               throw new ArgumentNullException ("predicate");
+                       if (match == null)
+                               throw new ArgumentNullException ("match");
 
-                       int counter = 0;
+                       var candidates = new List<T> ();
 
-                       var copy = new T [count];
-                       CopyTo (copy, 0);
+                       foreach (var item in this)
+                               if (match (item)) 
+                                       candidates.Add (item);
 
-                       foreach (var item in copy) {
-                               if (predicate (item)) {
-                                       Remove (item);
-                                       counter++;
-                               }
-                       }
+                       foreach (var item in candidates)
+                               Remove (item);
 
-                       return counter;
+                       return candidates.Count;
                }
 
                public void TrimExcess ()
                {
-                       Resize ();
+                       Resize (count);
                }
 
                // set operations
@@ -347,16 +381,9 @@ namespace System.Collections.Generic {
                        if (other == null)
                                throw new ArgumentNullException ("other");
 
-                       var copy = new T [count];
-                       CopyTo (copy, 0);
-
-                       foreach (var item in copy)
-                               if (!other.Contains (item))
-                                       Remove (item);
+                       var other_set = ToSet (other);
 
-                       foreach (var item in other)
-                               if (!Contains (item))
-                                       Remove (item);
+                       RemoveWhere (item => !other_set.Contains (item));
                }
 
                public void ExceptWith (IEnumerable<T> other)
@@ -385,11 +412,13 @@ namespace System.Collections.Generic {
                        if (other == null)
                                throw new ArgumentNullException ("other");
 
-                       if (count != other.Count ())
+                       var other_set = ToSet (other);
+
+                       if (count != other_set.Count)
                                return false;
 
                        foreach (var item in this)
-                               if (!other.Contains (item))
+                               if (!other_set.Contains (item))
                                        return false;
 
                        return true;
@@ -400,12 +429,18 @@ namespace System.Collections.Generic {
                        if (other == null)
                                throw new ArgumentNullException ("other");
 
-                       foreach (var item in other) {
-                               if (Contains (item))
+                       foreach (var item in ToSet (other))
+                               if (!Add (item))
                                        Remove (item);
-                               else
-                                       Add (item);
-                       }
+               }
+
+               HashSet<T> ToSet (IEnumerable<T> enumerable)
+               {
+                       var set = enumerable as HashSet<T>;
+                       if (set == null || !Comparer.Equals (set.Comparer))
+                               set = new HashSet<T> (enumerable, Comparer);
+
+                       return set;
                }
 
                public void UnionWith (IEnumerable<T> other)
@@ -417,7 +452,7 @@ namespace System.Collections.Generic {
                                Add (item);
                }
 
-               bool CheckIsSubsetOf (IEnumerable<T> other)
+               bool CheckIsSubsetOf (HashSet<T> other)
                {
                        if (other == null)
                                throw new ArgumentNullException ("other");
@@ -437,10 +472,12 @@ namespace System.Collections.Generic {
                        if (count == 0)
                                return true;
 
-                       if (count > other.Count ())
+                       var other_set = ToSet (other);
+
+                       if (count > other_set.Count)
                                return false;
 
-                       return CheckIsSubsetOf (other);
+                       return CheckIsSubsetOf (other_set);
                }
 
                public bool IsProperSubsetOf (IEnumerable<T> other)
@@ -451,13 +488,15 @@ namespace System.Collections.Generic {
                        if (count == 0)
                                return true;
 
-                       if (count >= other.Count ())
+                       var other_set = ToSet (other);
+
+                       if (count >= other_set.Count)
                                return false;
 
-                       return CheckIsSubsetOf (other);
+                       return CheckIsSubsetOf (other_set);
                }
 
-               bool CheckIsSupersetOf (IEnumerable<T> other)
+               bool CheckIsSupersetOf (HashSet<T> other)
                {
                        if (other == null)
                                throw new ArgumentNullException ("other");
@@ -474,10 +513,12 @@ namespace System.Collections.Generic {
                        if (other == null)
                                throw new ArgumentNullException ("other");
 
-                       if (count < other.Count ())
+                       var other_set = ToSet (other);
+
+                       if (count < other_set.Count)
                                return false;
 
-                       return CheckIsSupersetOf (other);
+                       return CheckIsSupersetOf (other_set);
                }
 
                public bool IsProperSupersetOf (IEnumerable<T> other)
@@ -485,39 +526,63 @@ namespace System.Collections.Generic {
                        if (other == null)
                                throw new ArgumentNullException ("other");
 
-                       if (count <= other.Count ())
+                       var other_set = ToSet (other);
+
+                       if (count <= other_set.Count)
                                return false;
 
-                       return CheckIsSupersetOf (other);
+                       return CheckIsSupersetOf (other_set);
                }
 
-               [MonoTODO]
                public static IEqualityComparer<HashSet<T>> CreateSetComparer ()
                {
-                       throw new NotImplementedException ();
+                       return HashSetEqualityComparer<T>.Instance;
                }
 
-               [MonoTODO]
                [SecurityPermission (SecurityAction.LinkDemand, Flags = SecurityPermissionFlag.SerializationFormatter)]
                public virtual void GetObjectData (SerializationInfo info, StreamingContext context)
                {
-                       throw new NotImplementedException ();
+                       if (info == null) {
+                               throw new ArgumentNullException("info");
+                       }
+                       info.AddValue("Version", generation);
+                       info.AddValue("Comparer", comparer, typeof(IEqualityComparer<T>));
+                       info.AddValue("Capacity", (table == null) ? 0 : table.Length);
+                       if (table != null) {
+                               T[] tableArray = new T[count];
+                               CopyTo(tableArray);
+                               info.AddValue("Elements", tableArray, typeof(T[]));
+                       }
                }
 
-               [MonoTODO]
                public virtual void OnDeserialization (object sender)
                {
-                       if (si == null)
-                               return;
+                       if (si != null)
+                       {
+                               generation = (int) si.GetValue("Version", typeof(int));
+                               comparer = (IEqualityComparer<T>) si.GetValue("Comparer", 
+                                                                             typeof(IEqualityComparer<T>));
+                               int capacity = (int) si.GetValue("Capacity", typeof(int));
 
-                       throw new NotImplementedException ();
-               }
+                               empty_slot = NO_SLOT;
+                               if (capacity > 0) {
+                                       InitArrays(capacity);
 
-               public IEnumerator<T> GetEnumerator ()
-               {
-                       return new Enumerator (this);
+                                       T[] tableArray = (T[]) si.GetValue("Elements", typeof(T[]));
+                                       if (tableArray == null) 
+                                               throw new SerializationException("Missing Elements");
+
+                                       for (int iElement = 0; iElement < tableArray.Length; iElement++) {
+                                               Add(tableArray[iElement]);
+                                       }
+                               } else 
+                                       table = null;
+
+                               si = null;
+                       }
                }
 
+
                IEnumerator<T> IEnumerable<T>.GetEnumerator ()
                {
                        return new Enumerator (this);
@@ -527,75 +592,76 @@ namespace System.Collections.Generic {
                        get { return false; }
                }
 
-               void ICollection<T>.CopyTo (T [] array, int index)
+               void ICollection<T>.Add (T item)
                {
-                       CopyTo (array, index);
+                       Add (item);
                }
 
-               void ICollection<T>.Add (T item)
+               IEnumerator IEnumerable.GetEnumerator ()
                {
-                       if (!Add (item))
-                               throw new ArgumentException ();
+                       return new Enumerator (this);
                }
 
-               IEnumerator IEnumerable.GetEnumerator ()
+               public Enumerator GetEnumerator ()
                {
                        return new Enumerator (this);
                }
 
-               struct Enumerator : IEnumerator<T>, IDisposable {
+               [Serializable]
+               public struct Enumerator : IEnumerator<T>, IDisposable {
 
                        HashSet<T> hashset;
-                       int index, current;
+                       int next;
                        int stamp;
 
-                       public Enumerator (HashSet<T> hashset)
+                       T current;
+
+                       internal Enumerator (HashSet<T> hashset)
+                               : this ()
                        {
                                this.hashset = hashset;
                                this.stamp = hashset.generation;
-
-                               index = -1;
-                               current = NO_SLOT;
                        }
 
                        public bool MoveNext ()
                        {
                                CheckState ();
 
-                               do {
-                                       if (current != NO_SLOT) {
-                                               current = hashset.links [current].Next;
-                                               continue;
-                                       }
-
-                                       if (index + 1 >= hashset.table.Length)
-                                               return false;
+                               if (next < 0)
+                                       return false;
 
-                                       current = hashset.table [++index] - 1;;
-                               } while (current == NO_SLOT);
+                               while (next < hashset.touched) {
+                                       int cur = next++;
+                                       if (hashset.GetLinkHashCode (cur) != 0) {
+                                               current = hashset.slots [cur];
+                                               return true;
+                                       }
+                               }
 
-                               return true;
+                               next = NO_SLOT;
+                               return false;
                        }
 
                        public T Current {
-                               get {
-                                       CheckCurrent ();
-
-                                       return hashset.slots [current];
-                               }
+                               get { return current; }
                        }
 
                        object IEnumerator.Current {
-                               get { return this.Current; }
+                               get {
+                                       CheckState ();
+                                       if (next <= 0)
+                                               throw new InvalidOperationException ("Current is not valid");
+                                       return current;
+                               }
                        }
 
                        void IEnumerator.Reset ()
                        {
-                               index = -1;
-                               current = NO_SLOT;
+                               CheckState ();
+                               next = 0;
                        }
 
-                       void IDisposable.Dispose ()
+                       public void Dispose ()
                        {
                                hashset = null;
                        }
@@ -607,90 +673,39 @@ namespace System.Collections.Generic {
                                if (hashset.generation != stamp)
                                        throw new InvalidOperationException ("HashSet have been modified while it was iterated over");
                        }
-
-                       void CheckCurrent ()
-                       {
-                               CheckState ();
-
-                               if (current == NO_SLOT)
-                                       throw new InvalidOperationException ("Current is not valid");
-                       }
                }
+       }
+       
+       sealed class HashSetEqualityComparer<T> : IEqualityComparer<HashSet<T>>
+       {
+               public static readonly HashSetEqualityComparer<T> Instance = new HashSetEqualityComparer<T> ();
+                       
+               public bool Equals (HashSet<T> lhs, HashSet<T> rhs)
+               {
+                       if (lhs == rhs)
+                               return true;
 
-               // borrowed from System.Collections.HashTable
-               static class PrimeHelper {
-
-                       static readonly int [] primes_table = {
-                               11,
-                               19,
-                               37,
-                               73,
-                               109,
-                               163,
-                               251,
-                               367,
-                               557,
-                               823,
-                               1237,
-                               1861,
-                               2777,
-                               4177,
-                               6247,
-                               9371,
-                               14057,
-                               21089,
-                               31627,
-                               47431,
-                               71143,
-                               106721,
-                               160073,
-                               240101,
-                               360163,
-                               540217,
-                               810343,
-                               1215497,
-                               1823231,
-                               2734867,
-                               4102283,
-                               6153409,
-                               9230113,
-                               13845163
-                       };
-
-                       static bool TestPrime (int x)
-                       {
-                               if ((x & 1) != 0) {
-                                       int top = (int) Math.Sqrt (x);
-
-                                       for (int n = 3; n < top; n += 2) {
-                                               if ((x % n) == 0)
-                                                       return false;
-                                       }
-
-                                       return true;
-                               }
+                       if (lhs == null || rhs == null || lhs.Count != rhs.Count)
+                               return false;
 
-                               // There is only one even prime - 2.
-                               return x == 2;
-                       }
+                       foreach (var item in lhs)
+                               if (!rhs.Contains (item))
+                                       return false;
 
-                       static int CalcPrime (int x)
-                       {
-                               for (int i = (x & (~1)) - 1; i < Int32.MaxValue; i += 2)
-                                       if (TestPrime (i))
-                                               return i;
+                       return true;
+               }
 
-                               return x;
-                       }
+               public int GetHashCode (HashSet<T> hashset)
+               {
+                       if (hashset == null)
+                               return 0;
 
-                       public static int ToPrime (int x)
-                       {
-                               for (int i = 0; i < primes_table.Length; i++)
-                                       if (x <= primes_table [i])
-                                               return primes_table [i];
+                       IEqualityComparer<T> comparer = EqualityComparer<T>.Default;
+                       int hash = 0;
+                       foreach (var item in hashset)
+                               hash ^= comparer.GetHashCode (item);
 
-                               return CalcPrime (x);
-                       }
+                       return hash;
                }
        }
 }