Implement hoisting of anonymous methods with variables inside async body
authorMarek Safar <marek.safar@gmail.com>
Sat, 13 Aug 2011 17:10:13 +0000 (18:10 +0100)
committerMarek Safar <marek.safar@gmail.com>
Sat, 13 Aug 2011 17:15:00 +0000 (18:15 +0100)
mcs/mcs/anonymous.cs
mcs/mcs/async.cs
mcs/mcs/context.cs
mcs/mcs/ecore.cs
mcs/mcs/expression.cs
mcs/mcs/iterators.cs
mcs/mcs/statement.cs
mcs/tests/test-async-16.cs [new file with mode: 0644]
mcs/tests/test-async-18.cs [new file with mode: 0644]
mcs/tests/ver-il-net_4_0.xml

index 117974b3ecc4db8650d8852c82b02106e38eaea6..11169c25aa854424e19ffdd2dcdcbec4b0b8f91a 100644 (file)
@@ -193,7 +193,7 @@ namespace Mono.CSharp {
                protected HoistedThis hoisted_this;
 
                // Local variable which holds this storey instance
-               public LocalTemporary Instance;
+               public Expression Instance;
 
                public AnonymousMethodStorey (Block block, TypeContainer parent, MemberBase host, TypeParameter[] tparams, string name)
                        : base (parent, MakeMemberName (host, name, unique_id, tparams, block.StartLocation),
@@ -236,7 +236,12 @@ namespace Mono.CSharp {
 
                protected Field AddCompilerGeneratedField (string name, FullNamedExpression type)
                {
-                       const Modifiers mod = Modifiers.INTERNAL | Modifiers.COMPILER_GENERATED;
+                       return AddCompilerGeneratedField (name, type, false);
+               }
+
+               protected Field AddCompilerGeneratedField (string name, FullNamedExpression type, bool privateAccess)
+               {
+                       Modifiers mod = Modifiers.COMPILER_GENERATED | (privateAccess ? Modifiers.PRIVATE : Modifiers.INTERNAL);
                        Field f = new Field (this, type, mod, new MemberName (name, Location), null);
                        AddField (f);
                        return f;
@@ -391,21 +396,47 @@ namespace Mono.CSharp {
                        SymbolWriter.OpenCompilerGeneratedBlock (ec);
 
                        //
-                       // Create an instance of a storey
+                       // Create an instance of this storey
                        //
-                       var storey_type_expr = CreateStoreyTypeExpression (ec);
-
                        ResolveContext rc = new ResolveContext (ec.MemberContext);
                        rc.CurrentBlock = block;
-                       Expression e = new New (storey_type_expr, null, Location).Resolve (rc);
-                       e.Emit (ec);
 
-                       Instance = new LocalTemporary (storey_type_expr.Type);
-                       Instance.Store (ec);
+                       var storey_type_expr = CreateStoreyTypeExpression (ec);
+                       var source = new New (storey_type_expr, null, Location).Resolve (rc);
+
+                       //
+                       // When the current context is async (or iterator) lift local storey
+                       // instantiation to the currect storey
+                       //
+                       if (ec.CurrentAnonymousMethod is StateMachineInitializer) {
+                               //
+                               // Unfortunately, normal capture mechanism could not be used because we are
+                               // too late in the pipeline and standart assign cannot be used either due to
+                               // recursive nature of GetStoreyInstanceExpression
+                               //
+                               var field = ec.CurrentAnonymousMethod.Storey.AddCompilerGeneratedField (
+                                       LocalVariable.GetCompilerGeneratedName (block), storey_type_expr, true);
+
+                               field.Define ();
+                               field.Emit ();
+
+                               var fexpr = new FieldExpr (field, Location);
+                               fexpr.InstanceExpression = new CompilerGeneratedThis (ec.CurrentType, Location);
+                               fexpr.EmitAssign (ec, source, false, false);
+
+                               Instance = fexpr;
+                       } else {
+                               var local = TemporaryVariableReference.Create (source.Type, block, Location);
+                               local.EmitAssign (ec, source);
+
+                               Instance = local;
+                       }
 
                        EmitHoistedFieldsInitialization (rc, ec);
 
-                       SymbolWriter.DefineScopeVariable (ID, Instance.Builder);
+                       // TODO: Implement properly
+                       //SymbolWriter.DefineScopeVariable (ID, Instance.Builder);
+
                        SymbolWriter.CloseCompilerGeneratedBlock (ec);
                }
 
@@ -526,10 +557,11 @@ namespace Mono.CSharp {
                        if (f == null) {
                                if (am.Storey == this) {
                                        //
-                                       // Access inside of same storey (S -> S)
+                                       // Access from inside of same storey (S -> S)
                                        //
                                        return new CompilerGeneratedThis (CurrentType, Location);
                                }
+
                                //
                                // External field access
                                //
@@ -770,12 +802,20 @@ namespace Mono.CSharp {
        {
                readonly string name;
 
-               public HoistedLocalVariable (AnonymousMethodStorey scope, LocalVariable local, string name)
-                       : base (scope, name, local.Type)
+               public HoistedLocalVariable (AnonymousMethodStorey storey, LocalVariable local, string name)
+                       : base (storey, name, local.Type)
                {
                        this.name = local.Name;
                }
 
+               //
+               // For compiler generated local variables
+               //
+               public HoistedLocalVariable (AnonymousMethodStorey storey, Field field)
+                       : base (storey, field)
+               {
+               }
+
                public override void EmitSymbolInfo ()
                {
                        SymbolWriter.DefineCapturedLocal (storey.ID, name, field.Name);
index c0847e9fcf5695a1ee262a8da22a061f03687557..129c97d24a1c37575ef07a9d9a3652c15f88a5e3 100644 (file)
@@ -610,7 +610,7 @@ namespace Mono.CSharp
 
                public Field AddAwaiter (TypeSpec type, Location loc)
                {
-                       return AddCompilerGeneratedField ("$awaiter" + awaiters++.ToString ("X"), new TypeExpression (type, loc));
+                       return AddCompilerGeneratedField ("$awaiter" + awaiters++.ToString ("X"), new TypeExpression (type, loc), true);
                }
 
                public Field AddCapturedLocalVariable (TypeSpec type)
@@ -618,7 +618,7 @@ namespace Mono.CSharp
                        if (mutator != null)
                                type = mutator.Mutate (type);
 
-                       var field = AddCompilerGeneratedField ("<s>$" + locals_captured++.ToString ("X"), new TypeExpression (type, Location));
+                       var field = AddCompilerGeneratedField ("<s>$" + locals_captured++.ToString ("X"), new TypeExpression (type, Location), true);
                        field.Define ();
 
                        return field;
@@ -628,7 +628,7 @@ namespace Mono.CSharp
                {
                        var action = Module.PredefinedTypes.Action.Resolve ();
                        if (action != null) {
-                               continuation = AddCompilerGeneratedField ("$continuation", new TypeExpression (action, Location));
+                               continuation = AddCompilerGeneratedField ("$continuation", new TypeExpression (action, Location), true);
                                continuation.ModFlags |= Modifiers.READONLY;
                        }
 
index 4c52e7615532083bc73ff4474c73ac14986ce660..2b4849d1a697dfa6c7d1fdc3ea3e9d6736eea302 100644 (file)
@@ -470,7 +470,7 @@ namespace Mono.CSharp
 
                        // FIXME: IsIterator is too aggressive, we should capture only if child
                        // block contains yield
-                       if (CurrentAnonymousMethod.IsIterator)
+                       if (CurrentAnonymousMethod.IsIterator || CurrentAnonymousMethod is AsyncInitializer)
                                return true;
 
                        return local.Block.ParametersBlock != CurrentBlock.ParametersBlock.Original;
index bcc4701dfb6ac889d67ef05c8af23100d519395a..090b8359b86f3ef1a7ec2645f5264251b2a78672 100644 (file)
@@ -6106,20 +6106,15 @@ namespace Mono.CSharp {
                        return new TemporaryVariableReference (li, loc);
                }
 
-               public override Expression CreateExpressionTree (ResolveContext ec)
-               {
-                       throw new NotSupportedException ("ET");
-               }
-
                protected override Expression DoResolve (ResolveContext ec)
                {
                        eclass = ExprClass.Variable;
 
                        //
                        // Don't capture temporary variables except when using
-                       // iterator redirection
+                       // state machine redirection
                        //
-                       if (ec.CurrentAnonymousMethod != null && ec.CurrentAnonymousMethod.IsIterator && ec.IsVariableCapturingRequired) {
+                       if (ec.CurrentAnonymousMethod != null && ec.CurrentAnonymousMethod is StateMachineInitializer && ec.IsVariableCapturingRequired) {
                                AnonymousMethodStorey storey = li.Block.Explicit.CreateAnonymousMethodStorey (ec);
                                storey.CaptureLocalVariable (ec, li);
                        }
@@ -6173,7 +6168,7 @@ namespace Mono.CSharp {
                }
 
                public override VariableInfo VariableInfo {
-                       get { throw new NotImplementedException (); }
+                       get { return null; }
                }
        }
 
index 0d02836d24a521615472b3376d78c7d03560a3fa..8b8dcf98bb003aae6b805c4832d94416d209078d 100644 (file)
@@ -4596,6 +4596,17 @@ namespace Mono.CSharp
                        return false;
                }
 
+               public override Expression CreateExpressionTree (ResolveContext ec)
+               {
+                       HoistedVariable hv = GetHoistedVariable (ec);
+                       if (hv != null)
+                               return hv.CreateExpressionTree ();
+
+                       Arguments arg = new Arguments (1);
+                       arg.Add (new Argument (this));
+                       return CreateExpressionFactoryCall (ec, "Constant", arg);
+               }
+
                public override Expression DoResolveLValue (ResolveContext rc, Expression right_side)
                {
                        if (IsLockedByStatement) {
@@ -4794,17 +4805,6 @@ namespace Mono.CSharp
                        local_info.AddressTaken = true;
                }
 
-               public override Expression CreateExpressionTree (ResolveContext ec)
-               {
-                       HoistedVariable hv = GetHoistedVariable (ec);
-                       if (hv != null)
-                               return hv.CreateExpressionTree ();
-
-                       Arguments arg = new Arguments (1);
-                       arg.Add (new Argument (this));
-                       return CreateExpressionFactoryCall (ec, "Constant", arg);
-               }
-
                void DoResolveBase (ResolveContext ec)
                {
                        VerifyAssigned (ec);
@@ -6830,16 +6830,6 @@ namespace Mono.CSharp
                        }
                }
 
-               public override Expression CreateExpressionTree (ResolveContext ec)
-               {
-                       Arguments args = new Arguments (1);
-                       args.Add (new Argument (this));
-                       
-                       // Use typeless constant for ldarg.0 to save some
-                       // space and avoid problems with anonymous stories
-                       return CreateExpressionFactoryCall (ec, "Constant", args);
-               }
-               
                protected override Expression DoResolve (ResolveContext ec)
                {
                        ResolveBase (ec);
index 54fa95118f6bd88021b9d42e45f857e0e435b94f..7a5920b4fd904eb358d48081380f2611477eb080 100644 (file)
@@ -185,7 +185,7 @@ namespace Mono.CSharp
 
                protected override string GetVariableMangledName (LocalVariable local_info)
                {
-                       return "<" + local_info.Name + ">__" + local_name_idx++.ToString ();
+                       return "<" + local_info.Name + ">__" + local_name_idx++.ToString ("X");
                }
        }
 
index d79ccc6efbf8f81d0df484e5f38e15108cd95171..9b233561e1dcf3c04c3109ba77316d485139e8a3 100644 (file)
@@ -1674,7 +1674,7 @@ namespace Mono.CSharp {
 
                public static LocalVariable CreateCompilerGenerated (TypeSpec type, Block block, Location loc)
                {
-                       LocalVariable li = new LocalVariable (block, "<$$>", Flags.CompilerGenerated | Flags.Used, loc);
+                       LocalVariable li = new LocalVariable (block, GetCompilerGeneratedName (block), Flags.CompilerGenerated | Flags.Used, loc);
                        li.Type = type;
                        return li;
                }
@@ -1710,6 +1710,11 @@ namespace Mono.CSharp {
                        ec.Emit (OpCodes.Ldloca, builder);
                }
 
+               public static string GetCompilerGeneratedName (Block block)
+               {
+                       return "$locvar" + block.ParametersBlock.TemporaryLocalsCount++.ToString ("X");
+               }
+
                public string GetReadOnlyContext ()
                {
                        switch (flags & Flags.ReadonlyMask) {
@@ -2506,6 +2511,8 @@ namespace Mono.CSharp {
                        }
                }
 
+               public int TemporaryLocalsCount { get; set; }
+
                #endregion
 
                // <summary>
@@ -4203,14 +4210,14 @@ namespace Mono.CSharp {
                        // Have to keep original lock value around to unlock same location
                        // in the case the original has changed or is null
                        //
-                       expr_copy = TemporaryVariableReference.Create (ec.BuiltinTypes.Object, ec.CurrentBlock.Parent, loc);
+                       expr_copy = TemporaryVariableReference.Create (ec.BuiltinTypes.Object, ec.CurrentBlock, loc);
                        expr_copy.Resolve (ec);
 
                        //
                        // Ensure Monitor methods are available
                        //
                        if (ResolvePredefinedMethods (ec) > 1) {
-                               lock_taken = TemporaryVariableReference.Create (ec.BuiltinTypes.Bool, ec.CurrentBlock.Parent, loc);
+                               lock_taken = TemporaryVariableReference.Create (ec.BuiltinTypes.Bool, ec.CurrentBlock, loc);
                                lock_taken.Resolve (ec);
                        }
 
diff --git a/mcs/tests/test-async-16.cs b/mcs/tests/test-async-16.cs
new file mode 100644 (file)
index 0000000..409cc87
--- /dev/null
@@ -0,0 +1,102 @@
+// Compiler options: -langversion:future
+
+using System;
+using System.Threading.Tasks;
+using System.Threading;
+using System.Reflection;
+using System.Linq;
+
+class Base : IDisposable
+{
+       protected object ovalue;
+       
+       protected static int dispose_counter;
+       
+       public void Dispose ()
+       {
+               ++dispose_counter;
+       }
+}
+
+class Tester : Base
+{
+       async Task<int> SwitchTest_1 ()
+       {
+               switch (await Task.Factory.StartNew (() => "X")) {
+                       case "A":
+                               return 1;
+                       case "B":
+                               return 2;
+                       case "C":
+                               return 3;
+                       case "D":
+                               return 4;
+                       case "X":
+                               return 0;
+               }
+               
+               return 5;
+       }
+       
+       async Task<int> Using_1 ()
+       {
+               using (Base a = await Task.Factory.StartNew (() => new Base ()),
+                               b = await Task.Factory.StartNew (() => new Tester ()),
+                               c = await Task.Factory.StartNew (() => new Base ()),
+                               d = await Task.Factory.StartNew (() => new Base ()))
+               {
+               }
+               
+               if (dispose_counter != 4)
+                       return 1;
+               
+               return 0;
+       }
+       
+       static bool RunTest (MethodInfo test)
+       {
+               Console.Write ("Running test {0, -25}", test.Name);
+               try {
+                       Task t = test.Invoke (new Tester (), null) as Task;
+                       if (!Task.WaitAll (new[] { t }, 1000)) {
+                               Console.WriteLine ("FAILED (Timeout)");
+                               return false;
+                       }
+
+                       var ti = t as Task<int>;
+                       if (ti != null) {
+                               if (ti.Result != 0) {
+                                       Console.WriteLine ("FAILED (Result={0})", ti.Result);
+                                       return false;
+                               }
+                       } else {
+                               var tb = t as Task<bool>;
+                               if (tb != null) {
+                                       if (!tb.Result) {
+                                               Console.WriteLine ("FAILED (Result={0})", tb.Result);
+                                               return false;
+                                       }
+                               }
+                       }
+
+                       Console.WriteLine ("OK");
+                       return true;
+               } catch (Exception e) {
+                       Console.WriteLine ("FAILED");
+                       Console.WriteLine (e.ToString ());
+                       return false;
+               }
+       }
+
+       public static int Main ()
+       {
+               var tests = from test in typeof (Tester).GetMethods (BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.DeclaredOnly)
+                                       where test.GetParameters ().Length == 0
+                                       orderby test.Name
+                                       select RunTest (test);
+
+               int failures = tests.Count (a => !a);
+               Console.WriteLine (failures + " tests failed");
+               return failures;
+       }
+}
diff --git a/mcs/tests/test-async-18.cs b/mcs/tests/test-async-18.cs
new file mode 100644 (file)
index 0000000..898d9ed
--- /dev/null
@@ -0,0 +1,73 @@
+// Compiler options: -langversion:future
+using System;
+using System.Threading.Tasks;
+using System.Threading;
+
+class Tester
+{
+       async Task<int> Lambda_1 ()
+       {
+               int res = 1;
+               {
+                       int a = 8;
+                       Func<int> f = () => a;
+                       res = await Task.Factory.StartNew (f);
+                       res += f ();
+               }
+               
+               return res - 16;
+       }
+       
+       async Task<int> Lambda_2 ()
+       {
+               int res = 1;
+               {
+                       int a = 8;
+                       Func<int> f = () => a + res;
+                       res = await Task.Factory.StartNew (f);
+                       res += f ();
+               }
+               
+               return res - 26;
+       }
+       
+       async Task<int> Lambda_3<T> ()
+       {
+               int res = 1;
+               {
+                       int a = 8;
+                       Func<int> f = () => a;
+                       res = await Task.Factory.StartNew (f);
+                       res += f ();
+               }
+               
+               return res - 16;
+       }       
+
+       public static int Main ()
+       {
+               var t = new Tester ().Lambda_1 ();
+               if (!Task.WaitAll (new [] { t }, 1000))
+                       return 1;
+               
+               if (t.Result != 0)
+                       return 2;
+               
+               t = new Tester ().Lambda_2 ();
+               if (!Task.WaitAll (new [] { t }, 1000))
+                       return 3;
+               
+               if (t.Result != 0)
+                       return 4;
+
+               t = new Tester ().Lambda_3<ulong>();
+               if (!Task.WaitAll (new [] { t }, 1000))
+                       return 5;
+               
+               if (t.Result != 0)
+                       return 6;
+               
+               Console.WriteLine ("ok");
+               return 0;
+       }
+}
index 0ed6703614d9c18b730a2e1b889269e84eba0a01..8d7d856297f403109606b62ad9a8bb94a3951213 100644 (file)
       </method>
     </type>
   </test>
+  <test name="test-async-16.cs">
+    <type name="Base">
+      <method name="Void Dispose()">
+        <size>13</size>
+      </method>
+      <method name="Void .ctor()">
+        <size>7</size>
+      </method>
+    </type>
+    <type name="Tester">
+      <method name="System.Threading.Tasks.Task`1[System.Int32] SwitchTest_1()">
+        <size>27</size>
+      </method>
+      <method name="System.Threading.Tasks.Task`1[System.Int32] Using_1()">
+        <size>27</size>
+      </method>
+      <method name="Boolean RunTest(System.Reflection.MethodInfo)">
+        <size>235</size>
+      </method>
+      <method name="Int32 Main()">
+        <size>179</size>
+      </method>
+      <method name="Boolean &lt;Main&gt;m__0(System.Reflection.MethodInfo)">
+        <size>12</size>
+      </method>
+      <method name="System.String &lt;Main&gt;m__1(System.Reflection.MethodInfo)">
+        <size>7</size>
+      </method>
+      <method name="Boolean &lt;Main&gt;m__2(System.Reflection.MethodInfo)">
+        <size>7</size>
+      </method>
+      <method name="Boolean &lt;Main&gt;m__3(Boolean)">
+        <size>5</size>
+      </method>
+      <method name="Void .ctor()">
+        <size>7</size>
+      </method>
+    </type>
+    <type name="Tester+&lt;SwitchTest_1&gt;c__async0">
+      <method name="Void MoveNext()">
+        <size>380</size>
+      </method>
+      <method name="System.String &lt;&gt;m__4()">
+        <size>6</size>
+      </method>
+      <method name="Void .ctor()">
+        <size>36</size>
+      </method>
+    </type>
+    <type name="Tester+&lt;Using_1&gt;c__async1">
+      <method name="Void MoveNext()">
+        <size>831</size>
+      </method>
+      <method name="Base &lt;&gt;m__5()">
+        <size>6</size>
+      </method>
+      <method name="Tester &lt;&gt;m__6()">
+        <size>6</size>
+      </method>
+      <method name="Base &lt;&gt;m__7()">
+        <size>6</size>
+      </method>
+      <method name="Base &lt;&gt;m__8()">
+        <size>6</size>
+      </method>
+      <method name="Void .ctor()">
+        <size>36</size>
+      </method>
+    </type>
+    <type name="Tester+&lt;Using_1&gt;c__async1+&lt;Using_1&gt;c__AnonStorey2">
+      <method name="Void .ctor()">
+        <size>7</size>
+      </method>
+    </type>
+  </test>
   <test name="test-async-17.cs">
     <type name="Tester">
       <method name="System.Threading.Tasks.Task`1[System.Int32] TestException_1()">
       </method>
     </type>
   </test>
+  <test name="test-async-18.cs">
+    <type name="Tester">
+      <method name="System.Threading.Tasks.Task`1[System.Int32] Lambda_1()">
+        <size>27</size>
+      </method>
+      <method name="System.Threading.Tasks.Task`1[System.Int32] Lambda_2()">
+        <size>27</size>
+      </method>
+      <method name="System.Threading.Tasks.Task`1[System.Int32] Lambda_3[T]()">
+        <size>27</size>
+      </method>
+      <method name="Int32 Main()">
+        <size>165</size>
+      </method>
+      <method name="Void .ctor()">
+        <size>7</size>
+      </method>
+    </type>
+    <type name="Tester+&lt;Lambda_1&gt;c__async0">
+      <method name="Void MoveNext()">
+        <size>290</size>
+      </method>
+      <method name="Void .ctor()">
+        <size>36</size>
+      </method>
+    </type>
+    <type name="Tester+&lt;Lambda_2&gt;c__async1">
+      <method name="Void MoveNext()">
+        <size>290</size>
+      </method>
+      <method name="Void .ctor()">
+        <size>36</size>
+      </method>
+    </type>
+    <type name="Tester+&lt;Lambda_3&gt;c__async2`1[T]">
+      <method name="Void MoveNext()">
+        <size>290</size>
+      </method>
+      <method name="Void .ctor()">
+        <size>36</size>
+      </method>
+    </type>
+    <type name="Tester+&lt;Lambda_1&gt;c__async0+&lt;Lambda_1&gt;c__AnonStorey3">
+      <method name="Int32 &lt;&gt;m__0()">
+        <size>7</size>
+      </method>
+      <method name="Void .ctor()">
+        <size>7</size>
+      </method>
+    </type>
+    <type name="Tester+&lt;Lambda_2&gt;c__async1+&lt;Lambda_2&gt;c__AnonStorey4">
+      <method name="Int32 &lt;&gt;m__1()">
+        <size>19</size>
+      </method>
+      <method name="Void .ctor()">
+        <size>7</size>
+      </method>
+    </type>
+    <type name="Tester+&lt;Lambda_3&gt;c__async2`1+&lt;Lambda_3&gt;c__AnonStorey5`1[T]">
+      <method name="Int32 &lt;&gt;m__2()">
+        <size>7</size>
+      </method>
+      <method name="Void .ctor()">
+        <size>7</size>
+      </method>
+    </type>
+  </test>
   <test name="test-cls-00.cs">
     <type name="CLSCLass_6">
       <method name="Void .ctor()">