Fix bug in BlockingCollection<T>.TryTake Add corresponding unit test
[mono.git] / mcs / class / System / System.Collections.Concurrent / BlockingCollection.cs
1 //
2 // BlockingCollection.cs
3 //
4 // Copyright (c) 2008 Jérémie "Garuma" Laval
5 //
6 // Permission is hereby granted, free of charge, to any person obtaining a copy
7 // of this software and associated documentation files (the "Software"), to deal
8 // in the Software without restriction, including without limitation the rights
9 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 // copies of the Software, and to permit persons to whom the Software is
11 // furnished to do so, subject to the following conditions:
12 //
13 // The above copyright notice and this permission notice shall be included in
14 // all copies or substantial portions of the Software.
15 //
16 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 // THE SOFTWARE.
23 //
24 //
25
26 #if NET_4_0 || BOOTSTRAP_NET_4_0
27
28 using System;
29 using System.Threading;
30 using System.Collections;
31 using System.Collections.Generic;
32 using System.Diagnostics;
33 using System.Runtime.InteropServices;
34
35 namespace System.Collections.Concurrent
36 {
37         [ComVisible (false)]
38         [DebuggerDisplay ("Count={Count}")]
39         [DebuggerTypeProxy (typeof (CollectionDebuggerView<>))]
40         public class BlockingCollection<T> : IEnumerable<T>, ICollection, IEnumerable, IDisposable
41         {
42                 const int spinCount = 5;
43
44                 readonly IProducerConsumerCollection<T> underlyingColl;
45                 readonly int upperBound;
46
47                 AtomicBoolean isComplete;
48                 long completeId;
49
50                 /* The whole idea of the collection is to use these two long values in a transactional
51                  * to track and manage the actual data inside the underlying lock-free collection
52                  * instead of directly working with it or using external locking.
53                  *
54                  * They are manipulated with CAS and are guaranteed to increase over time and use
55                  * of the instance thus preventing ABA problems.
56                  */
57                 long addId = long.MinValue;
58                 long removeId = long.MinValue;
59
60                 /* These events are used solely for the purpose of having an optimized sleep cycle when
61                  * the BlockingCollection have to wait on an external event (Add or Remove for instance)
62                  */
63                 ManualResetEventSlim mreAdd = new ManualResetEventSlim (true);
64                 ManualResetEventSlim mreRemove = new ManualResetEventSlim (true);
65
66                 #region ctors
67                 public BlockingCollection ()
68                         : this (new ConcurrentQueue<T> (), -1)
69                 {
70                 }
71
72                 public BlockingCollection (int upperBound)
73                         : this (new ConcurrentQueue<T> (), upperBound)
74                 {
75                 }
76
77                 public BlockingCollection (IProducerConsumerCollection<T> underlyingColl)
78                         : this (underlyingColl, -1)
79                 {
80                 }
81
82                 public BlockingCollection (IProducerConsumerCollection<T> underlyingColl, int upperBound)
83                 {
84                         this.underlyingColl = underlyingColl;
85                         this.upperBound     = upperBound;
86                         this.isComplete     = new AtomicBoolean ();
87                 }
88                 #endregion
89
90                 #region Add & Remove (+ Try)
91                 public void Add (T item)
92                 {
93                         Add (item, CancellationToken.None);
94                 }
95
96                 public void Add (T item, CancellationToken token)
97                 {
98                         SpinWait sw = new SpinWait ();
99                         long cachedAddId;
100
101                         while (true) {
102                                 token.ThrowIfCancellationRequested ();
103
104                                 cachedAddId = addId;
105                                 long cachedRemoveId = removeId;
106
107                                 if (upperBound != -1) {
108                                         if (cachedAddId - cachedRemoveId > upperBound) {
109                                                 if (sw.Count <= spinCount) {
110                                                         sw.SpinOnce ();
111                                                 } else {
112                                                         if (mreRemove.IsSet)
113                                                                 continue;
114                                                         if (cachedRemoveId != removeId)
115                                                                 continue;
116
117                                                         mreRemove.Wait (token);
118                                                         mreRemove.Reset ();
119                                                 }
120
121                                                 continue;
122                                         }
123                                 }
124
125                                 // Check our transaction id against completed stored one
126                                 if (isComplete.Value && cachedAddId >= completeId)
127                                         ThrowCompleteException ();
128                                 if (Interlocked.CompareExchange (ref addId, cachedAddId + 1, cachedAddId) == cachedAddId)
129                                         break;
130                         }
131
132                         if (isComplete.Value && cachedAddId >= completeId)
133                                 ThrowCompleteException ();
134
135                         while (!underlyingColl.TryAdd (item));
136
137                         if (!mreAdd.IsSet)
138                                 mreAdd.Set ();
139                 }
140
141                 public T Take ()
142                 {
143                         return Take (CancellationToken.None);
144                 }
145
146                 public T Take (CancellationToken token)
147                 {
148                         SpinWait sw = new SpinWait ();
149
150                         while (true) {
151                                 token.ThrowIfCancellationRequested ();
152
153                                 long cachedRemoveId = removeId;
154                                 long cachedAddId = addId;
155
156                                 // Empty case
157                                 if (cachedRemoveId == cachedAddId) {
158                                         if (IsCompleted)
159                                                 ThrowCompleteException ();
160
161                                         if (sw.Count <= spinCount) {
162                                                 sw.SpinOnce ();
163                                         } else {
164                                                 if (cachedAddId != addId)
165                                                         continue;
166                                                 if (IsCompleted)
167                                                         ThrowCompleteException ();
168
169                                                 mreAdd.Wait (token);
170                                                 mreAdd.Reset ();
171                                         }
172
173                                         continue;
174                                 }
175
176                                 if (Interlocked.CompareExchange (ref removeId, cachedRemoveId + 1, cachedRemoveId) == cachedRemoveId)
177                                         break;
178                         }
179
180                         T item;
181                         while (!underlyingColl.TryTake (out item));
182
183                         if (!mreRemove.IsSet)
184                                 mreRemove.Set ();
185
186                         return item;
187                 }
188
189                 public bool TryAdd (T item)
190                 {
191                         return TryAdd (item, () => false, CancellationToken.None);
192                 }
193
194                 bool TryAdd (T item, Func<bool> contFunc, CancellationToken token)
195                 {
196                         do {
197                                 token.ThrowIfCancellationRequested ();
198
199                                 long cachedAddId = addId;
200                                 long cachedRemoveId = removeId;
201
202                                 if (upperBound != -1)
203                                         if (cachedAddId - cachedRemoveId > upperBound)
204                                                 continue;
205
206                                 // Check our transaction id against completed stored one
207                                 if (isComplete.Value && cachedAddId >= completeId)
208                                         throw new InvalidOperationException ("The BlockingCollection<T> has"
209                                                                              + " been marked as complete with regards to additions.");
210
211                                 if (Interlocked.CompareExchange (ref addId, cachedAddId + 1, cachedAddId) != cachedAddId)
212                                         continue;
213
214                                 while (!underlyingColl.TryAdd (item));
215
216                                 if (!mreAdd.IsSet)
217                                         mreAdd.Set ();
218
219                                 return true;
220                         } while (contFunc ());
221
222                         return false;
223                 }
224
225                 public bool TryAdd (T item, TimeSpan ts)
226                 {
227                         return TryAdd (item, (int)ts.TotalMilliseconds);
228                 }
229
230                 public bool TryAdd (T item, int millisecondsTimeout)
231                 {
232                         Stopwatch sw = Stopwatch.StartNew ();
233                         return TryAdd (item, () => sw.ElapsedMilliseconds < millisecondsTimeout, CancellationToken.None);
234                 }
235
236                 public bool TryAdd (T item, int millisecondsTimeout, CancellationToken token)
237                 {
238                         Stopwatch sw = Stopwatch.StartNew ();
239                         return TryAdd (item, () => sw.ElapsedMilliseconds < millisecondsTimeout, token);
240                 }
241
242                 public bool TryTake (out T item)
243                 {
244                         return TryTake (out item, () => false, CancellationToken.None);
245                 }
246
247                 bool TryTake (out T item, Func<bool> contFunc, CancellationToken token)
248                 {
249                         item = default (T);
250
251                         do {
252                                 token.ThrowIfCancellationRequested ();
253
254                                 long cachedRemoveId = removeId;
255                                 long cachedAddId = addId;
256
257                                 // Empty case
258                                 if (cachedRemoveId == cachedAddId) {
259                                         if (IsCompleted)
260                                                 return false;
261
262                                         continue;
263                                 }
264
265                                 if (Interlocked.CompareExchange (ref removeId, cachedRemoveId + 1, cachedRemoveId) != cachedRemoveId)
266                                         continue;
267
268                                 while (!underlyingColl.TryTake (out item));
269
270                                 if (!mreRemove.IsSet)
271                                         mreRemove.Set ();
272                                 return true;
273                         } while (contFunc ());
274
275                         return false;
276                 }
277
278                 public bool TryTake (out T item, TimeSpan ts)
279                 {
280                         return TryTake (out item, (int)ts.TotalMilliseconds);
281                 }
282
283                 public bool TryTake (out T item, int millisecondsTimeout)
284                 {
285                         item = default (T);
286                         Stopwatch sw = Stopwatch.StartNew ();
287
288                         return TryTake (out item, () => sw.ElapsedMilliseconds < millisecondsTimeout, CancellationToken.None);
289                 }
290
291                 public bool TryTake (out T item, int millisecondsTimeout, CancellationToken token)
292                 {
293                         item = default (T);
294                         Stopwatch sw = Stopwatch.StartNew ();
295
296                         return TryTake (out item, () => sw.ElapsedMilliseconds < millisecondsTimeout, token);
297                 }
298                 #endregion
299
300                 #region static methods
301                 static void CheckArray (BlockingCollection<T>[] collections)
302                 {
303                         if (collections == null)
304                                 throw new ArgumentNullException ("collections");
305                         if (collections.Length == 0 || IsThereANullElement (collections))
306                                 throw new ArgumentException ("The collections argument is a 0-length array or contains a null element.", "collections");
307                 }
308
309                 static bool IsThereANullElement (BlockingCollection<T>[] collections)
310                 {
311                         foreach (BlockingCollection<T> e in collections)
312                                 if (e == null)
313                                         return true;
314                         return false;
315                 }
316
317                 public static int AddToAny (BlockingCollection<T>[] collections, T item)
318                 {
319                         CheckArray (collections);
320                         int index = 0;
321                         foreach (var coll in collections) {
322                                 try {
323                                         coll.Add (item);
324                                         return index;
325                                 } catch {}
326                                 index++;
327                         }
328                         return -1;
329                 }
330
331                 public static int AddToAny (BlockingCollection<T>[] collections, T item, CancellationToken token)
332                 {
333                         CheckArray (collections);
334                         int index = 0;
335                         foreach (var coll in collections) {
336                                 try {
337                                         coll.Add (item, token);
338                                         return index;
339                                 } catch {}
340                                 index++;
341                         }
342                         return -1;
343                 }
344
345                 public static int TryAddToAny (BlockingCollection<T>[] collections, T item)
346                 {
347                         CheckArray (collections);
348                         int index = 0;
349                         foreach (var coll in collections) {
350                                 if (coll.TryAdd (item))
351                                         return index;
352                                 index++;
353                         }
354                         return -1;
355                 }
356
357                 public static int TryAddToAny (BlockingCollection<T>[] collections, T item, TimeSpan ts)
358                 {
359                         CheckArray (collections);
360                         int index = 0;
361                         foreach (var coll in collections) {
362                                 if (coll.TryAdd (item, ts))
363                                         return index;
364                                 index++;
365                         }
366                         return -1;
367                 }
368
369                 public static int TryAddToAny (BlockingCollection<T>[] collections, T item, int millisecondsTimeout)
370                 {
371                         CheckArray (collections);
372                         int index = 0;
373                         foreach (var coll in collections) {
374                                 if (coll.TryAdd (item, millisecondsTimeout))
375                                         return index;
376                                 index++;
377                         }
378                         return -1;
379                 }
380
381                 public static int TryAddToAny (BlockingCollection<T>[] collections, T item, int millisecondsTimeout,
382                                                CancellationToken token)
383                 {
384                         CheckArray (collections);
385                         int index = 0;
386                         foreach (var coll in collections) {
387                                 if (coll.TryAdd (item, millisecondsTimeout, token))
388                                         return index;
389                                 index++;
390                         }
391                         return -1;
392                 }
393
394                 public static int TakeFromAny (BlockingCollection<T>[] collections, out T item)
395                 {
396                         item = default (T);
397                         CheckArray (collections);
398                         int index = 0;
399                         foreach (var coll in collections) {
400                                 try {
401                                         item = coll.Take ();
402                                         return index;
403                                 } catch {}
404                                 index++;
405                         }
406                         return -1;
407                 }
408
409                 public static int TakeFromAny (BlockingCollection<T>[] collections, out T item, CancellationToken token)
410                 {
411                         item = default (T);
412                         CheckArray (collections);
413                         int index = 0;
414                         foreach (var coll in collections) {
415                                 try {
416                                         item = coll.Take (token);
417                                         return index;
418                                 } catch {}
419                                 index++;
420                         }
421                         return -1;
422                 }
423
424                 public static int TryTakeFromAny (BlockingCollection<T>[] collections, out T item)
425                 {
426                         item = default (T);
427
428                         CheckArray (collections);
429                         int index = 0;
430                         foreach (var coll in collections) {
431                                 if (coll.TryTake (out item))
432                                         return index;
433                                 index++;
434                         }
435                         return -1;
436                 }
437
438                 public static int TryTakeFromAny (BlockingCollection<T>[] collections, out T item, TimeSpan ts)
439                 {
440                         item = default (T);
441
442                         CheckArray (collections);
443                         int index = 0;
444                         foreach (var coll in collections) {
445                                 if (coll.TryTake (out item, ts))
446                                         return index;
447                                 index++;
448                         }
449                         return -1;
450                 }
451
452                 public static int TryTakeFromAny (BlockingCollection<T>[] collections, out T item, int millisecondsTimeout)
453                 {
454                         item = default (T);
455
456                         CheckArray (collections);
457                         int index = 0;
458                         foreach (var coll in collections) {
459                                 if (coll.TryTake (out item, millisecondsTimeout))
460                                         return index;
461                                 index++;
462                         }
463                         return -1;
464                 }
465
466                 public static int TryTakeFromAny (BlockingCollection<T>[] collections, out T item, int millisecondsTimeout,
467                                                   CancellationToken token)
468                 {
469                         item = default (T);
470
471                         CheckArray (collections);
472                         int index = 0;
473                         foreach (var coll in collections) {
474                                 if (coll.TryTake (out item, millisecondsTimeout, token))
475                                         return index;
476                                 index++;
477                         }
478                         return -1;
479                 }
480                 #endregion
481
482                 public void CompleteAdding ()
483                 {
484                         // No further add beside that point
485                         completeId = addId;
486                         isComplete.Value = true;
487                         // Wakeup some operation in case this has an impact
488                         mreAdd.Set ();
489                         mreRemove.Set ();
490                 }
491
492                 void ThrowCompleteException ()
493                 {
494                         throw new InvalidOperationException ("The BlockingCollection<T> has"
495                                                              + " been marked as complete with regards to additions.");
496                 }
497
498                 void ICollection.CopyTo (Array array, int index)
499                 {
500                         underlyingColl.CopyTo (array, index);
501                 }
502
503                 public void CopyTo (T[] array, int index)
504                 {
505                         underlyingColl.CopyTo (array, index);
506                 }
507
508                 public IEnumerable<T> GetConsumingEnumerable ()
509                 {
510                         return GetConsumingEnumerable (Take);
511                 }
512
513                 public IEnumerable<T> GetConsumingEnumerable (CancellationToken token)
514                 {
515                         return GetConsumingEnumerable (() => Take (token));
516                 }
517
518                 IEnumerable<T> GetConsumingEnumerable (Func<T> getFunc)
519                 {
520                         while (true) {
521                                 T item = default (T);
522
523                                 try {
524                                         item = getFunc ();
525                                 } catch {
526                                         break;
527                                 }
528
529                                 yield return item;
530                         }
531                 }
532
533                 IEnumerator IEnumerable.GetEnumerator ()
534                 {
535                         return ((IEnumerable)underlyingColl).GetEnumerator ();
536                 }
537
538                 IEnumerator<T> IEnumerable<T>.GetEnumerator ()
539                 {
540                         return ((IEnumerable<T>)underlyingColl).GetEnumerator ();
541                 }
542
543                 public void Dispose ()
544                 {
545
546                 }
547
548                 protected virtual void Dispose (bool managedRes)
549                 {
550
551                 }
552
553                 public T[] ToArray ()
554                 {
555                         return underlyingColl.ToArray ();
556                 }
557
558                 public int BoundedCapacity {
559                         get {
560                                 return upperBound;
561                         }
562                 }
563
564                 public int Count {
565                         get {
566                                 return underlyingColl.Count;
567                         }
568                 }
569
570                 public bool IsAddingCompleted {
571                         get {
572                                 return isComplete.Value;
573                         }
574                 }
575
576                 public bool IsCompleted {
577                         get {
578                                 return isComplete.Value && addId == removeId;
579                         }
580                 }
581
582                 object ICollection.SyncRoot {
583                         get {
584                                 return underlyingColl.SyncRoot;
585                         }
586                 }
587
588                 bool ICollection.IsSynchronized {
589                         get {
590                                 return underlyingColl.IsSynchronized;
591                         }
592                 }
593         }
594 }
595 #endif