Merge pull request #900 from Blewzman/FixAggregateExceptionGetBaseException
[mono.git] / mcs / class / System.Core / Test / System.Linq / EnumerableAsQueryableTest.cs
1 //
2 // EnumerableAsQueryableTest.cs
3 //
4 // Authors:
5 //      Roei Erez (roeie@mainsoft.com)
6 //
7 // Copyright (C) 2007 Novell, Inc (http://www.novell.com)
8 //
9 // Permission is hereby granted, free of charge, to any person obtaining
10 // a copy of this software and associated documentation files (the
11 // "Software"), to deal in the Software without restriction, including
12 // without limitation the rights to use, copy, modify, merge, publish,
13 // distribute, sublicense, and/or sell copies of the Software, and to
14 // permit persons to whom the Software is furnished to do so, subject to
15 // the following conditions:
16 //
17 // The above copyright notice and this permission notice shall be
18 // included in all copies or substantial portions of the Software.
19 //
20 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
21 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
22 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
23 // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
24 // LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
25 // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
26 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
27 //
28
29 using System;
30 using System.Collections.Generic;
31 using System.Linq;
32 using System.Text;
33 using NUnit.Framework;
34 using System.Linq.Expressions;
35 using System.Runtime.CompilerServices;
36 using System.Reflection;
37 using System.Collections;
38
39 namespace MonoTests.System.Linq {
40
41         [TestFixture]
42         public class EnumerableAsQueryableTest {
43
44                 int [] _array;
45                 IQueryable<int> _src;
46
47                 [SetUp]
48                 public void MyTestCleanup ()
49                 {
50                         _array = new int [] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };
51                         _src = _array.AsQueryable<int> ();
52                 }
53
54                 [Test]
55                 public void NewQueryableExpression ()
56                 {
57                         var queryable = _array.AsQueryable ();
58                         var expression = queryable.Expression;
59
60                         Assert.AreEqual (ExpressionType.Constant, expression.NodeType);
61
62                         var constant = (ConstantExpression) expression;
63
64                         Assert.AreEqual (queryable, constant.Value);
65                 }
66
67                 [Test]
68                 public void Aggregate ()
69                 {
70                     Assert.AreEqual (_src.Aggregate<int> ((n, m) => n + m), _array.Aggregate<int> ((n, m) => n + m));
71                 }
72
73                 [Test]
74                 public void All ()
75                 {
76                     Assert.AreEqual (_src.All<int> ((n) => n < 11), _array.All<int> ((n) => n < 11));
77                     Assert.AreEqual (_src.All<int> ((n) => n < 10), _array.All<int> ((n) => n < 10));
78                 }
79
80                 [Test]
81                 public void Any ()
82                 {
83                         Assert.AreEqual (_src.Any<int> (i => i > 5), _array.Any<int> (i => i > 5));
84                 }
85
86                 [Test]
87                 public void Average ()
88                 {
89                         Assert.AreEqual (_src.Average<int> ((n) => 11), _array.Average<int> ((n) => 11));
90                 }
91
92                 [Test]
93                 public void Concat ()
94                 {
95                         Assert.AreEqual (_src.Concat<int> (_src).Count (), _array.Concat<int> (_src).Count ());
96                 }
97
98                 [Test]
99                 public void Contains ()
100                 {
101
102                         for (int i = 1; i < 20; ++i)
103                                 Assert.AreEqual (_src.Contains<int> (i), _array.Contains<int> (i));
104                 }
105
106                 [Test]
107                 public void Count ()
108                 {
109                         Assert.AreEqual (_src.Count<int> (), _array.Count<int> ());
110                 }
111
112                 [Test]
113                 public void Distinct ()
114                 {
115                         Assert.AreEqual (_src.Distinct<int> ().Count (), _array.Distinct<int> ().Count ());
116                         Assert.AreEqual (_src.Distinct<int> (new CustomEqualityComparer ()).Count (), _array.Distinct<int> (new CustomEqualityComparer ()).Count ());
117                 }
118
119                 [Test]
120                 public void ElementAt ()
121                 {
122                         for (int i = 0; i < 10; ++i)
123                                 Assert.AreEqual (_src.ElementAt<int> (i), _array.ElementAt<int> (i));
124                 }
125
126                 [Test]
127                 public void ElementAtOrDefault ()
128                 {
129                         for (int i = 0; i < 10; ++i)
130                                 Assert.AreEqual (_src.ElementAtOrDefault<int> (i), _array.ElementAtOrDefault<int> (i));
131                         Assert.AreEqual (_src.ElementAtOrDefault<int> (100), _array.ElementAtOrDefault<int> (100));
132                 }
133
134                 [Test]
135                 public void Except ()
136                 {
137                         int [] except = { 1, 2, 3 };
138                         Assert.AreEqual (_src.Except<int> (except.AsQueryable ()).Count (), _array.Except<int> (except).Count ());
139                 }
140
141                 [Test]
142                 public void First ()
143                 {
144                         Assert.AreEqual (_src.First<int> (), _array.First<int> ());
145                 }
146
147                 [Test]
148                 public void FirstOrDefault ()
149                 {
150                         Assert.AreEqual (_src.FirstOrDefault<int> ((n) => n > 5), _array.FirstOrDefault<int> ((n) => n > 5));
151                         Assert.AreEqual (_src.FirstOrDefault<int> ((n) => n > 10), _array.FirstOrDefault<int> ((n) => n > 10));
152                 }
153
154                 [Test]
155                 public void GroupBy ()
156                 {
157                         IQueryable<IGrouping<bool, int>> grouping = _src.GroupBy<int, bool> ((n) => n > 5);
158                         Assert.AreEqual (grouping.Count(), 2);
159                         foreach (IGrouping<bool, int> group in grouping)
160                         {
161                                 Assert.AreEqual(group.Count(), 5);
162                         }
163                 }
164
165                 [Test]
166                 public void Intersect ()
167                 {
168                         int [] subset = { 1, 2, 3 };
169                         int[] intersection = _src.Intersect<int> (subset.AsQueryable()).ToArray();
170                         Assert.AreEqual (subset, intersection);
171                 }
172
173                 [Test]
174                 public void Last ()
175                 {
176                         Assert.AreEqual (_src.Last<int> ((n) => n > 1), _array.Last<int> ((n) => n > 1));
177                 }
178
179                 [Test]
180                 public void LastOrDefault ()
181                 {
182                         Assert.AreEqual (_src.LastOrDefault<int> (), _array.LastOrDefault<int> ());
183                 }
184
185                 [Test]
186                 public void LongCount ()
187                 {
188                         Assert.AreEqual (_src.LongCount<int> (), _array.LongCount<int> ());
189                 }
190
191                 [Test]
192                 public void Max ()
193                 {
194                         Assert.AreEqual (_src.Max<int> (), _array.Max<int> ());
195                 }
196
197                 [Test]
198                 public void Min ()
199                 {
200                         Assert.AreEqual (_src.Min<int> (), _array.Min<int> ());
201                 }
202
203                 [Test]
204                 public void OfType ()
205                 {
206                         Assert.AreEqual (_src.OfType<int> ().Count (), _array.OfType<int> ().Count ());
207                 }
208
209                 [Test]
210                 public void OrderBy ()
211                 {
212                         int [] arr1 = _array.OrderBy<int, int> ((n) => n * -1).ToArray ();
213                         int [] arr2 = _src.OrderBy<int, int> ((n) => n * -1).ToArray ();
214                         Assert.AreEqual (arr1, arr2);
215                 }
216
217                 [Test]
218                 public void OrderByDescending ()
219                 {
220                         int [] arr1 = _array.OrderBy<int, int> ((n) => n).ToArray ();
221                         int [] arr2 = _src.OrderBy<int, int> ((n) => n).ToArray ();
222                         Assert.AreEqual (arr1, arr2);
223                 }
224
225                 [Test]
226                 public void Reverse ()
227                 {
228                         int [] arr1 = _array.Reverse<int> ().Reverse ().ToArray ();
229                         int [] arr2 = _src.Reverse<int> ().Reverse ().ToArray ();
230                         Assert.AreEqual (arr1, arr2);
231                 }
232
233                 [Test]
234                 public void Select ()
235                 {
236                         int [] arr1 = _array.Select<int, int> ((n) => n - 1).ToArray ();
237                         int [] arr2 = _src.Select<int, int> ((n) => n - 1).ToArray ();
238                         Assert.AreEqual (arr1, arr2);
239                 }
240
241                 [Test]
242                 public void SelectMany ()
243                 {
244                         int [] arr1 = _array.SelectMany<int, int> ((n) => new int [] { n, n, n }).ToArray ();
245                         int [] arr2 = _src.SelectMany<int, int> ((n) => new int [] { n, n, n }).ToArray ();
246                         Assert.AreEqual (arr1, arr2);
247                 }
248
249                 [Test]
250                 public void SequenceEqual ()
251                 {
252                         Assert.IsTrue (_src.SequenceEqual<int> (_src));
253                 }
254
255                 [Test]
256                 public void Single ()
257                 {
258                         Assert.AreEqual (_src.Single (n => n == 10), 10);
259                 }
260
261                 [Test]
262                 public void SingleOrDefault ()
263                 {
264                         Assert.AreEqual (_src.SingleOrDefault (n => n == 10), 10);
265                         Assert.AreEqual (_src.SingleOrDefault (n => n == 11), 0);
266                 }
267
268                 [Test]
269                 public void Skip ()
270                 {
271                         int [] arr1 = _array.Skip<int> (5).ToArray ();
272                         int [] arr2 = _src.Skip<int> (5).ToArray ();
273                         Assert.AreEqual (arr1, arr2);
274                 }
275
276                 [Test]
277                 public void SkipWhile ()
278                 {
279                         int[] arr1 = _src.SkipWhile<int> ((n) => n < 6).ToArray();
280                         int[] arr2 = _src.Skip<int> (5).ToArray();
281                         Assert.AreEqual (arr1, arr2);
282                 }
283
284                 [Test]
285                 public void Sum ()
286                 {
287                         Assert.AreEqual (_src.Sum<int> ((n) => n), _array.Sum<int> ((n) => n));
288                         Assert.AreEqual (_src.Sum<int> ((n) => n + 1), _array.Sum<int> ((n) => n + 1));
289                 }
290
291                 [Test]
292                 public void Take ()
293                 {
294                         int [] arr1 = _array.Take<int> (3).ToArray ();
295                         int [] arr2 = _src.Take<int> (3).ToArray ();
296                         Assert.AreEqual (arr1, arr2);
297                 }
298
299                 [Test]
300                 public void TakeWhile ()
301                 {
302                         int [] arr1 = _array.TakeWhile<int> (n => n < 6).ToArray ();
303                         int [] arr2 = _src.TakeWhile<int> (n => n < 6).ToArray ();
304                         Assert.AreEqual (arr1, arr2);
305                 }
306
307                 [Test]
308                 public void Union ()
309                 {
310                         int [] arr1 = _src.ToArray ();
311                         int[] arr2 = _src.Union (_src).ToArray ();
312                         Assert.AreEqual (arr1, arr2);
313
314                         int [] arr = { 11,12,13};
315                         Assert.AreEqual (_src.Union (arr).ToArray (), _array.Union (arr).ToArray ());
316                 }
317
318                 [Test]
319                 public void Where ()
320                 {
321                         int[] oddArray1 = _array.Where<int> ((n) => (n % 2) == 1).ToArray();
322                         int [] oddArray2 = _src.Where<int> ((n) => (n % 2) == 1).ToArray ();
323                         Assert.AreEqual (oddArray1, oddArray2);
324                 }
325
326                 [Test]
327                 [Category ("NotWorkingInterpreter")]
328                 public void UserExtensionMethod ()
329                 {
330                         BindingFlags extensionFlags = BindingFlags.Static | BindingFlags.Public;
331                         MethodInfo method = (from m in typeof (Ext).GetMethods (extensionFlags)
332                                                                  where (m.Name == "UserQueryableExt1" && m.GetParameters () [0].ParameterType.GetGenericTypeDefinition () == typeof (IQueryable<>))
333                                                                  select m).FirstOrDefault ().MakeGenericMethod (typeof (int));
334                         Expression<Func<int, int>> exp = i => i;
335                         Expression e = Expression.Equal (
336                                                                         Expression.Constant ("UserEnumerableExt1"),
337                                                                         Expression.Call (method, _src.Expression, Expression.Quote (exp)));
338                         Assert.AreEqual (_src.Provider.Execute<bool> (e), true, "UserQueryableExt1");
339
340                         method = (from m in typeof (Ext).GetMethods (extensionFlags)
341                                                            where (m.Name == "UserQueryableExt2" && m.GetParameters () [0].ParameterType.GetGenericTypeDefinition () == typeof (IQueryable<>))
342                                                            select m).FirstOrDefault ().MakeGenericMethod (typeof (int));
343                         e = Expression.Equal (
344                                                                         Expression.Constant ("UserEnumerableExt2"),
345                                                                         Expression.Call (method, _src.Expression, Expression.Quote (exp)));
346                         Assert.AreEqual (_src.Provider.Execute<bool> (e), true, "UserQueryableExt2");
347                 }
348
349                 [Test]
350                 [ExpectedException (typeof (InvalidOperationException))]
351                 public void UserExtensionMethodNegative ()
352                 {
353                         BindingFlags extensionFlags = BindingFlags.Static | BindingFlags.Public;
354                         MethodInfo method = (from m in typeof (Ext).GetMethods (extensionFlags)
355                                                                  where (m.Name == "UserQueryableExt3" && m.GetParameters () [0].ParameterType.GetGenericTypeDefinition () == typeof (IQueryable<>))
356                                                                  select m).FirstOrDefault ().MakeGenericMethod (typeof (int));
357                         Expression<Func<int, int>> exp = i => i;
358                         Expression e = Expression.Call (method, _src.Expression, Expression.Quote (exp), Expression.Constant (10));
359                         _src.Provider.Execute (e);
360                 }
361
362                 [Test]
363                 public void NonGenericMethod () {
364                         BindingFlags extensionFlags = BindingFlags.Static | BindingFlags.Public;
365                         MethodInfo method = (from m in typeof (Ext).GetMethods (extensionFlags)
366                                                                  where (m.Name == "NonGenericMethod" && m.GetParameters () [0].ParameterType.GetGenericTypeDefinition () == typeof (IQueryable<>))
367                                                                  select m).FirstOrDefault ();
368
369                         Expression e = Expression.Call (method, _src.Expression);
370                         Assert.AreEqual (_src.Provider.Execute (e), "EnumerableNonGenericMethod", "NonGenericMethod");
371                 }
372
373                 [Test]
374                 [ExpectedException(typeof(InvalidOperationException))]
375                 public void InstantiatedGenericMethod () {
376                         BindingFlags extensionFlags = BindingFlags.Static | BindingFlags.Public;
377                         MethodInfo method = (from m in typeof (Ext).GetMethods (extensionFlags)
378                                                                  where (m.Name == "InstantiatedGenericMethod" && m.GetParameters () [0].ParameterType.GetGenericTypeDefinition () == typeof (IQueryable<>))
379                                                                  select m).FirstOrDefault ().MakeGenericMethod (typeof (int));
380
381                         Expression e = Expression.Call (method, _src.Expression, Expression.Constant(0));
382                         _src.Provider.Execute (e);
383                 }
384
385                 [Test]
386                 [ExpectedException (typeof (ArgumentNullException))]
387                 public void NullEnumerable ()
388                 {
389                         IEnumerable<int> a = null;
390                         a.AsQueryable ();
391                 }
392
393                 [Test]
394                 [ExpectedException (typeof (ArgumentException))]
395                 public void NonGenericEnumerable1 ()
396                 {
397                         new MyEnum ().AsQueryable ();
398                 }
399
400                 [Test]
401                 public void NonGenericEnumerable2 ()
402                 {
403                         IEnumerable<int> nonGen = new int[] { 1, 2, 3 };
404                         Assert.IsTrue (nonGen.AsQueryable () is IQueryable<int>);
405                 }
406
407                 class Bar<T1, T2> : IEnumerable<T2> {
408
409                         public IEnumerator<T2> GetEnumerator ()
410                         {
411                                 yield break;
412                         }
413
414                         IEnumerator IEnumerable.GetEnumerator ()
415                         {
416                                 return GetEnumerator ();
417                         }
418                 }
419
420                 [Test]
421                 public void NonGenericAsQueryableInstantiateProperQueryable ()
422                 {
423                         IEnumerable bar = new Bar<int, string> ();
424                         IQueryable queryable = bar.AsQueryable ();
425
426                         Assert.IsTrue (queryable is IQueryable<string>);
427                 }
428         }
429
430         class MyEnum : IEnumerable
431         {
432                 public IEnumerator GetEnumerator ()
433                 {
434                         throw new NotImplementedException ();
435                 }
436         }
437
438         class CustomEqualityComparer : IEqualityComparer<int> {
439
440                 public bool Equals (int x, int y)
441                 {
442                         return true;
443                 }
444
445                 public int GetHashCode (int obj)
446                 {
447                         return 0;
448                 }
449         }
450
451         public static class Ext {
452
453                 public static string UserQueryableExt1<T> (this IQueryable<T> e, Expression<Func<int, int>> ex)
454                 {
455                         return "UserQueryableExt1";
456                 }
457
458                 public static string UserQueryableExt2<T> (this IQueryable<T> e, Expression<Func<int, int>> ex)
459                 {
460                         return "UserQueryableExt2";
461                 }
462
463                 public static string UserQueryableExt3<T> (this IQueryable<T> e, Expression<Func<int, int>> ex, int dummy)
464                 {
465                         return "UserQueryableExt3";
466                 }
467
468                 public static string UserQueryableExt1<T> (this IEnumerable<T> e, Expression<Func<int, int>> ex)
469                 {
470                         return "UserEnumerableExt1";
471                 }
472
473                 public static string UserQueryableExt2<T> (this IEnumerable<T> e, Func<int, int> ex)
474                 {
475                         return "UserEnumerableExt2";
476                 }
477
478                 public static string NonGenericMethod (this IQueryable<int> iq)
479                 {
480                         return "QueryableNonGenericMethod";
481                 }
482
483                 public static string NonGenericMethod (this IEnumerable<int> iq)
484                 {
485                         return "EnumerableNonGenericMethod";
486                 }
487
488                 public static string InstantiatedGenericMethod<T> (this IQueryable<int> iq, T t)
489                 {
490                         return "QueryableInstantiatedGenericMethod";
491                 }
492
493                 public static string InstantiatedGenericMethod (this IEnumerable<int> ie, int t)
494                 {
495                         return "EnumerableInstantiatedGenericMethod";
496                 }
497         }
498 }