ebcd614d33243d8f024df8ff65461883e33c5da4
[mono.git] / mcs / class / referencesource / System.Core / System / Linq / SequenceQuery.cs
1 using System;
2 using System.Collections;
3 using System.Collections.Generic;
4 using System.Collections.ObjectModel;
5 using System.Linq.Expressions;
6 using System.Reflection;
7 using System.Text;
8
9 // Include Silverlight's managed resources
10 #if SILVERLIGHT
11 using System.Core;
12 #endif //SILVERLIGHT
13
14 namespace System.Linq {
15     
16     // Must remain public for Silverlight
17     public abstract class EnumerableQuery {
18         internal abstract Expression Expression { get; }
19         internal abstract IEnumerable Enumerable { get; }
20         internal static IQueryable Create(Type elementType, IEnumerable sequence){
21             Type seqType = typeof(EnumerableQuery<>).MakeGenericType(elementType);
22 #if SILVERLIGHT
23             return (IQueryable) Activator.CreateInstance(seqType, sequence);
24 #else
25             return (IQueryable) Activator.CreateInstance(seqType, BindingFlags.Instance|BindingFlags.Public|BindingFlags.NonPublic, null, new object[] {sequence}, null);
26 #endif //SILVERLIGHT
27         }
28
29         internal static IQueryable Create(Type elementType, Expression expression) {
30             Type seqType = typeof(EnumerableQuery<>).MakeGenericType(elementType);
31 #if SILVERLIGHT
32             return (IQueryable) Activator.CreateInstance(seqType, expression);
33 #else
34             return (IQueryable) Activator.CreateInstance(seqType, BindingFlags.Instance|BindingFlags.Public|BindingFlags.NonPublic, null, new object[] {expression}, null);
35 #endif //SILVERLIGHT
36         }
37     }    
38
39     // Must remain public for Silverlight
40     public class EnumerableQuery<T> : EnumerableQuery, IOrderedQueryable<T>, IQueryable, IQueryProvider, IEnumerable<T>, IEnumerable {
41         Expression expression;
42         IEnumerable<T> enumerable;
43         
44         IQueryProvider IQueryable.Provider {
45             get{
46                return (IQueryProvider)this;
47             }
48         }
49
50         // Must remain public for Silverlight
51         public EnumerableQuery(IEnumerable<T> enumerable) {
52             this.enumerable = enumerable;
53             this.expression = Expression.Constant(this);            
54         }
55
56         // Must remain public for Silverlight
57         public EnumerableQuery(Expression expression) {
58             this.expression = expression;            
59         }
60
61         internal override Expression Expression {
62             get { return this.expression; }
63         }
64
65         internal override IEnumerable Enumerable {
66             get { return this.enumerable; }
67         }
68         
69         Expression IQueryable.Expression {
70             get { return this.expression; }
71         }
72
73         Type IQueryable.ElementType {
74             get { return typeof(T); }
75         }
76
77         IQueryable IQueryProvider.CreateQuery(Expression expression){
78             if (expression == null)
79                 throw Error.ArgumentNull("expression");
80             Type iqType = TypeHelper.FindGenericType(typeof(IQueryable<>), expression.Type);
81             if (iqType == null)
82                 throw Error.ArgumentNotValid("expression");
83             return EnumerableQuery.Create(iqType.GetGenericArguments()[0], expression);
84         }
85
86         IQueryable<S> IQueryProvider.CreateQuery<S>(Expression expression){
87             if (expression == null)
88                 throw Error.ArgumentNull("expression");
89             if (!typeof(IQueryable<S>).IsAssignableFrom(expression.Type)){
90                 throw Error.ArgumentNotValid("expression");
91             }
92             return new EnumerableQuery<S>(expression);
93         }
94
95         // Baselining as Safe for Mix demo so that interface can be transparent. Marking this
96         // critical (which was the original annotation when porting to silverlight) would violate
97         // fxcop security rules if the interface isn't also critical. However, transparent code
98         // can't access this anyway for Mix since we're not exposing AsQueryable().
99         // Microsoft: the above assertion no longer holds. Now making AsQueryable() public again
100         // the security fallout of which will need to be re-examined.
101         object IQueryProvider.Execute(Expression expression){
102             if (expression == null)
103                 throw Error.ArgumentNull("expression");
104             Type execType = typeof(EnumerableExecutor<>).MakeGenericType(expression.Type);
105             return EnumerableExecutor.Create(expression).ExecuteBoxed();
106         }
107
108         // see above
109         S IQueryProvider.Execute<S>(Expression expression){
110             if (expression == null)
111                 throw Error.ArgumentNull("expression");
112             if (!typeof(S).IsAssignableFrom(expression.Type))
113                 throw Error.ArgumentNotValid("expression");
114             return new EnumerableExecutor<S>(expression).Execute();
115         }
116
117         IEnumerator IEnumerable.GetEnumerator() {
118             return this.GetEnumerator();
119         }
120
121         IEnumerator<T> IEnumerable<T>.GetEnumerator() {
122             return this.GetEnumerator();
123         }
124
125         IEnumerator<T> GetEnumerator() {
126             if (this.enumerable == null) {
127                 EnumerableRewriter rewriter = new EnumerableRewriter();
128                 Expression body = rewriter.Visit(this.expression);
129                 Expression<Func<IEnumerable<T>>> f = Expression.Lambda<Func<IEnumerable<T>>>(body, (IEnumerable<ParameterExpression>)null);
130                 this.enumerable = f.Compile()();
131             }
132             return this.enumerable.GetEnumerator();
133         }
134
135         public override string ToString() {
136             ConstantExpression c = this.expression as ConstantExpression;
137             if (c != null && c.Value == this) {
138                 if (this.enumerable != null)
139                     return this.enumerable.ToString();
140                 return "null";
141             }
142             return this.expression.ToString();
143         }
144     }
145
146     // Must remain public for Silverlight
147     public abstract class EnumerableExecutor {
148         internal abstract object ExecuteBoxed();
149
150         internal static EnumerableExecutor Create(Expression expression) {
151             Type execType = typeof(EnumerableExecutor<>).MakeGenericType(expression.Type);
152 #if SILVERLIGHT
153             return (EnumerableExecutor)Activator.CreateInstance(execType, expression);
154 #else
155             return (EnumerableExecutor)Activator.CreateInstance(execType, BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, new object[] { expression }, null);
156 #endif //SILVERLIGHT
157         }
158     }
159
160     // Must remain public for Silverlight
161     public class EnumerableExecutor<T> : EnumerableExecutor{
162         Expression expression;
163         Func<T> func;
164
165         // Must remain public for Silverlight
166         public EnumerableExecutor(Expression expression){
167             this.expression = expression;
168         }
169
170         internal override object ExecuteBoxed() {
171             return this.Execute();
172         }
173
174         internal T Execute(){
175             if (this.func == null){
176                 EnumerableRewriter rewriter = new EnumerableRewriter();
177                 Expression body = rewriter.Visit(this.expression);
178                 Expression<Func<T>> f = Expression.Lambda<Func<T>>(body, (IEnumerable<ParameterExpression>)null);
179                 this.func = f.Compile();
180             }
181             return this.func();
182         }
183     }
184     
185     // 
186     internal class EnumerableRewriter : OldExpressionVisitor {
187
188         internal EnumerableRewriter() {
189         }
190
191         internal override Expression VisitMethodCall(MethodCallExpression m) {
192             Expression obj = this.Visit(m.Object);
193             ReadOnlyCollection<Expression> args = this.VisitExpressionList(m.Arguments);
194
195             // check for args changed
196             if (obj != m.Object || args != m.Arguments) {
197                 Expression[] argArray = args.ToArray();
198                 Type[] typeArgs = (m.Method.IsGenericMethod) ? m.Method.GetGenericArguments() : null;
199
200                 if ((m.Method.IsStatic || m.Method.DeclaringType.IsAssignableFrom(obj.Type)) 
201                     && ArgsMatch(m.Method, args, typeArgs)) {
202                     // current method is still valid
203                     return Expression.Call(obj, m.Method, args);
204                 }
205                 else if (m.Method.DeclaringType == typeof(Queryable)) {
206                     // convert Queryable method to Enumerable method
207                     MethodInfo seqMethod = FindEnumerableMethod(m.Method.Name, args, typeArgs);
208                     args = this.FixupQuotedArgs(seqMethod, args);
209                     return Expression.Call(obj, seqMethod, args);
210                 }
211                 else {
212                     // rebind to new method
213                     BindingFlags flags = BindingFlags.Static | (m.Method.IsPublic ? BindingFlags.Public : BindingFlags.NonPublic);
214                     MethodInfo method = FindMethod(m.Method.DeclaringType, m.Method.Name, args, typeArgs, flags);
215                     args = this.FixupQuotedArgs(method, args);
216                     return Expression.Call(obj, method, args);
217                 }
218             }
219             return m;
220         }
221
222         private ReadOnlyCollection<Expression> FixupQuotedArgs(MethodInfo mi, ReadOnlyCollection<Expression> argList) {
223             ParameterInfo[] pis = mi.GetParameters();
224             if (pis.Length > 0) {
225                 List<Expression> newArgs = null;
226                 for (int i = 0, n = pis.Length; i < n; i++) {
227                     Expression arg = argList[i];
228                     ParameterInfo pi = pis[i];
229                     arg = FixupQuotedExpression(pi.ParameterType, arg);
230                     if (newArgs == null && arg != argList[i]) {
231                         newArgs = new List<Expression>(argList.Count);
232                         for (int j = 0; j < i; j++) {
233                             newArgs.Add(argList[j]);
234                         }
235                     }
236                     if (newArgs != null) {
237                         newArgs.Add(arg);
238                     }
239                 }
240                 if (newArgs != null) 
241                     argList = newArgs.ToReadOnlyCollection();
242             }
243             return argList;
244         }
245
246         private Expression FixupQuotedExpression(Type type, Expression expression) {
247             Expression expr = expression;
248             while (true) {
249                 if (type.IsAssignableFrom(expr.Type))
250                     return expr;
251                 if (expr.NodeType != ExpressionType.Quote)
252                     break;
253                 expr = ((UnaryExpression)expr).Operand;
254             }
255             if (!type.IsAssignableFrom(expr.Type) && type.IsArray && expr.NodeType == ExpressionType.NewArrayInit) {
256                 Type strippedType = StripExpression(expr.Type);
257                 if (type.IsAssignableFrom(strippedType)) {
258                     Type elementType = type.GetElementType();
259                     NewArrayExpression na = (NewArrayExpression)expr;
260                     List<Expression> exprs = new List<Expression>(na.Expressions.Count);
261                     for (int i = 0, n = na.Expressions.Count; i < n; i++) {
262                         exprs.Add(this.FixupQuotedExpression(elementType, na.Expressions[i]));
263                     }
264                     expression = Expression.NewArrayInit(elementType, exprs);
265                 }
266             }
267             return expression;
268         }
269
270         internal override Expression VisitLambda(LambdaExpression lambda) {
271             return lambda;
272         }
273
274         private static Type GetPublicType(Type t)
275         {
276             // If we create a constant explicitly typed to be a private nested type,
277             // such as Lookup<,>.Grouping or a compiler-generated iterator class, then
278             // we cannot use the expression tree in a context which has only execution
279             // permissions.  We should endeavour to translate constants into 
280             // new constants which have public types.
281             if (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(Lookup<,>.Grouping))
282                 return typeof(IGrouping<,>).MakeGenericType(t.GetGenericArguments());
283             if (!t.IsNestedPrivate)
284                 return t;
285             foreach (Type iType in t.GetInterfaces())
286             {
287                 if (iType.IsGenericType && iType.GetGenericTypeDefinition() == typeof(IEnumerable<>))
288                     return iType;
289             }
290             if (typeof(IEnumerable).IsAssignableFrom(t))
291                 return typeof(IEnumerable);
292             return t;
293         }
294
295         internal override Expression VisitConstant(ConstantExpression c) {
296             EnumerableQuery sq = c.Value as EnumerableQuery;
297             if (sq != null) {
298                 if (sq.Enumerable != null)
299                 {
300                     Type t = GetPublicType(sq.Enumerable.GetType());
301                     return Expression.Constant(sq.Enumerable, t);
302                 }
303                 return this.Visit(sq.Expression);
304             }
305             return c;
306         }
307
308         internal override Expression VisitParameter(ParameterExpression p) {
309             return p;
310         }
311
312         private static volatile ILookup<string, MethodInfo> _seqMethods;
313         static MethodInfo FindEnumerableMethod(string name, ReadOnlyCollection<Expression> args, params Type[] typeArgs) {
314             if (_seqMethods == null) {
315                 _seqMethods = typeof(Enumerable).GetMethods(BindingFlags.Static|BindingFlags.Public).ToLookup(m => m.Name);
316             }
317             MethodInfo mi = _seqMethods[name].FirstOrDefault(m => ArgsMatch(m, args, typeArgs));
318             if (mi == null)
319                 throw Error.NoMethodOnTypeMatchingArguments(name, typeof(Enumerable));
320             if (typeArgs != null)
321                 return mi.MakeGenericMethod(typeArgs);
322             return mi;
323         }
324
325         internal static MethodInfo FindMethod(Type type, string name, ReadOnlyCollection<Expression> args, Type[] typeArgs, BindingFlags flags) {
326             MethodInfo[] methods = type.GetMethods(flags).Where(m => m.Name == name).ToArray();
327             if (methods.Length == 0)
328                 throw Error.NoMethodOnType(name, type);
329             MethodInfo mi = methods.FirstOrDefault(m => ArgsMatch(m, args, typeArgs));
330             if (mi == null)
331                 throw Error.NoMethodOnTypeMatchingArguments(name, type);
332             if (typeArgs != null)
333                 return mi.MakeGenericMethod(typeArgs);
334             return mi;
335         }
336
337         private static bool ArgsMatch(MethodInfo m, ReadOnlyCollection<Expression> args, Type[] typeArgs) {
338             ParameterInfo[] mParams = m.GetParameters();
339             if (mParams.Length != args.Count)
340                 return false;
341             if (!m.IsGenericMethod && typeArgs != null && typeArgs.Length > 0) {
342                 return false;
343             }
344             if (!m.IsGenericMethodDefinition && m.IsGenericMethod && m.ContainsGenericParameters) {
345                 m = m.GetGenericMethodDefinition();
346             }
347             if (m.IsGenericMethodDefinition) {
348                 if (typeArgs == null || typeArgs.Length == 0)
349                     return false;
350                 if (m.GetGenericArguments().Length != typeArgs.Length)
351                     return false;
352                 m = m.MakeGenericMethod(typeArgs);
353                 mParams = m.GetParameters();
354             }
355             for (int i = 0, n = args.Count; i < n; i++) {
356                 Type parameterType = mParams[i].ParameterType;
357                 if (parameterType == null)
358                     return false;
359                 if (parameterType.IsByRef)
360                     parameterType = parameterType.GetElementType();
361                 Expression arg = args[i];
362                 if (!parameterType.IsAssignableFrom(arg.Type)) {
363                     if (arg.NodeType == ExpressionType.Quote) {
364                         arg = ((UnaryExpression)arg).Operand;
365                     }
366                     if (!parameterType.IsAssignableFrom(arg.Type) &&
367                         !parameterType.IsAssignableFrom(StripExpression(arg.Type))) {
368                         return false;
369                     }
370                 }
371             }
372             return true;
373         }
374
375         private static Type StripExpression(Type type) {
376             bool isArray = type.IsArray;
377             Type tmp = isArray ? type.GetElementType() : type;
378             Type eType = TypeHelper.FindGenericType(typeof(Expression<>), tmp);
379             if (eType != null)
380                 tmp = eType.GetGenericArguments()[0];
381             if (isArray) {
382                 int rank = type.GetArrayRank();
383                 return (rank == 1) ? tmp.MakeArrayType() : tmp.MakeArrayType(rank);
384             }
385             return type;
386         }
387     }
388 }