Check key for null in ConcurrentDictionary
authorPetr Onderka <gsvick@gmail.com>
Wed, 18 Jul 2012 17:30:43 +0000 (19:30 +0200)
committerPetr Onderka <gsvick@gmail.com>
Sun, 19 Aug 2012 22:03:31 +0000 (00:03 +0200)
mcs/class/corlib/System.Collections.Concurrent/ConcurrentDictionary.cs
mcs/class/corlib/Test/System.Collections.Concurrent/ConcurrentDictionaryTests.cs

index 46c32b35186d1ef5fdd6e3e6edf8694fc89b1a8d..2bfb952af7d33c0aaeb4fe7c33e2246e4b819528 100644 (file)
@@ -87,6 +87,12 @@ namespace System.Collections.Concurrent
 
                }
 
+               void CheckKey (TKey key)
+               {
+                       if (key == null)
+                               throw new ArgumentNullException ("key");
+               }
+
                void Add (TKey key, TValue value)
                {
                        while (!TryAdd (key, value));
@@ -99,6 +105,7 @@ namespace System.Collections.Concurrent
 
                public bool TryAdd (TKey key, TValue value)
                {
+                       CheckKey (key);
                        return internalDictionary.Insert (Hash (key), key, Make (key, value));
                }
 
@@ -109,6 +116,11 @@ namespace System.Collections.Concurrent
 
                public TValue AddOrUpdate (TKey key, Func<TKey, TValue> addValueFactory, Func<TKey, TValue, TValue> updateValueFactory)
                {
+                       CheckKey (key);
+                       if (addValueFactory == null)
+                               throw new ArgumentNullException ("addValueFactory");
+                       if (updateValueFactory == null)
+                               throw new ArgumentNullException ("updateValueFactory");
                        return internalDictionary.InsertOrUpdate (Hash (key),
                                                                  key,
                                                                  () => Make (key, addValueFactory (key)),
@@ -122,6 +134,7 @@ namespace System.Collections.Concurrent
 
                TValue AddOrUpdate (TKey key, TValue addValue, TValue updateValue)
                {
+                       CheckKey (key);
                        return internalDictionary.InsertOrUpdate (Hash (key),
                                                                  key,
                                                                  Make (key, addValue),
@@ -138,6 +151,7 @@ namespace System.Collections.Concurrent
 
                public bool TryGetValue (TKey key, out TValue value)
                {
+                       CheckKey (key);
                        KeyValuePair<TKey, TValue> pair;
                        bool result = internalDictionary.Find (Hash (key), key, out pair);
                        value = pair.Value;
@@ -147,6 +161,7 @@ namespace System.Collections.Concurrent
 
                public bool TryUpdate (TKey key, TValue newValue, TValue comparisonValue)
                {
+                       CheckKey (key);
                        return internalDictionary.CompareExchange (Hash (key), key, Make (key, newValue), (e) => e.Value.Equals (comparisonValue));
                }
 
@@ -161,16 +176,19 @@ namespace System.Collections.Concurrent
 
                public TValue GetOrAdd (TKey key, Func<TKey, TValue> valueFactory)
                {
+                       CheckKey (key);
                        return internalDictionary.InsertOrGet (Hash (key), key, Make (key, default(TValue)), () => Make (key, valueFactory (key))).Value;
                }
 
                public TValue GetOrAdd (TKey key, TValue value)
                {
+                       CheckKey (key);
                        return internalDictionary.InsertOrGet (Hash (key), key, Make (key, value), null).Value;
                }
 
                public bool TryRemove (TKey key, out TValue value)
                {
+                       CheckKey (key);
                        KeyValuePair<TKey, TValue> data;
                        bool result = internalDictionary.Delete (Hash (key), key, out data);
                        value = data.Value;
@@ -196,6 +214,7 @@ namespace System.Collections.Concurrent
 
                public bool ContainsKey (TKey key)
                {
+                       CheckKey (key);
                        KeyValuePair<TKey, TValue> dummy;
                        return internalDictionary.Find (Hash (key), key, out dummy);
                }
index 70888e8637bca4080e656eabee40bd46d0675195..5978f6b219fe335253770e7bb0947cc784ff9377 100644 (file)
@@ -314,6 +314,33 @@ namespace MonoTests.System.Collections.Concurrent
                        foreach (var id in ids)
                                Assert.IsFalse (dict.TryGetValue (id, out result), id.ToString () + " (second)");
                }
+
+               [Test]
+               public void NullArgumentsTest ()
+               {
+                       AssertThrowsArgumentNullException (() => { var x = map[null]; });
+                       AssertThrowsArgumentNullException (() => map[null] = 0);
+                       AssertThrowsArgumentNullException (() => map.AddOrUpdate (null, k => 0, (k, v) => v));
+                       AssertThrowsArgumentNullException (() => map.AddOrUpdate ("", null, (k, v) => v));
+                       AssertThrowsArgumentNullException (() => map.AddOrUpdate ("", k => 0, null));
+                       AssertThrowsArgumentNullException (() => map.AddOrUpdate (null, 0, (k, v) => v));
+                       AssertThrowsArgumentNullException (() => map.AddOrUpdate ("", 0, null));
+                       AssertThrowsArgumentNullException (() => map.ContainsKey (null));
+                       AssertThrowsArgumentNullException (() => map.GetOrAdd (null, 0));
+                       int value;
+                       AssertThrowsArgumentNullException (() => map.TryGetValue (null, out value));
+                       AssertThrowsArgumentNullException (() => map.TryRemove (null, out value));
+                       AssertThrowsArgumentNullException (() => map.TryUpdate (null, 0, 0));
+               } 
+
+               void AssertThrowsArgumentNullException (Action action)
+               {
+                       try {
+                               action ();
+                               Assert.Fail ("Expected ArgumentNullException.");
+                       } catch (ArgumentNullException ex) {
+                       }
+               }
        }
 }
 #endif