Merge pull request #900 from Blewzman/FixAggregateExceptionGetBaseException
[mono.git] / mcs / class / dlr / Runtime / Microsoft.Dynamic / Utils / CollectionUtils.cs
1 /* ****************************************************************************
2  *
3  * Copyright (c) Microsoft Corporation. 
4  *
5  * This source code is subject to terms and conditions of the Apache License, Version 2.0. A 
6  * copy of the license can be found in the License.html file at the root of this distribution. If 
7  * you cannot locate the  Apache License, Version 2.0, please send an email to 
8  * dlr@microsoft.com. By using this source code in any fashion, you are agreeing to be bound 
9  * by the terms of the Apache License, Version 2.0.
10  *
11  * You must not remove this notice, or any other, from this software.
12  *
13  *
14  * ***************************************************************************/
15
16 using System;
17 using System.Collections;
18 using System.Collections.Generic;
19
20 namespace Microsoft.Scripting.Utils {
21     /// <summary>
22     /// Allows wrapping of proxy types (like COM RCWs) to expose their IEnumerable functionality
23     /// which is supported after casting to IEnumerable, even though Reflection will not indicate 
24     /// IEnumerable as a supported interface
25     /// </summary>
26     [System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Design", "CA1010:CollectionsShouldImplementGenericInterface")] // TODO
27     public class EnumerableWrapper : IEnumerable {
28         private IEnumerable _wrappedObject;
29         public EnumerableWrapper(IEnumerable o) {
30             _wrappedObject = o;
31         }
32
33         public IEnumerator GetEnumerator() {
34             return _wrappedObject.GetEnumerator();
35         }
36     }
37
38     public static class CollectionUtils {
39 #if !FEATURE_VARIANCE
40         public static IEnumerable<T> Cast<S, T>(this IEnumerable<S> sequence) where S : T {
41             foreach (var item in sequence) {
42                 yield return (T)item;
43             }
44         }
45 #else
46         public static IEnumerable<T> Cast<S, T>(this IEnumerable<S> sequence) where S : T {
47             return (IEnumerable<T>)sequence;
48         }
49 #endif
50
51         public static IEnumerable<TSuper> ToCovariant<T, TSuper>(IEnumerable<T> enumerable)
52             where T : TSuper {
53 #if FEATURE_VARIANCE
54             return (IEnumerable<TSuper>)enumerable;
55 #else
56             return new CovariantConvertor<T, TSuper>(enumerable);
57 #endif
58         }
59
60         public static void AddRange<T>(ICollection<T> collection, IEnumerable<T> items) {
61             ContractUtils.RequiresNotNull(collection, "collection");
62             ContractUtils.RequiresNotNull(items, "items");
63
64             List<T> list = collection as List<T>;
65             if (list != null) {
66                 list.AddRange(items);
67             } else {
68                 foreach (T item in items) {
69                     collection.Add(item);
70                 }
71             }
72         }
73
74         public static void AddRange<T>(this IList<T> list, IEnumerable<T> items) {
75             foreach (var item in items) {
76                 list.Add(item);
77             }
78         }
79
80         public static IEnumerable<T> ToEnumerable<T>(IEnumerable enumerable) {
81             foreach (T item in enumerable) {
82                 yield return item;
83             }
84         }
85
86         public static IEnumerator<TSuper> ToCovariant<T, TSuper>(IEnumerator<T> enumerator)
87             where T : TSuper {
88
89             ContractUtils.RequiresNotNull(enumerator, "enumerator");
90
91             while (enumerator.MoveNext()) {
92                 yield return enumerator.Current;
93             }
94         }
95
96         private class CovariantConvertor<T, TSuper> : IEnumerable<TSuper> where T : TSuper {
97             private IEnumerable<T> _enumerable;
98
99             public CovariantConvertor(IEnumerable<T> enumerable) {
100                 ContractUtils.RequiresNotNull(enumerable, "enumerable");
101                 _enumerable = enumerable;
102             }
103
104             public IEnumerator<TSuper> GetEnumerator() {
105                 return CollectionUtils.ToCovariant<T, TSuper>(_enumerable.GetEnumerator());
106             }
107
108             IEnumerator IEnumerable.GetEnumerator() {
109                 return GetEnumerator();
110             }
111         }
112
113         public static IDictionaryEnumerator ToDictionaryEnumerator(IEnumerator<KeyValuePair<object, object>> enumerator) {
114             return new DictionaryEnumerator(enumerator);
115         }
116
117         private sealed class DictionaryEnumerator : IDictionaryEnumerator {
118             private readonly IEnumerator<KeyValuePair<object, object>> _enumerator;
119
120             public DictionaryEnumerator(IEnumerator<KeyValuePair<object, object>> enumerator) {
121                 _enumerator = enumerator;
122             }
123
124             public DictionaryEntry Entry {
125                 get { return new DictionaryEntry(_enumerator.Current.Key, _enumerator.Current.Value); }
126             }
127
128             public object Key {
129                 get { return _enumerator.Current.Key; }
130             }
131
132             public object Value {
133                 get { return _enumerator.Current.Value; }
134             }
135
136             public object Current {
137                 get { return Entry; }
138             }
139
140             public bool MoveNext() {
141                 return _enumerator.MoveNext();
142             }
143
144             public void Reset() {
145                 _enumerator.Reset();
146             }
147         }
148
149         public static List<T> MakeList<T>(T item) {
150             List<T> result = new List<T>();
151             result.Add(item);
152             return result;
153         }
154
155         public static int CountOf<T>(IList<T> list, T item) where T : IEquatable<T> {
156             if (list == null) return 0;
157
158             int result = 0;
159             for (int i = 0; i < list.Count; i++) {
160                 if (list[i].Equals(item)) {
161                     result++;
162                 }
163             }
164             return result;
165         }
166
167         public static int Max(this IEnumerable<int> values) {
168             ContractUtils.RequiresNotNull(values, "values");
169
170             int result = Int32.MinValue;
171             foreach (var value in values) {
172                 if (value > result) {
173                     result = value;
174                 }
175             }
176             return result;
177         }
178
179         public static bool TrueForAll<T>(IEnumerable<T> collection, Predicate<T> predicate) {
180             ContractUtils.RequiresNotNull(collection, "collection");
181             ContractUtils.RequiresNotNull(predicate, "predicate");
182
183             foreach (T item in collection) {
184                 if (!predicate(item)) return false;
185             }
186
187             return true;
188         }
189
190         public static IList<TRet> ConvertAll<T, TRet>(IList<T> collection, Func<T, TRet> predicate) {
191             ContractUtils.RequiresNotNull(collection, "collection");
192             ContractUtils.RequiresNotNull(predicate, "predicate");
193
194             List<TRet> res = new List<TRet>(collection.Count);
195             foreach (T item in collection) {
196                 res.Add(predicate(item));
197             }
198
199             return res;
200         }
201
202         public static List<T> GetRange<T>(IList<T> list, int index, int count) {
203             ContractUtils.RequiresNotNull(list, "list");
204             ContractUtils.RequiresArrayRange(list, index, count, "index", "count");
205
206             List<T> result = new List<T>(count);
207             int stop = index + count;
208             for (int i = index; i < stop; i++) {
209                 result.Add(list[i]);
210             }
211             return result;
212         }
213
214         public static void InsertRange<T>(IList<T> collection, int index, IEnumerable<T> items) {
215             ContractUtils.RequiresNotNull(collection, "collection");
216             ContractUtils.RequiresNotNull(items, "items");
217             ContractUtils.RequiresArrayInsertIndex(collection, index, "index");
218
219             List<T> list = collection as List<T>;
220             if (list != null) {
221                 list.InsertRange(index, items);
222             } else {
223                 int i = index;
224                 foreach (T obj in items) {
225                     collection.Insert(i++, obj);
226                 }
227             }
228         }
229
230         public static void RemoveRange<T>(IList<T> collection, int index, int count) {
231             ContractUtils.RequiresNotNull(collection, "collection");
232             ContractUtils.RequiresArrayRange(collection, index, count, "index", "count");
233
234             List<T> list = collection as List<T>;
235             if (list != null) {
236                 list.RemoveRange(index, count);
237             } else {
238                 for (int i = index + count - 1; i >= index; i--) {
239                     collection.RemoveAt(i);
240                 }
241             }
242         }
243
244         public static int FindIndex<T>(this IList<T> collection, Predicate<T> predicate) {
245             ContractUtils.RequiresNotNull(collection, "collection");
246             ContractUtils.RequiresNotNull(predicate, "predicate");
247
248             for (int i = 0; i < collection.Count; i++) {
249                 if (predicate(collection[i])) {
250                     return i;
251                 }
252             }
253             return -1;
254         }
255
256         public static IList<T> ToSortedList<T>(this ICollection<T> collection, Comparison<T> comparison) {
257             ContractUtils.RequiresNotNull(collection, "collection");
258             ContractUtils.RequiresNotNull(comparison, "comparison");
259
260             var array = new T[collection.Count];
261             collection.CopyTo(array, 0);
262             Array.Sort(array, comparison);
263             return array;
264         }
265
266         public static T[] ToReverseArray<T>(this IList<T> list) {
267             ContractUtils.RequiresNotNull(list, "list");
268             T[] result = new T[list.Count];
269             for (int i = 0; i < result.Length; i++) {
270                 result[i] = list[result.Length - 1 - i];
271             }
272             return result;
273         }
274
275
276 #if SILVERLIGHT || WIN8 || WP75
277         // HashSet.CreateSetComparer not available on Silverlight
278         public static IEqualityComparer<HashSet<T>> CreateSetComparer<T>() {
279             return new HashSetEqualityComparer<T>();
280         }
281
282         class HashSetEqualityComparer<T> : IEqualityComparer<HashSet<T>> {
283             private IEqualityComparer<T> _comparer;
284
285             public HashSetEqualityComparer() {
286                 _comparer = EqualityComparer<T>.Default;
287             }
288
289             public bool Equals(HashSet<T> x, HashSet<T> y) {
290                 if (x == y) {
291                     return true;
292                 } else if (x == null || y == null || x.Count != y.Count) {
293                     return false;
294                 }
295
296                 foreach (T value in x) {
297                     if (!y.Contains(value)) {
298                         return false;
299                     }
300                 }
301
302                 return true;
303             }
304
305             public int GetHashCode(HashSet<T> obj) {
306                 int res = 6551;
307                 if (obj != null) {
308                     foreach (T t in obj) {
309                         res = res ^ _comparer.GetHashCode(t);
310                     }
311                 }
312
313                 return res;
314             }
315         }
316 #else
317         public static IEqualityComparer<HashSet<T>> CreateSetComparer<T>() {
318             return HashSet<T>.CreateSetComparer();
319         }
320 #endif
321     }
322 }