Refactor Add/Take/TryAdd/TryTake methods into a common block in BlockingCollection
authorJérémie Laval <jeremie.laval@gmail.com>
Thu, 9 Dec 2010 15:35:51 +0000 (15:35 +0000)
committerJérémie Laval <jeremie.laval@gmail.com>
Thu, 9 Dec 2010 16:00:15 +0000 (16:00 +0000)
mcs/class/System/System.Collections.Concurrent/BlockingCollection.cs

index fa3f5aa179530ca2564afb17baecdf51ac147ae9..2b56cc634d904896d0f6393b211d2ede95ff37b8 100644 (file)
@@ -48,7 +48,7 @@ namespace System.Collections.Concurrent
                long completeId;
 
                /* The whole idea of the collection is to use these two long values in a transactional
-                * to track and manage the actual data inside the underlying lock-free collection
+                * way to track and manage the actual data inside the underlying lock-free collection
                 * instead of directly working with it or using external locking.
                 *
                 * They are manipulated with CAS and are guaranteed to increase over time and use
@@ -63,6 +63,10 @@ namespace System.Collections.Concurrent
                ManualResetEventSlim mreAdd = new ManualResetEventSlim (true);
                ManualResetEventSlim mreRemove = new ManualResetEventSlim (true);
 
+               /* For time based operations, we share this instance of Stopwatch and base calculation
+                  on a time offset at each of these method call */
+               static Stopwatch watch = new Stopwatch ();
+
                #region ctors
                public BlockingCollection ()
                        : this (new ConcurrentQueue<T> (), -1)
@@ -95,129 +99,65 @@ namespace System.Collections.Concurrent
 
                public void Add (T item, CancellationToken token)
                {
-                       SpinWait sw = new SpinWait ();
-                       long cachedAddId;
-
-                       while (true) {
-                               token.ThrowIfCancellationRequested ();
-
-                               cachedAddId = addId;
-                               long cachedRemoveId = removeId;
-
-                               if (upperBound != -1) {
-                                       if (cachedAddId - cachedRemoveId > upperBound) {
-                                               if (sw.Count <= spinCount) {
-                                                       sw.SpinOnce ();
-                                               } else {
-                                                       if (mreRemove.IsSet)
-                                                               continue;
-                                                       if (cachedRemoveId != removeId)
-                                                               continue;
-
-                                                       mreRemove.Wait (token);
-                                                       mreRemove.Reset ();
-                                               }
-
-                                               continue;
-                                       }
-                               }
-
-                               // Check our transaction id against completed stored one
-                               if (isComplete.Value && cachedAddId >= completeId)
-                                       ThrowCompleteException ();
-                               if (Interlocked.CompareExchange (ref addId, cachedAddId + 1, cachedAddId) == cachedAddId)
-                                       break;
-                       }
-
-                       if (isComplete.Value && cachedAddId >= completeId)
-                               ThrowCompleteException ();
-
-                       while (!underlyingColl.TryAdd (item));
-
-                       if (!mreAdd.IsSet)
-                               mreAdd.Set ();
+                       TryAdd (item, -1, token);
                }
 
-               public T Take ()
+               public bool TryAdd (T item)
                {
-                       return Take (CancellationToken.None);
+                       return TryAdd (item, 0, CancellationToken.None);
                }
 
