X-Git-Url: http://wien.tomnetworks.com/gitweb/?a=blobdiff_plain;f=mcs%2Fclass%2FSystem.Core%2FSystem.Linq%2FEnumerable.cs;h=d2c98d0a1954f16a877a18c2c856d759747d2563;hb=cc2c56a916645e45991c43f86409966218a6a177;hp=7010fba59a9ac14ec5255d6586473cc94da11f0b;hpb=d3428efecd57b5c4e2e125bb42bf5415feebd4dd;p=mono.git diff --git a/mcs/class/System.Core/System.Linq/Enumerable.cs b/mcs/class/System.Core/System.Linq/Enumerable.cs index 7010fba59a9..d2c98d0a195 100644 --- a/mcs/class/System.Core/System.Linq/Enumerable.cs +++ b/mcs/class/System.Core/System.Linq/Enumerable.cs @@ -1,3 +1,14 @@ +// +// Enumerable.cs +// +// Authors: +// Marek Safar (marek.safar@gmail.com) +// Antonello Provenzano +// Alejandro Serrano "Serras" (trupill@yahoo.es) +// Jb Evain (jbevain@novell.com) +// +// Copyright (C) 2007 Novell, Inc (http://www.novell.com) +// // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the // "Software"), to deal in the Software without restriction, including @@ -5,22 +16,20 @@ // distribute, sublicense, and/or sell copies of the Software, and to // permit persons to whom the Software is furnished to do so, subject to // the following conditions: -// +// // The above copyright notice and this permission notice shall be // included in all copies or substantial portions of the Software. -// +// // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE // LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. // -// Authors: -// Marek Safar (marek.safar@gmail.com) -// Antonello Provenzano -// Alejandro Serrano "Serras" (trupill@yahoo.es) -// + +// precious: http://www.hookedonlinq.com using System; using System.Collections; @@ -31,17 +40,28 @@ namespace System.Linq { public static class Enumerable { + enum Fallback { + Default, + Throw + } + + class PredicateOf { + public static readonly Func Always = (t) => true; + } + + class Function { + public static readonly Func Identity = (t) => t; + } + #region Aggregate + public static TSource Aggregate (this IEnumerable source, Func func) { - if (source == null) - throw new ArgumentNullException ("source"); - if (func == null) - throw new ArgumentNullException ("func"); + Check.SourceAndFunc (source, func); - // custom foreach so that we can efficiently throw an exception + // custom foreach so that we can efficiently throw an exception // if zero elements and treat the first element differently - using (IEnumerator enumerator = source.GetEnumerator ()) { + using (var enumerator = source.GetEnumerator ()) { if (!enumerator.MoveNext ()) throw new InvalidOperationException ("No elements in source list"); @@ -52,406 +72,291 @@ namespace System.Linq } } - public static TAccumulate Aggregate (this IEnumerable source, TAccumulate seed, Func func) { - if (source == null || func == null) - throw new ArgumentNullException (); + Check.SourceAndFunc (source, func); TAccumulate folded = seed; foreach (TSource element in source) folded = func (folded, element); + return folded; } - public static TResult Aggregate (this IEnumerable source, TAccumulate seed, Func func, Func resultSelector) { - if (source == null) - throw new ArgumentNullException ("source"); - if (func == null) - throw new ArgumentNullException ("func"); + Check.SourceAndFunc (source, func); if (resultSelector == null) throw new ArgumentNullException ("resultSelector"); - TAccumulate result = seed; - foreach (TSource e in source) + var result = seed; + foreach (var e in source) result = func (result, e); + return resultSelector (result); } + #endregion #region All + public static bool All (this IEnumerable source, Func predicate) { - if (source == null || predicate == null) - throw new ArgumentNullException (); + Check.SourceAndPredicate (source, predicate); - foreach (TSource element in source) + foreach (var element in source) if (!predicate (element)) return false; + return true; } + #endregion #region Any + public static bool Any (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); - foreach (TSource element in source) - return true; - return false; - } + var collection = source as ICollection; + if (collection != null) + return collection.Count > 0; + using (var enumerator = source.GetEnumerator ()) + return enumerator.MoveNext (); + } public static bool Any (this IEnumerable source, Func predicate) { - if (source == null || predicate == null) - throw new ArgumentNullException (); + Check.SourceAndPredicate (source, predicate); foreach (TSource element in source) if (predicate (element)) return true; + return false; } + #endregion #region AsEnumerable + public static IEnumerable AsEnumerable (this IEnumerable source) { return source; } + #endregion #region Average + public static double Average (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - long sum = 0; - long counter = 0; - foreach (int element in source) { - sum += element; - counter++; - } - - if (counter == 0) - throw new InvalidOperationException (); - else - return (double) sum / (double) counter; + return Average (source, (a, b) => a + b, (a, b) => (double) a / (double) b); } + public static double Average (this IEnumerable source) + { + return Average (source, (a, b) => a + b, (a, b) => (double) a / (double) b); + } - public static double? Average (this IEnumerable source) + public static double Average (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + return Average (source, (a, b) => a + b, (a, b) => a / b); + } - bool onlyNull = true; - long sum = 0; - long counter = 0; - foreach (int? element in source) { - if (element.HasValue) { - onlyNull = false; - sum += element.Value; - counter++; - } - } - return (onlyNull ? null : (double?) sum / (double?) counter); + public static float Average (this IEnumerable source) + { + return Average (source, (a, b) => a + b, (a, b) => (float) a / (float) b); } + public static decimal Average (this IEnumerable source) + { + return Average (source, (a, b) => a + b, (a, b) => a / b); + } - public static double Average (this IEnumerable source) + static TResult Average (this IEnumerable source, + Func func, Func result) + where TElement : struct + where TAggregate : struct + where TResult : struct { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); - long sum = 0; + var total = default (TAggregate); long counter = 0; - foreach (long element in source) { - sum += element; - counter++; + foreach (var element in source) { + total = func (total, element); + ++counter; } if (counter == 0) throw new InvalidOperationException (); - else - return (double) sum / (double) counter; - } + return result (total, counter); + } - public static double? Average (this IEnumerable source) + static TResult? AverageNullable (this IEnumerable source, + Func func, Func result) + where TElement : struct + where TAggregate : struct + where TResult : struct { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); - bool onlyNull = true; - long sum = 0; + var total = default (TAggregate); long counter = 0; - foreach (long? element in source) { - if (element.HasValue) { - onlyNull = false; - sum += element.Value; - counter++; - } - } - return (onlyNull ? null : (double?) sum / (double?) counter); - } - - - public static double Average (this IEnumerable source) - { - if (source == null) - throw new ArgumentNullException (); + foreach (var element in source) { + if (!element.HasValue) + continue; - double sum = 0; - double counter = 0; - foreach (double element in source) { - sum += element; + total = func (total, element.Value); counter++; } if (counter == 0) - throw new InvalidOperationException (); - else - return sum / counter; - } + return null; + return new TResult? (result (total, counter)); + } - public static double? Average (this IEnumerable source) + public static double? Average (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - bool onlyNull = true; - double sum = 0; - double counter = 0; - foreach (double? element in source) { - if (element.HasValue) { - onlyNull = false; - sum += element.Value; - counter++; - } - } - return (onlyNull ? null : (double?) (sum / counter)); - } + Check.Source (source); + return source.AverageNullable ((a, b) => a + b, (a, b) => (double) a / (double) b); + } - public static decimal Average (this IEnumerable source) + public static double? Average (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - decimal sum = 0; - decimal counter = 0; - foreach (decimal element in source) { - sum += element; - counter++; - } + Check.Source (source); - if (counter == 0) - throw new InvalidOperationException (); - else - return sum / counter; + return source.AverageNullable ((a, b) => a + b, (a, b) => (double) a / b); } + public static double? Average (this IEnumerable source) + { + Check.Source (source); + + return source.AverageNullable ((a, b) => a + b, (a, b) => a / b); + } public static decimal? Average (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - bool onlyNull = true; - decimal sum = 0; - decimal counter = 0; - foreach (decimal? element in source) { - if (element.HasValue) { - onlyNull = false; - sum += element.Value; - counter++; - } - } - return (onlyNull ? null : (decimal?) (sum / counter)); + Check.Source (source); + + return source.AverageNullable ((a, b) => a + b, (a, b) => a / b); } + public static float? Average (this IEnumerable source) + { + Check.Source (source); + + return source.AverageNullable ((a, b) => a + b, (a, b) => (float) a / (float) b); + } public static double Average (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); - - long sum = 0; - long counter = 0; - foreach (TSource item in source) { - sum += selector (item); - counter++; - } + Check.SourceAndSelector (source, selector); - if (counter == 0) - throw new InvalidOperationException (); - else - return (double) sum / (double) counter; + return source.Select (selector).Average ((a, b) => a + b, (a, b) => (double) a / (double) b); } - public static double? Average (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - bool onlyNull = true; - long sum = 0; - long counter = 0; - foreach (TSource item in source) { - int? element = selector (item); - if (element.HasValue) { - onlyNull = false; - sum += element.Value; - counter++; - } - } - return (onlyNull ? null : (double?) sum / (double?) counter); + return source.Select (selector).AverageNullable ((a, b) => a + b, (a, b) => (double) a / (double) b); } - public static double Average (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); - - long sum = 0; - long counter = 0; - foreach (TSource item in source) { - sum += selector (item); - counter++; - } + Check.SourceAndSelector (source, selector); - if (counter == 0) - throw new InvalidOperationException (); - else - return (double) sum / (double) counter; + return source.Select (selector).Average ((a, b) => a + b, (a, b) => (double) a / (double) b); } - public static double? Average (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - bool onlyNull = true; - long sum = 0; - long counter = 0; - foreach (TSource item in source) { - long? element = selector (item); - if (element.HasValue) { - onlyNull = false; - sum += element.Value; - counter++; - } - } - return (onlyNull ? null : (double?) sum / (double?) counter); + return source.Select (selector).AverageNullable ((a, b) => a + b, (a, b) => (double) a / (double) b); } - public static double Average (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); - - double sum = 0; - double counter = 0; - foreach (TSource item in source) { - sum += selector (item); - counter++; - } + Check.SourceAndSelector (source, selector); - if (counter == 0) - throw new InvalidOperationException (); - else - return sum / counter; + return source.Select (selector).Average ((a, b) => a + b, (a, b) => a / b); } - public static double? Average (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - bool onlyNull = true; - double sum = 0; - double counter = 0; - foreach (TSource item in source) { - double? element = selector (item); - if (element.HasValue) { - onlyNull = false; - sum += element.Value; - counter++; - } - } - return (onlyNull ? null : (double?) (sum / counter)); + return source.Select (selector).AverageNullable ((a, b) => a + b, (a, b) => a / b); } - - public static decimal Average (this IEnumerable source, Func selector) + public static float Average (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - decimal sum = 0; - decimal counter = 0; - foreach (TSource item in source) { - sum += selector (item); - counter++; - } + return source.Select (selector).Average ((a, b) => a + b, (a, b) => (float) a / (float) b); + } - if (counter == 0) - throw new InvalidOperationException (); - else - return sum / counter; + public static float? Average (this IEnumerable source, Func selector) + { + Check.SourceAndSelector (source, selector); + + return source.Select (selector).AverageNullable ((a, b) => a + b, (a, b) => (float) a / (float) b); } + public static decimal Average (this IEnumerable source, Func selector) + { + Check.SourceAndSelector (source, selector); + + return source.Select (selector).Average ((a, b) => a + b, (a, b) => a / b); + } public static decimal? Average (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - bool onlyNull = true; - decimal sum = 0; - decimal counter = 0; - foreach (TSource item in source) { - decimal? element = selector (item); - if (element.HasValue) { - onlyNull = false; - sum += element.Value; - counter++; - } - } - return (onlyNull ? null : (decimal?) (sum / counter)); + return source.Select (selector).AverageNullable ((a, b) => a + b, (a, b) => a / b); } + #endregion #region Cast + public static IEnumerable Cast (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); - foreach (object element in source) - yield return (TResult) element; + return CreateCastIterator (source); + } + + static IEnumerable CreateCastIterator (IEnumerable source) + { + foreach (TResult element in source) + yield return element; } + #endregion #region Concat + public static IEnumerable Concat (this IEnumerable first, IEnumerable second) { - if (first == null || second == null) - throw new ArgumentNullException (); + Check.FirstAndSecond (first, second); + + return CreateConcatIterator (first, second); + } + static IEnumerable CreateConcatIterator (IEnumerable first, IEnumerable second) + { foreach (TSource element in first) yield return element; foreach (TSource element in second) @@ -464,59 +369,58 @@ namespace System.Linq public static bool Contains (this IEnumerable source, TSource value) { - ICollection collection = source as ICollection; + var collection = source as ICollection; if (collection != null) return collection.Contains (value); return Contains (source, value, null); } - public static bool Contains (this IEnumerable source, TSource value, IEqualityComparer comparer) { - if (source == null) - throw new ArgumentNullException ("source"); + Check.Source (source); if (comparer == null) comparer = EqualityComparer.Default; - foreach (TSource e in source) { - if (comparer.Equals (e, value)) + foreach (var element in source) + if (comparer.Equals (element, value)) return true; - } return false; } #endregion #region Count + public static int Count (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); - ICollection collection = source as ICollection; + var collection = source as ICollection; if (collection != null) return collection.Count; int counter = 0; - foreach (TSource element in source) - counter++; + using (var enumerator = source.GetEnumerator ()) + while (enumerator.MoveNext ()) + counter++; + return counter; } public static int Count (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); int counter = 0; - foreach (TSource element in source) + foreach (var element in source) if (selector (element)) counter++; return counter; } + #endregion #region DefaultIfEmpty @@ -528,16 +432,20 @@ namespace System.Linq public static IEnumerable DefaultIfEmpty (this IEnumerable source, TSource defaultValue) { - if (source == null) - throw new ArgumentNullException ("source"); + Check.Source (source); + + return CreateDefaultIfEmptyIterator (source, defaultValue); + } - bool noYield = true; + static IEnumerable CreateDefaultIfEmptyIterator (IEnumerable source, TSource defaultValue) + { + bool empty = true; foreach (TSource item in source) { - noYield = false; + empty = false; yield return item; } - if (noYield) + if (empty) yield return defaultValue; } @@ -552,43 +460,56 @@ namespace System.Linq public static IEnumerable Distinct (this IEnumerable source, IEqualityComparer comparer) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); if (comparer == null) comparer = EqualityComparer.Default; - List items = new List (); // TODO: use a HashSet here - foreach (TSource element in source) { - if (!Contains (items, element, comparer)) { + return CreateDistinctIterator (source, comparer); + } + + static IEnumerable CreateDistinctIterator (IEnumerable source, IEqualityComparer comparer) + { + var items = new HashSet (comparer); + foreach (var element in source) { + if (! items.Contains (element)) { items.Add (element); yield return element; } } } + #endregion #region ElementAt + static TSource ElementAt (this IEnumerable source, int index, Fallback fallback) + { + long counter = 0L; + + foreach (var element in source) { + if (index == counter++) + return element; + } + + if (fallback == Fallback.Throw) + throw new ArgumentOutOfRangeException (); + + return default (TSource); + } + public static TSource ElementAt (this IEnumerable source, int index) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); + if (index < 0) throw new ArgumentOutOfRangeException (); - IList list = source as IList; + var list = source as IList; if (list != null) return list [index]; - int counter = 0; - foreach (TSource element in source) { - if (counter == index) - return element; - counter++; - } - - throw new ArgumentOutOfRangeException (); + return source.ElementAt (index, Fallback.Throw); } #endregion @@ -597,32 +518,27 @@ namespace System.Linq public static TSource ElementAtOrDefault (this IEnumerable source, int index) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); + if (index < 0) return default (TSource); - IList list = source as IList; + var list = source as IList; if (list != null) return index < list.Count ? list [index] : default (TSource); - int counter = 0; - foreach (TSource element in source) { - if (counter == index) - return element; - counter++; - } - - return default (TSource); + return source.ElementAt (index, Fallback.Default); } #endregion #region Empty + public static IEnumerable Empty () { - return new List (); + return new TResult [0]; } + #endregion #region Except @@ -634,15 +550,19 @@ namespace System.Linq public static IEnumerable Except (this IEnumerable first, IEnumerable second, IEqualityComparer comparer) { - if (first == null || second == null) - throw new ArgumentNullException (); + Check.FirstAndSecond (first, second); if (comparer == null) comparer = EqualityComparer.Default; - List items = new List (Distinct (second)); - foreach (TSource element in first) { - if (!Contains (items, element, comparer)) + return CreateExceptIterator (first, second, comparer); + } + + static IEnumerable CreateExceptIterator (IEnumerable first, IEnumerable second, IEqualityComparer comparer) + { + var items = new HashSet (second, comparer); + foreach (var element in first) { + if (!items.Contains (element, comparer)) yield return element; } } @@ -651,58 +571,61 @@ namespace System.Linq #region First - public static TSource First (this IEnumerable source) + static TSource First (this IEnumerable source, Func predicate, Fallback fallback) { - if (source == null) - throw new ArgumentNullException (); + foreach (var element in source) + if (predicate (element)) + return element; - foreach (TSource element in source) - return element; + if (fallback == Fallback.Throw) + throw new InvalidOperationException (); - throw new InvalidOperationException (); + return default (TSource); } - - public static TSource First (this IEnumerable source, Func predicate) + public static TSource First (this IEnumerable source) { - if (source == null || predicate == null) - throw new ArgumentNullException (); + Check.Source (source); - foreach (TSource element in source) { - if (predicate (element)) - return element; + var list = source as IList; + if (list != null) { + if (list.Count != 0) + return list [0]; + + throw new InvalidOperationException (); + } else { + using (var enumerator = source.GetEnumerator ()) { + if (enumerator.MoveNext ()) + return enumerator.Current; + } } throw new InvalidOperationException (); } + public static TSource First (this IEnumerable source, Func predicate) + { + Check.SourceAndPredicate (source, predicate); + + return source.First (predicate, Fallback.Throw); + } + #endregion #region FirstOrDefault public static TSource FirstOrDefault (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); - foreach (TSource element in source) - return element; - - return default (TSource); + return source.First (PredicateOf.Always, Fallback.Default); } - public static TSource FirstOrDefault (this IEnumerable source, Func predicate) { - if (source == null || predicate == null) - throw new ArgumentNullException (); + Check.SourceAndPredicate (source, predicate); - foreach (TSource element in source) { - if (predicate (element)) - return element; - } - - return default (TSource); + return source.First (predicate, Fallback.Default); } #endregion @@ -720,20 +643,23 @@ namespace System.Linq return null; } - public static IEnumerable> GroupBy (this IEnumerable source, Func keySelector) { return GroupBy (source, keySelector, null); } - public static IEnumerable> GroupBy (this IEnumerable source, Func keySelector, IEqualityComparer comparer) { - if (source == null || keySelector == null) - throw new ArgumentNullException (); + Check.SourceAndKeySelector (source, keySelector); + + return CreateGroupByIterator (source, keySelector, comparer); + } + static IEnumerable> CreateGroupByIterator (this IEnumerable source, + Func keySelector, IEqualityComparer comparer) + { Dictionary> groups = new Dictionary> (); List nullList = new List (); int counter = 0; @@ -771,19 +697,16 @@ namespace System.Linq } } - public static IEnumerable> GroupBy (this IEnumerable source, Func keySelector, Func elementSelector) { return GroupBy (source, keySelector, elementSelector, null); } - public static IEnumerable> GroupBy (this IEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer) { - if (source == null || keySelector == null || elementSelector == null) - throw new ArgumentNullException (); + Check.SourceAndKeyElementSelectors (source, keySelector, elementSelector); Dictionary> groups = new Dictionary> (); List nullList = new List (); @@ -823,6 +746,43 @@ namespace System.Linq } } + public static IEnumerable GroupBy (this IEnumerable source, + Func keySelector, Func elementSelector, + Func, TResult> resultSelector) + { + return GroupBy (source, keySelector, elementSelector, resultSelector, null); + } + + public static IEnumerable GroupBy (this IEnumerable source, + Func keySelector, Func elementSelector, + Func, TResult> resultSelector, + IEqualityComparer comparer) + { + IEnumerable> groups = GroupBy ( + source, keySelector, elementSelector, comparer); + + foreach (IGrouping group in groups) + yield return resultSelector (group.Key, group); + } + + public static IEnumerable GroupBy (this IEnumerable source, + Func keySelector, + Func, TResult> resultSelector) + { + return GroupBy (source, keySelector, resultSelector, null); + } + + public static IEnumerable GroupBy (this IEnumerable source, + Func keySelector, + Func, TResult> resultSelector, + IEqualityComparer comparer) + { + IEnumerable> groups = GroupBy (source, keySelector, comparer); + + foreach (IGrouping group in groups) + yield return resultSelector (group.Key, group); + } + #endregion # region GroupJoin @@ -839,14 +799,20 @@ namespace System.Linq Func innerKeySelector, Func, TResult> resultSelector, IEqualityComparer comparer) { - if (outer == null || inner == null || outerKeySelector == null || - innerKeySelector == null || resultSelector == null) - throw new ArgumentNullException (); + Check.JoinSelectors (outer, inner, outerKeySelector, innerKeySelector, resultSelector); if (comparer == null) comparer = EqualityComparer.Default; - Lookup innerKeys = ToLookup (inner, innerKeySelector, comparer); + return CreateGroupJoinIterator (outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer); + } + + static IEnumerable CreateGroupJoinIterator (this IEnumerable outer, + IEnumerable inner, Func outerKeySelector, + Func innerKeySelector, Func, TResult> resultSelector, + IEqualityComparer comparer) + { + ILookup innerKeys = ToLookup (inner, innerKeySelector, comparer); /*Dictionary> innerKeys = new Dictionary> (); foreach (U element in inner) { @@ -862,32 +828,34 @@ namespace System.Linq yield return resultSelector (element, innerKeys [outerKey]); else yield return resultSelector (element, Empty ()); - } + } } #endregion #region Intersect - public static IEnumerable Intersect (this IEnumerable first, IEnumerable second) { - if (first == null || second == null) - throw new ArgumentNullException (); + return Intersect (first, second, null); + } - List items = new List (Distinct (first)); - bool [] marked = new bool [items.Count]; - for (int i = 0; i < marked.Length; i++) - marked [i] = false; + public static IEnumerable Intersect (this IEnumerable first, IEnumerable second, IEqualityComparer comparer) + { + Check.FirstAndSecond (first, second); - foreach (TSource element in second) { - int index = IndexOf (items, element); - if (index != -1) - marked [index] = true; - } - for (int i = 0; i < marked.Length; i++) { - if (marked [i]) - yield return items [i]; + if (comparer == null) + comparer = EqualityComparer.Default; + + return CreateIntersectIterator (first, second, comparer); + } + + static IEnumerable CreateIntersectIterator (IEnumerable first, IEnumerable second, IEqualityComparer comparer) + { + var items = new HashSet (second, comparer); + foreach (TSource element in first) { + if (items.Remove (element)) + yield return element; } } @@ -899,14 +867,19 @@ namespace System.Linq IEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func resultSelector, IEqualityComparer comparer) { - if (outer == null || inner == null || outerKeySelector == null || - innerKeySelector == null || resultSelector == null) - throw new ArgumentNullException (); + Check.JoinSelectors (outer, inner, outerKeySelector, innerKeySelector, resultSelector); if (comparer == null) comparer = EqualityComparer.Default; - Lookup innerKeys = ToLookup (inner, innerKeySelector, comparer); + return CreateJoinIterator (outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer); + } + + static IEnumerable CreateJoinIterator (this IEnumerable outer, + IEnumerable inner, Func outerKeySelector, + Func innerKeySelector, Func resultSelector, IEqualityComparer comparer) + { + ILookup innerKeys = ToLookup (inner, innerKeySelector, comparer); /*Dictionary> innerKeys = new Dictionary> (); foreach (U element in inner) { @@ -929,49 +902,55 @@ namespace System.Linq IEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func resultSelector) { - return Join (outer, inner, outerKeySelector, innerKeySelector, resultSelector, null); + return outer.Join (inner, outerKeySelector, innerKeySelector, resultSelector, null); } - # endregion + + #endregion #region Last - public static TSource Last (this IEnumerable source) + static TSource Last (this IEnumerable source, Func predicate, Fallback fallback) { - if (source == null) - throw new ArgumentNullException (); + var empty = true; + var item = default (TSource); - bool noElements = true; - TSource lastElement = default (TSource); - foreach (TSource element in source) { - if (noElements) noElements = false; - lastElement = element; + foreach (var element in source) { + if (!predicate (element)) + continue; + + item = element; + empty = false; } - if (!noElements) - return lastElement; - else + if (!empty) + return item; + + if (fallback == Fallback.Throw) throw new InvalidOperationException (); + + return item; } - public static TSource Last (this IEnumerable source, - Func predicate) + public static TSource Last (this IEnumerable source) { - if (source == null || predicate == null) - throw new ArgumentNullException (); + Check.Source (source); - bool noElements = true; - TSource lastElement = default (TSource); - foreach (TSource element in source) { - if (predicate (element)) { - if (noElements) noElements = false; - lastElement = element; - } - } - - if (!noElements) - return lastElement; - else + var collection = source as ICollection; + if (collection != null && collection.Count == 0) throw new InvalidOperationException (); + + var list = source as IList; + if (list != null) + return list [list.Count - 1]; + + return source.Last (PredicateOf.Always, Fallback.Throw); + } + + public static TSource Last (this IEnumerable source, Func predicate) + { + Check.SourceAndPredicate (source, predicate); + + return source.Last (predicate, Fallback.Throw); } #endregion @@ -980,50 +959,47 @@ namespace System.Linq public static TSource LastOrDefault (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); - TSource lastElement = default (TSource); - foreach (TSource element in source) - lastElement = element; + var list = source as IList; + if (list != null) + return list.Count > 0 ? list [list.Count - 1] : default (TSource); - return lastElement; + return source.Last (PredicateOf.Always, Fallback.Default); } - public static TSource LastOrDefault (this IEnumerable source, - Func predicate) + public static TSource LastOrDefault (this IEnumerable source, Func predicate) { - if (source == null || predicate == null) - throw new ArgumentNullException (); - - TSource lastElement = default (TSource); - foreach (TSource element in source) { - if (predicate (element)) - lastElement = element; - } + Check.SourceAndPredicate (source, predicate); - return lastElement; + return source.Last (predicate, Fallback.Default); } #endregion #region LongCount + public static long LongCount (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); + +#if !NET_2_1 + var array = source as TSource []; + if (array != null) + return array.LongLength; +#endif long counter = 0; - foreach (TSource element in source) - counter++; + using (var enumerator = source.GetEnumerator ()) + while (enumerator.MoveNext ()) + counter++; + return counter; } - public static long LongCount (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); long counter = 0; foreach (TSource element in source) @@ -1039,160 +1015,123 @@ namespace System.Linq public static int Max (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - int maximum = int.MinValue; - int counter = 0; - foreach (int element in source) { - if (element > maximum) - maximum = element; - counter++; - } + Check.Source (source); - if (counter == 0) - throw new InvalidOperationException (); - else - return maximum; + return Iterate (source, int.MinValue, (a, b) => Math.Max (a, b)); } - - public static int? Max (this IEnumerable source) + public static long Max (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - bool onlyNull = true; - int? maximum = int.MinValue; - foreach (int? element in source) { - if (element.HasValue) { - onlyNull = false; - if (element > maximum) - maximum = element; - } - } - return (onlyNull ? null : maximum); + Check.Source (source); + + return Iterate (source, long.MinValue, (a, b) => Math.Max (a, b)); } + public static double Max (this IEnumerable source) + { + Check.Source (source); - public static long Max (this IEnumerable source) + return Iterate (source, double.MinValue, (a, b) => Math.Max (a, b)); + } + + public static float Max (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); - long maximum = long.MinValue; - int counter = 0; - foreach (long element in source) { - if (element > maximum) - maximum = element; - counter++; - } + return Iterate (source, float.MinValue, (a, b) => Math.Max (a, b)); + } - if (counter == 0) - throw new InvalidOperationException (); - else - return maximum; + public static decimal Max (this IEnumerable source) + { + Check.Source (source); + + return Iterate (source, decimal.MinValue, (a, b) => Math.Max (a, b)); } + public static int? Max (this IEnumerable source) + { + Check.Source (source); + + return IterateNullable (source, (a, b) => Math.Max (a, b)); + } public static long? Max (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - bool onlyNull = true; - long? maximum = long.MinValue; - foreach (long? element in source) { - if (element.HasValue) { - onlyNull = false; - if (element > maximum) - maximum = element; - } - } - return (onlyNull ? null : maximum); - } + Check.Source (source); + return IterateNullable (source, (a, b) => Math.Max (a, b)); + } - public static double Max (this IEnumerable source) + public static double? Max (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); - double maximum = double.MinValue; - int counter = 0; - foreach (double element in source) { - if (element > maximum) - maximum = element; - counter++; - } - - if (counter == 0) - throw new InvalidOperationException (); - else - return maximum; + return IterateNullable (source, (a, b) => Math.Max (a, b)); } - - public static double? Max (this IEnumerable source) + public static float? Max (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - bool onlyNull = true; - double? maximum = double.MinValue; - foreach (double? element in source) { - if (element.HasValue) { - onlyNull = false; - if (element > maximum) - maximum = element; - } - } - return (onlyNull ? null : maximum); + Check.Source (source); + + return IterateNullable (source, (a, b) => Math.Max (a, b)); } + public static decimal? Max (this IEnumerable source) + { + Check.Source (source); + + return IterateNullable (source, (a, b) => Math.Max (a, b)); + } - public static decimal Max (this IEnumerable source) + static T? IterateNullable (IEnumerable source, Func selector) where T : struct { - if (source == null) - throw new ArgumentNullException (); + bool empty = true; + T? value = null; + foreach (var element in source) { + if (!element.HasValue) + continue; - decimal maximum = decimal.MinValue; - int counter = 0; - foreach (decimal element in source) { - if (element > maximum) - maximum = element; - counter++; + if (!value.HasValue) + value = element.Value; + else + value = selector (element.Value, value.Value); + + empty = false; } - if (counter == 0) - throw new InvalidOperationException (); - else - return maximum; - } + if (empty) + return null; + return value; + } - public static decimal? Max (this IEnumerable source) + static TRet? IterateNullable ( + IEnumerable source, + Func source_selector, + Func selector) where TRet : struct { - if (source == null) - throw new ArgumentNullException (); - - bool onlyNull = true; - decimal? maximum = decimal.MinValue; - foreach (decimal? element in source) { - if (element.HasValue) { - onlyNull = false; - if (element > maximum) - maximum = element; - } + bool empty = true; + TRet? value = null; + foreach (var element in source) { + TRet? item = source_selector (element); + + if (!value.HasValue) + value = item; + else if (selector (item, value)) + value = item; + + empty = false; } - return (onlyNull ? null : maximum); - } + if (empty) + return null; + + return value; + } public static TSource Max (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); bool notAssigned = true; TSource maximum = default (TSource); @@ -1222,180 +1161,93 @@ namespace System.Linq return maximum; } - - public static int Max (this IEnumerable source, - Func selector) + public static int Max (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); - - int maximum = int.MinValue; - int counter = 0; - foreach (TSource item in source) { - int element = selector (item); - if (element > maximum) - maximum = element; - counter++; - } + Check.SourceAndSelector (source, selector); - if (counter == 0) - throw new InvalidOperationException (); - else - return maximum; + return Iterate (source, int.MinValue, (a, b) => Math.Max (selector (a), b)); } - - public static int? Max (this IEnumerable source, - Func selector) + public static long Max (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - bool onlyNull = true; - int? maximum = int.MinValue; - foreach (TSource item in source) { - int? element = selector (item); - if (element.HasValue) { - onlyNull = false; - if (element > maximum) - maximum = element; - } - } - return (onlyNull ? null : maximum); + return Iterate (source, long.MinValue, (a, b) => Math.Max (selector (a), b)); } - - public static long Max (this IEnumerable source, - Func selector) + public static double Max (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); - - long maximum = long.MinValue; - int counter = 0; - foreach (TSource item in source) { - long element = selector (item); - if (element > maximum) - maximum = element; - counter++; - } + Check.SourceAndSelector (source, selector); - if (counter == 0) - throw new InvalidOperationException (); - else - return maximum; + return Iterate (source, double.MinValue, (a, b) => Math.Max (selector (a), b)); } - - public static long? Max (this IEnumerable source, - Func selector) + public static float Max (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - bool onlyNull = true; - long? maximum = long.MinValue; - foreach (TSource item in source) { - long? element = selector (item); - if (element.HasValue) { - onlyNull = false; - if (element > maximum) - maximum = element; - } - } - return (onlyNull ? null : maximum); + return Iterate (source, float.MinValue, (a, b) => Math.Max (selector (a), b)); } - - public static double Max (this IEnumerable source, - Func selector) + public static decimal Max (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - double maximum = double.MinValue; - int counter = 0; - foreach (TSource item in source) { - double element = selector (item); - if (element > maximum) - maximum = element; - counter++; + return Iterate (source, decimal.MinValue, (a, b) => Math.Max (selector (a), b)); + } + + static U Iterate (IEnumerable source, U initValue, Func selector) + { + bool empty = true; + foreach (var element in source) { + initValue = selector (element, initValue); + empty = false; } - if (counter == 0) + if (empty) throw new InvalidOperationException (); - else - return maximum; - } + return initValue; + } - public static double? Max (this IEnumerable source, - Func selector) + public static int? Max (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - bool onlyNull = true; - double? maximum = double.MinValue; - foreach (TSource item in source) { - double? element = selector (item); - if (element.HasValue) { - onlyNull = false; - if (element > maximum) - maximum = element; - } - } - return (onlyNull ? null : maximum); + return IterateNullable (source, selector, (a, b) => a > b); } - - public static decimal Max (this IEnumerable source, - Func selector) + public static long? Max (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - decimal maximum = decimal.MinValue; - int counter = 0; - foreach (TSource item in source) { - decimal element = selector (item); - if (element > maximum) - maximum = element; - counter++; - } - - if (counter == 0) - throw new InvalidOperationException (); - else - return maximum; + return IterateNullable (source, selector, (a, b) => a > b); } + public static double? Max (this IEnumerable source, Func selector) + { + Check.SourceAndSelector (source, selector); + + return IterateNullable (source, selector, (a, b) => a > b); + } - public static decimal? Max (this IEnumerable source, - Func selector) + public static float? Max (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - bool onlyNull = true; - decimal? maximum = decimal.MinValue; - foreach (TSource item in source) { - decimal? element = selector (item); - if (element.HasValue) { - onlyNull = false; - if (element > maximum) - maximum = element; - } - } - return (onlyNull ? null : maximum); + return IterateNullable (source, selector, (a, b) => a > b); } + public static decimal? Max (this IEnumerable source, Func selector) + { + Check.SourceAndSelector (source, selector); + + return IterateNullable (source, selector, (a, b) => a > b); + } - public static TResult Max (this IEnumerable source, - Func selector) + public static TResult Max (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); bool notAssigned = true; TResult maximum = default (TResult); @@ -1432,159 +1284,77 @@ namespace System.Linq public static int Min (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - int minimum = int.MaxValue; - int counter = 0; - foreach (int element in source) { - if (element < minimum) - minimum = element; - counter++; - } + Check.Source (source); - if (counter == 0) - throw new InvalidOperationException (); - else - return minimum; + return Iterate (source, int.MaxValue, (a, b) => Math.Min (a, b)); } - - public static int? Min (this IEnumerable source) + public static long Min (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - bool onlyNull = true; - int? minimum = int.MaxValue; - foreach (int? element in source) { - if (element.HasValue) { - onlyNull = false; - if (element < minimum) - minimum = element; - } - } - return (onlyNull ? null : minimum); + Check.Source (source); + + return Iterate (source, long.MaxValue, (a, b) => Math.Min (a, b)); } - public static long Min (this IEnumerable source) + public static double Min (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); - long minimum = long.MaxValue; - int counter = 0; - foreach (long element in source) { - if (element < minimum) - minimum = element; - counter++; - } - - if (counter == 0) - throw new InvalidOperationException (); - else - return minimum; + return Iterate (source, double.MaxValue, (a, b) => Math.Min (a, b)); } - - public static long? Min (this IEnumerable source) + public static float Min (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - bool onlyNull = true; - long? minimum = long.MaxValue; - foreach (long? element in source) { - if (element.HasValue) { - onlyNull = false; - if (element < minimum) - minimum = element; - } - } - return (onlyNull ? null : minimum); - } + Check.Source (source); + return Iterate (source, float.MaxValue, (a, b) => Math.Min (a, b)); + } - public static double Min (this IEnumerable source) + public static decimal Min (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); - double minimum = double.MaxValue; - int counter = 0; - foreach (double element in source) { - if (element < minimum) - minimum = element; - counter++; - } + return Iterate (source, decimal.MaxValue, (a, b) => Math.Min (a, b)); + } - if (counter == 0) - throw new InvalidOperationException (); - else - return minimum; + public static int? Min (this IEnumerable source) + { + Check.Source (source); + + return IterateNullable (source, (a, b) => Math.Min (a, b)); } + public static long? Min (this IEnumerable source) + { + Check.Source (source); + + return IterateNullable (source, (a, b) => Math.Min (a, b)); + } public static double? Min (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - bool onlyNull = true; - double? minimum = double.MaxValue; - foreach (double? element in source) { - if (element.HasValue) { - onlyNull = false; - if (element < minimum) - minimum = element; - } - } - return (onlyNull ? null : minimum); - } + Check.Source (source); + return IterateNullable (source, (a, b) => Math.Min (a, b)); + } - public static decimal Min (this IEnumerable source) + public static float? Min (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - decimal minimum = decimal.MaxValue; - int counter = 0; - foreach (decimal element in source) { - if (element < minimum) - minimum = element; - counter++; - } + Check.Source (source); - if (counter == 0) - throw new InvalidOperationException (); - else - return minimum; + return IterateNullable (source, (a, b) => Math.Min (a, b)); } - public static decimal? Min (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - bool onlyNull = true; - decimal? minimum = decimal.MaxValue; - foreach (decimal? element in source) { - if (element.HasValue) { - onlyNull = false; - if (element < minimum) - minimum = element; - } - } - return (onlyNull ? null : minimum); - } + Check.Source (source); + return IterateNullable (source, (a, b) => Math.Min (a, b)); + } public static TSource Min (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); bool notAssigned = true; TSource minimum = default (TSource); @@ -1614,180 +1384,79 @@ namespace System.Linq return minimum; } - - public static int Min (this IEnumerable source, - Func selector) + public static int Min (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - int minimum = int.MaxValue; - int counter = 0; - foreach (TSource item in source) { - int element = selector (item); - if (element < minimum) - minimum = element; - counter++; - } - - if (counter == 0) - throw new InvalidOperationException (); - else - return minimum; + return Iterate (source, int.MaxValue, (a, b) => Math.Min (selector (a), b)); } - - public static int? Min (this IEnumerable source, - Func selector) + public static long Min (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - bool onlyNull = true; - int? minimum = int.MaxValue; - foreach (TSource item in source) { - int? element = selector (item); - if (element.HasValue) { - onlyNull = false; - if (element < minimum) - minimum = element; - } - } - return (onlyNull ? null : minimum); + return Iterate (source, long.MaxValue, (a, b) => Math.Min (selector (a), b)); } - - public static long Min (this IEnumerable source, - Func selector) + public static double Min (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - long minimum = long.MaxValue; - int counter = 0; - foreach (TSource item in source) { - long element = selector (item); - if (element < minimum) - minimum = element; - counter++; - } - - if (counter == 0) - throw new InvalidOperationException (); - else - return minimum; + return Iterate (source, double.MaxValue, (a, b) => Math.Min (selector (a), b)); } - - public static long? Min (this IEnumerable source, - Func selector) + public static float Min (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - bool onlyNull = true; - long? minimum = long.MaxValue; - foreach (TSource item in source) { - long? element = selector (item); - if (element.HasValue) { - onlyNull = false; - if (element < minimum) - minimum = element; - } - } - return (onlyNull ? null : minimum); + return Iterate (source, float.MaxValue, (a, b) => Math.Min (selector (a), b)); } - - public static double Min (this IEnumerable source, - Func selector) + public static decimal Min (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); - - double minimum = double.MaxValue; - int counter = 0; - foreach (TSource item in source) { - double element = selector (item); - if (element < minimum) - minimum = element; - counter++; - } + Check.SourceAndSelector (source, selector); - if (counter == 0) - throw new InvalidOperationException (); - else - return minimum; + return Iterate (source, decimal.MaxValue, (a, b) => Math.Min (selector (a), b)); } - - public static double? Min (this IEnumerable source, - Func selector) + public static int? Min (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - bool onlyNull = true; - double? minimum = double.MaxValue; - foreach (TSource item in source) { - double? element = selector (item); - if (element.HasValue) { - onlyNull = false; - if (element < minimum) - minimum = element; - } - } - return (onlyNull ? null : minimum); + return IterateNullable (source, selector, (a, b) => a < b); } - - public static decimal Min (this IEnumerable source, - Func selector) + public static long? Min (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - decimal minimum = decimal.MaxValue; - int counter = 0; - foreach (TSource item in source) { - decimal element = selector (item); - if (element < minimum) - minimum = element; - counter++; - } - - if (counter == 0) - throw new InvalidOperationException (); - else - return minimum; + return IterateNullable (source, selector, (a, b) => a < b); } + public static float? Min (this IEnumerable source, Func selector) + { + Check.SourceAndSelector (source, selector); - public static decimal? Min (this IEnumerable source, - Func selector) + return IterateNullable (source, selector, (a, b) => a < b); + } + + public static double? Min (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - bool onlyNull = true; - decimal? minimum = decimal.MaxValue; - foreach (TSource item in source) { - decimal? element = selector (item); - if (element.HasValue) { - onlyNull = false; - if (element < minimum) - minimum = element; - } - } - return (onlyNull ? null : minimum); + return IterateNullable (source, selector, (a, b) => a < b); } + public static decimal? Min (this IEnumerable source, Func selector) + { + Check.SourceAndSelector (source, selector); + + return IterateNullable (source, selector, (a, b) => a < b); + } - public static TResult Min (this IEnumerable source, - Func selector) + public static TResult Min (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); bool notAssigned = true; TResult minimum = default (TResult); @@ -1824,9 +1493,13 @@ namespace System.Linq public static IEnumerable OfType (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); + return CreateOfTypeIterator (source); + } + + static IEnumerable CreateOfTypeIterator (IEnumerable source) + { foreach (object element in source) if (element is TResult) yield return (TResult) element; @@ -1842,16 +1515,13 @@ namespace System.Linq return OrderBy (source, keySelector, null); } - public static IOrderedEnumerable OrderBy (this IEnumerable source, Func keySelector, IComparer comparer) { - if (source == null || keySelector == null) - throw new ArgumentNullException (); + Check.SourceAndKeySelector (source, keySelector); - return new InternalOrderedSequence ( - source, keySelector, comparer, false); + return new OrderedSequence (source, keySelector, comparer, SortDirection.Ascending); } #endregion @@ -1864,15 +1534,12 @@ namespace System.Linq return OrderByDescending (source, keySelector, null); } - public static IOrderedEnumerable OrderByDescending (this IEnumerable source, Func keySelector, IComparer comparer) { - if (source == null || keySelector == null) - throw new ArgumentNullException (); + Check.SourceAndKeySelector (source, keySelector); - return new InternalOrderedSequence ( - source, keySelector, comparer, true); + return new OrderedSequence (source, keySelector, comparer, SortDirection.Descending); } #endregion @@ -1881,10 +1548,20 @@ namespace System.Linq public static IEnumerable Range (int start, int count) { - if (count < 0 || (start + count - 1) > int.MaxValue) + if (count < 0) + throw new ArgumentOutOfRangeException ("count"); + + long upto = ((long) start + count) - 1; + + if (upto > int.MaxValue) throw new ArgumentOutOfRangeException (); - for (int i = start; i < (start + count - 1); i++) + return CreateRangeIterator (start, (int) upto); + } + + static IEnumerable CreateRangeIterator (int start, int upto) + { + for (int i = start; i <= upto; i++) yield return i; } @@ -1897,46 +1574,62 @@ namespace System.Linq if (count < 0) throw new ArgumentOutOfRangeException (); + return CreateRepeatIterator (element, count); + } + + static IEnumerable CreateRepeatIterator (TResult element, int count) + { for (int i = 0; i < count; i++) yield return element; } #endregion - #region Reverse public static IEnumerable Reverse (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); + + var list = source as IList; + if (list == null) + list = new List (source); + + return CreateReverseIterator (list); + } - List list = new List (source); - list.Reverse (); - return list; + static IEnumerable CreateReverseIterator (IList source) + { + for (int i = source.Count; i > 0; --i) + yield return source [i - 1]; } #endregion #region Select - public static IEnumerable Select (this IEnumerable source, - Func selector) + public static IEnumerable Select (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - foreach (TSource element in source) - yield return selector (element); + return CreateSelectIterator (source, selector); } + static IEnumerable CreateSelectIterator (IEnumerable source, Func selector) + { + foreach (var element in source) + yield return selector (element); + } - public static IEnumerable Select (this IEnumerable source, - Func selector) + public static IEnumerable Select (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); + return CreateSelectIterator (source, selector); + } + + static IEnumerable CreateSelectIterator (IEnumerable source, Func selector) + { int counter = 0; foreach (TSource element in source) { yield return selector (element, counter); @@ -1948,27 +1641,32 @@ namespace System.Linq #region SelectMany - public static IEnumerable SelectMany (this IEnumerable source, - Func> selector) + public static IEnumerable SelectMany (this IEnumerable source, Func> selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); + return CreateSelectManyIterator (source, selector); + } + + static IEnumerable CreateSelectManyIterator (IEnumerable source, Func> selector) + { foreach (TSource element in source) foreach (TResult item in selector (element)) yield return item; } - - public static IEnumerable SelectMany (this IEnumerable source, - Func> selector) + public static IEnumerable SelectMany (this IEnumerable source, Func> selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); + return CreateSelectManyIterator (source, selector); + } + + static IEnumerable CreateSelectManyIterator (IEnumerable source, Func> selector) + { int counter = 0; foreach (TSource element in source) { - foreach (TResult item in selector (element, counter++)) + foreach (TResult item in selector (element, counter)) yield return item; counter++; } @@ -1977,20 +1675,30 @@ namespace System.Linq public static IEnumerable SelectMany (this IEnumerable source, Func> collectionSelector, Func selector) { - if (source == null || collectionSelector == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndCollectionSelectors (source, collectionSelector, selector); + + return CreateSelectManyIterator (source, collectionSelector, selector); + } + static IEnumerable CreateSelectManyIterator (IEnumerable source, + Func> collectionSelector, Func selector) + { foreach (TSource element in source) foreach (TCollection collection in collectionSelector (element)) yield return selector (element, collection); } public static IEnumerable SelectMany (this IEnumerable source, - Func> collectionSelector, Func selector) + Func> collectionSelector, Func selector) { - if (source == null || collectionSelector == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndCollectionSelectors (source, collectionSelector, selector); + + return CreateSelectManyIterator (source, collectionSelector, selector); + } + static IEnumerable CreateSelectManyIterator (IEnumerable source, + Func> collectionSelector, Func selector) + { int counter = 0; foreach (TSource element in source) foreach (TCollection collection in collectionSelector (element, counter++)) @@ -2001,46 +1709,40 @@ namespace System.Linq #region Single - public static TSource Single (this IEnumerable source) + static TSource Single (this IEnumerable source, Func predicate, Fallback fallback) { - if (source == null) - throw new ArgumentNullException (); + var found = false; + var item = default (TSource); - bool otherElement = false; - TSource singleElement = default (TSource); - foreach (TSource element in source) { - if (otherElement) throw new InvalidOperationException (); - if (!otherElement) otherElement = true; - singleElement = element; + foreach (var element in source) { + if (!predicate (element)) + continue; + + if (found) + throw new InvalidOperationException (); + + found = true; + item = element; } - if (otherElement) - return singleElement; - else + if (!found && fallback == Fallback.Throw) throw new InvalidOperationException (); - } + return item; + } - public static TSource Single (this IEnumerable source, - Func predicate) + public static TSource Single (this IEnumerable source) { - if (source == null || predicate == null) - throw new ArgumentNullException (); + Check.Source (source); - bool otherElement = false; - TSource singleElement = default (TSource); - foreach (TSource element in source) { - if (predicate (element)) { - if (otherElement) throw new InvalidOperationException (); - if (!otherElement) otherElement = true; - singleElement = element; - } - } + return source.Single (PredicateOf.Always, Fallback.Throw); + } - if (otherElement) - return singleElement; - else - throw new InvalidOperationException (); + public static TSource Single (this IEnumerable source, Func predicate) + { + Check.SourceAndPredicate (source, predicate); + + return source.Single (predicate, Fallback.Throw); } #endregion @@ -2049,67 +1751,53 @@ namespace System.Linq public static TSource SingleOrDefault (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - bool otherElement = false; - TSource singleElement = default (TSource); - foreach (TSource element in source) { - if (otherElement) throw new InvalidOperationException (); - if (!otherElement) otherElement = true; - singleElement = element; - } + Check.Source (source); - return singleElement; + return source.Single (PredicateOf.Always, Fallback.Default); } - - public static TSource SingleOrDefault (this IEnumerable source, - Func predicate) + public static TSource SingleOrDefault (this IEnumerable source, Func predicate) { - if (source == null || predicate == null) - throw new ArgumentNullException (); - - bool otherElement = false; - TSource singleElement = default (TSource); - foreach (TSource element in source) { - if (predicate (element)) { - if (otherElement) throw new InvalidOperationException (); - if (!otherElement) otherElement = true; - singleElement = element; - } - } + Check.SourceAndPredicate (source, predicate); - return singleElement; + return source.Single (predicate, Fallback.Default); } #endregion #region Skip + public static IEnumerable Skip (this IEnumerable source, int count) { - if (source == null) - throw new NotSupportedException (); + Check.Source (source); + return CreateSkipIterator (source, count); + } + + static IEnumerable CreateSkipIterator (IEnumerable source, int count) + { int i = 0; - foreach (TSource e in source) { - if (++i < count) + foreach (var element in source) { + if (i++ < count) continue; - yield return e; + + yield return element; } } + #endregion #region SkipWhile - - public static IEnumerable SkipWhile ( - this IEnumerable source, - Func predicate) + public static IEnumerable SkipWhile (this IEnumerable source, Func predicate) { - if (source == null || predicate == null) - throw new ArgumentNullException (); + Check.SourceAndPredicate (source, predicate); + + return CreateSkipWhileIterator (source, predicate); + } + static IEnumerable CreateSkipWhileIterator (IEnumerable source, Func predicate) + { bool yield = false; foreach (TSource element in source) { @@ -2123,13 +1811,15 @@ namespace System.Linq } } - - public static IEnumerable SkipWhile (this IEnumerable source, - Func predicate) + public static IEnumerable SkipWhile (this IEnumerable source, Func predicate) { - if (source == null || predicate == null) - throw new ArgumentNullException (); + Check.SourceAndPredicate (source, predicate); + + return CreateSkipWhileIterator (source, predicate); + } + static IEnumerable CreateSkipWhileIterator (IEnumerable source, Func predicate) + { int counter = 0; bool yield = false; @@ -2151,245 +1841,203 @@ namespace System.Linq public static int Sum (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException ("source"); - - int sum = 0; - foreach (int element in source) - sum += element; + Check.Source (source); - return sum; + return Sum (source, (a, b) => checked (a + b)); } - - public static int Sum (this IEnumerable source, Func selector) + public static int? Sum (this IEnumerable source) { - if (source == null || selector == null) - throw new ArgumentNullException (); - - int sum = 0; - foreach (TSource element in source) - sum += selector (element); + Check.Source (source); - return sum; + return source.SumNullable (0, (total, element) => element.HasValue ? checked (total + element) : total); } - - public static int? Sum (this IEnumerable source) + public static int Sum (this IEnumerable source, Func selector) { - if (source == null) - throw new ArgumentNullException (); - - int? sum = 0; - foreach (int? element in source) - if (element.HasValue) - sum += element.Value; + Check.SourceAndSelector (source, selector); - return sum; + return Sum (source, (a, b) => checked (a + selector (b))); } - public static int? Sum (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); - - int? sum = 0; - foreach (TSource element in source) { - int? item = selector (element); - if (item.HasValue) - sum += item.Value; - } + Check.SourceAndSelector (source, selector); - return sum; + return source.SumNullable (0, (a, b) => { + var value = selector (b); + return value.HasValue ? checked (a + value.Value) : a; + }); } - public static long Sum (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); - long sum = 0; - foreach (long element in source) - sum += element; - - return sum; + return Sum (source, (a, b) => checked (a + b)); } - - public static long Sum (this IEnumerable source, Func selector) + public static long? Sum (this IEnumerable source) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.Source (source); - long sum = 0; - foreach (TSource element in source) - sum += selector (element); - - return sum; + return source.SumNullable (0, (total, element) => element.HasValue ? checked (total + element) : total); } - - public static long? Sum (this IEnumerable source) + public static long Sum (this IEnumerable source, Func selector) { - if (source == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - long? sum = 0; - foreach (long? element in source) - if (element.HasValue) - sum += element.Value; - - return sum; + return Sum (source, (a, b) => checked (a + selector (b))); } - public static long? Sum (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - long? sum = 0; - foreach (TSource element in source) { - long? item = selector (element); - if (item.HasValue) - sum += item.Value; - } - - return sum; + return source.SumNullable (0, (a, b) => { + var value = selector (b); + return value.HasValue ? checked (a + value.Value) : a; + }); } - public static double Sum (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - double sum = 0; - foreach (double element in source) - sum += element; + Check.Source (source); - return sum; + return Sum (source, (a, b) => checked (a + b)); } + public static double? Sum (this IEnumerable source) + { + Check.Source (source); + + return source.SumNullable (0, (total, element) => element.HasValue ? checked (total + element) : total); + } public static double Sum (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); - - double sum = 0; - foreach (TSource element in source) - sum += selector (element); + Check.SourceAndSelector (source, selector); - return sum; + return Sum (source, (a, b) => checked (a + selector (b))); } - - public static double? Sum (this IEnumerable source) + public static double? Sum (this IEnumerable source, Func selector) { - if (source == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - double? sum = 0; - foreach (double? element in source) - if (element.HasValue) - sum += element.Value; - - return sum; + return source.SumNullable (0, (a, b) => { + var value = selector (b); + return value.HasValue ? checked (a + value.Value) : a; + }); } + public static float Sum (this IEnumerable source) + { + Check.Source (source); - public static double? Sum (this IEnumerable source, Func selector) + return Sum (source, (a, b) => checked (a + b)); + } + + public static float? Sum (this IEnumerable source) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.Source (source); - double? sum = 0; - foreach (TSource element in source) { - double? item = selector (element); - if (item.HasValue) - sum += item.Value; - } + return source.SumNullable (0, (total, element) => element.HasValue ? checked (total + element) : total); + } - return sum; + public static float Sum (this IEnumerable source, Func selector) + { + Check.SourceAndSelector (source, selector); + + return Sum (source, (a, b) => checked (a + selector (b))); } + public static float? Sum (this IEnumerable source, Func selector) + { + Check.SourceAndSelector (source, selector); + + return source.SumNullable (0, (a, b) => { + var value = selector (b); + return value.HasValue ? checked (a + value.Value) : a; + }); + } public static decimal Sum (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); - decimal sum = 0; - foreach (decimal element in source) - sum += element; - - return sum; + return Sum (source, (a, b) => checked (a + b)); } + public static decimal? Sum (this IEnumerable source) + { + Check.Source (source); + + return source.SumNullable (0, (total, element) => element.HasValue ? checked (total + element) : total); + } public static decimal Sum (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); - - decimal sum = 0; - foreach (TSource element in source) - sum += selector (element); + Check.SourceAndSelector (source, selector); - return sum; + return Sum (source, (a, b) => checked (a + selector (b))); } - - public static decimal? Sum (this IEnumerable source) + public static decimal? Sum (this IEnumerable source, Func selector) { - if (source == null) - throw new ArgumentNullException (); - - decimal? sum = 0; - foreach (decimal? element in source) - if (element.HasValue) - sum += element.Value; + Check.SourceAndSelector (source, selector); - return sum; + return source.SumNullable (0, (a, b) => { + var value = selector (b); + return value.HasValue ? checked (a + value.Value) : a; + }); } - - public static decimal? Sum (this IEnumerable source, Func selector) + static TR Sum (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + TR total = default (TR); + long counter = 0; + foreach (var element in source) { + total = selector (total, element); + ++counter; + } - decimal? sum = 0; - foreach (TSource element in source) { - decimal? item = selector (element); - if (item.HasValue) - sum += item.Value; + return total; + } + + static TR SumNullable (this IEnumerable source, TR zero, Func selector) + { + TR total = zero; + foreach (var element in source) { + total = selector (total, element); } - return sum; + return total; } #endregion + #region Take public static IEnumerable Take (this IEnumerable source, int count) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); + return CreateTakeIterator (source, count); + } + + static IEnumerable CreateTakeIterator (IEnumerable source, int count) + { if (count <= 0) yield break; - else { - int counter = 0; - foreach (TSource element in source) { - yield return element; - counter++; - if (counter == count) - yield break; - } + + int counter = 0; + foreach (TSource element in source) { + yield return element; + + if (++counter == count) + yield break; } } @@ -2399,28 +2047,36 @@ namespace System.Linq public static IEnumerable TakeWhile (this IEnumerable source, Func predicate) { - if (source == null || predicate == null) - throw new ArgumentNullException (); + Check.SourceAndPredicate (source, predicate); - foreach (TSource element in source) { - if (predicate (element)) - yield return element; - else + return CreateTakeWhileIterator (source, predicate); + } + + static IEnumerable CreateTakeWhileIterator (IEnumerable source, Func predicate) + { + foreach (var element in source) { + if (!predicate (element)) yield break; + + yield return element; } } public static IEnumerable TakeWhile (this IEnumerable source, Func predicate) { - if (source == null || predicate == null) - throw new ArgumentNullException (); + Check.SourceAndPredicate (source, predicate); + + return CreateTakeWhileIterator (source, predicate); + } + static IEnumerable CreateTakeWhileIterator (IEnumerable source, Func predicate) + { int counter = 0; - foreach (TSource element in source) { - if (predicate (element, counter)) - yield return element; - else + foreach (var element in source) { + if (!predicate (element, counter)) yield break; + + yield return element; counter++; } } @@ -2434,12 +2090,10 @@ namespace System.Linq return ThenBy (source, keySelector, null); } - public static IOrderedEnumerable ThenBy (this IOrderedEnumerable source, Func keySelector, IComparer comparer) { - if (source == null || keySelector == null) - throw new ArgumentNullException (); + Check.SourceAndKeySelector (source, keySelector); return source.CreateOrderedEnumerable (keySelector, comparer, false); } @@ -2454,12 +2108,10 @@ namespace System.Linq return ThenByDescending (source, keySelector, null); } - public static IOrderedEnumerable ThenByDescending (this IOrderedEnumerable source, Func keySelector, IComparer comparer) { - if (source == null || keySelector == null) - throw new ArgumentNullException (); + Check.SourceAndKeySelector (source, keySelector); return source.CreateOrderedEnumerable (keySelector, comparer, true); } @@ -2467,47 +2119,63 @@ namespace System.Linq #endregion #region ToArray + public static TSource [] ToArray (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); + + var collection = source as ICollection; + if (collection != null) { + var array = new TSource [collection.Count]; + collection.CopyTo (array, 0); + return array; + } - List list = new List (source); - return list.ToArray (); + return new List (source).ToArray (); } #endregion #region ToDictionary - public static Dictionary ToDictionary (this IEnumerable source, Func keySelector, Func elementSelector) + public static Dictionary ToDictionary (this IEnumerable source, + Func keySelector, Func elementSelector) { return ToDictionary (source, keySelector, elementSelector, null); } - - public static Dictionary ToDictionary (this IEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer) + public static Dictionary ToDictionary (this IEnumerable source, + Func keySelector, Func elementSelector, IEqualityComparer comparer) { - if (source == null) - throw new ArgumentNullException ("source"); - if (keySelector == null) - throw new ArgumentNullException ("keySelector"); - if (elementSelector == null) - throw new ArgumentNullException ("elementSelector"); - - Dictionary dict = new Dictionary (comparer); - foreach (TSource e in source) { + Check.SourceAndKeyElementSelectors (source, keySelector, elementSelector); + + if (comparer == null) + comparer = EqualityComparer.Default; + + var dict = new Dictionary (comparer); + foreach (var e in source) dict.Add (keySelector (e), elementSelector (e)); - } return dict; } + + public static Dictionary ToDictionary (this IEnumerable source, + Func keySelector) + { + return ToDictionary (source, keySelector, null); + } + + public static Dictionary ToDictionary (this IEnumerable source, + Func keySelector, IEqualityComparer comparer) + { + return ToDictionary (source, keySelector, Function.Identity, comparer); + } + #endregion #region ToList public static List ToList (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException ("source"); + Check.Source (source); return new List (source); } @@ -2515,84 +2183,110 @@ namespace System.Linq #region ToLookup - public static Lookup ToLookup (this IEnumerable source, Func keySelector) + public static ILookup ToLookup (this IEnumerable source, Func keySelector) { - return ToLookup (source, keySelector, null); + return ToLookup (source, keySelector, Function.Identity, null); } - - public static Lookup ToLookup (this IEnumerable source, + public static ILookup ToLookup (this IEnumerable source, Func keySelector, IEqualityComparer comparer) { - if (source == null || keySelector == null) - throw new ArgumentNullException (); - - Dictionary> dictionary = new Dictionary> (comparer ?? EqualityComparer.Default); - foreach (TSource element in source) { - TKey key = keySelector (element); - if (key == null) - throw new ArgumentNullException (); - if (!dictionary.ContainsKey (key)) - dictionary.Add (key, new List ()); - dictionary [key].Add (element); - } - return new Lookup (dictionary); + return ToLookup (source, keySelector, element => element, comparer); } - - public static Lookup ToLookup (this IEnumerable source, + public static ILookup ToLookup (this IEnumerable source, Func keySelector, Func elementSelector) { return ToLookup (source, keySelector, elementSelector, null); } - - public static Lookup ToLookup (this IEnumerable source, + public static ILookup ToLookup (this IEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer) { - if (source == null || keySelector == null || elementSelector == null) - throw new ArgumentNullException (); + Check.SourceAndKeyElementSelectors (source, keySelector, elementSelector); - Dictionary> dictionary = new Dictionary> (comparer ?? EqualityComparer.Default); - foreach (TSource element in source) { - TKey key = keySelector (element); + var dictionary = new Dictionary> (comparer ?? EqualityComparer.Default); + foreach (var element in source) { + var key = keySelector (element); if (key == null) - throw new ArgumentNullException (); - if (!dictionary.ContainsKey (key)) - dictionary.Add (key, new List ()); - dictionary [key].Add (elementSelector (element)); + throw new ArgumentNullException ("key"); + + List list; + if (!dictionary.TryGetValue (key, out list)) { + list = new List (); + dictionary.Add (key, list); + } + + list.Add (elementSelector (element)); } + return new Lookup (dictionary); } #endregion - #region ToSequence + #region SequenceEqual + + public static bool SequenceEqual (this IEnumerable first, IEnumerable second) + { + return first.SequenceEqual (second, null); + } - public static IEnumerable ToSequence (this IEnumerable source) + public static bool SequenceEqual (this IEnumerable first, IEnumerable second, IEqualityComparer comparer) { - return (IEnumerable) source; + Check.FirstAndSecond (first, second); + + if (comparer == null) + comparer = EqualityComparer.Default; + + using (IEnumerator first_enumerator = first.GetEnumerator (), + second_enumerator = second.GetEnumerator ()) { + + while (first_enumerator.MoveNext ()) { + if (!second_enumerator.MoveNext ()) + return false; + + if (!comparer.Equals (first_enumerator.Current, second_enumerator.Current)) + return false; + } + + return !second_enumerator.MoveNext (); + } } #endregion #region Union - public static IEnumerable Union (this IEnumerable first, IEnumerable second) { - if (first == null || second == null) - throw new ArgumentNullException (); + Check.FirstAndSecond (first, second); - List items = new List (); - foreach (TSource element in first) { - if (IndexOf (items, element) == -1) { + return first.Union (second, null); + } + + public static IEnumerable Union (this IEnumerable first, IEnumerable second, IEqualityComparer comparer) + { + Check.FirstAndSecond (first, second); + + if (comparer == null) + comparer = EqualityComparer.Default; + + return CreateUnionIterator (first, second, comparer); + } + + static IEnumerable CreateUnionIterator (IEnumerable first, IEnumerable second, IEqualityComparer comparer) + { + var items = new HashSet (comparer); + foreach (var element in first) { + if (! items.Contains (element)) { items.Add (element); yield return element; } } - foreach (TSource element in second) { - if (IndexOf (items, element) == -1) { + + foreach (var element in second) { + if (! items.Contains (element, comparer)) { items.Add (element); yield return element; } @@ -2603,24 +2297,29 @@ namespace System.Linq #region Where - public static IEnumerable Where (this IEnumerable source, - Func predicate) + public static IEnumerable Where (this IEnumerable source, Func predicate) { - if (source == null || predicate == null) - throw new ArgumentNullException (); + Check.SourceAndPredicate (source, predicate); + return CreateWhereIterator (source, predicate); + } + + static IEnumerable CreateWhereIterator (IEnumerable source, Func predicate) + { foreach (TSource element in source) if (predicate (element)) yield return element; } - - public static IEnumerable Where (this IEnumerable source, - Func predicate) + public static IEnumerable Where (this IEnumerable source, Func predicate) { - if (source == null || predicate == null) - throw new ArgumentNullException (); + Check.SourceAndPredicate (source, predicate); + return CreateWhereIterator (source, predicate); + } + + static IEnumerable CreateWhereIterator (this IEnumerable source, Func predicate) + { int counter = 0; foreach (TSource element in source) { if (predicate (element, counter)) @@ -2631,61 +2330,20 @@ namespace System.Linq #endregion - // These methods are not included in the - // .NET Standard Query Operators Specification, - // but they provide additional useful commands - - #region Compare - - private static bool Equals (T first, T second) - { - // Mostly, values in Enumerable - // sequences need to be compared using - // Equals and GetHashCode - - if (first == null || second == null) - return (first == null && second == null); - else - return ((first.Equals (second) || - first.GetHashCode () == second.GetHashCode ())); - } - - #endregion - - #region IndexOf - - static int IndexOf (this IEnumerable source, T item, IEqualityComparer comparer) - { - if (comparer == null) - comparer = EqualityComparer.Default; - - int counter = 0; - foreach (T element in source) { - if (comparer.Equals (element, item)) - return counter; - counter++; - } - // The item was not found - return -1; - } - - static int IndexOf (this IEnumerable source, T item) - { - return IndexOf (source, item, null); + class ReadOnlyCollectionOf { + public static readonly ReadOnlyCollection Empty = new ReadOnlyCollection (new T [0]); } - #endregion - #region ToReadOnlyCollection - internal static ReadOnlyCollection ToReadOnlyCollection (IEnumerable source) + internal static ReadOnlyCollection ToReadOnlyCollection (this IEnumerable source) { if (source == null) - return new ReadOnlyCollection (new List ()); + return ReadOnlyCollectionOf.Empty; - if (typeof (ReadOnlyCollection).IsInstanceOfType (source)) - return source as ReadOnlyCollection; + var ro = source as ReadOnlyCollection; + if (ro != null) + return ro; - return new ReadOnlyCollection (ToArray (source)); + return new ReadOnlyCollection (source.ToArray ()); } - #endregion } }