[corlib] Implement task awaiters running on custom task scheduler. Fixes #16587
authorMarek Safar <marek.safar@gmail.com>
Thu, 5 Dec 2013 08:45:26 +0000 (09:45 +0100)
committerMarek Safar <marek.safar@gmail.com>
Thu, 5 Dec 2013 08:51:56 +0000 (09:51 +0100)
mcs/class/corlib/System.Runtime.CompilerServices/TaskAwaiter.cs
mcs/class/corlib/System.Runtime.CompilerServices/TaskAwaiter_T.cs
mcs/class/corlib/System.Threading.Tasks/Task.cs
mcs/class/corlib/System.Threading.Tasks/TaskContinuation.cs
mcs/class/corlib/Test/System.Runtime.CompilerServices/TaskAwaiterTest.cs

index 90620b0610d726e814c1e38fdb1252792086c325..75b5eddf39c8db99dfede7215ceeeb54e1e7e3ca 100644 (file)
@@ -53,7 +53,7 @@ namespace System.Runtime.CompilerServices
                public void GetResult ()
                {
                        if (!task.IsCompleted)
-                               task.WaitCore (Timeout.Infinite, CancellationToken.None);
+                               task.WaitCore (Timeout.Infinite, CancellationToken.None, true);
 
                        if (task.Status != TaskStatus.RanToCompletion)
                                // Merge current and dispatched stack traces if there is any
@@ -77,7 +77,16 @@ namespace System.Runtime.CompilerServices
                        if (continueOnSourceContext && SynchronizationContext.Current != null) {
                                task.ContinueWith (new SynchronizationContextContinuation (continuation, SynchronizationContext.Current));
                        } else {
-                               task.ContinueWith (new ActionContinuation (continuation));
+                               IContinuation cont;
+                               if (TaskScheduler.Current != TaskScheduler.Default) {
+                                       var runner = new Task (TaskActionInvoker.Create (continuation), null, CancellationToken.None, TaskCreationOptions.None, null);
+                                       runner.SetupScheduler (TaskScheduler.Current);
+                                       cont = new SchedulerAwaitContinuation (runner);
+                               } else {
+                                       cont = new ActionContinuation (continuation);
+                               }
+
+                               task.ContinueWith (cont);
                        }
                }
 
index f12db5fe57814cf91ada45c557aa4c5aa1b581fa..b59a0ef005e3736eed23103ee90f822c9dd2f6b7 100644 (file)
@@ -53,7 +53,7 @@ namespace System.Runtime.CompilerServices
                public TResult GetResult ()
                {
                        if (!task.IsCompleted)
-                               task.WaitCore (Timeout.Infinite, CancellationToken.None);
+                               task.WaitCore (Timeout.Infinite, CancellationToken.None, true);
 
                        if (task.Status != TaskStatus.RanToCompletion)
                                ExceptionDispatchInfo.Capture (TaskAwaiter.HandleUnexpectedTaskResult (task)).Throw ();
index 5eefcb7da5f87592bc67d2751e579f0b95ed1055..f55689114ec100f223609b640e77c52b085c3757 100644 (file)
@@ -224,7 +224,7 @@ namespace System.Threading.Tasks
                        }
 
                        Schedule ();
-                       Wait ();
+                       WaitCore (Timeout.Infinite, CancellationToken.None, false);
                }
                #endregion
                
@@ -641,7 +641,7 @@ namespace System.Threading.Tasks
                        if (millisecondsTimeout < -1)
                                throw new ArgumentOutOfRangeException ("millisecondsTimeout");
 
-                       bool result = WaitCore (millisecondsTimeout, cancellationToken);
+                       bool result = WaitCore (millisecondsTimeout, cancellationToken, true);
 
                        if (IsCanceled)
                                throw new AggregateException (new TaskCanceledException (this));
@@ -653,13 +653,13 @@ namespace System.Threading.Tasks
                        return result;
                }
 
