Merge pull request #408 from strawd/master
[mono.git] / mcs / class / System / System.Collections.Concurrent / BlockingCollection.cs
index f3d1b4e724e64cfeaa86293e76babc4d1c9fa7fa..41a241b10e20dc12a96e75c6c3c8faf43ebb44e3 100644 (file)
@@ -1,4 +1,4 @@
-#if NET_4_0
+//
 // BlockingCollection.cs
 //
 // Copyright (c) 2008 Jérémie "Garuma" Laval
 //
 //
 
+#if NET_4_0
+
 using System;
 using System.Threading;
 using System.Collections;
 using System.Collections.Generic;
 using System.Diagnostics;
+using System.Runtime.InteropServices;
 
 namespace System.Collections.Concurrent
 {
+       [ComVisible (false)]
+       [DebuggerDisplay ("Count={Count}")]
+       [DebuggerTypeProxy (typeof (CollectionDebuggerView<>))]
        public class BlockingCollection<T> : IEnumerable<T>, ICollection, IEnumerable, IDisposable
        {
+               const int spinCount = 5;
+
                readonly IProducerConsumerCollection<T> underlyingColl;
                readonly int upperBound;
-               
-               readonly SpinWait sw = new SpinWait ();
-               
+
                AtomicBoolean isComplete;
                long completeId;
 
+               /* The whole idea of the collection is to use these two long values in a transactional
+                * way to track and manage the actual data inside the underlying lock-free collection
+                * instead of directly working with it or using external locking.
+                *
+                * They are manipulated with CAS and are guaranteed to increase over time and use
+                * of the instance thus preventing ABA problems.
+                */
                long addId = long.MinValue;
                long removeId = long.MinValue;
-               
+
+               /* These events are used solely for the purpose of having an optimized sleep cycle when
+                * the BlockingCollection have to wait on an external event (Add or Remove for instance)
+                */
+               ManualResetEventSlim mreAdd = new ManualResetEventSlim (true);
+               ManualResetEventSlim mreRemove = new ManualResetEventSlim (true);
+
+               /* For time based operations, we share this instance of Stopwatch and base calculation
+                  on a time offset at each of these method call */
+               static Stopwatch watch = new Stopwatch ();
+
                #region ctors
                public BlockingCollection ()
                        : this (new ConcurrentQueue<T> (), -1)
                {
                }
-               
-               public BlockingCollection (int upperBound)
-                       : this (new ConcurrentQueue<T> (), upperBound)
+
+               public BlockingCollection (int boundedCapacity)
+                       : this (new ConcurrentQueue<T> (), boundedCapacity)
                {
                }
-               
-               public BlockingCollection (IProducerConsumerCollection<T> underlyingColl)
-                       : this (underlyingColl, -1)
+
+               public BlockingCollection (IProducerConsumerCollection<T> collection)
+                       : this (collection, -1)
                {
                }
-               
-               public BlockingCollection (IProducerConsumerCollection<T> underlyingColl, int upperBound)
+
+               public BlockingCollection (IProducerConsumerCollection<T> collection, int boundedCapacity)
                {
-                       this.underlyingColl = underlyingColl;
-                       this.upperBound     = upperBound;
+                       this.underlyingColl = collection;
+                       this.upperBound     = boundedCapacity;
                        this.isComplete     = new AtomicBoolean ();
                }
                #endregion
-               
+
                #region Add & Remove (+ Try)
                public void Add (T item)
                {
-                       Add (item, null);
+                       Add (item, CancellationToken.None);
                }
-               
-               public void Add (T item, CancellationToken token)
-               {
-                       Add (item, () => token.IsCancellationRequested);
-               }
-               
-               void Add (T item, Func<bool> cancellationFunc)
-               {
-                       while (true) {
-                               long cachedAddId = addId;
-                               long cachedRemoveId = removeId;
-                               
-                               if (upperBound != -1) {
-                                       if (cachedAddId - cachedRemoveId > upperBound) {
-                                               Block ();
-                                               continue;
-                                       }
-                               }
-                               
-                               // Check our transaction id against completed stored one
-                               if (isComplete.Value && cachedAddId >= completeId)
-                                       throw new InvalidOperationException ("The BlockingCollection<T> has"
-                                                                            + " been marked as complete with regards to additions.");
-                               
-                               if (Interlocked.CompareExchange (ref addId, cachedAddId + 1, cachedAddId) == cachedAddId)
-                                       break;
-                               
-                               if (cancellationFunc != null && cancellationFunc ())
-                                       throw new OperationCanceledException ("CancellationToken triggered");
-                       }
-                       
-                       
-                       if (!underlyingColl.TryAdd (item))
-                               throw new InvalidOperationException ("The underlying collection didn't accept the item.");
-               }
-               
-               public T Take ()
-               {
-                       return Take (null);
-               }
-               
-               public T Take (CancellationToken token)
-               {
-                       return Take (() => token.IsCancellationRequested);
-               }
-               
-               T Take (Func<bool> cancellationFunc)
+
+               public void Add (T item, CancellationToken cancellationToken)
                {
-                       while (true) {
-                               long cachedRemoveId = removeId;
-                               long cachedAddId = addId;
-                               
-                               // Empty case
-                               if (cachedRemoveId == cachedAddId) {
-                                       if (isComplete.Value && cachedRemoveId >= completeId)
-                                               throw new OperationCanceledException ("The BlockingCollection<T> has"
-                                                                                     + " been marked as complete with regards to additions.");
-                                       
-                                       Block ();
-                                       continue;
-                               }
-                               
-                               if (Interlocked.CompareExchange (ref removeId, cachedRemoveId + 1, cachedRemoveId) == cachedRemoveId)
-                                       break;
-                               
-                               if (cancellationFunc != null && cancellationFunc ())
-                                       throw new OperationCanceledException ("The CancellationToken has had cancellation requested.");
-                       }
-                       
-                       T item;
-                       while (!underlyingColl.TryTake (out item));
-                       
-                       return item;
+                       TryAdd (item, -1, cancellationToken);
                }
-               
+
                public bool TryAdd (T item)
                {
-                       return TryAdd (item, null, null);
+                       return TryAdd (item, 0, CancellationToken.None);
                }
-               
-               bool TryAdd (T item, Func<bool> contFunc, CancellationToken? token)
+
+               public bool TryAdd (T item, int millisecondsTimeout, CancellationToken cancellationToken)
                {
+                       if (millisecondsTimeout < -1)
+                               throw new ArgumentOutOfRangeException ("millisecondsTimeout");
+
+                       long start = millisecondsTimeout == -1 ? 0 : watch.ElapsedMilliseconds;
+                       SpinWait sw = new SpinWait ();
+
                        do {
-                               if (token.HasValue && token.Value.IsCancellationRequested)
-                                       throw new OperationCanceledException ("The CancellationToken has had cancellation requested.");
-                               
+                               cancellationToken.ThrowIfCancellationRequested ();
+
                                long cachedAddId = addId;
                                long cachedRemoveId = removeId;
-                               
-                               if (upperBound != -1) {
-                                       if (cachedAddId - cachedRemoveId > upperBound) {
-                                               continue;
+
+                               // If needed, we check and wait that the collection isn't full
+                               if (upperBound != -1 && cachedAddId - cachedRemoveId > upperBound) {
+                                       if (millisecondsTimeout == 0)
+                                               return false;
+
+                                       if (sw.Count <= spinCount) {
+                                               sw.SpinOnce ();
+                                       } else {
+                                               mreRemove.Reset ();
+                                               if (cachedRemoveId != removeId || cachedAddId != addId) {
+                                                       mreRemove.Set ();
+                                                       continue;
+                                               }
+
+                                               mreRemove.Wait (ComputeTimeout (millisecondsTimeout, start), cancellationToken);
                                        }
+
+                                       continue;
                                }
-                               
+
                                // Check our transaction id against completed stored one
                                if (isComplete.Value && cachedAddId >= completeId)
-                                       throw new InvalidOperationException ("The BlockingCollection<T> has"
-                                                                            + " been marked as complete with regards to additions.");
-                               
+                                       ThrowCompleteException ();
+
+                               // Validate the steps we have been doing until now
                                if (Interlocked.CompareExchange (ref addId, cachedAddId + 1, cachedAddId) != cachedAddId)
                                        continue;
-                       
+
+                               // We have a slot reserved in the underlying collection, try to take it
                                if (!underlyingColl.TryAdd (item))
                                        throw new InvalidOperationException ("The underlying collection didn't accept the item.");
-                               
+
+                               // Wake up process that may have been sleeping
+                               mreAdd.Set ();
+
                                return true;
-                       } while (contFunc != null && contFunc ());
-                       
+                       } while (millisecondsTimeout == -1 || (watch.ElapsedMilliseconds - start) < millisecondsTimeout);
+
                        return false;
                }
-               
-               public bool TryAdd (T item, TimeSpan ts)
+
+               public bool TryAdd (T item, TimeSpan timeout)
                {
-                       return TryAdd (item, (int)ts.TotalMilliseconds);
+                       return TryAdd (item, (int)timeout.TotalMilliseconds);
                }
-               
+
                public bool TryAdd (T item, int millisecondsTimeout)
                {
-                       Stopwatch sw = Stopwatch.StartNew ();
-                       return TryAdd (item, () => sw.ElapsedMilliseconds < millisecondsTimeout, null);
+                       return TryAdd (item, millisecondsTimeout, CancellationToken.None);
                }
-               
-               public bool TryAdd (T item, int millisecondsTimeout, CancellationToken token)
+
+               public T Take ()
                {
-                       Stopwatch sw = Stopwatch.StartNew ();
-                       return TryAdd (item, () => sw.ElapsedMilliseconds < millisecondsTimeout, token);
+                       return Take (CancellationToken.None);
                }
-               
+
+               public T Take (CancellationToken cancellationToken)
+               {
+                       T item;
+                       TryTake (out item, -1, cancellationToken, true);
+
+                       return item;
+               }
+
                public bool TryTake (out T item)
                {
-                       return TryTake (out item, null, null);
+                       return TryTake (out item, 0, CancellationToken.None);
                }
-               
-               bool TryTake (out T item, Func<bool> contFunc, CancellationToken? token)
+
+               public bool TryTake (out T item, int millisecondsTimeout, CancellationToken cancellationToken)
                {
+                       return TryTake (out item, millisecondsTimeout, cancellationToken, false);
+               }
+
+               bool TryTake (out T item, int milliseconds, CancellationToken cancellationToken, bool throwComplete)
+               {
+                       if (milliseconds < -1)
+                               throw new ArgumentOutOfRangeException ("milliseconds");
+
                        item = default (T);
-                       
+                       SpinWait sw = new SpinWait ();
+                       long start = milliseconds == -1 ? 0 : watch.ElapsedMilliseconds;
+
                        do {
-                               if (token.HasValue && token.Value.IsCancellationRequested)
-                                       throw new OperationCanceledException ("The CancellationToken has had cancellation requested.");
-                               
+                               cancellationToken.ThrowIfCancellationRequested ();
+
                                long cachedRemoveId = removeId;
                                long cachedAddId = addId;
-                               
+
                                // Empty case
                                if (cachedRemoveId == cachedAddId) {
-                                       if (isComplete.Value && cachedRemoveId >= completeId)
-                                               continue;
-                                       
+                                       if (milliseconds == 0)
+                                               return false;
+
+                                       if (IsCompleted) {
+                                               if (throwComplete)
+                                                       ThrowCompleteException ();
+                                               else
+                                                       return false;
+                                       }
+
+                                       if (sw.Count <= spinCount) {
+                                               sw.SpinOnce ();
+                                       } else {
+                                               mreAdd.Reset ();
+                                               if (cachedRemoveId != removeId || cachedAddId != addId) {
+                                                       mreAdd.Set ();
+                                                       continue;
+                                               }
+
+                                               mreAdd.Wait (ComputeTimeout (milliseconds, start), cancellationToken);
+                                       }
+
                                        continue;
                                }
-                               
+
                                if (Interlocked.CompareExchange (ref removeId, cachedRemoveId + 1, cachedRemoveId) != cachedRemoveId)
                                        continue;
-                               
-                               return underlyingColl.TryTake (out item);
-                       } while (contFunc != null && contFunc ());
-                       
+
+                               while (!underlyingColl.TryTake (out item));
+
+                               mreRemove.Set ();
+
+                               return true;
+
+                       } while (milliseconds == -1 || (watch.ElapsedMilliseconds - start) < milliseconds);
+
                        return false;
                }
-               
-               public bool TryTake (out T item, TimeSpan ts)
+
+               public bool TryTake (out T item, TimeSpan timeout)
                {
-                       return TryTake (out item, (int)ts.TotalMilliseconds);
+                       return TryTake (out item, (int)timeout.TotalMilliseconds);
                }
-               
+
                public bool TryTake (out T item, int millisecondsTimeout)
                {
                        item = default (T);
-                       Stopwatch sw = Stopwatch.StartNew ();
-                       
-                       return TryTake (out item, () => sw.ElapsedMilliseconds < millisecondsTimeout, null);
+
+                       return TryTake (out item, millisecondsTimeout, CancellationToken.None, false);
                }
-               
-               public bool TryTake (out T item, int millisecondsTimeout, CancellationToken token)
+
+               static int ComputeTimeout (int millisecondsTimeout, long start)
                {
-                       item = default (T);
-                       Stopwatch sw = Stopwatch.StartNew ();
-                       
-                       return TryTake (out item, () => sw.ElapsedMilliseconds < millisecondsTimeout, token);
+                       return millisecondsTimeout == -1 ? 500 : (int)Math.Max (watch.ElapsedMilliseconds - start - millisecondsTimeout, 1);
                }
                #endregion
-               
+
                #region static methods
                static void CheckArray (BlockingCollection<T>[] collections)
                {
@@ -265,7 +277,7 @@ namespace System.Collections.Concurrent
                        if (collections.Length == 0 || IsThereANullElement (collections))
                                throw new ArgumentException ("The collections argument is a 0-length array or contains a null element.", "collections");
                }
-               
+
                static bool IsThereANullElement (BlockingCollection<T>[] collections)
                {
                        foreach (BlockingCollection<T> e in collections)
@@ -273,7 +285,7 @@ namespace System.Collections.Concurrent
                                        return true;
                        return false;
                }
-               
+
                public static int AddToAny (BlockingCollection<T>[] collections, T item)
                {
                        CheckArray (collections);
@@ -287,21 +299,21 @@ namespace System.Collections.Concurrent
                        }
                        return -1;
                }
-               
-               public static int AddToAny (BlockingCollection<T>[] collections, T item, CancellationToken token)
+
+               public static int AddToAny (BlockingCollection<T>[] collections, T item, CancellationToken cancellationToken)
                {
                        CheckArray (collections);
                        int index = 0;
                        foreach (var coll in collections) {
                                try {
-                                       coll.Add (item, token);
+                                       coll.Add (item, cancellationToken);
                                        return index;
                                } catch {}
                                index++;
                        }
                        return -1;
                }
-               
+
                public static int TryAddToAny (BlockingCollection<T>[] collections, T item)
                {
                        CheckArray (collections);
@@ -313,19 +325,19 @@ namespace System.Collections.Concurrent
                        }
                        return -1;
                }
