Merge pull request #607 from maksimenko/master
[mono.git] / mcs / class / System.Core / System.Linq / Enumerable.cs
index e6442d9bc43e21a37d1fd0a919ac9e8daba6b4b5..71f392690a47666dc15496f0ffbcca58fe56dae7 100644 (file)
@@ -45,9 +45,11 @@ namespace System.Linq
                        Throw
                }
 
+#if !FULL_AOT_RUNTIME
                static class PredicateOf<T> {
                        public static readonly Func<T, bool> Always = (t) => true;
                }
+#endif
 
                static class Function<T> {
                        public static readonly Func<T, T> Identity = (t) => t;
@@ -71,7 +73,7 @@ namespace System.Linq
                        // if zero elements and treat the first element differently
                        using (var enumerator = source.GetEnumerator ()) {
                                if (!enumerator.MoveNext ())
-                                       throw new InvalidOperationException ("No elements in source list");
+                                       throw EmptySequence ();
 
                                TSource folded = enumerator.Current;
                                while (enumerator.MoveNext ())
@@ -171,7 +173,7 @@ namespace System.Linq
                                count++;
                        }
                        if (count == 0)
-                               throw new InvalidOperationException ();
+                               throw EmptySequence ();
                        return total / (double) count;
                }
 
@@ -186,7 +188,7 @@ namespace System.Linq
                                count++;
                        }
                        if (count == 0)
-                               throw new InvalidOperationException ();
+                               throw EmptySequence ();
                        return total / (double) count;
                }
 
@@ -201,7 +203,7 @@ namespace System.Linq
                                count++;
                        }
                        if (count == 0)
-                               throw new InvalidOperationException ();
+                               throw EmptySequence ();
                        return total / count;
                }
 
@@ -216,7 +218,7 @@ namespace System.Linq
                                count++;
                        }
                        if (count == 0)
-                               throw new InvalidOperationException ();
+                               throw EmptySequence ();
                        return total / count;
                }
 
@@ -231,7 +233,7 @@ namespace System.Linq
                                count++;
                        }
                        if (count == 0)
-                               throw new InvalidOperationException ();
+                               throw EmptySequence ();
                        return total / count;
                }
 
@@ -379,7 +381,7 @@ namespace System.Linq
                                count++;
                        }
                        if (count == 0)
-                               throw new InvalidOperationException ();
+                               throw EmptySequence ();
                        return total / (double) count;
                }
 
@@ -416,7 +418,7 @@ namespace System.Linq
                                count++;
                        }
                        if (count == 0)
-                               throw new InvalidOperationException ();
+                               throw EmptySequence ();
                        return total / (double) count;
 
                }
@@ -454,7 +456,7 @@ namespace System.Linq
                                count++;
                        }
                        if (count == 0)
-                               throw new InvalidOperationException ();
+                               throw EmptySequence ();
                        return total / count;
 
                }
@@ -493,7 +495,7 @@ namespace System.Linq
                                count++;
                        }
                        if (count == 0)
-                               throw new InvalidOperationException ();
+                               throw EmptySequence ();
                        return total / count;
                }
 
@@ -530,7 +532,7 @@ namespace System.Linq
                                count++;
                        }
                        if (count == 0)
-                               throw new InvalidOperationException ();
+                               throw EmptySequence ();
                        return total / count;
                }
 
@@ -637,19 +639,19 @@ namespace System.Linq
                        int counter = 0;
                        using (var enumerator = source.GetEnumerator ())
                                while (enumerator.MoveNext ())
-                                       counter++;
+                                       checked { counter++; }
 
                        return counter;
                }
 
-               public static int Count<TSource> (this IEnumerable<TSource> source, Func<TSource, bool> selector)
+               public static int Count<TSource> (this IEnumerable<TSource> source, Func<TSource, bool> predicate)
                {
-                       Check.SourceAndSelector (source, selector);
+                       Check.SourceAndSelector (source, predicate);
 
                        int counter = 0;
                        foreach (var element in source)
-                               if (selector (element))
-                                       counter++;
+                               if (predicate (element))
+                                       checked { counter++; }
 
                        return counter;
                }
