X-Git-Url: http://wien.tomnetworks.com/gitweb/?a=blobdiff_plain;f=mcs%2Fclass%2FSystem.Core%2FSystem.Linq%2FEnumerable.cs;h=4d46b21bc0a431d5927645d9739417f6694e4e10;hb=935d94af6ab8f7262e84def31e47a66e25c696a9;hp=f24e56c4b7aaaaa6c126339d3fc2f4b40088ca98;hpb=fdcf281f9b18ca409ed2ed64e22cd7c4e0ffa03c;p=mono.git diff --git a/mcs/class/System.Core/System.Linq/Enumerable.cs b/mcs/class/System.Core/System.Linq/Enumerable.cs index f24e56c4b7a..4d46b21bc0a 100644 --- a/mcs/class/System.Core/System.Linq/Enumerable.cs +++ b/mcs/class/System.Core/System.Linq/Enumerable.cs @@ -1,3 +1,14 @@ +// +// Enumerable.cs +// +// Authors: +// Marek Safar (marek.safar@gmail.com) +// Antonello Provenzano +// Alejandro Serrano "Serras" (trupill@yahoo.es) +// Jb Evain (jbevain@novell.com) +// +// Copyright (C) 2007 Novell, Inc (http://www.novell.com) +// // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the // "Software"), to deal in the Software without restriction, including @@ -5,22 +16,20 @@ // distribute, sublicense, and/or sell copies of the Software, and to // permit persons to whom the Software is furnished to do so, subject to // the following conditions: -// +// // The above copyright notice and this permission notice shall be // included in all copies or substantial portions of the Software. -// +// // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE // LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. // -// Authors: -// Marek Safar (marek.safar@gmail.com) -// Antonello Provenzano -// Alejandro Serrano "Serras" (trupill@yahoo.es) -// + +// precious: http://www.hookedonlinq.com using System; using System.Collections; @@ -31,17 +40,24 @@ namespace System.Linq { public static class Enumerable { + enum Fallback { + Default, + Throw + } + + class PredicateOf { + public static readonly Func Always = (t) => true; + } + #region Aggregate + public static TSource Aggregate (this IEnumerable source, Func func) { - if (source == null) - throw new ArgumentNullException ("source"); - if (func == null) - throw new ArgumentNullException ("func"); + Check.SourceAndFunc (source, func); - // custom foreach so that we can efficiently throw an exception + // custom foreach so that we can efficiently throw an exception // if zero elements and treat the first element differently - using (IEnumerator enumerator = source.GetEnumerator ()) { + using (var enumerator = source.GetEnumerator ()) { if (!enumerator.MoveNext ()) throw new InvalidOperationException ("No elements in source list"); @@ -52,406 +68,287 @@ namespace System.Linq } } - public static TAccumulate Aggregate (this IEnumerable source, TAccumulate seed, Func func) { - if (source == null || func == null) - throw new ArgumentNullException (); + Check.SourceAndFunc (source, func); TAccumulate folded = seed; foreach (TSource element in source) folded = func (folded, element); + return folded; } - public static TResult Aggregate (this IEnumerable source, TAccumulate seed, Func func, Func resultSelector) { - if (source == null) - throw new ArgumentNullException ("source"); - if (func == null) - throw new ArgumentNullException ("func"); + Check.SourceAndFunc (source, func); if (resultSelector == null) throw new ArgumentNullException ("resultSelector"); - TAccumulate result = seed; - foreach (TSource e in source) + var result = seed; + foreach (var e in source) result = func (result, e); + return resultSelector (result); } + #endregion #region All + public static bool All (this IEnumerable source, Func predicate) { - if (source == null || predicate == null) - throw new ArgumentNullException (); + Check.SourceAndPredicate (source, predicate); - foreach (TSource element in source) + foreach (var element in source) if (!predicate (element)) return false; + return true; } + #endregion #region Any + public static bool Any (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); - foreach (TSource element in source) - return true; - return false; + using (var enumerator = source.GetEnumerator ()) + return enumerator.MoveNext (); } - public static bool Any (this IEnumerable source, Func predicate) { - if (source == null || predicate == null) - throw new ArgumentNullException (); + Check.SourceAndPredicate (source, predicate); foreach (TSource element in source) if (predicate (element)) return true; + return false; } + #endregion #region AsEnumerable + public static IEnumerable AsEnumerable (this IEnumerable source) { return source; } + #endregion #region Average + public static double Average (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - long sum = 0; - long counter = 0; - foreach (int element in source) { - sum += element; - counter++; - } - - if (counter == 0) - throw new InvalidOperationException (); - else - return (double) sum / (double) counter; + return Average (source, (a, b) => a + b, (a, b) => (double) a / (double) b); } + public static double Average (this IEnumerable source) + { + return Average (source, (a, b) => a + b, (a, b) => (double) a / (double) b); + } - public static double? Average (this IEnumerable source) + public static double Average (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + return Average (source, (a, b) => a + b, (a, b) => a / b); + } - bool onlyNull = true; - long sum = 0; - long counter = 0; - foreach (int? element in source) { - if (element.HasValue) { - onlyNull = false; - sum += element.Value; - counter++; - } - } - return (onlyNull ? null : (double?) sum / (double?) counter); + public static float Average (this IEnumerable source) + { + return Average (source, (a, b) => a + b, (a, b) => (float) a / (float) b); } + public static decimal Average (this IEnumerable source) + { + return Average (source, (a, b) => a + b, (a, b) => a / b); + } - public static double Average (this IEnumerable source) + static TResult Average (this IEnumerable source, + Func func, Func result) + where TElement : struct + where TAggregate : struct + where TResult : struct { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); - long sum = 0; + var total = default (TAggregate); long counter = 0; - foreach (long element in source) { - sum += element; - counter++; + foreach (var element in source) { + total = func (total, element); + ++counter; } if (counter == 0) throw new InvalidOperationException (); - else - return (double) sum / (double) counter; - } + return result (total, counter); + } - public static double? Average (this IEnumerable source) + static TResult? AverageNullable (this IEnumerable source, + Func func, Func result) + where TElement : struct + where TAggregate : struct + where TResult : struct { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); - bool onlyNull = true; - long sum = 0; + var total = default (TAggregate); long counter = 0; - foreach (long? element in source) { - if (element.HasValue) { - onlyNull = false; - sum += element.Value; - counter++; - } - } - return (onlyNull ? null : (double?) sum / (double?) counter); - } - - - public static double Average (this IEnumerable source) - { - if (source == null) - throw new ArgumentNullException (); + foreach (var element in source) { + if (!element.HasValue) + continue; - double sum = 0; - double counter = 0; - foreach (double element in source) { - sum += element; + total = func (total, element.Value); counter++; } if (counter == 0) - throw new InvalidOperationException (); - else - return sum / counter; - } + return null; + return new TResult? (result (total, counter)); + } - public static double? Average (this IEnumerable source) + public static double? Average (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - bool onlyNull = true; - double sum = 0; - double counter = 0; - foreach (double? element in source) { - if (element.HasValue) { - onlyNull = false; - sum += element.Value; - counter++; - } - } - return (onlyNull ? null : (double?) (sum / counter)); - } + Check.Source (source); + return source.AverageNullable ((a, b) => a + b, (a, b) => (double) a / (double) b); + } - public static decimal Average (this IEnumerable source) + public static double? Average (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - decimal sum = 0; - decimal counter = 0; - foreach (decimal element in source) { - sum += element; - counter++; - } + Check.Source (source); - if (counter == 0) - throw new InvalidOperationException (); - else - return sum / counter; + return source.AverageNullable ((a, b) => a + b, (a, b) => a / b); } + public static double? Average (this IEnumerable source) + { + Check.Source (source); + + return source.AverageNullable ((a, b) => a + b, (a, b) => a / b); + } public static decimal? Average (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - bool onlyNull = true; - decimal sum = 0; - decimal counter = 0; - foreach (decimal? element in source) { - if (element.HasValue) { - onlyNull = false; - sum += element.Value; - counter++; - } - } - return (onlyNull ? null : (decimal?) (sum / counter)); + Check.Source (source); + + return source.AverageNullable ((a, b) => a + b, (a, b) => a / b); } + public static float? Average (this IEnumerable source) + { + Check.Source (source); + + return source.AverageNullable ((a, b) => a + b, (a, b) => (float) a / (float) b); + } public static double Average (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - long sum = 0; - long counter = 0; - foreach (TSource item in source) { - sum += selector (item); - counter++; - } - - if (counter == 0) - throw new InvalidOperationException (); - else - return (double) sum / (double) counter; + return source.Select (selector).Average ((a, b) => a + b, (a, b) => (double) a / (double) b); } - public static double? Average (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - bool onlyNull = true; - long sum = 0; - long counter = 0; - foreach (TSource item in source) { - int? element = selector (item); - if (element.HasValue) { - onlyNull = false; - sum += element.Value; - counter++; - } - } - return (onlyNull ? null : (double?) sum / (double?) counter); + return source.Select (selector).AverageNullable ((a, b) => a + b, (a, b) => (double) a / (double) b); } - public static double Average (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - long sum = 0; - long counter = 0; - foreach (TSource item in source) { - sum += selector (item); - counter++; - } - - if (counter == 0) - throw new InvalidOperationException (); - else - return (double) sum / (double) counter; + return source.Select (selector).Average ((a, b) => a + b, (a, b) => (double) a / (double) b); } - public static double? Average (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - bool onlyNull = true; - long sum = 0; - long counter = 0; - foreach (TSource item in source) { - long? element = selector (item); - if (element.HasValue) { - onlyNull = false; - sum += element.Value; - counter++; - } - } - return (onlyNull ? null : (double?) sum / (double?) counter); + return source.Select (selector).AverageNullable ((a, b) => a + b, (a, b) => (double) a / (double) b); } - public static double Average (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - double sum = 0; - double counter = 0; - foreach (TSource item in source) { - sum += selector (item); - counter++; - } - - if (counter == 0) - throw new InvalidOperationException (); - else - return sum / counter; + return source.Select (selector).Average ((a, b) => a + b, (a, b) => a / b); } - public static double? Average (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - bool onlyNull = true; - double sum = 0; - double counter = 0; - foreach (TSource item in source) { - double? element = selector (item); - if (element.HasValue) { - onlyNull = false; - sum += element.Value; - counter++; - } - } - return (onlyNull ? null : (double?) (sum / counter)); + return source.Select (selector).AverageNullable ((a, b) => a + b, (a, b) => a / b); } - - public static decimal Average (this IEnumerable source, Func selector) + public static float Average (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - decimal sum = 0; - decimal counter = 0; - foreach (TSource item in source) { - sum += selector (item); - counter++; - } + return source.Select (selector).Average ((a, b) => a + b, (a, b) => (float) a / (float) b); + } - if (counter == 0) - throw new InvalidOperationException (); - else - return sum / counter; + public static float? Average (this IEnumerable source, Func selector) + { + Check.SourceAndSelector (source, selector); + + return source.Select (selector).AverageNullable ((a, b) => a + b, (a, b) => (float) a / (float) b); } + public static decimal Average (this IEnumerable source, Func selector) + { + Check.SourceAndSelector (source, selector); + + return source.Select (selector).Average ((a, b) => a + b, (a, b) => a / b); + } public static decimal? Average (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - bool onlyNull = true; - decimal sum = 0; - decimal counter = 0; - foreach (TSource item in source) { - decimal? element = selector (item); - if (element.HasValue) { - onlyNull = false; - sum += element.Value; - counter++; - } - } - return (onlyNull ? null : (decimal?) (sum / counter)); + return source.Select (selector).AverageNullable ((a, b) => a + b, (a, b) => a / b); } + #endregion #region Cast + public static IEnumerable Cast (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); - foreach (object element in source) - yield return (TResult) element; + return CreateCastIterator (source); + } + + static IEnumerable CreateCastIterator (IEnumerable source) + { + foreach (TResult element in source) + yield return element; } + #endregion #region Concat + public static IEnumerable Concat (this IEnumerable first, IEnumerable second) { - if (first == null || second == null) - throw new ArgumentNullException (); + Check.FirstAndSecond (first, second); + + return CreateConcatIterator (first, second); + } + static IEnumerable CreateConcatIterator (IEnumerable first, IEnumerable second) + { foreach (TSource element in first) yield return element; foreach (TSource element in second) @@ -464,94 +361,83 @@ namespace System.Linq public static bool Contains (this IEnumerable source, TSource value) { - if (source is ICollection) { - ICollection collection = (ICollection) source; + var collection = source as ICollection; + if (collection != null) return collection.Contains (value); - } return Contains (source, value, null); } - public static bool Contains (this IEnumerable source, TSource value, IEqualityComparer comparer) { - if (source == null) - throw new ArgumentNullException ("source"); + Check.Source (source); if (comparer == null) comparer = EqualityComparer.Default; - - foreach (TSource e in source) { - if (comparer.Equals (e, value)) + foreach (var element in source) + if (comparer.Equals (element, value)) return true; - } return false; } #endregion #region Count + public static int Count (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); + + var collection = source as ICollection; + if (collection != null) + return collection.Count; - if (source is ICollection) - return ((ICollection) source).Count; - else { - int counter = 0; - foreach (TSource element in source) + int counter = 0; + using (var enumerator = source.GetEnumerator ()) + while (enumerator.MoveNext ()) counter++; - return counter; - } - } + return counter; + } public static int Count (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); int counter = 0; - foreach (TSource element in source) + foreach (var element in source) if (selector (element)) counter++; return counter; } + #endregion #region DefaultIfEmpty public static IEnumerable DefaultIfEmpty (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - bool noYield = true; - foreach (TSource item in source) { - noYield = false; - yield return item; - } - - if (noYield) - yield return default (TSource); + return DefaultIfEmpty (source, default (TSource)); } - public static IEnumerable DefaultIfEmpty (this IEnumerable source, TSource defaultValue) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); + + return CreateDefaultIfEmptyIterator (source, defaultValue); + } - bool noYield = true; + static IEnumerable CreateDefaultIfEmptyIterator (IEnumerable source, TSource defaultValue) + { + bool empty = true; foreach (TSource item in source) { - noYield = false; + empty = false; yield return item; } - if (noYield) + if (empty) yield return defaultValue; } @@ -566,42 +452,56 @@ namespace System.Linq public static IEnumerable Distinct (this IEnumerable source, IEqualityComparer comparer) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); if (comparer == null) comparer = EqualityComparer.Default; - List items = new List (); - foreach (TSource element in source) { - if (!Contains (items, element, comparer)) { + return CreateDistinctIterator (source, comparer); + } + + static IEnumerable CreateDistinctIterator (IEnumerable source, IEqualityComparer comparer) + { + var items = new HashSet (comparer); + foreach (var element in source) { + if (! items.Contains (element)) { items.Add (element); yield return element; } } } + #endregion #region ElementAt + static TSource ElementAt (this IEnumerable source, int index, Fallback fallback) + { + long counter = 0L; + + foreach (var element in source) { + if (index == counter++) + return element; + } + + if (fallback == Fallback.Throw) + throw new ArgumentOutOfRangeException (); + + return default (TSource); + } + public static TSource ElementAt (this IEnumerable source, int index) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); + if (index < 0) throw new ArgumentOutOfRangeException (); - if (source is IList) - return ((IList) source) [index]; - else { - int counter = 0; - foreach (TSource element in source) { - if (counter == index) - return element; - counter++; - } - throw new ArgumentOutOfRangeException (); - } + var list = source as IList; + if (list != null) + return list [index]; + + return source.ElementAt (index, Fallback.Throw); } #endregion @@ -610,34 +510,27 @@ namespace System.Linq public static TSource ElementAtOrDefault (this IEnumerable source, int index) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); + if (index < 0) return default (TSource); - if (source is IList) { - if (((IList) source).Count >= index) - return default (TSource); - else - return ((IList) source) [index]; - } else { - int counter = 0; - foreach (TSource element in source) { - if (counter == index) - return element; - counter++; - } - return default (TSource); - } + var list = source as IList; + if (list != null) + return index < list.Count ? list [index] : default (TSource); + + return source.ElementAt (index, Fallback.Default); } #endregion #region Empty + public static IEnumerable Empty () { - return new List (); + return new TResult [0]; } + #endregion #region Except @@ -649,80 +542,82 @@ namespace System.Linq public static IEnumerable Except (this IEnumerable first, IEnumerable second, IEqualityComparer comparer) { - if (first == null || second == null) - throw new ArgumentNullException (); + Check.FirstAndSecond (first, second); if (comparer == null) comparer = EqualityComparer.Default; - List items = new List (Distinct (first)); - foreach (TSource element in second) { - int index = IndexOf (items, element, comparer); - if (index == -1) - items.Add (element); - else - items.RemoveAt (index); + return CreateExceptIterator (first, second, comparer); + } + + static IEnumerable CreateExceptIterator (IEnumerable first, IEnumerable second, IEqualityComparer comparer) + { + var items = new HashSet (second, comparer); + foreach (var element in first) { + if (!items.Contains (element, comparer)) + yield return element; } - foreach (TSource item in items) - yield return item; } #endregion #region First - public static TSource First (this IEnumerable source) + static TSource First (this IEnumerable source, Func predicate, Fallback fallback) { - if (source == null) - throw new ArgumentNullException (); + foreach (var element in source) + if (predicate (element)) + return element; - foreach (TSource element in source) - return element; + if (fallback == Fallback.Throw) + throw new InvalidOperationException (); - throw new InvalidOperationException (); + return default (TSource); } - - public static TSource First (this IEnumerable source, Func predicate) + public static TSource First (this IEnumerable source) { - if (source == null || predicate == null) - throw new ArgumentNullException (); + Check.Source (source); - foreach (TSource element in source) { - if (predicate (element)) - return element; + var list = source as IList; + if (list != null) { + if (list.Count != 0) + return list [0]; + + throw new InvalidOperationException (); + } else { + using (var enumerator = source.GetEnumerator ()) { + if (enumerator.MoveNext ()) + return enumerator.Current; + } } throw new InvalidOperationException (); } + public static TSource First (this IEnumerable source, Func predicate) + { + Check.SourceAndPredicate (source, predicate); + + return source.First (predicate, Fallback.Throw); + } + #endregion #region FirstOrDefault public static TSource FirstOrDefault (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); - foreach (TSource element in source) - return element; - - return default (TSource); + return source.First (PredicateOf.Always, Fallback.Default); } - public static TSource FirstOrDefault (this IEnumerable source, Func predicate) { - if (source == null || predicate == null) - throw new ArgumentNullException (); + Check.SourceAndPredicate (source, predicate); - foreach (TSource element in source) { - if (predicate (element)) - return element; - } - - return default (TSource); + return source.First (predicate, Fallback.Default); } #endregion @@ -740,20 +635,23 @@ namespace System.Linq return null; } - public static IEnumerable> GroupBy (this IEnumerable source, Func keySelector) { return GroupBy (source, keySelector, null); } - public static IEnumerable> GroupBy (this IEnumerable source, Func keySelector, IEqualityComparer comparer) { - if (source == null || keySelector == null) - throw new ArgumentNullException (); + Check.SourceAndKeySelector (source, keySelector); + + return CreateGroupByIterator (source, keySelector, comparer); + } + static IEnumerable> CreateGroupByIterator (this IEnumerable source, + Func keySelector, IEqualityComparer comparer) + { Dictionary> groups = new Dictionary> (); List nullList = new List (); int counter = 0; @@ -791,19 +689,16 @@ namespace System.Linq } } - public static IEnumerable> GroupBy (this IEnumerable source, Func keySelector, Func elementSelector) { return GroupBy (source, keySelector, elementSelector, null); } - public static IEnumerable> GroupBy (this IEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer) { - if (source == null || keySelector == null || elementSelector == null) - throw new ArgumentNullException (); + Check.SourceAndKeyElementSelectors (source, keySelector, elementSelector); Dictionary> groups = new Dictionary> (); List nullList = new List (); @@ -843,6 +738,43 @@ namespace System.Linq } } + public static IEnumerable GroupBy (this IEnumerable source, + Func keySelector, Func elementSelector, + Func, TResult> resultSelector) + { + return GroupBy (source, keySelector, elementSelector, resultSelector, null); + } + + public static IEnumerable GroupBy (this IEnumerable source, + Func keySelector, Func elementSelector, + Func, TResult> resultSelector, + IEqualityComparer comparer) + { + IEnumerable> groups = GroupBy ( + source, keySelector, elementSelector, comparer); + + foreach (IGrouping group in groups) + yield return resultSelector (group.Key, group); + } + + public static IEnumerable GroupBy (this IEnumerable source, + Func keySelector, + Func, TResult> resultSelector) + { + return GroupBy (source, keySelector, resultSelector, null); + } + + public static IEnumerable GroupBy (this IEnumerable source, + Func keySelector, + Func, TResult> resultSelector, + IEqualityComparer comparer) + { + IEnumerable> groups = GroupBy (source, keySelector, comparer); + + foreach (IGrouping group in groups) + yield return resultSelector (group.Key, group); + } + #endregion # region GroupJoin @@ -859,14 +791,20 @@ namespace System.Linq Func innerKeySelector, Func, TResult> resultSelector, IEqualityComparer comparer) { - if (outer == null || inner == null || outerKeySelector == null || - innerKeySelector == null || resultSelector == null) - throw new ArgumentNullException (); + Check.JoinSelectors (outer, inner, outerKeySelector, innerKeySelector, resultSelector); if (comparer == null) comparer = EqualityComparer.Default; - Lookup innerKeys = ToLookup (inner, innerKeySelector, comparer); + return CreateGroupJoinIterator (outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer); + } + + static IEnumerable CreateGroupJoinIterator (this IEnumerable outer, + IEnumerable inner, Func outerKeySelector, + Func innerKeySelector, Func, TResult> resultSelector, + IEqualityComparer comparer) + { + ILookup innerKeys = ToLookup (inner, innerKeySelector, comparer); /*Dictionary> innerKeys = new Dictionary> (); foreach (U element in inner) { @@ -880,6 +818,8 @@ namespace System.Linq TKey outerKey = outerKeySelector (element); if (innerKeys.Contains (outerKey)) yield return resultSelector (element, innerKeys [outerKey]); + else + yield return resultSelector (element, Empty ()); } } @@ -887,25 +827,27 @@ namespace System.Linq #region Intersect - public static IEnumerable Intersect (this IEnumerable first, IEnumerable second) { - if (first == null || second == null) - throw new ArgumentNullException (); + return Intersect (first, second, null); + } - List items = new List (Distinct (first)); - bool [] marked = new bool [items.Count]; - for (int i = 0; i < marked.Length; i++) - marked [i] = false; + public static IEnumerable Intersect (this IEnumerable first, IEnumerable second, IEqualityComparer comparer) + { + Check.FirstAndSecond (first, second); - foreach (TSource element in second) { - int index = IndexOf (items, element); - if (index != -1) - marked [index] = true; - } - for (int i = 0; i < marked.Length; i++) { - if (marked [i]) - yield return items [i]; + if (comparer == null) + comparer = EqualityComparer.Default; + + return CreateIntersectIterator (first, second, comparer); + } + + static IEnumerable CreateIntersectIterator (IEnumerable first, IEnumerable second, IEqualityComparer comparer) + { + var items = new HashSet (second, comparer); + foreach (TSource element in first) { + if (items.Contains (element)) + yield return element; } } @@ -917,14 +859,19 @@ namespace System.Linq IEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func resultSelector, IEqualityComparer comparer) { - if (outer == null || inner == null || outerKeySelector == null || - innerKeySelector == null || resultSelector == null) - throw new ArgumentNullException (); + Check.JoinSelectors (outer, inner, outerKeySelector, innerKeySelector, resultSelector); if (comparer == null) comparer = EqualityComparer.Default; - Lookup innerKeys = ToLookup (inner, innerKeySelector, comparer); + return CreateJoinIterator (outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer); + } + + static IEnumerable CreateJoinIterator (this IEnumerable outer, + IEnumerable inner, Func outerKeySelector, + Func innerKeySelector, Func resultSelector, IEqualityComparer comparer) + { + ILookup innerKeys = ToLookup (inner, innerKeySelector, comparer); /*Dictionary> innerKeys = new Dictionary> (); foreach (U element in inner) { @@ -947,49 +894,47 @@ namespace System.Linq IEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func resultSelector) { - return Join (outer, inner, outerKeySelector, innerKeySelector, resultSelector, null); + return outer.Join (inner, outerKeySelector, innerKeySelector, resultSelector, null); } - # endregion + + #endregion #region Last - public static TSource Last (this IEnumerable source) + static TSource Last (this IEnumerable source, Func predicate, Fallback fallback) { - if (source == null) - throw new ArgumentNullException (); + var empty = true; + var item = default (TSource); - bool noElements = true; - TSource lastElement = default (TSource); - foreach (TSource element in source) { - if (noElements) noElements = false; - lastElement = element; + foreach (var element in source) { + if (!predicate (element)) + continue; + + item = element; + empty = false; } - if (!noElements) - return lastElement; - else + if (!empty) + return item; + + if (fallback == Fallback.Throw) throw new InvalidOperationException (); + + return item; } - public static TSource Last (this IEnumerable source, - Func predicate) + public static TSource Last (this IEnumerable source) { - if (source == null || predicate == null) - throw new ArgumentNullException (); + Check.Source (source); - bool noElements = true; - TSource lastElement = default (TSource); - foreach (TSource element in source) { - if (predicate (element)) { - if (noElements) noElements = false; - lastElement = element; - } - } + return source.Last (PredicateOf.Always, Fallback.Throw); + } - if (!noElements) - return lastElement; - else - throw new InvalidOperationException (); + public static TSource Last (this IEnumerable source, Func predicate) + { + Check.SourceAndPredicate (source, predicate); + + return source.Last (predicate, Fallback.Throw); } #endregion @@ -998,50 +943,47 @@ namespace System.Linq public static TSource LastOrDefault (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); - TSource lastElement = default (TSource); - foreach (TSource element in source) - lastElement = element; + var list = source as IList; + if (list != null) + return list.Count > 0 ? list [list.Count - 1] : default (TSource); - return lastElement; + return source.Last (PredicateOf.Always, Fallback.Default); } - public static TSource LastOrDefault (this IEnumerable source, - Func predicate) + public static TSource LastOrDefault (this IEnumerable source, Func predicate) { - if (source == null || predicate == null) - throw new ArgumentNullException (); + Check.SourceAndPredicate (source, predicate); - TSource lastElement = default (TSource); - foreach (TSource element in source) { - if (predicate (element)) - lastElement = element; - } - - return lastElement; + return source.Last (predicate, Fallback.Default); } #endregion #region LongCount + public static long LongCount (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); + +#if !NET_2_1 + var array = source as TSource []; + if (array != null) + return array.LongLength; +#endif long counter = 0; - foreach (TSource element in source) - counter++; + using (var enumerator = source.GetEnumerator ()) + while (enumerator.MoveNext ()) + counter++; + return counter; } - public static long LongCount (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); long counter = 0; foreach (TSource element in source) @@ -1057,160 +999,97 @@ namespace System.Linq public static int Max (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); - int maximum = int.MinValue; - int counter = 0; - foreach (int element in source) { - if (element > maximum) - maximum = element; - counter++; - } - - if (counter == 0) - throw new InvalidOperationException (); - else - return maximum; + return Iterate (source, int.MinValue, (a, b) => Math.Max (a, b)); } - - public static int? Max (this IEnumerable source) + public static long Max (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - bool onlyNull = true; - int? maximum = int.MinValue; - foreach (int? element in source) { - if (element.HasValue) { - onlyNull = false; - if (element > maximum) - maximum = element; - } - } - return (onlyNull ? null : maximum); - } + Check.Source (source); + return Iterate (source, long.MinValue, (a, b) => Math.Max (a, b)); + } - public static long Max (this IEnumerable source) + public static double Max (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - long maximum = long.MinValue; - int counter = 0; - foreach (long element in source) { - if (element > maximum) - maximum = element; - counter++; - } + Check.Source (source); - if (counter == 0) - throw new InvalidOperationException (); - else - return maximum; + return Iterate (source, double.MinValue, (a, b) => Math.Max (a, b)); } - - public static long? Max (this IEnumerable source) + public static float Max (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - bool onlyNull = true; - long? maximum = long.MinValue; - foreach (long? element in source) { - if (element.HasValue) { - onlyNull = false; - if (element > maximum) - maximum = element; - } - } - return (onlyNull ? null : maximum); - } + Check.Source (source); + return Iterate (source, float.MinValue, (a, b) => Math.Max (a, b)); + } - public static double Max (this IEnumerable source) + public static decimal Max (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); - double maximum = double.MinValue; - int counter = 0; - foreach (double element in source) { - if (element > maximum) - maximum = element; - counter++; - } + return Iterate (source, decimal.MinValue, (a, b) => Math.Max (a, b)); + } - if (counter == 0) - throw new InvalidOperationException (); - else - return maximum; + public static int? Max (this IEnumerable source) + { + Check.Source (source); + + return IterateNullable (source, int.MinValue, (a, b) => a > b); } + public static long? Max (this IEnumerable source) + { + Check.Source (source); + + return IterateNullable (source, long.MinValue, (a, b) => a > b); + } public static double? Max (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - bool onlyNull = true; - double? maximum = double.MinValue; - foreach (double? element in source) { - if (element.HasValue) { - onlyNull = false; - if (element > maximum) - maximum = element; - } - } - return (onlyNull ? null : maximum); - } + Check.Source (source); + return IterateNullable (source, double.MinValue, (a, b) => a > b); + } - public static decimal Max (this IEnumerable source) + public static float? Max (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - decimal maximum = decimal.MinValue; - int counter = 0; - foreach (decimal element in source) { - if (element > maximum) - maximum = element; - counter++; - } + Check.Source (source); - if (counter == 0) - throw new InvalidOperationException (); - else - return maximum; + return IterateNullable (source, float.MinValue, (a, b) => a > b); } - public static decimal? Max (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - bool onlyNull = true; - decimal? maximum = decimal.MinValue; - foreach (decimal? element in source) { - if (element.HasValue) { - onlyNull = false; - if (element > maximum) - maximum = element; - } - } - return (onlyNull ? null : maximum); + Check.Source (source); + + return IterateNullable (source, decimal.MinValue, (a, b) => a > b); } + static T? IterateNullable (IEnumerable source, T initValue, Func selector) where T : struct + { + bool empty = true; + T? value = initValue; + foreach (var element in source) { + if (!element.HasValue) + continue; + + if (selector (element.Value, value)) + value = element; + + empty = false; + } + + if (empty) + return null; + + return value; + } public static TSource Max (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); bool notAssigned = true; TSource maximum = default (TSource); @@ -1240,180 +1119,121 @@ namespace System.Linq return maximum; } - - public static int Max (this IEnumerable source, - Func selector) + public static int Max (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); - - int maximum = int.MinValue; - int counter = 0; - foreach (TSource item in source) { - int element = selector (item); - if (element > maximum) - maximum = element; - counter++; - } + Check.SourceAndSelector (source, selector); - if (counter == 0) - throw new InvalidOperationException (); - else - return maximum; + return Iterate (source, int.MinValue, (a, b) => Math.Max (selector (a), b)); } - - public static int? Max (this IEnumerable source, - Func selector) + public static long Max (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - bool onlyNull = true; - int? maximum = int.MinValue; - foreach (TSource item in source) { - int? element = selector (item); - if (element.HasValue) { - onlyNull = false; - if (element > maximum) - maximum = element; - } - } - return (onlyNull ? null : maximum); + return Iterate (source, long.MinValue, (a, b) => Math.Max (selector (a), b)); } - - public static long Max (this IEnumerable source, - Func selector) + public static double Max (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); - - long maximum = long.MinValue; - int counter = 0; - foreach (TSource item in source) { - long element = selector (item); - if (element > maximum) - maximum = element; - counter++; - } + Check.SourceAndSelector (source, selector); - if (counter == 0) - throw new InvalidOperationException (); - else - return maximum; + return Iterate (source, double.MinValue, (a, b) => Math.Max (selector (a), b)); } - - public static long? Max (this IEnumerable source, - Func selector) + public static float Max (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - bool onlyNull = true; - long? maximum = long.MinValue; - foreach (TSource item in source) { - long? element = selector (item); - if (element.HasValue) { - onlyNull = false; - if (element > maximum) - maximum = element; - } - } - return (onlyNull ? null : maximum); + return Iterate (source, float.MinValue, (a, b) => Math.Max (selector (a), b)); } - - public static double Max (this IEnumerable source, - Func selector) + public static decimal Max (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - double maximum = double.MinValue; - int counter = 0; - foreach (TSource item in source) { - double element = selector (item); - if (element > maximum) - maximum = element; - counter++; + return Iterate (source, decimal.MinValue, (a, b) => Math.Max (selector (a), b)); + } + + static U Iterate (IEnumerable source, U initValue, Func selector) + { + bool empty = true; + foreach (var element in source) { + initValue = selector (element, initValue); + empty = false; } - if (counter == 0) + if (empty) throw new InvalidOperationException (); - else - return maximum; - } + return initValue; + } - public static double? Max (this IEnumerable source, - Func selector) + static U? IterateNullable (IEnumerable source, U initialValue, Func selector) where U : struct { - if (source == null || selector == null) - throw new ArgumentNullException (); + bool empty = true; + U? value = initialValue; + foreach (var element in source) { + value = selector (element, value); + if (!value.HasValue) + continue; - bool onlyNull = true; - double? maximum = double.MinValue; - foreach (TSource item in source) { - double? element = selector (item); - if (element.HasValue) { - onlyNull = false; - if (element > maximum) - maximum = element; - } + empty = false; } - return (onlyNull ? null : maximum); - } + if (empty) + return null; - public static decimal Max (this IEnumerable source, - Func selector) + return value; + } + + public static int? Max (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - decimal maximum = decimal.MinValue; - int counter = 0; - foreach (TSource item in source) { - decimal element = selector (item); - if (element > maximum) - maximum = element; - counter++; - } + return IterateNullable (source, int.MinValue, (a, b) => { + var v = selector (a); return v > b ? v : b; + }); + } - if (counter == 0) - throw new InvalidOperationException (); - else - return maximum; + public static long? Max (this IEnumerable source, Func selector) + { + Check.SourceAndSelector (source, selector); + + return IterateNullable (source, long.MinValue, (a, b) => { + var v = selector (a); return v > b ? v : b; + }); } + public static double? Max (this IEnumerable source, Func selector) + { + Check.SourceAndSelector (source, selector); + + return IterateNullable (source, double.MinValue, (a, b) => { + var v = selector (a); return v > b ? v : b; + }); + } - public static decimal? Max (this IEnumerable source, - Func selector) + public static float? Max (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - bool onlyNull = true; - decimal? maximum = decimal.MinValue; - foreach (TSource item in source) { - decimal? element = selector (item); - if (element.HasValue) { - onlyNull = false; - if (element > maximum) - maximum = element; - } - } - return (onlyNull ? null : maximum); + return IterateNullable (source, float.MinValue, (a, b) => { + var v = selector (a); return v > b ? v : b; + }); } + public static decimal? Max (this IEnumerable source, Func selector) + { + Check.SourceAndSelector (source, selector); + + return IterateNullable (source, decimal.MinValue, (a, b) => { + var v = selector (a); return v > b ? v : b; + }); + } - public static TResult Max (this IEnumerable source, - Func selector) + public static TResult Max (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); bool notAssigned = true; TResult maximum = default (TResult); @@ -1450,159 +1270,77 @@ namespace System.Linq public static int Min (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - int minimum = int.MaxValue; - int counter = 0; - foreach (int element in source) { - if (element < minimum) - minimum = element; - counter++; - } + Check.Source (source); - if (counter == 0) - throw new InvalidOperationException (); - else - return minimum; + return Iterate (source, int.MaxValue, (a, b) => Math.Min (a, b)); } - - public static int? Min (this IEnumerable source) + public static long Min (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - bool onlyNull = true; - int? minimum = int.MaxValue; - foreach (int? element in source) { - if (element.HasValue) { - onlyNull = false; - if (element < minimum) - minimum = element; - } - } - return (onlyNull ? null : minimum); + Check.Source (source); + + return Iterate (source, long.MaxValue, (a, b) => Math.Min (a, b)); } - public static long Min (this IEnumerable source) + public static double Min (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - long minimum = long.MaxValue; - int counter = 0; - foreach (long element in source) { - if (element < minimum) - minimum = element; - counter++; - } + Check.Source (source); - if (counter == 0) - throw new InvalidOperationException (); - else - return minimum; + return Iterate (source, double.MaxValue, (a, b) => Math.Min (a, b)); } - - public static long? Min (this IEnumerable source) + public static float Min (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - bool onlyNull = true; - long? minimum = long.MaxValue; - foreach (long? element in source) { - if (element.HasValue) { - onlyNull = false; - if (element < minimum) - minimum = element; - } - } - return (onlyNull ? null : minimum); - } + Check.Source (source); + return Iterate (source, float.MaxValue, (a, b) => Math.Min (a, b)); + } - public static double Min (this IEnumerable source) + public static decimal Min (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); - double minimum = double.MaxValue; - int counter = 0; - foreach (double element in source) { - if (element < minimum) - minimum = element; - counter++; - } + return Iterate (source, decimal.MaxValue, (a, b) => Math.Min (a, b)); + } - if (counter == 0) - throw new InvalidOperationException (); - else - return minimum; + public static int? Min (this IEnumerable source) + { + Check.Source (source); + + return IterateNullable (source, int.MaxValue, (a, b) => a < b); } + public static long? Min (this IEnumerable source) + { + Check.Source (source); + + return IterateNullable (source, long.MaxValue, (a, b) => a < b); + } public static double? Min (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - bool onlyNull = true; - double? minimum = double.MaxValue; - foreach (double? element in source) { - if (element.HasValue) { - onlyNull = false; - if (element < minimum) - minimum = element; - } - } - return (onlyNull ? null : minimum); - } + Check.Source (source); + return IterateNullable (source, double.MaxValue, (a, b) => a < b); + } - public static decimal Min (this IEnumerable source) + public static float? Min (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - decimal minimum = decimal.MaxValue; - int counter = 0; - foreach (decimal element in source) { - if (element < minimum) - minimum = element; - counter++; - } + Check.Source (source); - if (counter == 0) - throw new InvalidOperationException (); - else - return minimum; + return IterateNullable (source, float.MaxValue, (a, b) => a < b); } - public static decimal? Min (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - bool onlyNull = true; - decimal? minimum = decimal.MaxValue; - foreach (decimal? element in source) { - if (element.HasValue) { - onlyNull = false; - if (element < minimum) - minimum = element; - } - } - return (onlyNull ? null : minimum); - } + Check.Source (source); + return IterateNullable (source, decimal.MaxValue, (a, b) => a < b); + } public static TSource Min (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); bool notAssigned = true; TSource minimum = default (TSource); @@ -1632,180 +1370,89 @@ namespace System.Linq return minimum; } - - public static int Min (this IEnumerable source, - Func selector) + public static int Min (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); - - int minimum = int.MaxValue; - int counter = 0; - foreach (TSource item in source) { - int element = selector (item); - if (element < minimum) - minimum = element; - counter++; - } + Check.SourceAndSelector (source, selector); - if (counter == 0) - throw new InvalidOperationException (); - else - return minimum; + return Iterate (source, int.MaxValue, (a, b) => Math.Min (selector (a), b)); } - - public static int? Min (this IEnumerable source, - Func selector) + public static long Min (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - bool onlyNull = true; - int? minimum = int.MaxValue; - foreach (TSource item in source) { - int? element = selector (item); - if (element.HasValue) { - onlyNull = false; - if (element < minimum) - minimum = element; - } - } - return (onlyNull ? null : minimum); + return Iterate (source, long.MaxValue, (a, b) => Math.Min (selector (a), b)); } - - public static long Min (this IEnumerable source, - Func selector) + public static double Min (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); - - long minimum = long.MaxValue; - int counter = 0; - foreach (TSource item in source) { - long element = selector (item); - if (element < minimum) - minimum = element; - counter++; - } + Check.SourceAndSelector (source, selector); - if (counter == 0) - throw new InvalidOperationException (); - else - return minimum; + return Iterate (source, double.MaxValue, (a, b) => Math.Min (selector (a), b)); } - - public static long? Min (this IEnumerable source, - Func selector) + public static float Min (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - bool onlyNull = true; - long? minimum = long.MaxValue; - foreach (TSource item in source) { - long? element = selector (item); - if (element.HasValue) { - onlyNull = false; - if (element < minimum) - minimum = element; - } - } - return (onlyNull ? null : minimum); + return Iterate (source, float.MaxValue, (a, b) => Math.Min (selector (a), b)); } - - public static double Min (this IEnumerable source, - Func selector) + public static decimal Min (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); - - double minimum = double.MaxValue; - int counter = 0; - foreach (TSource item in source) { - double element = selector (item); - if (element < minimum) - minimum = element; - counter++; - } + Check.SourceAndSelector (source, selector); - if (counter == 0) - throw new InvalidOperationException (); - else - return minimum; + return Iterate (source, decimal.MaxValue, (a, b) => Math.Min (selector (a), b)); } - - public static double? Min (this IEnumerable source, - Func selector) + public static int? Min (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - bool onlyNull = true; - double? minimum = double.MaxValue; - foreach (TSource item in source) { - double? element = selector (item); - if (element.HasValue) { - onlyNull = false; - if (element < minimum) - minimum = element; - } - } - return (onlyNull ? null : minimum); + return IterateNullable (source, int.MaxValue, (a, b) => { + var v = selector (a); return v < b ? v : b; + }); } - - public static decimal Min (this IEnumerable source, - Func selector) + public static long? Min (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); - - decimal minimum = decimal.MaxValue; - int counter = 0; - foreach (TSource item in source) { - decimal element = selector (item); - if (element < minimum) - minimum = element; - counter++; - } + Check.SourceAndSelector (source, selector); - if (counter == 0) - throw new InvalidOperationException (); - else - return minimum; + return IterateNullable (source, long.MaxValue, (a, b) => { + var v = selector (a); return v < b ? v : b; + }); } + public static float? Min (this IEnumerable source, Func selector) + { + Check.SourceAndSelector (source, selector); + + return IterateNullable (source, float.MaxValue, (a, b) => { + var v = selector (a); return v < b ? v : b; + }); + } - public static decimal? Min (this IEnumerable source, - Func selector) + public static double? Min (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - bool onlyNull = true; - decimal? minimum = decimal.MaxValue; - foreach (TSource item in source) { - decimal? element = selector (item); - if (element.HasValue) { - onlyNull = false; - if (element < minimum) - minimum = element; - } - } - return (onlyNull ? null : minimum); + return IterateNullable (source, double.MaxValue, (a, b) => { + var v = selector (a); return v < b ? v : b; + }); } + public static decimal? Min (this IEnumerable source, Func selector) + { + Check.SourceAndSelector (source, selector); + + return IterateNullable (source, decimal.MaxValue, (a, b) => { + var v = selector (a); return v < b ? v : b; + }); + } - public static TResult Min (this IEnumerable source, - Func selector) + public static TResult Min (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); bool notAssigned = true; TResult minimum = default (TResult); @@ -1842,9 +1489,13 @@ namespace System.Linq public static IEnumerable OfType (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); + return CreateOfTypeIterator (source); + } + + static IEnumerable CreateOfTypeIterator (IEnumerable source) + { foreach (object element in source) if (element is TResult) yield return (TResult) element; @@ -1860,16 +1511,13 @@ namespace System.Linq return OrderBy (source, keySelector, null); } - public static IOrderedEnumerable OrderBy (this IEnumerable source, Func keySelector, IComparer comparer) { - if (source == null || keySelector == null) - throw new ArgumentNullException (); + Check.SourceAndKeySelector (source, keySelector); - return new InternalOrderedSequence ( - source, keySelector, comparer, false); + return new OrderedSequence (source, keySelector, comparer, SortDirection.Ascending); } #endregion @@ -1882,15 +1530,12 @@ namespace System.Linq return OrderByDescending (source, keySelector, null); } - public static IOrderedEnumerable OrderByDescending (this IEnumerable source, Func keySelector, IComparer comparer) { - if (source == null || keySelector == null) - throw new ArgumentNullException (); + Check.SourceAndKeySelector (source, keySelector); - return new InternalOrderedSequence ( - source, keySelector, comparer, true); + return new OrderedSequence (source, keySelector, comparer, SortDirection.Descending); } #endregion @@ -1899,10 +1544,20 @@ namespace System.Linq public static IEnumerable Range (int start, int count) { - if (count < 0 || (start + count - 1) > int.MaxValue) + if (count < 0) + throw new ArgumentOutOfRangeException ("count"); + + long upto = ((long) start + count) - 1; + + if (upto > int.MaxValue) throw new ArgumentOutOfRangeException (); - for (int i = start; i < (start + count - 1); i++) + return CreateRangeIterator (start, (int) upto); + } + + static IEnumerable CreateRangeIterator (int start, int upto) + { + for (int i = start; i <= upto; i++) yield return i; } @@ -1915,46 +1570,62 @@ namespace System.Linq if (count < 0) throw new ArgumentOutOfRangeException (); + return CreateRepeatIterator (element, count); + } + + static IEnumerable CreateRepeatIterator (TResult element, int count) + { for (int i = 0; i < count; i++) yield return element; } #endregion - #region Reverse public static IEnumerable Reverse (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); - List list = new List (source); - list.Reverse (); - return list; + var list = source as IList; + if (list == null) + list = new List (source); + + return CreateReverseIterator (list); + } + + static IEnumerable CreateReverseIterator (IList source) + { + for (int i = source.Count; i > 0; --i) + yield return source [i - 1]; } #endregion #region Select - public static IEnumerable Select (this IEnumerable source, - Func selector) + public static IEnumerable Select (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - foreach (TSource element in source) - yield return selector (element); + return CreateSelectIterator (source, selector); } + static IEnumerable CreateSelectIterator (IEnumerable source, Func selector) + { + foreach (var element in source) + yield return selector (element); + } - public static IEnumerable Select (this IEnumerable source, - Func selector) + public static IEnumerable Select (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); + + return CreateSelectIterator (source, selector); + } + static IEnumerable CreateSelectIterator (IEnumerable source, Func selector) + { int counter = 0; foreach (TSource element in source) { yield return selector (element, counter); @@ -1966,27 +1637,32 @@ namespace System.Linq #region SelectMany - public static IEnumerable SelectMany (this IEnumerable source, - Func> selector) + public static IEnumerable SelectMany (this IEnumerable source, Func> selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); + return CreateSelectManyIterator (source, selector); + } + + static IEnumerable CreateSelectManyIterator (IEnumerable source, Func> selector) + { foreach (TSource element in source) foreach (TResult item in selector (element)) yield return item; } - - public static IEnumerable SelectMany (this IEnumerable source, - Func> selector) + public static IEnumerable SelectMany (this IEnumerable source, Func> selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); + return CreateSelectManyIterator (source, selector); + } + + static IEnumerable CreateSelectManyIterator (IEnumerable source, Func> selector) + { int counter = 0; foreach (TSource element in source) { - foreach (TResult item in selector (element, counter++)) + foreach (TResult item in selector (element, counter)) yield return item; counter++; } @@ -1995,20 +1671,30 @@ namespace System.Linq public static IEnumerable SelectMany (this IEnumerable source, Func> collectionSelector, Func selector) { - if (source == null || collectionSelector == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndCollectionSelectors (source, collectionSelector, selector); + + return CreateSelectManyIterator (source, collectionSelector, selector); + } + static IEnumerable CreateSelectManyIterator (IEnumerable source, + Func> collectionSelector, Func selector) + { foreach (TSource element in source) foreach (TCollection collection in collectionSelector (element)) yield return selector (element, collection); } public static IEnumerable SelectMany (this IEnumerable source, - Func> collectionSelector, Func selector) + Func> collectionSelector, Func selector) { - if (source == null || collectionSelector == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndCollectionSelectors (source, collectionSelector, selector); + + return CreateSelectManyIterator (source, collectionSelector, selector); + } + static IEnumerable CreateSelectManyIterator (IEnumerable source, + Func> collectionSelector, Func selector) + { int counter = 0; foreach (TSource element in source) foreach (TCollection collection in collectionSelector (element, counter++)) @@ -2019,46 +1705,40 @@ namespace System.Linq #region Single - public static TSource Single (this IEnumerable source) + static TSource Single (this IEnumerable source, Func predicate, Fallback fallback) { - if (source == null) - throw new ArgumentNullException (); + var found = false; + var item = default (TSource); - bool otherElement = false; - TSource singleElement = default (TSource); - foreach (TSource element in source) { - if (otherElement) throw new InvalidOperationException (); - if (!otherElement) otherElement = true; - singleElement = element; + foreach (var element in source) { + if (!predicate (element)) + continue; + + if (found) + throw new InvalidOperationException (); + + found = true; + item = element; } - if (otherElement) - return singleElement; - else + if (!found && fallback == Fallback.Throw) throw new InvalidOperationException (); - } + return item; + } - public static TSource Single (this IEnumerable source, - Func predicate) + public static TSource Single (this IEnumerable source) { - if (source == null || predicate == null) - throw new ArgumentNullException (); + Check.Source (source); - bool otherElement = false; - TSource singleElement = default (TSource); - foreach (TSource element in source) { - if (predicate (element)) { - if (otherElement) throw new InvalidOperationException (); - if (!otherElement) otherElement = true; - singleElement = element; - } - } + return source.Single (PredicateOf.Always, Fallback.Throw); + } - if (otherElement) - return singleElement; - else - throw new InvalidOperationException (); + public static TSource Single (this IEnumerable source, Func predicate) + { + Check.SourceAndPredicate (source, predicate); + + return source.Single (predicate, Fallback.Throw); } #endregion @@ -2067,67 +1747,53 @@ namespace System.Linq public static TSource SingleOrDefault (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - bool otherElement = false; - TSource singleElement = default (TSource); - foreach (TSource element in source) { - if (otherElement) throw new InvalidOperationException (); - if (!otherElement) otherElement = true; - singleElement = element; - } + Check.Source (source); - return singleElement; + return source.Single (PredicateOf.Always, Fallback.Default); } - - public static TSource SingleOrDefault (this IEnumerable source, - Func predicate) + public static TSource SingleOrDefault (this IEnumerable source, Func predicate) { - if (source == null || predicate == null) - throw new ArgumentNullException (); - - bool otherElement = false; - TSource singleElement = default (TSource); - foreach (TSource element in source) { - if (predicate (element)) { - if (otherElement) throw new InvalidOperationException (); - if (!otherElement) otherElement = true; - singleElement = element; - } - } + Check.SourceAndPredicate (source, predicate); - return singleElement; + return source.Single (predicate, Fallback.Default); } #endregion #region Skip + public static IEnumerable Skip (this IEnumerable source, int count) { - if (source == null) - throw new NotSupportedException (); + Check.Source (source); + return CreateSkipIterator (source, count); + } + + static IEnumerable CreateSkipIterator (IEnumerable source, int count) + { int i = 0; - foreach (TSource e in source) { - if (++i < count) + foreach (var element in source) { + if (i++ < count) continue; - yield return e; + + yield return element; } } + #endregion #region SkipWhile - - public static IEnumerable SkipWhile ( - this IEnumerable source, - Func predicate) + public static IEnumerable SkipWhile (this IEnumerable source, Func predicate) { - if (source == null || predicate == null) - throw new ArgumentNullException (); + Check.SourceAndPredicate (source, predicate); + + return CreateSkipWhileIterator (source, predicate); + } + static IEnumerable CreateSkipWhileIterator (IEnumerable source, Func predicate) + { bool yield = false; foreach (TSource element in source) { @@ -2141,13 +1807,15 @@ namespace System.Linq } } - - public static IEnumerable SkipWhile (this IEnumerable source, - Func predicate) + public static IEnumerable SkipWhile (this IEnumerable source, Func predicate) { - if (source == null || predicate == null) - throw new ArgumentNullException (); + Check.SourceAndPredicate (source, predicate); + + return CreateSkipWhileIterator (source, predicate); + } + static IEnumerable CreateSkipWhileIterator (IEnumerable source, Func predicate) + { int counter = 0; bool yield = false; @@ -2169,245 +1837,203 @@ namespace System.Linq public static int Sum (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException ("source"); - - int sum = 0; - foreach (int element in source) - sum += element; + Check.Source (source); - return sum; + return Sum (source, (a, b) => a + b); } - - public static int Sum (this IEnumerable source, Func selector) + public static int? Sum (this IEnumerable source) { - if (source == null || selector == null) - throw new ArgumentNullException (); - - int sum = 0; - foreach (TSource element in source) - sum += selector (element); + Check.Source (source); - return sum; + return source.SumNullable (0, (a, b) => a.HasValue ? a + b : a); } - - public static int? Sum (this IEnumerable source) + public static int Sum (this IEnumerable source, Func selector) { - if (source == null) - throw new ArgumentNullException (); - - int? sum = 0; - foreach (int? element in source) - if (element.HasValue) - sum += element.Value; + Check.SourceAndSelector (source, selector); - return sum; + return Sum (source, (a, b) => a + selector (b)); } - public static int? Sum (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); - - int? sum = 0; - foreach (TSource element in source) { - int? item = selector (element); - if (item.HasValue) - sum += item.Value; - } + Check.SourceAndSelector (source, selector); - return sum; + return source.SumNullable (0, (a, b) => { + var value = selector (b); + return value.HasValue ? a + value.Value : a; + }); } - public static long Sum (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); - long sum = 0; - foreach (long element in source) - sum += element; - - return sum; + return Sum (source, (a, b) => a + b); } - - public static long Sum (this IEnumerable source, Func selector) + public static long? Sum (this IEnumerable source) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.Source (source); - long sum = 0; - foreach (TSource element in source) - sum += selector (element); - - return sum; + return source.SumNullable (0, (a, b) => a.HasValue ? a + b : a); } - - public static long? Sum (this IEnumerable source) + public static long Sum (this IEnumerable source, Func selector) { - if (source == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - long? sum = 0; - foreach (long? element in source) - if (element.HasValue) - sum += element.Value; - - return sum; + return Sum (source, (a, b) => a + selector (b)); } - public static long? Sum (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - long? sum = 0; - foreach (TSource element in source) { - long? item = selector (element); - if (item.HasValue) - sum += item.Value; - } - - return sum; + return source.SumNullable (0, (a, b) => { + var value = selector (b); + return value.HasValue ? a + value.Value : a; + }); } - public static double Sum (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); - - double sum = 0; - foreach (double element in source) - sum += element; + Check.Source (source); - return sum; + return Sum (source, (a, b) => a + b); } + public static double? Sum (this IEnumerable source) + { + Check.Source (source); + + return source.SumNullable (0, (a, b) => a.HasValue ? a + b : a); + } public static double Sum (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); - - double sum = 0; - foreach (TSource element in source) - sum += selector (element); + Check.SourceAndSelector (source, selector); - return sum; + return Sum (source, (a, b) => a + selector (b)); } - - public static double? Sum (this IEnumerable source) + public static double? Sum (this IEnumerable source, Func selector) { - if (source == null) - throw new ArgumentNullException (); + Check.SourceAndSelector (source, selector); - double? sum = 0; - foreach (double? element in source) - if (element.HasValue) - sum += element.Value; - - return sum; + return source.SumNullable (0, (a, b) => { + var value = selector (b); + return value.HasValue ? a + value.Value : a; + }); } + public static float Sum (this IEnumerable source) + { + Check.Source (source); - public static double? Sum (this IEnumerable source, Func selector) + return Sum (source, (a, b) => a + b); + } + + public static float? Sum (this IEnumerable source) { - if (source == null || selector == null) - throw new ArgumentNullException (); + Check.Source (source); - double? sum = 0; - foreach (TSource element in source) { - double? item = selector (element); - if (item.HasValue) - sum += item.Value; - } + return source.SumNullable (0, (a, b) => a.HasValue ? a + b : a); + } - return sum; + public static float Sum (this IEnumerable source, Func selector) + { + Check.SourceAndSelector (source, selector); + + return Sum (source, (a, b) => a + selector (b)); } + public static float? Sum (this IEnumerable source, Func selector) + { + Check.SourceAndSelector (source, selector); + + return source.SumNullable (0, (a, b) => { + var value = selector (b); + return value.HasValue ? a + value.Value : a; + }); + } public static decimal Sum (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); - decimal sum = 0; - foreach (decimal element in source) - sum += element; - - return sum; + return Sum (source, (a, b) => a + b); } + public static decimal? Sum (this IEnumerable source) + { + Check.Source (source); + + return source.SumNullable (0, (a, b) => a.HasValue ? a + b : a); + } public static decimal Sum (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); - - decimal sum = 0; - foreach (TSource element in source) - sum += selector (element); + Check.SourceAndSelector (source, selector); - return sum; + return Sum (source, (a, b) => a + selector (b)); } - - public static decimal? Sum (this IEnumerable source) + public static decimal? Sum (this IEnumerable source, Func selector) { - if (source == null) - throw new ArgumentNullException (); - - decimal? sum = 0; - foreach (decimal? element in source) - if (element.HasValue) - sum += element.Value; + Check.SourceAndSelector (source, selector); - return sum; + return source.SumNullable (0, (a, b) => { + var value = selector (b); + return value.HasValue ? a + value.Value : a; + }); } - - public static decimal? Sum (this IEnumerable source, Func selector) + static TR Sum (this IEnumerable source, Func selector) { - if (source == null || selector == null) - throw new ArgumentNullException (); + TR total = default (TR); + long counter = 0; + foreach (var element in source) { + total = selector (total, element); + ++counter; + } - decimal? sum = 0; - foreach (TSource element in source) { - decimal? item = selector (element); - if (item.HasValue) - sum += item.Value; + return total; + } + + static TR SumNullable (this IEnumerable source, TR zero, Func selector) + { + TR total = zero; + foreach (var element in source) { + total = selector (total, element); } - return sum; + return total; } #endregion + #region Take public static IEnumerable Take (this IEnumerable source, int count) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); + return CreateTakeIterator (source, count); + } + + static IEnumerable CreateTakeIterator (IEnumerable source, int count) + { if (count <= 0) yield break; - else { - int counter = 0; - foreach (TSource element in source) { - yield return element; - counter++; - if (counter == count) - yield break; - } + + int counter = 0; + foreach (TSource element in source) { + yield return element; + + if (++counter == count) + yield break; } } @@ -2417,28 +2043,36 @@ namespace System.Linq public static IEnumerable TakeWhile (this IEnumerable source, Func predicate) { - if (source == null || predicate == null) - throw new ArgumentNullException (); + Check.SourceAndPredicate (source, predicate); - foreach (TSource element in source) { - if (predicate (element)) - yield return element; - else + return CreateTakeWhileIterator (source, predicate); + } + + static IEnumerable CreateTakeWhileIterator (IEnumerable source, Func predicate) + { + foreach (var element in source) { + if (!predicate (element)) yield break; + + yield return element; } } public static IEnumerable TakeWhile (this IEnumerable source, Func predicate) { - if (source == null || predicate == null) - throw new ArgumentNullException (); + Check.SourceAndPredicate (source, predicate); + + return CreateTakeWhileIterator (source, predicate); + } + static IEnumerable CreateTakeWhileIterator (IEnumerable source, Func predicate) + { int counter = 0; - foreach (TSource element in source) { - if (predicate (element, counter)) - yield return element; - else + foreach (var element in source) { + if (!predicate (element, counter)) yield break; + + yield return element; counter++; } } @@ -2452,12 +2086,10 @@ namespace System.Linq return ThenBy (source, keySelector, null); } - public static IOrderedEnumerable ThenBy (this IOrderedEnumerable source, Func keySelector, IComparer comparer) { - if (source == null || keySelector == null) - throw new ArgumentNullException (); + Check.SourceAndKeySelector (source, keySelector); return source.CreateOrderedEnumerable (keySelector, comparer, false); } @@ -2472,12 +2104,10 @@ namespace System.Linq return ThenByDescending (source, keySelector, null); } - public static IOrderedEnumerable ThenByDescending (this IOrderedEnumerable source, Func keySelector, IComparer comparer) { - if (source == null || keySelector == null) - throw new ArgumentNullException (); + Check.SourceAndKeySelector (source, keySelector); return source.CreateOrderedEnumerable (keySelector, comparer, true); } @@ -2485,47 +2115,72 @@ namespace System.Linq #endregion #region ToArray + public static TSource [] ToArray (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException (); + Check.Source (source); + + var collection = source as ICollection; + if (collection != null) { + var array = new TSource [collection.Count]; + collection.CopyTo (array, 0); + return array; + } - List list = new List (source); - return list.ToArray (); + return new List (source).ToArray (); } #endregion #region ToDictionary - public static Dictionary ToDictionary (this IEnumerable source, Func keySelector, Func elementSelector) + public static Dictionary ToDictionary (this IEnumerable source, + Func keySelector, Func elementSelector) { return ToDictionary (source, keySelector, elementSelector, null); } - - public static Dictionary ToDictionary (this IEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer) + public static Dictionary ToDictionary (this IEnumerable source, + Func keySelector, Func elementSelector, IEqualityComparer comparer) { - if (source == null) - throw new ArgumentNullException ("source"); - if (keySelector == null) - throw new ArgumentNullException ("keySelector"); - if (elementSelector == null) - throw new ArgumentNullException ("elementSelector"); - - Dictionary dict = new Dictionary (comparer); - foreach (TSource e in source) { + Check.SourceAndKeyElementSelectors (source, keySelector, elementSelector); + + if (comparer == null) + comparer = EqualityComparer.Default; + + var dict = new Dictionary (comparer); + foreach (var e in source) dict.Add (keySelector (e), elementSelector (e)); - } return dict; } + + public static Dictionary ToDictionary (this IEnumerable source, + Func keySelector) + { + return ToDictionary (source, keySelector, null); + } + + public static Dictionary ToDictionary (this IEnumerable source, + Func keySelector, IEqualityComparer comparer) + { + Check.SourceAndKeySelector (source, keySelector); + + if (comparer == null) + comparer = EqualityComparer.Default; + + var dict = new Dictionary (comparer); + foreach (var e in source) + dict.Add (keySelector (e), e); + + return dict; + } + #endregion #region ToList public static List ToList (this IEnumerable source) { - if (source == null) - throw new ArgumentNullException ("source"); + Check.Source (source); return new List (source); } @@ -2533,19 +2188,17 @@ namespace System.Linq #region ToLookup - public static Lookup ToLookup (this IEnumerable source, Func keySelector) + public static ILookup ToLookup (this IEnumerable source, Func keySelector) { return ToLookup (source, keySelector, null); } - - public static Lookup ToLookup (this IEnumerable source, + public static ILookup ToLookup (this IEnumerable source, Func keySelector, IEqualityComparer comparer) { - if (source == null || keySelector == null) - throw new ArgumentNullException (); + Check.SourceAndKeySelector (source, keySelector); - Dictionary> dictionary = new Dictionary> (comparer ?? EqualityComparer.Default); + var dictionary = new Dictionary> (comparer ?? EqualityComparer.Default); foreach (TSource element in source) { TKey key = keySelector (element); if (key == null) @@ -2557,19 +2210,16 @@ namespace System.Linq return new Lookup (dictionary); } - - public static Lookup ToLookup (this IEnumerable source, + public static ILookup ToLookup (this IEnumerable source, Func keySelector, Func elementSelector) { return ToLookup (source, keySelector, elementSelector, null); } - - public static Lookup ToLookup (this IEnumerable source, + public static ILookup ToLookup (this IEnumerable source, Func keySelector, Func elementSelector, IEqualityComparer comparer) { - if (source == null || keySelector == null || elementSelector == null) - throw new ArgumentNullException (); + Check.SourceAndKeyElementSelectors (source, keySelector, elementSelector); Dictionary> dictionary = new Dictionary> (comparer ?? EqualityComparer.Default); foreach (TSource element in source) { @@ -2585,32 +2235,67 @@ namespace System.Linq #endregion - #region ToSequence + #region SequenceEqual + + public static bool SequenceEqual (this IEnumerable first, IEnumerable second) + { + return first.SequenceEqual (second, null); + } - public static IEnumerable ToSequence (this IEnumerable source) + public static bool SequenceEqual (this IEnumerable first, IEnumerable second, IEqualityComparer comparer) { - return (IEnumerable) source; + Check.FirstAndSecond (first, second); + + if (comparer == null) + comparer = EqualityComparer.Default; + + var first_enumerator = first.GetEnumerator (); + var second_enumerator = second.GetEnumerator (); + + while (first_enumerator.MoveNext ()) { + if (!second_enumerator.MoveNext ()) + return false; + + if (!comparer.Equals (first_enumerator.Current, second_enumerator.Current)) + return false; + } + + return !second_enumerator.MoveNext (); } #endregion #region Union - public static IEnumerable Union (this IEnumerable first, IEnumerable second) { - if (first == null || second == null) - throw new ArgumentNullException (); + Check.FirstAndSecond (first, second); - List items = new List (); - foreach (TSource element in first) { - if (IndexOf (items, element) == -1) { + return first.Union (second, null); + } + + public static IEnumerable Union (this IEnumerable first, IEnumerable second, IEqualityComparer comparer) + { + Check.FirstAndSecond (first, second); + + if (comparer == null) + comparer = EqualityComparer.Default; + + return CreateUnionIterator (first, second, comparer); + } + + static IEnumerable CreateUnionIterator (IEnumerable first, IEnumerable second, IEqualityComparer comparer) + { + var items = new HashSet (comparer); + foreach (var element in first) { + if (! items.Contains (element)) { items.Add (element); yield return element; } } - foreach (TSource element in second) { - if (IndexOf (items, element) == -1) { + + foreach (var element in second) { + if (! items.Contains (element, comparer)) { items.Add (element); yield return element; } @@ -2621,24 +2306,29 @@ namespace System.Linq #region Where - public static IEnumerable Where (this IEnumerable source, - Func predicate) + public static IEnumerable Where (this IEnumerable source, Func predicate) { - if (source == null || predicate == null) - throw new ArgumentNullException (); + Check.SourceAndPredicate (source, predicate); + return CreateWhereIterator (source, predicate); + } + + static IEnumerable CreateWhereIterator (IEnumerable source, Func predicate) + { foreach (TSource element in source) if (predicate (element)) yield return element; } - - public static IEnumerable Where (this IEnumerable source, - Func predicate) + public static IEnumerable Where (this IEnumerable source, Func predicate) { - if (source == null || predicate == null) - throw new ArgumentNullException (); + Check.SourceAndPredicate (source, predicate); + return CreateWhereIterator (source, predicate); + } + + static IEnumerable CreateWhereIterator (this IEnumerable source, Func predicate) + { int counter = 0; foreach (TSource element in source) { if (predicate (element, counter)) @@ -2649,61 +2339,20 @@ namespace System.Linq #endregion - // These methods are not included in the - // .NET Standard Query Operators Specification, - // but they provide additional useful commands - - #region Compare - - private static bool Equals (T first, T second) - { - // Mostly, values in Enumerable - // sequences need to be compared using - // Equals and GetHashCode - - if (first == null || second == null) - return (first == null && second == null); - else - return ((first.Equals (second) || - first.GetHashCode () == second.GetHashCode ())); - } - - #endregion - - #region IndexOf - - static int IndexOf (this IEnumerable source, T item, IEqualityComparer comparer) - { - if (comparer == null) - comparer = EqualityComparer.Default; - - int counter = 0; - foreach (T element in source) { - if (comparer.Equals (element, item)) - return counter; - counter++; - } - // The item was not found - return -1; - } - - static int IndexOf (this IEnumerable source, T item) - { - return IndexOf (source, item, null); + class ReadOnlyCollectionOf { + public static readonly ReadOnlyCollection Empty = new ReadOnlyCollection (new T [0]); } - #endregion - #region ToReadOnlyCollection - internal static ReadOnlyCollection ToReadOnlyCollection (IEnumerable source) + internal static ReadOnlyCollection ToReadOnlyCollection (this IEnumerable source) { if (source == null) - return new ReadOnlyCollection (new List ()); + return ReadOnlyCollectionOf.Empty; - if (typeof (ReadOnlyCollection).IsInstanceOfType (source)) - return source as ReadOnlyCollection; + var ro = source as ReadOnlyCollection; + if (ro != null) + return ro; - return new ReadOnlyCollection (ToArray (source)); + return new ReadOnlyCollection (source.ToArray ()); } - #endregion } }