New tests, update.
[mono.git] / mcs / class / System.Core / System.Linq.Expressions / ExpressionTransformer.cs
1 //
2 // ExpressionTransformer.cs
3 //
4 // Authors:
5 //      Roei Erez (roeie@mainsoft.com)
6 //
7 // Copyright (C) 2007 Novell, Inc (http://www.novell.com)
8 //
9 // Permission is hereby granted, free of charge, to any person obtaining
10 // a copy of this software and associated documentation files (the
11 // "Software"), to deal in the Software without restriction, including
12 // without limitation the rights to use, copy, modify, merge, publish,
13 // distribute, sublicense, and/or sell copies of the Software, and to
14 // permit persons to whom the Software is furnished to do so, subject to
15 // the following conditions:
16 //
17 // The above copyright notice and this permission notice shall be
18 // included in all copies or substantial portions of the Software.
19 //
20 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
21 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
22 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
23 // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
24 // LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
25 // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
26 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
27 //
28
29 using System;
30 using System.Collections.ObjectModel;
31 using System.Collections.Generic;
32 using System.Linq;
33 using System.Linq.Expressions;
34
35 namespace System.Linq.Expressions {
36
37         abstract class ExpressionTransformer {
38
39                 public Expression Transform (Expression e)
40                 {
41                         return Visit (e);
42                 }
43
44                 protected virtual Expression Visit (Expression expression)
45                 {
46                         if (expression == null)
47                                 return null;
48
49                         switch (expression.NodeType) {
50                         case ExpressionType.Negate:
51                         case ExpressionType.NegateChecked:
52                         case ExpressionType.Not:
53                         case ExpressionType.Convert:
54                         case ExpressionType.ConvertChecked:
55                         case ExpressionType.ArrayLength:
56                         case ExpressionType.Quote:
57                         case ExpressionType.TypeAs:
58                         case ExpressionType.UnaryPlus:
59                                 return VisitUnary ((UnaryExpression) expression);
60                         case ExpressionType.Add:
61                         case ExpressionType.AddChecked:
62                         case ExpressionType.Subtract:
63                         case ExpressionType.SubtractChecked:
64                         case ExpressionType.Multiply:
65                         case ExpressionType.MultiplyChecked:
66                         case ExpressionType.Divide:
67                         case ExpressionType.Modulo:
68                         case ExpressionType.Power:
69                         case ExpressionType.And:
70                         case ExpressionType.AndAlso:
71                         case ExpressionType.Or:
72                         case ExpressionType.OrElse:
73                         case ExpressionType.LessThan:
74                         case ExpressionType.LessThanOrEqual:
75                         case ExpressionType.GreaterThan:
76                         case ExpressionType.GreaterThanOrEqual:
77                         case ExpressionType.Equal:
78                         case ExpressionType.NotEqual:
79                         case ExpressionType.Coalesce:
80                         case ExpressionType.ArrayIndex:
81                         case ExpressionType.RightShift:
82                         case ExpressionType.LeftShift:
83                         case ExpressionType.ExclusiveOr:
84                                 return VisitBinary ((BinaryExpression) expression);
85                         case ExpressionType.TypeIs:
86                                 return VisitTypeIs ((TypeBinaryExpression) expression);
87                         case ExpressionType.Conditional:
88                                 return VisitConditional ((ConditionalExpression) expression);
89                         case ExpressionType.Constant:
90                                 return VisitConstant ((ConstantExpression) expression);
91                         case ExpressionType.Parameter:
92                                 return VisitParameter ((ParameterExpression) expression);
93                         case ExpressionType.MemberAccess:
94                                 return VisitMemberAccess ((MemberExpression) expression);
95                         case ExpressionType.Call:
96                                 return VisitMethodCall ((MethodCallExpression) expression);
97                         case ExpressionType.Lambda:
98                                 return VisitLambda ((LambdaExpression) expression);
99                         case ExpressionType.New:
100                                 return VisitNew ((NewExpression) expression);
101                         case ExpressionType.NewArrayInit:
102                         case ExpressionType.NewArrayBounds:
103                                 return VisitNewArray ((NewArrayExpression) expression);
104                         case ExpressionType.Invoke:
105                                 return VisitInvocation ((InvocationExpression) expression);
106                         case ExpressionType.MemberInit:
107                                 return VisitMemberInit ((MemberInitExpression) expression);
108                         case ExpressionType.ListInit:
109                                 return VisitListInit ((ListInitExpression) expression);
110                         default:
111                                 throw new ArgumentException (string.Format ("Unhandled expression type: '{0}'", expression.NodeType));
112                         }
113                 }
114
115                 protected virtual MemberBinding VisitBinding (MemberBinding binding)
116                 {
117                         switch (binding.BindingType) {
118                         case MemberBindingType.Assignment:
119                                 return VisitMemberAssignment ((MemberAssignment) binding);
120                         case MemberBindingType.MemberBinding:
121                                 return VisitMemberMemberBinding ((MemberMemberBinding) binding);
122                         case MemberBindingType.ListBinding:
123                                 return VisitMemberListBinding ((MemberListBinding) binding);
124                         default:
125                                 throw new ArgumentException (string.Format ("Unhandled binding type '{0}'", binding.BindingType));
126                         }
127                 }
128
129                 protected virtual ElementInit VisitElementInitializer (ElementInit initializer)
130                 {
131                         ReadOnlyCollection<Expression> transformed = VisitExpressionList (initializer.Arguments);
132                         if (transformed != initializer.Arguments)
133                                 return Expression.ElementInit (initializer.AddMethod, transformed);
134                         return initializer;
135                 }
136
137                 protected virtual UnaryExpression VisitUnary (UnaryExpression unary)
138                 {
139                         Expression transformedOperand = Visit (unary.Operand);
140                         if (transformedOperand != unary.Operand)
141                                 return Expression.MakeUnary (unary.NodeType, transformedOperand, unary.Type, unary.Method);
142                         return unary;
143                 }
144
145                 protected virtual BinaryExpression VisitBinary (BinaryExpression binary)
146                 {
147                         Expression left = Visit (binary.Left);
148                         Expression right = Visit (binary.Right);
149                         LambdaExpression conversion = VisitLambda (binary.Conversion);
150                         if (left != binary.Left || right != binary.Right || conversion != binary.Conversion)
151                                 return Expression.MakeBinary (binary.NodeType, left, right, binary.IsLiftedToNull, binary.Method, conversion);
152                         return binary;
153                 }
154
155                 protected virtual TypeBinaryExpression VisitTypeIs (TypeBinaryExpression type)
156                 {
157                         Expression inner = Visit (type.Expression);
158                         if (inner != type.Expression)
159                                 return Expression.TypeIs (inner, type.TypeOperand);
160                         return type;
161                 }
162
163                 protected virtual ConstantExpression VisitConstant (ConstantExpression constant)
164                 {
165                         return constant;
166                 }
167
168                 protected virtual ConditionalExpression VisitConditional (ConditionalExpression conditional)
169                 {
170                         Expression test = Visit (conditional.Test);
171                         Expression ifTrue = Visit (conditional.IfTrue);
172                         Expression ifFalse = Visit (conditional.IfFalse);
173                         if (test != conditional.Test || ifTrue != conditional.IfTrue || ifFalse != conditional.IfFalse)
174                                 return Expression.Condition (test, ifTrue, ifFalse);
175                         return conditional;
176                 }
177
178                 protected virtual ParameterExpression VisitParameter (ParameterExpression parameter)
179                 {
180                         return parameter;
181                 }
182
183                 protected virtual MemberExpression VisitMemberAccess (MemberExpression member)
184                 {
185                         Expression memberExp = Visit (member.Expression);
186                         if (memberExp != member.Expression)
187                                 return Expression.MakeMemberAccess (memberExp, member.Member);
188                         return member;
189                 }
190
191                 protected virtual MethodCallExpression VisitMethodCall (MethodCallExpression methodCall)
192                 {
193                         Expression instance = Visit (methodCall.Object);
194                         ReadOnlyCollection<Expression> args = VisitExpressionList (methodCall.Arguments);
195                         if (instance != methodCall.Object || args != methodCall.Arguments)
196                                 return Expression.Call (instance, methodCall.Method, args);
197                         return methodCall;
198                 }
199
200                 protected virtual ReadOnlyCollection<Expression> VisitExpressionList (ReadOnlyCollection<Expression> list)
201                 {
202                         return VisitList<Expression> (list, Visit);
203                 }
204
205                 private ReadOnlyCollection<T> VisitList<T> (ReadOnlyCollection<T> list, Func<T,T> selector) where T :class
206                 {
207                         int index = 0;
208                         T [] arr = null;
209                         foreach (T e in list) {
210                                 T visited = selector (e);
211                                 if (visited != e || arr != null) {
212                                         if (arr == null)
213                                                 arr = new T [list.Count];
214                                         arr [index] = visited;
215                                 }
216                                 index++;
217                         }
218                         if (arr != null)
219                                 return arr.ToReadOnlyCollection ();
220                         return list;
221                 }
222
223                 protected virtual MemberAssignment VisitMemberAssignment (MemberAssignment assignment)
224                 {
225                         Expression inner = Visit (assignment.Expression);
226                         if (inner != assignment.Expression)
227                                 return Expression.Bind (assignment.Member, inner);
228                         return assignment;
229                 }
230
231                 protected virtual MemberMemberBinding VisitMemberMemberBinding (MemberMemberBinding binding)
232                 {
233                         ReadOnlyCollection<MemberBinding> bindingExp = VisitBindingList (binding.Bindings);
234                         if (bindingExp != binding.Bindings)
235                                 return Expression.MemberBind (binding.Member, bindingExp);
236                         return binding;
237                 }
238
239                 protected virtual MemberListBinding VisitMemberListBinding (MemberListBinding binding)
240                 {
241                         ReadOnlyCollection<ElementInit> initializers =
242                                 VisitElementInitializerList (binding.Initializers);
243                         if (initializers != binding.Initializers)
244                                 return Expression.ListBind (binding.Member, initializers);
245                         return binding;
246                 }
247
248                 protected virtual ReadOnlyCollection<MemberBinding> VisitBindingList (ReadOnlyCollection<MemberBinding> list)
249                 {
250                         return VisitList<MemberBinding> (list, VisitBinding);
251                 }
252
253                 protected virtual ReadOnlyCollection<ElementInit> VisitElementInitializerList (ReadOnlyCollection<ElementInit> list)
254                 {
255                         return VisitList<ElementInit> (list, VisitElementInitializer);
256                 }
257
258                 protected virtual LambdaExpression VisitLambda (LambdaExpression lambda)
259                 {
260                         Expression body = Visit (lambda.Body);
261                         ReadOnlyCollection<ParameterExpression> parameters =
262                                 VisitList<ParameterExpression> (lambda.Parameters, VisitParameter);
263                         if (body != lambda.Body || parameters != lambda.Parameters)
264                                 return Expression.Lambda (body, parameters.ToArray());
265                         return lambda;
266                 }
267
268                 protected virtual NewExpression VisitNew (NewExpression nex)
269                 {
270                         ReadOnlyCollection<Expression> args = VisitList (nex.Arguments, Visit);
271                         if (args != nex.Arguments)
272                                 return Expression.New (nex.Constructor, args);
273                         return nex;
274                 }
275
276                 protected virtual MemberInitExpression VisitMemberInit (MemberInitExpression init)
277                 {
278                         NewExpression  newExp = VisitNew (init.NewExpression);
279                         ReadOnlyCollection<MemberBinding> bindings = VisitBindingList (init.Bindings);
280                         if (newExp != init.NewExpression || bindings != init.Bindings)
281                                 return Expression.MemberInit (newExp, bindings);
282                         return init;
283                 }
284
285                 protected virtual ListInitExpression VisitListInit (ListInitExpression init)
286                 {
287                         NewExpression newExp = VisitNew (init.NewExpression);
288                         ReadOnlyCollection<ElementInit> initializers = VisitElementInitializerList (init.Initializers);
289                         if (newExp != init.NewExpression || initializers != init.Initializers)
290                                 return Expression.ListInit (newExp, initializers.ToArray());
291                         return init;
292                 }
293
294                 protected virtual NewArrayExpression VisitNewArray (NewArrayExpression newArray)
295                 {
296                         ReadOnlyCollection<Expression> expressions = VisitExpressionList (newArray.Expressions);
297                         if (expressions != newArray.Expressions) {
298                                 if (newArray.NodeType == ExpressionType.NewArrayBounds)
299                                         return Expression.NewArrayBounds (newArray.Type, expressions);
300                                 else
301                                         return Expression.NewArrayInit (newArray.Type, expressions);
302                         }
303                         return newArray;
304                 }
305
306                 protected virtual InvocationExpression VisitInvocation (InvocationExpression invocation)
307                 {
308                         ReadOnlyCollection<Expression> args = VisitExpressionList (invocation.Arguments);
309                         Expression invocationExp = Visit (invocation.Expression);
310                         if (args != invocation.Arguments || invocationExp != invocation.Expression)
311                                 return Expression.Invoke (invocationExp, args);
312                         return invocation;
313                 }
314         }
315 }