@@ -795,7 +797,7 @@ namespace System.Linq
                {
                        var items = new HashSet<TSource> (second, comparer);
                        foreach (var element in first) {
-                               if (!items.Contains (element, comparer))
+                               if (items.Add (element))
                                        yield return element;
                        }
                }
@@ -811,7 +813,7 @@ namespace System.Linq
                                        return element;
 
                        if (fallback == Fallback.Throw)
-                               throw new InvalidOperationException ();
+                               throw NoMatchingElement ();
 
                        return default (TSource);
                }
@@ -831,7 +833,7 @@ namespace System.Linq
                                }
                        }
 
-                       throw new InvalidOperationException ("The source sequence is empty");
+                       throw EmptySequence ();
                }
 
                public static TSource First<TSource> (this IEnumerable<TSource> source, Func<TSource, bool> predicate)
@@ -849,7 +851,15 @@ namespace System.Linq
                {
                        Check.Source (source);
 
+#if !FULL_AOT_RUNTIME
                        return source.First (PredicateOf<TSource>.Always, Fallback.Default);
+#else
+                       // inline the code to reduce dependency o generic causing AOT errors on device (e.g. bug #3285)
+                       foreach (var element in source)
+                               return element;
+
+                       return default (TSource);
+#endif
                }
 
                public static TSource FirstOrDefault<TSource> (this IEnumerable<TSource> source, Func<TSource, bool> predicate)
@@ -1191,7 +1201,7 @@ namespace System.Linq
                                return item;
 
                        if (fallback == Fallback.Throw)
-                               throw new InvalidOperationException ();
+                               throw NoMatchingElement ();
 
                        return item;
                }
@@ -1202,14 +1212,29 @@ namespace System.Linq
 
                        var collection = source as ICollection<TSource>;
                        if (collection != null && collection.Count == 0)
-                               throw new InvalidOperationException ();
+                               throw EmptySequence ();
 
                        var list = source as IList<TSource>;
                        if (list != null)
                                return list [list.Count - 1];
 
+#if !FULL_AOT_RUNTIME
                        return source.Last (PredicateOf<TSource>.Always, Fallback.Throw);
-               }
+#else
+                       var empty = true;
+                       var item = default (TSource);
+
+                       foreach (var element in source) {
+                               item = element;
+                               empty = false;
+                       }
+
+                       if (!empty)
+                               return item;
+
+                       throw EmptySequence ();
+#endif
+        }
 
                public static TSource Last<TSource> (this IEnumerable<TSource> source, Func<TSource, bool> predicate)
                {
@@ -1230,7 +1255,22 @@ namespace System.Linq
                        if (list != null)
                                return list.Count > 0 ? list [list.Count - 1] : default (TSource);
 
+#if !FULL_AOT_RUNTIME
                        return source.Last (PredicateOf<TSource>.Always, Fallback.Default);
+#else
+                       var empty = true;
+                       var item = default (TSource);
+
+                       foreach (var element in source) {
+                               item = element;
+                               empty = false;
+                       }
+
+                       if (!empty)
+                               return item;
+
+                       return item;
+#endif
                }
 
                public static TSource LastOrDefault<TSource> (this IEnumerable<TSource> source, Func<TSource, bool> predicate)
@@ -1262,13 +1302,13 @@ namespace System.Linq
                        return counter;
                }
 
-               public static long LongCount<TSource> (this IEnumerable<TSource> source, Func<TSource, bool> selector)
+               public static long LongCount<TSource> (this IEnumerable<TSource> source, Func<TSource, bool> predicate)
                {
-                       Check.SourceAndSelector (source, selector);
+                       Check.SourceAndSelector (source, predicate);
 
                        long counter = 0;
                        foreach (TSource element in source)
-                               if (selector (element))
+                               if (predicate (element))
                                        counter++;
 
                        return counter;
@@ -1289,7 +1329,7 @@ namespace System.Linq
                                empty = false;
                        }
                        if (empty)
-                               throw new InvalidOperationException ();
+                               throw EmptySequence();
                        return max;
                }
 