-               
-               public static int TryAddToAny (BlockingCollection<T>[] collections, T item, TimeSpan ts)
+
+               public static int TryAddToAny (BlockingCollection<T>[] collections, T item, TimeSpan timeout)
                {
                        CheckArray (collections);
                        int index = 0;
                        foreach (var coll in collections) {
-                               if (coll.TryAdd (item, ts))
+                               if (coll.TryAdd (item, timeout))
                                        return index;
                                index++;
                        }
                        return -1;
                }
-               
+
                public static int TryAddToAny (BlockingCollection<T>[] collections, T item, int millisecondsTimeout)
                {
                        CheckArray (collections);
@@ -337,20 +349,20 @@ namespace System.Collections.Concurrent
                        }
                        return -1;
                }
-               
+
                public static int TryAddToAny (BlockingCollection<T>[] collections, T item, int millisecondsTimeout,
-                                              CancellationToken token)
+                                              CancellationToken cancellationToken)
                {
                        CheckArray (collections);
                        int index = 0;
                        foreach (var coll in collections) {
-                               if (coll.TryAdd (item, millisecondsTimeout, token))
+                               if (coll.TryAdd (item, millisecondsTimeout, cancellationToken))
                                        return index;
                                index++;
                        }
                        return -1;
                }
-               
+
                public static int TakeFromAny (BlockingCollection<T>[] collections, out T item)
                {
                        item = default (T);
@@ -365,26 +377,26 @@ namespace System.Collections.Concurrent
                        }
                        return -1;
                }
