X-Git-Url: http://wien.tomnetworks.com/gitweb/?a=blobdiff_plain;f=mcs%2Fclass%2FSystem.Core%2FSystem.Linq%2FEnumerable.cs;h=d2c98d0a1954f16a877a18c2c856d759747d2563;hb=cc2c56a916645e45991c43f86409966218a6a177;hp=d243ff7c327e53b54ab0c76581c67096c7878eae;hpb=fd6bd11aa889f906c3be7f3f1f2893015a8ea808;p=mono.git diff --git a/mcs/class/System.Core/System.Linq/Enumerable.cs b/mcs/class/System.Core/System.Linq/Enumerable.cs index d243ff7c327..d2c98d0a195 100644 --- a/mcs/class/System.Core/System.Linq/Enumerable.cs +++ b/mcs/class/System.Core/System.Linq/Enumerable.cs @@ -40,6 +40,19 @@ namespace System.Linq { public static class Enumerable { + enum Fallback { + Default, + Throw + } + + class PredicateOf { + public static readonly Func Always = (t) => true; + } + + class Function { + public static readonly Func Identity = (t) => t; + } + #region Aggregate public static TSource Aggregate (this IEnumerable source, Func func) @@ -77,8 +90,8 @@ namespace System.Linq if (resultSelector == null) throw new ArgumentNullException ("resultSelector"); - TAccumulate result = seed; - foreach (TSource e in source) + var result = seed; + foreach (var e in source) result = func (result, e); return resultSelector (result); @@ -92,7 +105,7 @@ namespace System.Linq { Check.SourceAndPredicate (source, predicate); - foreach (TSource element in source) + foreach (var element in source) if (!predicate (element)) return false; @@ -107,6 +120,10 @@ namespace System.Linq { Check.Source (source); + var collection = source as ICollection; + if (collection != null) + return collection.Count > 0; + using (var enumerator = source.GetEnumerator ()) return enumerator.MoveNext (); } @@ -137,35 +154,39 @@ namespace System.Linq public static double Average (this IEnumerable source) { - return Average (source, (a, b) => a + b, (a, b) => (double) a / (double) b); + return Average (source, (a, b) => a + b, (a, b) => (double) a / (double) b); } public static double Average (this IEnumerable source) { - return Average (source, (a, b) => a + b, (a, b) => (double) a / (double) b); + return Average (source, (a, b) => a + b, (a, b) => (double) a / (double) b); } public static double Average (this IEnumerable source) { - return Average (source, (a, b) => a + b, (a, b) => a / b); + return Average (source, (a, b) => a + b, (a, b) => a / b); } public static float Average (this IEnumerable source) { - return Average (source, (a, b) => a + b, (a, b) => a / b); + return Average (source, (a, b) => a + b, (a, b) => (float) a / (float) b); } public static decimal Average (this IEnumerable source) { - return Average (source, (a, b) => a + b, (a, b) => a / b); + return Average (source, (a, b) => a + b, (a, b) => a / b); } - static TR Average (this IEnumerable source, Func func, Func result) + static TResult Average (this IEnumerable source, + Func func, Func result) + where TElement : struct + where TAggregate : struct + where TResult : struct { Check.Source (source); - TA total = default (TA); - int counter = 0; + var total = default (TAggregate); + long counter = 0; foreach (var element in source) { total = func (total, element); ++counter; @@ -177,269 +198,135 @@ namespace System.Linq return result (total, counter); } - public static double? Average (this IEnumerable source) + static TResult? AverageNullable (this IEnumerable source, + Func func, Func result) + where TElement : struct + where TAggregate : struct + where TResult : struct { Check.Source (source); - bool onlyNull = true; - long sum = 0; + var total = default (TAggregate); long counter = 0; - foreach (int? element in source) { - if (element.HasValue) { - onlyNull = false; - sum += element.Value; - counter++; - } + foreach (var element in source) { + if (!element.HasValue) + continue; + + total = func (total, element.Value); + counter++; } - return (onlyNull ? null : (double?) sum / (double?) counter); + + if (counter == 0) + return null; + + return new TResult? (result (total, counter)); + } + + public static double? Average (this IEnumerable source) + { + Check.Source (source); + + return source.AverageNullable ((a, b) => a + b, (a, b) => (double) a / (double) b); } public static double? Average (this IEnumerable source) { Check.Source (source); - bool onlyNull = true; - long sum = 0; - long counter = 0; - foreach (long? element in source) { - if (element.HasValue) { - onlyNull = false; - sum += element.Value; - counter++; - } - } - return (onlyNull ? null : (double?) sum / (double?) counter); + return source.AverageNullable ((a, b) => a + b, (a, b) => (double) a / b); } public static double? Average (this IEnumerable source) { Check.Source (source); - bool onlyNull = true; - double sum = 0; - double counter = 0; - foreach (double? element in source) { - if (element.HasValue) { - onlyNull = false; - sum += element.Value; - counter++; - } - } - return (onlyNull ? null : (double?) (sum / counter)); + return source.AverageNullable ((a, b) => a + b, (a, b) => a / b); } public static decimal? Average (this IEnumerable source) { Check.Source (source); - bool onlyNull = true; - decimal sum = 0; - decimal counter = 0; - foreach (decimal? element in source) { - if (element.HasValue) { - onlyNull = false; - sum += element.Value; - counter++; - } - } - return (onlyNull ? null : (decimal?) (sum / counter)); + return source.AverageNullable ((a, b) => a + b, (a, b) => a / b); } public static float? Average (this IEnumerable source) { Check.Source (source); - float sum = 0; - float counter = 0; - foreach (float? element in source) { - if (element.HasValue) { - sum += element.Value; - ++counter; - } - } - - if (counter == 0) - return null; - - return sum / counter; + return source.AverageNullable ((a, b) => a + b, (a, b) => (float) a / (float) b); } public static double Average (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - long sum = 0; - long counter = 0; - foreach (TSource item in source) { - sum += selector (item); - counter++; - } - - if (counter == 0) - throw new InvalidOperationException (); - else - return (double) sum / (double) counter; + return source.Select (selector).Average ((a, b) => a + b, (a, b) => (double) a / (double) b); } public static double? Average (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - bool onlyNull = true; - long sum = 0; - long counter = 0; - foreach (TSource item in source) { - int? element = selector (item); - if (element.HasValue) { - onlyNull = false; - sum += element.Value; - counter++; - } - } - return (onlyNull ? null : (double?) sum / (double?) counter); + return source.Select (selector).AverageNullable ((a, b) => a + b, (a, b) => (double) a / (double) b); } public static double Average (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - long sum = 0; - long counter = 0; - foreach (TSource item in source) { - sum += selector (item); - counter++; - } - - if (counter == 0) - throw new InvalidOperationException (); - else - return (double) sum / (double) counter; + return source.Select (selector).Average ((a, b) => a + b, (a, b) => (double) a / (double) b); } public static double? Average (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - bool onlyNull = true; - long sum = 0; - long counter = 0; - foreach (TSource item in source) { - long? element = selector (item); - if (element.HasValue) { - onlyNull = false; - sum += element.Value; - counter++; - } - } - return (onlyNull ? null : (double?) sum / (double?) counter); + return source.Select (selector).AverageNullable ((a, b) => a + b, (a, b) => (double) a / (double) b); } public static double Average (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - double sum = 0; - double counter = 0; - foreach (TSource item in source) { - sum += selector (item); - counter++; - } - - if (counter == 0) - throw new InvalidOperationException (); - else - return sum / counter; + return source.Select (selector).Average ((a, b) => a + b, (a, b) => a / b); } public static double? Average (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - bool onlyNull = true; - double sum = 0; - double counter = 0; - foreach (TSource item in source) { - double? element = selector (item); - if (element.HasValue) { - onlyNull = false; - sum += element.Value; - counter++; - } - } - return (onlyNull ? null : (double?) (sum / counter)); + return source.Select (selector).AverageNullable ((a, b) => a + b, (a, b) => a / b); } public static float Average (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - float sum = 0; - float counter = 0; - foreach (TSource item in source) { - sum += selector (item); - ++counter; - } - - if (counter == 0) - throw new InvalidOperationException (); - - return sum / counter; + return source.Select (selector).Average ((a, b) => a + b, (a, b) => (float) a / (float) b); } public static float? Average (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - float sum = 0; - float counter = 0; - foreach (TSource item in source) { - float? value = selector (item); - if (value.HasValue) { - sum += value.Value; - ++counter; - } - } - - if (counter == 0) - throw new InvalidOperationException (); - - return sum / counter; + return source.Select (selector).AverageNullable ((a, b) => a + b, (a, b) => (float) a / (float) b); } public static decimal Average (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - decimal sum = 0; - decimal counter = 0; - foreach (TSource item in source) { - sum += selector (item); - counter++; - } - - if (counter == 0) - throw new InvalidOperationException (); - else - return sum / counter; + return source.Select (selector).Average ((a, b) => a + b, (a, b) => a / b); } public static decimal? Average (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - bool onlyNull = true; - decimal sum = 0; - decimal counter = 0; - foreach (TSource item in source) { - decimal? element = selector (item); - if (element.HasValue) { - onlyNull = false; - sum += element.Value; - counter++; - } - } - return (onlyNull ? null : (decimal?) (sum / counter)); + return source.Select (selector).AverageNullable ((a, b) => a + b, (a, b) => a / b); } + #endregion #region Cast @@ -453,8 +340,8 @@ namespace System.Linq static IEnumerable CreateCastIterator (IEnumerable source) { - foreach (object element in source) - yield return (TResult) element; + foreach (TResult element in source) + yield return element; } #endregion @@ -496,10 +383,9 @@ namespace System.Linq if (comparer == null) comparer = EqualityComparer.Default; - foreach (TSource e in source) { - if (comparer.Equals (e, value)) + foreach (var element in source) + if (comparer.Equals (element, value)) return true; - } return false; } @@ -597,9 +483,25 @@ namespace System.Linq #region ElementAt + static TSource ElementAt (this IEnumerable source, int index, Fallback fallback) + { + long counter = 0L; + + foreach (var element in source) { + if (index == counter++) + return element; + } + + if (fallback == Fallback.Throw) + throw new ArgumentOutOfRangeException (); + + return default (TSource); + } + public static TSource ElementAt (this IEnumerable source, int index) { Check.Source (source); + if (index < 0) throw new ArgumentOutOfRangeException (); @@ -607,14 +509,7 @@ namespace System.Linq if (list != null) return list [index]; - int counter = 0; - foreach (var element in source) { - if (counter == index) - return element; - counter++; - } - - throw new ArgumentOutOfRangeException (); + return source.ElementAt (index, Fallback.Throw); } #endregion @@ -624,6 +519,7 @@ namespace System.Linq public static TSource ElementAtOrDefault (this IEnumerable source, int index) { Check.Source (source); + if (index < 0) return default (TSource); @@ -631,14 +527,7 @@ namespace System.Linq if (list != null) return index < list.Count ? list [index] : default (TSource); - int counter = 0; - foreach (TSource element in source) { - if (counter == index) - return element; - counter++; - } - - return default (TSource); + return source.ElementAt (index, Fallback.Default); } #endregion @@ -671,9 +560,9 @@ namespace System.Linq static IEnumerable CreateExceptIterator (IEnumerable first, IEnumerable second, IEqualityComparer comparer) { - var items = new HashSet (Distinct (second)); - foreach (TSource element in first) { - if (! items.Contains (element, comparer)) + var items = new HashSet (second, comparer); + foreach (var element in first) { + if (!items.Contains (element, comparer)) yield return element; } } @@ -682,27 +571,43 @@ namespace System.Linq #region First + static TSource First (this IEnumerable source, Func predicate, Fallback fallback) + { + foreach (var element in source) + if (predicate (element)) + return element; + + if (fallback == Fallback.Throw) + throw new InvalidOperationException (); + + return default (TSource); + } + public static TSource First (this IEnumerable source) { Check.Source (source); - foreach (TSource element in source) - return element; + var list = source as IList; + if (list != null) { + if (list.Count != 0) + return list [0]; + + throw new InvalidOperationException (); + } else { + using (var enumerator = source.GetEnumerator ()) { + if (enumerator.MoveNext ()) + return enumerator.Current; + } + } throw new InvalidOperationException (); } - public static TSource First (this IEnumerable source, Func predicate) { Check.SourceAndPredicate (source, predicate); - foreach (TSource element in source) { - if (predicate (element)) - return element; - } - - throw new InvalidOperationException (); + return source.First (predicate, Fallback.Throw); } #endregion @@ -713,23 +618,14 @@ namespace System.Linq { Check.Source (source); - foreach (TSource element in source) - return element; - - return default (TSource); + return source.First (PredicateOf.Always, Fallback.Default); } - public static TSource FirstOrDefault (this IEnumerable source, Func predicate) { Check.SourceAndPredicate (source, predicate); - foreach (TSource element in source) { - if (predicate (element)) - return element; - } - - return default (TSource); + return source.First (predicate, Fallback.Default); } #endregion @@ -747,7 +643,6 @@ namespace System.Linq return null; } - public static IEnumerable> GroupBy (this IEnumerable source, Func keySelector) { @@ -759,6 +654,12 @@ namespace System.Linq { Check.SourceAndKeySelector (source, keySelector); + return CreateGroupByIterator (source, keySelector, comparer); + } + + static IEnumerable> CreateGroupByIterator (this IEnumerable source, + Func keySelector, IEqualityComparer comparer) + { Dictionary> groups = new Dictionary> (); List nullList = new List (); int counter = 0; @@ -796,7 +697,6 @@ namespace System.Linq } } - public static IEnumerable> GroupBy (this IEnumerable source, Func keySelector, Func elementSelector) { @@ -853,13 +753,16 @@ namespace System.Linq return GroupBy (source, keySelector, elementSelector, resultSelector, null); } - [MonoTODO] public static IEnumerable GroupBy (this IEnumerable source, Func keySelector, Func elementSelector, Func, TResult> resultSelector, IEqualityComparer comparer) { - throw new NotImplementedException (); + IEnumerable> groups = GroupBy ( + source, keySelector, elementSelector, comparer); + + foreach (IGrouping group in groups) + yield return resultSelector (group.Key, group); } public static IEnumerable GroupBy (this IEnumerable source, @@ -869,13 +772,15 @@ namespace System.Linq return GroupBy (source, keySelector, resultSelector, null); } - [MonoTODO] public static IEnumerable GroupBy (this IEnumerable source, Func keySelector, Func, TResult> resultSelector, IEqualityComparer comparer) { - throw new NotImplementedException (); + IEnumerable> groups = GroupBy (source, keySelector, comparer); + + foreach (IGrouping group in groups) + yield return resultSelector (group.Key, group); } #endregion @@ -899,6 +804,14 @@ namespace System.Linq if (comparer == null) comparer = EqualityComparer.Default; + return CreateGroupJoinIterator (outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer); + } + + static IEnumerable CreateGroupJoinIterator (this IEnumerable outer, + IEnumerable inner, Func outerKeySelector, + Func innerKeySelector, Func, TResult> resultSelector, + IEqualityComparer comparer) + { ILookup innerKeys = ToLookup (inner, innerKeySelector, comparer); /*Dictionary> innerKeys = new Dictionary> (); foreach (U element in inner) @@ -941,7 +854,7 @@ namespace System.Linq { var items = new HashSet (second, comparer); foreach (TSource element in first) { - if (items.Contains (element)) + if (items.Remove (element)) yield return element; } } @@ -959,6 +872,13 @@ namespace System.Linq if (comparer == null) comparer = EqualityComparer.Default; + return CreateJoinIterator (outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer); + } + + static IEnumerable CreateJoinIterator (this IEnumerable outer, + IEnumerable inner, Func outerKeySelector, + Func innerKeySelector, Func resultSelector, IEqualityComparer comparer) + { ILookup innerKeys = ToLookup (inner, innerKeySelector, comparer); /*Dictionary> innerKeys = new Dictionary> (); foreach (U element in inner) @@ -982,48 +902,55 @@ namespace System.Linq IEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func resultSelector) { - return Join (outer, inner, outerKeySelector, innerKeySelector, resultSelector, null); + return outer.Join (inner, outerKeySelector, innerKeySelector, resultSelector, null); } - # endregion + #endregion #region Last - public static TSource Last (this IEnumerable source) + static TSource Last (this IEnumerable source, Func predicate, Fallback fallback) { - Check.Source (source); + var empty = true; + var item = default (TSource); - bool noElements = true; - TSource lastElement = default (TSource); - foreach (TSource element in source) { - if (noElements) noElements = false; - lastElement = element; + foreach (var element in source) { + if (!predicate (element)) + continue; + + item = element; + empty = false; } - if (!noElements) - return lastElement; - else + if (!empty) + return item; + + if (fallback == Fallback.Throw) + throw new InvalidOperationException (); + + return item; + } + + public static TSource Last (this IEnumerable source) + { + Check.Source (source); + + var collection = source as ICollection; + if (collection != null && collection.Count == 0) throw new InvalidOperationException (); + + var list = source as IList; + if (list != null) + return list [list.Count - 1]; + + return source.Last (PredicateOf.Always, Fallback.Throw); } - public static TSource Last (this IEnumerable source, - Func predicate) + public static TSource Last (this IEnumerable source, Func predicate) { Check.SourceAndPredicate (source, predicate); - bool noElements = true; - TSource lastElement = default (TSource); - foreach (TSource element in source) { - if (predicate (element)) { - if (noElements) noElements = false; - lastElement = element; - } - } - - if (!noElements) - return lastElement; - else - throw new InvalidOperationException (); + return source.Last (predicate, Fallback.Throw); } #endregion @@ -1038,34 +965,30 @@ namespace System.Linq if (list != null) return list.Count > 0 ? list [list.Count - 1] : default (TSource); - TSource lastElement = default (TSource); - foreach (TSource element in source) - lastElement = element; - - return lastElement; + return source.Last (PredicateOf.Always, Fallback.Default); } - public static TSource LastOrDefault (this IEnumerable source, - Func predicate) + public static TSource LastOrDefault (this IEnumerable source, Func predicate) { Check.SourceAndPredicate (source, predicate); - TSource lastElement = default (TSource); - foreach (TSource element in source) { - if (predicate (element)) - lastElement = element; - } - - return lastElement; + return source.Last (predicate, Fallback.Default); } #endregion #region LongCount + public static long LongCount (this IEnumerable source) { Check.Source (source); +#if !NET_2_1 + var array = source as TSource []; + if (array != null) + return array.LongLength; +#endif + long counter = 0; using (var enumerator = source.GetEnumerator ()) while (enumerator.MoveNext ()) @@ -1129,51 +1052,78 @@ namespace System.Linq { Check.Source (source); - return IterateNullable (source, int.MinValue, (a, b) => a > b); + return IterateNullable (source, (a, b) => Math.Max (a, b)); } public static long? Max (this IEnumerable source) { Check.Source (source); - return IterateNullable (source, long.MinValue, (a, b) => a > b); + return IterateNullable (source, (a, b) => Math.Max (a, b)); } public static double? Max (this IEnumerable source) { Check.Source (source); - return IterateNullable (source, double.MinValue, (a, b) => a > b); + return IterateNullable (source, (a, b) => Math.Max (a, b)); } public static float? Max (this IEnumerable source) { Check.Source (source); - return IterateNullable (source, float.MinValue, (a, b) => a > b); + return IterateNullable (source, (a, b) => Math.Max (a, b)); } public static decimal? Max (this IEnumerable source) { Check.Source (source); - return IterateNullable (source, decimal.MinValue, (a, b) => a > b); + return IterateNullable (source, (a, b) => Math.Max (a, b)); } - static T? IterateNullable (IEnumerable source, T initValue, Func selector) where T : struct + static T? IterateNullable (IEnumerable source, Func selector) where T : struct { - int counter = 0; - T? value = initValue; + bool empty = true; + T? value = null; foreach (var element in source) { if (!element.HasValue) continue; - if (selector (element.Value, value)) - value = element; - ++counter; + if (!value.HasValue) + value = element.Value; + else + value = selector (element.Value, value.Value); + + empty = false; } - if (counter == 0) + if (empty) + return null; + + return value; + } + + static TRet? IterateNullable ( + IEnumerable source, + Func source_selector, + Func selector) where TRet : struct + { + bool empty = true; + TRet? value = null; + foreach (var element in source) { + TRet? item = source_selector (element); + + if (!value.HasValue) + value = item; + else if (selector (item, value)) + value = item; + + empty = false; + } + + if (empty) return null; return value; @@ -1248,79 +1198,51 @@ namespace System.Linq static U Iterate (IEnumerable source, U initValue, Func selector) { - int counter = 0; + bool empty = true; foreach (var element in source) { initValue = selector (element, initValue); - ++counter; + empty = false; } - if (counter == 0) + if (empty) throw new InvalidOperationException (); return initValue; } - static U? IterateNullable (IEnumerable source, U initialValue, Func selector) where U : struct - { - int counter = 0; - U? value = initialValue; - foreach (var element in source) { - value = selector (element, value); - if (!value.HasValue) - continue; - - ++counter; - } - - if (counter == 0) - return null; - - return value; - } - public static int? Max (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return IterateNullable (source, int.MinValue, (a, b) => { - var v = selector (a); return v > b ? v : b; - }); + return IterateNullable (source, selector, (a, b) => a > b); } public static long? Max (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return IterateNullable (source, long.MinValue, (a, b) => { - var v = selector (a); return v > b ? v : b; - }); + return IterateNullable (source, selector, (a, b) => a > b); } public static double? Max (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return IterateNullable (source, double.MinValue, (a, b) => { - var v = selector (a); return v > b ? v : b; - }); + return IterateNullable (source, selector, (a, b) => a > b); } public static float? Max (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return IterateNullable (source, float.MinValue, (a, b) => { - var v = selector (a); return v > b ? v : b; - }); + return IterateNullable (source, selector, (a, b) => a > b); } public static decimal? Max (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return IterateNullable (source, decimal.MinValue, (a, b) => { - var v = selector (a); return v > b ? v : b; - }); + return IterateNullable (source, selector, (a, b) => a > b); } public static TResult Max (this IEnumerable source, Func selector) @@ -1399,35 +1321,35 @@ namespace System.Linq { Check.Source (source); - return IterateNullable (source, int.MaxValue, (a, b) => a < b); + return IterateNullable (source, (a, b) => Math.Min (a, b)); } public static long? Min (this IEnumerable source) { Check.Source (source); - return IterateNullable (source, long.MaxValue, (a, b) => a < b); + return IterateNullable (source, (a, b) => Math.Min (a, b)); } public static double? Min (this IEnumerable source) { Check.Source (source); - return IterateNullable (source, double.MaxValue, (a, b) => a < b); + return IterateNullable (source, (a, b) => Math.Min (a, b)); } public static float? Min (this IEnumerable source) { Check.Source (source); - return IterateNullable (source, float.MaxValue, (a, b) => a < b); + return IterateNullable (source, (a, b) => Math.Min (a, b)); } public static decimal? Min (this IEnumerable source) { Check.Source (source); - return IterateNullable (source, decimal.MaxValue, (a, b) => a < b); + return IterateNullable (source, (a, b) => Math.Min (a, b)); } public static TSource Min (this IEnumerable source) @@ -1501,45 +1423,35 @@ namespace System.Linq { Check.SourceAndSelector (source, selector); - return IterateNullable (source, int.MaxValue, (a, b) => { - var v = selector (a); return v < b ? v : b; - }); + return IterateNullable (source, selector, (a, b) => a < b); } public static long? Min (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return IterateNullable (source, long.MaxValue, (a, b) => { - var v = selector (a); return v < b ? v : b; - }); + return IterateNullable (source, selector, (a, b) => a < b); } public static float? Min (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return IterateNullable (source, float.MaxValue, (a, b) => { - var v = selector (a); return v < b ? v : b; - }); + return IterateNullable (source, selector, (a, b) => a < b); } public static double? Min (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return IterateNullable (source, double.MaxValue, (a, b) => { - var v = selector (a); return v < b ? v : b; - }); + return IterateNullable (source, selector, (a, b) => a < b); } public static decimal? Min (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return IterateNullable (source, decimal.MaxValue, (a, b) => { - var v = selector (a); return v < b ? v : b; - }); + return IterateNullable (source, selector, (a, b) => a < b); } public static TResult Min (this IEnumerable source, Func selector) @@ -1603,15 +1515,13 @@ namespace System.Linq return OrderBy (source, keySelector, null); } - public static IOrderedEnumerable OrderBy (this IEnumerable source, Func keySelector, IComparer comparer) { Check.SourceAndKeySelector (source, keySelector); - return new OrderedSequence ( - source, keySelector, comparer, false); + return new OrderedSequence (source, keySelector, comparer, SortDirection.Ascending); } #endregion @@ -1624,14 +1534,12 @@ namespace System.Linq return OrderByDescending (source, keySelector, null); } - public static IOrderedEnumerable OrderByDescending (this IEnumerable source, Func keySelector, IComparer comparer) { Check.SourceAndKeySelector (source, keySelector); - return new OrderedSequence ( - source, keySelector, comparer, true); + return new OrderedSequence (source, keySelector, comparer, SortDirection.Descending); } #endregion @@ -1677,21 +1585,23 @@ namespace System.Linq #endregion - #region Reverse public static IEnumerable Reverse (this IEnumerable source) { Check.Source (source); - return CreateReverseIterator (source); + var list = source as IList; + if (list == null) + list = new List (source); + + return CreateReverseIterator (list); } - static IEnumerable CreateReverseIterator (IEnumerable source) + static IEnumerable CreateReverseIterator (IList source) { - var list = new List (source); - list.Reverse (); - return list; + for (int i = source.Count; i > 0; --i) + yield return source [i - 1]; } #endregion @@ -1707,7 +1617,7 @@ namespace System.Linq static IEnumerable CreateSelectIterator (IEnumerable source, Func selector) { - foreach (TSource element in source) + foreach (var element in source) yield return selector (element); } @@ -1799,44 +1709,40 @@ namespace System.Linq #region Single - public static TSource Single (this IEnumerable source) + static TSource Single (this IEnumerable source, Func predicate, Fallback fallback) { - Check.Source (source); + var found = false; + var item = default (TSource); - bool otherElement = false; - TSource singleElement = default (TSource); - foreach (TSource element in source) { - if (otherElement) throw new InvalidOperationException (); - if (!otherElement) otherElement = true; - singleElement = element; + foreach (var element in source) { + if (!predicate (element)) + continue; + + if (found) + throw new InvalidOperationException (); + + found = true; + item = element; } - if (otherElement) - return singleElement; - else + if (!found && fallback == Fallback.Throw) throw new InvalidOperationException (); + + return item; } + public static TSource Single (this IEnumerable source) + { + Check.Source (source); - public static TSource Single (this IEnumerable source, - Func predicate) + return source.Single (PredicateOf.Always, Fallback.Throw); + } + + public static TSource Single (this IEnumerable source, Func predicate) { Check.SourceAndPredicate (source, predicate); - bool otherElement = false; - TSource singleElement = default (TSource); - foreach (TSource element in source) { - if (predicate (element)) { - if (otherElement) throw new InvalidOperationException (); - if (!otherElement) otherElement = true; - singleElement = element; - } - } - - if (otherElement) - return singleElement; - else - throw new InvalidOperationException (); + return source.Single (predicate, Fallback.Throw); } #endregion @@ -1847,34 +1753,14 @@ namespace System.Linq { Check.Source (source); - bool otherElement = false; - TSource singleElement = default (TSource); - foreach (TSource element in source) { - if (otherElement) throw new InvalidOperationException (); - if (!otherElement) otherElement = true; - singleElement = element; - } - - return singleElement; + return source.Single (PredicateOf.Always, Fallback.Default); } - - public static TSource SingleOrDefault (this IEnumerable source, - Func predicate) + public static TSource SingleOrDefault (this IEnumerable source, Func predicate) { Check.SourceAndPredicate (source, predicate); - bool otherElement = false; - TSource singleElement = default (TSource); - foreach (TSource element in source) { - if (predicate (element)) { - if (otherElement) throw new InvalidOperationException (); - if (!otherElement) otherElement = true; - singleElement = element; - } - } - - return singleElement; + return source.Single (predicate, Fallback.Default); } #endregion @@ -1957,30 +1843,30 @@ namespace System.Linq { Check.Source (source); - return Sum (source, (a, b) => a + b); + return Sum (source, (a, b) => checked (a + b)); } public static int? Sum (this IEnumerable source) { Check.Source (source); - return SumNullable (source, (a, b) => a.HasValue ? a + b : a); + return source.SumNullable (0, (total, element) => element.HasValue ? checked (total + element) : total); } public static int Sum (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return Sum (source, (a, b) => a + selector (b)); + return Sum (source, (a, b) => checked (a + selector (b))); } public static int? Sum (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return SumNullable (source, (a, b) => { + return source.SumNullable (0, (a, b) => { var value = selector (b); - return value.HasValue ? a + value.Value : a; + return value.HasValue ? checked (a + value.Value) : a; }); } @@ -1988,30 +1874,30 @@ namespace System.Linq { Check.Source (source); - return Sum (source, (a, b) => a + b); + return Sum (source, (a, b) => checked (a + b)); } public static long? Sum (this IEnumerable source) { Check.Source (source); - return SumNullable (source, (a, b) => a.HasValue ? a + b : a); + return source.SumNullable (0, (total, element) => element.HasValue ? checked (total + element) : total); } public static long Sum (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return Sum (source, (a, b) => a + selector (b)); + return Sum (source, (a, b) => checked (a + selector (b))); } public static long? Sum (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return SumNullable (source, (a, b) => { + return source.SumNullable (0, (a, b) => { var value = selector (b); - return value.HasValue ? a + value.Value : a; + return value.HasValue ? checked (a + value.Value) : a; }); } @@ -2019,30 +1905,30 @@ namespace System.Linq { Check.Source (source); - return Sum (source, (a, b) => a + b); + return Sum (source, (a, b) => checked (a + b)); } public static double? Sum (this IEnumerable source) { Check.Source (source); - return SumNullable (source, (a, b) => a.HasValue ? a + b : a); + return source.SumNullable (0, (total, element) => element.HasValue ? checked (total + element) : total); } public static double Sum (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return Sum (source, (a, b) => a + selector (b)); + return Sum (source, (a, b) => checked (a + selector (b))); } public static double? Sum (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return SumNullable (source, (a, b) => { + return source.SumNullable (0, (a, b) => { var value = selector (b); - return value.HasValue ? a + value.Value : a; + return value.HasValue ? checked (a + value.Value) : a; }); } @@ -2050,30 +1936,30 @@ namespace System.Linq { Check.Source (source); - return Sum (source, (a, b) => a + b); + return Sum (source, (a, b) => checked (a + b)); } public static float? Sum (this IEnumerable source) { Check.Source (source); - return SumNullable (source, (a, b) => a.HasValue ? a + b : a); + return source.SumNullable (0, (total, element) => element.HasValue ? checked (total + element) : total); } public static float Sum (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return Sum (source, (a, b) => a + selector (b)); + return Sum (source, (a, b) => checked (a + selector (b))); } public static float? Sum (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return SumNullable (source, (a, b) => { + return source.SumNullable (0, (a, b) => { var value = selector (b); - return value.HasValue ? a + value.Value : a; + return value.HasValue ? checked (a + value.Value) : a; }); } @@ -2081,51 +1967,48 @@ namespace System.Linq { Check.Source (source); - return Sum (source, (a, b) => a + b); + return Sum (source, (a, b) => checked (a + b)); } public static decimal? Sum (this IEnumerable source) { Check.Source (source); - return SumNullable (source, (a, b) => a.HasValue ? a + b : a); + return source.SumNullable (0, (total, element) => element.HasValue ? checked (total + element) : total); } public static decimal Sum (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return Sum (source, (a, b) => a + selector (b)); + return Sum (source, (a, b) => checked (a + selector (b))); } public static decimal? Sum (this IEnumerable source, Func selector) { Check.SourceAndSelector (source, selector); - return SumNullable (source, (a, b) => { + return source.SumNullable (0, (a, b) => { var value = selector (b); - return value.HasValue ? a + value.Value : a; + return value.HasValue ? checked (a + value.Value) : a; }); } static TR Sum (this IEnumerable source, Func selector) { TR total = default (TR); - int counter = 0; + long counter = 0; foreach (var element in source) { total = selector (total, element); ++counter; } - if (counter == 0) - throw new InvalidOperationException (); - return total; } - static TR SumNullable (this IEnumerable source, Func selector) + static TR SumNullable (this IEnumerable source, TR zero, Func selector) { - TR total = default (TR); + TR total = zero; foreach (var element in source) { total = selector (total, element); } @@ -2151,10 +2034,10 @@ namespace System.Linq int counter = 0; foreach (TSource element in source) { - if (counter++ == count) - yield break; - yield return element; + + if (++counter == count) + yield break; } } @@ -2171,7 +2054,7 @@ namespace System.Linq static IEnumerable CreateTakeWhileIterator (IEnumerable source, Func predicate) { - foreach (TSource element in source) { + foreach (var element in source) { if (!predicate (element)) yield break; @@ -2207,7 +2090,6 @@ namespace System.Linq return ThenBy (source, keySelector, null); } - public static IOrderedEnumerable ThenBy (this IOrderedEnumerable source, Func keySelector, IComparer comparer) { @@ -2226,7 +2108,6 @@ namespace System.Linq return ThenByDescending (source, keySelector, null); } - public static IOrderedEnumerable ThenByDescending (this IOrderedEnumerable source, Func keySelector, IComparer comparer) { @@ -2238,12 +2119,19 @@ namespace System.Linq #endregion #region ToArray + public static TSource [] ToArray (this IEnumerable source) { Check.Source (source); - List list = new List (source); - return list.ToArray (); + var collection = source as ICollection; + if (collection != null) { + var array = new TSource [collection.Count]; + collection.CopyTo (array, 0); + return array; + } + + return new List (source).ToArray (); } #endregion @@ -2279,16 +2167,7 @@ namespace System.Linq public static Dictionary ToDictionary (this IEnumerable source, Func keySelector, IEqualityComparer comparer) { - Check.SourceAndKeySelector (source, keySelector); - - if (comparer == null) - comparer = EqualityComparer.Default; - - var dict = new Dictionary (comparer); - foreach (var e in source) - dict.Add (keySelector (e), e); - - return dict; + return ToDictionary (source, keySelector, Function.Identity, comparer); } #endregion @@ -2306,24 +2185,13 @@ namespace System.Linq public static ILookup ToLookup (this IEnumerable source, Func keySelector) { - return ToLookup (source, keySelector, null); + return ToLookup (source, keySelector, Function.Identity, null); } public static ILookup ToLookup (this IEnumerable source, Func keySelector, IEqualityComparer comparer) { - Check.SourceAndKeySelector (source, keySelector); - - var dictionary = new Dictionary> (comparer ?? EqualityComparer.Default); - foreach (TSource element in source) { - TKey key = keySelector (element); - if (key == null) - throw new ArgumentNullException (); - if (!dictionary.ContainsKey (key)) - dictionary.Add (key, new List ()); - dictionary [key].Add (element); - } - return new Lookup (dictionary); + return ToLookup (source, keySelector, element => element, comparer); } public static ILookup ToLookup (this IEnumerable source, @@ -2337,15 +2205,21 @@ namespace System.Linq { Check.SourceAndKeyElementSelectors (source, keySelector, elementSelector); - Dictionary> dictionary = new Dictionary> (comparer ?? EqualityComparer.Default); - foreach (TSource element in source) { - TKey key = keySelector (element); + var dictionary = new Dictionary> (comparer ?? EqualityComparer.Default); + foreach (var element in source) { + var key = keySelector (element); if (key == null) - throw new ArgumentNullException (); - if (!dictionary.ContainsKey (key)) - dictionary.Add (key, new List ()); - dictionary [key].Add (elementSelector (element)); + throw new ArgumentNullException ("key"); + + List list; + if (!dictionary.TryGetValue (key, out list)) { + list = new List (); + dictionary.Add (key, list); + } + + list.Add (elementSelector (element)); } + return new Lookup (dictionary); } @@ -2365,18 +2239,19 @@ namespace System.Linq if (comparer == null) comparer = EqualityComparer.Default; - var first_enumerator = first.GetEnumerator (); - var second_enumerator = second.GetEnumerator (); + using (IEnumerator first_enumerator = first.GetEnumerator (), + second_enumerator = second.GetEnumerator ()) { - while (first_enumerator.MoveNext ()) { - if (!second_enumerator.MoveNext ()) - return false; + while (first_enumerator.MoveNext ()) { + if (!second_enumerator.MoveNext ()) + return false; - if (!comparer.Equals (first_enumerator.Current, second_enumerator.Current)) - return false; - } + if (!comparer.Equals (first_enumerator.Current, second_enumerator.Current)) + return false; + } - return !second_enumerator.MoveNext (); + return !second_enumerator.MoveNext (); + } } #endregion