@@ -1304,7 +1344,7 @@ namespace System.Linq
                                empty = false;
                        }
                        if (empty)
-                               throw new InvalidOperationException ();
+                               throw EmptySequence ();
                        return max;
                }
 
@@ -1319,7 +1359,7 @@ namespace System.Linq
                                empty = false;
                        }
                        if (empty)
-                               throw new InvalidOperationException ();
+                               throw EmptySequence ();
                        return max;
                }
 
@@ -1334,7 +1374,7 @@ namespace System.Linq
                                empty = false;
                        }
                        if (empty)
-                               throw new InvalidOperationException ();
+                               throw EmptySequence ();
                        return max;
                }
 
@@ -1349,7 +1389,7 @@ namespace System.Linq
                                empty = false;
                        }
                        if (empty)
-                               throw new InvalidOperationException ();
+                               throw EmptySequence ();
                        return max;
                }
 
@@ -1487,7 +1527,7 @@ namespace System.Linq
                                                max = element;
                                }
                                if (empty)
-                                       throw new InvalidOperationException ();
+                                       throw EmptySequence ();
                        }
                        return max;
                }
@@ -1503,7 +1543,7 @@ namespace System.Linq
                                empty = false;
                        }
                        if (empty)
-                               throw new InvalidOperationException ();
+                               throw NoMatchingElement ();
                        return max;
                }
 
@@ -1518,7 +1558,7 @@ namespace System.Linq
                                empty = false;
                        }
                        if (empty)
-                               throw new InvalidOperationException ();
+                               throw NoMatchingElement ();
                        return max;
                }
 
@@ -1533,7 +1573,7 @@ namespace System.Linq
                                empty = false;
                        }
                        if (empty)
-                               throw new InvalidOperationException ();
+                               throw NoMatchingElement ();
                        return max;
                }
 
@@ -1548,7 +1588,7 @@ namespace System.Linq
                                empty = false;
                        }
                        if (empty)
-                               throw new InvalidOperationException ();
+                               throw NoMatchingElement ();
                        return max;
                }
 
@@ -1563,7 +1603,7 @@ namespace System.Linq
                                empty = false;
                        }
                        if (empty)
-                               throw new InvalidOperationException ();
+                               throw NoMatchingElement ();
                        return max;
                }
 
@@ -1576,7 +1616,7 @@ namespace System.Linq
                        }
 
                        if (empty)
-                               throw new InvalidOperationException ();
+                               throw NoMatchingElement ();
 
                        return initValue;
                }
@@ -1709,7 +1749,7 @@ namespace System.Linq
                                empty = false;
                        }
                        if (empty)
-                               throw new InvalidOperationException ();
+                               throw EmptySequence ();
                        return min;
                }
 
@@ -1724,7 +1764,7 @@ namespace System.Linq
                                empty = false;
                        }
                        if (empty)
-                               throw new InvalidOperationException ();
+                               throw EmptySequence ();
                        return min;
                }
 
@@ -1739,7 +1779,7 @@ namespace System.Linq
                                empty = false;
                        }
                        if (empty)
-                               throw new InvalidOperationException ();
+                               throw EmptySequence ();
                        return min;
                }
 
@@ -1754,7 +1794,7 @@ namespace System.Linq
                                empty = false;
                        }
                        if (empty)
-                               throw new InvalidOperationException ();
+                               throw EmptySequence ();
                        return min;
                }
 
@@ -1769,7 +1809,7 @@ namespace System.Linq
                                empty = false;
                        }
                        if (empty)
-                               throw new InvalidOperationException ();
+                               throw EmptySequence ();
                        return min;
                }
 
@@ -1906,7 +1946,7 @@ namespace System.Linq
                                                min = element;
                                }
                                if (empty)
-                                       throw new InvalidOperationException ();
+                                       throw EmptySequence ();
                        }
                        return min;
                }
@@ -1922,7 +1962,7 @@ namespace System.Linq
                                empty = false;
                        }
                        if (empty)
-                               throw new InvalidOperationException ();
+                               throw NoMatchingElement ();
                        return min;
                }
 