-               internal bool WaitCore (int millisecondsTimeout, CancellationToken cancellationToken)
+               internal bool WaitCore (int millisecondsTimeout, CancellationToken cancellationToken, bool runInline)
                {
                        if (IsCompleted)
                                return true;
 
                        // If the task is ready to be run and we were supposed to wait on it indefinitely without cancellation, just run it
-                       if (Status == TaskStatus.WaitingToRun && millisecondsTimeout == Timeout.Infinite && scheduler != null && !cancellationToken.CanBeCanceled)
+                       if (runInline && Status == TaskStatus.WaitingToRun && millisecondsTimeout == Timeout.Infinite && scheduler != null && !cancellationToken.CanBeCanceled)
                                scheduler.RunInline (this, true);
 
                        bool result = true;
index 9825be780009dbd896e61e072e17d8118bdedfef..010e3310c365bdf06ef85c49abaef670fda728e8 100644 (file)
@@ -125,6 +125,21 @@ namespace System.Threading.Tasks
                }
        }
 
+       class SchedulerAwaitContinuation : IContinuation
+       {
+               readonly Task task;
+
+               public SchedulerAwaitContinuation (Task task)
+               {
+                       this.task = task;
+               }
+
+               public void Execute ()
+               {
+                       task.RunSynchronouslyCore (task.scheduler);
+               }
+       }
+
        class SynchronizationContextContinuation : IContinuation
        {
                readonly Action action;
index 48f629adface65b5d690e707af73fd8c1bb957c9..4526e9574d1820ce156bea4044c3f81ca62aecb9 100644 (file)
@@ -33,12 +33,45 @@ using System.Threading;
 using System.Threading.Tasks;
 using NUnit.Framework;
 using System.Runtime.CompilerServices;
+using System.Collections.Generic;
 
 namespace MonoTests.System.Runtime.CompilerServices
 {
        [TestFixture]
        public class TaskAwaiterTest
        {
+               class Scheduler : TaskScheduler
+               {
+                       string name;
+
+                       public Scheduler (string name)
+                       {
+                               this.name = name;
+                       }
+
+                       public int InlineCalls { get; set; }
+                       public int QueueCalls { get; set; }
+
+                       protected override IEnumerable<Task> GetScheduledTasks ()
+                       {
+                               throw new NotImplementedException ();
+                       }
+
+                       protected override void QueueTask (Task task)
+                       {
+                               ++QueueCalls;
+                               ThreadPool.QueueUserWorkItem (o => {
+                                       TryExecuteTask (task);
+                               });
+                       }
+
+                       protected override bool TryExecuteTaskInline (Task task, bool taskWasPreviouslyQueued)
+                       {
+                               ++InlineCalls;
+                               return false;
+                       }
+               }
+
                [Test]
                public void GetResultFaulted ()
                {
@@ -85,6 +118,43 @@ namespace MonoTests.System.Runtime.CompilerServices
                        awaiter.GetResult ();
                        Assert.AreEqual (TaskStatus.RanToCompletion, task.Status);
                }
+
+               [Test]
+               public void CustomScheduler ()
+               {
+                       var a = new Scheduler ("a");
+                       var b = new Scheduler ("b");
+
+                       var r = TestCS (a, b).Result;
+                       Assert.AreEqual (0, r, "#1");
+                       Assert.AreEqual (1, a.InlineCalls, "#2a");
+                       Assert.AreEqual (0, b.InlineCalls, "#2b");
+                       Assert.AreEqual (2, a.QueueCalls, "#3a");
+                       Assert.AreEqual (1, b.QueueCalls, "#3b");
+               }
+
+               static async Task<int> TestCS (TaskScheduler schedulerA, TaskScheduler schedulerB)
+               {
+                       var res = await Task.Factory.StartNew (async () => {
+                               if (TaskScheduler.Current != schedulerA)
+                                       return 1;
+
+                               await Task.Factory.StartNew (
+                                       () => {
+                                               if (TaskScheduler.Current != schedulerB)
+                                                       return 2;
+
+                                               return 0;
+                                       }, CancellationToken.None, TaskCreationOptions.None, schedulerB);
+
+                               if (TaskScheduler.Current != schedulerA)
+                                       return 3;
+
+                               return 0;
+                       }, CancellationToken.None, TaskCreationOptions.None, schedulerA);
+
+                       return res.Result;
+               }
        }
 }