-               public T Take (CancellationToken token)
+               public bool TryAdd (T item, int milliseconds, CancellationToken token)
                {
+                       if (milliseconds < -1)
+                               throw new ArgumentOutOfRangeException ("milliseconds");
+
+                       long start = milliseconds == -1 ? 0 : watch.ElapsedMilliseconds;
                        SpinWait sw = new SpinWait ();
 
-                       while (true) {
+                       do {
                                token.ThrowIfCancellationRequested ();
 
-                               long cachedRemoveId = removeId;
                                long cachedAddId = addId;
+                               long cachedRemoveId = removeId;
 
-                               // Empty case
-                               if (cachedRemoveId == cachedAddId) {
-                                       if (IsCompleted)
-                                               ThrowCompleteException ();
+                               // If needed, we check and wait that the collection isn't full
+                               if (upperBound != -1 && cachedAddId - cachedRemoveId > upperBound) {
+                                       if (milliseconds == 0)
+                                               return false;
 
                                        if (sw.Count <= spinCount) {
                                                sw.SpinOnce ();
                                        } else {
-                                               if (cachedAddId != addId)
+                                               mreRemove.Reset ();
+                                               if (cachedRemoveId != removeId || cachedAddId != addId) {
+                                                       mreRemove.Set ();
                                                        continue;
-                                               if (IsCompleted)
-                                                       ThrowCompleteException ();
+                                               }
 
-                                               mreAdd.Wait (token);
-                                               mreAdd.Reset ();
+                                               mreRemove.Wait (ComputeTimeout (milliseconds, start), token);
                                        }
 
                                        continue;
                                }
 
-                               if (Interlocked.CompareExchange (ref removeId, cachedRemoveId + 1, cachedRemoveId) == cachedRemoveId)
-                                       break;
-                       }
-
-                       T item;
-                       while (!underlyingColl.TryTake (out item));
-
-                       if (!mreRemove.IsSet)
-                               mreRemove.Set ();
-
-                       return item;
-               }
-
-               public bool TryAdd (T item)
-               {
-                       return TryAdd (item, () => false, CancellationToken.None);
-               }
-
-               bool TryAdd (T item, Func<bool> contFunc, CancellationToken token)
-               {
-                       do {
-                               token.ThrowIfCancellationRequested ();
-
-                               long cachedAddId = addId;
-                               long cachedRemoveId = removeId;
-
-                               if (upperBound != -1)
-                                       if (cachedAddId - cachedRemoveId > upperBound)
-                                               continue;
-
                                // Check our transaction id against completed stored one
                                if (isComplete.Value && cachedAddId >= completeId)
-                                       throw new InvalidOperationException ("The BlockingCollection<T> has"
-                                                                            + " been marked as complete with regards to additions.");
+                                       ThrowCompleteException ();
 
+                               // Validate the steps we have been doing until now
                                if (Interlocked.CompareExchange (ref addId, cachedAddId + 1, cachedAddId) != cachedAddId)
                                        continue;
 
-                               while (!underlyingColl.TryAdd (item));
+                               // We have a slot reserved in the underlying collection, try to take it
+                               if (!underlyingColl.TryAdd (item))
+                                       throw new InvalidOperationException ("The underlying collection didn't accept the item.");
 
-                               if (!mreAdd.IsSet)
-                                       mreAdd.Set ();
+                               // Wake up process that may have been sleeping
+                               mreAdd.Set ();
 
                                return true;
-                       } while (contFunc ());
+                       } while (milliseconds == -1 || (watch.ElapsedMilliseconds - start) < milliseconds);
 
                        return false;
                }
@@ -229,24 +169,40 @@ namespace System.Collections.Concurrent
 
                public bool TryAdd (T item, int millisecondsTimeout)
                {
-                       Stopwatch sw = Stopwatch.StartNew ();
-                       return TryAdd (item, () => sw.ElapsedMilliseconds < millisecondsTimeout, CancellationToken.None);
+                       return TryAdd (item, millisecondsTimeout, CancellationToken.None);
+               }
+
+               public T Take ()
+               {
+                       return Take (CancellationToken.None);
                }
 
-               public bool TryAdd (T item, int millisecondsTimeout, CancellationToken token)
+               public T Take (CancellationToken token)
                {
-                       Stopwatch sw = Stopwatch.StartNew ();
-                       return TryAdd (item, () => sw.ElapsedMilliseconds < millisecondsTimeout, token);
+                       T item;
+                       TryTake (out item, -1, token, true);
+
+                       return item;
                }
 
                public bool TryTake (out T item)
                {
-                       return TryTake (out item, () => false, CancellationToken.None);
+                       return TryTake (out item, 0, CancellationToken.None);
                }
 
-               bool TryTake (out T item, Func<bool> contFunc, CancellationToken token)
+               public bool TryTake (out T item, int millisecondsTimeout, CancellationToken token)
                {
+                       return TryTake (out item, millisecondsTimeout, token, false);
+               }
+
+               bool TryTake (out T item, int milliseconds, CancellationToken token, bool throwComplete)
+               {
+                       if (milliseconds < -1)
+                               throw new ArgumentOutOfRangeException ("milliseconds");
+
                        item = default (T);
+                       SpinWait sw = new SpinWait ();
+                       long start = milliseconds == -1 ? 0 : watch.ElapsedMilliseconds;
 
                        do {
                                token.ThrowIfCancellationRequested ();
@@ -256,9 +212,28 @@ namespace System.Collections.Concurrent
 
                                // Empty case
                                if (cachedRemoveId == cachedAddId) {
-                                       if (IsCompleted)
+                                       if (milliseconds == 0)
                                                return false;
 
+                                       if (IsCompleted) {
+                                               if (throwComplete)
+                                                       ThrowCompleteException ();
+                                               else
+                                                       return false;
+                                       }
+
+                                       if (sw.Count <= spinCount) {
+                                               sw.SpinOnce ();
+                                       } else {
+                                               mreAdd.Reset ();
+                                               if (cachedRemoveId != removeId || cachedAddId != addId) {
+                                                       mreAdd.Set ();
+                                                       continue;
+                                               }
+
+                                               mreAdd.Wait (ComputeTimeout (milliseconds, start), token);
+                                       }
+
                                        continue;
                                }
 
@@ -267,10 +242,11 @@ namespace System.Collections.Concurrent
 
                                while (!underlyingColl.TryTake (out item));
 
-                               if (!mreRemove.IsSet)
-                                       mreRemove.Set ();
+                               mreRemove.Set ();
+
                                return true;
-                       } while (contFunc ());
+
+                       } while (milliseconds == -1 || (watch.ElapsedMilliseconds - start) < milliseconds);
 
                        return false;
                }
@@ -283,17 +259,13 @@ namespace System.Collections.Concurrent
                public bool TryTake (out T item, int millisecondsTimeout)
                {
                        item = default (T);
-                       Stopwatch sw = Stopwatch.StartNew ();
 
-                       return TryTake (out item, () => sw.ElapsedMilliseconds < millisecondsTimeout, CancellationToken.None);
+                       return TryTake (out item, millisecondsTimeout, CancellationToken.None, false);
                }
 
-               public bool TryTake (out T item, int millisecondsTimeout, CancellationToken token)
+               static int ComputeTimeout (int millisecondsTimeout, long start)
                {
-                       item = default (T);
-                       Stopwatch sw = Stopwatch.StartNew ();
-
-                       return TryTake (out item, () => sw.ElapsedMilliseconds < millisecondsTimeout, token);
+                       return millisecondsTimeout == -1 ? 500 : (int)Math.Max (watch.ElapsedMilliseconds - start - millisecondsTimeout, 1);
                }
                #endregion