Merge pull request #409 from Alkarex/patch-1
[mono.git] / mcs / class / corlib / System.Collections.Concurrent / ConcurrentDictionary.cs
index e4af46c1e9b9630c50d454eadbd215e5ad24b231..2bfb952af7d33c0aaeb4fe7c33e2246e4b819528 100644 (file)
@@ -1,5 +1,4 @@
-#if NET_4_0
-// ConcurrentSkipList.cs
+// ConcurrentDictionary.cs
 //
 // Copyright (c) 2009 Jérémie "Garuma" Laval
 //
 //
 //
 
+#if NET_4_0 || MOBILE
+
 using System;
 using System.Threading;
 using System.Collections;
 using System.Collections.Generic;
 using System.Runtime.Serialization;
+using System.Diagnostics;
 
 namespace System.Collections.Concurrent
 {
+       [DebuggerDisplay ("Count={Count}")]
+       [DebuggerTypeProxy (typeof (CollectionDebuggerView<,>))]
        public class ConcurrentDictionary<TKey, TValue> : IDictionary<TKey, TValue>,
          ICollection<KeyValuePair<TKey, TValue>>, IEnumerable<KeyValuePair<TKey, TValue>>,
          IDictionary, ICollection, IEnumerable
        {
-               class Pair
-               {
-                       public readonly TKey Key;
-                       public TValue Value;
-
-                       public Pair (TKey key, TValue value)
-                       {
-                               Key = key;
-                               Value = value;
-                       }
-
-                       public override bool Equals (object obj)
-                       {
-                               Pair rhs = obj as Pair;
-                               return rhs == null ? false : Key.Equals (rhs.Key) && Value.Equals (rhs.Value);
-                       }
-
-                       public override int GetHashCode ()
-                       {
-                               return Key.GetHashCode ();
-                       }
-               }
-
-               class Basket: List<Pair>
-               {
-                       public SpinLock Lock = new SpinLock ();
-               }
-
-               // Assumption: a List<T> is never empty
-               ConcurrentSkipList<Basket> container
-                       = new ConcurrentSkipList<Basket> ((value) => value[0].GetHashCode ());
-               int count;
                IEqualityComparer<TKey> comparer;
 
+               SplitOrderedList<TKey, KeyValuePair<TKey, TValue>> internalDictionary;
+
                public ConcurrentDictionary () : this (EqualityComparer<TKey>.Default)
                {
                }
 
-               public ConcurrentDictionary (IEnumerable<KeyValuePair<TKey, TValue>> values)
-                       : this (values, EqualityComparer<TKey>.Default)
+               public ConcurrentDictionary (IEnumerable<KeyValuePair<TKey, TValue>> collection)
+                       : this (collection, EqualityComparer<TKey>.Default)
                {
-                       foreach (KeyValuePair<TKey, TValue> pair in values)
-                               Add (pair.Key, pair.Value);
                }
 
                public ConcurrentDictionary (IEqualityComparer<TKey> comparer)
                {
                        this.comparer = comparer;
+                       this.internalDictionary = new SplitOrderedList<TKey, KeyValuePair<TKey, TValue>> (comparer);
                }
 
-               public ConcurrentDictionary (IEnumerable<KeyValuePair<TKey, TValue>> values, IEqualityComparer<TKey> comparer)
+               public ConcurrentDictionary (IEnumerable<KeyValuePair<TKey, TValue>> collection, IEqualityComparer<TKey> comparer)
                        : this (comparer)
                {
-                       foreach (KeyValuePair<TKey, TValue> pair in values)
+                       foreach (KeyValuePair<TKey, TValue> pair in collection)
                                Add (pair.Key, pair.Value);
                }
 
@@ -100,9 +73,9 @@ namespace System.Collections.Concurrent
                }
 
                public ConcurrentDictionary (int concurrencyLevel,
-                                            IEnumerable<KeyValuePair<TKey, TValue>> values,
+                                            IEnumerable<KeyValuePair<TKey, TValue>> collection,
                                             IEqualityComparer<TKey> comparer)
-                       : this (values, comparer)
+                       : this (collection, comparer)
                {
 
                }
@@ -114,6 +87,12 @@ namespace System.Collections.Concurrent
 
                }
 
