[corlib] Fix extra await context switch when custom SynchronizationContext does not...
authorMarek Safar <marek.safar@gmail.com>
Fri, 25 Apr 2014 16:02:13 +0000 (18:02 +0200)
committerMarek Safar <marek.safar@gmail.com>
Fri, 25 Apr 2014 16:03:43 +0000 (18:03 +0200)
mcs/class/corlib/System.Threading.Tasks/TaskContinuation.cs
mcs/class/corlib/Test/System.Runtime.CompilerServices/TaskAwaiterTest.cs

index b826f309e396804e37b64187c816eed4b0ccd7c0..2a0547850bd044fd65d53eea6bd251ea3ea5aad9 100644 (file)
@@ -161,7 +161,11 @@ namespace System.Threading.Tasks
 
                public void Execute ()
                {
-                       ctx.Post (l => ((Action) l) (), action);
+                       // No context switch when we are on correct context
+                       if (ctx == SynchronizationContext.Current)
+                               action ();
+                       else
+                               ctx.Post (l => ((Action) l) (), action);
                }
        }
 
index 3e8233a1fe0ddb923fd5f2b39cf625d779d1be4c..a403d5287e2d6af53add6c41d825a641bea42ae6 100644 (file)
@@ -34,6 +34,7 @@ using System.Threading.Tasks;
 using NUnit.Framework;
 using System.Runtime.CompilerServices;
 using System.Collections.Generic;
+using System.Collections;
 
 namespace MonoTests.System.Runtime.CompilerServices
 {
@@ -72,6 +73,48 @@ namespace MonoTests.System.Runtime.CompilerServices
                        }
                }
 
+               class SingleThreadSynchronizationContext : SynchronizationContext
+               {
+                       readonly Queue _queue = new Queue ();
+
+                       public void RunOnCurrentThread ()
+                       {
+                               while (_queue.Count != 0) {
+                                       var workItem = (KeyValuePair<SendOrPostCallback, object>) _queue.Dequeue ();
+                                       workItem.Key (workItem.Value);
+                               }
+                       }
+                               
+                       public override void Post (SendOrPostCallback d, object state)
+                       {
+                               if (d == null) {
+                                       throw new ArgumentNullException ("d");
+                               }
+
+                               _queue.Enqueue (new KeyValuePair<SendOrPostCallback, object> (d, state));
+                       }
+
+                       public override void Send (SendOrPostCallback d, object state)
+                       {
+                               throw new NotSupportedException ("Synchronously sending is not supported.");
+                       }
+               }
+
+               string progress;
+               SynchronizationContext sc;
+
+               [SetUp]
+               public void Setup ()
+               {
+                       sc = SynchronizationContext.Current;
+               }
+
+               [TearDown]
+               public void TearDown ()
+               {
+                       SynchronizationContext.SetSynchronizationContext (sc);
+               }
+
                [Test]
                public void GetResultFaulted ()
                {
@@ -183,6 +226,86 @@ namespace MonoTests.System.Runtime.CompilerServices
                        // e.g. Touch.Unit defaults to run tests on the main thread and this will return false
                        Assert.AreEqual (Thread.CurrentThread.IsBackground, mres2.WaitOne (2000), "#2");;
                }
+
+               [Test]
+               public void CompletionOnSameCustomSynchronizationContext ()
+               {
+                       progress = "";
+                       var syncContext = new SingleThreadSynchronizationContext ();
+                       SynchronizationContext.SetSynchronizationContext (syncContext);
+
+                       syncContext.Post (delegate {
+                               Go (syncContext);
+                       }, null);
+
+                       // Custom message loop
+                       var cts = new CancellationTokenSource ();
+                       cts.CancelAfter (5000);
+                       while (progress.Length != 3 && !cts.IsCancellationRequested) {
+                               syncContext.RunOnCurrentThread ();
+                               Thread.Sleep (0);
+                       }
+
+                       Assert.AreEqual ("123", progress);
+               }
+
+               async void Go (SynchronizationContext ctx)
+               {
+                       await Wait (ctx);
+
+                       progress += "2";
+               }
+
+               async Task Wait (SynchronizationContext ctx)
+               {
+                       await Task.Delay (10); // Force block suspend/return
+
+                       ctx.Post (l => progress += "3", null);
+
+                       progress += "1";
+
+                       // Exiting same context - no need to post continuation
+               }
+
+               [Test]
+               public void CompletionOnDifferentCustomSynchronizationContext ()
+               {
+                       progress = "";
+                       var syncContext = new SingleThreadSynchronizationContext ();
+                       SynchronizationContext.SetSynchronizationContext (syncContext);
+
+                       syncContext.Post (delegate {
+                               Go2 (syncContext);
+                       }, null);
+
+                       // Custom message loop
+                       var cts = new CancellationTokenSource ();
+                       cts.CancelAfter (5000);
+                       while (progress.Length != 3 && !cts.IsCancellationRequested) {
+                               syncContext.RunOnCurrentThread ();
+                               Thread.Sleep (0);
+                       }
+
+                       Assert.AreEqual ("132", progress);
+               }
+
+               async void Go2 (SynchronizationContext ctx)
+               {
+                       await Wait2 (ctx);
+
+                       progress += "2";
+               }
+
+               async Task Wait2 (SynchronizationContext ctx)
+               {
+                       await Task.Delay (10); // Force block suspend/return
+
+                       ctx.Post (l => progress += "3", null);
+
+                       progress += "1";
+
+                       SynchronizationContext.SetSynchronizationContext (null);
+               }
        }
 }