Updates referencesource to .NET 4.7
[mono.git] / mcs / class / referencesource / System.Core / System / Linq / SequenceQuery.cs
1 using System;
2 using System.Collections;
3 using System.Collections.Generic;
4 using System.Collections.ObjectModel;
5 using System.Linq.Expressions;
6 using System.Reflection;
7 using System.Text;
8
9 // Include Silverlight's managed resources
10 #if SILVERLIGHT
11 using System.Core;
12 #endif //SILVERLIGHT
13
14 namespace System.Linq {
15     
16     // Must remain public for Silverlight
17     public abstract class EnumerableQuery {
18         internal abstract Expression Expression { get; }
19         internal abstract IEnumerable Enumerable { get; }
20         internal static IQueryable Create(Type elementType, IEnumerable sequence){
21             Type seqType = typeof(EnumerableQuery<>).MakeGenericType(elementType);
22 #if SILVERLIGHT
23             return (IQueryable) Activator.CreateInstance(seqType, sequence);
24 #else
25             return (IQueryable) Activator.CreateInstance(seqType, BindingFlags.Instance|BindingFlags.Public|BindingFlags.NonPublic, null, new object[] {sequence}, null);
26 #endif //SILVERLIGHT
27         }
28
29         internal static IQueryable Create(Type elementType, Expression expression) {
30             Type seqType = typeof(EnumerableQuery<>).MakeGenericType(elementType);
31 #if SILVERLIGHT
32             return (IQueryable) Activator.CreateInstance(seqType, expression);
33 #else
34             return (IQueryable) Activator.CreateInstance(seqType, BindingFlags.Instance|BindingFlags.Public|BindingFlags.NonPublic, null, new object[] {expression}, null);
35 #endif //SILVERLIGHT
36         }
37     }    
38
39     // Must remain public for Silverlight
40     public class EnumerableQuery<T> : EnumerableQuery, IOrderedQueryable<T>, IQueryable, IQueryProvider, IEnumerable<T>, IEnumerable {
41         Expression expression;
42         IEnumerable<T> enumerable;
43         
44         IQueryProvider IQueryable.Provider {
45             get{
46                return (IQueryProvider)this;
47             }
48         }
49
50         // Must remain public for Silverlight
51         public EnumerableQuery(IEnumerable<T> enumerable) {
52             this.enumerable = enumerable;
53             this.expression = Expression.Constant(this);            
54         }
55
56         // Must remain public for Silverlight
57         public EnumerableQuery(Expression expression) {
58             this.expression = expression;            
59         }
60
61         internal override Expression Expression {
62             get { return this.expression; }
63         }
64
65         internal override IEnumerable Enumerable {
66             get { return this.enumerable; }
67         }
68         
69         Expression IQueryable.Expression {
70             get { return this.expression; }
71         }
72
73         Type IQueryable.ElementType {
74             get { return typeof(T); }
75         }
76
77         IQueryable IQueryProvider.CreateQuery(Expression expression){
78             if (expression == null)
79                 throw Error.ArgumentNull("expression");
80             Type iqType = TypeHelper.FindGenericType(typeof(IQueryable<>), expression.Type);
81             if (iqType == null)
82                 throw Error.ArgumentNotValid("expression");
83             return EnumerableQuery.Create(iqType.GetGenericArguments()[0], expression);
84         }
85
86         IQueryable<S> IQueryProvider.CreateQuery<S>(Expression expression){
87             if (expression == null)
88                 throw Error.ArgumentNull("expression");
89             if (!typeof(IQueryable<S>).IsAssignableFrom(expression.Type)){
90                 throw Error.ArgumentNotValid("expression");
91             }
92             return new EnumerableQuery<S>(expression);
93         }
94
95         // Baselining as Safe for Mix demo so that interface can be transparent. Marking this
96         // critical (which was the original annotation when porting to silverlight) would violate
97         // fxcop security rules if the interface isn't also critical. However, transparent code
98         // can't access this anyway for Mix since we're not exposing AsQueryable().
99         // Microsoft: the above assertion no longer holds. Now making AsQueryable() public again
100         // the security fallout of which will need to be re-examined.
101         object IQueryProvider.Execute(Expression expression){
102             if (expression == null)
103                 throw Error.ArgumentNull("expression");
104 #if !MONO
105             Type execType = typeof(EnumerableExecutor<>).MakeGenericType(expression.Type);
106 #endif
107             return EnumerableExecutor.Create(expression).ExecuteBoxed();
108         }
109
110         // see above
111         S IQueryProvider.Execute<S>(Expression expression){
112             if (expression == null)
113                 throw Error.ArgumentNull("expression");
114             if (!typeof(S).IsAssignableFrom(expression.Type))
115                 throw Error.ArgumentNotValid("expression");
116             return new EnumerableExecutor<S>(expression).Execute();
117         }
118
119         IEnumerator IEnumerable.GetEnumerator() {
120             return this.GetEnumerator();
121         }
122
123         IEnumerator<T> IEnumerable<T>.GetEnumerator() {
124             return this.GetEnumerator();
125         }
126
127         IEnumerator<T> GetEnumerator() {
128             if (this.enumerable == null) {
129                 EnumerableRewriter rewriter = new EnumerableRewriter();
130                 Expression body = rewriter.Visit(this.expression);
131                 Expression<Func<IEnumerable<T>>> f = Expression.Lambda<Func<IEnumerable<T>>>(body, (IEnumerable<ParameterExpression>)null);
132                 this.enumerable = f.Compile()();
133             }
134             return this.enumerable.GetEnumerator();
135         }
136
137         public override string ToString() {
138             ConstantExpression c = this.expression as ConstantExpression;
139             if (c != null && c.Value == this) {
140                 if (this.enumerable != null)
141                     return this.enumerable.ToString();
142                 return "null";
143             }
144             return this.expression.ToString();
145         }
146     }
147
148     // Must remain public for Silverlight
149     public abstract class EnumerableExecutor {
150         internal abstract object ExecuteBoxed();
151
152         internal static EnumerableExecutor Create(Expression expression) {
153             Type execType = typeof(EnumerableExecutor<>).MakeGenericType(expression.Type);
154 #if SILVERLIGHT
155             return (EnumerableExecutor)Activator.CreateInstance(execType, expression);
156 #else
157             return (EnumerableExecutor)Activator.CreateInstance(execType, BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, new object[] { expression }, null);
158 #endif //SILVERLIGHT
159         }
160     }
161
162     // Must remain public for Silverlight
163     public class EnumerableExecutor<T> : EnumerableExecutor{
164         Expression expression;
165         Func<T> func;
166
167         // Must remain public for Silverlight
168         public EnumerableExecutor(Expression expression){
169             this.expression = expression;
170         }
171
172         internal override object ExecuteBoxed() {
173             return this.Execute();
174         }
175
176         internal T Execute(){
177             if (this.func == null){
178                 EnumerableRewriter rewriter = new EnumerableRewriter();
179                 Expression body = rewriter.Visit(this.expression);
180                 Expression<Func<T>> f = Expression.Lambda<Func<T>>(body, (IEnumerable<ParameterExpression>)null);
181                 this.func = f.Compile();
182             }
183             return this.func();
184         }
185     }
186     
187     // 
188     internal class EnumerableRewriter : OldExpressionVisitor {
189
190         internal EnumerableRewriter() {
191         }
192
193         internal override Expression VisitMethodCall(MethodCallExpression m) {
194             Expression obj = this.Visit(m.Object);
195             ReadOnlyCollection<Expression> args = this.VisitExpressionList(m.Arguments);
196
197             // check for args changed
198             if (obj != m.Object || args != m.Arguments) {
199 #if !MONO
200                 Expression[] argArray = args.ToArray();
201 #endif
202                 Type[] typeArgs = (m.Method.IsGenericMethod) ? m.Method.GetGenericArguments() : null;
203
204                 if ((m.Method.IsStatic || m.Method.DeclaringType.IsAssignableFrom(obj.Type)) 
205                     && ArgsMatch(m.Method, args, typeArgs)) {
206                     // current method is still valid
207                     return Expression.Call(obj, m.Method, args);
208                 }
209                 else if (m.Method.DeclaringType == typeof(Queryable)) {
210                     // convert Queryable method to Enumerable method
211                     MethodInfo seqMethod = FindEnumerableMethod(m.Method.Name, args, typeArgs);
212                     args = this.FixupQuotedArgs(seqMethod, args);
213                     return Expression.Call(obj, seqMethod, args);
214                 }
215                 else {
216                     // rebind to new method
217                     BindingFlags flags = BindingFlags.Static | (m.Method.IsPublic ? BindingFlags.Public : BindingFlags.NonPublic);
218                     MethodInfo method = FindMethod(m.Method.DeclaringType, m.Method.Name, args, typeArgs, flags);
219                     args = this.FixupQuotedArgs(method, args);
220                     return Expression.Call(obj, method, args);
221                 }
222             }
223             return m;
224         }
225
226         private ReadOnlyCollection<Expression> FixupQuotedArgs(MethodInfo mi, ReadOnlyCollection<Expression> argList) {
227             ParameterInfo[] pis = mi.GetParameters();
228             if (pis.Length > 0) {
229                 List<Expression> newArgs = null;
230                 for (int i = 0, n = pis.Length; i < n; i++) {
231                     Expression arg = argList[i];
232                     ParameterInfo pi = pis[i];
233                     arg = FixupQuotedExpression(pi.ParameterType, arg);
234                     if (newArgs == null && arg != argList[i]) {
235                         newArgs = new List<Expression>(argList.Count);
236                         for (int j = 0; j < i; j++) {
237                             newArgs.Add(argList[j]);
238                         }
239                     }
240                     if (newArgs != null) {
241                         newArgs.Add(arg);
242                     }
243                 }
244                 if (newArgs != null) 
245                     argList = newArgs.ToReadOnlyCollection();
246             }
247             return argList;
248         }
249
250         private Expression FixupQuotedExpression(Type type, Expression expression) {
251             Expression expr = expression;
252             while (true) {
253                 if (type.IsAssignableFrom(expr.Type))
254                     return expr;
255                 if (expr.NodeType != ExpressionType.Quote)
256                     break;
257                 expr = ((UnaryExpression)expr).Operand;
258             }
259             if (!type.IsAssignableFrom(expr.Type) && type.IsArray && expr.NodeType == ExpressionType.NewArrayInit) {
260                 Type strippedType = StripExpression(expr.Type);
261                 if (type.IsAssignableFrom(strippedType)) {
262                     Type elementType = type.GetElementType();
263                     NewArrayExpression na = (NewArrayExpression)expr;
264                     List<Expression> exprs = new List<Expression>(na.Expressions.Count);
265                     for (int i = 0, n = na.Expressions.Count; i < n; i++) {
266                         exprs.Add(this.FixupQuotedExpression(elementType, na.Expressions[i]));
267                     }
268                     expression = Expression.NewArrayInit(elementType, exprs);
269                 }
270             }
271             return expression;
272         }
273
274         internal override Expression VisitLambda(LambdaExpression lambda) {
275             return lambda;
276         }
277
278         private static Type GetPublicType(Type t)
279         {
280             // If we create a constant explicitly typed to be a private nested type,
281             // such as Lookup<,>.Grouping or a compiler-generated iterator class, then
282             // we cannot use the expression tree in a context which has only execution
283             // permissions.  We should endeavour to translate constants into 
284             // new constants which have public types.
285             if (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(Lookup<,>.Grouping))
286                 return typeof(IGrouping<,>).MakeGenericType(t.GetGenericArguments());
287             if (!t.IsNestedPrivate)
288                 return t;
289             foreach (Type iType in t.GetInterfaces())
290             {
291                 if (iType.IsGenericType && iType.GetGenericTypeDefinition() == typeof(IEnumerable<>))
292                     return iType;
293             }
294             if (typeof(IEnumerable).IsAssignableFrom(t))
295                 return typeof(IEnumerable);
296             return t;
297         }
298
299         internal override Expression VisitConstant(ConstantExpression c) {
300             EnumerableQuery sq = c.Value as EnumerableQuery;
301             if (sq != null) {
302                 if (sq.Enumerable != null)
303                 {
304                     Type t = GetPublicType(sq.Enumerable.GetType());
305                     return Expression.Constant(sq.Enumerable, t);
306                 }
307                 return this.Visit(sq.Expression);
308             }
309             return c;
310         }
311
312         internal override Expression VisitParameter(ParameterExpression p) {
313             return p;
314         }
315
316         private static volatile ILookup<string, MethodInfo> _seqMethods;
317         static MethodInfo FindEnumerableMethod(string name, ReadOnlyCollection<Expression> args, params Type[] typeArgs) {
318             if (_seqMethods == null) {
319                 _seqMethods = typeof(Enumerable).GetMethods(BindingFlags.Static|BindingFlags.Public).ToLookup(m => m.Name);
320             }
321             MethodInfo mi = _seqMethods[name].FirstOrDefault(m => ArgsMatch(m, args, typeArgs));
322             if (mi == null)
323                 throw Error.NoMethodOnTypeMatchingArguments(name, typeof(Enumerable));
324             if (typeArgs != null)
325                 return mi.MakeGenericMethod(typeArgs);
326             return mi;
327         }
328
329         internal static MethodInfo FindMethod(Type type, string name, ReadOnlyCollection<Expression> args, Type[] typeArgs, BindingFlags flags) {
330             MethodInfo[] methods = type.GetMethods(flags).Where(m => m.Name == name).ToArray();
331             if (methods.Length == 0)
332                 throw Error.NoMethodOnType(name, type);
333             MethodInfo mi = methods.FirstOrDefault(m => ArgsMatch(m, args, typeArgs));
334             if (mi == null)
335                 throw Error.NoMethodOnTypeMatchingArguments(name, type);
336             if (typeArgs != null)
337                 return mi.MakeGenericMethod(typeArgs);
338             return mi;
339         }
340
341         private static bool ArgsMatch(MethodInfo m, ReadOnlyCollection<Expression> args, Type[] typeArgs) {
342             ParameterInfo[] mParams = m.GetParameters();
343             if (mParams.Length != args.Count)
344                 return false;
345             if (!m.IsGenericMethod && typeArgs != null && typeArgs.Length > 0) {
346                 return false;
347             }
348             if (!m.IsGenericMethodDefinition && m.IsGenericMethod && m.ContainsGenericParameters) {
349                 m = m.GetGenericMethodDefinition();
350             }
351             if (m.IsGenericMethodDefinition) {
352                 if (typeArgs == null || typeArgs.Length == 0)
353                     return false;
354                 if (m.GetGenericArguments().Length != typeArgs.Length)
355                     return false;
356                 m = m.MakeGenericMethod(typeArgs);
357                 mParams = m.GetParameters();
358             }
359             for (int i = 0, n = args.Count; i < n; i++) {
360                 Type parameterType = mParams[i].ParameterType;
361                 if (parameterType == null)
362                     return false;
363                 if (parameterType.IsByRef)
364                     parameterType = parameterType.GetElementType();
365                 Expression arg = args[i];
366                 if (!parameterType.IsAssignableFrom(arg.Type)) {
367                     if (arg.NodeType == ExpressionType.Quote) {
368                         arg = ((UnaryExpression)arg).Operand;
369                     }
370                     if (!parameterType.IsAssignableFrom(arg.Type) &&
371                         !parameterType.IsAssignableFrom(StripExpression(arg.Type))) {
372                         return false;
373                     }
374                 }
375             }
376             return true;
377         }
378
379         private static Type StripExpression(Type type) {
380             bool isArray = type.IsArray;
381             Type tmp = isArray ? type.GetElementType() : type;
382             Type eType = TypeHelper.FindGenericType(typeof(Expression<>), tmp);
383             if (eType != null)
384                 tmp = eType.GetGenericArguments()[0];
385             if (isArray) {
386                 int rank = type.GetArrayRank();
387                 return (rank == 1) ? tmp.MakeArrayType() : tmp.MakeArrayType(rank);
388             }
389             return type;
390         }
391     }
392 }