+               void CheckKey (TKey key)
+               {
+                       if (key == null)
+                               throw new ArgumentNullException ("key");
+               }
+
                void Add (TKey key, TValue value)
                {
                        while (!TryAdd (key, value));
@@ -126,42 +105,8 @@ namespace System.Collections.Concurrent
 
                public bool TryAdd (TKey key, TValue value)
                {
-                       Basket basket;
-                       bool taken = false;
-
-                       // Add a value to an existing basket
-                       if (TryGetBasket (key, out basket)) {
-                               while (!taken) {
-                                       try {
-                                               basket.Lock.Enter (ref taken);
-                                               if (!taken)
-                                                       continue;
-
-                                               foreach (var p in basket) {
-                                                       if (comparer.Equals (p.Key, key))
-                                                               return false;
-                                               }
-                                               basket.Add (new Pair (key, value));
-                                       } finally {
-                                               basket.Lock.Exit ();
-                                       }
-                               }
-                       } else {
-                               // Add a new basket
-                               basket = new Basket ();
-                               basket.Add (new Pair (key, value));
-
-                               if (container.TryAdd (basket)) {
-                                       Interlocked.Increment (ref count);
-                                       return true;
-                               } else {
-                                       return false;
-                               }
-                       }
-
-                       Interlocked.Increment (ref count);
-
-                       return true;
+                       CheckKey (key);
+                       return internalDictionary.Insert (Hash (key), key, Make (key, value));
                }
 
                void ICollection<KeyValuePair<TKey,TValue>>.Add (KeyValuePair<TKey, TValue> pair)
@@ -171,30 +116,15 @@ namespace System.Collections.Concurrent
 
                public TValue AddOrUpdate (TKey key, Func<TKey, TValue> addValueFactory, Func<TKey, TValue, TValue> updateValueFactory)
                {
-                       Basket basket;
-                       TValue temp = default (TValue);
-                       bool taken = false;
-
-                       if (!TryGetBasket (key, out basket)) {
-                               Add (key, (temp = addValueFactory (key)));
-                       } else {
-                               while (!taken) {
-                                       try {
-                                               basket.Lock.Enter (ref taken);
-                                               if (!taken)
-                                                       continue;
-
-                                               Pair pair = basket.Find ((p) => comparer.Equals (p.Key, key));
-                                               if (pair == null)
-                                                       throw new InvalidOperationException ("pair is null, shouldn't be");
-                                               pair.Value = (temp = updateValueFactory (key, pair.Value));
-                                       } finally {
-                                               basket.Lock.Exit ();
-                                       }
-                               }
-                       }
-
-                       return temp;
+                       CheckKey (key);
+                       if (addValueFactory == null)
+                               throw new ArgumentNullException ("addValueFactory");
+                       if (updateValueFactory == null)
+                               throw new ArgumentNullException ("updateValueFactory");
+                       return internalDictionary.InsertOrUpdate (Hash (key),
+                                                                 key,
+                                                                 () => Make (key, addValueFactory (key)),
+                                                                 (e) => Make (key, updateValueFactory (key, e.Value))).Value;
                }
 
                public TValue AddOrUpdate (TKey key, TValue addValue, Func<TKey, TValue, TValue> updateValueFactory)
@@ -202,68 +132,37 @@ namespace System.Collections.Concurrent
                        return AddOrUpdate (key, (_) => addValue, updateValueFactory);
                }
 
+               TValue AddOrUpdate (TKey key, TValue addValue, TValue updateValue)
+               {
+                       CheckKey (key);
+                       return internalDictionary.InsertOrUpdate (Hash (key),
+                                                                 key,
+                                                                 Make (key, addValue),
+                                                                 Make (key, updateValue)).Value;
+               }
+
                TValue GetValue (TKey key)
                {
                        TValue temp;
                        if (!TryGetValue (key, out temp))
-                               // TODO: find a correct Exception
-                               throw new ArgumentException ("Not a valid key for this dictionary", "key");
+                               throw new KeyNotFoundException (key.ToString ());
                        return temp;
                }
 
                public bool TryGetValue (TKey key, out TValue value)
                {
-                       Basket basket;
-                       value = default (TValue);
-                       bool taken = false;
+                       CheckKey (key);
+                       KeyValuePair<TKey, TValue> pair;
+                       bool result = internalDictionary.Find (Hash (key), key, out pair);
+                       value = pair.Value;
 
-                       if (!TryGetBasket (key, out basket))
-                               return false;
-
-                       while (!taken) {
-                               try {
-                                       basket.Lock.Enter (ref taken);
-                                       if (!taken)
-                                               continue;
-
-                                       Pair pair = basket.Find ((p) => comparer.Equals (p.Key, key));
-                                       if (pair == null)
-                                               return false;
-                                       value = pair.Value;
-                               } finally {
-                                       basket.Lock.Exit ();
-                               }
-                       }
-
-                       return true;
+                       return result;
                }
 
