Finish reimplementation of ReaderWriterLockSlim so that it pass unit tests.
authorJérémie Laval <jeremie.laval@gmail.com>
Tue, 31 Aug 2010 10:22:02 +0000 (11:22 +0100)
committerJérémie Laval <jeremie.laval@gmail.com>
Tue, 31 Aug 2010 16:01:36 +0000 (17:01 +0100)
mcs/class/System.Core/System.Threading/ReaderWriterLockSlim.cs

index c30a6348a12964fa6fdc57c323c7bc0a050c4fa6..b0f3ba4d56e1dd370f9d3149ce7f7e45700fe781 100644 (file)
@@ -35,43 +35,63 @@ using System.Threading;
 
 namespace System.Threading {
 
-       [HostProtectionAttribute(SecurityAction.LinkDemand, MayLeakOnAbort = true)]
-       [HostProtectionAttribute(SecurityAction.LinkDemand, Synchronization = true, ExternalThreading = true)]
-       public class ReaderWriterLockSlim : IDisposable 
+       [Flags]
+       internal enum ThreadLockState
+       {
+               None = 0,
+               Read = 1,
+               Write = 2,
+               Upgradable = 4,
+               UpgradedRead = 5,
+               UpgradedWrite = 6
+       }
+
+       internal static class ThreadLockStateExtensions
        {
-               enum ThreadLockState
+               internal static bool Has (this ThreadLockState state, ThreadLockState value)
                {
-                       None = 0,
-                       Read,
-                       Write,
-                       Upgradable
+                       return (state & value) > 0;
                }
+       }
 
+       [HostProtectionAttribute(SecurityAction.LinkDemand, MayLeakOnAbort = true)]
+       [HostProtectionAttribute(SecurityAction.LinkDemand, Synchronization = true, ExternalThreading = true)]
+       public class ReaderWriterLockSlim : IDisposable
+       {
                /* Position of each bit isn't really important 
                 * but their relative order is
                 */
                const int RwWaitBit = 0;
-               const int RwWriteBit = 1;
-               const int RwReadBit = 2;
+               const int RwWaitUpgradeBit = 1;
+               const int RwWriteBit = 2;
+               const int RwReadBit = 3;
 
                const int RwWait = 1;
-               const int RwWrite = 2;
-               const int RwRead = 4;
+               const int RwWaitUpgrade = 2;
+               const int RwWrite = 4;
+               const int RwRead = 8;
 
                int rwlock;
                
                readonly LockRecursionPolicy recursionPolicy;
-               AtomicBoolean upgradableTaken;
+
+               AtomicBoolean upgradableTaken = new AtomicBoolean ();
+               ManualResetEventSlim upgradableEvent = new ManualResetEventSlim (true);
+
+               int numReadWaiters, numUpgradeWaiters, numWriteWaiters;
+               bool disposed;
 
                [ThreadStatic]
-               IDictionary<ReaderWriterLockSlim, ThreadLockState> currentThreadState;
+               static IDictionary<ReaderWriterLockSlim, ThreadLockState> currentThreadState;
 
-               public ReaderWriterLockSlim (LockRecursionPolicy.None)
+               public ReaderWriterLockSlim () : this (LockRecursionPolicy.NoRecursion)
                {
                }
 
                public ReaderWriterLockSlim (LockRecursionPolicy recursionPolicy)
                {
+                       if (recursionPolicy != LockRecursionPolicy.NoRecursion)
+                               throw new NotSupportedException ("Creating a recursion-aware reader-writer lock is not yet supported");
                        this.recursionPolicy = recursionPolicy;
                }
 
@@ -84,37 +104,51 @@ namespace System.Threading {
                {
                        if (CheckState (millisecondsTimeout, ThreadLockState.Read))
                                return true;
+
+                       // This is downgrading from upgradable, no need for check since
+                       // we already have a sort-of read lock that's going to disappear
+                       // after user calls ExitUpgradeableReadLock
+                       if (CurrentThreadState.Has (ThreadLockState.Upgradable)) {
+                               Interlocked.Add (ref rwlock, RwRead);
+                               CurrentThreadState = CurrentThreadState ^ ThreadLockState.Read;
+                               return true;
+                       }
                        
                        Stopwatch sw = Stopwatch.StartNew ();
-                       while (millisecondsTimeout < || && sw.ElapsedMilliseconds < millisecondsTimeout) {
-                               if ((rwlock & (RwWrite | RwWait)) > 0) {
-                                       // Should sleep
+                       Interlocked.Increment (ref numReadWaiters);
+
+                       while (millisecondsTimeout == -1 || sw.ElapsedMilliseconds < millisecondsTimeout) {
+                               if ((rwlock & 0x7) > 0) {
+                                       Thread.Sleep (1);
                                        continue;
                                }
-                               
-                               if ((Interlocked.Add (ref rwlock, RwRead) & (RwWrite | RwWait)) == 0) {
-                                       CurrentThreadState = ThreadLockState.Read;
+
+                               if ((Interlocked.Add (ref rwlock, RwRead) & 0x7) == 0) {
+                                       CurrentThreadState = CurrentThreadState ^ ThreadLockState.Read;
+                                       Interlocked.Decrement (ref numReadWaiters);
                                        return true;
                                }
 
                                Interlocked.Add (ref rwlock, -RwRead);
-                               // Should sleep
+
+                               Thread.Sleep (1);
                        }
 
+                       Interlocked.Decrement (ref numReadWaiters);
                        return false;
                }
 
                public bool TryEnterReadLock (TimeSpan timeout)
                {
-                       return TryEnterReadLock (timeout.TotalMilliseconds);
+                       return TryEnterReadLock (CheckTimeout (timeout));
                }
 
                public void ExitReadLock ()
                {
-                       if (CurrentThreadState != Read)
+                       if (CurrentThreadState != ThreadLockState.Read)
                                throw new SynchronizationLockException ("The current thread has not entered the lock in read mode");
-                       
-                       CurrentThreadState = None;
+
+                       CurrentThreadState = ThreadLockState.None;
                        Interlocked.Add (ref rwlock, -RwRead);
                }
 
@@ -125,45 +159,54 @@ namespace System.Threading {
                
                public bool TryEnterWriteLock (int millisecondsTimeout)
                {
-                       if (CheckState (millisecondsTimeout, ThreadLockState.Write))
+                       bool isUpgradable = CurrentThreadState.Has (ThreadLockState.Upgradable);
+                       if (CheckState (millisecondsTimeout, isUpgradable ? ThreadLockState.UpgradedWrite : ThreadLockState.Write))
                                return true;
 
                        Stopwatch sw = Stopwatch.StartNew ();
+                       Interlocked.Increment (ref numWriteWaiters);
+
+                       // If the code goes there that means we had a read lock beforehand
+                       if (isUpgradable && rwlock >= RwRead)
+                               Interlocked.Add (ref rwlock, -RwRead);
+
+                       int stateCheck = isUpgradable ? RwWaitUpgrade : RwWait;
+                       int appendValue = RwWait | (isUpgradable ? RwWaitUpgrade : 0);
 
                        while (millisecondsTimeout < 0 || sw.ElapsedMilliseconds < millisecondsTimeout) {
                                int state = rwlock;
 
-                               if (state < RwWrite) {
-                                       if (Interlocked.CompareExchange (ref rwlock, state, RwWrite) == state) {
-                                               CurrentThreadState = Write;
+                               if (state <= stateCheck) {
+                                       if (Interlocked.CompareExchange (ref rwlock, RwWrite, state) == state) {
+                                               CurrentThreadState = isUpgradable ? ThreadLockState.UpgradedWrite : ThreadLockState.Write;
+                                               Interlocked.Decrement (ref numWriteWaiters);
                                                return true;
                                        }
                                        state = rwlock;
                                }
 
-                               while ((state & RwWait) == 0 && Interlocked.CompareExchange (ref rwlock, state, state | RwWait) != state)
+                               while ((state & RwWait) == 0 && Interlocked.CompareExchange (ref rwlock, state | appendValue, state) == state)
                                        state = rwlock;
 
-                               while (rwlock > RwWait && (millisecondsTimeout < 0 || sw.ElapsedMilliseconds < millisecondsTimeout)) {
-                                       // Should wait here
-                                       
-                               }
+                               while (rwlock > stateCheck && (millisecondsTimeout < 0 || sw.ElapsedMilliseconds < millisecondsTimeout))
+                                       Thread.Sleep (1);
                        }
 
+                       Interlocked.Decrement (ref numWriteWaiters);
                        return false;
                }
 
                public bool TryEnterWriteLock (TimeSpan timeout)
                {
-                       return TryEnterWriteLock (timeout.TotalMilliseconds);
+                       return TryEnterWriteLock (CheckTimeout (timeout));
                }
 
                public void ExitWriteLock ()
                {
-                       if (CurrentThreadState != Read)
-                               throw new SynchronizationLockException ("The current thread has not entered the lock in read mode");
+                       if (!CurrentThreadState.Has (ThreadLockState.Write))
+                               throw new SynchronizationLockException ("The current thread has not entered the lock in write mode");
                        
-                       CurrentThreadState = None;
+                       CurrentThreadState = CurrentThreadState ^ ThreadLockState.Write;
                        Interlocked.Add (ref rwlock, -RwWrite);
                }
 
@@ -181,8 +224,35 @@ namespace System.Threading {
                        if (CheckState (millisecondsTimeout, ThreadLockState.Upgradable))
                                return true;
 
-                       if (CurrentThreadState == ThreadLockState.Read)
-                               throw new LockRecursionException ("The current thread has already entered read mode");                          
+                       if (CurrentThreadState.Has (ThreadLockState.Read))
+                               throw new LockRecursionException ("The current thread has already entered read mode");
+
+                       Stopwatch sw = Stopwatch.StartNew ();
+                       Interlocked.Increment (ref numUpgradeWaiters);
+
+                       while (!upgradableEvent.IsSet || !upgradableTaken.TryRelaxedSet ()) {
+                               if (millisecondsTimeout != -1 && sw.ElapsedMilliseconds > millisecondsTimeout) {
+                                       Interlocked.Decrement (ref numUpgradeWaiters);
+                                       return false;
+                               }
+
+                               upgradableEvent.Wait (ComputeTimeout (millisecondsTimeout, sw));
+                       }
+
+                       upgradableEvent.Reset ();
+
+                       if (TryEnterReadLock (ComputeTimeout (millisecondsTimeout, sw))) {
+                               CurrentThreadState = ThreadLockState.Upgradable;
+                               Interlocked.Decrement (ref numUpgradeWaiters);
+                               return true;
+                       }
+
+                       upgradableTaken.Value = false;
+                       upgradableEvent.Set ();
+
+                       Interlocked.Decrement (ref numUpgradeWaiters);
+
+                       return false;
                }
 
                public bool TryEnterUpgradeableReadLock (TimeSpan timeout)
@@ -192,79 +262,87 @@ namespace System.Threading {
               
                public void ExitUpgradeableReadLock ()
                {
+                       if (!CurrentThreadState.Has (ThreadLockState.Upgradable | ThreadLockState.Read))
+                               throw new SynchronizationLockException ("The current thread has not entered the lock in upgradable mode");
                        
-               }
+                       upgradableTaken.Value = false;
+                       upgradableEvent.Set ();
 
-               bool CheckState (int millisecondsTimeout, ThreadLockState validState)
-               {
-                       if (millisecondsTimeout < Timeout.Infinite)
-                               throw new ArgumentOutOfRangeException ("millisecondsTimeout");
-                       
-                       // Detect and prevent recursion
-                       if (recursionPolicy == LockRecursionPolicy.None && CurrentThreadState != None)
-                               throw new LockRecursionException ("The current thread has already a lock and recursion isn't supported");
-                       
-                       // If we already had write lock, just return
-                       if (CurrentThreadState == validState)
-                               return true;
+                       CurrentThreadState = CurrentThreadState ^ ThreadLockState.Upgradable;
+                       Interlocked.Add (ref rwlock, -RwRead);
                }
 
                public void Dispose ()
                {
-                       read_locks = null;
+                       disposed = true;
                }
 
                public bool IsReadLockHeld {
-                       get { return RecursiveReadCount != 0; }
+                       get {
+                               return rwlock >= RwRead;
+                       }
                }
                
                public bool IsWriteLockHeld {
-                       get { return RecursiveWriteCount != 0; }
+                       get {
+                               return (rwlock & RwWrite) > 0;
+                       }
                }
                
                public bool IsUpgradeableReadLockHeld {
-                       get { return RecursiveUpgradeCount != 0; }
+                       get {
+                               return upgradableTaken.Value;
+                       }
                }
 
                public int CurrentReadCount {
-                       get { return owners & 0xFFFFFFF; }
+                       get {
+                               return (rwlock >> RwReadBit) - (IsUpgradeableReadLockHeld ? 1 : 0);
+                       }
                }
                
                public int RecursiveReadCount {
                        get {
-                               EnterMyLock ();
-                               LockDetails ld = GetReadLockDetails (Thread.CurrentThread.ManagedThreadId, false);
-                               int count = ld == null ? 0 : ld.ReadLocks;
-                               ExitMyLock ();
-                               return count;
+                               return IsReadLockHeld ? IsUpgradeableReadLockHeld ? 0 : 1 : 0;
                        }
                }
 
                public int RecursiveUpgradeCount {
-                       get { return upgradable_thread == Thread.CurrentThread ? 1 : 0; }
+                       get {
+                               return IsUpgradeableReadLockHeld ? 1 : 0;
+                       }
                }
 
                public int RecursiveWriteCount {
-                       get { return write_thread == Thread.CurrentThread ? 1 : 0; }
+                       get {
+                               return IsWriteLockHeld ? 1 : 0;
+                       }
                }
 
                public int WaitingReadCount {
-                       get { return (int) numReadWaiters; }
+                       get {
+                               return numReadWaiters;
+                       }
                }
 
                public int WaitingUpgradeCount {
-                       get { return (int) numUpgradeWaiters; }
+                       get {
+                               return numUpgradeWaiters;
+                       }
                }
 
                public int WaitingWriteCount {
-                       get { return (int) numWriteWaiters; }
+                       get {
+                               return numWriteWaiters;
+                       }
                }
 
                public LockRecursionPolicy RecursionPolicy {
-                       get { return recursionPolicy; }
+                       get {
+                               return recursionPolicy;
+                       }
                }
                
-#region Private methods
                ThreadLockState CurrentThreadState {
                        get {
                                // TODO: provide a IEqualityComparer thingie to have better hashes
@@ -284,6 +362,42 @@ namespace System.Threading {
                                currentThreadState[this] = value;
                        }
                }
-#endregion
+
+               bool CheckState (int millisecondsTimeout, ThreadLockState validState)
+               {
+                       if (disposed)
+                               throw new ObjectDisposedException ("ReaderWriterLockSlim");
+
+                       if (millisecondsTimeout < Timeout.Infinite)
+                               throw new ArgumentOutOfRangeException ("millisecondsTimeout");
+
+                       // Detect and prevent recursion
+                       ThreadLockState ctstate = CurrentThreadState;
+
+                       if (recursionPolicy == LockRecursionPolicy.NoRecursion)
+                               if ((ctstate != ThreadLockState.None && ctstate != ThreadLockState.Upgradable)
+                                   || (ctstate == ThreadLockState.Upgradable && validState == ThreadLockState.Upgradable))
+                                       throw new LockRecursionException ("The current thread has already a lock and recursion isn't supported");
+
+                       // If we already had lock, just return
+                       if (CurrentThreadState == validState)
+                               return true;
+
+                       return false;
+               }
+
+               static int CheckTimeout (TimeSpan timeout)
+               {
+                       try {
+                               return checked ((int)timeout.TotalMilliseconds);
+                       } catch (System.OverflowException) {
+                               throw new ArgumentOutOfRangeException ("timeout");
+                       }
+               }
+
+               static int ComputeTimeout (int millisecondsTimeout, Stopwatch sw)
+               {
+                       return millisecondsTimeout == -1 ? -1 : (int)Math.Max (sw.ElapsedMilliseconds - millisecondsTimeout, 1);
+               }
        }
 }