@@ -1937,7 +1977,7 @@ namespace System.Linq
                                empty = false;
                        }
                        if (empty)
-                               throw new InvalidOperationException ();
+                               throw NoMatchingElement ();
                        return min;
                }
 
@@ -1952,7 +1992,7 @@ namespace System.Linq
                                empty = false;
                        }
                        if (empty)
-                               throw new InvalidOperationException ();
+                               throw NoMatchingElement ();
                        return min;
                }
 
@@ -1967,7 +2007,7 @@ namespace System.Linq
                                empty = false;
                        }
                        if (empty)
-                               throw new InvalidOperationException ();
+                               throw NoMatchingElement ();
                        return min;
                }
 
@@ -1982,7 +2022,7 @@ namespace System.Linq
                                empty = false;
                        }
                        if (empty)
-                               throw new InvalidOperationException ();
+                               throw NoMatchingElement ();
                        return min;
                }
 
@@ -2163,18 +2203,16 @@ namespace System.Linq
                        if (count < 0)
                                throw new ArgumentOutOfRangeException ("count");
 
-                       long upto = ((long) start + count) - 1;
-
-                       if (upto > int.MaxValue)
+                       if (((long) start + count) - 1L > int.MaxValue)
                                throw new ArgumentOutOfRangeException ();
 
-                       return CreateRangeIterator (start, (int) upto);
+                       return CreateRangeIterator (start, count);
                }
 
-               static IEnumerable<int> CreateRangeIterator (int start, int upto)
+               static IEnumerable<int> CreateRangeIterator (int start, int count)
                {
-                       for (int i = start; i <= upto; i++)
-                               yield return i;
+                       for (int i = 0; i < count; i++)
+                               yield return start + i;
                }
 
                #endregion
@@ -2208,12 +2246,10 @@ namespace System.Linq
 
                static IEnumerable<TSource> CreateReverseIterator<TSource> (IEnumerable<TSource> source)
                {
-                       var list = source as IList<TSource>;
-                       if (list == null)
-                               list = new List<TSource> (source);
+                       var array = source.ToArray ();
 
-                       for (int i = list.Count - 1; i >= 0; i--)
-                               yield return list [i];
+                       for (int i = array.Length - 1; i >= 0; i--)
+                               yield return array [i];
                }
 
                #endregion
@@ -2285,11 +2321,11 @@ namespace System.Linq
                }
 
                public static IEnumerable<TResult> SelectMany<TSource, TCollection, TResult> (this IEnumerable<TSource> source,
-                       Func<TSource, IEnumerable<TCollection>> collectionSelector, Func<TSource, TCollection, TResult> selector)
+                       Func<TSource, IEnumerable<TCollection>> collectionSelector, Func<TSource, TCollection, TResult> resultSelector)
                {
-                       Check.SourceAndCollectionSelectors (source, collectionSelector, selector);
+                       Check.SourceAndCollectionSelectors (source, collectionSelector, resultSelector);
 
-                       return CreateSelectManyIterator (source, collectionSelector, selector);
+                       return CreateSelectManyIterator (source, collectionSelector, resultSelector);
                }
 
                static IEnumerable<TResult> CreateSelectManyIterator<TSource, TCollection, TResult> (IEnumerable<TSource> source,
@@ -2301,11 +2337,11 @@ namespace System.Linq
                }
 
                public static IEnumerable<TResult> SelectMany<TSource, TCollection, TResult> (this IEnumerable<TSource> source,
-                       Func<TSource, int, IEnumerable<TCollection>> collectionSelector, Func<TSource, TCollection, TResult> selector)
+                       Func<TSource, int, IEnumerable<TCollection>> collectionSelector, Func<TSource, TCollection, TResult> resultSelector)
                {
-                       Check.SourceAndCollectionSelectors (source, collectionSelector, selector);
+                       Check.SourceAndCollectionSelectors (source, collectionSelector, resultSelector);
 
-                       return CreateSelectManyIterator (source, collectionSelector, selector);
+                       return CreateSelectManyIterator (source, collectionSelector, resultSelector);
                }
 
                static IEnumerable<TResult> CreateSelectManyIterator<TSource, TCollection, TResult> (IEnumerable<TSource> source,
@@ -2331,14 +2367,14 @@ namespace System.Linq
                                        continue;
 
                                if (found)
-                                       throw new InvalidOperationException ();
+                                       throw MoreThanOneMatchingElement ();
 
                                found = true;
                                item = element;
                        }
 
                        if (!found && fallback == Fallback.Throw)