-               public bool TryUpdate (TKey key, TValue newValue, TValue comparand)
+               public bool TryUpdate (TKey key, TValue newValue, TValue comparisonValue)
                {
-                       Basket basket;
-                       bool taken = false;
-
-                       if (!TryGetBasket (key, out basket))
-                               return false;
-
-                       while (!taken) {
-                               try {
-                                       basket.Lock.Enter (ref taken);
-                                       if (!taken)
-                                               continue;
-
-                                       Pair pair = basket.Find ((p) => comparer.Equals (p.Key, key));
-                                       if (pair.Value.Equals (comparand)) {
-                                               pair.Value = newValue;
-
-                                               return true;
-                                       }
-                               } finally {
-                                       basket.Lock.Exit ();
-                               }
-                       }
-
-                       return false;
+                       CheckKey (key);
+                       return internalDictionary.CompareExchange (Hash (key), key, Make (key, newValue), (e) => e.Value.Equals (comparisonValue));
                }
 
                public TValue this[TKey key] {
@@ -271,101 +170,29 @@ namespace System.Collections.Concurrent
                                return GetValue (key);
                        }
                        set {
-                               Basket basket;
-                               bool taken = false;
-
-                               if (!TryGetBasket (key, out basket)) {
-                                       Add (key, value);
-                                       return;
-                               }
-
-                               while (!taken) {
-                                       try {
-                                               basket.Lock.Enter (ref taken);
-                                               if (!taken)
-                                                       continue;
-
-                                               Pair pair = basket.Find ((p) => comparer.Equals (p.Key, key));
-                                               if (pair == null)
-                                                       throw new InvalidOperationException ("pair is null, shouldn't be");
-                                               pair.Value = value;
-                                       } finally {
-                                               basket.Lock.Exit ();
-                                       }
-                               }
+                               AddOrUpdate (key, value, value);
                        }
                }
 
                public TValue GetOrAdd (TKey key, Func<TKey, TValue> valueFactory)
                {
-                       Basket basket;
-                       TValue temp = default (TValue);
-
-                       if (TryGetBasket (key, out basket)) {
-                               Pair pair = null;
-                               bool taken = false;
-
-                               while (!taken) {
-                                       try {
-                                               basket.Lock.Enter (ref taken);
-                                               if (!taken)
-                                                       continue;
-                                               pair = basket.Find ((p) => comparer.Equals (p.Key, key));
-                                               if (pair != null)
-                                                       temp = pair.Value;
-                                       } finally {
-                                               basket.Lock.Exit ();
-                                       }
-                               }
-
-                               if (pair == null)
-                                       Add (key, (temp = valueFactory (key)));
-                       } else {
-                               Add (key, (temp = valueFactory (key)));
-                       }
-
-                       return temp;
+                       CheckKey (key);
+                       return internalDictionary.InsertOrGet (Hash (key), key, Make (key, default(TValue)), () => Make (key, valueFactory (key))).Value;
                }
 
                public TValue GetOrAdd (TKey key, TValue value)
                {
-                       return GetOrAdd (key, (_) => value);
+                       CheckKey (key);
+                       return internalDictionary.InsertOrGet (Hash (key), key, Make (key, value), null).Value;
                }
 
-               public bool TryRemove(TKey key, out TValue value)
+               public bool TryRemove (TKey key, out TValue value)
                {
-                       value = default (TValue);
-                       Basket b;
-                       bool taken = false;
-
-                       if (!TryGetBasket (key, out b))
-                               return false;
-
-                       while (!taken) {
-                               try {
-                                       b.Lock.Enter (ref taken);
-                                       if (!taken)
-                                               continue;
-
-                                       TValue temp = default (TValue);
-                                       // Should always be == 1 but who know
-                                       bool result = b.RemoveAll ((p) => {
-                                               bool r = comparer.Equals (p.Key, key);
-                                               if (r) temp = p.Value;
-                                               return r;
-                                       }) >= 1;
-                                       value = temp;
-
-                                       if (result)
-                                               Interlocked.Decrement (ref count);
-
-                                       return result;
-                               } finally {
-                                       b.Lock.Exit ();
-                               }
-                       }
-
-                       return false;
+                       CheckKey (key);
+                       KeyValuePair<TKey, TValue> data;
+                       bool result = internalDictionary.Delete (Hash (key), key, out data);
+                       value = data.Value;
+                       return result;
                }
 
                bool Remove (TKey key)
