2010-03-06 Rodrigo Kumpera <rkumpera@novell.com>
[mono.git] / mcs / class / System.Data / Mono.Data.SqlExpressions / Aggregation.cs
1 //
2 // Aggregation.cs
3 //
4 // Author:
5 //   Juraj Skripsky (juraj@hotfeet.ch)
6 //
7 // (C) 2004 HotFeet GmbH (http://www.hotfeet.ch)
8 //
9
10 //
11 // Copyright (C) 2004 Novell, Inc (http://www.novell.com)
12 //
13 // Permission is hereby granted, free of charge, to any person obtaining
14 // a copy of this software and associated documentation files (the
15 // "Software"), to deal in the Software without restriction, including
16 // without limitation the rights to use, copy, modify, merge, publish,
17 // distribute, sublicense, and/or sell copies of the Software, and to
18 // permit persons to whom the Software is furnished to do so, subject to
19 // the following conditions:
20 // 
21 // The above copyright notice and this permission notice shall be
22 // included in all copies or substantial portions of the Software.
23 // 
24 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
25 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
26 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
27 // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
28 // LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
29 // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
30 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
31 //
32
33 using System;
34 using System.Collections;
35 using System.Data;
36
37 namespace Mono.Data.SqlExpressions {
38         internal enum AggregationFunction {
39                 Count, Sum, Min, Max, Avg, StDev, Var
40         }
41
42         internal class Aggregation : BaseExpression {
43                 bool cacheResults;
44                 DataRow[] rows;
45                 ColumnReference column;
46                 AggregationFunction function;
47                 int count;
48                 IConvertible result;
49                 DataRowChangeEventHandler RowChangeHandler;
50                 DataTable table ;
51
52                 public Aggregation (bool cacheResults, DataRow[] rows, AggregationFunction function, ColumnReference column)
53                 {
54                         this.cacheResults = cacheResults;
55                         this.rows = rows;
56                         this.column = column;
57                         this.function = function;
58                         this.result = null;
59                         if (cacheResults)
60                                 RowChangeHandler = new DataRowChangeEventHandler (InvalidateCache);
61                 }
62
63                 public override bool Equals(object obj)
64                 {
65                         if (!base.Equals (obj))
66                                 return false;
67
68                         if (!(obj is Aggregation))
69                                 return false;
70
71                         Aggregation other = (Aggregation) obj;
72                         if (!other.function.Equals( function))
73                                 return false;
74
75                         if (!other.column.Equals (column))
76                                 return false;           
77
78                         if (other.rows != null && rows != null) {
79                         if (other.rows.Length != rows.Length)
80                                 return false;
81
82                         for (int i=0; i < rows.Length; i++)
83                                 if (other.rows [i] != rows [i])
84                                         return false;
85
86                         }
87                         else if (!(other.rows == null && rows == null))
88                                 return false;
89                 
90                         return true;
91                 }
92
93                 public override int GetHashCode()
94                 {
95                         int hashCode = base.GetHashCode ();
96                         hashCode ^= function.GetHashCode ();
97                         hashCode ^= column.GetHashCode ();
98                         for (int i=0; i < rows.Length; i++)
99                                 hashCode ^= rows [i].GetHashCode ();
100                         
101                         return hashCode;
102                 }
103                 
104         
105                 public override object Eval (DataRow row)
106                 {
107                         //TODO: implement a better caching strategy and a mechanism for cache invalidation.
108                         //for now only aggregation over the table owning 'row' (e.g. 'sum(parts)'
109                         //in constrast to 'sum(child.parts)') is cached.
110                         if (cacheResults && result != null && column.ReferencedTable == ReferencedTable.Self)
111                                 return result;
112                                 
113                         count = 0;
114                         result = null;
115                         
116                         object[] values;
117                         if (rows == null)
118                                 values = column.GetValues (column.GetReferencedRows (row));
119                         else
120                                 values = column.GetValues (rows);
121                         
122                         foreach (object val in values) {
123                                 if (val == null)
124                                         continue;
125                                         
126                                 count++;
127                                 Aggregate ((IConvertible)val);
128                         }
129
130                         switch (function) {
131                         case AggregationFunction.StDev:
132                         case AggregationFunction.Var:
133                                 result = CalcStatisticalFunction (values);
134                                 break;
135                                         
136                         case AggregationFunction.Avg:
137                                 result = ((count == 0) ? DBNull.Value : Numeric.Divide (result, count));
138                                 break;
139                         
140                         case AggregationFunction.Count:
141                                 result = count;
142                                 break;
143                         }
144                         
145                         if (result == null)
146                                 result = DBNull.Value;
147                         
148                         if (cacheResults && column.ReferencedTable == ReferencedTable.Self) 
149                         {
150                                 table = row.Table;
151                                 row.Table.RowChanged += RowChangeHandler;
152                         }       
153                         return result;
154                 }
155
156                 override public bool DependsOn(DataColumn other)
157                 {
158                         return column.DependsOn(other);
159                 }
160                 
161                 private void Aggregate (IConvertible val)
162                 {
163                         switch (function) {
164                         case AggregationFunction.Min:
165                                 result = (result != null ? Numeric.Min (result, val) : val);
166                                 return;
167                         
168                         case AggregationFunction.Max:
169                                 result = (result != null ? Numeric.Max (result, val) : val);
170                                 return;
171
172                         case AggregationFunction.Sum:
173                         case AggregationFunction.Avg:
174                         case AggregationFunction.StDev:
175                         case AggregationFunction.Var:
176                                 result = (result != null ? Numeric.Add (result, val) : val);
177                                 return;
178                         }
179                 }
180                 
181                 private IConvertible CalcStatisticalFunction (object[] values)
182                 {
183                         if (count < 2)
184                                 return DBNull.Value;
185
186                         double average = (double)Convert.ChangeType(result, TypeCode.Double) / count;
187                         double res = 0.0;
188                                                 
189                         foreach (object val in values) {
190                                 if (val == null)
191                                         continue;
192                                         
193                                 double diff = average - (double)Convert.ChangeType(val, TypeCode.Double);
194                                 res += System.Math.Pow (diff, 2);
195                         }
196                         res /= (count - 1);
197                         
198                         if (function == AggregationFunction.StDev)
199                                 res = System.Math.Sqrt (res);
200
201                         return res;
202                 }
203
204                 public override void ResetExpression ()
205                 {
206                         if (table != null)
207                                 InvalidateCache (table, null);
208                 }
209
210                 private void InvalidateCache (Object sender, DataRowChangeEventArgs args)
211                 {
212                         result = null; 
213                         ((DataTable)sender).RowChanged -= RowChangeHandler;
214                 }
215         }
216 }