2 using System.Collections.Generic;
3 using System.Data.Linq;
4 using System.Diagnostics;
5 using System.Diagnostics.CodeAnalysis;
7 namespace System.Data.Linq.SqlClient {
8 internal abstract class SqlVisitor {
12 [SuppressMessage("Microsoft.Performance", "CA1800:DoNotCastUnnecessarily", Justification="Microsoft: Cast is dependent on node type and casts do not happen unecessarily in a single code path.")]
13 [SuppressMessage("Microsoft.Maintainability", "CA1506:AvoidExcessiveClassCoupling", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
14 [SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
15 internal virtual SqlNode Visit(SqlNode node) {
16 SqlNode result = null;
23 CheckRecursionDepth(500, nDepth);
25 switch (node.NodeType) {
27 case SqlNodeType.Not2V:
28 case SqlNodeType.Negate:
29 case SqlNodeType.BitNot:
30 case SqlNodeType.IsNull:
31 case SqlNodeType.IsNotNull:
32 case SqlNodeType.Count:
33 case SqlNodeType.LongCount:
38 case SqlNodeType.Stddev:
39 case SqlNodeType.Convert:
40 case SqlNodeType.ValueOf:
41 case SqlNodeType.OuterJoinedValue:
42 case SqlNodeType.ClrLength:
43 result = this.VisitUnaryOperator((SqlUnary)node);
45 case SqlNodeType.Lift:
46 result = this.VisitLift((SqlLift)node);
53 case SqlNodeType.BitAnd:
54 case SqlNodeType.BitOr:
55 case SqlNodeType.BitXor:
64 case SqlNodeType.EQ2V:
65 case SqlNodeType.NE2V:
66 case SqlNodeType.Concat:
67 case SqlNodeType.Coalesce:
68 result = this.VisitBinaryOperator((SqlBinary)node);
70 case SqlNodeType.Between:
71 result = this.VisitBetween((SqlBetween)node);
74 result = this.VisitIn((SqlIn)node);
76 case SqlNodeType.Like:
77 result = this.VisitLike((SqlLike)node);
79 case SqlNodeType.Treat:
80 result = this.VisitTreat((SqlUnary)node);
82 case SqlNodeType.Alias:
83 result = this.VisitAlias((SqlAlias)node);
85 case SqlNodeType.AliasRef:
86 result = this.VisitAliasRef((SqlAliasRef)node);
88 case SqlNodeType.Member:
89 result = this.VisitMember((SqlMember)node);
92 result = this.VisitRow((SqlRow)node);
94 case SqlNodeType.Column:
95 result = this.VisitColumn((SqlColumn)node);
97 case SqlNodeType.ColumnRef:
98 result = this.VisitColumnRef((SqlColumnRef)node);
100 case SqlNodeType.Table:
101 result = this.VisitTable((SqlTable)node);
103 case SqlNodeType.UserQuery:
104 result = this.VisitUserQuery((SqlUserQuery)node);
106 case SqlNodeType.StoredProcedureCall:
107 result = this.VisitStoredProcedureCall((SqlStoredProcedureCall)node);
109 case SqlNodeType.UserRow:
110 result = this.VisitUserRow((SqlUserRow)node);
112 case SqlNodeType.UserColumn:
113 result = this.VisitUserColumn((SqlUserColumn)node);
115 case SqlNodeType.Multiset:
116 case SqlNodeType.ScalarSubSelect:
117 case SqlNodeType.Element:
118 case SqlNodeType.Exists:
119 result = this.VisitSubSelect((SqlSubSelect)node);
121 case SqlNodeType.Join:
122 result = this.VisitJoin((SqlJoin)node);
124 case SqlNodeType.Select:
125 result = this.VisitSelect((SqlSelect)node);
127 case SqlNodeType.Parameter:
128 result = this.VisitParameter((SqlParameter)node);
130 case SqlNodeType.New:
131 result = this.VisitNew((SqlNew)node);
133 case SqlNodeType.Link:
134 result = this.VisitLink((SqlLink)node);
136 case SqlNodeType.ClientQuery:
137 result = this.VisitClientQuery((SqlClientQuery)node);
139 case SqlNodeType.JoinedCollection:
140 result = this.VisitJoinedCollection((SqlJoinedCollection)node);
142 case SqlNodeType.Value:
143 result = this.VisitValue((SqlValue)node);
145 case SqlNodeType.ClientArray:
146 result = this.VisitClientArray((SqlClientArray)node);
148 case SqlNodeType.Insert:
149 result = this.VisitInsert((SqlInsert)node);
151 case SqlNodeType.Update:
152 result = this.VisitUpdate((SqlUpdate)node);
154 case SqlNodeType.Delete:
155 result = this.VisitDelete((SqlDelete)node);
157 case SqlNodeType.MemberAssign:
158 result = this.VisitMemberAssign((SqlMemberAssign)node);
160 case SqlNodeType.Assign:
161 result = this.VisitAssign((SqlAssign)node);
163 case SqlNodeType.Block:
164 result = this.VisitBlock((SqlBlock)node);
166 case SqlNodeType.SearchedCase:
167 result = this.VisitSearchedCase((SqlSearchedCase)node);
169 case SqlNodeType.ClientCase:
170 result = this.VisitClientCase((SqlClientCase)node);
172 case SqlNodeType.SimpleCase:
173 result = this.VisitSimpleCase((SqlSimpleCase)node);
175 case SqlNodeType.TypeCase:
176 result = this.VisitTypeCase((SqlTypeCase)node);
178 case SqlNodeType.Union:
179 result = this.VisitUnion((SqlUnion)node);
181 case SqlNodeType.ExprSet:
182 result = this.VisitExprSet((SqlExprSet)node);
184 case SqlNodeType.Variable:
185 result = this.VisitVariable((SqlVariable)node);
187 case SqlNodeType.DoNotVisit:
188 result = this.VisitDoNotVisit((SqlDoNotVisitExpression)node);
190 case SqlNodeType.OptionalValue:
191 result = this.VisitOptionalValue((SqlOptionalValue)node);
193 case SqlNodeType.FunctionCall:
194 result = this.VisitFunctionCall((SqlFunctionCall)node);
196 case SqlNodeType.TableValuedFunctionCall:
197 result = this.VisitTableValuedFunctionCall((SqlTableValuedFunctionCall)node);
199 case SqlNodeType.MethodCall:
200 result = this.VisitMethodCall((SqlMethodCall)node);
202 case SqlNodeType.Nop:
203 result = this.VisitNop((SqlNop)node);
205 case SqlNodeType.SharedExpression:
206 result = this.VisitSharedExpression((SqlSharedExpression)node);
208 case SqlNodeType.SharedExpressionRef:
209 result = this.VisitSharedExpressionRef((SqlSharedExpressionRef)node);
211 case SqlNodeType.SimpleExpression:
212 result = this.VisitSimpleExpression((SqlSimpleExpression)node);
214 case SqlNodeType.Grouping:
215 result = this.VisitGrouping((SqlGrouping)node);
217 case SqlNodeType.DiscriminatedType:
218 result = this.VisitDiscriminatedType((SqlDiscriminatedType)node);
220 case SqlNodeType.DiscriminatorOf:
221 result = this.VisitDiscriminatorOf((SqlDiscriminatorOf)node);
223 case SqlNodeType.ClientParameter:
224 result = this.VisitClientParameter((SqlClientParameter)node);
226 case SqlNodeType.RowNumber:
227 result = this.VisitRowNumber((SqlRowNumber)node);
229 case SqlNodeType.IncludeScope:
230 result = this.VisitIncludeScope((SqlIncludeScope)node);
233 throw Error.UnexpectedNode(node);
244 /// This method checks the recursion level to help diagnose/prevent
245 /// infinite recursion in debug builds. Calls are ommitted in non debug builds.
247 [SuppressMessage("Microsoft.Usage", "CA2201:DoNotRaiseReservedExceptionTypes", Justification="Debug-only code.")]
248 [Conditional("DEBUG")]
249 internal static void CheckRecursionDepth(int maxLevel, int level) {
250 if (level > maxLevel) {
251 System.Diagnostics.Debug.Assert(false);
252 //**********************************************************************
253 // EXCLUDING FROM LOCALIZATION.
254 // Reason: This code only executes in DEBUG.
255 throw new Exception("Infinite Descent?");
256 //**********************************************************************
260 [SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
261 internal object Eval(SqlExpression expr) {
262 if (expr.NodeType == SqlNodeType.Value) {
263 return ((SqlValue)expr).Value;
265 throw Error.UnexpectedNode(expr.NodeType);
268 internal virtual SqlExpression VisitDoNotVisit(SqlDoNotVisitExpression expr) {
269 return expr.Expression;
271 internal virtual SqlRowNumber VisitRowNumber(SqlRowNumber rowNumber) {
272 for (int i = 0, n = rowNumber.OrderBy.Count; i < n; i++) {
273 rowNumber.OrderBy[i].Expression = this.VisitExpression(rowNumber.OrderBy[i].Expression);
278 internal virtual SqlExpression VisitExpression(SqlExpression exp) {
279 return (SqlExpression)this.Visit(exp);
281 internal virtual SqlSelect VisitSequence(SqlSelect sel) {
282 return (SqlSelect)this.Visit(sel);
284 internal virtual SqlExpression VisitNop(SqlNop nop) {
287 internal virtual SqlExpression VisitLift(SqlLift lift) {
288 lift.Expression = this.VisitExpression(lift.Expression);
291 internal virtual SqlExpression VisitUnaryOperator(SqlUnary uo) {
292 uo.Operand = this.VisitExpression(uo.Operand);
295 internal virtual SqlExpression VisitBinaryOperator(SqlBinary bo) {
296 bo.Left = this.VisitExpression(bo.Left);
297 bo.Right = this.VisitExpression(bo.Right);
300 internal virtual SqlAlias VisitAlias(SqlAlias a) {
301 a.Node = this.Visit(a.Node);
304 internal virtual SqlExpression VisitAliasRef(SqlAliasRef aref) {
307 internal virtual SqlNode VisitMember(SqlMember m) {
308 m.Expression = this.VisitExpression(m.Expression);
311 internal virtual SqlExpression VisitCast(SqlUnary c) {
312 c.Operand = this.VisitExpression(c.Operand);
315 internal virtual SqlExpression VisitTreat(SqlUnary t) {
316 t.Operand = this.VisitExpression(t.Operand);
319 internal virtual SqlTable VisitTable(SqlTable tab) {
322 internal virtual SqlUserQuery VisitUserQuery(SqlUserQuery suq) {
323 for (int i = 0, n = suq.Arguments.Count; i < n; i++) {
324 suq.Arguments[i] = this.VisitExpression(suq.Arguments[i]);
326 suq.Projection = this.VisitExpression(suq.Projection);
327 for (int i = 0, n = suq.Columns.Count; i < n; i++) {
328 suq.Columns[i] = (SqlUserColumn) this.Visit(suq.Columns[i]);
332 internal virtual SqlStoredProcedureCall VisitStoredProcedureCall(SqlStoredProcedureCall spc) {
333 for (int i = 0, n = spc.Arguments.Count; i < n; i++) {
334 spc.Arguments[i] = this.VisitExpression(spc.Arguments[i]);
336 spc.Projection = this.VisitExpression(spc.Projection);
337 for (int i = 0, n = spc.Columns.Count; i < n; i++) {
338 spc.Columns[i] = (SqlUserColumn) this.Visit(spc.Columns[i]);
342 internal virtual SqlExpression VisitUserColumn(SqlUserColumn suc) {
345 internal virtual SqlExpression VisitUserRow(SqlUserRow row) {
348 internal virtual SqlRow VisitRow(SqlRow row) {
349 for (int i = 0, n = row.Columns.Count; i < n; i++) {
350 row.Columns[i].Expression = this.VisitExpression(row.Columns[i].Expression);
354 internal virtual SqlExpression VisitNew(SqlNew sox) {
355 for (int i = 0, n = sox.Args.Count; i < n; i++) {
356 sox.Args[i] = this.VisitExpression(sox.Args[i]);
358 for (int i = 0, n = sox.Members.Count; i < n; i++) {
359 sox.Members[i].Expression = this.VisitExpression(sox.Members[i].Expression);
363 internal virtual SqlNode VisitLink(SqlLink link) {
364 // Don't visit the link's Expansion
365 for (int i = 0, n = link.KeyExpressions.Count; i < n; i++) {
366 link.KeyExpressions[i] = this.VisitExpression(link.KeyExpressions[i]);
370 internal virtual SqlExpression VisitClientQuery(SqlClientQuery cq) {
371 for (int i = 0, n = cq.Arguments.Count; i < n; i++) {
372 cq.Arguments[i] = this.VisitExpression(cq.Arguments[i]);
376 internal virtual SqlExpression VisitJoinedCollection(SqlJoinedCollection jc) {
377 jc.Expression = this.VisitExpression(jc.Expression);
378 jc.Count = this.VisitExpression(jc.Count);
381 internal virtual SqlExpression VisitClientArray(SqlClientArray scar) {
382 for (int i = 0, n = scar.Expressions.Count; i < n; i++) {
383 scar.Expressions[i] = this.VisitExpression(scar.Expressions[i]);
387 internal virtual SqlExpression VisitClientParameter(SqlClientParameter cp) {
390 internal virtual SqlExpression VisitColumn(SqlColumn col) {
391 col.Expression = this.VisitExpression(col.Expression);
394 internal virtual SqlExpression VisitColumnRef(SqlColumnRef cref) {
397 internal virtual SqlExpression VisitParameter(SqlParameter p) {
400 internal virtual SqlExpression VisitValue(SqlValue value) {
403 internal virtual SqlExpression VisitSubSelect(SqlSubSelect ss) {
404 switch(ss.NodeType) {
405 case SqlNodeType.ScalarSubSelect: return this.VisitScalarSubSelect(ss);
406 case SqlNodeType.Multiset: return this.VisitMultiset(ss);
407 case SqlNodeType.Element: return this.VisitElement(ss);
408 case SqlNodeType.Exists: return this.VisitExists(ss);
410 throw Error.UnexpectedNode(ss.NodeType);
412 internal virtual SqlExpression VisitScalarSubSelect(SqlSubSelect ss) {
413 ss.Select = this.VisitSequence(ss.Select);
416 internal virtual SqlExpression VisitMultiset(SqlSubSelect sms) {
417 sms.Select = this.VisitSequence(sms.Select);
420 internal virtual SqlExpression VisitElement(SqlSubSelect elem) {
421 elem.Select = this.VisitSequence(elem.Select);
424 internal virtual SqlExpression VisitExists(SqlSubSelect sqlExpr) {
425 sqlExpr.Select = this.VisitSequence(sqlExpr.Select);
428 internal virtual SqlSource VisitJoin(SqlJoin join) {
429 join.Left = this.VisitSource(join.Left);
430 join.Right = this.VisitSource(join.Right);
431 join.Condition = this.VisitExpression(join.Condition);
434 internal virtual SqlSource VisitSource(SqlSource source) {
435 return (SqlSource) this.Visit(source);
437 internal virtual SqlSelect VisitSelectCore(SqlSelect select) {
438 select.From = this.VisitSource(select.From);
439 select.Where = this.VisitExpression(select.Where);
440 for (int i = 0, n = select.GroupBy.Count; i < n; i++) {
441 select.GroupBy[i] = this.VisitExpression(select.GroupBy[i]);
443 select.Having = this.VisitExpression(select.Having);
444 for (int i = 0, n = select.OrderBy.Count; i < n; i++) {
445 select.OrderBy[i].Expression = this.VisitExpression(select.OrderBy[i].Expression);
447 select.Top = this.VisitExpression(select.Top);
448 select.Row = (SqlRow)this.Visit(select.Row);
451 internal virtual SqlSelect VisitSelect(SqlSelect select) {
452 select = this.VisitSelectCore(select);
453 select.Selection = this.VisitExpression(select.Selection);
456 internal virtual SqlStatement VisitInsert(SqlInsert insert) {
457 insert.Table = (SqlTable)this.Visit(insert.Table);
458 insert.Expression = this.VisitExpression(insert.Expression);
459 insert.Row = (SqlRow)this.Visit(insert.Row);
462 internal virtual SqlStatement VisitUpdate(SqlUpdate update) {
463 update.Select = this.VisitSequence(update.Select);
464 for (int i = 0, n = update.Assignments.Count; i < n; i++) {
465 update.Assignments[i] = (SqlAssign)this.Visit(update.Assignments[i]);
469 internal virtual SqlStatement VisitDelete(SqlDelete delete) {
470 delete.Select = this.VisitSequence(delete.Select);
473 internal virtual SqlMemberAssign VisitMemberAssign(SqlMemberAssign ma) {
474 ma.Expression = this.VisitExpression(ma.Expression);
477 internal virtual SqlStatement VisitAssign(SqlAssign sa) {
478 sa.LValue = this.VisitExpression(sa.LValue);
479 sa.RValue = this.VisitExpression(sa.RValue);
482 internal virtual SqlBlock VisitBlock(SqlBlock b) {
483 for (int i = 0, n = b.Statements.Count; i < n; i++) {
484 b.Statements[i] = (SqlStatement)this.Visit(b.Statements[i]);
488 internal virtual SqlExpression VisitSearchedCase(SqlSearchedCase c) {
489 for (int i = 0, n = c.Whens.Count; i < n; i++) {
490 SqlWhen when = c.Whens[i];
491 when.Match = this.VisitExpression(when.Match);
492 when.Value = this.VisitExpression(when.Value);
494 c.Else = this.VisitExpression(c.Else);
497 internal virtual SqlExpression VisitClientCase(SqlClientCase c) {
498 c.Expression = this.VisitExpression(c.Expression);
499 for (int i = 0, n = c.Whens.Count; i < n; i++) {
500 SqlClientWhen when = c.Whens[i];
501 when.Match = this.VisitExpression(when.Match);
502 when.Value = this.VisitExpression(when.Value);
506 internal virtual SqlExpression VisitSimpleCase(SqlSimpleCase c) {
507 c.Expression = this.VisitExpression(c.Expression);
508 for (int i = 0, n = c.Whens.Count; i < n; i++) {
509 SqlWhen when = c.Whens[i];
510 when.Match = this.VisitExpression(when.Match);
511 when.Value = this.VisitExpression(when.Value);
515 internal virtual SqlExpression VisitTypeCase(SqlTypeCase tc) {
516 tc.Discriminator = this.VisitExpression(tc.Discriminator);
517 for (int i = 0, n = tc.Whens.Count; i < n; i++) {
518 SqlTypeCaseWhen when = tc.Whens[i];
519 when.Match = this.VisitExpression(when.Match);
520 when.TypeBinding = this.VisitExpression(when.TypeBinding);
524 internal virtual SqlNode VisitUnion(SqlUnion su) {
525 su.Left = this.Visit(su.Left);
526 su.Right = this.Visit(su.Right);
529 internal virtual SqlExpression VisitExprSet(SqlExprSet xs) {
530 for (int i = 0, n = xs.Expressions.Count; i < n; i++) {
531 xs.Expressions[i] = this.VisitExpression(xs.Expressions[i]);
535 internal virtual SqlExpression VisitVariable(SqlVariable v) {
538 internal virtual SqlExpression VisitOptionalValue(SqlOptionalValue sov) {
539 sov.HasValue = this.VisitExpression(sov.HasValue);
540 sov.Value = this.VisitExpression(sov.Value);
543 internal virtual SqlExpression VisitBetween(SqlBetween between) {
544 between.Expression = this.VisitExpression(between.Expression);
545 between.Start = this.VisitExpression(between.Start);
546 between.End = this.VisitExpression(between.End);
549 internal virtual SqlExpression VisitIn(SqlIn sin) {
550 sin.Expression = this.VisitExpression(sin.Expression);
551 for (int i = 0, n = sin.Values.Count; i < n; i++) {
552 sin.Values[i] = this.VisitExpression(sin.Values[i]);
556 internal virtual SqlExpression VisitLike(SqlLike like) {
557 like.Expression = this.VisitExpression(like.Expression);
558 like.Pattern = this.VisitExpression(like.Pattern);
559 like.Escape = this.VisitExpression(like.Escape);
562 internal virtual SqlExpression VisitFunctionCall(SqlFunctionCall fc) {
563 for (int i = 0, n = fc.Arguments.Count; i < n; i++) {
564 fc.Arguments[i] = this.VisitExpression(fc.Arguments[i]);
568 internal virtual SqlExpression VisitTableValuedFunctionCall(SqlTableValuedFunctionCall fc) {
569 for (int i = 0, n = fc.Arguments.Count; i < n; i++) {
570 fc.Arguments[i] = this.VisitExpression(fc.Arguments[i]);
574 internal virtual SqlExpression VisitMethodCall(SqlMethodCall mc) {
575 mc.Object = this.VisitExpression(mc.Object);
576 for (int i = 0, n = mc.Arguments.Count; i < n; i++) {
577 mc.Arguments[i] = this.VisitExpression(mc.Arguments[i]);
581 internal virtual SqlExpression VisitSharedExpression(SqlSharedExpression shared) {
582 shared.Expression = this.VisitExpression(shared.Expression);
585 internal virtual SqlExpression VisitSharedExpressionRef(SqlSharedExpressionRef sref) {
588 internal virtual SqlExpression VisitSimpleExpression(SqlSimpleExpression simple) {
589 simple.Expression = this.VisitExpression(simple.Expression);
592 internal virtual SqlExpression VisitGrouping(SqlGrouping g) {
593 g.Key = this.VisitExpression(g.Key);
594 g.Group = this.VisitExpression(g.Group);
597 internal virtual SqlExpression VisitDiscriminatedType(SqlDiscriminatedType dt) {
598 dt.Discriminator = this.VisitExpression(dt.Discriminator);
601 internal virtual SqlExpression VisitDiscriminatorOf(SqlDiscriminatorOf dof) {
602 dof.Object = this.VisitExpression(dof.Object);
605 internal virtual SqlNode VisitIncludeScope(SqlIncludeScope node) {
606 node.Child = this.Visit(node.Child);
613 internal bool RefersToColumn(SqlExpression exp, SqlColumn col) {
617 System.Diagnostics.Debug.Assert(refersDepth < 20);
620 switch (exp.NodeType) {
621 case SqlNodeType.Column:
622 return exp == col || this.RefersToColumn(((SqlColumn)exp).Expression, col);
623 case SqlNodeType.ColumnRef:
624 SqlColumnRef cref = (SqlColumnRef)exp;
625 return cref.Column == col || this.RefersToColumn(cref.Column.Expression, col);
626 case SqlNodeType.ExprSet:
627 SqlExprSet set = (SqlExprSet)exp;
628 for (int i = 0, n = set.Expressions.Count; i < n; i++) {
629 if (this.RefersToColumn(set.Expressions[i], col)) {
634 case SqlNodeType.OuterJoinedValue:
635 return this.RefersToColumn(((SqlUnary)exp).Operand, col);