9717a97a0cce46f48d4c0081720d81998dc5ed3f
[mono.git] / mcs / class / referencesource / System.Data.Linq / SqlClient / Query / SqlFlattener.cs
1 using System;
2 using System.Collections.Generic;
3
4 namespace System.Data.Linq.SqlClient {
5     using System.Data.Linq.Mapping;
6     using System.Data.Linq.Provider;
7     using System.Diagnostics.CodeAnalysis;
8
9     // flatten object expressions into rows
10     internal class SqlFlattener {
11         Visitor visitor;
12
13         internal SqlFlattener(SqlFactory sql, SqlColumnizer columnizer) {
14             this.visitor = new Visitor(sql, columnizer);
15         }
16
17         internal SqlNode Flatten(SqlNode node) {
18             node = this.visitor.Visit(node);
19             return node;
20         }
21
22         class Visitor : SqlVisitor {
23             [SuppressMessage("Microsoft.Performance", "CA1823:AvoidUnusedPrivateFields", Justification = "Microsoft: part of our standard visitor pattern")]
24             SqlFactory sql;
25             SqlColumnizer columnizer;
26             bool isTopLevel;
27             Dictionary<SqlColumn, SqlColumn> map = new Dictionary<SqlColumn,SqlColumn>();
28
29             [SuppressMessage("Microsoft.Performance", "CA1805:DoNotInitializeUnnecessarily", Justification="Unknown reason.")]
30             internal Visitor(SqlFactory sql, SqlColumnizer columnizer) {
31                 this.sql = sql;
32                 this.columnizer = columnizer;
33                 this.isTopLevel = true;
34             }
35
36             internal override SqlExpression VisitColumnRef(SqlColumnRef cref) {
37                 SqlColumn mapped;
38                 if (this.map.TryGetValue(cref.Column, out mapped)) {
39                     return new SqlColumnRef(mapped);
40                 }
41                 return cref;
42             }
43
44             internal override SqlSelect VisitSelectCore(SqlSelect select) {
45                 bool saveIsTopLevel = this.isTopLevel;
46                 this.isTopLevel = false;
47                 try {
48                     return base.VisitSelectCore(select);
49                 }
50                 finally {
51                     this.isTopLevel = saveIsTopLevel;
52                 }
53             }
54
55             internal override SqlSelect VisitSelect(SqlSelect select) {
56                 select = base.VisitSelect(select);
57
58                 select.Selection = this.FlattenSelection(select.Row, false, select.Selection);
59
60                 if (select.GroupBy.Count > 0) {
61                     this.FlattenGroupBy(select.GroupBy);
62                 }
63
64                 if (select.OrderBy.Count > 0) {
65                     this.FlattenOrderBy(select.OrderBy);
66                 }
67
68                 if (!this.isTopLevel) {
69                     select.Selection = new SqlNop(select.Selection.ClrType, select.Selection.SqlType, select.SourceExpression);
70                 }
71
72                 return select;
73             }
74
75             internal override SqlStatement VisitInsert(SqlInsert sin) {
76                 base.VisitInsert(sin);
77                 sin.Expression = this.FlattenSelection(sin.Row, true, sin.Expression);
78                 return sin;
79             }
80
81             private SqlExpression FlattenSelection(SqlRow row, bool isInput, SqlExpression selection) {
82                 selection = this.columnizer.ColumnizeSelection(selection);
83                 return new SelectionFlattener(row, this.map, isInput).VisitExpression(selection);
84             }
85
86             class SelectionFlattener : SqlVisitor {
87                 SqlRow row;
88                 Dictionary<SqlColumn, SqlColumn> map;
89                 bool isInput;
90                 bool isNew;
91
92                 internal SelectionFlattener(SqlRow row, Dictionary<SqlColumn, SqlColumn> map, bool isInput) {
93                     this.row = row;
94                     this.map = map;
95                     this.isInput = isInput;
96                 }
97
98                 internal override SqlExpression VisitNew(SqlNew sox) {
99                     this.isNew = true;
100                     return base.VisitNew(sox);
101                 }
102
103                 internal override SqlExpression VisitColumn(SqlColumn col) {
104                     SqlColumn c = this.FindColumn(this.row.Columns, col);
105                     if (c == null && col.Expression != null && !this.isInput && (!this.isNew || (this.isNew && !col.Expression.IsConstantColumn))) {
106                         c = this.FindColumnWithExpression(this.row.Columns, col.Expression);
107                     }
108                     if (c == null) {
109                         this.row.Columns.Add(col);
110                         c = col;
111                     }
112                     else if (c != col) {
113                         // preserve expr-sets when folding expressions together
114                         if (col.Expression.NodeType == SqlNodeType.ExprSet && c.Expression.NodeType != SqlNodeType.ExprSet) {
115                             c.Expression = col.Expression;
116                         }
117                         this.map[col] = c;
118                     }
119                     return new SqlColumnRef(c);
120                 }
121
122                 internal override SqlExpression VisitColumnRef(SqlColumnRef cref) {
123                     SqlColumn c = this.FindColumn(this.row.Columns, cref.Column);
124                     if (c == null) {
125                         return MakeFlattenedColumn(cref, null);
126                     }
127                     else {
128                         return new SqlColumnRef(c);
129                     }
130                 }
131
132                 // ignore subquery in selection
133                 internal override SqlExpression VisitSubSelect(SqlSubSelect ss) {
134                     return ss;
135                 }
136
137                 internal override SqlExpression VisitClientQuery(SqlClientQuery cq) {
138                     return cq;
139                 }
140
141                 private SqlColumnRef MakeFlattenedColumn(SqlExpression expr, string name) {
142                     SqlColumn c = (!this.isInput) ? this.FindColumnWithExpression(this.row.Columns, expr) : null;
143                     if (c == null) {
144                         c = new SqlColumn(expr.ClrType, expr.SqlType, name, null, expr, expr.SourceExpression);
145                         this.row.Columns.Add(c);
146                     }
147                     return new SqlColumnRef(c);
148                 }
149
150
151                 private SqlColumn FindColumn(IEnumerable<SqlColumn> columns, SqlColumn col) {
152                     foreach (SqlColumn c in columns) {
153                         if (this.RefersToColumn(c, col)) {
154                             return c;
155                         }
156                     }
157                     return null;
158                 }
159
160                 private SqlColumn FindColumnWithExpression(IEnumerable<SqlColumn> columns, SqlExpression expr) {
161                     foreach (SqlColumn c in columns) {
162                         if (c == expr) {
163                             return c;
164                         }
165                         if (SqlComparer.AreEqual(c.Expression, expr)) {
166                             return c;
167                         }
168                     }
169                     return null;
170                 }
171             }
172
173             private void FlattenGroupBy(List<SqlExpression> exprs) {
174                 List<SqlExpression> list = new List<SqlExpression>(exprs.Count);
175                 foreach (SqlExpression gex in exprs) {
176                     if (TypeSystem.IsSequenceType(gex.ClrType)) {
177                         throw Error.InvalidGroupByExpressionType(gex.ClrType.Name);
178                     }
179                     this.FlattenGroupByExpression(list, gex);
180                 }
181                 exprs.Clear();
182                 exprs.AddRange(list);
183             }
184
185             private void FlattenGroupByExpression(List<SqlExpression> exprs, SqlExpression expr) {
186                 SqlNew sn = expr as SqlNew;
187                 if (sn != null) {
188                     foreach (SqlMemberAssign ma in sn.Members) {
189                         this.FlattenGroupByExpression(exprs, ma.Expression);
190                     }
191                     foreach (SqlExpression arg in sn.Args) {
192                         this.FlattenGroupByExpression(exprs, arg);
193                     }
194                 }
195                 else if (expr.NodeType == SqlNodeType.TypeCase) {
196                     SqlTypeCase tc = (SqlTypeCase)expr;
197                     this.FlattenGroupByExpression(exprs, tc.Discriminator);
198                     foreach (SqlTypeCaseWhen when in tc.Whens) {
199                         this.FlattenGroupByExpression(exprs, when.TypeBinding);
200                     }
201                 }
202                 else if (expr.NodeType == SqlNodeType.Link) {
203                     SqlLink link = (SqlLink)expr;
204                     if (link.Expansion != null) {
205                         this.FlattenGroupByExpression(exprs, link.Expansion);
206                     }
207                     else {
208                         foreach (SqlExpression key in link.KeyExpressions) {
209                             this.FlattenGroupByExpression(exprs, key);
210                         }
211                     }
212                 }
213                 else if (expr.NodeType == SqlNodeType.OptionalValue) {
214                     SqlOptionalValue sop = (SqlOptionalValue)expr;
215                     this.FlattenGroupByExpression(exprs, sop.HasValue);
216                     this.FlattenGroupByExpression(exprs, sop.Value);
217                 }
218                 else if (expr.NodeType == SqlNodeType.OuterJoinedValue) {
219                     this.FlattenGroupByExpression(exprs, ((SqlUnary)expr).Operand);
220                 }
221                 else if (expr.NodeType == SqlNodeType.DiscriminatedType) {
222                     SqlDiscriminatedType dt = (SqlDiscriminatedType)expr;
223                     this.FlattenGroupByExpression(exprs, dt.Discriminator);
224                 }
225                 else {
226                     // this expression should have been 'pushed-down' in SqlBinder, so we
227                     // should only find column-references & expr-sets unless the expression could not
228                     // be columnized (in which case it was a bad group-by expression.)
229                     if (expr.NodeType != SqlNodeType.ColumnRef &&
230                         expr.NodeType != SqlNodeType.ExprSet) {
231                         if (!expr.SqlType.CanBeColumn) {
232                             throw Error.InvalidGroupByExpressionType(expr.SqlType.ToQueryString());
233                         }
234                         throw Error.InvalidGroupByExpression();
235                     }
236                     exprs.Add(expr);
237                 }
238             }
239
240             [SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
241             private void FlattenOrderBy(List<SqlOrderExpression> exprs) {
242                 foreach (SqlOrderExpression obex in exprs) {
243                     if (!obex.Expression.SqlType.IsOrderable) {
244                         if (obex.Expression.SqlType.CanBeColumn) {
245                             throw Error.InvalidOrderByExpression(obex.Expression.SqlType.ToQueryString());
246                         }
247                         else {
248                             throw Error.InvalidOrderByExpression(obex.Expression.ClrType.Name);
249                         }
250                     }
251                 }
252             }
253         }
254     }
255 }