-               
-               public static int TakeFromAny (BlockingCollection<T>[] collections, out T item, CancellationToken token)
+
+               public static int TakeFromAny (BlockingCollection<T>[] collections, out T item, CancellationToken cancellationToken)
                {
                        item = default (T);
                        CheckArray (collections);
                        int index = 0;
                        foreach (var coll in collections) {
                                try {
-                                       item = coll.Take (token);
+                                       item = coll.Take (cancellationToken);
                                        return index;
                                } catch {}
                                index++;
                        }
                        return -1;
                }
-               
+
                public static int TryTakeFromAny (BlockingCollection<T>[] collections, out T item)
                {
                        item = default (T);
-                       
+
                        CheckArray (collections);
                        int index = 0;
                        foreach (var coll in collections) {
@@ -394,25 +406,25 @@ namespace System.Collections.Concurrent
                        }
                        return -1;
                }
-               
-               public static int TryTakeFromAny (BlockingCollection<T>[] collections, out T item, TimeSpan ts)
+
+               public static int TryTakeFromAny (BlockingCollection<T>[] collections, out T item, TimeSpan timeout)
                {
                        item = default (T);
-                       
+
                        CheckArray (collections);
                        int index = 0;
                        foreach (var coll in collections) {
-                               if (coll.TryTake (out item, ts))
+                               if (coll.TryTake (out item, timeout))
                                        return index;
                                index++;
                        }
                        return -1;
                }
