X-Git-Url: http://wien.tomnetworks.com/gitweb/?a=blobdiff_plain;f=mcs%2Fclass%2FSystem.Core%2FSystem.Linq%2FEnumerable.cs;h=d2c98d0a1954f16a877a18c2c856d759747d2563;hb=cc2c56a916645e45991c43f86409966218a6a177;hp=57c7597917154830b3c9ca422420d2f542063a1e;hpb=8d5f4968f6e707aaea81b04ef6a01d6015789c6c;p=mono.git diff --git a/mcs/class/System.Core/System.Linq/Enumerable.cs b/mcs/class/System.Core/System.Linq/Enumerable.cs index 57c75979171..d2c98d0a195 100644 --- a/mcs/class/System.Core/System.Linq/Enumerable.cs +++ b/mcs/class/System.Core/System.Linq/Enumerable.cs @@ -49,6 +49,10 @@ namespace System.Linq public static readonly Func Always = (t) => true; } + class Function { + public static readonly Func Identity = (t) => t; + } + #region Aggregate public static TSource Aggregate (this IEnumerable source, Func func) @@ -116,6 +120,10 @@ namespace System.Linq { Check.Source (source); + var collection = source as ICollection; + if (collection != null) + return collection.Count > 0; + using (var enumerator = source.GetEnumerator ()) return enumerator.MoveNext (); } @@ -225,7 +233,7 @@ namespace System.Linq { Check.Source (source); - return source.AverageNullable ((a, b) => a + b, (a, b) => a / b); + return source.AverageNullable ((a, b) => a + b, (a, b) => (double) a / b); } public static double? Average (this IEnumerable source) @@ -332,8 +340,8 @@ namespace System.Linq static IEnumerable CreateCastIterator (IEnumerable source) { - foreach (object element in source) - yield return (TResult) element; + foreach (TResult element in source) + yield return element; } #endregion @@ -579,7 +587,20 @@ namespace System.Linq { Check.Source (source); - return source.First (PredicateOf.Always, Fallback.Throw); + var list = source as IList; + if (list != null) { + if (list.Count != 0) + return list [0]; + + throw new InvalidOperationException (); + } else { + using (var enumerator = source.GetEnumerator ()) { + if (enumerator.MoveNext ()) + return enumerator.Current; + } + } + + throw new InvalidOperationException (); } public static TSource First (this IEnumerable source, Func predicate) @@ -633,6 +654,12 @@ namespace System.Linq { Check.SourceAndKeySelector (source, keySelector); + return CreateGroupByIterator (source, keySelector, comparer); + } + + static IEnumerable> CreateGroupByIterator (this IEnumerable source, + Func keySelector, IEqualityComparer comparer) + { Dictionary> groups = new Dictionary> (); List nullList = new List (); int counter = 0; @@ -670,7 +697,6 @@ namespace System.Linq } } - public static IEnumerable> GroupBy (this IEnumerable source, Func keySelector, Func elementSelector) { @@ -727,7 +753,6 @@ namespace System.Linq return GroupBy (source, keySelector, elementSelector, resultSelector, null); } - [MonoTODO] public static IEnumerable GroupBy (this IEnumerable source, Func keySelector, Func elementSelector, Func, TResult> resultSelector, @@ -737,7 +762,7 @@ namespace System.Linq source, keySelector, elementSelector, comparer); foreach (IGrouping group in groups) - yield return resultSelector (group.Key, group); + yield return resultSelector (group.Key, group); } public static IEnumerable GroupBy (this IEnumerable source, @@ -747,7 +772,6 @@ namespace System.Linq return GroupBy (source, keySelector, resultSelector, null); } - [MonoTODO] public static IEnumerable GroupBy (this IEnumerable source, Func keySelector, Func, TResult> resultSelector, @@ -780,6 +804,14 @@ namespace System.Linq if (comparer == null) comparer = EqualityComparer.Default; + return CreateGroupJoinIterator (outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer); + } + + static IEnumerable CreateGroupJoinIterator (this IEnumerable outer, + IEnumerable inner, Func outerKeySelector, + Func innerKeySelector, Func, TResult> resultSelector, + IEqualityComparer comparer) + { ILookup innerKeys = ToLookup (inner, innerKeySelector, comparer); /*Dictionary> innerKeys = new Dictionary> (); foreach (U element in inner) @@ -822,7 +854,7 @@ namespace System.Linq { var items = new HashSet (second, comparer); foreach (TSource element in first) { - if (items.Contains (element)) + if (items.Remove (element)) yield return element; } } @@ -840,6 +872,13 @@ namespace System.Linq if (comparer == null) comparer = EqualityComparer.Default; + return CreateJoinIterator (outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer); + } + + static IEnumerable CreateJoinIterator (this IEnumerable outer, + IEnumerable inner, Func outerKeySelector, + Func innerKeySelector, Func resultSelector, IEqualityComparer comparer) + { ILookup innerKeys = ToLookup (inner, innerKeySelector, comparer); /*Dictionary> innerKeys = new Dictionary> (); foreach (U element in inner) @@ -863,7 +902,7 @@ namespace System.Linq IEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func resultSelector) { - return Join (outer, inner, outerKeySelector, innerKeySelector, resultSelector, null); + return outer.Join (inner, outerKeySelector, innerKeySelector, resultSelector, null); } #endregion @@ -896,6 +935,14 @@ namespace System.Linq { Check.Source (source); + var collection = source as ICollection; + if (collection != null && collection.Count == 0) + throw new InvalidOperationException (); + + var list = source as IList; + if (list != null) + return list [list.Count - 1]; + return source.Last (PredicateOf.Always, Fallback.Throw); } @@ -1005,51 +1052,78 @@ namespace System.Linq { Check.Source (source); - return IterateNullable (source, int.MinValue, (a, b) => a > b); + return IterateNullable (source, (a, b) => Math.Max (a, b)); } public static long? Max (this IEnumerable source) { Check.Source (source); - return IterateNullable (source, long.MinValue, (a, b) => a > b); + return IterateNullable (source, (a, b) => Math.Max (a, b)); } public static double? Max (this IEnumerable source) { Check.Source (source); - return IterateNullable (source, double.MinValue, (a, b) => a > b); + return IterateNullable (source, (a, b) => Math.Max (a, b)); } public static float? Max (this IEnumerable source) { Check.Source (source); - return IterateNullable (source, float.MinValue, (a, b) => a > b); + return IterateNullable (source, (a, b) => Math.Max (a, b)); } public static decimal? Max (this IEnumerable source) { Check.Source (source); - return IterateNullable (source, decimal.MinValue, (a, b) => a > b); + return IterateNullable (source, (a, b) => Math.Max (a, b)); } - static T? IterateNullable (IEnumerable source, T initValue, Func selector) where T : struct + static T? IterateNullable (IEnumerable source, Func selector) where T : struct { - int counter = 0; - T? value = initValue; + bool empty = true; + T? value = null; foreach (var element in source) { if (!element.HasValue) continue; - if (selector (element.Value, value)) - value = element; - ++counter; + if (!value.HasValue) + value = element.Value; + else + value = selector (element.Value, value.Value); + + empty = false; } - if (counter == 0) + if (empty) + return null; + + return value; + } + + static TRet? IterateNullable ( + IEnumerable source, + Func source_selector, + Func selector) where TRet : struct + { + bool empty = true; + TRet? value = null; + foreach (var element in source) { + TRet? item = source_selector (element); + + if (!value.HasValue) + value = item; + else if (selector (item, value)) + value = item; + + empty = false; + } + + if (empty) return null; return value; @@ -1124,79 +1198,51 @@ namespace System.Linq static U Iterate (IEnumerable source, U initValue, Func selector) { - int counter = 0; + bool empty = true; foreach (var element in source) { initValue = selector (element, initValue); - ++counter; + empty = false; } - if (counter == 0) + if (empty) throw new InvalidOperationException (); return initValue; } - static U? IterateNullable (IEnumerable source, U initialValue, Func selector) where U : struct - { - int counter = 0; - U? value = initialValue; - foreach (var element in source) { - value = selector (element, value); - if (!value.HasValue) - continue; - - ++counter; - } - - if (counter == 0) - return null; - - return value; - } - public static int? Max (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return IterateNullable (source, int.MinValue, (a, b) => { - var v = selector (a); return v > b ? v : b; - }); + return IterateNullable (source, selector, (a, b) => a > b); } public static long? Max (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return IterateNullable (source, long.MinValue, (a, b) => { - var v = selector (a); return v > b ? v : b; - }); + return IterateNullable (source, selector, (a, b) => a > b); } public static double? Max (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return IterateNullable (source, double.MinValue, (a, b) => { - var v = selector (a); return v > b ? v : b; - }); + return IterateNullable (source, selector, (a, b) => a > b); } public static float? Max (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return IterateNullable (source, float.MinValue, (a, b) => { - var v = selector (a); return v > b ? v : b; - }); + return IterateNullable (source, selector, (a, b) => a > b); } public static decimal? Max (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return IterateNullable (source, decimal.MinValue, (a, b) => { - var v = selector (a); return v > b ? v : b; - }); + return IterateNullable (source, selector, (a, b) => a > b); } public static TResult Max (this IEnumerable source, Func selector) @@ -1275,35 +1321,35 @@ namespace System.Linq { Check.Source (source); - return IterateNullable (source, int.MaxValue, (a, b) => a < b); + return IterateNullable (source, (a, b) => Math.Min (a, b)); } public static long? Min (this IEnumerable source) { Check.Source (source); - return IterateNullable (source, long.MaxValue, (a, b) => a < b); + return IterateNullable (source, (a, b) => Math.Min (a, b)); } public static double? Min (this IEnumerable source) { Check.Source (source); - return IterateNullable (source, double.MaxValue, (a, b) => a < b); + return IterateNullable (source, (a, b) => Math.Min (a, b)); } public static float? Min (this IEnumerable source) { Check.Source (source); - return IterateNullable (source, float.MaxValue, (a, b) => a < b); + return IterateNullable (source, (a, b) => Math.Min (a, b)); } public static decimal? Min (this IEnumerable source) { Check.Source (source); - return IterateNullable (source, decimal.MaxValue, (a, b) => a < b); + return IterateNullable (source, (a, b) => Math.Min (a, b)); } public static TSource Min (this IEnumerable source) @@ -1377,45 +1423,35 @@ namespace System.Linq { Check.SourceAndSelector (source, selector); - return IterateNullable (source, int.MaxValue, (a, b) => { - var v = selector (a); return v < b ? v : b; - }); + return IterateNullable (source, selector, (a, b) => a < b); } public static long? Min (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return IterateNullable (source, long.MaxValue, (a, b) => { - var v = selector (a); return v < b ? v : b; - }); + return IterateNullable (source, selector, (a, b) => a < b); } public static float? Min (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return IterateNullable (source, float.MaxValue, (a, b) => { - var v = selector (a); return v < b ? v : b; - }); + return IterateNullable (source, selector, (a, b) => a < b); } public static double? Min (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return IterateNullable (source, double.MaxValue, (a, b) => { - var v = selector (a); return v < b ? v : b; - }); + return IterateNullable (source, selector, (a, b) => a < b); } public static decimal? Min (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return IterateNullable (source, decimal.MaxValue, (a, b) => { - var v = selector (a); return v < b ? v : b; - }); + return IterateNullable (source, selector, (a, b) => a < b); } public static TResult Min (this IEnumerable source, Func selector) @@ -1807,30 +1843,30 @@ namespace System.Linq { Check.Source (source); - return Sum (source, (a, b) => a + b); + return Sum (source, (a, b) => checked (a + b)); } public static int? Sum (this IEnumerable source) { Check.Source (source); - return SumNullable (source, (a, b) => a.HasValue ? a + b : a); + return source.SumNullable (0, (total, element) => element.HasValue ? checked (total + element) : total); } public static int Sum (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return Sum (source, (a, b) => a + selector (b)); + return Sum (source, (a, b) => checked (a + selector (b))); } public static int? Sum (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return SumNullable (source, (a, b) => { + return source.SumNullable (0, (a, b) => { var value = selector (b); - return value.HasValue ? a + value.Value : a; + return value.HasValue ? checked (a + value.Value) : a; }); } @@ -1838,30 +1874,30 @@ namespace System.Linq { Check.Source (source); - return Sum (source, (a, b) => a + b); + return Sum (source, (a, b) => checked (a + b)); } public static long? Sum (this IEnumerable source) { Check.Source (source); - return SumNullable (source, (a, b) => a.HasValue ? a + b : a); + return source.SumNullable (0, (total, element) => element.HasValue ? checked (total + element) : total); } public static long Sum (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return Sum (source, (a, b) => a + selector (b)); + return Sum (source, (a, b) => checked (a + selector (b))); } public static long? Sum (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return SumNullable (source, (a, b) => { + return source.SumNullable (0, (a, b) => { var value = selector (b); - return value.HasValue ? a + value.Value : a; + return value.HasValue ? checked (a + value.Value) : a; }); } @@ -1869,30 +1905,30 @@ namespace System.Linq { Check.Source (source); - return Sum (source, (a, b) => a + b); + return Sum (source, (a, b) => checked (a + b)); } public static double? Sum (this IEnumerable source) { Check.Source (source); - return SumNullable (source, (a, b) => a.HasValue ? a + b : a); + return source.SumNullable (0, (total, element) => element.HasValue ? checked (total + element) : total); } public static double Sum (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return Sum (source, (a, b) => a + selector (b)); + return Sum (source, (a, b) => checked (a + selector (b))); } public static double? Sum (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return SumNullable (source, (a, b) => { + return source.SumNullable (0, (a, b) => { var value = selector (b); - return value.HasValue ? a + value.Value : a; + return value.HasValue ? checked (a + value.Value) : a; }); } @@ -1900,30 +1936,30 @@ namespace System.Linq { Check.Source (source); - return Sum (source, (a, b) => a + b); + return Sum (source, (a, b) => checked (a + b)); } public static float? Sum (this IEnumerable source) { Check.Source (source); - return SumNullable (source, (a, b) => a.HasValue ? a + b : a); + return source.SumNullable (0, (total, element) => element.HasValue ? checked (total + element) : total); } public static float Sum (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return Sum (source, (a, b) => a + selector (b)); + return Sum (source, (a, b) => checked (a + selector (b))); } public static float? Sum (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return SumNullable (source, (a, b) => { + return source.SumNullable (0, (a, b) => { var value = selector (b); - return value.HasValue ? a + value.Value : a; + return value.HasValue ? checked (a + value.Value) : a; }); } @@ -1931,30 +1967,30 @@ namespace System.Linq { Check.Source (source); - return Sum (source, (a, b) => a + b); + return Sum (source, (a, b) => checked (a + b)); } public static decimal? Sum (this IEnumerable source) { Check.Source (source); - return SumNullable (source, (a, b) => a.HasValue ? a + b : a); + return source.SumNullable (0, (total, element) => element.HasValue ? checked (total + element) : total); } public static decimal Sum (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return Sum (source, (a, b) => a + selector (b)); + return Sum (source, (a, b) => checked (a + selector (b))); } public static decimal? Sum (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return SumNullable (source, (a, b) => { + return source.SumNullable (0, (a, b) => { var value = selector (b); - return value.HasValue ? a + value.Value : a; + return value.HasValue ? checked (a + value.Value) : a; }); } @@ -1967,15 +2003,12 @@ namespace System.Linq ++counter; } - if (counter == 0) - throw new InvalidOperationException (); - return total; } - static TR SumNullable (this IEnumerable source, Func selector) + static TR SumNullable (this IEnumerable source, TR zero, Func selector) { - TR total = default (TR); + TR total = zero; foreach (var element in source) { total = selector (total, element); } @@ -2001,10 +2034,10 @@ namespace System.Linq int counter = 0; foreach (TSource element in source) { - if (counter++ == count) - yield break; - yield return element; + + if (++counter == count) + yield break; } } @@ -2134,16 +2167,7 @@ namespace System.Linq public static Dictionary ToDictionary (this IEnumerable source, Func keySelector, IEqualityComparer comparer) { - Check.SourceAndKeySelector (source, keySelector); - - if (comparer == null) - comparer = EqualityComparer.Default; - - var dict = new Dictionary (comparer); - foreach (var e in source) - dict.Add (keySelector (e), e); - - return dict; + return ToDictionary (source, keySelector, Function.Identity, comparer); } #endregion @@ -2161,24 +2185,13 @@ namespace System.Linq public static ILookup ToLookup (this IEnumerable source, Func keySelector) { - return ToLookup (source, keySelector, null); + return ToLookup (source, keySelector, Function.Identity, null); } public static ILookup ToLookup (this IEnumerable source, Func keySelector, IEqualityComparer comparer) { - Check.SourceAndKeySelector (source, keySelector); - - var dictionary = new Dictionary> (comparer ?? EqualityComparer.Default); - foreach (TSource element in source) { - TKey key = keySelector (element); - if (key == null) - throw new ArgumentNullException (); - if (!dictionary.ContainsKey (key)) - dictionary.Add (key, new List ()); - dictionary [key].Add (element); - } - return new Lookup (dictionary); + return ToLookup (source, keySelector, element => element, comparer); } public static ILookup ToLookup (this IEnumerable source, @@ -2192,15 +2205,21 @@ namespace System.Linq { Check.SourceAndKeyElementSelectors (source, keySelector, elementSelector); - Dictionary> dictionary = new Dictionary> (comparer ?? EqualityComparer.Default); - foreach (TSource element in source) { - TKey key = keySelector (element); + var dictionary = new Dictionary> (comparer ?? EqualityComparer.Default); + foreach (var element in source) { + var key = keySelector (element); if (key == null) - throw new ArgumentNullException (); - if (!dictionary.ContainsKey (key)) - dictionary.Add (key, new List ()); - dictionary [key].Add (elementSelector (element)); + throw new ArgumentNullException ("key"); + + List list; + if (!dictionary.TryGetValue (key, out list)) { + list = new List (); + dictionary.Add (key, list); + } + + list.Add (elementSelector (element)); } + return new Lookup (dictionary); } @@ -2220,18 +2239,19 @@ namespace System.Linq if (comparer == null) comparer = EqualityComparer.Default; - var first_enumerator = first.GetEnumerator (); - var second_enumerator = second.GetEnumerator (); + using (IEnumerator first_enumerator = first.GetEnumerator (), + second_enumerator = second.GetEnumerator ()) { - while (first_enumerator.MoveNext ()) { - if (!second_enumerator.MoveNext ()) - return false; + while (first_enumerator.MoveNext ()) { + if (!second_enumerator.MoveNext ()) + return false; - if (!comparer.Equals (first_enumerator.Current, second_enumerator.Current)) - return false; - } + if (!comparer.Equals (first_enumerator.Current, second_enumerator.Current)) + return false; + } - return !second_enumerator.MoveNext (); + return !second_enumerator.MoveNext (); + } } #endregion