{
readonly Func<T> initializer;
LocalDataStoreSlot localStore;
+ Exception cachedException;
class DataSlotWrapper
{
+ public bool Creating;
public bool Init;
public Func<T> Getter;
}
public void Dispose ()
{
- Dispose(true);
+ Dispose (true);
}
protected virtual void Dispose (bool dispManagedRes)
public bool IsValueCreated {
get {
+ ThrowIfNeeded ();
return IsInitializedThreadLocal ();
}
}
public T Value {
get {
+ ThrowIfNeeded ();
return GetValueThreadLocal ();
}
+ set {
+ ThrowIfNeeded ();
+
+ DataSlotWrapper w = GetWrapper ();
+ w.Init = true;
+ w.Getter = () => value;
+ }
}
public override string ToString ()
{
- return string.Format("[ThreadLocal: IsValueCreated={0}, Value={1}]", IsValueCreated, Value);
+ return string.Format ("[ThreadLocal: IsValueCreated={0}, Value={1}]", IsValueCreated, Value);
}
T GetValueThreadLocal ()
{
- DataSlotWrapper myWrapper = Thread.GetData (localStore) as DataSlotWrapper;
- // In case it's the first time the Thread access its data
- if (myWrapper == null) {
- myWrapper = DataSlotCreator ();
- Thread.SetData (localStore, myWrapper);
- }
-
- return myWrapper.Getter();
+ DataSlotWrapper myWrapper = GetWrapper ();
+ if (myWrapper.Creating)
+ throw new InvalidOperationException ("The initialization function attempted to reference Value recursively");
+
+ return myWrapper.Getter ();
}
bool IsInitializedThreadLocal ()
+ {
+ DataSlotWrapper myWrapper = GetWrapper ();
+
+ return myWrapper.Init;
+ }
+
+ DataSlotWrapper GetWrapper ()
{
DataSlotWrapper myWrapper = (DataSlotWrapper)Thread.GetData (localStore);
if (myWrapper == null) {
myWrapper = DataSlotCreator ();
Thread.SetData (localStore, myWrapper);
}
-
- return myWrapper.Init;
+
+ return myWrapper;
+ }
+
+ void ThrowIfNeeded ()
+ {
+ if (cachedException != null)
+ throw cachedException;
}
DataSlotWrapper DataSlotCreator ()
Func<T> valSelector = initializer;
wrapper.Getter = delegate {
- T val = valSelector ();
- wrapper.Init = true;
- wrapper.Getter = delegate { return val; };
- return val;
+ wrapper.Creating = true;
+ try {
+ T val = valSelector ();
+ wrapper.Creating = false;
+ wrapper.Init = true;
+ wrapper.Getter = () => val;
+ return val;
+ } catch (Exception e) {
+ cachedException = e;
+ throw e;
+ }
};
return wrapper;
t.Start ();
t.Join ();
}
-
+
+ [Test]
+ public void InitializeThrowingTest ()
+ {
+ int callTime = 0;
+ threadLocal = new ThreadLocal<int> (() => {
+ Interlocked.Increment (ref callTime);
+ throw new ApplicationException ("foo");
+ return 43;
+ });
+
+ Exception exception = null;
+
+ try {
+ var foo = threadLocal.Value;
+ } catch (Exception e) {
+ exception = e;
+ }
+
+ Assert.IsNotNull (exception, "#1");
+ Assert.IsInstanceOfType (typeof (ApplicationException), exception, "#2");
+ Assert.AreEqual (1, callTime, "#3");
+
+ exception = null;
+
+ try {
+ var foo = threadLocal.Value;
+ } catch (Exception e) {
+ exception = e;
+ }
+
+ Assert.IsNotNull (exception, "#4");
+ Assert.IsInstanceOfType (typeof (ApplicationException), exception, "#5");
+ Assert.AreEqual (1, callTime, "#6");
+ }
+
+ [Test, ExpectedException (typeof (InvalidOperationException))]
+ public void MultipleReferenceToValueTest ()
+ {
+ threadLocal = new ThreadLocal<int> (() => threadLocal.Value + 1);
+
+ var value = threadLocal.Value;
+ }
+
void AssertThreadLocal ()
{
Assert.IsFalse (threadLocal.IsValueCreated, "#1");