Fix for bug 349053 - unable to serialize SortedDictionary
authorLuke Halliwell <luke.j.halliwell@gmail.com>
Sun, 26 Dec 2010 16:26:44 +0000 (08:26 -0800)
committerGonzalo Paniagua Javier <gonzalo.mono@gmail.com>
Tue, 4 Jan 2011 02:43:27 +0000 (21:43 -0500)
mcs/class/System/System.Collections.Generic/SortedDictionary.cs
mcs/class/System/Test/System.Collections.Generic/SortedDictionaryTest.cs

index b535aa6881d983d981d8397ad951e1b949737acc..e8354d32d9d10ce21ebad572ba70b7f0349b4211 100644 (file)
 using System;
 using System.Collections;
 using System.Diagnostics;
+using System.Runtime.Serialization;
+using System.Security.Permissions;
 
 namespace System.Collections.Generic
 {
        [Serializable]
        [DebuggerDisplay ("Count={Count}")]
        [DebuggerTypeProxy (typeof (CollectionDebuggerView<,>))]
-       public class SortedDictionary<TKey,TValue> : IDictionary<TKey,TValue>, ICollection<KeyValuePair<TKey,TValue>>, IEnumerable<KeyValuePair<TKey,TValue>>, IDictionary, ICollection, IEnumerable
+       public class SortedDictionary<TKey,TValue> : IDictionary<TKey,TValue>, ICollection<KeyValuePair<TKey,TValue>>, IEnumerable<KeyValuePair<TKey,TValue>>, IDictionary, ICollection, IEnumerable, ISerializable
        {
                class Node : RBTree.Node {
                        public TKey key;
@@ -76,6 +78,7 @@ namespace System.Collections.Generic
                        }
                }
 
+               [Serializable]
                class NodeHelper : RBTree.INodeHelper<TKey> {
                        public IComparer<TKey> cmp;
 
@@ -127,6 +130,17 @@ namespace System.Collections.Generic
                        foreach (KeyValuePair<TKey, TValue> entry in dic)
                                Add (entry.Key, entry.Value);
                }
+
+               protected SortedDictionary (SerializationInfo info, StreamingContext context)
+               {
+                       hlp = (NodeHelper)info.GetValue("Helper", typeof(NodeHelper));
+                       tree = new RBTree (hlp);
+
+                       KeyValuePair<TKey, TValue> [] data = (KeyValuePair<TKey, TValue>[])info.GetValue("KeyValuePairs", typeof(KeyValuePair<TKey, TValue>[]));
+                       foreach (KeyValuePair<TKey, TValue> entry in data)
+                               Add(entry.Key, entry.Value);
+               }
+
                #endregion
 
                #region PublicProperty
@@ -226,6 +240,18 @@ namespace System.Collections.Generic
                        return n != null;
                }
 
+               [SecurityPermission (SecurityAction.LinkDemand, Flags=SecurityPermissionFlag.SerializationFormatter)]
+               public virtual void GetObjectData (SerializationInfo info, StreamingContext context)
+               {
+                       if (info == null)
+                               throw new ArgumentNullException ("info");
+
+                       KeyValuePair<TKey, TValue> [] data = new KeyValuePair<TKey,TValue> [Count];
+                       CopyTo (data, 0);
+                       info.AddValue ("KeyValuePairs", data);
+                       info.AddValue ("Helper", hlp);
+               }
+
                #endregion
 
                #region PrivateMethod
@@ -280,7 +306,7 @@ namespace System.Collections.Generic
                {
                        TValue value;
                        return TryGetValue (item.Key, out value) &&
-                               EqualityComparer<TValue>.Default.Equals (item.Value, value) &&
+                               EqualityComparer<TValue>.Default.Equals (item.Value, value) &&
                                Remove (item.Key);
                }
 
@@ -466,7 +492,7 @@ namespace System.Collections.Generic
 
                        IEnumerator IEnumerable.GetEnumerator ()
                        {
-                               return new Enumerator (_dic);
+                               return new Enumerator (_dic);
                        }
 
                        public struct Enumerator : IEnumerator<TValue>,IEnumerator, IDisposable
@@ -601,7 +627,7 @@ namespace System.Collections.Generic
 
                        IEnumerator IEnumerable.GetEnumerator ()
                        {
-                               return new Enumerator (_dic);
+                               return new Enumerator (_dic);
                        }
 
                        public struct Enumerator : IEnumerator<TKey>, IEnumerator, IDisposable
index 5fd1f6bcd1df972275e0fb469dcba25bcfc846d5..bb1c225f63f0f7de813a84cb56154c26868e0446 100644 (file)
 #if NET_2_0
 
 using System;
+using System.IO;
 using System.Collections;
 using System.Collections.Generic;
-using System.Runtime.Serialization;
+using System.Runtime.Serialization.Formatters.Binary;
 
 using NUnit.Framework;
 
 namespace MonoTests.System.Collections.Generic
 {
        [TestFixture]
-        public class SortedDictionaryTest
+       public class SortedDictionaryTest
        {
                [Test]
                public void CtorNullComparer ()
@@ -565,8 +566,58 @@ namespace MonoTests.System.Collections.Generic
                        ((IDisposable) e4).Dispose ();
                        Assert.IsTrue (Throws (delegate { var x = e4.Current; GC.KeepAlive (x); }));
                }
+
+               // Serialize a dictionary out and deserialize it back in again
+               SortedDictionary<int, string> Roundtrip(SortedDictionary<int, string> dic)
+               {
+                       BinaryFormatter bf = new BinaryFormatter ();
+                       MemoryStream stream = new MemoryStream ();
+                       bf.Serialize (stream, dic);
+                       stream.Position = 0;
+                       return (SortedDictionary<int, string>)bf.Deserialize (stream);
+               }
+           
+               [Test]
+               public void Serialize()
+               {
+                       SortedDictionary<int, string> test = new SortedDictionary<int, string>();
+                       test.Add(1, "a");
+                       test.Add(3, "c");
+                       test.Add(2, "b");
+
+                       SortedDictionary<int, string> result = Roundtrip(test);
+                       Assert.AreEqual(3, result.Count);
+
+                       Assert.AreEqual("a", result[1]);
+                       Assert.AreEqual("b", result[2]);
+                       Assert.AreEqual("c", result[3]);
+               }
+
+               [Test]
+               public void SerializeReverseComparer()
+               {
+                       SortedDictionary<int,string> test =
+                               new SortedDictionary<int,string> (
+                                       ReverseComparer<int>.Instance);
+
+                       test.Add (1, "A");
+                       test.Add (3, "B");
+                       test.Add (2, "C");
+
+                       SortedDictionary<int,string> result = Roundtrip (test);
+                   
+                       SortedDictionary<int,string>.Enumerator e = result.GetEnumerator ();
+                       Assert.IsTrue (e.MoveNext (), "#1");
+                       Assert.AreEqual ("B", e.Current.Value, "#2");
+                       Assert.IsTrue (e.MoveNext (), "#3");
+                       Assert.AreEqual ("C", e.Current.Value, "#4");
+                       Assert.IsTrue (e.MoveNext (), "#5");
+                       Assert.AreEqual ("A", e.Current.Value, "#6");
+                       Assert.IsFalse (e.MoveNext (), "#7");
+               }
        }
 
+       [Serializable]
        class ReverseComparer<T> : IComparer<T>
        {
                static ReverseComparer<T> instance = new ReverseComparer<T> ();