2008-03-05 Jb Evain <jbevain@novell.com>
[mono.git] / mcs / class / System.Core / System.Linq.Expressions / Expression.cs
index 684f3305bcb87beb73b1fd24dd4bf76464dc96b7..d24322901d8a585964e76affb0b1370ee726374a 100644 (file)
@@ -66,22 +66,7 @@ namespace System.Linq.Expressions {
                        return ExpressionPrinter.ToString (this);
                }
 
-#region Binary Expressions
-               static bool IsInt (Type t)
-               {
-                       return t == typeof (byte) || t == typeof (sbyte) ||
-                               t == typeof (short) || t == typeof (ushort) ||
-                               t == typeof (int) || t == typeof (uint) ||
-                               t == typeof (long) || t == typeof (ulong);
-               }
-
-               static bool IsNumber (Type t)
-               {
-                       if (IsInt (t))
-                               return true;
-
-                       return t == typeof (float) || t == typeof (double) || t == typeof (decimal);
-               }
+               #region Binary Expressions
 
                static MethodInfo GetUnaryOperator (string oper_name, Type on_type, Expression expression)
                {
@@ -144,7 +129,7 @@ namespace System.Linq.Expressions {
                {
                        MethodInfo [] methods = on_type.GetMethods (PublicStatic);
 
-                       foreach (MethodInfo m in methods){
+                       foreach (MethodInfo m in methods) {
                                if (m.Name != oper_name)
                                        continue;
 
@@ -187,30 +172,21 @@ namespace System.Linq.Expressions {
                                if (pi.Length != 2)
                                        throw new ArgumentException ("Must have only two parameters", "method");
 
-                               Type ltype = left.Type.IsValueType && IsNullable (left.Type) ? GetNullableOf(left.Type) : left.Type;
-                               Type rtype = left.Type.IsValueType && IsNullable (right.Type) ? GetNullableOf(right.Type) :right.Type;
-
-                               if (ltype != pi [0].ParameterType)
+                               if (!pi [0].ParameterType.IsAssignableFrom (GetNotNullableOf (left.Type)))
                                        throw new InvalidOperationException ("left-side argument type does not match left expression type");
 
-                               if (rtype != pi [1].ParameterType)
+                               if (!pi [1].ParameterType.IsAssignableFrom (GetNotNullableOf (right.Type)))
                                        throw new InvalidOperationException ("right-side argument type does not match right expression type");
 
                                return method;
                        } else {
                                Type ltype = left.Type;
                                Type rtype = right.Type;
-                               Type ultype = left.Type;
-                               Type urtype = right.Type;
-
-                               if (IsNullable (ltype))
-                                       ultype = GetNullableOf (ltype);
-
-                               if (IsNullable (rtype))
-                                       urtype = GetNullableOf (rtype);
+                               Type ultype = GetNotNullableOf (ltype);
+                               Type urtype = GetNotNullableOf (rtype);
 
-                               if (oper_name == "op_BitwiseOr" || oper_name == "op_BitwiseAnd"){
-                                       if (ultype == typeof (bool)){
+                               if (oper_name == "op_BitwiseOr" || oper_name == "op_BitwiseAnd") {
+                                       if (ultype == typeof (bool)) {
                                                if (ultype == urtype && ltype == rtype)
                                                        return null;
                                        }
@@ -219,7 +195,7 @@ namespace System.Linq.Expressions {
                                // Use IsNumber to avoid expensive reflection.
                                if (IsNumber (ultype)){
                                        if (ultype == urtype && ltype == rtype)
-                                               return method;
+                                               return null;
 
                                        if (oper_name != null){
                                                method = GetBinaryOperator (oper_name, rtype, left, right);
@@ -257,34 +233,37 @@ namespace System.Linq.Expressions {
                        if (right == null)
                                throw new ArgumentNullException ("right");
 
-                       if (method == null){
+                       if (method == null) {
                                // avoid reflection shortcut and catches Ints/bools before we check Numbers in general
-                               if (left.Type == right.Type && (left.Type == typeof (bool) || IsInt (left.Type)))
-                                       return method;
+                               if (left.Type == right.Type && IsIntOrBool (left.Type))
+                                       return null;
                        }
 
                        method = BinaryCoreCheck (oper_name, left, right, method);
-                       if (method == null){
-                               //
+                       if (method == null) {
                                // The check in BinaryCoreCheck allows a bit more than we do
                                // (floats and doubles).  Catch this here
-                               //
-                               if (left.Type == typeof(double) || left.Type == typeof(float))
+                               if (left.Type == typeof (double) || left.Type == typeof (float))
                                        throw new InvalidOperationException ("Types not supported");
                        }
 
                        return method;
                }
 
+               static Type GetResultType (Expression expression, MethodInfo method)
+               {
+                       return method == null ? expression.Type : method.ReturnType;
+               }
+
                static BinaryExpression MakeSimpleBinary (ExpressionType et, Expression left, Expression right, MethodInfo method)
                {
-                       Type result = method == null ? left.Type : method.ReturnType;
                        bool is_lifted;
 
-                       if (method == null){
-                               if (IsNullable (left.Type)){
+                       if (method == null) {
+                               if (IsNullable (left.Type)) {
                                        if (!IsNullable (right.Type))
-                                               throw new Exception ("Assertion, internal error: left is nullable, requires right to be as well");
+                                               throw new InvalidOperationException ("Assertion, internal error: left is nullable, requires right to be as well");
+
                                        is_lifted = true;
                                } else
                                        is_lifted = false;
@@ -295,14 +274,12 @@ namespace System.Linq.Expressions {
                                is_lifted = false;
                        }
 
-                       return new BinaryExpression (et, result, left, right, false, is_lifted, method, null);
+                       return new BinaryExpression (et, GetResultType (left, method), left, right, is_lifted, is_lifted, method, null);
                }
 
                static UnaryExpression MakeSimpleUnary (ExpressionType et, Expression expression, MethodInfo method)
                {
-                       Type result = method == null ? expression.Type : method.ReturnType;
-
-                       return new UnaryExpression (et, expression, result, method);
+                       return new UnaryExpression (et, expression, GetResultType (expression, method), method);
                }
 
                static BinaryExpression MakeBoolBinary (ExpressionType et, Expression left, Expression right, bool liftToNull, MethodInfo method)
@@ -314,28 +291,27 @@ namespace System.Linq.Expressions {
                        bool rnullable = IsNullable (rtype);
                        bool is_lifted;
 
-                       //
                        // Implement the rules as described in "Expression.Equal" method.
-                       //
-                       if (method == null){
-                               if (lnullable == false && rnullable == false){
+                       if (method == null) {
+                               if (!lnullable && !rnullable) {
                                        is_lifted = false;
+                                       liftToNull = false;
                                        result = typeof (bool);
-                               } else if (lnullable && rnullable){
+                               } else if (lnullable && rnullable) {
                                        is_lifted = true;
                                        result = liftToNull ? typeof(bool?) : typeof (bool);
                                } else
-                                       throw new Exception ("Internal error: this should have been caught in BinaryCoreCheck");
+                                       throw new InvalidOperationException ("Internal error: this should have been caught in BinaryCoreCheck");
                        } else {
                                ParameterInfo [] pi = method.GetParameters ();
                                Type mltype = pi [0].ParameterType;
                                Type mrtype = pi [1].ParameterType;
 
-                               if (ltype == mltype && rtype == mrtype){
+                               if (ltype == mltype && rtype == mrtype) {
                                        is_lifted = false;
+                                       liftToNull = false;
                                        result = method.ReturnType;
-                               }
-                               else if (ltype.IsValueType && rtype.IsValueType &&
+                               } else if (ltype.IsValueType && rtype.IsValueType &&
                                           ((lnullable && GetNullableOf (ltype) == mltype) ||
                                                (rnullable && GetNullableOf (rtype) == mrtype))){
                                        is_lifted = true;
@@ -351,10 +327,10 @@ namespace System.Linq.Expressions {
                                                // See:
                                                // https://connect.microsoft.com/VisualStudio/feedback/ViewFeedback.aspx?FeedbackID=323139
                                                result = typeof (Nullable<>).MakeGenericType (method.ReturnType);
-                                                       //Type.GetType ("System.Nullable`1[" + method.ReturnType.ToString () + "]");
                                        }
                                } else {
                                        is_lifted = false;
+                                       liftToNull = false;
                                        result = method.ReturnType;
                                }
                        }
@@ -386,15 +362,10 @@ namespace System.Linq.Expressions {
                {
                        method = BinaryCoreCheck ("op_Addition", left, right, method);
 
-                       //
                        // The check in BinaryCoreCheck allows a bit more than we do
                        // (byte, sbyte).  Catch that here
-                       //
-
-                       if (method == null){
-                               Type ltype = left.Type;
-
-                               if (ltype == typeof (byte) || ltype == typeof (sbyte))
+                       if (method == null) {
+                               if (left.Type == typeof (byte) || left.Type == typeof (sbyte))
                                        throw new InvalidOperationException (String.Format ("AddChecked not defined for {0} and {1}", left.Type, right.Type));
                        }
 
@@ -409,6 +380,7 @@ namespace System.Linq.Expressions {
                public static BinaryExpression Subtract (Expression left, Expression right, MethodInfo method)
                {
                        method = BinaryCoreCheck ("op_Subtraction", left, right, method);
+
                        return MakeSimpleBinary (ExpressionType.Subtract, left, right, method);
                }
 
@@ -421,17 +393,13 @@ namespace System.Linq.Expressions {
                {
                        method = BinaryCoreCheck ("op_Subtraction", left, right, method);
 
-                       //
                        // The check in BinaryCoreCheck allows a bit more than we do
                        // (byte, sbyte).  Catch that here
-                       //
-
-                       if (method == null){
-                               Type ltype = left.Type;
-
-                               if (ltype == typeof (byte) || ltype == typeof (sbyte))
+                       if (method == null) {
+                               if (left.Type == typeof (byte) || left.Type == typeof (sbyte))
                                        throw new InvalidOperationException (String.Format ("SubtractChecked not defined for {0} and {1}", left.Type, right.Type));
                        }
+
                        return MakeSimpleBinary (ExpressionType.SubtractChecked, left, right, method);
                }
 
@@ -573,7 +541,7 @@ namespace System.Linq.Expressions {
                {
                        method = ConditionalBinaryCheck ("op_BitwiseAnd", left, right, method);
 
-                       return MakeBoolBinary (ExpressionType.AndAlso, left, right, false, method);
+                       return MakeBoolBinary (ExpressionType.AndAlso, left, right, true, method);
                }
 
                static MethodInfo ConditionalBinaryCheck (string oper, Expression left, Expression right, MethodInfo method)
@@ -581,11 +549,10 @@ namespace System.Linq.Expressions {
                        method = BinaryCoreCheck (oper, left, right, method);
 
                        if (method == null) {
-                               if (left.Type != typeof (bool))
+                               if (GetNotNullableOf (left.Type) != typeof (bool))
                                        throw new InvalidOperationException ("Only booleans are allowed");
                        } else {
                                // The method should have identical parameter and return types.
-
                                if (left.Type != right.Type || method.ReturnType != left.Type)
                                        throw new ArgumentException ("left, right and return type must match");
                        }
@@ -602,7 +569,7 @@ namespace System.Linq.Expressions {
                {
                        method = ConditionalBinaryCheck ("op_BitwiseOr", left, right, method);
 
-                       return MakeBoolBinary (ExpressionType.OrElse, left, right, false, method);
+                       return MakeBoolBinary (ExpressionType.OrElse, left, right, true, method);
                }
 
                //
@@ -686,7 +653,7 @@ namespace System.Linq.Expressions {
                // Miscelaneous
                //
 
-               static void ArrayCheck (Expression array)
+               static void CheckArray (Expression array)
                {
                        if (array == null)
                                throw new ArgumentNullException ("array");
@@ -696,7 +663,8 @@ namespace System.Linq.Expressions {
 
                public static BinaryExpression ArrayIndex (Expression array, Expression index)
                {
-                       ArrayCheck (array);
+                       CheckArray (array);
+
                        if (index == null)
                                throw new ArgumentNullException ("index");
                        if (array.Type.GetArrayRank () != 1)
@@ -819,7 +787,7 @@ namespace System.Linq.Expressions {
                        throw new ArgumentException ("MakeBinary expect a binary node type");
                }
 
-#endregion
+               #endregion
 
                public static MethodCallExpression ArrayIndex (Expression array, params Expression [] indexes)
                {
@@ -828,7 +796,7 @@ namespace System.Linq.Expressions {
 
                public static MethodCallExpression ArrayIndex (Expression array, IEnumerable<Expression> indexes)
                {
-                       ArrayCheck (array);
+                       CheckArray (array);
 
                        if (indexes == null)
                                throw new ArgumentNullException ("indexes");
@@ -934,6 +902,25 @@ namespace System.Linq.Expressions {
                        return new MethodCallExpression (instance, method, args);
                }
 
+               static Type [] CollectTypes (IEnumerable<Expression> expressions)
+               {
+                       return (from arg in expressions select arg.Type).ToArray ();
+               }
+
+               static MethodInfo TryMakeGeneric (MethodInfo method, Type [] args)
+               {
+                       if (method == null)
+                               return null;
+
+                       if (!method.IsGenericMethod && args == null)
+                               return method;
+
+                       if (args.Length == method.GetGenericArguments ().Length)
+                               return method.MakeGenericMethod (args);
+
+                       return null;
+               }
+
                public static MethodCallExpression Call (Expression instance, string methodName, Type [] typeArguments, params Expression [] arguments)
                {
                        if (instance == null)
@@ -941,16 +928,13 @@ namespace System.Linq.Expressions {
                        if (methodName == null)
                                throw new ArgumentNullException ("methodName");
 
-                       if (typeArguments == null)
-                               typeArguments = new Type [0];
-
-                       var method = instance.Type.GetMethod (methodName, AllInstance, null, typeArguments, null);
+                       var method = instance.Type.GetMethod (methodName, AllInstance, null, CollectTypes (arguments), null);
+                       method = TryMakeGeneric (method, typeArguments);
                        if (method == null)
                                throw new InvalidOperationException ("No such method");
 
                        var args = arguments.ToReadOnlyCollection ();
-                       if (typeArguments.Length != args.Count)
-                               throw new InvalidOperationException ("Argument count doesn't match parameters length");
+                       CheckMethodArguments (method, args);
 
                        return new MethodCallExpression (instance, method, args);
                }
@@ -962,16 +946,13 @@ namespace System.Linq.Expressions {
                        if (methodName == null)
                                throw new ArgumentNullException ("methodName");
 
-                       if (typeArguments == null)
-                               typeArguments = new Type [0];
-
-                       var method = type.GetMethod (methodName, AllStatic, null, typeArguments, null);
+                       var method = type.GetMethod (methodName, AllStatic, null, CollectTypes (arguments), null);
+                       method = TryMakeGeneric (method, typeArguments);
                        if (method == null)
                                throw new InvalidOperationException ("No such method");
 
                        var args = arguments.ToReadOnlyCollection ();
-                       if (typeArguments.Length != args.Count)
-                               throw new InvalidOperationException ("Argument count doesn't match parameters length");
+                       CheckMethodArguments (method, args);
 
                        return new MethodCallExpression (method, args);
                }
@@ -1460,39 +1441,53 @@ namespace System.Linq.Expressions {
                        throw new ArgumentException ("MakeUnary expect an unary operator");
                }
 
-               [MonoTODO]
-               public static MemberMemberBinding MemberBind (MemberInfo member, params MemberBinding [] binding)
+               public static MemberMemberBinding MemberBind (MemberInfo member, params MemberBinding [] bindings)
                {
-                       throw new NotImplementedException ();
+                       return MemberBind (member, bindings as IEnumerable<MemberBinding>);
                }
 
-               [MonoTODO]
-               public static MemberMemberBinding MemberBind (MemberInfo member, IEnumerable<MemberBinding> binding)
+               public static MemberMemberBinding MemberBind (MemberInfo member, IEnumerable<MemberBinding> bindings)
                {
-                       throw new NotImplementedException ();
-               }
+                       if (member == null)
+                               throw new ArgumentNullException ("member");
 
-               [MonoTODO]
-               public static MemberMemberBinding MemberBind (MethodInfo propertyAccessor, params MemberBinding [] binding)
-               {
-                       throw new NotImplementedException ();
+                       Type type = null;
+                       switch (member.MemberType) {
+                       case MemberTypes.Field:
+                               type = (member as FieldInfo).FieldType;
+                               break;
+                       case MemberTypes.Property:
+                               type = (member as PropertyInfo).PropertyType;
+                               break;
+                       default:
+                               throw new ArgumentException ("Member is neither a field or a property");
+                       }
+
+                       return new MemberMemberBinding (member, CheckMemberBindings (type, bindings));
                }
 
-               [MonoTODO]
-               public static MemberMemberBinding MemberBind (MethodInfo propertyAccessor, IEnumerable<MemberBinding> binding)
+               public static MemberMemberBinding MemberBind (MethodInfo propertyAccessor, params MemberBinding [] bindings)
                {
-                       throw new NotImplementedException ();
+                       return MemberBind (propertyAccessor, bindings as IEnumerable<MemberBinding>);
                }
 
-               public static MemberInitExpression MemberInit (NewExpression newExpression, params MemberBinding [] bindings)
+               public static MemberMemberBinding MemberBind (MethodInfo propertyAccessor, IEnumerable<MemberBinding> bindings)
                {
-                       return MemberInit (newExpression, bindings as IEnumerable<MemberBinding>);
+                       if (propertyAccessor == null)
+                               throw new ArgumentNullException ("propertyAccessor");
+
+                       var bds = bindings.ToReadOnlyCollection ();
+                       CheckForNull (bds, "bindings");
+
+                       var prop = GetAssociatedProperty (propertyAccessor);
+                       if (prop == null)
+                               throw new ArgumentException ("propertyAccessor");
+
+                       return new MemberMemberBinding (prop, CheckMemberBindings (prop.PropertyType, bindings));
                }
 
-               public static MemberInitExpression MemberInit (NewExpression newExpression, IEnumerable<MemberBinding> bindings)
+               static ReadOnlyCollection<MemberBinding> CheckMemberBindings (Type type, IEnumerable<MemberBinding> bindings)
                {
-                       if (newExpression == null)
-                               throw new ArgumentNullException ("newExpression");
                        if (bindings == null)
                                throw new ArgumentNullException ("bindings");
 
@@ -1500,10 +1495,23 @@ namespace System.Linq.Expressions {
                        CheckForNull (bds, "bindings");
 
                        foreach (var binding in bds)
-                               if (!binding.Member.DeclaringType.IsAssignableFrom (newExpression.Type))
-                                       throw new ArgumentException ("Expression type not assignable to member type");
+                               if (!binding.Member.DeclaringType.IsAssignableFrom (type))
+                                       throw new ArgumentException ("Type not assignable to member type");
 
-                       return new MemberInitExpression (newExpression, bds);
+                       return bds;
+               }
+
+               public static MemberInitExpression MemberInit (NewExpression newExpression, params MemberBinding [] bindings)
+               {
+                       return MemberInit (newExpression, bindings as IEnumerable<MemberBinding>);
+               }
+
+               public static MemberInitExpression MemberInit (NewExpression newExpression, IEnumerable<MemberBinding> bindings)
+               {
+                       if (newExpression == null)
+                               throw new ArgumentNullException ("newExpression");
+
+                       return new MemberInitExpression (newExpression, CheckMemberBindings (newExpression.Type, bindings));
                }
 
                public static UnaryExpression Negate (Expression expression)
@@ -1546,11 +1554,16 @@ namespace System.Linq.Expressions {
                        if (type == null)
                                throw new ArgumentNullException ("type");
 
+                       var args = (null as IEnumerable<Expression>).ToReadOnlyCollection ();
+
+                       if (type.IsValueType)
+                               return new NewExpression (type, args);
+
                        var ctor = type.GetConstructor (Type.EmptyTypes);
                        if (ctor == null)
                                throw new ArgumentException ("Type doesn't have a parameter less constructor");
 
-                       return new NewExpression (ctor, (null as IEnumerable<Expression>).ToReadOnlyCollection (), null);
+                       return new NewExpression (ctor, args, null);
                }
 
                public static NewExpression New (ConstructorInfo constructor, params Expression [] arguments)
@@ -1824,6 +1837,27 @@ namespace System.Linq.Expressions {
                        return MakeSimpleUnary (ExpressionType.UnaryPlus, expression, method);
                }
 
+               static bool IsInt (Type t)
+               {
+                       return t == typeof (byte) || t == typeof (sbyte) ||
+                               t == typeof (short) || t == typeof (ushort) ||
+                               t == typeof (int) || t == typeof (uint) ||
+                               t == typeof (long) || t == typeof (ulong);
+               }
+
+               static bool IsIntOrBool (Type t)
+               {
+                       return IsInt (t) || t == typeof (bool);
+               }
+
+               static bool IsNumber (Type t)
+               {
+                       if (IsInt (t))
+                               return true;
+
+                       return t == typeof (float) || t == typeof (double) || t == typeof (decimal);
+               }
+
                internal static bool IsNullable (Type type)
                {
                        return type.IsGenericType && type.GetGenericTypeDefinition () == typeof (Nullable<>);
@@ -1834,7 +1868,10 @@ namespace System.Linq.Expressions {
                        if (t.IsPointer)
                                return IsUnsigned (t.GetElementType ());
 
-                       return t == typeof (ushort) || t == typeof (uint) || t == typeof (ulong) || t == typeof (byte);
+                       return t == typeof (ushort) ||
+                               t == typeof (uint) ||
+                               t == typeof (ulong) ||
+                               t == typeof (byte);
                }
 
                //
@@ -1845,6 +1882,11 @@ namespace System.Linq.Expressions {
                        return type.GetGenericArguments () [0];
                }
 
+               internal static Type GetNotNullableOf (Type type)
+               {
+                       return IsNullable (type) ? GetNullableOf (type) : type;
+               }
+
                //
                // This method must be overwritten by derived classes to
                // compile the expression