Include set accessor in ThreadLocal's Value property. Catch possible exception in...
authorJérémie Laval <jeremie.laval@gmail.com>
Wed, 4 Aug 2010 17:21:30 +0000 (19:21 +0200)
committerJérémie Laval <jeremie.laval@gmail.com>
Wed, 4 Aug 2010 17:27:09 +0000 (19:27 +0200)
mcs/class/corlib/System.Threading/ThreadLocal.cs
mcs/class/corlib/Test/System.Threading/ThreadLazyTests.cs

index 24d3267cd1467250064f06b3a250b810c53d4797..6f5b339278a2d31f2c3c9efda3dc13563582b6e5 100644 (file)
@@ -38,9 +38,11 @@ namespace System.Threading
        {
                readonly Func<T> initializer;
                LocalDataStoreSlot localStore;
+               Exception cachedException;
                
                class DataSlotWrapper
                {
+                       public bool Creating;
                        public bool Init;
                        public Func<T> Getter;
                }
@@ -60,7 +62,7 @@ namespace System.Threading
                
                public void Dispose ()
                {
-                       Dispose(true);
+                       Dispose (true);
                }
                
                protected virtual void Dispose (bool dispManagedRes)
@@ -70,43 +72,62 @@ namespace System.Threading
                
                public bool IsValueCreated {
                        get {
+                               ThrowIfNeeded ();
                                return IsInitializedThreadLocal ();
                        }
                }
                
                public T Value {
                        get {
+                               ThrowIfNeeded ();
                                return GetValueThreadLocal ();
                        }
+                       set {
+                               ThrowIfNeeded ();
+
+                               DataSlotWrapper w = GetWrapper ();
+                               w.Init = true;
+                               w.Getter = () => value;
+                       }
                }
                
                public override string ToString ()
                {
-                       return string.Format("[ThreadLocal: IsValueCreated={0}, Value={1}]", IsValueCreated, Value);
+                       return string.Format ("[ThreadLocal: IsValueCreated={0}, Value={1}]", IsValueCreated, Value);
                }
 
                
                T GetValueThreadLocal ()
                {
-                       DataSlotWrapper myWrapper = Thread.GetData (localStore) as DataSlotWrapper;
-                       // In case it's the first time the Thread access its data
-                       if (myWrapper == null) {
-                               myWrapper = DataSlotCreator ();
-                               Thread.SetData (localStore, myWrapper);
-                       }
-                               
-                       return myWrapper.Getter();
+                       DataSlotWrapper myWrapper = GetWrapper ();
+                       if (myWrapper.Creating)
+                               throw new InvalidOperationException ("The initialization function attempted to reference Value recursively");
+
+                       return myWrapper.Getter ();
                }
                
                bool IsInitializedThreadLocal ()
+               {
+                       DataSlotWrapper myWrapper = GetWrapper ();
+
+                       return myWrapper.Init;
+               }
+
+               DataSlotWrapper GetWrapper ()
                {
                        DataSlotWrapper myWrapper = (DataSlotWrapper)Thread.GetData (localStore);
                        if (myWrapper == null) {
                                myWrapper = DataSlotCreator ();
                                Thread.SetData (localStore, myWrapper);
                        }
-                       
-                       return myWrapper.Init;
+
+                       return myWrapper;
+               }
+
+               void ThrowIfNeeded ()
+               {
+                       if (cachedException != null)
+                               throw cachedException;
                }
 
                DataSlotWrapper DataSlotCreator ()
@@ -115,10 +136,17 @@ namespace System.Threading
                        Func<T> valSelector = initializer;
        
                        wrapper.Getter = delegate {
-                               T val = valSelector ();
-                               wrapper.Init = true;
-                               wrapper.Getter = delegate { return val; };
-                               return val;
+                               wrapper.Creating = true;
+                               try {
+                                       T val = valSelector ();
+                                       wrapper.Creating = false;
+                                       wrapper.Init = true;
+                                       wrapper.Getter = () => val;
+                                       return val;
+                               } catch (Exception e) {
+                                       cachedException = e;
+                                       throw e;
+                               }
                        };
                        
                        return wrapper;
index abf59a202dcad9ea7b7ed638ff3401b25bb3802a..1488b781f0550f5f0031ae83363a3d952ed5235b 100644 (file)
@@ -62,7 +62,50 @@ namespace MonoTests.System.Threading
                        t.Start ();
                        t.Join ();
                }
-               
+
+               [Test]
+               public void InitializeThrowingTest ()
+               {
+                       int callTime = 0;
+                       threadLocal = new ThreadLocal<int> (() => {
+                                       Interlocked.Increment (ref callTime);
+                                       throw new ApplicationException ("foo");
+                                       return 43;
+                               });
+
+                       Exception exception = null;
+
+                       try {
+                               var foo = threadLocal.Value;
+                       } catch (Exception e) {
+                               exception = e;
+                       }
+
+                       Assert.IsNotNull (exception, "#1");
+                       Assert.IsInstanceOfType (typeof (ApplicationException), exception, "#2");
+                       Assert.AreEqual (1, callTime, "#3");
+
+                       exception = null;
+
+                       try {
+                               var foo = threadLocal.Value;
+                       } catch (Exception e) {
+                               exception = e;
+                       }
+
+                       Assert.IsNotNull (exception, "#4");
+                       Assert.IsInstanceOfType (typeof (ApplicationException), exception, "#5");
+                       Assert.AreEqual (1, callTime, "#6");
+               }
+
+               [Test, ExpectedException (typeof (InvalidOperationException))]
+               public void MultipleReferenceToValueTest ()
+               {
+                       threadLocal = new ThreadLocal<int> (() => threadLocal.Value + 1);
+
+                       var value = threadLocal.Value;
+               }
+
                void AssertThreadLocal ()
                {
                        Assert.IsFalse (threadLocal.IsValueCreated, "#1");