@@ -387,7 +214,9 @@ namespace System.Collections.Concurrent
 
                public bool ContainsKey (TKey key)
                {
-                       return container.ContainsFromHash (key.GetHashCode ());
+                       CheckKey (key);
+                       KeyValuePair<TKey, TValue> dummy;
+                       return internalDictionary.Find (Hash (key), key, out dummy);
                }
 
                bool IDictionary.Contains (object key)
@@ -446,18 +275,18 @@ namespace System.Collections.Concurrent
                public void Clear()
                {
                        // Pronk
-                       container = new ConcurrentSkipList<Basket> ((value) => value [0].GetHashCode ());
+                       internalDictionary = new SplitOrderedList<TKey, KeyValuePair<TKey, TValue>> (comparer);
                }
 
                public int Count {
                        get {
-                               return count;
+                               return internalDictionary.Count;
                        }
                }
 
                public bool IsEmpty {
                        get {
-                               return count == 0;
+                               return Count == 0;
                        }
                }
 
@@ -513,12 +342,12 @@ namespace System.Collections.Concurrent
                        if (arr == null)
                                return;
 
-                       CopyTo (arr, startIndex, count);
+                       CopyTo (arr, startIndex, Count);
                }
 
                void CopyTo (KeyValuePair<TKey, TValue>[] array, int startIndex)
                {
-                       CopyTo (array, startIndex, count);
+                       CopyTo (array, startIndex, Count);
                }
 
                void ICollection<KeyValuePair<TKey, TValue>>.CopyTo (KeyValuePair<TKey, TValue>[] array, int startIndex)
@@ -528,31 +357,11 @@ namespace System.Collections.Concurrent
 
                void CopyTo (KeyValuePair<TKey, TValue>[] array, int startIndex, int num)
                {
-                       // TODO: This is quite unsafe as the count value will likely change during
-                       // the copying. Watchout for IndexOutOfRange thingies
-                       if (array.Length <= count + startIndex)
-                               throw new InvalidOperationException ("The array isn't big enough");
-
-                       int i = startIndex;
-
-                       foreach (Basket b in container) {
-                               bool taken = false;
-
-                               while (!taken) {
-                                       try {
-                                               b.Lock.Enter (ref taken);
-                                               if (!taken)
-                                                       continue;
-
-                                               foreach (Pair p in b) {
-                                                       if (i >= num)
-                                                               break;
-                                                       array[i++] = new KeyValuePair<TKey, TValue> (p.Key, p.Value);
-                                               }
-                                       } finally {
-                                               b.Lock.Exit ();
-                                       }
-                               }
+                       foreach (var kvp in this) {
+                               array [startIndex++] = kvp;
+
+                               if (--num <= 0)
+                                       return;
                        }
                }
 
@@ -568,22 +377,7 @@ namespace System.Collections.Concurrent
 
                IEnumerator<KeyValuePair<TKey, TValue>> GetEnumeratorInternal ()
                {
-                       foreach (Basket b in container) {
-                               bool taken = false;
-
-                               while (!taken) {
-                                       try {
-                                               b.Lock.Enter (ref taken);
-                                               if (!taken)
-                                                       continue;
-
-                                               foreach (Pair p in b)
-                                                       yield return new KeyValuePair<TKey, TValue> (p.Key, p.Value);
-                                       } finally {
-                                               b.Lock.Exit ();
-                                       }
-                               }
-                       }
+                       return internalDictionary.GetEnumerator ();
                }
 
                IDictionaryEnumerator IDictionary.GetEnumerator ()
@@ -642,7 +436,6 @@ namespace System.Collections.Concurrent
                        }
                }
 
-
                bool IDictionary.IsFixedSize {
                        get {
                                return false;
@@ -653,13 +446,14 @@ namespace System.Collections.Concurrent
                        get { return true; }
                }
 
-               bool TryGetBasket (TKey key, out Basket basket)
+               static KeyValuePair<U, V> Make<U, V> (U key, V value)
                {
-                       basket = null;
-                       if (!container.GetFromHash (key.GetHashCode (), out basket))
-                               return false;
+                       return new KeyValuePair<U, V> (key, value);
+               }
 
-                       return true;
+               uint Hash (TKey key)
+               {
+                       return (uint)comparer.GetHashCode (key);
                }
        }
 }