merge -r 60439:60440
[mono.git] / mcs / class / System.Data / Mono.Data.SqlExpressions / Aggregation.cs
index ab6d0028d91b1e64ea044a6b3c50aaeb422aae5d..5f27a849f2e1cf2e1896dce0644bc1948bb164ec 100644 (file)
@@ -46,7 +46,9 @@ namespace Mono.Data.SqlExpressions {
                AggregationFunction function;
                int count;
                IConvertible result;
-       
+               DataRowChangeEventHandler RowChangeHandler;
+               DataTable table ;
+
                public Aggregation (bool cacheResults, DataRow[] rows, AggregationFunction function, ColumnReference column)
                {
                        this.cacheResults = cacheResults;
@@ -54,13 +56,57 @@ namespace Mono.Data.SqlExpressions {
                        this.column = column;
                        this.function = function;
                        this.result = null;
+                       if (cacheResults)
+                               RowChangeHandler = new DataRowChangeEventHandler (InvalidateCache);
+               }
+
+               public override bool Equals(object obj)
+               {
+                       if (!base.Equals (obj))
+                               return false;
+
+                       if (!(obj is Aggregation))
+                               return false;
+
+                       Aggregation other = (Aggregation) obj;
+                       if (!other.function.Equals( function))
+                               return false;
+
+                       if (!other.column.Equals (column))
+                               return false;           
+
+                       if (other.rows != null && rows != null) {
+                       if (other.rows.Length != rows.Length)
+                               return false;
+
+                       for (int i=0; i < rows.Length; i++)
+                               if (other.rows [i] != rows [i])
+                                       return false;
+
+                       }
+                       else if (!(other.rows == null && rows == null))
+                               return false;
+               
+                       return true;
                }
+
+               public override int GetHashCode()
+               {
+                       int hashCode = base.GetHashCode ();
+                       hashCode ^= function.GetHashCode ();
+                       hashCode ^= column.GetHashCode ();
+                       for (int i=0; i < rows.Length; i++)
+                               hashCode ^= rows [i].GetHashCode ();
+                       
+                       return hashCode;
+               }
+               
        
                public override object Eval (DataRow row)
                {
                        //TODO: implement a better caching strategy and a mechanism for cache invalidation.
                        //for now only aggregation over the table owning 'row' (e.g. 'sum(parts)'
-                       //in constrast to 'sum(parent.parts)' and 'sum(child.parts)') is cached.
+                       //in constrast to 'sum(child.parts)') is cached.
                        if (cacheResults && result != null && column.ReferencedTable == ReferencedTable.Self)
                                return result;
                                
@@ -80,7 +126,7 @@ namespace Mono.Data.SqlExpressions {
                                count++;
                                Aggregate ((IConvertible)val);
                        }
-                       
+
                        switch (function) {
                        case AggregationFunction.StDev:
                        case AggregationFunction.Var:
@@ -88,7 +134,7 @@ namespace Mono.Data.SqlExpressions {
                                break;
                                        
                        case AggregationFunction.Avg:
-                               result = Numeric.Divide (result, count);
+                               result = ((count == 0) ? DBNull.Value : Numeric.Divide (result, count));
                                break;
                        
                        case AggregationFunction.Count:
@@ -98,7 +144,12 @@ namespace Mono.Data.SqlExpressions {
                        
                        if (result == null)
                                result = DBNull.Value;
-                               
+                       
+                       if (cacheResults && column.ReferencedTable == ReferencedTable.Self) 
+                       {
+                               table = row.Table;
+                               row.Table.RowChanged += RowChangeHandler;
+                       }       
                        return result;
                }
 
@@ -129,6 +180,9 @@ namespace Mono.Data.SqlExpressions {
                
                private IConvertible CalcStatisticalFunction (object[] values)
                {
+                       if (count < 2)
+                               return DBNull.Value;
+
                        double average = (double)Convert.ChangeType(result, TypeCode.Double) / count;
                        double res = 0.0;
                                                
@@ -143,8 +197,20 @@ namespace Mono.Data.SqlExpressions {
                        
                        if (function == AggregationFunction.StDev)
                                res = System.Math.Sqrt (res);
-                       
+
                        return res;
                }
+
+               public override void ResetExpression ()
+               {
+                       if (table != null)
+                               InvalidateCache (table, null);
+               }
+
+               private void InvalidateCache (Object sender, DataRowChangeEventArgs args)
+               {
+                       result = null; 
+                       ((DataTable)sender).RowChanged -= RowChangeHandler;
+               }
        }
 }