2009-07-11 Michael Barker <mike@middlesoft.co.uk>
[mono.git] / mcs / class / System.Data.Linq / src / DbLinq / Data / Linq / Sugar / Implementation / SqlBuilder.cs
1 #region MIT license\r
2 // \r
3 // MIT license\r
4 //\r
5 // Copyright (c) 2007-2008 Jiri Moudry, Pascal Craponne\r
6 // \r
7 // Permission is hereby granted, free of charge, to any person obtaining a copy\r
8 // of this software and associated documentation files (the "Software"), to deal\r
9 // in the Software without restriction, including without limitation the rights\r
10 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\r
11 // copies of the Software, and to permit persons to whom the Software is\r
12 // furnished to do so, subject to the following conditions:\r
13 // \r
14 // The above copyright notice and this permission notice shall be included in\r
15 // all copies or substantial portions of the Software.\r
16 // \r
17 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\r
18 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\r
19 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\r
20 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\r
21 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\r
22 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\r
23 // THE SOFTWARE.\r
24 // \r
25 #endregion\r
26 \r
27 using System;\r
28 using System.Collections.Generic;\r
29 using System.Linq;\r
30 using System.Linq.Expressions;\r
31 \r
32 using DbLinq.Data.Linq.Sql;\r
33 using DbLinq.Data.Linq.Sugar.ExpressionMutator;\r
34 using DbLinq.Data.Linq.Sugar.Expressions;\r
35 \r
36 using DbLinq.Factory;\r
37 using DbLinq.Util;\r
38 \r
39 namespace DbLinq.Data.Linq.Sugar.Implementation\r
40 {\r
41     internal class SqlBuilder : ISqlBuilder\r
42     {\r
43         public IExpressionQualifier ExpressionQualifier { get; set; }\r
44 \r
45         public SqlBuilder()\r
46         {\r
47             ExpressionQualifier = ObjectFactory.Get<IExpressionQualifier>();\r
48         }\r
49 \r
50         /// <summary>\r
51         /// Builds a SQL string, based on a QueryContext\r
52         /// The build indirectly depends on ISqlProvider which provides all SQL Parts.\r
53         /// </summary>\r
54         /// <param name="expressionQuery"></param>\r
55         /// <param name="queryContext"></param>\r
56         /// <returns></returns>\r
57         public SqlStatement BuildSelect(ExpressionQuery expressionQuery, QueryContext queryContext)\r
58         {\r
59             return Build(expressionQuery.Select, queryContext);\r
60         }\r
61 \r
62         /// <summary>\r
63         /// Returns a list of sorted tables, given a select expression.\r
64         /// The tables are sorted by dependency: independent tables first, dependent tables next\r
65         /// </summary>\r
66         /// <param name="selectExpression"></param>\r
67         /// <returns></returns>\r
68         protected IList<TableExpression> GetSortedTables(SelectExpression selectExpression)\r
69         {\r
70             var tables = new List<TableExpression>();\r
71             foreach (var table in selectExpression.Tables)\r
72             {\r
73                 // the rules are:\r
74                 // a table climbs up to 0 until we find the table it depends on\r
75                 // we keep the index and insert on it\r
76                 // we place joining tables under joined tables\r
77                 int tableIndex;\r
78                 for (tableIndex = tables.Count; tableIndex > 0; tableIndex--)\r
79                 {\r
80                     // above us, the joined table? Stop now\r
81                     if (tables[tableIndex - 1] == table.JoinedTable)\r
82                         break;\r
83                     // if the current table is joining and we have a non-joining table above, we stop here too\r
84                     if (table.JoinExpression != null && tables[tableIndex - 1].JoinExpression == null)\r
85                         break;\r
86                 }\r
87                 tables.Insert(tableIndex, table);\r
88             }\r
89             return tables;\r
90         }\r
91 \r
92         /// <summary>\r
93         /// Main SQL builder\r
94         /// </summary>\r
95         /// <param name="selectExpression"></param>\r
96         /// <param name="queryContext"></param>\r
97         /// <returns></returns>\r
98         public SqlStatement Build(SelectExpression selectExpression, QueryContext queryContext)\r
99         {\r
100             var translator = GetTranslator(queryContext.DataContext.Vendor.SqlProvider);\r
101             var sqlProvider = queryContext.DataContext.Vendor.SqlProvider;\r
102             selectExpression = translator.OuterExpression(selectExpression);\r
103 \r
104             // A scope usually has:\r
105             // - a SELECT: the operation creating a CLR object with data coming from SQL tier\r
106             // - a FROM: list of tables\r
107             // - a WHERE: list of conditions\r
108             // - a GROUP BY: grouping by selected columns\r
109             // - a ORDER BY: sort\r
110             var select = BuildSelect(selectExpression, queryContext);\r
111             if (select.ToString() == string.Empty)\r
112             {\r
113                 SubSelectExpression subselect = null;\r
114                 if (selectExpression.Tables.Count == 1)\r
115                     subselect = selectExpression.Tables[0] as SubSelectExpression;\r
116                 if(subselect != null)\r
117                     return sqlProvider.GetParenthesis(Build(subselect.Select, queryContext));\r
118             }\r
119 \r
120             // TODO: the following might be wrong (at least this might be the wrong place to do this\r
121             if (select.ToString() == string.Empty)\r
122                 select = new SqlStatement("SELECT " + sqlProvider.GetLiteral(null) + " AS " + sqlProvider.GetSafeName("Empty"));\r
123 \r
124             var tables = GetSortedTables(selectExpression);\r
125             var from = BuildFrom(tables, queryContext);\r
126             var join = BuildJoin(tables, queryContext);\r
127             var where = BuildWhere(tables, selectExpression.Where, queryContext);\r
128             var groupBy = BuildGroupBy(selectExpression.Group, queryContext);\r
129             var having = BuildHaving(selectExpression.Where, queryContext);\r
130             var orderBy = BuildOrderBy(selectExpression.OrderBy, queryContext);\r
131             select = Join(queryContext, select, from, join, where, groupBy, having, orderBy);\r
132             select = BuildLimit(selectExpression, select, queryContext);\r
133 \r
134             if (selectExpression.NextSelectExpression != null)\r
135             {\r
136                 var nextLiteralSelect = Build(selectExpression.NextSelectExpression, queryContext);\r
137                 select = queryContext.DataContext.Vendor.SqlProvider.GetLiteral(\r
138                     selectExpression.NextSelectExpressionOperator,\r
139                     select, nextLiteralSelect);\r
140             }\r
141 \r
142             return select;\r
143         }\r
144 \r
145         public SqlStatement Join(QueryContext queryContext, params SqlStatement[] clauses)\r
146         {\r
147             return SqlStatement.Join(queryContext.DataContext.Vendor.SqlProvider.NewLine,\r
148                                (from clause in clauses where clause.ToString() != string.Empty select clause).ToList());\r
149         }\r
150 \r
151         /// <summary>\r
152         /// The simple part: converts an expression to SQL\r
153         /// This is not used for FROM clause\r
154         /// </summary>\r
155         /// <param name="expression"></param>\r
156         /// <param name="queryContext"></param>\r
157         /// <returns></returns>\r
158         protected virtual SqlStatement BuildExpression(Expression expression, QueryContext queryContext)\r
159         {\r
160             var sqlProvider = queryContext.DataContext.Vendor.SqlProvider;\r
161             var currentPrecedence = ExpressionQualifier.GetPrecedence(expression);\r
162             // first convert operands\r
163             var operands = expression.GetOperands();\r
164             var literalOperands = new List<SqlStatement>();\r
165             foreach (var operand in operands)\r
166             {\r
167                 var operandPrecedence = ExpressionQualifier.GetPrecedence(operand);\r
168                 var literalOperand = BuildExpression(operand, queryContext);\r
169                 if (operandPrecedence > currentPrecedence)\r
170                     literalOperand = sqlProvider.GetParenthesis(literalOperand);\r
171                 literalOperands.Add(literalOperand);\r
172             }\r
173 \r
174             // then converts expression\r
175             if (expression is SpecialExpression)\r
176                 return sqlProvider.GetLiteral(((SpecialExpression)expression).SpecialNodeType, literalOperands);\r
177             if (expression is TableExpression)\r
178             {\r
179                 var tableExpression = (TableExpression)expression;\r
180                 if (tableExpression.Alias != null) // if we have an alias, use it\r
181                 {\r
182                     return sqlProvider.GetColumn(sqlProvider.GetTableAlias(tableExpression.Alias),\r
183                                                  sqlProvider.GetColumns());\r
184                 }\r
185                 return sqlProvider.GetColumns();\r
186             }\r
187             if (expression is ColumnExpression)\r
188             {\r
189                 var columnExpression = (ColumnExpression)expression;\r
190                 if (columnExpression.Table.Alias != null)\r
191                 {\r
192                     return sqlProvider.GetColumn(sqlProvider.GetTableAlias(columnExpression.Table.Alias),\r
193                                                  columnExpression.Name);\r
194                 }\r
195                 return sqlProvider.GetColumn(columnExpression.Name);\r
196             }\r
197             if (expression is InputParameterExpression)\r
198             {\r
199                 var inputParameterExpression = (InputParameterExpression)expression;\r
200                 if (expression.Type.IsArray)\r
201                 {\r
202                     int i = 0;\r
203                     List<SqlStatement> inputParameters = new List<SqlStatement>();\r
204                     foreach (object p in (Array)inputParameterExpression.GetValue())\r
205                     {\r
206                         inputParameters.Add(new SqlStatement(new SqlParameterPart(sqlProvider.GetParameterName(inputParameterExpression.Alias + i.ToString()),\r
207                                                           inputParameterExpression.Alias + i.ToString())));\r
208                         ++i;\r
209                     }\r
210                     return new SqlStatement(sqlProvider.GetLiteral(inputParameters.ToArray()));\r
211                 }\r
212                 return\r
213                     new SqlStatement(new SqlParameterPart(sqlProvider.GetParameterName(inputParameterExpression.Alias),\r
214                                                           inputParameterExpression.Alias));\r
215             }\r
216             if (expression is SelectExpression)\r
217                 return Build((SelectExpression)expression, queryContext);\r
218             if (expression is ConstantExpression)\r
219                 return sqlProvider.GetLiteral(((ConstantExpression)expression).Value);\r
220             if (expression is GroupExpression)\r
221                 return BuildExpression(((GroupExpression)expression).GroupedExpression, queryContext);\r
222 \r
223             StartIndexOffsetExpression indexExpression = expression as StartIndexOffsetExpression;\r
224             if (indexExpression!=null)\r
225             {\r
226                 if (indexExpression.StartsAtOne)\r
227                 {\r
228                     literalOperands.Add(BuildExpression(Expression.Constant(1), queryContext));\r
229                     return sqlProvider.GetLiteral(ExpressionType.Add, literalOperands);\r
230                 }\r
231                 else\r
232                     return literalOperands.First();\r
233             }\r
234             if (expression.NodeType == ExpressionType.Convert || expression.NodeType == ExpressionType.ConvertChecked)\r
235             {\r
236                 var unaryExpression = (UnaryExpression)expression;\r
237                 var firstOperand = literalOperands.First();\r
238                 if (IsConversionRequired(unaryExpression))\r
239                     return sqlProvider.GetLiteralConvert(firstOperand, unaryExpression.Type);\r
240                 return firstOperand;\r
241             }\r
242             return sqlProvider.GetLiteral(expression.NodeType, literalOperands);\r
243         }\r
244 \r
245         private Expressions.ExpressionTranslator GetTranslator(DbLinq.Vendor.ISqlProvider provider)\r
246         {\r
247             var p = provider as DbLinq.Vendor.Implementation.SqlProvider;\r
248             if (p != null)\r
249                 return p.GetTranslator();\r
250             return new ExpressionTranslator();\r
251         }\r
252 \r
253         /// <summary>\r
254         /// Determines if a SQL conversion is required\r
255         /// </summary>\r
256         /// <param name="expression"></param>\r
257         /// <returns></returns>\r
258         private bool IsConversionRequired(UnaryExpression expression)\r
259         {\r
260             // obvious (and probably never happens), conversion to the same type\r
261             if (expression.Type == expression.Operand.Type)\r
262                 return false;\r
263             // second, nullable to non-nullable for the same type\r
264             if (expression.Type.IsNullable() && !expression.Operand.Type.IsNullable())\r
265             {\r
266                 if (expression.Type.GetNullableType() == expression.Operand.Type)\r
267                     return false;\r
268             }\r
269             // third, non-nullable to nullable\r
270             if (!expression.Type.IsNullable() && expression.Operand.Type.IsNullable())\r
271             {\r
272                 if (expression.Type == expression.Operand.Type.GetNullableType())\r
273                     return false;\r
274             }\r
275             // found no excuse not to convert? then convert\r
276             return true;\r
277         }\r
278 \r
279         protected virtual bool MustDeclareAsJoin(IList<TableExpression> tables, TableExpression table)\r
280         {\r
281             // the first table can not be declared as join\r
282             if (table == tables[0])\r
283                 return false;\r
284             // we must declare as join, whatever the join is,\r
285             // if some of the registered tables are registered as complex join\r
286             if (tables.Any(t => t.JoinType != TableJoinType.Inner))\r
287                 return table.JoinExpression != null;\r
288             return false;\r
289         }\r
290 \r
291         protected virtual SqlStatement BuildFrom(IList<TableExpression> tables, QueryContext queryContext)\r
292         {\r
293             var sqlProvider = queryContext.DataContext.Vendor.SqlProvider;\r
294             var fromClauses = new List<SqlStatement>();\r
295             foreach (var tableExpression in tables)\r
296             {\r
297                 if (!MustDeclareAsJoin(tables, tableExpression))\r
298                 {\r
299                     if (tableExpression.Alias != null)\r
300                     {\r
301                         string tableAlias;\r
302 \r
303                         // All subqueries has an alias in FROM\r
304                         SubSelectExpression subquery = tableExpression as SubSelectExpression;\r
305                         if (subquery == null)\r
306                             tableAlias = sqlProvider.GetTableAsAlias(tableExpression.Name, tableExpression.Alias);\r
307                         else\r
308                         {\r
309                             var subqueryStatements = new SqlStatement(Build(subquery.Select, queryContext));\r
310                             tableAlias = sqlProvider.GetSubQueryAsAlias(subqueryStatements.ToString(), tableExpression.Alias);\r
311                         }\r
312 \r
313                         if ((tableExpression.JoinType & TableJoinType.LeftOuter) != 0)\r
314                             tableAlias = "/* LEFT OUTER */ " + tableAlias;\r
315                         if ((tableExpression.JoinType & TableJoinType.RightOuter) != 0)\r
316                             tableAlias = "/* RIGHT OUTER */ " + tableAlias;\r
317                         fromClauses.Add(tableAlias);\r
318                     }\r
319                     else\r
320                     {\r
321                         fromClauses.Add(sqlProvider.GetTable(tableExpression.Name));\r
322                     }\r
323                 }\r
324             }\r
325             return sqlProvider.GetFromClause(fromClauses.ToArray());\r
326         }\r
327 \r
328         /// <summary>\r
329         /// Builds join clauses\r
330         /// </summary>\r
331         /// <param name="tables"></param>\r
332         /// <param name="queryContext"></param>\r
333         /// <returns></returns>\r
334         protected virtual SqlStatement BuildJoin(IList<TableExpression> tables, QueryContext queryContext)\r
335         {\r
336             var sqlProvider = queryContext.DataContext.Vendor.SqlProvider;\r
337             var joinClauses = new List<SqlStatement>();\r
338             foreach (var tableExpression in tables)\r
339             {\r
340                 // this is the pending declaration of direct tables\r
341                 if (MustDeclareAsJoin(tables, tableExpression))\r
342                 {\r
343                     // get constitutive Parts\r
344                     var joinExpression = BuildExpression(tableExpression.JoinExpression, queryContext);\r
345                     var tableAlias = sqlProvider.GetTableAsAlias(tableExpression.Name, tableExpression.Alias);\r
346                     SqlStatement joinClause;\r
347                     switch (tableExpression.JoinType)\r
348                     {\r
349                         case TableJoinType.Inner:\r
350                             joinClause = sqlProvider.GetInnerJoinClause(tableAlias, joinExpression);\r
351                             break;\r
352                         case TableJoinType.LeftOuter:\r
353                             joinClause = sqlProvider.GetLeftOuterJoinClause(tableAlias, joinExpression);\r
354                             break;\r
355                         case TableJoinType.RightOuter:\r
356                             joinClause = sqlProvider.GetRightOuterJoinClause(tableAlias, joinExpression);\r
357                             break;\r
358                         case TableJoinType.FullOuter:\r
359                             throw new NotImplementedException();\r
360                         default:\r
361                             throw new ArgumentOutOfRangeException();\r
362                     }\r
363                     joinClauses.Add(joinClause);\r
364                 }\r
365             }\r
366             return sqlProvider.GetJoinClauses(joinClauses.ToArray());\r
367         }\r
368 \r
369         protected virtual bool IsHavingClause(Expression expression)\r
370         {\r
371             bool isHaving = false;\r
372             expression.Recurse(delegate(Expression e)\r
373                                    {\r
374                                        if (e is GroupExpression)\r
375                                            isHaving = true;\r
376                                        return e;\r
377                                    });\r
378             return isHaving;\r
379         }\r
380 \r
381         protected virtual SqlStatement BuildWhere(IList<TableExpression> tables, IList<Expression> wheres, QueryContext queryContext)\r
382         {\r
383             var sqlProvider = queryContext.DataContext.Vendor.SqlProvider;\r
384             var whereClauses = new List<SqlStatement>();\r
385             foreach (var tableExpression in tables)\r
386             {\r
387                 if (!MustDeclareAsJoin(tables, tableExpression) && tableExpression.JoinExpression != null)\r
388                     whereClauses.Add(BuildExpression(tableExpression.JoinExpression, queryContext));\r
389             }\r
390             foreach (var whereExpression in wheres)\r
391             {\r
392                 if (!IsHavingClause(whereExpression))\r
393                     whereClauses.Add(BuildExpression(whereExpression, queryContext));\r
394             }\r
395             return sqlProvider.GetWhereClause(whereClauses.ToArray());\r
396         }\r
397 \r
398         protected virtual SqlStatement BuildHaving(IList<Expression> wheres, QueryContext queryContext)\r
399         {\r
400             var sqlProvider = queryContext.DataContext.Vendor.SqlProvider;\r
401             var havingClauses = new List<SqlStatement>();\r
402             foreach (var whereExpression in wheres)\r
403             {\r
404                 if (IsHavingClause(whereExpression))\r
405                     havingClauses.Add(BuildExpression(whereExpression, queryContext));\r
406             }\r
407             return sqlProvider.GetHavingClause(havingClauses.ToArray());\r
408         }\r
409 \r
410         protected virtual SqlStatement GetGroupByClause(ColumnExpression columnExpression, QueryContext queryContext)\r
411         {\r
412             var sqlProvider = queryContext.DataContext.Vendor.SqlProvider;\r
413             if (columnExpression.Table.Alias != null)\r
414             {\r
415                 return sqlProvider.GetColumn(sqlProvider.GetTableAlias(columnExpression.Table.Alias),\r
416                                              columnExpression.Name);\r
417             }\r
418             return sqlProvider.GetColumn(columnExpression.Name);\r
419         }\r
420 \r
421         protected virtual SqlStatement BuildGroupBy(IList<GroupExpression> groupByExpressions, QueryContext queryContext)\r
422         {\r
423             var sqlProvider = queryContext.DataContext.Vendor.SqlProvider;\r
424             var groupByClauses = new List<SqlStatement>();\r
425             foreach (var groupByExpression in groupByExpressions)\r
426             {\r
427                 foreach (var operand in groupByExpression.Clauses)\r
428                 {\r
429                     var columnOperand = operand as ColumnExpression;\r
430                     if (columnOperand == null)\r
431                         throw Error.BadArgument("S0201: Groupby argument must be a ColumnExpression");\r
432                     groupByClauses.Add(GetGroupByClause(columnOperand, queryContext));\r
433                 }\r
434             }\r
435             return sqlProvider.GetGroupByClause(groupByClauses.ToArray());\r
436         }\r
437 \r
438         protected virtual SqlStatement BuildOrderBy(IList<OrderByExpression> orderByExpressions, QueryContext queryContext)\r
439         {\r
440             var sqlProvider = queryContext.DataContext.Vendor.SqlProvider;\r
441             var orderByClauses = new List<SqlStatement>();\r
442             foreach (var clause in orderByExpressions)\r
443             {\r
444                 orderByClauses.Add(sqlProvider.GetOrderByColumn(BuildExpression(clause.ColumnExpression, queryContext),\r
445                                                                 clause.Descending));\r
446             }\r
447             return sqlProvider.GetOrderByClause(orderByClauses.ToArray());\r
448         }\r
449 \r
450         protected virtual SqlStatement BuildSelect(Expression select, QueryContext queryContext)\r
451         {\r
452             var sqlProvider = queryContext.DataContext.Vendor.SqlProvider;\r
453             var selectClauses = new List<SqlStatement>();\r
454             foreach (var selectExpression in select.GetOperands())\r
455             {\r
456                 var expressionString = BuildExpression(selectExpression, queryContext);\r
457                 if (selectExpression is SelectExpression)\r
458                     selectClauses.Add(sqlProvider.GetParenthesis(expressionString));\r
459                 else\r
460                     selectClauses.Add(expressionString);\r
461             }\r
462             SelectExpression selectExp = select as SelectExpression;\r
463             if (selectExp != null)\r
464             {\r
465                 if (selectExp.Group.Count == 1 && selectExp.Group[0].GroupedExpression == selectExp.Group[0].KeyExpression)\r
466                 {\r
467                     // this is a select DISTINCT expression\r
468                     // TODO: better handle selected columns on DISTINCT: I suspect this will not work in some cases\r
469                     if (selectClauses.Count == 0)\r
470                     {\r
471                         selectClauses.Add(sqlProvider.GetColumns());\r
472                     }\r
473                     return sqlProvider.GetSelectDistinctClause(selectClauses.ToArray());\r
474                 }\r
475             }\r
476             return sqlProvider.GetSelectClause(selectClauses.ToArray());\r
477         }\r
478 \r
479         protected virtual SqlStatement BuildLimit(SelectExpression select, SqlStatement literalSelect, QueryContext queryContext)\r
480         {\r
481             if (select.Limit != null)\r
482             {\r
483                 var literalLimit = BuildExpression(select.Limit, queryContext);\r
484                 if (select.Offset != null)\r
485                 {\r
486                     var literalOffset = BuildExpression(select.Offset, queryContext);\r
487                     var literalOffsetAndLimit = BuildExpression(select.OffsetAndLimit, queryContext);\r
488                     return queryContext.DataContext.Vendor.SqlProvider.GetLiteralLimit(literalSelect, literalLimit,\r
489                                                                                        literalOffset,\r
490                                                                                        literalOffsetAndLimit);\r
491                 }\r
492                 return queryContext.DataContext.Vendor.SqlProvider.GetLiteralLimit(literalSelect, literalLimit);\r
493             }\r
494             return literalSelect;\r
495         }\r
496     }\r
497 }