Implements lifted nullable comparison with null
authorMarek Safar <marek.safar@gmail.com>
Fri, 16 Aug 2013 13:18:49 +0000 (15:18 +0200)
committerMarek Safar <marek.safar@gmail.com>
Fri, 16 Aug 2013 13:18:49 +0000 (15:18 +0200)
mcs/class/System.Core/System.Linq.Expressions/BinaryExpression.cs
mcs/class/System.Core/System.Linq.Expressions/ConstantExpression.cs
mcs/class/System.Core/System.Linq.Expressions/Expression.cs
mcs/class/System.Core/Test/System.Linq.Expressions/ExpressionTest_Equal.cs

index 03f5f5ee492870bb104c560872620fffbc2e7fe6..b0528917b2c0c052d248dfea792d0fd0c0bad41d 100644 (file)
@@ -570,12 +570,48 @@ namespace System.Linq.Expressions {
 
                void EmitRelationalBinary (EmitContext ec)
                {
-                       if (!IsLifted)
+                       if (!IsLifted) {
                                EmitNonLiftedBinary (ec);
-                       else if (IsLiftedToNull)
+                               return;
+                       }
+
+                       if (IsLiftedToNull) {
                                EmitLiftedToNullBinary (ec);
-                       else
-                               EmitLiftedRelationalBinary (ec);
+                               return;
+                       }
+
+                       if (ConstantExpression.IsNull (right) && !ConstantExpression.IsNull (left) && left.Type.IsNullable ()) {
+                               EmitNullEquality (ec, left);
+                               return;
+                       }
+
+                       if (ConstantExpression.IsNull (left) && !ConstantExpression.IsNull (right) && right.Type.IsNullable ()) {
+                               EmitNullEquality (ec, right);
+                               return;
+                       }
+
+                       EmitLiftedRelationalBinary (ec);
+               }
+
+               void EmitNullEquality (EmitContext ec, Expression e)
+               {
+                       var ig = ec.ig;
+
+                       if (IsLiftedToNull) {
+                               e.Emit (ec);
+                               if (e.Type != typeof (void))
+                                       ig.Emit (OpCodes.Pop);
+
+                               ec.EmitNullableNew (typeof (bool?));
+                               return;
+                       }
+
+                       var se = ec.EmitStored (e);
+                       ec.EmitNullableHasValue (se);
+                       if (NodeType == ExpressionType.Equal) {
+                               ig.Emit (OpCodes.Ldc_I4_0);
+                               ig.Emit (OpCodes.Ceq);
+                       }               
                }
 
                void EmitLiftedUserDefinedOperator (EmitContext ec)
index 10dd69ce8f6aaacf4768b62328749ca99a09fad4..2ff30c8b81c538b11470a40c4e855c931ad48a95 100644 (file)
@@ -51,6 +51,12 @@ namespace System.Linq.Expressions {
                {
                        this.value = value;
                }
+               
+               internal static bool IsNull (Expression e)
+               {
+                       var c = e as ConstantExpression;
+                       return c != null && c.value == null;
+               }
 
 #if !FULL_AOT_RUNTIME
                internal override void Emit (EmitContext ec)
index f79156284c2ef6e037613f1117aa09b5d698b815..696da9ff91c14074c2986996e4cc6a470fbd1fb1 100644 (file)
@@ -272,6 +272,12 @@ namespace System.Linq.Expressions {
 
                                        if (ltype == rtype && ultype == typeof (bool))
                                                return null;
+
+                                       if (ltype.IsNullable () && ConstantExpression.IsNull (right) && !ConstantExpression.IsNull (left))
+                                               return null;
+
+                                       if (rtype.IsNullable () && ConstantExpression.IsNull (left) && !ConstantExpression.IsNull (right))
+                                               return null;
                                }
 
                                if (oper_name == "op_LeftShift" || oper_name == "op_RightShift") {
@@ -392,12 +398,16 @@ namespace System.Linq.Expressions {
                                if (!left.Type.IsNullable () && !right.Type.IsNullable ()) {
                                        is_lifted = false;
                                        liftToNull = false;
-                                       type = typeof (bool);
+                                       type = typeof(bool);
                                } else if (left.Type.IsNullable () && right.Type.IsNullable ()) {
                                        is_lifted = true;
-                                       type = liftToNull ? typeof (bool?) : typeof (bool);
-                               } else
+                                       type = liftToNull ? typeof(bool?) : typeof(bool);
+                               } else if (ConstantExpression.IsNull (left) || ConstantExpression.IsNull (right)) {
+                                       is_lifted = true;
+                                       type = typeof (bool);
+                               } else {                        
                                        throw new InvalidOperationException ();
+                               }
                        } else {
                                var parameters = method.GetParameters ();
 
index 06612bdfa1820c312c899e87bfc69b27cde54db3..749d78e40f8dcec109c137a286ebb06c09054468 100644 (file)
@@ -458,5 +458,23 @@ namespace MonoTests.System.Linq.Expressions
                        Assert.AreEqual (false, eq (Foo.Bar, null));
                        Assert.AreEqual (true, eq (null, null));
                }
+
+               [Test]
+               public void NullableNullEqual ()
+               {
+                       var param = Expression.Parameter (typeof (DateTime?), "x");
+
+                       var node = Expression.Equal (param, Expression.Constant (null));
+
+                       Assert.IsTrue (node.IsLifted);
+                       Assert.IsFalse (node.IsLiftedToNull);
+                       Assert.AreEqual (typeof (bool), node.Type);
+                       Assert.IsNull (node.Method);
+
+                       var eq = Expression.Lambda<Func<DateTime?, bool>> (node, new [] { param }).Compile ();
+
+                       Assert.AreEqual (true, eq (null));
+                       Assert.AreEqual (false, eq (DateTime.Now));
+               }
        }
 }