-                               throw new InvalidOperationException ();
+                               throw NoMatchingElement ();
 
                        return item;
                }
@@ -2347,8 +2383,26 @@ namespace System.Linq
                {
                        Check.Source (source);
 
+#if !FULL_AOT_RUNTIME
                        return source.Single (PredicateOf<TSource>.Always, Fallback.Throw);
-               }
+#else
+                       var found = false;
+                       var item = default (TSource);
+
+                       foreach (var element in source) {
+                               if (found)
+                                       throw MoreThanOneElement ();
+
+                               found = true;
+                               item = element;
+                       }
+
+                       if (!found)
+                               throw NoMatchingElement ();
+
+                       return item;
+#endif
+        }
 
                public static TSource Single<TSource> (this IEnumerable<TSource> source, Func<TSource, bool> predicate)
                {
@@ -2365,8 +2419,23 @@ namespace System.Linq
                {
                        Check.Source (source);
 
+#if !FULL_AOT_RUNTIME
                        return source.Single (PredicateOf<TSource>.Always, Fallback.Default);
-               }
+#else
+                       var found = false;
+                       var item = default (TSource);
+
+                       foreach (var element in source) {
+                               if (found)
+                                       throw MoreThanOneMatchingElement ();
+
+                               found = true;
+                               item = element;
+                       }
+
+                       return item;
+#endif
+        }
 
                public static TSource SingleOrDefault<TSource> (this IEnumerable<TSource> source, Func<TSource, bool> predicate)
                {
@@ -2766,6 +2835,12 @@ namespace System.Linq
                {
                        Check.SourceAndKeySelector (source, keySelector);
 
+#if FULL_AOT_RUNTIME
+                       var oe = source as OrderedEnumerable <TSource>;
+                       if (oe != null)
+                               return oe.CreateOrderedEnumerable (keySelector, comparer, false);
+#endif
+
                        return source.CreateOrderedEnumerable (keySelector, comparer, false);
                }
 
@@ -2784,6 +2859,11 @@ namespace System.Linq
                {
                        Check.SourceAndKeySelector (source, keySelector);
 
+#if FULL_AOT_RUNTIME
+                       var oe = source as OrderedEnumerable <TSource>;
+                       if (oe != null)
+                               return oe.CreateOrderedEnumerable (keySelector, comparer, true);
+#endif
                        return source.CreateOrderedEnumerable (keySelector, comparer, true);
                }
 
@@ -2795,14 +2875,34 @@ namespace System.Linq
                {
                        Check.Source (source);
 
+                       TSource[] array;
                        var collection = source as ICollection<TSource>;
                        if (collection != null) {
-                               var array = new TSource [collection.Count];
+                               if (collection.Count == 0)
+                                       return EmptyOf<TSource>.Instance;
+                               
+                               array = new TSource [collection.Count];
                                collection.CopyTo (array, 0);
                                return array;
                        }
 
-                       return new List<TSource> (source).ToArray ();
+                       int pos = 0;
+                       array = EmptyOf<TSource>.Instance;
+                       foreach (var element in source) {
+                               if (pos == array.Length) {
+                                       if (pos == 0)
+                                               array = new TSource [4];
+                                       else
+                                               Array.Resize (ref array, pos * 2);
+                               }
+
+                               array[pos++] = element;
+                       }
+
+                       if (pos != array.Length)
+                               Array.Resize (ref array, pos);
+                       
+                       return array;
                }
 
                #endregion
@@ -2862,7 +2962,7 @@ namespace System.Linq
                public static ILookup<TKey, TSource> ToLookup<TSource, TKey> (this IEnumerable<TSource> source,
                        Func<TSource, TKey> keySelector, IEqualityComparer<TKey> comparer)
                {
-                       return ToLookup<TSource, TKey, TSource> (source, keySelector, element => element, comparer);
+                       return ToLookup<TSource, TKey, TSource> (source, keySelector, Function<TSource>.Identity, comparer);
                }
 
                public static ILookup<TKey, TElement> ToLookup<TSource, TKey, TElement> (this IEnumerable<TSource> source,
@@ -2963,7 +3063,7 @@ namespace System.Linq
                        }
 
                        foreach (var element in second) {
-                               if (! items.Contains (element, comparer)) {
+                               if (! items.Contains (element)) {
                                        items.Add (element);
                                        yield return element;
                                }
@@ -2972,7 +3072,7 @@ namespace System.Linq
 
                #endregion
                
-#if NET_4_0            
+#if NET_4_0
                #region Zip
                
                public static IEnumerable<TResult> Zip<TFirst, TSecond, TResult> (this IEnumerable<TFirst> first, IEnumerable<TSecond> second, Func<TFirst, TSecond, TResult> resultSelector)
@@ -3005,6 +3105,11 @@ namespace System.Linq
                {
                        Check.SourceAndPredicate (source, predicate);
 
+                       // It cannot be IList<TSource> because it may break on user implementation
+                       var array = source as TSource[];
+                       if (array != null)
+                               return CreateWhereIterator (array, predicate);
+
                        return CreateWhereIterator (source, predicate);
                }
 
@@ -3015,14 +3120,27 @@ namespace System.Linq
                                        yield return element;
                }
 
+               static IEnumerable<TSource> CreateWhereIterator<TSource> (TSource[] source, Func<TSource, bool> predicate)
+               {
+                       for (int i = 0; i < source.Length; ++i) {
+                               var element = source [i];
+                               if (predicate (element))
+                                       yield return element;
+                       }
+               }       
+
                public static IEnumerable<TSource> Where<TSource> (this IEnumerable<TSource> source, Func<TSource, int, bool> predicate)
                {
                        Check.SourceAndPredicate (source, predicate);
 
+                       var array = source as TSource[];
+                       if (array != null)
+                               return CreateWhereIterator (array, predicate);
+
                        return CreateWhereIterator (source, predicate);
                }
 
-               static IEnumerable<TSource> CreateWhereIterator<TSource> (this IEnumerable<TSource> source, Func<TSource, int, bool> predicate)
+               static IEnumerable<TSource> CreateWhereIterator<TSource> (IEnumerable<TSource> source, Func<TSource, int, bool> predicate)
                {
                        int counter = 0;
                        foreach (TSource element in source) {
@@ -3032,6 +3150,15 @@ namespace System.Linq
                        }
                }
 
+               static IEnumerable<TSource> CreateWhereIterator<TSource> (TSource[] source, Func<TSource, int, bool> predicate)
+               {
+                       for (int i = 0; i < source.Length; ++i) {
+                               var element = source [i];
+                               if (predicate (element, i))
+                                       yield return element;
+                       }
+               }
+
                #endregion
 
                internal static ReadOnlyCollection<TSource> ToReadOnlyCollection<TSource> (this IEnumerable<TSource> source)
@@ -3045,5 +3172,26 @@ namespace System.Linq
 
                        return new ReadOnlyCollection<TSource> (source.ToArray<TSource> ());
                }
+
+               #region Exception helpers
+
+               static Exception EmptySequence ()
+               {
+                       return new InvalidOperationException (Locale.GetText ("Sequence contains no elements"));
+               }
+               static Exception NoMatchingElement ()
+               {
+                       return new InvalidOperationException (Locale.GetText ("Sequence contains no matching element"));
+               }
+               static Exception MoreThanOneElement ()
+               {
+                       return new InvalidOperationException (Locale.GetText ("Sequence contains more than one element"));
+               }
+               static Exception MoreThanOneMatchingElement ()
+               {
+                       return new InvalidOperationException (Locale.GetText ("Sequence contains more than one matching element"));
+               }
+
+               #endregion
        }
 }