44ad177a3b79df586abb4dc70f1faebb1a737287
[mono.git] / mcs / class / System.Core / System.Linq.Parallel / ParallelExecuter.cs
1 //
2 // ParallelExecuter.cs
3 //
4 // Author:
5 //       Jérémie "Garuma" Laval <jeremie.laval@gmail.com>
6 //
7 // Copyright (c) 2010 Jérémie "Garuma" Laval
8 //
9 // Permission is hereby granted, free of charge, to any person obtaining a copy
10 // of this software and associated documentation files (the "Software"), to deal
11 // in the Software without restriction, including without limitation the rights
12 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 // copies of the Software, and to permit persons to whom the Software is
14 // furnished to do so, subject to the following conditions:
15 //
16 // The above copyright notice and this permission notice shall be included in
17 // all copies or substantial portions of the Software.
18 //
19 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
25 // THE SOFTWARE.
26
27 #if NET_4_0
28 using System;
29 using System.Threading;
30 using System.Threading.Tasks;
31 using System.Collections;
32 using System.Collections.Generic;
33 using System.Collections.Concurrent;
34 using System.Linq.Parallel.QueryNodes;
35
36 namespace System.Linq.Parallel
37 {
38         internal static class ParallelExecuter
39         {
40                 internal static QueryOptions CheckQuery<T> (QueryBaseNode<T> startingNode)
41                 {
42                         return CheckQuery<T> (startingNode, false);
43                 }
44
45                 internal static QueryOptions CheckQuery<T> (QueryBaseNode<T> startingNode, bool blocking)
46                 {
47                         return CheckQuery (startingNode, GetBestWorkerNumber (blocking));
48                 }
49
50                 internal static QueryOptions CheckQuery<T> (QueryBaseNode<T> startingNode, int partitionCount)
51                 {
52                         QueryCheckerVisitor visitor = new QueryCheckerVisitor (partitionCount);
53                         startingNode.Visit (visitor);
54
55                         return visitor.Options;
56                 }
57
58                 internal static CancellationToken Chain (this CancellationToken self, CancellationTokenSource other)
59                 {
60                         CancellationTokenSource linked = CancellationTokenSource.CreateLinkedTokenSource (self, other.Token);
61                         return linked.Token;
62                 }
63
64                 internal static bool IsOrdered<TSource> (this QueryBaseNode<TSource> source)
65                 {
66                         QueryIsOrderedVisitor visitor = new QueryIsOrderedVisitor ();
67                         source.Visit (visitor);
68
69                         return visitor.BehindOrderGuard;
70                 }
71
72                 internal static int GetBestWorkerNumber ()
73                 {
74                         return GetBestWorkerNumber (false);
75                 }
76
77                 internal static int GetBestWorkerNumber (bool blocking)
78                 {
79                         return blocking && Task.CurrentId == null ? Environment.ProcessorCount + 1 : Environment.ProcessorCount;
80                 }
81
82                 internal static Task[] Process<TSource, TElement> (QueryBaseNode<TSource> node,
83                                                                    Action<TElement, CancellationToken> call,
84                                                                    Func<QueryBaseNode<TSource>, QueryOptions, IList<IEnumerable<TElement>>> acquisitionFunc,
85                                                                    QueryOptions options)
86                 {
87                         return Process<TSource, TElement> (node, call, acquisitionFunc, null, options);
88                 }
89
90                 internal static Task[] Process<TSource, TElement> (QueryBaseNode<TSource> node,
91                                                                    Action<TElement, CancellationToken> call,
92                                                                    Func<QueryBaseNode<TSource>, QueryOptions, IList<IEnumerable<TElement>>> acquisitionFunc,
93                                                                    Action endAction,
94                                                                    QueryOptions options)
95                 {
96                         CancellationTokenSource src
97                                 = CancellationTokenSource.CreateLinkedTokenSource (options.ImplementerToken, options.Token);
98
99                         IList<IEnumerable<TElement>> enumerables = acquisitionFunc (node, options);
100
101                         Task[] tasks = new Task[enumerables.Count];
102
103                         for (int i = 0; i < tasks.Length; i++) {
104                                 int index = i;
105                                 tasks[i] = Task.Factory.StartNew (() => {
106                                         try {
107                                                 foreach (TElement item in enumerables[index]) {
108                                                         // This is from specific operators
109                                                         if (options.ImplementerToken.IsCancellationRequested)
110                                                                 break;
111                                                         if (options.Token.IsCancellationRequested)
112                                                                 throw new OperationCanceledException (options.Token);
113
114                                                         call (item, src.Token);
115                                                 }
116                                         } finally {
117                                                 if (endAction != null)
118                                                         endAction ();
119                                         }
120                                   }, options.Token);
121                         }
122
123                         return tasks;
124                 }
125
126                 internal static void ProcessAndBlock<T> (QueryBaseNode<T> node, Action<T, CancellationToken> call)
127                 {
128                         QueryOptions options = CheckQuery (node, true);
129
130                         Task[] tasks = Process (node, call, (n, o) => n.GetEnumerables (o), options);
131                         Task.WaitAll (tasks, options.Token);
132                 }
133
134                 internal static Action ProcessAndCallback<T> (QueryBaseNode<T> node, Action<T, CancellationToken> call,
135                                                               Action callback, QueryOptions options)
136                 {
137                         Task[] tasks = Process (node, call, (n, o) => n.GetEnumerables (o), options);
138                         if (callback != null)
139                                 Task.Factory.ContinueWhenAll (tasks,  (_) => callback ());
140
141                         return () => Task.WaitAll (tasks, options.Token);
142                 }
143
144                 internal static Action ProcessAndCallback<T> (QueryBaseNode<T> node, Action<KeyValuePair<long, T>, CancellationToken> call,
145                                                               Action endAction,
146                                                               Action callback, QueryOptions options)
147                 {
148                         Task[] tasks = Process (node, call, (n, o) => n.GetOrderedEnumerables (o), endAction, options);
149                         if (callback != null)
150                                 Task.Factory.ContinueWhenAll (tasks,  (_) => callback ());
151
152                         return () => Task.WaitAll (tasks, options.Token);
153                 }
154
155                 internal static void ProcessAndAggregate<T, U> (QueryBaseNode<T> node,
156                                                                 Func<U> seedFunc,
157                                                                 Func<U, T, U> localCall,
158                                                                 Action<IList<U>> call)
159                 {
160                         QueryOptions options = CheckQuery (node, true);
161
162                         IList<IEnumerable<T>> enumerables = node.GetEnumerables (options);
163                         U[] locals = new U[enumerables.Count];
164                         Task[] tasks = new Task[enumerables.Count];
165
166                         if (seedFunc != null) {
167                                 for (int i = 0; i < locals.Length; i++)
168                                         locals[i] = seedFunc ();
169                         }
170
171                         for (int i = 0; i < tasks.Length; i++) {
172                                 var procSlot = new AggregateProcessSlot<T, U> (options,
173                                                                                i,
174                                                                                enumerables[i].GetEnumerator (),
175                                                                                locals,
176                                                                                localCall,
177                                                                                seedFunc);
178
179                                 tasks[i] = Task.Factory.StartNew (procSlot.Process, options.Token);
180                         }
181
182                         Task.WaitAll (tasks, options.Token);
183
184                         if (call != null)
185                                 call (locals);
186                 }
187
188                 class AggregateProcessSlot<T, U>
189                 {
190                         readonly QueryOptions options;
191                         readonly int index;
192                         readonly IEnumerator<T> enumerator;
193                         readonly U[] locals;
194                         readonly Func<U, T, U> localCall;
195                         readonly Func<U> seedFunc;
196
197                         public AggregateProcessSlot (QueryOptions options,
198                                                      int index,
199                                                      IEnumerator<T> enumerator,
200                                                      U[] locals,
201                                                      Func<U, T, U> localCall,
202                                                      Func<U> seedFunc)
203                         {
204                                 this.options = options;
205                                 this.index = index;
206                                 this.enumerator = enumerator;
207                                 this.locals = locals;
208                                 this.localCall = localCall;
209                                 this.seedFunc = seedFunc;
210                         }
211
212                         public void Process ()
213                         {
214                                 var token = options.Token;
215                                 var implementerToken = options.ImplementerToken;
216
217                                 try {
218                                         if (seedFunc == null) {
219                                                 if (!enumerator.MoveNext ())
220                                                         return;
221                                                 locals[index] = (U)(object)enumerator.Current;
222                                         }
223
224                                         while (enumerator.MoveNext ()) {
225                                                 if (implementerToken.IsCancellationRequested)
226                                                         break;
227                                                 token.ThrowIfCancellationRequested ();
228                                                 locals[index] = localCall (locals[index], enumerator.Current);
229                                         }
230                                 } finally {
231                                         enumerator.Dispose ();
232                                 }
233                         }
234                 }
235         }
236 }
237 #endif