520d5dcd0620567b5298bd1c00b6c297418ee685
[mono.git] / mcs / class / referencesource / System.Data.Entity / System / Data / Common / CommandTrees / DefaultExpressionVisitor.cs
1 //---------------------------------------------------------------------
2 // <copyright file="DefaultExpressionVisitor.cs" company="Microsoft">
3 //      Copyright (c) Microsoft Corporation.  All rights reserved.
4 // </copyright>
5 //
6 // @owner  Microsoft
7 // @backupOwner Microsoft
8 //---------------------------------------------------------------------
9
10 namespace System.Data.Common.CommandTrees
11 {
12     using System;
13     using System.Collections.Generic;
14     using System.Data.Metadata.Edm;
15     using System.Diagnostics;
16     using System.Linq;
17     using CqtBuilder = System.Data.Common.CommandTrees.ExpressionBuilder.DbExpressionBuilder;
18
19     /// <summary>
20     /// Visits each element of an expression tree from a given root expression. If any element changes, the tree is
21     /// rebuilt back to the root and the new root expression is returned; otherwise the original root expression is returned.
22     /// </summary>
23     public class DefaultExpressionVisitor : DbExpressionVisitor<DbExpression>
24     {
25         private readonly Dictionary<DbVariableReferenceExpression, DbVariableReferenceExpression> varMappings = new Dictionary<DbVariableReferenceExpression, DbVariableReferenceExpression>();
26
27         protected DefaultExpressionVisitor()
28         {
29         }
30
31         protected virtual void OnExpressionReplaced(DbExpression oldExpression, DbExpression newExpression)
32         {
33         }
34
35         protected virtual void OnVariableRebound(DbVariableReferenceExpression fromVarRef, DbVariableReferenceExpression toVarRef)
36         {
37         }
38
39         protected virtual void OnEnterScope(IEnumerable<DbVariableReferenceExpression> scopeVariables)
40         {
41         }
42
43         protected virtual void OnExitScope()
44         {
45         }
46                 
47         protected virtual DbExpression VisitExpression(DbExpression expression)
48         {
49             DbExpression newValue = null;
50             if (expression != null)
51             {
52                 newValue = expression.Accept<DbExpression>(this);
53             }
54             
55             return newValue;
56         }
57                 
58         protected virtual IList<DbExpression> VisitExpressionList(IList<DbExpression> list)
59         {
60             return VisitList(list, this.VisitExpression);
61         }
62
63         protected virtual DbExpressionBinding VisitExpressionBinding(DbExpressionBinding binding)
64         {
65             DbExpressionBinding result = binding;
66             if (binding != null)
67             {
68                 DbExpression newInput = this.VisitExpression(binding.Expression);
69                 if (!object.ReferenceEquals(binding.Expression, newInput))
70                 {
71                     result = CqtBuilder.BindAs(newInput, binding.VariableName);
72                     this.RebindVariable(binding.Variable, result.Variable);
73                 }
74             }
75             return result;
76         }
77
78         protected virtual IList<DbExpressionBinding> VisitExpressionBindingList(IList<DbExpressionBinding> list)
79         {
80             return this.VisitList(list, this.VisitExpressionBinding);
81         }
82
83         protected virtual DbGroupExpressionBinding VisitGroupExpressionBinding(DbGroupExpressionBinding binding)
84         {
85             DbGroupExpressionBinding result = binding;
86             if (binding != null)
87             {
88                 DbExpression newInput = this.VisitExpression(binding.Expression);
89                 if (!object.ReferenceEquals(binding.Expression, newInput))
90                 {
91                     result = CqtBuilder.GroupBindAs(newInput, binding.VariableName, binding.GroupVariableName);
92                     this.RebindVariable(binding.Variable, result.Variable);
93                     this.RebindVariable(binding.GroupVariable, result.GroupVariable);
94                 }
95             }
96             return result;
97         }
98
99         protected virtual DbSortClause VisitSortClause(DbSortClause clause)
100         {
101             DbSortClause result = clause;
102             if (clause != null)
103             {
104                 DbExpression newExpression = this.VisitExpression(clause.Expression);
105                 if (!object.ReferenceEquals(clause.Expression, newExpression))
106                 {
107                     if (!string.IsNullOrEmpty(clause.Collation))
108                     {
109                         result = (clause.Ascending ? CqtBuilder.ToSortClause(newExpression, clause.Collation) : CqtBuilder.ToSortClauseDescending(newExpression, clause.Collation));
110                     }
111                     else
112                     {
113                         result = (clause.Ascending ? CqtBuilder.ToSortClause(newExpression) : CqtBuilder.ToSortClauseDescending(newExpression));
114                     }
115                 }
116             }
117             return result;
118         }
119
120         protected virtual IList<DbSortClause> VisitSortOrder(IList<DbSortClause> sortOrder)
121         {
122             return VisitList(sortOrder, this.VisitSortClause);
123         }
124
125         protected virtual DbAggregate VisitAggregate(DbAggregate aggregate)
126         {
127             // Currently only function or group aggregate are possible
128             DbFunctionAggregate functionAggregate = aggregate as DbFunctionAggregate;
129             if (functionAggregate != null)
130             {
131                 return VisitFunctionAggregate(functionAggregate);
132             }
133
134             DbGroupAggregate groupAggregate = (DbGroupAggregate)aggregate;
135             return VisitGroupAggregate(groupAggregate);
136         }
137
138         protected virtual DbFunctionAggregate VisitFunctionAggregate(DbFunctionAggregate aggregate)
139         {
140             DbFunctionAggregate result = aggregate;
141             if (aggregate != null)
142             {
143                 EdmFunction newFunction = this.VisitFunction(aggregate.Function);
144                 IList<DbExpression> newArguments = this.VisitExpressionList(aggregate.Arguments);
145
146                 Debug.Assert(newArguments.Count == 1, "Function aggregate had more than one argument?");
147
148                 if (!object.ReferenceEquals(aggregate.Function, newFunction) ||
149                     !object.ReferenceEquals(aggregate.Arguments, newArguments))
150                 {
151                     if (aggregate.Distinct)
152                     {
153                         result = CqtBuilder.AggregateDistinct(newFunction, newArguments[0]);
154                     }
155                     else
156                     {
157                         result = CqtBuilder.Aggregate(newFunction, newArguments[0]);
158                     }
159                 }
160             }
161             return result;
162         }
163
164         protected virtual DbGroupAggregate VisitGroupAggregate(DbGroupAggregate aggregate)
165         {
166             DbGroupAggregate result = aggregate;
167             if (aggregate != null)
168             {
169                 IList<DbExpression> newArguments = this.VisitExpressionList(aggregate.Arguments);
170                 Debug.Assert(newArguments.Count == 1, "Group aggregate had more than one argument?");
171
172                 if (!object.ReferenceEquals(aggregate.Arguments, newArguments))
173                 {
174                     result = CqtBuilder.GroupAggregate(newArguments[0]);
175                 }
176             }
177             return result;
178         }
179
180         protected virtual DbLambda VisitLambda(DbLambda lambda)
181         {
182             EntityUtil.CheckArgumentNull(lambda, "lambda");
183
184             DbLambda result = lambda;
185             IList<DbVariableReferenceExpression> newFormals = this.VisitList(lambda.Variables, varRef =>
186                 {
187                     TypeUsage newVarType = this.VisitTypeUsage(varRef.ResultType);
188                     if (!object.ReferenceEquals(varRef.ResultType, newVarType))
189                     {
190                         return CqtBuilder.Variable(newVarType, varRef.VariableName);
191                     }
192                     else
193                     {
194                         return varRef;
195                     }
196                 }
197             );
198             this.EnterScope(newFormals.ToArray()); // ToArray: Don't pass the List instance directly to OnEnterScope
199             DbExpression newBody = this.VisitExpression(lambda.Body);
200             this.ExitScope();
201
202             if (!object.ReferenceEquals(lambda.Variables, newFormals) ||
203                 !object.ReferenceEquals(lambda.Body, newBody))
204             {
205                 result = CqtBuilder.Lambda(newBody, newFormals);
206             }
207             return result;
208         }
209
210         // Metadata 'Visitor' methods
211         protected virtual EdmType VisitType(EdmType type) { return type; }
212         protected virtual TypeUsage VisitTypeUsage(TypeUsage type) { return type; }
213         protected virtual EntitySetBase VisitEntitySet(EntitySetBase entitySet) { return entitySet; }
214         protected virtual EdmFunction VisitFunction(EdmFunction functionMetadata) { return functionMetadata; }
215                 
216         #region Private Implementation
217
218         private void NotifyIfChanged(DbExpression originalExpression, DbExpression newExpression)
219         {
220             if (!object.ReferenceEquals(originalExpression, newExpression))
221             {
222                 this.OnExpressionReplaced(originalExpression, newExpression);
223             }
224         }
225
226         private IList<TElement> VisitList<TElement>(IList<TElement> list, Func<TElement, TElement> map)
227         {
228             IList<TElement> result = list;
229             if(list != null)
230             {
231                 List<TElement> newList = null;
232                 for (int idx = 0; idx < list.Count; idx++)
233                 {
234                     TElement newElement = map(list[idx]);
235                     if (newList == null &&
236                         !object.ReferenceEquals(list[idx], newElement))
237                     {
238                         newList = new List<TElement>(list);
239                         result = newList;
240                     }
241
242                     if (newList != null)
243                     {
244                         newList[idx] = newElement;
245                     }
246                 }
247             }
248             return result;
249         }
250
251         private DbExpression VisitUnary(DbUnaryExpression expression, Func<DbExpression, DbExpression> callback)
252         {
253             DbExpression result = expression;
254             DbExpression newArgument = this.VisitExpression(expression.Argument);
255             if (!object.ReferenceEquals(expression.Argument, newArgument))
256             {
257                 result = callback(newArgument);
258             }
259             NotifyIfChanged(expression, result);
260             return result;
261         }
262
263         private DbExpression VisitTypeUnary(DbUnaryExpression expression, TypeUsage type, Func<DbExpression, TypeUsage, DbExpression> callback)
264         {
265             DbExpression result = expression;
266
267             DbExpression newArgument = this.VisitExpression(expression.Argument);
268             TypeUsage newType = this.VisitTypeUsage(type);
269
270             if (!object.ReferenceEquals(expression.Argument, newArgument) ||
271                 !object.ReferenceEquals(type, newType))
272             {
273                 result = callback(newArgument, newType);
274             }
275             NotifyIfChanged(expression, result);
276             return result;
277         }
278
279         private DbExpression VisitBinary(DbBinaryExpression expression, Func<DbExpression, DbExpression, DbExpression> callback)
280         {
281             DbExpression result = expression;
282
283             DbExpression newLeft = this.VisitExpression(expression.Left);
284             DbExpression newRight = this.VisitExpression(expression.Right);
285             if (!object.ReferenceEquals(expression.Left, newLeft) ||
286                 !object.ReferenceEquals(expression.Right, newRight))
287             {
288                 result = callback(newLeft, newRight);
289             }
290             NotifyIfChanged(expression, result);
291             return result;
292         }
293
294         private DbRelatedEntityRef VisitRelatedEntityRef(DbRelatedEntityRef entityRef)
295         {
296             RelationshipEndMember newSource; 
297             RelationshipEndMember newTarget;
298             VisitRelationshipEnds(entityRef.SourceEnd, entityRef.TargetEnd, out newSource, out newTarget);
299             DbExpression newTargetRef = this.VisitExpression(entityRef.TargetEntityReference);
300
301             if (!object.ReferenceEquals(entityRef.SourceEnd, newSource) ||
302                 !object.ReferenceEquals(entityRef.TargetEnd, newTarget) ||
303                 !object.ReferenceEquals(entityRef.TargetEntityReference, newTargetRef))
304             {
305                 return CqtBuilder.CreateRelatedEntityRef(newSource, newTarget, newTargetRef);
306             }
307             else
308             {
309                 return entityRef;
310             }
311         }
312
313         private void VisitRelationshipEnds(RelationshipEndMember source, RelationshipEndMember target, out RelationshipEndMember newSource, out RelationshipEndMember newTarget)
314         {
315             // 
316             Debug.Assert(source.DeclaringType.EdmEquals(target.DeclaringType), "Relationship ends not declared by same relationship type?");
317             RelationshipType mappedType = (RelationshipType)this.VisitType(target.DeclaringType);
318
319             newSource = mappedType.RelationshipEndMembers[source.Name];
320             newTarget = mappedType.RelationshipEndMembers[target.Name];
321         }
322
323         private DbExpression VisitTerminal(DbExpression expression, Func<TypeUsage, DbExpression> reconstructor)
324         {
325             DbExpression result = expression;
326             TypeUsage newType = this.VisitTypeUsage(expression.ResultType);
327             if (!object.ReferenceEquals(expression.ResultType, newType))
328             {
329                 result = reconstructor(newType);
330             }
331             NotifyIfChanged(expression, result);
332             return result;
333         }
334
335         private void RebindVariable(DbVariableReferenceExpression from, DbVariableReferenceExpression to)
336         {
337             //
338             // The variable is only considered rebound if the name and/or type is different.
339             // Otherwise, the original variable reference and the new variable reference are
340             // equivalent, and no rebinding of references to the old variable is necessary.
341             //
342             // When considering the new/old result types,  the TypeUsage instance may be equal
343             // or equivalent, but the EdmType must be the same instance, so that expressions
344             // such as a DbPropertyExpression with the DbVariableReferenceExpression as the Instance
345             // continue to be valid.
346             //
347             if (!from.VariableName.Equals(to.VariableName, StringComparison.Ordinal) ||
348                 !object.ReferenceEquals(from.ResultType.EdmType, to.ResultType.EdmType) ||
349                 !from.ResultType.EdmEquals(to.ResultType))
350             {
351                 this.varMappings[from] = to;
352                 this.OnVariableRebound(from, to);
353             }
354         }
355
356         private DbExpressionBinding VisitExpressionBindingEnterScope(DbExpressionBinding binding)
357         {
358             DbExpressionBinding result = this.VisitExpressionBinding(binding);
359             this.OnEnterScope(new[] { result.Variable });
360             return result;
361         }
362
363         private void EnterScope(params DbVariableReferenceExpression[] scopeVars)
364         {
365             this.OnEnterScope(scopeVars);
366         }
367
368         private void ExitScope()
369         {
370             this.OnExitScope();
371         }
372
373         #endregion
374
375         #region DbExpressionVisitor<DbExpression> Members
376
377         public override DbExpression Visit(DbExpression expression)
378         {
379             EntityUtil.CheckArgumentNull(expression, "expression");
380
381             throw EntityUtil.NotSupported(System.Data.Entity.Strings.Cqt_General_UnsupportedExpression(expression.GetType().FullName));
382         }
383
384         public override DbExpression Visit(DbConstantExpression expression)
385         {
386             EntityUtil.CheckArgumentNull(expression, "expression");
387
388             // Note that it is only safe to call DbConstantExpression.GetValue because the call to
389             // DbExpressionBuilder.Constant must clone immutable values (byte[]).
390             return VisitTerminal(expression, newType => CqtBuilder.Constant(newType, expression.GetValue()));
391         }
392                 
393         public override DbExpression Visit(DbNullExpression expression)
394         {
395             EntityUtil.CheckArgumentNull(expression, "expression");
396
397             return VisitTerminal(expression, CqtBuilder.Null);
398         }
399
400         public override DbExpression Visit(DbVariableReferenceExpression expression)
401         {
402             EntityUtil.CheckArgumentNull(expression, "expression");
403
404             DbExpression result = expression;
405             DbVariableReferenceExpression newRef;
406             if (this.varMappings.TryGetValue(expression, out newRef))
407             {
408                 result = newRef;
409             }
410             NotifyIfChanged(expression, result);
411             return result;
412         }
413
414         public override DbExpression Visit(DbParameterReferenceExpression expression)
415         {
416             EntityUtil.CheckArgumentNull(expression, "expression");
417
418             return VisitTerminal(expression, newType => CqtBuilder.Parameter(newType, expression.ParameterName));
419         }
420
421         public override DbExpression Visit(DbFunctionExpression expression)
422         {
423             EntityUtil.CheckArgumentNull(expression, "expression");
424
425             DbExpression result = expression;
426             IList<DbExpression> newArguments = this.VisitExpressionList(expression.Arguments);
427             EdmFunction newFunction = this.VisitFunction(expression.Function);
428             if (!object.ReferenceEquals(expression.Arguments, newArguments) ||
429                 !object.ReferenceEquals(expression.Function, newFunction))
430             {
431                 result = CqtBuilder.Invoke(newFunction, newArguments);
432             }
433             
434             NotifyIfChanged(expression, result);
435             return result;
436         }
437         
438         public override DbExpression Visit(DbLambdaExpression expression)
439         {
440             EntityUtil.CheckArgumentNull(expression, "expression");
441
442             DbExpression result = expression;
443             IList<DbExpression> newArguments = this.VisitExpressionList(expression.Arguments);
444             DbLambda newLambda = this.VisitLambda(expression.Lambda);
445             
446             if (!object.ReferenceEquals(expression.Arguments, newArguments) ||
447                 !object.ReferenceEquals(expression.Lambda, newLambda))
448             {
449                 result = CqtBuilder.Invoke(newLambda, newArguments);
450             }
451             NotifyIfChanged(expression, result);
452             return result;
453         }
454
455         public override DbExpression Visit(DbPropertyExpression expression)
456         {
457             EntityUtil.CheckArgumentNull(expression, "expression");
458
459             DbExpression result = expression;
460             DbExpression newInstance = this.VisitExpression(expression.Instance);
461             if (!object.ReferenceEquals(expression.Instance, newInstance))
462             {
463                 result = CqtBuilder.Property(newInstance, expression.Property.Name);
464             }
465             NotifyIfChanged(expression, result);
466             return result;
467         }
468
469         public override DbExpression Visit(DbComparisonExpression expression)
470         {
471             EntityUtil.CheckArgumentNull(expression, "expression");
472
473             switch(expression.ExpressionKind)
474             {
475                 case DbExpressionKind.Equals:
476                     return this.VisitBinary(expression, CqtBuilder.Equal);
477
478                 case DbExpressionKind.NotEquals:
479                     return this.VisitBinary(expression, CqtBuilder.NotEqual);
480
481                 case DbExpressionKind.GreaterThan:
482                     return this.VisitBinary(expression, CqtBuilder.GreaterThan);
483
484                 case DbExpressionKind.GreaterThanOrEquals:
485                     return this.VisitBinary(expression, CqtBuilder.GreaterThanOrEqual);
486
487                 case DbExpressionKind.LessThan:
488                     return this.VisitBinary(expression, CqtBuilder.LessThan);
489
490                 case DbExpressionKind.LessThanOrEquals:
491                     return this.VisitBinary(expression, CqtBuilder.LessThanOrEqual);
492
493                 default:
494                     throw EntityUtil.NotSupported();
495             }
496         }
497
498         public override DbExpression Visit(DbLikeExpression expression)
499         {
500             EntityUtil.CheckArgumentNull(expression, "expression");
501
502             DbExpression result = expression;
503
504             DbExpression newArgument = this.VisitExpression(expression.Argument);
505             DbExpression newPattern = this.VisitExpression(expression.Pattern);
506             DbExpression newEscape = this.VisitExpression(expression.Escape);
507
508             if (!object.ReferenceEquals(expression.Argument, newArgument) ||
509                 !object.ReferenceEquals(expression.Pattern, newPattern) ||
510                 !object.ReferenceEquals(expression.Escape, newEscape))
511             {
512                 result = CqtBuilder.Like(newArgument, newPattern, newEscape);
513             }
514             NotifyIfChanged(expression, result);
515             return result;
516         }
517         
518         public override DbExpression Visit(DbLimitExpression expression)
519         {
520             EntityUtil.CheckArgumentNull(expression, "expression");
521
522             DbExpression result = expression;
523
524             DbExpression newArgument = this.VisitExpression(expression.Argument);
525             DbExpression newLimit = this.VisitExpression(expression.Limit);
526             
527             if (!object.ReferenceEquals(expression.Argument, newArgument) ||
528                 !object.ReferenceEquals(expression.Limit, newLimit))
529             {
530                 Debug.Assert(!expression.WithTies, "Limit.WithTies == true?");
531                 result = CqtBuilder.Limit(newArgument, newLimit);
532             }
533             NotifyIfChanged(expression, result);
534             return result;
535         }
536
537         public override DbExpression Visit(DbIsNullExpression expression)
538         {
539             EntityUtil.CheckArgumentNull(expression, "expression");
540
541             return VisitUnary(expression, exp =>
542                 {
543                     if(TypeSemantics.IsRowType(exp.ResultType))
544                     {
545                         // 
546                         return CqtBuilder.CreateIsNullExpressionAllowingRowTypeArgument(exp);
547                     }
548                     else
549                     {
550                         return CqtBuilder.IsNull(exp);
551                     }
552                 }
553             );
554         }
555
556         public override DbExpression Visit(DbArithmeticExpression expression)
557         {
558             EntityUtil.CheckArgumentNull(expression, "expression");
559
560             DbExpression result = expression;
561             IList<DbExpression> newArguments = this.VisitExpressionList(expression.Arguments);
562             if (!object.ReferenceEquals(expression.Arguments, newArguments))
563             {
564                 switch(expression.ExpressionKind)
565                 {
566                     case DbExpressionKind.Divide:
567                         result = CqtBuilder.Divide(newArguments[0], newArguments[1]);
568                         break;
569
570                     case DbExpressionKind.Minus:
571                         result = CqtBuilder.Minus(newArguments[0], newArguments[1]);
572                         break;
573
574                     case DbExpressionKind.Modulo:
575                         result = CqtBuilder.Modulo(newArguments[0], newArguments[1]);
576                         break;
577
578                     case DbExpressionKind.Multiply:
579                         result = CqtBuilder.Multiply(newArguments[0], newArguments[1]);
580                         break;
581
582                     case DbExpressionKind.Plus:
583                         result = CqtBuilder.Plus(newArguments[0], newArguments[1]);
584                         break;
585
586                     case DbExpressionKind.UnaryMinus:
587                         result = CqtBuilder.UnaryMinus(newArguments[0]);
588                         break;
589
590                     default:
591                         throw EntityUtil.NotSupported();
592                 }
593             }
594             NotifyIfChanged(expression, result);
595             return result;
596         }
597
598         public override DbExpression Visit(DbAndExpression expression)
599         {
600             EntityUtil.CheckArgumentNull(expression, "expression");
601
602             return VisitBinary(expression, CqtBuilder.And);
603         }
604
605         public override DbExpression Visit(DbOrExpression expression)
606         {
607             EntityUtil.CheckArgumentNull(expression, "expression");
608
609             return VisitBinary(expression, CqtBuilder.Or);
610         }
611
612         public override DbExpression Visit(DbNotExpression expression)
613         {
614             EntityUtil.CheckArgumentNull(expression, "expression");
615
616             return VisitUnary(expression, CqtBuilder.Not);
617         }
618
619         public override DbExpression Visit(DbDistinctExpression expression)
620         {
621             EntityUtil.CheckArgumentNull(expression, "expression");
622
623             return VisitUnary(expression, CqtBuilder.Distinct);
624         }
625
626         public override DbExpression Visit(DbElementExpression expression)
627         {
628             EntityUtil.CheckArgumentNull(expression, "expression");
629
630             Func<DbExpression, DbExpression> resultConstructor;
631             if (expression.IsSinglePropertyUnwrapped)
632             {
633                 // 
634                 resultConstructor = CqtBuilder.CreateElementExpressionUnwrapSingleProperty;
635             }
636             else
637             {
638                 resultConstructor = CqtBuilder.Element;
639             }
640
641             return VisitUnary(expression, resultConstructor);
642         }
643
644         public override DbExpression Visit(DbIsEmptyExpression expression)
645         {
646             EntityUtil.CheckArgumentNull(expression, "expression");
647
648             return VisitUnary(expression, CqtBuilder.IsEmpty);
649         }
650
651         public override DbExpression Visit(DbUnionAllExpression expression)
652         {
653             EntityUtil.CheckArgumentNull(expression, "expression");
654
655             return VisitBinary(expression, CqtBuilder.UnionAll);
656         }
657
658         public override DbExpression Visit(DbIntersectExpression expression)
659         {
660             EntityUtil.CheckArgumentNull(expression, "expression");
661
662             return VisitBinary(expression, CqtBuilder.Intersect);
663         }
664
665         public override DbExpression Visit(DbExceptExpression expression)
666         {
667             EntityUtil.CheckArgumentNull(expression, "expression");
668
669             return VisitBinary(expression, CqtBuilder.Except);
670         }
671
672         public override DbExpression Visit(DbTreatExpression expression)
673         {
674             EntityUtil.CheckArgumentNull(expression, "expression");
675
676             return this.VisitTypeUnary(expression, expression.ResultType, CqtBuilder.TreatAs);
677         }
678
679         public override DbExpression Visit(DbIsOfExpression expression)
680         {
681             EntityUtil.CheckArgumentNull(expression, "expression");
682
683             if (expression.ExpressionKind == DbExpressionKind.IsOfOnly)
684             {
685                 return this.VisitTypeUnary(expression, expression.OfType, CqtBuilder.IsOfOnly);
686             }
687             else
688             {
689                 return this.VisitTypeUnary(expression, expression.OfType, CqtBuilder.IsOf);
690             }
691         }
692
693         public override DbExpression Visit(DbCastExpression expression)
694         {
695             EntityUtil.CheckArgumentNull(expression, "expression");
696
697             return this.VisitTypeUnary(expression, expression.ResultType, CqtBuilder.CastTo);
698         }
699
700         public override DbExpression Visit(DbCaseExpression expression)
701         {
702             EntityUtil.CheckArgumentNull(expression, "expression");
703
704             DbExpression result = expression;
705
706             IList<DbExpression> newWhens = this.VisitExpressionList(expression.When);
707             IList<DbExpression> newThens = this.VisitExpressionList(expression.Then);
708             DbExpression newElse = this.VisitExpression(expression.Else);
709
710             if (!object.ReferenceEquals(expression.When, newWhens) ||
711                 !object.ReferenceEquals(expression.Then, newThens) ||
712                 !object.ReferenceEquals(expression.Else, newElse))
713             {
714                 result = CqtBuilder.Case(newWhens, newThens, newElse);
715             }
716             NotifyIfChanged(expression, result);
717             return result;
718         }
719         
720         public override DbExpression Visit(DbOfTypeExpression expression)
721         {
722             EntityUtil.CheckArgumentNull(expression, "expression");
723
724             if (expression.ExpressionKind == DbExpressionKind.OfTypeOnly)
725             {
726                 return this.VisitTypeUnary(expression, expression.OfType, CqtBuilder.OfTypeOnly);
727             }
728             else
729             {
730                 return this.VisitTypeUnary(expression, expression.OfType, CqtBuilder.OfType);
731             }
732         }
733
734         public override DbExpression Visit(DbNewInstanceExpression expression)
735         {
736             EntityUtil.CheckArgumentNull(expression, "expression");
737
738             DbExpression result = expression;
739             TypeUsage newType = this.VisitTypeUsage(expression.ResultType);
740             IList<DbExpression> newArguments = this.VisitExpressionList(expression.Arguments);
741             bool unchanged = (object.ReferenceEquals(expression.ResultType, newType) && object.ReferenceEquals(expression.Arguments, newArguments));
742             if (expression.HasRelatedEntityReferences)
743             {
744                 IList<DbRelatedEntityRef> newRefs = this.VisitList(expression.RelatedEntityReferences, this.VisitRelatedEntityRef);
745                 if (!unchanged ||
746                     !object.ReferenceEquals(expression.RelatedEntityReferences, newRefs))
747                 {
748                     result = CqtBuilder.CreateNewEntityWithRelationshipsExpression((EntityType)newType.EdmType, newArguments, newRefs);
749                 }
750             }
751             else
752             {
753                 if (!unchanged)
754                 {
755                     result = CqtBuilder.New(newType, System.Linq.Enumerable.ToArray(newArguments));
756                 }
757             }
758             NotifyIfChanged(expression, result);
759             return result;
760         }
761
762         public override DbExpression Visit(DbRefExpression expression)
763         {
764             EntityUtil.CheckArgumentNull(expression, "expression");
765
766             DbExpression result = expression;
767
768             EntityType targetType = (EntityType)TypeHelpers.GetEdmType<RefType>(expression.ResultType).ElementType;
769
770             DbExpression newArgument = this.VisitExpression(expression.Argument);
771             EntityType newType = (EntityType)this.VisitType(targetType);
772             EntitySet newSet = (EntitySet)this.VisitEntitySet(expression.EntitySet);
773             if (!object.ReferenceEquals(expression.Argument, newArgument) ||
774                 !object.ReferenceEquals(targetType, newType) ||
775                 !object.ReferenceEquals(expression.EntitySet, newSet))
776             {
777                 result = CqtBuilder.RefFromKey(newSet, newArgument, newType);
778             }
779             NotifyIfChanged(expression, result);
780             return result;
781         }
782
783         public override DbExpression Visit(DbRelationshipNavigationExpression expression)
784         {
785             EntityUtil.CheckArgumentNull(expression, "expression");
786
787             DbExpression result = expression;
788
789             RelationshipEndMember newFrom;
790             RelationshipEndMember newTo;
791             VisitRelationshipEnds(expression.NavigateFrom, expression.NavigateTo, out newFrom, out newTo);
792             DbExpression newNavSource = this.VisitExpression(expression.NavigationSource);
793
794             if (!object.ReferenceEquals(expression.NavigateFrom, newFrom) ||
795                 !object.ReferenceEquals(expression.NavigateTo, newTo) ||
796                 !object.ReferenceEquals(expression.NavigationSource, newNavSource))
797             {
798                 result = CqtBuilder.Navigate(newNavSource, newFrom, newTo);
799             }
800             NotifyIfChanged(expression, result);
801             return result;
802         }
803
804         public override DbExpression Visit(DbDerefExpression expression)
805         {
806             EntityUtil.CheckArgumentNull(expression, "expression");
807
808             return this.VisitUnary(expression, CqtBuilder.Deref);
809         }
810
811         public override DbExpression Visit(DbRefKeyExpression expression)
812         {
813             EntityUtil.CheckArgumentNull(expression, "expression");
814
815             return this.VisitUnary(expression, CqtBuilder.GetRefKey);
816         }
817
818         public override DbExpression Visit(DbEntityRefExpression expression)
819         {
820             EntityUtil.CheckArgumentNull(expression, "expression");
821
822             return this.VisitUnary(expression, CqtBuilder.GetEntityRef);
823         }
824
825         public override DbExpression Visit(DbScanExpression expression)
826         {
827             EntityUtil.CheckArgumentNull(expression, "expression");
828
829             DbExpression result = expression;
830
831             EntitySetBase newSet = this.VisitEntitySet(expression.Target);
832             if (!object.ReferenceEquals(expression.Target, newSet))
833             {
834                 result = CqtBuilder.Scan(newSet);
835             }
836             NotifyIfChanged(expression, result);
837             return result;
838         }
839                 
840         public override DbExpression Visit(DbFilterExpression expression)
841         {
842             EntityUtil.CheckArgumentNull(expression, "expression");
843
844             DbExpression result = expression;
845
846             DbExpressionBinding input = this.VisitExpressionBindingEnterScope(expression.Input);
847             DbExpression predicate = this.VisitExpression(expression.Predicate);
848             this.ExitScope();
849             if (!object.ReferenceEquals(expression.Input, input) ||
850                 !object.ReferenceEquals(expression.Predicate, predicate))
851             {
852                 result = CqtBuilder.Filter(input, predicate);
853             }
854             NotifyIfChanged(expression, result);
855             return result;
856         }
857         
858         public override DbExpression Visit(DbProjectExpression expression)
859         {
860             EntityUtil.CheckArgumentNull(expression, "expression");
861
862             DbExpression result = expression;
863
864             DbExpressionBinding input = this.VisitExpressionBindingEnterScope(expression.Input);
865             DbExpression projection = this.VisitExpression(expression.Projection);
866             this.ExitScope();
867             if (!object.ReferenceEquals(expression.Input, input) ||
868                 !object.ReferenceEquals(expression.Projection, projection))
869             {
870                 result = CqtBuilder.Project(input, projection);
871             }
872             NotifyIfChanged(expression, result);
873             return result;
874         }
875
876         public override DbExpression Visit(DbCrossJoinExpression expression)
877         {
878             EntityUtil.CheckArgumentNull(expression, "expression");
879
880             DbExpression result = expression;
881
882             IList<DbExpressionBinding> newInputs = this.VisitExpressionBindingList(expression.Inputs);
883             if (!object.ReferenceEquals(expression.Inputs, newInputs))
884             {
885                 result = CqtBuilder.CrossJoin(newInputs);
886             }
887             NotifyIfChanged(expression, result);
888             return result;
889         }
890
891         public override DbExpression Visit(DbJoinExpression expression)
892         {
893             EntityUtil.CheckArgumentNull(expression, "expression");
894
895             DbExpression result = expression;
896
897             DbExpressionBinding newLeft = this.VisitExpressionBinding(expression.Left);
898             DbExpressionBinding newRight = this.VisitExpressionBinding(expression.Right);
899             
900             this.EnterScope(newLeft.Variable, newRight.Variable);
901             DbExpression newCondition = this.VisitExpression(expression.JoinCondition);
902             this.ExitScope();
903
904             if (!object.ReferenceEquals(expression.Left, newLeft) ||
905                 !object.ReferenceEquals(expression.Right, newRight) ||
906                 !object.ReferenceEquals(expression.JoinCondition, newCondition))
907             {
908                 if (DbExpressionKind.InnerJoin == expression.ExpressionKind)
909                 {
910                     result = CqtBuilder.InnerJoin(newLeft, newRight, newCondition);
911                 }
912                 else if (DbExpressionKind.LeftOuterJoin == expression.ExpressionKind)
913                 {
914                     result = CqtBuilder.LeftOuterJoin(newLeft, newRight, newCondition);
915                 }
916                 else
917                 {
918                     Debug.Assert(expression.ExpressionKind == DbExpressionKind.FullOuterJoin, "DbJoinExpression had ExpressionKind other than InnerJoin, LeftOuterJoin or FullOuterJoin?");
919                     result = CqtBuilder.FullOuterJoin(newLeft, newRight, newCondition);
920                 }
921             }
922             NotifyIfChanged(expression, result);
923             return result;
924         }
925
926         public override DbExpression Visit(DbApplyExpression expression)
927         {
928             EntityUtil.CheckArgumentNull(expression, "expression");
929
930             DbExpression result = expression;
931
932             DbExpressionBinding newInput = this.VisitExpressionBindingEnterScope(expression.Input);
933             DbExpressionBinding newApply = this.VisitExpressionBinding(expression.Apply);
934             this.ExitScope();
935
936             if (!object.ReferenceEquals(expression.Input, newInput) ||
937                 !object.ReferenceEquals(expression.Apply, newApply))
938             {
939                 if (DbExpressionKind.CrossApply == expression.ExpressionKind)
940                 {
941                     result = CqtBuilder.CrossApply(newInput, newApply);
942                 }
943                 else
944                 {
945                     Debug.Assert(expression.ExpressionKind == DbExpressionKind.OuterApply, "DbApplyExpression had ExpressionKind other than CrossApply or OuterApply?");
946                     result = CqtBuilder.OuterApply(newInput, newApply);
947                 }
948             }
949             NotifyIfChanged(expression, result);
950             return result;
951         }
952
953         public override DbExpression Visit(DbGroupByExpression expression)
954         {
955             EntityUtil.CheckArgumentNull(expression, "expression");
956
957             DbExpression result = expression;
958
959             DbGroupExpressionBinding newInput = this.VisitGroupExpressionBinding(expression.Input);
960             this.EnterScope(newInput.Variable);
961             IList<DbExpression> newKeys = this.VisitExpressionList(expression.Keys);
962             this.ExitScope();
963             this.EnterScope(newInput.GroupVariable);
964             IList<DbAggregate> newAggs = this.VisitList<DbAggregate>(expression.Aggregates, this.VisitAggregate);
965             this.ExitScope();
966
967             if (!object.ReferenceEquals(expression.Input, newInput) ||
968                 !object.ReferenceEquals(expression.Keys, newKeys) ||
969                 !object.ReferenceEquals(expression.Aggregates, newAggs))
970             {
971                 RowType groupOutput =
972                     TypeHelpers.GetEdmType<RowType>(TypeHelpers.GetEdmType<CollectionType>(expression.ResultType).TypeUsage);
973
974                 var boundKeys = groupOutput.Properties.Take(newKeys.Count).Select(p => p.Name).Zip(newKeys).ToList();
975                 var boundAggs = groupOutput.Properties.Skip(newKeys.Count).Select(p => p.Name).Zip(newAggs).ToList();
976
977                 result = CqtBuilder.GroupBy(newInput, boundKeys, boundAggs);
978             }
979             NotifyIfChanged(expression, result);
980             return result;
981         }
982                 
983         public override DbExpression Visit(DbSkipExpression expression)
984         {
985             EntityUtil.CheckArgumentNull(expression, "expression");
986
987             DbExpression result = expression;
988
989             DbExpressionBinding newInput = this.VisitExpressionBindingEnterScope(expression.Input);
990             IList<DbSortClause> newSortOrder = this.VisitSortOrder(expression.SortOrder);
991             this.ExitScope();
992             DbExpression newCount = this.VisitExpression(expression.Count);
993
994             if (!object.ReferenceEquals(expression.Input, newInput) ||
995                 !object.ReferenceEquals(expression.SortOrder, newSortOrder) ||
996                 !object.ReferenceEquals(expression.Count, newCount))
997             {
998                 result = CqtBuilder.Skip(newInput, newSortOrder, newCount);
999             }
1000             NotifyIfChanged(expression, result);
1001             return result;
1002         }
1003
1004         public override DbExpression Visit(DbSortExpression expression)
1005         {
1006             EntityUtil.CheckArgumentNull(expression, "expression");
1007
1008             DbExpression result = expression;
1009
1010             DbExpressionBinding newInput = this.VisitExpressionBindingEnterScope(expression.Input);
1011             IList<DbSortClause> newSortOrder = this.VisitSortOrder(expression.SortOrder);
1012             this.ExitScope();
1013
1014             if (!object.ReferenceEquals(expression.Input, newInput) ||
1015                 !object.ReferenceEquals(expression.SortOrder, newSortOrder))
1016             {
1017                 result = CqtBuilder.Sort(newInput, newSortOrder);
1018             }
1019             NotifyIfChanged(expression, result);
1020             return result;
1021         }
1022
1023         public override DbExpression Visit(DbQuantifierExpression expression)
1024         {
1025             EntityUtil.CheckArgumentNull(expression, "expression");
1026
1027             DbExpression result = expression;
1028
1029             DbExpressionBinding input = this.VisitExpressionBindingEnterScope(expression.Input);
1030             DbExpression predicate = this.VisitExpression(expression.Predicate);
1031             this.ExitScope();
1032
1033             if (!object.ReferenceEquals(expression.Input, input) ||
1034                 !object.ReferenceEquals(expression.Predicate, predicate))
1035             {
1036                 if (DbExpressionKind.All == expression.ExpressionKind)
1037                 {
1038                     result = CqtBuilder.All(input, predicate);
1039                 }
1040                 else
1041                 {
1042                     Debug.Assert(expression.ExpressionKind == DbExpressionKind.Any, "DbQuantifierExpression had ExpressionKind other than All or Any?");
1043                     result = CqtBuilder.Any(input, predicate);
1044                 }
1045             }
1046             NotifyIfChanged(expression, result);
1047             return result;
1048         }
1049
1050         #endregion
1051     }
1052 }