50ccc2c6d099820465bb2e2fbd30978117c36f6e
[mono.git] / mcs / class / referencesource / System.Data.Linq / SqlClient / Query / Funcletizer.cs
1 using System;
2 using System.Collections.Generic;
3 using System.Collections.ObjectModel;
4 using System.Linq;
5 using System.Linq.Expressions;
6
7 namespace System.Data.Linq.SqlClient {
8     using System.Data.Linq.Mapping;
9     using System.Diagnostics.CodeAnalysis;
10
11     internal static class Funcletizer {
12
13         internal static Expression Funcletize(Expression expression) {
14             return new Localizer(new LocalMapper().MapLocals(expression)).Localize(expression);
15         }
16
17         class Localizer : ExpressionVisitor {
18             Dictionary<Expression, bool> locals;
19
20             internal Localizer(Dictionary<Expression, bool> locals) {
21                 this.locals = locals;
22             }
23
24             internal Expression Localize(Expression expression) {
25                 return this.Visit(expression);
26             }
27
28             internal override Expression Visit(Expression exp) {
29                 if (exp == null) {
30                     return null;
31                 }
32                 if (this.locals.ContainsKey(exp)) {
33                     return MakeLocal(exp);
34                 }
35                 if (exp.NodeType == (ExpressionType)InternalExpressionType.Known) {
36                     return exp;
37                 }
38                 return base.Visit(exp);
39             }
40
41             private static Expression MakeLocal(Expression e) {
42                 if (e.NodeType == ExpressionType.Constant) {
43                     return e;
44                 }
45                 else if (e.NodeType == ExpressionType.Convert || e.NodeType == ExpressionType.ConvertChecked) {
46                     UnaryExpression ue = (UnaryExpression)e;
47                     if (ue.Type == typeof(object)) {
48                         Expression local = MakeLocal(ue.Operand);
49                         return (e.NodeType == ExpressionType.Convert) ? Expression.Convert(local, e.Type) : Expression.ConvertChecked(local, e.Type);
50                     }
51                     // convert a const null
52                     if (ue.Operand.NodeType == ExpressionType.Constant) {
53                         ConstantExpression c = (ConstantExpression)ue.Operand;
54                         if (c.Value == null) {
55                             return Expression.Constant(null, ue.Type);
56                         }
57                     }
58                 }
59                 return Expression.Invoke(Expression.Constant(Expression.Lambda(e).Compile()));
60             }
61         }
62         class DependenceChecker : ExpressionVisitor {
63             HashSet<ParameterExpression> inScope = new HashSet<ParameterExpression>();
64             bool isIndependent = true;
65
66             /// <summary>
67             /// This method returns 'true' when the expression doesn't reference any parameters 
68             /// from outside the scope of the expression.
69             /// </summary>
70             static public bool IsIndependent(Expression expression) {
71                 var v = new DependenceChecker();
72                 v.Visit(expression);
73                 return v.isIndependent;
74             }
75             internal override Expression VisitLambda(LambdaExpression lambda) {
76                 foreach (var p in lambda.Parameters) {
77                     this.inScope.Add(p);
78                 }
79                 return base.VisitLambda(lambda);
80             }
81             internal override Expression VisitParameter(ParameterExpression p) {
82                 this.isIndependent &= this.inScope.Contains(p);
83                 return p;
84             }
85         }
86
87         class LocalMapper : ExpressionVisitor {
88             bool isRemote;
89             Dictionary<Expression, bool> locals;
90
91             internal Dictionary<Expression, bool> MapLocals(Expression expression) {
92                 this.locals = new Dictionary<Expression, bool>();
93                 this.isRemote = false;
94                 this.Visit(expression);
95                 return this.locals;
96             }
97
98             internal override Expression Visit(Expression expression) {
99                 if (expression == null) {
100                     return null;
101                 }
102                 bool saveIsRemote = this.isRemote;
103                 switch (expression.NodeType) {
104                     case (ExpressionType)InternalExpressionType.Known:
105                         return expression;
106                     case (ExpressionType)ExpressionType.Constant:
107                         break;
108                     default:
109                         this.isRemote = false;
110                         base.Visit(expression);
111                         if (!this.isRemote
112                             && expression.NodeType != ExpressionType.Lambda
113                             && expression.NodeType != ExpressionType.Quote
114                             && DependenceChecker.IsIndependent(expression)) {
115                             this.locals[expression] = true; // Not 'Add' because the same expression may exist in the tree twice. 
116                         }
117                         break;
118                 }
119                 if (typeof(ITable).IsAssignableFrom(expression.Type) ||
120                     typeof(DataContext).IsAssignableFrom(expression.Type)) {
121                     this.isRemote = true;
122                 }
123                 this.isRemote |= saveIsRemote;
124                 return expression;
125             }
126             internal override Expression VisitMemberAccess(MemberExpression m) {
127                 base.VisitMemberAccess(m);
128                 this.isRemote |= (m.Expression != null && typeof(ITable).IsAssignableFrom(m.Expression.Type));
129                 return m;
130             }
131             internal override Expression VisitMethodCall(MethodCallExpression m) {
132                 base.VisitMethodCall(m);
133                 this.isRemote |= m.Method.DeclaringType == typeof(System.Data.Linq.Provider.DataManipulation)
134                               || Attribute.IsDefined(m.Method, typeof(FunctionAttribute));
135                 return m;
136             }
137         }
138     }
139
140     internal abstract class ExpressionVisitor {
141         internal ExpressionVisitor() {
142         }
143
144         [SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
145         [SuppressMessage("Microsoft.Performance", "CA1800:DoNotCastUnnecessarily", Justification = "Microsoft: Cast is dependent on node type and casts do not happen unecessarily in a single code path.")]
146         internal virtual Expression Visit(Expression exp) {
147             if (exp == null)
148                 return exp;
149             switch (exp.NodeType) {
150                 case ExpressionType.Negate:
151                 case ExpressionType.NegateChecked:
152                 case ExpressionType.Not:
153                 case ExpressionType.Convert:
154                 case ExpressionType.ConvertChecked:
155                 case ExpressionType.ArrayLength:
156                 case ExpressionType.Quote:
157                 case ExpressionType.TypeAs:
158                     return this.VisitUnary((UnaryExpression)exp);
159                 case ExpressionType.Add:
160                 case ExpressionType.AddChecked:
161                 case ExpressionType.Subtract:
162                 case ExpressionType.SubtractChecked:
163                 case ExpressionType.Multiply:
164                 case ExpressionType.MultiplyChecked:
165                 case ExpressionType.Divide:
166                 case ExpressionType.Modulo:
167                 case ExpressionType.Power:
168                 case ExpressionType.And:
169                 case ExpressionType.AndAlso:
170                 case ExpressionType.Or:
171                 case ExpressionType.OrElse:
172                 case ExpressionType.LessThan:
173                 case ExpressionType.LessThanOrEqual:
174                 case ExpressionType.GreaterThan:
175                 case ExpressionType.GreaterThanOrEqual:
176                 case ExpressionType.Equal:
177                 case ExpressionType.NotEqual:
178                 case ExpressionType.Coalesce:
179                 case ExpressionType.ArrayIndex:
180                 case ExpressionType.RightShift:
181                 case ExpressionType.LeftShift:
182                 case ExpressionType.ExclusiveOr:
183                     return this.VisitBinary((BinaryExpression)exp);
184                 case ExpressionType.TypeIs:
185                     return this.VisitTypeIs((TypeBinaryExpression)exp);
186                 case ExpressionType.Conditional:
187                     return this.VisitConditional((ConditionalExpression)exp);
188                 case ExpressionType.Constant:
189                     return this.VisitConstant((ConstantExpression)exp);
190                 case ExpressionType.Parameter:
191                     return this.VisitParameter((ParameterExpression)exp);
192                 case ExpressionType.MemberAccess:
193                     return this.VisitMemberAccess((MemberExpression)exp);
194                 case ExpressionType.Call:
195                     return this.VisitMethodCall((MethodCallExpression)exp);
196                 case ExpressionType.Lambda:
197                     return this.VisitLambda((LambdaExpression)exp);
198                 case ExpressionType.New:
199                     return this.VisitNew((NewExpression)exp);
200                 case ExpressionType.NewArrayInit:
201                 case ExpressionType.NewArrayBounds:
202                     return this.VisitNewArray((NewArrayExpression)exp);
203                 case ExpressionType.Invoke:
204                     return this.VisitInvocation((InvocationExpression)exp);
205                 case ExpressionType.MemberInit:
206                     return this.VisitMemberInit((MemberInitExpression)exp);
207                 case ExpressionType.ListInit:
208                     return this.VisitListInit((ListInitExpression)exp);
209                 case ExpressionType.UnaryPlus:
210                     if (exp.Type == typeof(TimeSpan))
211                         return this.VisitUnary((UnaryExpression)exp);
212                     throw Error.UnhandledExpressionType(exp.NodeType);
213                 default:
214                     throw Error.UnhandledExpressionType(exp.NodeType);
215             }
216         }
217
218         internal virtual MemberBinding VisitBinding(MemberBinding binding) {
219             switch (binding.BindingType) {
220                 case MemberBindingType.Assignment:
221                     return this.VisitMemberAssignment((MemberAssignment)binding);
222                 case MemberBindingType.MemberBinding:
223                     return this.VisitMemberMemberBinding((MemberMemberBinding)binding);
224                 case MemberBindingType.ListBinding:
225                     return this.VisitMemberListBinding((MemberListBinding)binding);
226                 default:
227                     throw Error.UnhandledBindingType(binding.BindingType);
228             }
229         }
230
231         internal virtual ElementInit VisitElementInitializer(ElementInit initializer) {
232             ReadOnlyCollection<Expression> arguments = this.VisitExpressionList(initializer.Arguments);
233             if (arguments != initializer.Arguments) {
234                 return Expression.ElementInit(initializer.AddMethod, arguments);
235             }
236             return initializer;
237         }
238
239         internal virtual Expression VisitUnary(UnaryExpression u) {
240             Expression operand = this.Visit(u.Operand);
241             if (operand != u.Operand) {
242                 return Expression.MakeUnary(u.NodeType, operand, u.Type, u.Method);
243             }
244             return u;
245         }
246
247         internal virtual Expression VisitBinary(BinaryExpression b) {
248             Expression left = this.Visit(b.Left);
249             Expression right = this.Visit(b.Right);
250             if (left != b.Left || right != b.Right) {
251                 return Expression.MakeBinary(b.NodeType, left, right, b.IsLiftedToNull, b.Method);
252             }
253             return b;
254         }
255
256         internal virtual Expression VisitTypeIs(TypeBinaryExpression b) {
257             Expression expr = this.Visit(b.Expression);
258             if (expr != b.Expression) {
259                 return Expression.TypeIs(expr, b.TypeOperand);
260             }
261             return b;
262         }
263
264         internal virtual Expression VisitConstant(ConstantExpression c) {
265             return c;
266         }
267
268         internal virtual Expression VisitConditional(ConditionalExpression c) {
269             Expression test = this.Visit(c.Test);
270             Expression ifTrue = this.Visit(c.IfTrue);
271             Expression ifFalse = this.Visit(c.IfFalse);
272             if (test != c.Test || ifTrue != c.IfTrue || ifFalse != c.IfFalse) {
273                 return Expression.Condition(test, ifTrue, ifFalse);
274             }
275             return c;
276         }
277
278         internal virtual Expression VisitParameter(ParameterExpression p) {
279             return p;
280         }
281
282         internal virtual Expression VisitMemberAccess(MemberExpression m) {
283             Expression exp = this.Visit(m.Expression);
284             if (exp != m.Expression) {
285                 return Expression.MakeMemberAccess(exp, m.Member);
286             }
287             return m;
288         }
289
290         internal virtual Expression VisitMethodCall(MethodCallExpression m) {
291             Expression obj = this.Visit(m.Object);
292             IEnumerable<Expression> args = this.VisitExpressionList(m.Arguments);
293             if (obj != m.Object || args != m.Arguments) {
294                 return Expression.Call(obj, m.Method, args);
295             }
296             return m;
297         }
298
299         internal virtual ReadOnlyCollection<Expression> VisitExpressionList(ReadOnlyCollection<Expression> original) {
300             List<Expression> list = null;
301             for (int i = 0, n = original.Count; i < n; i++) {
302                 Expression p = this.Visit(original[i]);
303                 if (list != null) {
304                     list.Add(p);
305                 }
306                 else if (p != original[i]) {
307                     list = new List<Expression>(n);
308                     for (int j = 0; j < i; j++) {
309                         list.Add(original[j]);
310                     }
311                     list.Add(p);
312                 }
313             }
314             if (list != null)
315                 return new ReadOnlyCollection<Expression>(list);
316             return original;
317         }
318
319         internal virtual MemberAssignment VisitMemberAssignment(MemberAssignment assignment) {
320             Expression e = this.Visit(assignment.Expression);
321             if (e != assignment.Expression) {
322                 return Expression.Bind(assignment.Member, e);
323             }
324             return assignment;
325         }
326
327         internal virtual MemberMemberBinding VisitMemberMemberBinding(MemberMemberBinding binding) {
328             IEnumerable<MemberBinding> bindings = this.VisitBindingList(binding.Bindings);
329             if (bindings != binding.Bindings) {
330                 return Expression.MemberBind(binding.Member, bindings);
331             }
332             return binding;
333         }
334
335         internal virtual MemberListBinding VisitMemberListBinding(MemberListBinding binding) {
336             IEnumerable<ElementInit> initializers = this.VisitElementInitializerList(binding.Initializers);
337             if (initializers != binding.Initializers) {
338                 return Expression.ListBind(binding.Member, initializers);
339             }
340             return binding;
341         }
342
343         internal virtual IEnumerable<MemberBinding> VisitBindingList(ReadOnlyCollection<MemberBinding> original) {
344             List<MemberBinding> list = null;
345             for (int i = 0, n = original.Count; i < n; i++) {
346                 MemberBinding b = this.VisitBinding(original[i]);
347                 if (list != null) {
348                     list.Add(b);
349                 }
350                 else if (b != original[i]) {
351                     list = new List<MemberBinding>(n);
352                     for (int j = 0; j < i; j++) {
353                         list.Add(original[j]);
354                     }
355                     list.Add(b);
356                 }
357             }
358             if (list != null)
359                 return list;
360             return original;
361         }
362
363         internal virtual IEnumerable<ElementInit> VisitElementInitializerList(ReadOnlyCollection<ElementInit> original) {
364             List<ElementInit> list = null;
365             for (int i = 0, n = original.Count; i < n; i++) {
366                 ElementInit init = this.VisitElementInitializer(original[i]);
367                 if (list != null) {
368                     list.Add(init);
369                 }
370                 else if (init != original[i]) {
371                     list = new List<ElementInit>(n);
372                     for (int j = 0; j < i; j++) {
373                         list.Add(original[j]);
374                     }
375                     list.Add(init);
376                 }
377             }
378             if (list != null) {
379                 return list;
380             }
381             return original;
382         }
383
384         internal virtual Expression VisitLambda(LambdaExpression lambda) {
385             Expression body = this.Visit(lambda.Body);
386             if (body != lambda.Body) {
387                 return Expression.Lambda(lambda.Type, body, lambda.Parameters);
388             }
389             return lambda;
390         }
391
392         internal virtual NewExpression VisitNew(NewExpression nex) {
393             IEnumerable<Expression> args = this.VisitExpressionList(nex.Arguments);
394             if (args != nex.Arguments) {
395                 if (nex.Members != null) {
396                     return Expression.New(nex.Constructor, args, nex.Members);
397                 }
398                 else {
399                     return Expression.New(nex.Constructor, args);
400                 }
401             }
402             return nex;
403         }
404
405         internal virtual Expression VisitMemberInit(MemberInitExpression init) {
406             NewExpression n = this.VisitNew(init.NewExpression);
407             IEnumerable<MemberBinding> bindings = this.VisitBindingList(init.Bindings);
408             if (n != init.NewExpression || bindings != init.Bindings) {
409                 return Expression.MemberInit(n, bindings);
410             }
411             return init;
412         }
413
414         internal virtual Expression VisitListInit(ListInitExpression init) {
415             NewExpression n = this.VisitNew(init.NewExpression);
416             IEnumerable<ElementInit> initializers = this.VisitElementInitializerList(init.Initializers);
417             if (n != init.NewExpression || initializers != init.Initializers) {
418                 return Expression.ListInit(n, initializers);
419             }
420             return init;
421         }
422
423         internal virtual Expression VisitNewArray(NewArrayExpression na) {
424             IEnumerable<Expression> exprs = this.VisitExpressionList(na.Expressions);
425             if (exprs != na.Expressions) {
426                 if (na.NodeType == ExpressionType.NewArrayInit) {
427                     return Expression.NewArrayInit(na.Type.GetElementType(), exprs);
428                 }
429                 else {
430                     return Expression.NewArrayBounds(na.Type.GetElementType(), exprs);
431                 }
432             }
433             return na;
434         }
435
436         internal virtual Expression VisitInvocation(InvocationExpression iv) {
437             IEnumerable<Expression> args = this.VisitExpressionList(iv.Arguments);
438             Expression expr = this.Visit(iv.Expression);
439             if (args != iv.Arguments || expr != iv.Expression) {
440                 return Expression.Invoke(expr, args);
441             }
442             return iv;
443         }
444     }
445 }