Merge pull request #498 from Unroll-Me/master
[mono.git] / mcs / class / System.Core / System.Linq.Expressions / BinaryExpression.cs
index 8dd56b9d6a9c82d0ea312a3b53bad9c72f6bf127..143f57513f47f7ba7c38307bf8729162d48b5d95 100644 (file)
@@ -97,13 +97,7 @@ namespace System.Linq.Expressions {
                        this.is_lifted = is_lifted;
                }
 
-               void EmitMethod (EmitContext ec)
-               {
-                       left.Emit (ec);
-                       right.Emit (ec);
-                       ec.EmitCall (method);
-               }
-
+#if !FULL_AOT_RUNTIME
                void EmitArrayAccess (EmitContext ec)
                {
                        left.Emit (ec);
@@ -187,6 +181,36 @@ namespace System.Linq.Expressions {
                        ig.MarkLabel (done);
                }
 
+               MethodInfo GetFalseOperator ()
+               {
+                       return GetFalseOperator (left.Type.GetNotNullableType ());
+               }
+
+               MethodInfo GetTrueOperator ()
+               {
+                       return GetTrueOperator (left.Type.GetNotNullableType ());
+               }
+
+               void EmitUserDefinedLogicalShortCircuit (EmitContext ec)
+               {
+                       var ig = ec.ig;
+                       var and = NodeType == ExpressionType.AndAlso;
+
+                       var done = ig.DefineLabel ();
+
+                       var left = ec.EmitStored (this.left);
+
+                       ec.EmitLoad (left);
+                       ig.Emit (OpCodes.Dup);
+                       ec.EmitCall (and ? GetFalseOperator () : GetTrueOperator ());
+                       ig.Emit (OpCodes.Brtrue, done);
+
+                       ec.Emit (this.right);
+                       ec.EmitCall (method);
+
+                       ig.MarkLabel (done);
+               }
+
                void EmitLiftedLogicalShortCircuit (EmitContext ec)
                {
                        var ig = ec.ig;
@@ -231,11 +255,11 @@ namespace System.Linq.Expressions {
                        ig.Emit (and ? OpCodes.Ldc_I4_0 : OpCodes.Ldc_I4_1);
 
                        ig.MarkLabel (ret_new);
-                       ec.EmitNullableNew (typeof (bool?));
+                       ec.EmitNullableNew (Type);
                        ig.Emit (OpCodes.Br, done);
 
                        ig.MarkLabel (ret_null);
-                       var ret = ig.DeclareLocal (typeof (bool?));
+                       var ret = ig.DeclareLocal (Type);
                        ec.EmitNullableInitialize (ret);
 
                        ig.MarkLabel (done);
@@ -248,6 +272,36 @@ namespace System.Linq.Expressions {
                        var load_right = ig.DefineLabel ();
 
                        var left = ec.EmitStored (this.left);
+                       var left_is_nullable = left.LocalType.IsNullable ();
+
+                       if (left_is_nullable)
+                               ec.EmitNullableHasValue (left);
+                       else
+                               ec.EmitLoad (left);
+
+                       ig.Emit (OpCodes.Brfalse, load_right);
+
+                       if (left_is_nullable && !Type.IsNullable ())
+                               ec.EmitNullableGetValue (left);
+                       else
+                               ec.EmitLoad (left);
+
+                       ig.Emit (OpCodes.Br, done);
+
+                       ig.MarkLabel (load_right);
+                       ec.Emit (this.right);
+
+                       ig.MarkLabel (done);
+               }
+
+               void EmitConvertedCoalesce (EmitContext ec)
+               {
+                       var ig = ec.ig;
+                       var done = ig.DefineLabel ();
+                       var load_right = ig.DefineLabel ();
+
+                       var left = ec.EmitStored (this.left);
+
                        if (left.LocalType.IsNullable ())
                                ec.EmitNullableHasValue (left);
                        else
@@ -255,7 +309,10 @@ namespace System.Linq.Expressions {
 
                        ig.Emit (OpCodes.Brfalse, load_right);
 
+                       ec.Emit (conversion);
                        ec.EmitLoad (left);
+                       ig.Emit (OpCodes.Callvirt, conversion.Type.GetInvokeMethod ());
+
                        ig.Emit (OpCodes.Br, done);
 
                        ig.MarkLabel (load_right);
@@ -314,10 +371,13 @@ namespace System.Linq.Expressions {
                                ig.Emit (is_unsigned ? OpCodes.Rem_Un : OpCodes.Rem);
                                break;
                        case ExpressionType.RightShift:
-                               ig.Emit (is_unsigned ? OpCodes.Shr_Un : OpCodes.Shr);
-                               break;
                        case ExpressionType.LeftShift:
-                               ig.Emit (OpCodes.Shl);
+                               ig.Emit (OpCodes.Ldc_I4, left.Type == typeof (int) ? 0x1f : 0x3f);
+                               ig.Emit (OpCodes.And);
+                               if (NodeType == ExpressionType.RightShift)
+                                       ig.Emit (is_unsigned ? OpCodes.Shr_Un : OpCodes.Shr);
+                               else
+                                       ig.Emit (OpCodes.Shl);
                                break;
                        case ExpressionType.And:
                                ig.Emit (OpCodes.And);
@@ -369,9 +429,46 @@ namespace System.Linq.Expressions {
                        }
                }
 
+               bool IsLeftLiftedBinary ()
+               {
+                       return left.Type.IsNullable () && !right.Type.IsNullable ();
+               }
+
+               void EmitLeftLiftedToNullBinary (EmitContext ec)
+               {
+                       var ig = ec.ig;
+
+                       var ret = ig.DefineLabel ();
+                       var done = ig.DefineLabel ();
+
+                       var left = ec.EmitStored (this.left);
+
+                       ec.EmitNullableHasValue (left);
+                       ig.Emit (OpCodes.Brfalse, ret);
+
+                       ec.EmitNullableGetValueOrDefault (left);
+                       ec.Emit (right);
+
+                       EmitBinaryOperator (ec);
+
+                       ec.EmitNullableNew (Type);
+
+                       ig.Emit (OpCodes.Br, done);
+
+                       ig.MarkLabel (ret);
+
+                       var temp = ig.DeclareLocal (Type);
+                       ec.EmitNullableInitialize (temp);
+
+                       ig.MarkLabel (done);
+               }
+
                void EmitLiftedArithmeticBinary (EmitContext ec)
                {
-                       EmitLiftedToNullBinary (ec);
+                       if (IsLeftLiftedBinary ())
+                               EmitLeftLiftedToNullBinary (ec);
+                       else
+                               EmitLiftedToNullBinary (ec);
                }
 
                void EmitLiftedToNullBinary (EmitContext ec)
@@ -479,10 +576,154 @@ namespace System.Linq.Expressions {
                                EmitLiftedRelationalBinary (ec);
                }
 
+               void EmitLiftedUserDefinedOperator (EmitContext ec)
+               {
+                       var ig = ec.ig;
+
+                       var ret_true = ig.DefineLabel ();
+                       var ret_false = ig.DefineLabel ();
+                       var done = ig.DefineLabel ();
+
+                       var left = ec.EmitStored (this.left);
+                       var right = ec.EmitStored (this.right);
+
+                       ec.EmitNullableHasValue (left);
+                       ec.EmitNullableHasValue (right);
+                       switch (NodeType) {
+                       case ExpressionType.Equal:
+                               ig.Emit (OpCodes.Bne_Un, ret_false);
+                               ec.EmitNullableHasValue (left);
+                               ig.Emit (OpCodes.Brfalse, ret_true);
+                               break;
+                       case ExpressionType.NotEqual:
+                               ig.Emit (OpCodes.Bne_Un, ret_true);
+                               ec.EmitNullableHasValue (left);
+                               ig.Emit (OpCodes.Brfalse, ret_false);
+                               break;
+                       default:
+                               ig.Emit (OpCodes.And);
+                               ig.Emit (OpCodes.Brfalse, ret_false);
+                               break;
+                       }
+
+                       ec.EmitNullableGetValueOrDefault (left);
+                       ec.EmitNullableGetValueOrDefault (right);
+                       ec.EmitCall (method);
+                       ig.Emit (OpCodes.Br, done);
+
+                       ig.MarkLabel (ret_true);
+                       ig.Emit (OpCodes.Ldc_I4_1);
+                       ig.Emit (OpCodes.Br, done);
+
+                       ig.MarkLabel (ret_false);
+                       ig.Emit (OpCodes.Ldc_I4_0);
+                       ig.Emit (OpCodes.Br, done);
+
+                       ig.MarkLabel (done);
+               }
+
+               void EmitLiftedToNullUserDefinedOperator (EmitContext ec)
+               {
+                       var ig = ec.ig;
+
+                       var ret = ig.DefineLabel ();
+                       var done = ig.DefineLabel ();
+
+                       var left = ec.EmitStored (this.left);
+                       var right = ec.EmitStored (this.right);
+
+                       ec.EmitNullableHasValue (left);
+                       ec.EmitNullableHasValue (right);
+                       ig.Emit (OpCodes.And);
+                       ig.Emit (OpCodes.Brfalse, ret);
+
+                       ec.EmitNullableGetValueOrDefault (left);
+                       ec.EmitNullableGetValueOrDefault (right);
+                       ec.EmitCall (method);
+                       ec.EmitNullableNew (Type);
+                       ig.Emit (OpCodes.Br, done);
+
+                       ig.MarkLabel (ret);
+                       var temp = ig.DeclareLocal (Type);
+                       ec.EmitNullableInitialize (temp);
+
+                       ig.MarkLabel (done);
+               }
+
+               void EmitUserDefinedLiftedLogicalShortCircuit (EmitContext ec)
+               {
+                       var ig = ec.ig;
+                       var and = NodeType == ExpressionType.AndAlso;
+
+                       var left_is_null = ig.DefineLabel ();
+                       var ret_left = ig.DefineLabel ();
+                       var ret_null = ig.DefineLabel ();
+                       var done = ig.DefineLabel ();
+
+                       var left = ec.EmitStored (this.left);
+
+                       ec.EmitNullableHasValue (left);
+                       ig.Emit (OpCodes.Brfalse, and ? ret_null : left_is_null);
+
+                       ec.EmitNullableGetValueOrDefault (left);
+                       ec.EmitCall (and ? GetFalseOperator () : GetTrueOperator ());
+                       ig.Emit (OpCodes.Brtrue, ret_left);
+
+                       ig.MarkLabel (left_is_null);
+                       var right = ec.EmitStored (this.right);
+                       ec.EmitNullableHasValue (right);
+                       ig.Emit (OpCodes.Brfalse, ret_null);
+
+                       ec.EmitNullableGetValueOrDefault (left);
+                       ec.EmitNullableGetValueOrDefault (right);
+                       ec.EmitCall (method);
+
+                       ec.EmitNullableNew (Type);
+                       ig.Emit (OpCodes.Br, done);
+
+                       ig.MarkLabel (ret_left);
+                       ec.EmitLoad (left);
+                       ig.Emit (OpCodes.Br, done);
+
+                       ig.MarkLabel (ret_null);
+                       var ret = ig.DeclareLocal (Type);
+                       ec.EmitNullableInitialize (ret);
+
+                       ig.MarkLabel (done);
+               }
+
+               void EmitUserDefinedOperator (EmitContext ec)
+               {
+                       if (!IsLifted) {
+                               switch (NodeType) {
+                               case ExpressionType.AndAlso:
+                               case ExpressionType.OrElse:
+                                       EmitUserDefinedLogicalShortCircuit (ec);
+                                       break;
+                               default:
+                                       left.Emit (ec);
+                                       right.Emit (ec);
+                                       ec.EmitCall (method);
+                                       break;
+                               }
+                       } else if (IsLiftedToNull) {
+                               switch (NodeType) {
+                               case ExpressionType.AndAlso:
+                               case ExpressionType.OrElse:
+                                       EmitUserDefinedLiftedLogicalShortCircuit (ec);
+                                       break;
+                               default:
+                                       EmitLiftedToNullUserDefinedOperator (ec);
+                                       break;
+                               }
+                       }  else
+                               EmitLiftedUserDefinedOperator (ec);
+               }
+
                internal override void Emit (EmitContext ec)
                {
-                       if (method != null){
-                               EmitMethod (ec);
+                       if (method != null) {
+                               EmitUserDefinedOperator (ec);
                                return;
                        }
 
@@ -491,7 +732,10 @@ namespace System.Linq.Expressions {
                                EmitArrayAccess (ec);
                                return;
                        case ExpressionType.Coalesce:
-                               EmitCoalesce (ec);
+                               if (conversion != null)
+                                       EmitConvertedCoalesce (ec);
+                               else
+                                       EmitCoalesce (ec);
                                return;
                        case ExpressionType.Power:
                        case ExpressionType.Add:
@@ -525,5 +769,6 @@ namespace System.Linq.Expressions {
                                throw new NotSupportedException (this.NodeType.ToString ());
                        }
                }
+#endif
        }
 }