Implements lifted nullable comparison with null
[mono.git] / mcs / class / System.Core / System.Linq.Expressions / ExpressionTransformer.cs
1 //
2 // ExpressionTransformer.cs
3 //
4 // Authors:
5 //      Roei Erez (roeie@mainsoft.com)
6 //  Jb Evain (jbevain@novell.com)
7 //
8 // Copyright (C) 2007 Novell, Inc (http://www.novell.com)
9 //
10 // Permission is hereby granted, free of charge, to any person obtaining
11 // a copy of this software and associated documentation files (the
12 // "Software"), to deal in the Software without restriction, including
13 // without limitation the rights to use, copy, modify, merge, publish,
14 // distribute, sublicense, and/or sell copies of the Software, and to
15 // permit persons to whom the Software is furnished to do so, subject to
16 // the following conditions:
17 //
18 // The above copyright notice and this permission notice shall be
19 // included in all copies or substantial portions of the Software.
20 //
21 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
22 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
23 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
24 // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
25 // LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
26 // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
27 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
28 //
29
30 using System;
31 using System.Collections.ObjectModel;
32 using System.Collections.Generic;
33 using System.Linq.Expressions;
34
35 namespace System.Linq.Expressions {
36
37         abstract class ExpressionTransformer {
38
39                 public Expression Transform (Expression expression)
40                 {
41                         return Visit (expression);
42                 }
43
44                 protected virtual Expression Visit (Expression exp)
45                 {
46                         if (exp == null) return exp;
47
48                         switch (exp.NodeType) {
49                         case ExpressionType.Negate:
50                         case ExpressionType.NegateChecked:
51                         case ExpressionType.Not:
52                         case ExpressionType.Convert:
53                         case ExpressionType.ConvertChecked:
54                         case ExpressionType.ArrayLength:
55                         case ExpressionType.Quote:
56                         case ExpressionType.TypeAs:
57                         case ExpressionType.UnaryPlus:
58                                 return this.VisitUnary ((UnaryExpression) exp);
59                         case ExpressionType.Add:
60                         case ExpressionType.AddChecked:
61                         case ExpressionType.Subtract:
62                         case ExpressionType.SubtractChecked:
63                         case ExpressionType.Multiply:
64                         case ExpressionType.MultiplyChecked:
65                         case ExpressionType.Divide:
66                         case ExpressionType.Power:
67                         case ExpressionType.Modulo:
68                         case ExpressionType.And:
69                         case ExpressionType.AndAlso:
70                         case ExpressionType.Or:
71                         case ExpressionType.OrElse:
72                         case ExpressionType.LessThan:
73                         case ExpressionType.LessThanOrEqual:
74                         case ExpressionType.GreaterThan:
75                         case ExpressionType.GreaterThanOrEqual:
76                         case ExpressionType.Equal:
77                         case ExpressionType.NotEqual:
78                         case ExpressionType.Coalesce:
79                         case ExpressionType.ArrayIndex:
80                         case ExpressionType.RightShift:
81                         case ExpressionType.LeftShift:
82                         case ExpressionType.ExclusiveOr:
83                                 return this.VisitBinary ((BinaryExpression) exp);
84                         case ExpressionType.TypeIs:
85                                 return this.VisitTypeIs ((TypeBinaryExpression) exp);
86                         case ExpressionType.Conditional:
87                                 return this.VisitConditional ((ConditionalExpression) exp);
88                         case ExpressionType.Constant:
89                                 return this.VisitConstant ((ConstantExpression) exp);
90                         case ExpressionType.Parameter:
91                                 return this.VisitParameter ((ParameterExpression) exp);
92                         case ExpressionType.MemberAccess:
93                                 return this.VisitMemberAccess ((MemberExpression) exp);
94                         case ExpressionType.Call:
95                                 return this.VisitMethodCall ((MethodCallExpression) exp);
96                         case ExpressionType.Lambda:
97                                 return this.VisitLambda ((LambdaExpression) exp);
98                         case ExpressionType.New:
99                                 return this.VisitNew ((NewExpression) exp);
100                         case ExpressionType.NewArrayInit:
101                         case ExpressionType.NewArrayBounds:
102                                 return this.VisitNewArray ((NewArrayExpression) exp);
103                         case ExpressionType.Invoke:
104                                 return this.VisitInvocation ((InvocationExpression) exp);
105                         case ExpressionType.MemberInit:
106                                 return this.VisitMemberInit ((MemberInitExpression) exp);
107                         case ExpressionType.ListInit:
108                                 return this.VisitListInit ((ListInitExpression) exp);
109                         default:
110                                 throw new Exception (string.Format ("Unhandled expression type: '{0}'", exp.NodeType));
111                         }
112                 }
113
114                 protected virtual MemberBinding VisitBinding (MemberBinding binding)
115                 {
116                         switch (binding.BindingType) {
117                         case MemberBindingType.Assignment:
118                                 return this.VisitMemberAssignment ((MemberAssignment) binding);
119                         case MemberBindingType.MemberBinding:
120                                 return this.VisitMemberMemberBinding ((MemberMemberBinding) binding);
121                         case MemberBindingType.ListBinding:
122                                 return this.VisitMemberListBinding ((MemberListBinding) binding);
123                         default:
124                                 throw new Exception (string.Format ("Unhandled binding type '{0}'", binding.BindingType));
125                         }
126                 }
127
128                 protected virtual ElementInit VisitElementInitializer (ElementInit initializer)
129                 {
130                         ReadOnlyCollection<Expression> arguments = this.VisitExpressionList (initializer.Arguments);
131                         if (arguments != initializer.Arguments) return Expression.ElementInit (initializer.AddMethod, arguments);
132                         return initializer;
133                 }
134
135                 protected virtual Expression VisitUnary (UnaryExpression u)
136                 {
137                         Expression operand = this.Visit (u.Operand);
138                         if (operand != u.Operand) return Expression.MakeUnary (u.NodeType, operand, u.Type, u.Method);
139                         return u;
140                 }
141
142                 protected virtual Expression VisitBinary (BinaryExpression b)
143                 {
144                         Expression left = this.Visit (b.Left);
145                         Expression right = this.Visit (b.Right);
146                         Expression conversion = this.Visit (b.Conversion);
147                         if (left != b.Left || right != b.Right || conversion != b.Conversion) {
148                                 if (b.NodeType == ExpressionType.Coalesce && b.Conversion != null) {
149                                         return Expression.Coalesce (left, right, conversion as LambdaExpression);
150                                 } else {
151                                         return Expression.MakeBinary (b.NodeType, left, right, b.IsLiftedToNull, b.Method);
152                                 }
153                         }
154                         return b;
155                 }
156
157                 protected virtual Expression VisitTypeIs (TypeBinaryExpression b)
158                 {
159                         Expression expr = this.Visit (b.Expression);
160                         if (expr != b.Expression) {
161                                 return Expression.TypeIs (expr, b.TypeOperand);
162                         }
163                         return b;
164                 }
165
166                 protected virtual Expression VisitConstant (ConstantExpression c)
167                 {
168                         return c;
169                 }
170
171                 protected virtual Expression VisitConditional (ConditionalExpression c)
172                 {
173                         Expression test = this.Visit (c.Test);
174                         Expression ifTrue = this.Visit (c.IfTrue);
175                         Expression ifFalse = this.Visit (c.IfFalse);
176                         if (test != c.Test || ifTrue != c.IfTrue || ifFalse != c.IfFalse) {
177                                 return Expression.Condition (test, ifTrue, ifFalse);
178                         }
179                         return c;
180                 }
181
182                 protected virtual Expression VisitParameter (ParameterExpression p)
183                 {
184                         return p;
185                 }
186
187                 protected virtual Expression VisitMemberAccess (MemberExpression m)
188                 {
189                         Expression exp = this.Visit (m.Expression);
190                         if (exp != m.Expression) {
191                                 return Expression.MakeMemberAccess (exp, m.Member);
192                         }
193                         return m;
194                 }
195
196                 protected virtual Expression VisitMethodCall (MethodCallExpression m)
197                 {
198                         Expression obj = this.Visit (m.Object);
199                         IEnumerable<Expression> args = this.VisitExpressionList (m.Arguments);
200                         if (obj != m.Object || args != m.Arguments) {
201                                 return Expression.Call (obj, m.Method, args);
202                         }
203                         return m;
204                 }
205
206                 protected virtual ReadOnlyCollection<Expression> VisitExpressionList (ReadOnlyCollection<Expression> original)
207                 {
208                         var list = VisitList (original, Visit);
209                         if (list == null) return original;
210
211                         return new ReadOnlyCollection<Expression> (list);
212                 }
213
214                 protected virtual MemberAssignment VisitMemberAssignment (MemberAssignment assignment)
215                 {
216                         Expression e = this.Visit (assignment.Expression);
217                         if (e != assignment.Expression) return Expression.Bind (assignment.Member, e);
218                         return assignment;
219                 }
220
221                 protected virtual MemberMemberBinding VisitMemberMemberBinding (MemberMemberBinding binding)
222                 {
223                         IEnumerable<MemberBinding> bindings = this.VisitBindingList (binding.Bindings);
224                         if (bindings != binding.Bindings) return Expression.MemberBind (binding.Member, bindings);
225                         return binding;
226                 }
227
228                 protected virtual MemberListBinding VisitMemberListBinding (MemberListBinding binding)
229                 {
230                         IEnumerable<ElementInit> initializers = this.VisitElementInitializerList (binding.Initializers);
231                         if (initializers != binding.Initializers) return Expression.ListBind (binding.Member, initializers);
232                         return binding;
233                 }
234
235                 protected virtual IEnumerable<MemberBinding> VisitBindingList (ReadOnlyCollection<MemberBinding> original)
236                 {
237                         return VisitList (original, VisitBinding);
238                 }
239
240                 protected virtual IEnumerable<ElementInit> VisitElementInitializerList (ReadOnlyCollection<ElementInit> original)
241                 {
242                         return VisitList (original, VisitElementInitializer);
243                 }
244
245                 private IList<TElement> VisitList<TElement> (ReadOnlyCollection<TElement> original, Func<TElement, TElement> visit)
246                 {
247                         List<TElement> list = null;
248                         for (int i = 0, n = original.Count; i < n; i++) {
249                                 TElement element = visit (original [i]);
250                                 if (list != null) {
251                                         list.Add (element);
252                                 } else if (!EqualityComparer<TElement>.Default.Equals (element, original [i])) {
253                                         list = new List<TElement> (n);
254                                         for (int j = 0; j < i; j++) {
255                                                 list.Add (original [j]);
256                                         }
257                                         list.Add (element);
258                                 }
259                         }
260                         if (list != null)
261                                 return list;
262
263                         return original;
264                 }
265
266                 protected virtual Expression VisitLambda (LambdaExpression lambda)
267                 {
268                         Expression body = this.Visit (lambda.Body);
269                         if (body != lambda.Body) return Expression.Lambda (lambda.Type, body, lambda.Parameters);
270                         return lambda;
271                 }
272
273                 protected virtual NewExpression VisitNew (NewExpression nex)
274                 {
275                         IEnumerable<Expression> args = this.VisitExpressionList (nex.Arguments);
276                         if (args != nex.Arguments) {
277                                 if (nex.Members != null)
278                                         return Expression.New (nex.Constructor, args, nex.Members);
279                                 else
280                                         return Expression.New (nex.Constructor, args);
281                         }
282                         return nex;
283                 }
284
285                 protected virtual Expression VisitMemberInit (MemberInitExpression init)
286                 {
287                         NewExpression n = this.VisitNew (init.NewExpression);
288                         IEnumerable<MemberBinding> bindings = this.VisitBindingList (init.Bindings);
289                         if (n != init.NewExpression || bindings != init.Bindings) return Expression.MemberInit (n, bindings);
290                         return init;
291                 }
292
293                 protected virtual Expression VisitListInit (ListInitExpression init)
294                 {
295                         NewExpression n = this.VisitNew (init.NewExpression);
296                         IEnumerable<ElementInit> initializers = this.VisitElementInitializerList (init.Initializers);
297                         if (n != init.NewExpression || initializers != init.Initializers) return Expression.ListInit (n, initializers);
298                         return init;
299                 }
300
301                 protected virtual Expression VisitNewArray (NewArrayExpression na)
302                 {
303                         IEnumerable<Expression> exprs = this.VisitExpressionList (na.Expressions);
304                         if (exprs != na.Expressions) {
305                                 if (na.NodeType == ExpressionType.NewArrayInit) {
306                                         return Expression.NewArrayInit (na.Type.GetElementType (), exprs);
307                                 } else {
308                                         return Expression.NewArrayBounds (na.Type.GetElementType (), exprs);
309                                 }
310                         }
311                         return na;
312                 }
313
314                 protected virtual Expression VisitInvocation (InvocationExpression iv)
315                 {
316                         IEnumerable<Expression> args = this.VisitExpressionList (iv.Arguments);
317                         Expression expr = this.Visit (iv.Expression);
318                         if (args != iv.Arguments || expr != iv.Expression) return Expression.Invoke (expr, args);
319                         return iv;
320                 }
321         }
322 }