From 6561ede7956abcab59c624e8e590c08ef72cb199 Mon Sep 17 00:00:00 2001 From: Roei Erez Date: Thu, 15 May 2008 12:36:13 +0000 Subject: [PATCH] Add implementation for AsQueryable. svn path=/trunk/mcs/; revision=103265 --- .../System.Core-2008.JavaEE.csproj | 5 +- mcs/class/System.Core/System.Core-2008.csproj | 3 + mcs/class/System.Core/System.Core-2008.sln | 16 +- mcs/class/System.Core/System.Core.dll.sources | 3 + .../System.Core/System.Core_test.dll.sources | 1 + .../System.Linq.Expressions/ChangeLog | 5 + .../ExpressionTransformer.cs | 321 ++++++++++++ mcs/class/System.Core/System.Linq/ChangeLog | 6 + .../System.Core/System.Linq/Queryable.cs | 2 +- .../System.Linq/QueryableEnumerable.cs | 94 ++++ .../System.Linq/QueryableTransformer.cs | 172 +++++++ .../Test/System.Core.Tests-2008.JavaEE.csproj | 11 +- .../Test/System.Core.Tests-2008.csproj | 3 +- .../System.Core/Test/System.Linq/ChangeLog | 4 + .../System.Linq/EnumerableAsQueryableTest.cs | 468 ++++++++++++++++++ 15 files changed, 1108 insertions(+), 6 deletions(-) create mode 100644 mcs/class/System.Core/System.Linq.Expressions/ExpressionTransformer.cs create mode 100644 mcs/class/System.Core/System.Linq/QueryableEnumerable.cs create mode 100644 mcs/class/System.Core/System.Linq/QueryableTransformer.cs create mode 100644 mcs/class/System.Core/Test/System.Linq/EnumerableAsQueryableTest.cs diff --git a/mcs/class/System.Core/System.Core-2008.JavaEE.csproj b/mcs/class/System.Core/System.Core-2008.JavaEE.csproj index 776b54e8763..a4a7919fa6e 100644 --- a/mcs/class/System.Core/System.Core-2008.JavaEE.csproj +++ b/mcs/class/System.Core/System.Core-2008.JavaEE.csproj @@ -152,6 +152,7 @@ + @@ -192,6 +193,8 @@ + + @@ -280,4 +283,4 @@ - + \ No newline at end of file diff --git a/mcs/class/System.Core/System.Core-2008.csproj b/mcs/class/System.Core/System.Core-2008.csproj index 4f1cf58f653..ee9f8c07c23 100755 --- a/mcs/class/System.Core/System.Core-2008.csproj +++ b/mcs/class/System.Core/System.Core-2008.csproj @@ -53,6 +53,7 @@ + @@ -74,6 +75,8 @@ + + diff --git a/mcs/class/System.Core/System.Core-2008.sln b/mcs/class/System.Core/System.Core-2008.sln index bce8a7e8f32..35a64066e0b 100755 --- a/mcs/class/System.Core/System.Core-2008.sln +++ b/mcs/class/System.Core/System.Core-2008.sln @@ -1,10 +1,18 @@  Microsoft Visual Studio Solution File, Format Version 10.00 -# Visual C# Express 2008 +# Visual Studio 2008 Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "System.Core-2008", "System.Core-2008.csproj", "{D287D5CA-4F81-4215-AFC8-8A1413696884}" EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "System.Core.Tests-2008", "Test\System.Core.Tests-2008.csproj", "{F902A50D-6156-4935-A1AC-E82DF0EB83D3}" EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{6DD21DC4-1760-412E-BA3B-9295911E8794}" + ProjectSection(SolutionItems) = preProject + LocalTestRun.testrunconfig = LocalTestRun.testrunconfig + System.Core-2008.vsmdi = System.Core-2008.vsmdi + EndProjectSection +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "LINQIntro", "..\LINQIntro\LINQIntro.csproj", "{9C0334A8-9E1C-4581-A547-46DD4D957E76}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -24,6 +32,12 @@ Global {F902A50D-6156-4935-A1AC-E82DF0EB83D3}.Release|Any CPU.Build.0 = Release|Any CPU {F902A50D-6156-4935-A1AC-E82DF0EB83D3}.Test.NET|Any CPU.ActiveCfg = Test.NET|Any CPU {F902A50D-6156-4935-A1AC-E82DF0EB83D3}.Test.NET|Any CPU.Build.0 = Test.NET|Any CPU + {9C0334A8-9E1C-4581-A547-46DD4D957E76}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {9C0334A8-9E1C-4581-A547-46DD4D957E76}.Debug|Any CPU.Build.0 = Debug|Any CPU + {9C0334A8-9E1C-4581-A547-46DD4D957E76}.Release|Any CPU.ActiveCfg = Release|Any CPU + {9C0334A8-9E1C-4581-A547-46DD4D957E76}.Release|Any CPU.Build.0 = Release|Any CPU + {9C0334A8-9E1C-4581-A547-46DD4D957E76}.Test.NET|Any CPU.ActiveCfg = Release|Any CPU + {9C0334A8-9E1C-4581-A547-46DD4D957E76}.Test.NET|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/mcs/class/System.Core/System.Core.dll.sources b/mcs/class/System.Core/System.Core.dll.sources index fbd78daa6ba..2f3171fcbd6 100644 --- a/mcs/class/System.Core/System.Core.dll.sources +++ b/mcs/class/System.Core/System.Core.dll.sources @@ -15,6 +15,8 @@ System.Runtime.CompilerServices/IStrongBox.cs System.Runtime.CompilerServices/StrongBox_T.cs System.Linq/Check.cs System.Linq/Enumerable.cs +System.Linq/QueryableEnumerable.cs +System.Linq/QueryableTransformer.cs System.Linq/Grouping.cs System.Linq/IGrouping.cs System.Linq/IOrderedQueryable.cs @@ -42,6 +44,7 @@ System.Linq.Expressions/Expression_T.cs System.Linq.Expressions/ExpressionPrinter.cs System.Linq.Expressions/ExpressionType.cs System.Linq.Expressions/ExpressionVisitor.cs +System.Linq.Expressions/ExpressionTransformer.cs System.Linq.Expressions/Extensions.cs System.Linq.Expressions/InvocationExpression.cs System.Linq.Expressions/LambdaExpression.cs diff --git a/mcs/class/System.Core/System.Core_test.dll.sources b/mcs/class/System.Core/System.Core_test.dll.sources index 2f6fc57be49..9046e242ed9 100644 --- a/mcs/class/System.Core/System.Core_test.dll.sources +++ b/mcs/class/System.Core/System.Core_test.dll.sources @@ -4,6 +4,7 @@ System/TimeZoneInfo.TransitionTimeTest.cs System.Collections.Generic/HashSetTest.cs System.Linq/EnumerableTest.cs System.Linq/EnumerableMoreTest.cs +System.Linq/EnumerableAsQueryableTest.cs System.Linq.Expressions/ExpressionTest.cs System.Linq.Expressions/ExpressionTest_Add.cs System.Linq.Expressions/ExpressionTest_AddChecked.cs diff --git a/mcs/class/System.Core/System.Linq.Expressions/ChangeLog b/mcs/class/System.Core/System.Linq.Expressions/ChangeLog index 27275f180b9..258832a4eea 100644 --- a/mcs/class/System.Core/System.Linq.Expressions/ChangeLog +++ b/mcs/class/System.Core/System.Linq.Expressions/ChangeLog @@ -1,3 +1,8 @@ +2008-05-15 Roei Erez + + * ExpressionTransformer.cs: Add a base class for transforming Expressions. + In use at AsQueryable() implementation. + 2008-05-14 Jb Evain * EmitContext.cs: only generate a new lambda name if we're in diff --git a/mcs/class/System.Core/System.Linq.Expressions/ExpressionTransformer.cs b/mcs/class/System.Core/System.Linq.Expressions/ExpressionTransformer.cs new file mode 100644 index 00000000000..f82a1f65658 --- /dev/null +++ b/mcs/class/System.Core/System.Linq.Expressions/ExpressionTransformer.cs @@ -0,0 +1,321 @@ +// +// ExpressionTransformer.cs +// +// Authors: +// Roei Erez (roeie@mainsoft.com) +// +// Copyright (C) 2007 Novell, Inc (http://www.novell.com) +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +// + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Text; +using System.Collections.ObjectModel; + +namespace System.Linq +{ + abstract class ExpressionTransformer + { + + protected ExpressionTransformer () + { + } + + public Expression Transform (Expression e) + { + return Visit (e); + } + + protected virtual Expression Visit (Expression expression) + { + if (expression == null) + return null; + + switch (expression.NodeType) { + case ExpressionType.Negate: + case ExpressionType.NegateChecked: + case ExpressionType.Not: + case ExpressionType.Convert: + case ExpressionType.ConvertChecked: + case ExpressionType.ArrayLength: + case ExpressionType.Quote: + case ExpressionType.TypeAs: + case ExpressionType.UnaryPlus: + return VisitUnary ((UnaryExpression) expression); + case ExpressionType.Add: + case ExpressionType.AddChecked: + case ExpressionType.Subtract: + case ExpressionType.SubtractChecked: + case ExpressionType.Multiply: + case ExpressionType.MultiplyChecked: + case ExpressionType.Divide: + case ExpressionType.Modulo: + case ExpressionType.Power: + case ExpressionType.And: + case ExpressionType.AndAlso: + case ExpressionType.Or: + case ExpressionType.OrElse: + case ExpressionType.LessThan: + case ExpressionType.LessThanOrEqual: + case ExpressionType.GreaterThan: + case ExpressionType.GreaterThanOrEqual: + case ExpressionType.Equal: + case ExpressionType.NotEqual: + case ExpressionType.Coalesce: + case ExpressionType.ArrayIndex: + case ExpressionType.RightShift: + case ExpressionType.LeftShift: + case ExpressionType.ExclusiveOr: + return VisitBinary ((BinaryExpression) expression); + case ExpressionType.TypeIs: + return VisitTypeIs ((TypeBinaryExpression) expression); + case ExpressionType.Conditional: + return VisitConditional ((ConditionalExpression) expression); + case ExpressionType.Constant: + return VisitConstant ((ConstantExpression) expression); + case ExpressionType.Parameter: + return VisitParameter ((ParameterExpression) expression); + case ExpressionType.MemberAccess: + return VisitMemberAccess ((MemberExpression) expression); + case ExpressionType.Call: + return VisitMethodCall ((MethodCallExpression) expression); + case ExpressionType.Lambda: + return VisitLambda ((LambdaExpression) expression); + case ExpressionType.New: + return VisitNew ((NewExpression) expression); + case ExpressionType.NewArrayInit: + case ExpressionType.NewArrayBounds: + return VisitNewArray ((NewArrayExpression) expression); + case ExpressionType.Invoke: + return VisitInvocation ((InvocationExpression) expression); + case ExpressionType.MemberInit: + return VisitMemberInit ((MemberInitExpression) expression); + case ExpressionType.ListInit: + return VisitListInit ((ListInitExpression) expression); + default: + throw new ArgumentException (string.Format ("Unhandled expression type: '{0}'", expression.NodeType)); + } + } + + protected virtual MemberBinding VisitBinding (MemberBinding binding) + { + switch (binding.BindingType) { + case MemberBindingType.Assignment: + return VisitMemberAssignment ((MemberAssignment) binding); + case MemberBindingType.MemberBinding: + return VisitMemberMemberBinding ((MemberMemberBinding) binding); + case MemberBindingType.ListBinding: + return VisitMemberListBinding ((MemberListBinding) binding); + default: + throw new ArgumentException (string.Format ("Unhandled binding type '{0}'", binding.BindingType)); + } + } + + protected virtual ElementInit VisitElementInitializer (ElementInit initializer) + { + ReadOnlyCollection transformed = VisitExpressionList (initializer.Arguments); + if (transformed != initializer.Arguments) + return Expression.ElementInit (initializer.AddMethod, transformed); + return initializer; + } + + protected virtual UnaryExpression VisitUnary (UnaryExpression unary) + { + Expression transformedOperand = Visit (unary.Operand); + if (transformedOperand != unary.Operand) + return Expression.MakeUnary (unary.NodeType, transformedOperand, unary.Type, unary.Method); + return unary; + } + + protected virtual BinaryExpression VisitBinary (BinaryExpression binary) + { + Expression left = Visit (binary.Left); + Expression right = Visit (binary.Right); + LambdaExpression conversion = VisitLambda (binary.Conversion); + if (left != binary.Left || right != binary.Right || conversion != binary.Conversion) + return Expression.MakeBinary (binary.NodeType, left, right, binary.IsLiftedToNull, binary.Method, conversion); + return binary; + } + + protected virtual TypeBinaryExpression VisitTypeIs (TypeBinaryExpression type) + { + Expression inner = Visit (type.Expression); + if (inner != type.Expression) + return Expression.TypeIs (inner, type.TypeOperand); + return type; + } + + protected virtual ConstantExpression VisitConstant (ConstantExpression constant) + { + return constant; + } + + protected virtual ConditionalExpression VisitConditional (ConditionalExpression conditional) + { + Expression test = Visit (conditional.Test); + Expression ifTrue = Visit (conditional.IfTrue); + Expression ifFalse = Visit (conditional.IfFalse); + if (test != conditional.Test || ifTrue != conditional.IfTrue || ifFalse != conditional.IfFalse) + return Expression.Condition (test, ifTrue, ifFalse); + return conditional; + } + + protected virtual ParameterExpression VisitParameter (ParameterExpression parameter) + { + return parameter; + } + + protected virtual MemberExpression VisitMemberAccess (MemberExpression member) + { + Expression memberExp = Visit (member.Expression); + if (memberExp != member.Expression) + return Expression.MakeMemberAccess (memberExp, member.Member); + return member; + } + + protected virtual MethodCallExpression VisitMethodCall (MethodCallExpression methodCall) + { + Expression instance = Visit (methodCall.Object); + ReadOnlyCollection args = VisitExpressionList (methodCall.Arguments); + if (instance != methodCall.Object || args != methodCall.Arguments) + return Expression.Call (instance, methodCall.Method, args); + return methodCall; + } + + protected virtual ReadOnlyCollection VisitExpressionList (ReadOnlyCollection list) + { + return VisitList (list, Visit); + } + + private ReadOnlyCollection VisitList (ReadOnlyCollection list, Func selector) where T :class + { + int index = 0; + T [] arr = null; + foreach (T e in list) { + T visited = selector (e); + if (visited != e || arr != null) { + if (arr == null) + arr = new T [list.Count]; + arr [index] = visited; + } + index++; + } + if (arr != null) + return arr.ToReadOnlyCollection (); + return list; + } + + protected virtual MemberAssignment VisitMemberAssignment (MemberAssignment assignment) + { + Expression inner = Visit (assignment.Expression); + if (inner != assignment.Expression) + return Expression.Bind (assignment.Member, inner); + return assignment; + } + + protected virtual MemberMemberBinding VisitMemberMemberBinding (MemberMemberBinding binding) + { + ReadOnlyCollection bindingExp = VisitBindingList (binding.Bindings); + if (bindingExp != binding.Bindings) + return Expression.MemberBind (binding.Member, bindingExp); + return binding; + } + + protected virtual MemberListBinding VisitMemberListBinding (MemberListBinding binding) + { + ReadOnlyCollection initializers = + VisitElementInitializerList (binding.Initializers); + if (initializers != binding.Initializers) + return Expression.ListBind (binding.Member, initializers); + return binding; + } + + protected virtual ReadOnlyCollection VisitBindingList (ReadOnlyCollection list) + { + return VisitList (list, VisitBinding); + } + + protected virtual ReadOnlyCollection VisitElementInitializerList (ReadOnlyCollection list) + { + return VisitList (list, VisitElementInitializer); + } + + protected virtual LambdaExpression VisitLambda (LambdaExpression lambda) + { + Expression body = Visit (lambda.Body); + ReadOnlyCollection parameters = + VisitList (lambda.Parameters, VisitParameter); + if (body != lambda.Body || parameters != lambda.Parameters) + return Expression.Lambda (body, parameters.ToArray()); + return lambda; + } + + protected virtual NewExpression VisitNew (NewExpression nex) + { + ReadOnlyCollection args = VisitList (nex.Arguments, Visit); + if (args != nex.Arguments) + return Expression.New (nex.Constructor, args); + return nex; + } + + protected virtual MemberInitExpression VisitMemberInit (MemberInitExpression init) + { + NewExpression newExp = VisitNew (init.NewExpression); + ReadOnlyCollection bindings = VisitBindingList (init.Bindings); + if (newExp != init.NewExpression || bindings != init.Bindings) + return Expression.MemberInit (newExp, bindings); + return init; + } + + protected virtual ListInitExpression VisitListInit (ListInitExpression init) + { + NewExpression newExp = VisitNew (init.NewExpression); + ReadOnlyCollection initializers = VisitElementInitializerList (init.Initializers); + if (newExp != init.NewExpression || initializers != init.Initializers) + return Expression.ListInit (newExp, initializers.ToArray()); + return init; + } + + protected virtual NewArrayExpression VisitNewArray (NewArrayExpression newArray) + { + ReadOnlyCollection expressions = VisitExpressionList (newArray.Expressions); + if (expressions != newArray.Expressions) { + if (newArray.NodeType == ExpressionType.NewArrayBounds) + return Expression.NewArrayBounds (newArray.Type, expressions); + else + return Expression.NewArrayInit (newArray.Type, expressions); + } + return newArray; + } + + protected virtual InvocationExpression VisitInvocation (InvocationExpression invocation) + { + ReadOnlyCollection args = VisitExpressionList (invocation.Arguments); + Expression invocationExp = Visit (invocation.Expression); + if (args != invocation.Arguments || invocationExp != invocation.Expression) + return Expression.Invoke (invocationExp, args); + return invocation; + } + } +} diff --git a/mcs/class/System.Core/System.Linq/ChangeLog b/mcs/class/System.Core/System.Linq/ChangeLog index b20e6b9f076..7bddb18e3dd 100644 --- a/mcs/class/System.Core/System.Linq/ChangeLog +++ b/mcs/class/System.Core/System.Linq/ChangeLog @@ -1,3 +1,9 @@ +2008-05-15 Roei Erez + + * QueryableTransformer.cs, QueryableEnumerable.cs: two classes added for implementation + of Queryable.AsQueryable() implementation. + * Queryable.cs: Implement AsQueryable() method. + 2008-05-08 Jonathan Pryor * Enumerable.cs: LongCount() can be optimized for arrays, and Reverse() can diff --git a/mcs/class/System.Core/System.Linq/Queryable.cs b/mcs/class/System.Core/System.Linq/Queryable.cs index 3daf0e05303..dd08ebe241d 100644 --- a/mcs/class/System.Core/System.Linq/Queryable.cs +++ b/mcs/class/System.Core/System.Linq/Queryable.cs @@ -144,7 +144,7 @@ namespace System.Linq { if (queryable != null) return queryable; - throw new NotImplementedException (); + return new QueryableEnumerable (new ConstantExpression (source, typeof (IQueryable))); } [MonoTODO] diff --git a/mcs/class/System.Core/System.Linq/QueryableEnumerable.cs b/mcs/class/System.Core/System.Linq/QueryableEnumerable.cs new file mode 100644 index 00000000000..47078033f7b --- /dev/null +++ b/mcs/class/System.Core/System.Linq/QueryableEnumerable.cs @@ -0,0 +1,94 @@ +// +// QueryableEnumerable.cs +// +// Authors: +// Roei Erez (roeie@mainsoft.com) +// +// Copyright (C) 2007 Novell, Inc (http://www.novell.com) +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +// + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Linq.Expressions; + +namespace System.Linq +{ + internal class QueryableEnumerable : IQueryable, IQueryProvider, IOrderedQueryable + { + Expression expression; + + public QueryableEnumerable (Expression expression) { + this.expression = expression; + } + + public Type ElementType { + get { return expression.Type; } + } + + public Expression Expression { + get { return expression; } + } + + public IQueryProvider Provider { + get { return this; } + } + + public System.Collections.IEnumerator GetEnumerator () + { + return ((IEnumerable)this).GetEnumerator (); + } + + IEnumerator IEnumerable.GetEnumerator () + { + return Execute> (Expression).GetEnumerator (); + } + + public IQueryable CreateQuery (System.Linq.Expressions.Expression expression) + { + return (IQueryable) Activator.CreateInstance ( + typeof (QueryableEnumerable<>).MakeGenericType (expression.Type.GetGenericArguments()[0]), expression); + } + + public object Execute (System.Linq.Expressions.Expression expression) + { + QueryableTransformer visitor = new QueryableTransformer (); + Expression body = visitor.Transform (expression); + LambdaExpression lambda = Expression.Lambda (body); + return lambda.Compile ().DynamicInvoke(); + } + + public IQueryable CreateQuery (System.Linq.Expressions.Expression expression) + { + return new QueryableEnumerable (expression); + } + + public TResult Execute (System.Linq.Expressions.Expression expression) + { + QueryableTransformer visitor = new QueryableTransformer (); + Expression body = visitor.Transform (expression); + Expression> lambda = Expression.Lambda> (body); + return lambda.Compile () (); + } + } +} diff --git a/mcs/class/System.Core/System.Linq/QueryableTransformer.cs b/mcs/class/System.Core/System.Linq/QueryableTransformer.cs new file mode 100644 index 00000000000..d0aaf777c20 --- /dev/null +++ b/mcs/class/System.Core/System.Linq/QueryableTransformer.cs @@ -0,0 +1,172 @@ +// +// QueryableTransformer.cs +// +// Authors: +// Roei Erez (roeie@mainsoft.com) +// +// Copyright (C) 2007 Novell, Inc (http://www.novell.com) +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +// + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Collections.ObjectModel; + +namespace System.Linq +{ + internal class QueryableTransformer : ExpressionTransformer + { + + internal QueryableTransformer () {} + + protected override MethodCallExpression VisitMethodCall (MethodCallExpression methodCall) + { + if ( IsQueryableExtension ( methodCall.Method )) + { + return ReplaceIQueryableMethod (methodCall); + } + return base.VisitMethodCall (methodCall); + } + + protected override LambdaExpression VisitLambda (LambdaExpression lambda) + { + return lambda; + } + + bool IsQueryableExtension (MethodInfo method) + { + return method.GetCustomAttributes(typeof(ExtensionAttribute), false).Count() > 0 && + typeof(IQueryable).IsAssignableFrom( method.GetParameters () [0].ParameterType ); + } + + MethodCallExpression ReplaceIQueryableMethod (MethodCallExpression oldCall) + { + Expression target = null; + if (oldCall.Object != null){ + target = Visit (oldCall.Object); + } + MethodInfo newMethod = ReplaceIQueryableMethodInfo(oldCall.Method); + + Expression [] args = new Expression [oldCall.Arguments.Count]; + int counter = 0; + foreach (Expression e in oldCall.Arguments) { + Type methodParam = newMethod.GetParameters() [counter].ParameterType; + args [counter++] = ReplaceQuotedLambdaIfNeeded(Visit (e), methodParam); + } + ReadOnlyCollection col = args.ToReadOnlyCollection(); + MethodCallExpression newMethodCall = new MethodCallExpression (target, newMethod, col); + return newMethodCall; + } + + static Expression ReplaceQuotedLambdaIfNeeded (Expression e, Type delegateType) + { + UnaryExpression unary = e as UnaryExpression; + if (unary != null) { + LambdaExpression lambda = unary.Operand as LambdaExpression; + if (lambda != null && lambda.Type == delegateType) + return lambda; + } + return e; + } + + static MethodInfo ReplaceIQueryableMethodInfo (MethodInfo qm) + { + Type typeToSearch = qm.DeclaringType == typeof (Queryable) ? typeof (Enumerable) : qm.DeclaringType; + MethodInfo result = GetMatchingMethod (qm, typeToSearch); + if (result == null) + throw new InvalidOperationException ( + string.Format("There is no method {0} on type {1} that matches the specified arguments", + qm.Name, + qm.DeclaringType.FullName)); + return result; + } + + static MethodInfo GetMatchingMethod (MethodInfo qm, Type fromType) + { + return (from em in fromType.GetMethods () + where Match (em, qm) + select em.MakeGenericMethod (qm.GetGenericArguments ())) + .FirstOrDefault (); + } + + static bool Match (MethodInfo em, MethodInfo qm) { + + if (em.GetCustomAttributes (typeof (ExtensionAttribute), false).Count() == 0) + return false; + + if (em.Name != qm.Name) + return false; + + if (em.GetGenericArguments ().Length != qm.GetGenericArguments ().Length) + return false; + + Type [] parameters = (from p in qm.GetParameters () select p.ParameterType).ToArray (); + Type returnType = qm.ReturnType; + + if (parameters.Length != em.GetParameters ().Length) + return false; + + MethodInfo instanceMethod = em; + if (qm.IsGenericMethod) { + if (!qm.IsGenericMethod) + return false; + if (em.GetParameters ().Length != qm.GetParameters ().Length) + return false; + Type [] genArgs = qm.GetGenericArguments (); + instanceMethod = em.MakeGenericMethod (genArgs); + } + + Type [] enumerableParams = (from p in instanceMethod.GetParameters () select p.ParameterType).ToArray (); + + if (enumerableParams [0] != ConvertParameter (parameters [0])) + return false; + for (int i = 1; i < enumerableParams.Length; ++i) + if (!ArgumentMatch(enumerableParams [i], parameters [i])) + return false; + if (!ArgumentMatch(instanceMethod.ReturnType, returnType)) + return false; + return true; + } + + static bool ArgumentMatch (Type enumerableParam, Type queryableParam) + { + return enumerableParam == queryableParam || enumerableParam == ConvertParameter (queryableParam); + } + + static Type ConvertParameter (Type type) + { + if (type.IsGenericType && type.GetGenericTypeDefinition () == typeof (IQueryable<>)) + type = typeof (IEnumerable<>).MakeGenericType (type.GetGenericArguments ()); + else if (type.IsGenericType && type.GetGenericTypeDefinition () == typeof (IOrderedQueryable<>)) + type = typeof (IOrderedEnumerable<>).MakeGenericType (type.GetGenericArguments ()); + else if (type.IsGenericType && type.GetGenericTypeDefinition () == typeof (Expression<>)) + type = type.GetGenericArguments () [0]; + else if (type == typeof (IQueryable)) + type = typeof (System.Collections.IEnumerable); + return type; + } + } +} diff --git a/mcs/class/System.Core/Test/System.Core.Tests-2008.JavaEE.csproj b/mcs/class/System.Core/Test/System.Core.Tests-2008.JavaEE.csproj index 5a58919298b..d8945b06e3c 100644 --- a/mcs/class/System.Core/Test/System.Core.Tests-2008.JavaEE.csproj +++ b/mcs/class/System.Core/Test/System.Core.Tests-2008.JavaEE.csproj @@ -115,7 +115,7 @@ --> - + @@ -148,10 +148,13 @@ + + + @@ -204,12 +207,16 @@ + + + + - \ No newline at end of file + diff --git a/mcs/class/System.Core/Test/System.Core.Tests-2008.csproj b/mcs/class/System.Core/Test/System.Core.Tests-2008.csproj index e592341572f..01c228b99cf 100755 --- a/mcs/class/System.Core/Test/System.Core.Tests-2008.csproj +++ b/mcs/class/System.Core/Test/System.Core.Tests-2008.csproj @@ -103,6 +103,7 @@ + @@ -126,4 +127,4 @@ --> - \ No newline at end of file + diff --git a/mcs/class/System.Core/Test/System.Linq/ChangeLog b/mcs/class/System.Core/Test/System.Linq/ChangeLog index 97894bcf362..d0728314d09 100644 --- a/mcs/class/System.Core/Test/System.Linq/ChangeLog +++ b/mcs/class/System.Core/Test/System.Linq/ChangeLog @@ -1,3 +1,7 @@ +2008-05-15 Roei Erez + + * EnumerableAsQueryableTest.cs: test cases for Queryable.AsQueryable() implementation. + 2008-05-08 Jonathan Pryor * EnumerableTest.cs: test Reverse() for non-IList types. diff --git a/mcs/class/System.Core/Test/System.Linq/EnumerableAsQueryableTest.cs b/mcs/class/System.Core/Test/System.Linq/EnumerableAsQueryableTest.cs new file mode 100644 index 00000000000..9740f7b1a84 --- /dev/null +++ b/mcs/class/System.Core/Test/System.Linq/EnumerableAsQueryableTest.cs @@ -0,0 +1,468 @@ +// +// EnumerableAsQueryableTest.cs +// +// Authors: +// Roei Erez (roeie@mainsoft.com) +// +// Copyright (C) 2007 Novell, Inc (http://www.novell.com) +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +// + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using NUnit.Framework; +using System.Linq.Expressions; +using System.Runtime.CompilerServices; +using System.Reflection; + +namespace MonoTests.System.Linq { + + [TestFixture] + public class EnumerableAsQueryableTest { + + int [] _array; + IQueryable _src; + + [SetUp] + public void MyTestCleanup () + { + _array = new int [] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }; + _src = _array.AsQueryable (); + } + +#if TARGET_JVM //TODO: gmcs fails, bug #390666 + [Test] + public void Aggregate () + { + Assert.AreEqual (_src.Aggregate ((n, m) => n + m), _array.Aggregate ((n, m) => n + m)); + } +#endif + +#if TARGET_JVM //TODO: gmcs fails, bug #390666 + [Test] + public void All () + { + Assert.AreEqual (_src.All ((n) => n < 11), _array.All ((n) => n < 11)); + Assert.AreEqual (_src.All ((n) => n < 10), _array.All ((n) => n < 10)); + } +#endif + +#if TARGET_JVM //TODO: gmcs fails, bug #390666 + [Test] + public void Any () + { + Assert.AreEqual (_src.Any (i => i > 5), _array.Any (i => i > 5)); + } +#endif + +#if TARGET_JVM //TODO: gmcs fails, bug #390666 + [Test] + public void Average () + { + Assert.AreEqual (_src.Average ((n) => 11), _array.Average ((n) => 11)); + } +#endif + + [Test] + public void Concat () + { + Assert.AreEqual (_src.Concat (_src).Count (), _array.Concat (_src).Count ()); + } + + [Test] + public void Contains () + { + + for (int i = 1; i < 20; ++i) + Assert.AreEqual (_src.Contains (i), _array.Contains (i)); + } + + [Test] + public void Count () + { + Assert.AreEqual (_src.Count (), _array.Count ()); + } + + + [Test] + public void Distinct () + { + Assert.AreEqual (_src.Distinct ().Count (), _array.Distinct ().Count ()); + Assert.AreEqual (_src.Distinct (new CustomEqualityComparer ()).Count (), _array.Distinct (new CustomEqualityComparer ()).Count ()); + } + + [Test] + public void ElementAt () + { + for (int i = 0; i < 10; ++i) + Assert.AreEqual (_src.ElementAt (i), _array.ElementAt (i)); + } + + [Test] + public void ElementAtOrDefault () + { + for (int i = 0; i < 10; ++i) + Assert.AreEqual (_src.ElementAtOrDefault (i), _array.ElementAtOrDefault (i)); + Assert.AreEqual (_src.ElementAtOrDefault (100), _array.ElementAtOrDefault (100)); + } + + [Test] + public void Except () + { + int [] except = { 1, 2, 3 }; + Assert.AreEqual (_src.Except (except.AsQueryable ()).Count (), _array.Except (except).Count ()); + } + + [Test] + public void First () + { + Assert.AreEqual (_src.First (), _array.First ()); + } + +#if TARGET_JVM //TODO: gmcs fails, bug #390666 + [Test] + public void FirstOrDefault () + { + Assert.AreEqual (_src.FirstOrDefault ((n) => n > 5), _array.FirstOrDefault ((n) => n > 5)); + Assert.AreEqual (_src.FirstOrDefault ((n) => n > 10), _array.FirstOrDefault ((n) => n > 10)); + } +#endif + +#if TARGET_JVM //TODO: gmcs fails, bug #390666 + [Test] + public void GroupBy () + { + IQueryable> grouping = _src.GroupBy ((n) => n > 5); + Assert.AreEqual (grouping.Count(), 2); + foreach (IGrouping group in grouping) + { + Assert.AreEqual(group.Count(), 5); + } + } +#endif + +#if TARGET_JVM //TODO: gmcs fails, bug #390666 + [Test] + public void Intersect () + { + int [] subset = { 1, 2, 3 }; + int[] intersection = _src.Intersect (subset.AsQueryable()).ToArray(); + Assert.AreEqual (subset, intersection); + } +#endif + +#if TARGET_JVM //TODO: gmcs fails, bug #390666 + [Test] + public void Last () + { + Assert.AreEqual (_src.Last ((n) => n > 1), _array.Last ((n) => n > 1)); + } +#endif + + [Test] + public void LastOrDefault () + { + Assert.AreEqual (_src.LastOrDefault (), _array.LastOrDefault ()); + } + + [Test] + public void LongCount () + { + Assert.AreEqual (_src.LongCount (), _array.LongCount ()); + } + + [Test] + public void Max () + { + Assert.AreEqual (_src.Max (), _array.Max ()); + } + + [Test] + public void Min () + { + Assert.AreEqual (_src.Min (), _array.Min ()); + } + + [Test] + public void OfType () + { + Assert.AreEqual (_src.OfType ().Count (), _array.OfType ().Count ()); + } + +#if TARGET_JVM //TODO: gmcs fails, bug #390666 + [Test] + public void OrderBy () + { + int [] arr1 = _array.OrderBy ((n) => n * -1).ToArray (); + int [] arr2 = _src.OrderBy ((n) => n * -1).ToArray (); + Assert.AreEqual (arr1, arr2); + } +#endif + +#if TARGET_JVM //TODO: gmcs fails, bug #390666 + [Test] + public void OrderByDescending () + { + int [] arr1 = _array.OrderBy ((n) => n).ToArray (); + int [] arr2 = _src.OrderBy ((n) => n).ToArray (); + Assert.AreEqual (arr1, arr2); + } +#endif + + [Test] + public void Reverse () + { + int [] arr1 = _array.Reverse ().Reverse ().ToArray (); + int [] arr2 = _src.Reverse ().Reverse ().ToArray (); + Assert.AreEqual (arr1, arr2); + } + +#if TARGET_JVM //TODO: gmcs fails, bug #390666 + [Test] + public void Select () + { + int [] arr1 = _array.Select ((n) => n - 1).ToArray (); + int [] arr2 = _src.Select ((n) => n - 1).ToArray (); + Assert.AreEqual (arr1, arr2); + } +#endif + +#if TARGET_JVM //TODO: gmcs fails, bug #390666 + [Test] + public void SelectMany () + { + int [] arr1 = _array.SelectMany ((n) => new int [] { n, n, n }).ToArray (); + int [] arr2 = _src.SelectMany ((n) => new int [] { n, n, n }).ToArray (); + Assert.AreEqual (arr1, arr2); + } +#endif + + [Test] + public void SequenceEqual () + { + Assert.IsTrue (_src.SequenceEqual (_src)); + } + +#if TARGET_JVM //TODO: gmcs fails, bug #390666 + [Test] + public void Single () + { + Assert.AreEqual (_src.Single (n => n == 10), 10); + } +#endif + +#if TARGET_JVM //TODO: gmcs fails, bug #390666 + [Test] + public void SingleOrDefault () + { + Assert.AreEqual (_src.SingleOrDefault (n => n == 10), 10); + Assert.AreEqual (_src.SingleOrDefault (n => n == 11), 0); + } +#endif + + [Test] + public void Skip () + { + int [] arr1 = _array.Skip (5).ToArray (); + int [] arr2 = _src.Skip (5).ToArray (); + Assert.AreEqual (arr1, arr2); + } + +#if TARGET_JVM //TODO: gmcs fails, bug #390666 + [Test] + public void SkipWhile () + { + int[] arr1 = _src.SkipWhile ((n) => n < 6).ToArray(); + int[] arr2 = _src.Skip (5).ToArray(); + Assert.AreEqual (arr1, arr2); + } +#endif + +#if TARGET_JVM //TODO: gmcs fails, bug #390666 + [Test] + public void Sum () + { + Assert.AreEqual (_src.Sum ((n) => n), _array.Sum ((n) => n)); + Assert.AreEqual (_src.Sum ((n) => n + 1), _array.Sum ((n) => n + 1)); + } +#endif + + [Test] + public void Take () + { + int [] arr1 = _array.Take (3).ToArray (); + int [] arr2 = _src.Take (3).ToArray (); + Assert.AreEqual (arr1, arr2); + } + +#if TARGET_JVM //TODO: gmcs fails, bug #390666 + [Test] + public void TakeWhile () + { + int [] arr1 = _array.TakeWhile (n => n < 6).ToArray (); + int [] arr2 = _src.TakeWhile (n => n < 6).ToArray (); + Assert.AreEqual (arr1, arr2); + } +#endif + + [Test] + public void Union () + { + int [] arr1 = _src.ToArray (); + int[] arr2 = _src.Union (_src).ToArray (); + Assert.AreEqual (arr1, arr2); + + int [] arr = { 11,12,13}; + Assert.AreEqual (_src.Union (arr).ToArray (), _array.Union (arr).ToArray ()); + } + +#if TARGET_JVM //TODO: gmcs fails, bug #390666 + [Test] + public void Where () + { + int[] oddArray1 = _array.Where ((n) => (n % 2) == 1).ToArray(); + int [] oddArray2 = _src.Where ((n) => (n % 2) == 1).ToArray (); + Assert.AreEqual (oddArray1, oddArray2); + } +#endif + + [Test] + public void UserExtensionMethod () + { + BindingFlags extensionFlags = BindingFlags.Static | BindingFlags.Public; + MethodInfo method = (from m in typeof (Ext).GetMethods (extensionFlags) + where (m.Name == "UserQueryableExt1" && m.GetParameters () [0].ParameterType.GetGenericTypeDefinition () == typeof (IQueryable<>)) + select m).FirstOrDefault ().MakeGenericMethod (typeof (int)); + Expression> exp = i => i; + Expression e = Expression.Equal ( + Expression.Constant ("UserEnumerableExt1"), + Expression.Call (method, _src.Expression, Expression.Quote (exp))); + Assert.AreEqual (_src.Provider.Execute (e), true, "UserQueryableExt1"); + + method = (from m in typeof (Ext).GetMethods (extensionFlags) + where (m.Name == "UserQueryableExt2" && m.GetParameters () [0].ParameterType.GetGenericTypeDefinition () == typeof (IQueryable<>)) + select m).FirstOrDefault ().MakeGenericMethod (typeof (int)); + e = Expression.Equal ( + Expression.Constant ("UserEnumerableExt2"), + Expression.Call (method, _src.Expression, Expression.Quote (exp))); + Assert.AreEqual (_src.Provider.Execute (e), true, "UserQueryableExt2"); + } + + [Test] + [ExpectedException (typeof (InvalidOperationException))] + public void UserExtensionMethodNegative () + { + BindingFlags extensionFlags = BindingFlags.Static | BindingFlags.Public; + MethodInfo method = (from m in typeof (Ext).GetMethods (extensionFlags) + where (m.Name == "UserQueryableExt3" && m.GetParameters () [0].ParameterType.GetGenericTypeDefinition () == typeof (IQueryable<>)) + select m).FirstOrDefault ().MakeGenericMethod (typeof (int)); + Expression> exp = i => i; + Expression e = Expression.Call (method, _src.Expression, Expression.Quote (exp), Expression.Constant (10)); + _src.Provider.Execute (e); + } + + [Test] + public void NonGenericMethod () { + BindingFlags extensionFlags = BindingFlags.Static | BindingFlags.Public; + MethodInfo method = (from m in typeof (Ext).GetMethods (extensionFlags) + where (m.Name == "NonGenericMethod" && m.GetParameters () [0].ParameterType.GetGenericTypeDefinition () == typeof (IQueryable<>)) + select m).FirstOrDefault (); + Expression> exp = i => i; + Expression e = Expression.Call (method, _src.Expression); + Assert.AreEqual (_src.Provider.Execute (e), "EnumerableNonGenericMethod", "NonGenericMethod"); + } + + [Test] + [ExpectedException(typeof(InvalidOperationException))] + public void InstantiatedGenericMethod () { + BindingFlags extensionFlags = BindingFlags.Static | BindingFlags.Public; + MethodInfo method = (from m in typeof (Ext).GetMethods (extensionFlags) + where (m.Name == "InstantiatedGenericMethod" && m.GetParameters () [0].ParameterType.GetGenericTypeDefinition () == typeof (IQueryable<>)) + select m).FirstOrDefault ().MakeGenericMethod (typeof (int)); + Expression> exp = i => i; + Expression e = Expression.Call (method, _src.Expression, Expression.Constant(0)); + _src.Provider.Execute (e); + } + } + + class CustomEqualityComparer : IEqualityComparer { + + public bool Equals (int x, int y) + { + return true; + } + + public int GetHashCode (int obj) + { + return 0; + } + } + + public static class Ext { + + public static string UserQueryableExt1 (this IQueryable e, Expression> ex) + { + return "UserQueryableExt1"; + } + + public static string UserQueryableExt2 (this IQueryable e, Expression> ex) + { + return "UserQueryableExt2"; + } + + public static string UserQueryableExt3 (this IQueryable e, Expression> ex, int dummy) + { + return "UserQueryableExt3"; + } + + public static string UserQueryableExt1 (this IEnumerable e, Expression> ex) + { + return "UserEnumerableExt1"; + } + + public static string UserQueryableExt2 (this IEnumerable e, Func ex) + { + return "UserEnumerableExt2"; + } + + public static string NonGenericMethod (this IQueryable iq) + { + return "QueryableNonGenericMethod"; + } + + public static string NonGenericMethod (this IEnumerable iq) + { + return "EnumerableNonGenericMethod"; + } + + public static string InstantiatedGenericMethod (this IQueryable iq, T t) + { + return "QueryableInstantiatedGenericMethod"; + } + + public static string InstantiatedGenericMethod (this IEnumerable ie, int t) + { + return "EnumerableInstantiatedGenericMethod"; + } + } +} -- 2.25.1