Updates referencesource to .NET 4.7
[mono.git] / mcs / class / referencesource / System.Core / System / Linq / Parallel / QueryOperators / Unary / IndexedWhereQueryOperator.cs
1 // ==++==
2 //
3 //   Copyright (c) Microsoft Corporation.  All rights reserved.
4 // 
5 // ==--==
6 // =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+
7 //
8 // IndexedWhereQueryOperator.cs
9 //
10 // <OWNER>Microsoft</OWNER>
11 //
12 // =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
13
14 using System.Collections.Generic;
15 using System.Diagnostics.Contracts;
16 using System.Threading;
17
18 namespace System.Linq.Parallel
19 {
20     /// <summary>
21     /// A variant of the Where operator that supplies element index while performing the
22     /// filtering operation. This requires cooperation with partitioning and merging to
23     /// guarantee ordering is preserved.
24     ///
25     /// </summary>
26     /// <typeparam name="TInputOutput"></typeparam>
27     internal sealed class IndexedWhereQueryOperator<TInputOutput> : UnaryQueryOperator<TInputOutput, TInputOutput>
28     {
29
30         // Predicate function. Used to filter out non-matching elements during execution.
31         private Func<TInputOutput, int, bool> m_predicate;
32         private bool m_prematureMerge = false; // Whether to prematurely merge the input of this operator.
33         private bool m_limitsParallelism = false; // Whether this operator limits parallelism
34
35         //---------------------------------------------------------------------------------------
36         // Initializes a new where operator.
37         //
38         // Arguments:
39         //    child         - the child operator or data source from which to pull data
40         //    predicate     - a delegate representing the predicate function
41         //
42         // Assumptions:
43         //    predicate must be non null.
44         //
45
46         internal IndexedWhereQueryOperator(IEnumerable<TInputOutput> child,
47                                            Func<TInputOutput, int, bool> predicate)
48             :base(child)
49         {
50             Contract.Assert(child != null, "child data source cannot be null");
51             Contract.Assert(predicate != null, "need a filter function");
52
53             m_predicate = predicate;
54
55             // In an indexed Select, elements must be returned in the order in which
56             // indices were assigned.
57             m_outputOrdered = true;
58
59             InitOrdinalIndexState();
60         }
61
62         private void InitOrdinalIndexState()
63         {
64             OrdinalIndexState childIndexState = Child.OrdinalIndexState;
65             if (ExchangeUtilities.IsWorseThan(childIndexState, OrdinalIndexState.Correct))
66             {
67                 m_prematureMerge = true;
68                 m_limitsParallelism = childIndexState != OrdinalIndexState.Shuffled;
69             }
70
71             SetOrdinalIndexState(OrdinalIndexState.Increasing);
72         }
73
74
75         //---------------------------------------------------------------------------------------
76         // Just opens the current operator, including opening the child and wrapping it with
77         // partitions as needed.
78         //
79
80         internal override QueryResults<TInputOutput> Open(
81             QuerySettings settings, bool preferStriping)
82         {
83             QueryResults<TInputOutput> childQueryResults = Child.Open(settings, preferStriping);
84             return new UnaryQueryOperatorResults(childQueryResults, this, settings, preferStriping);
85         }
86
87         internal override void WrapPartitionedStream<TKey>(
88             PartitionedStream<TInputOutput, TKey> inputStream, IPartitionedStreamRecipient<TInputOutput> recipient, bool preferStriping, QuerySettings settings)
89         {
90             int partitionCount = inputStream.PartitionCount;
91
92             // If the index is not correct, we need to reindex.
93             PartitionedStream<TInputOutput, int> inputStreamInt;
94             if (m_prematureMerge)
95             {
96                 ListQueryResults<TInputOutput> listResults = ExecuteAndCollectResults(inputStream, partitionCount, Child.OutputOrdered, preferStriping, settings);
97                 inputStreamInt = listResults.GetPartitionedStream();
98             }
99             else
100             {
101                 Contract.Assert(typeof(TKey) == typeof(int));
102                 inputStreamInt = (PartitionedStream<TInputOutput, int>)(object)inputStream;
103             }
104
105             // Since the index is correct, the type of the index must be int
106             PartitionedStream<TInputOutput, int> outputStream =
107                 new PartitionedStream<TInputOutput, int>(partitionCount, Util.GetDefaultComparer<int>(), OrdinalIndexState);
108
109             for (int i = 0; i < partitionCount; i++)
110             {
111                 outputStream[i] = new IndexedWhereQueryOperatorEnumerator(inputStreamInt[i], m_predicate, settings.CancellationState.MergedCancellationToken);
112             }
113
114             recipient.Receive(outputStream);
115         }
116
117
118         //---------------------------------------------------------------------------------------
119         // Returns an enumerable that represents the query executing sequentially.
120         //
121
122         internal override IEnumerable<TInputOutput> AsSequentialQuery(CancellationToken token)
123         {
124             IEnumerable<TInputOutput> wrappedChild = CancellableEnumerable.Wrap(Child.AsSequentialQuery(token), token);
125             return wrappedChild.Where(m_predicate);
126         }
127
128
129         //---------------------------------------------------------------------------------------
130         // Whether this operator performs a premature merge that would not be performed in
131         // a similar sequential operation (i.e., in LINQ to Objects).
132         //
133
134         internal override bool LimitsParallelism
135         {
136             get { return m_limitsParallelism; }
137         }
138
139         //-----------------------------------------------------------------------------------
140         // An enumerator that implements the filtering logic.
141         //
142
143         private class IndexedWhereQueryOperatorEnumerator : QueryOperatorEnumerator<TInputOutput, int>
144         {
145
146             private readonly QueryOperatorEnumerator<TInputOutput, int> m_source; // The data source to enumerate.
147             private readonly Func<TInputOutput, int, bool> m_predicate; // The predicate used for filtering.
148             private CancellationToken m_cancellationToken;
149             private Shared<int> m_outputLoopCount;
150             //-----------------------------------------------------------------------------------
151             // Instantiates a new enumerator.
152             //
153
154             internal IndexedWhereQueryOperatorEnumerator(QueryOperatorEnumerator<TInputOutput, int> source, Func<TInputOutput, int, bool> predicate,
155                 CancellationToken cancellationToken)
156             {
157                 Contract.Assert(source != null);
158                 Contract.Assert(predicate != null);
159                 m_source = source;
160                 m_predicate = predicate;
161                 m_cancellationToken = cancellationToken;
162             }
163
164             //-----------------------------------------------------------------------------------
165             // Moves to the next matching element in the underlying data stream.
166             //
167
168             internal override bool MoveNext(ref TInputOutput currentElement, ref int currentKey)
169             {
170                 Contract.Assert(m_predicate != null, "expected a compiled operator");
171
172                 // Iterate through the input until we reach the end of the sequence or find
173                 // an element matching the predicate.
174                 
175                 if (m_outputLoopCount == null)
176                     m_outputLoopCount = new Shared<int>(0);
177                 
178                 while (m_source.MoveNext(ref currentElement, ref currentKey))
179                 {
180                     if ((m_outputLoopCount.Value++ & CancellationState.POLL_INTERVAL) == 0)
181                         CancellationState.ThrowIfCanceled(m_cancellationToken);
182
183                     if (m_predicate(currentElement, currentKey))
184                     {
185                         return true;
186                     }
187                 }
188
189                 return false;
190             }
191
192             protected override void Dispose(bool disposing)
193             {
194                 Contract.Assert(m_source != null);
195                 m_source.Dispose();
196             }
197
198         }
199
200     }
201 }