using System.Threading;
using System.Collections;
using System.Collections.Generic;
+using System.Runtime.Serialization;
namespace System.Collections.Concurrent
{
- public class ConcurrentDictionary<TKey, TValue> : IDictionary<TKey, TValue>
- , ICollection<KeyValuePair<TKey, TValue>>, ICollection, IEnumerable<KeyValuePair<TKey, TValue>>, IEnumerable
- , IProducerConsumerCollection<KeyValuePair<TKey, TValue>>
-
+ public class ConcurrentDictionary<TKey, TValue> : IDictionary<TKey, TValue>,
+ ICollection<KeyValuePair<TKey, TValue>>, IEnumerable<KeyValuePair<TKey, TValue>>,
+ IDictionary, ICollection, IEnumerable, ISerializable, IDeserializationCallback
{
class Pair
{
// Assumption: a List<T> is never empty
ConcurrentSkipList<Basket> container
- = new ConcurrentSkipList<Basket> ((value) => value [0].GetHashCode ());
+ = new ConcurrentSkipList<Basket> ((value) => value[0].GetHashCode ());
int count;
int stamp = int.MinValue;
+ IEqualityComparer<TKey> comparer;
+
+ public ConcurrentDictionary () : this (EqualityComparer<TKey>.Default)
+ {
+ }
+
+ public ConcurrentDictionary (IEnumerable<KeyValuePair<TKey, TValue>> values)
+ : this (values, EqualityComparer<TKey>.Default)
+ {
+ foreach (KeyValuePair<TKey, TValue> pair in values)
+ Add (pair.Key, pair.Value);
+ }
+
+ public ConcurrentDictionary (IEqualityComparer<TKey> comparer)
+ {
+ this.comparer = comparer;
+ }
+
+ public ConcurrentDictionary (IEnumerable<KeyValuePair<TKey, TValue>> values, IEqualityComparer<TKey> comparer)
+ : this (comparer)
+ {
+ foreach (KeyValuePair<TKey, TValue> pair in values)
+ Add (pair.Key, pair.Value);
+ }
+
+ // Parameters unused
+ public ConcurrentDictionary (int concurrencyLevel, int capacity)
+ : this (EqualityComparer<TKey>.Default)
+ {
+
+ }
+
+ public ConcurrentDictionary (int concurrencyLevel,
+ IEnumerable<KeyValuePair<TKey, TValue>> values,
+ IEqualityComparer<TKey> comparer)
+ : this (values, comparer)
+ {
+
+ }
+
+ // Parameters unused
+ public ConcurrentDictionary (int concurrencyLevel, int capacity, IEqualityComparer<TKey> comparer)
+ : this (comparer)
+ {
+
+ }
- public ConcurrentDictionary ()
+ internal ConcurrentDictionary (SerializationInfo info, StreamingContext context)
{
+ throw new NotImplementedException ();
}
- public void Add (TKey key, TValue value)
+ void Add (TKey key, TValue value)
{
while (!TryAdd (key, value));
}
+ void IDictionary<TKey, TValue>.Add (TKey key, TValue value)
+ {
+ Add (key, value);
+ }
+
public bool TryAdd (TKey key, TValue value)
{
Interlocked.Increment (ref count);
// Find a maybe more sexy locking scheme later
lock (basket) {
foreach (var p in basket) {
- if (p.Key.Equals (key))
+ if (comparer.Equals (p.Key, key))
throw new ArgumentException ("An element with the same key already exists");
}
basket.Add (new Pair (key, value));
Add (pair.Key, pair.Value);
}
- public TValue GetValue (TKey key)
+ TValue GetValue (TKey key)
{
TValue temp;
if (!TryGetValue (key, out temp))
return false;
lock (basket) {
- Pair pair = basket.Find ((p) => p.Key.Equals (key));
+ Pair pair = basket.Find ((p) => comparer.Equals (p.Key, key));
if (pair == null)
return false;
value = pair.Value;
return true;
}
+ public bool TryUpdate (TKey key, TValue newValue, TValue comparand)
+ {
+ Basket basket;
+ if (!TryGetBasket (key, out basket))
+ return false;
+
+ lock (basket) {
+ Pair pair = basket.Find ((p) => comparer.Equals (p.Key, key));
+ if (pair.Value.Equals (comparand)) {
+ pair.Value = newValue;
+
+ return true;
+ }
+ }
+
+ return false;
+ }
+
public TValue this[TKey key] {
get {
return GetValue (key);
return;
}
lock (basket) {
- Pair pair = basket.Find ((p) => p.Key.Equals (key));
+ Pair pair = basket.Find ((p) => comparer.Equals (p.Key, key));
if (pair == null)
throw new InvalidOperationException ("pair is null, shouldn't be");
pair.Value = value;
}
}
- public bool Remove(TKey key)
+ public bool TryRemove(TKey key, out TValue value)
{
+ value = default (TValue);
Basket b;
+
if (!TryGetBasket (key, out b))
return false;
lock (b) {
+ TValue temp = default (TValue);
// Should always be == 1 but who know
- bool result = b.RemoveAll ((p) => p.Key.Equals (key)) >= 1;
+ bool result = b.RemoveAll ((p) => {
+ bool r = comparer.Equals (p.Key, key);
+ if (r) temp = p.Value;
+ return r;
+ }) >= 1;
+ value = temp;
+
if (result)
Interlocked.Decrement (ref count);
}
}
+ bool Remove (TKey key)
+ {
+ TValue dummy;
+
+ return TryRemove (key, out dummy);
+ }
+
+ bool IDictionary<TKey, TValue>.Remove (TKey key)
+ {
+ return Remove (key);
+ }
+
bool ICollection<KeyValuePair<TKey,TValue>>.Remove (KeyValuePair<TKey,TValue> pair)
{
return Remove (pair.Key);
return container.ContainsFromHash (key.GetHashCode ());
}
- bool ICollection<KeyValuePair<TKey,TValue>>.Contains (KeyValuePair<TKey, TValue> pair)
+ bool IDictionary.Contains (object key)
{
- return ContainsKey (pair.Key);
+ if (!(key is TKey))
+ return false;
+
+ return ContainsKey ((TKey)key);
}
- public bool TryAdd (KeyValuePair<TKey, TValue> item)
+ void IDictionary.Remove (object key)
{
- Add (item.Key, item.Value);
+ if (!(key is TKey))
+ return;
- return true;
+ Remove ((TKey)key);
}
- public bool TryTake (out KeyValuePair<TKey, TValue> item)
+ object IDictionary.this [object key]
{
- item = default (KeyValuePair<TKey, TValue>);
- return false;
+ get {
+ if (!(key is TKey))
+ throw new ArgumentException ("key isn't of correct type", "key");
+
+ return this[(TKey)key];
+ }
+ set {
+ if (!(key is TKey) || !(value is TValue))
+ throw new ArgumentException ("key or value aren't of correct type");
+
+ this[(TKey)key] = (TValue)value;
+ }
+ }
+
+ void IDictionary.Add (object key, object value)
+ {
+ if (!(key is TKey) || !(value is TValue))
+ throw new ArgumentException ("key or value aren't of correct type");
+
+ Add ((TKey)key, (TValue)value);
+ }
+
+ bool ICollection<KeyValuePair<TKey,TValue>>.Contains (KeyValuePair<TKey, TValue> pair)
+ {
+ return ContainsKey (pair.Key);
}
public KeyValuePair<TKey,TValue>[] ToArray ()
}
}
- public bool IsReadOnly {
+ public bool IsEmpty {
+ get {
+ return count == 0;
+ }
+ }
+
+ bool ICollection<KeyValuePair<TKey, TValue>>.IsReadOnly {
+ get {
+ return false;
+ }
+ }
+
+ bool IDictionary.IsReadOnly {
get {
return false;
}
CopyTo (arr, startIndex, count);
}
- public void CopyTo (KeyValuePair<TKey, TValue>[] array, int startIndex)
+ void CopyTo (KeyValuePair<TKey, TValue>[] array, int startIndex)
{
CopyTo (array, startIndex, count);
}
+ void ICollection<KeyValuePair<TKey, TValue>>.CopyTo (KeyValuePair<TKey, TValue>[] array, int startIndex)
+ {
+ CopyTo (array, startIndex);
+ }
+
void CopyTo (KeyValuePair<TKey, TValue>[] array, int startIndex, int num)
{
// TODO: This is quite unsafe as the count value will likely change during
}
}
+ void ISerializable.GetObjectData (SerializationInfo info, StreamingContext context)
+ {
+ throw new NotImplementedException ();
+ }
+
bool ICollection.IsSynchronized {
- get {
- return true;
- }
+ get { return true; }
+ }
+
+ void IDeserializationCallback.OnDeserialization (object sender)
+ {
+ throw new NotImplementedException ();
}
- public bool IsFixedSize {
+ bool IDictionary.IsFixedSize {
get {
return false;
}
}
+
bool TryGetBasket (TKey key, out Basket basket)
{
return true;
}
- /// <summary>
- /// </summary>
- /// <param name="element"></param>
public void Push (T element)
{
Node temp = new Node ();
temp.Value = element;
-
do {
- temp.Next = head;
- } while (Interlocked.CompareExchange<Node> (ref head, temp, temp.Next) != temp.Next);
+ temp.Next = head;
+ } while (Interlocked.CompareExchange (ref head, temp, temp.Next) != temp.Next);
Interlocked.Increment (ref count);
}
+
+ public void PushRange (T[] values)
+ {
+ PushRange (values, 0, values.Length);
+ }
+
+ public void PushRange (T[] values, int start, int length)
+ {
+ Node insert = null;
+ Node first = null;
+
+ for (int i = start; i < length; i++) {
+ Node temp = new Node ();
+ temp.Value = values[i];
+ temp.Next = insert;
+ insert = temp;
+
+ if (first == null)
+ first = temp;
+ }
+
+ do {
+ first.Next = head;
+ } while (Interlocked.CompareExchange (ref head, insert, first.Next) != first.Next);
+
+ Interlocked.Add (ref count, length);
+ }
+
- /// <summary>
- /// </summary>
- /// <returns></returns>
public bool TryPop (out T value)
{
Node temp;
value = default (T);
return false;
}
- } while (Interlocked.CompareExchange<Node> (ref head, temp.Next, temp) != temp);
+ } while (Interlocked.CompareExchange (ref head, temp.Next, temp) != temp);
Interlocked.Decrement (ref count);
value = temp.Value;
return true;
}
+
+ public int TryPopRange (T[] values)
+ {
+ return TryPopRange (values, 0, values.Length);
+ }
+
+ public int TryPopRange (T[] values, int startIndex, int count)
+ {
+ int insertIndex = startIndex;
+ Node temp;
+ Node end;
+
+ do {
+ temp = head;
+ if (temp == null)
+ return -1;
+ end = temp;
+ for (int j = 0; j < count - 1; j++) {
+ end = end.Next;
+ if (end == null)
+ break;
+ }
+ } while (Interlocked.CompareExchange (ref head, end, temp) != temp);
+
+ int i;
+ for (i = startIndex; i < count && temp != null; i++) {
+ values[i] = temp.Value;
+ temp = temp.Next;
+ }
+
+ return i - 1;
+ }
- /// <summary>
- /// </summary>
- /// <returns></returns>
public bool TryPeek (out T value)
{
Node myHead = head;
}
}
- public void CopyTo (Array array, int index)
+ void ICollection.CopyTo (Array array, int index)
{
T[] dest = array as T[];
if (dest == null)
}
}
- public virtual void GetObjectData (SerializationInfo info, StreamingContext context)
+ void ISerializable.GetObjectData (SerializationInfo info, StreamingContext context)
{
throw new NotImplementedException ();
}
get { return true; }
}
- public virtual void OnDeserialization (object sender)
+ void IDeserializationCallback.OnDeserialization (object sender)
{
throw new NotImplementedException ();
}
ConcurrentDictionary<string, int> map;
[SetUp]
- public void Setup()
+ public void Setup ()
{
- map = new ConcurrentDictionary<string, int>();
+ map = new ConcurrentDictionary<string, int> ();
AddStuff();
}
- void AddStuff()
+ void AddStuff ()
{
- map.Add("foo", 1);
- map.Add("bar", 2);
- map.Add("foobar", 3);
+ map.TryAdd ("foo", 1);
+ map.TryAdd ("bar", 2);
+ map.TryAdd ("foobar", 3);
}
[Test]
- public void AddWithoutDuplicateTest()
+ public void AddWithoutDuplicateTest ()
{
- map.Add("baz", 2);
+ map.TryAdd("baz", 2);
+ int val;
- Assert.AreEqual(2, map.GetValue("baz"));
+ Assert.IsTrue (map.TryGetValue("baz", out val));
+ Assert.AreEqual(2, val);
Assert.AreEqual(2, map["baz"]);
Assert.AreEqual(4, map.Count);
}
Setup ();
int index = 0;
bool r1 = false, r2 = false, r3 = false;
+ int val;
ParallelTestHelper.ParallelStressTest (map, delegate {
int own = Interlocked.Increment (ref index);
switch (own) {
case 1:
- r1 = map.Remove ("foo");
+ r1 = map.TryRemove ("foo", out val);
break;
case 2:
- r2 =map.Remove ("bar");
+ r2 =map.TryRemove ("bar", out val);
break;
case 3:
- r3 = map.Remove ("foobar");
+ r3 = map.TryRemove ("foobar", out val);
break;
}
}, 3);
[Test, ExpectedException(typeof(ArgumentException))]
public void AddWithDuplicate()
{
- map.Add("foo", 6);
+ map.TryAdd("foo", 6);
}
[Test]
public void GetValueTest()
{
- Assert.AreEqual(1, map.GetValue("foo"), "#1");
+ Assert.AreEqual(1, map["foo"], "#1");
Assert.AreEqual(2, map["bar"], "#2");
Assert.AreEqual(3, map.Count, "#3");
}
{
int val;
Assert.IsFalse(map.TryGetValue("barfoo", out val));
- map.GetValue("barfoo");
+ val = map["barfoo"];
}
[Test]
int val;
Assert.AreEqual(9, map["foo"], "#1");
- Assert.AreEqual(9, map.GetValue("foo"), "#2");
Assert.IsTrue(map.TryGetValue("foo", out val), "#3");
Assert.AreEqual(9, val, "#4");
}