-               
+
                public static int TryTakeFromAny (BlockingCollection<T>[] collections, out T item, int millisecondsTimeout)
                {
                        item = default (T);
-                       
+
                        CheckArray (collections);
                        int index = 0;
                        foreach (var coll in collections) {
@@ -422,126 +434,128 @@ namespace System.Collections.Concurrent
                        }
                        return -1;
                }
-               
+
                public static int TryTakeFromAny (BlockingCollection<T>[] collections, out T item, int millisecondsTimeout,
-                                                 CancellationToken token)
+                                                 CancellationToken cancellationToken)
                {
                        item = default (T);
-                       
+
                        CheckArray (collections);
                        int index = 0;
                        foreach (var coll in collections) {
-                               if (coll.TryTake (out item, millisecondsTimeout, token))
+                               if (coll.TryTake (out item, millisecondsTimeout, cancellationToken))
                                        return index;
                                index++;
                        }
                        return -1;
                }
                #endregion
-               
+
                public void CompleteAdding ()
                {
-                 // No further add beside that point
-                 completeId = addId;
-                 isComplete.Value = true;
+                       // No further add beside that point
+                       completeId = addId;
+                       isComplete.Value = true;
+                       // Wakeup some operation in case this has an impact
+                       mreAdd.Set ();
+                       mreRemove.Set ();
                }
-               
+
+               void ThrowCompleteException ()
+               {
+                       throw new InvalidOperationException ("The BlockingCollection<T> has"
+                                                            + " been marked as complete with regards to additions.");
+               }
+
                void ICollection.CopyTo (Array array, int index)
                {
                        underlyingColl.CopyTo (array, index);
                }
-               
+
                public void CopyTo (T[] array, int index)
                {
                        underlyingColl.CopyTo (array, index);
                }
-               
+
                public IEnumerable<T> GetConsumingEnumerable ()
                {
-                       return GetConsumingEnumerable (Take);
-               }
-               
-               public IEnumerable<T> GetConsumingEnumerable (CancellationToken token)
-               {
-                       return GetConsumingEnumerable (() => Take (token));
+                       return GetConsumingEnumerable (CancellationToken.None);
                }
-               
-               IEnumerable<T> GetConsumingEnumerable (Func<T> getFunc)
+
+               public IEnumerable<T> GetConsumingEnumerable (CancellationToken cancellationToken)
                {
                        while (true) {
                                T item = default (T);
-                               
+
                                try {
-                                       item = getFunc ();
+                                       item = Take (cancellationToken);
                                } catch {
-                                       break;
+                                       // Then the exception is perfectly normal
+                                       if (IsCompleted)
+                                               break;
+                                       // otherwise rethrow
+                                       throw;
                                }
-                               
+
                                yield return item;
                        }
                }
-               
+
                IEnumerator IEnumerable.GetEnumerator ()
                {
                        return ((IEnumerable)underlyingColl).GetEnumerator ();
                }
-               
+
                IEnumerator<T> IEnumerable<T>.GetEnumerator ()
                {
                        return ((IEnumerable<T>)underlyingColl).GetEnumerator ();
                }
-               
+
                public void Dispose ()
                {
-                       
+
                }
-               
-               protected virtual void Dispose (bool managedRes)
+
+               protected virtual void Dispose (bool disposing)
                {
-                       
+
                }
-               
+
                public T[] ToArray ()
                {
                        return underlyingColl.ToArray ();
                }
-               
-               // Method used to stall the thread for a limited period of time before retrying an operation
-               void Block ()
-               {
-                       sw.SpinOnce ();
-               }
-               
+
                public int BoundedCapacity {
                        get {
                                return upperBound;
                        }
                }
-               
+
                public int Count {
                        get {
                                return underlyingColl.Count;
                        }
                }
-               
+
                public bool IsAddingCompleted {
                        get {
                                return isComplete.Value;
                        }
                }
-               
+
                public bool IsCompleted {
                        get {
                                return isComplete.Value && addId == removeId;
                        }
                }
-               
+
                object ICollection.SyncRoot {
                        get {
                                return underlyingColl.SyncRoot;
                        }
                }
-               
+
                bool ICollection.IsSynchronized {
                        get {
                                return underlyingColl.IsSynchronized;