public void GetResult ()
{
if (!task.IsCompleted)
- task.WaitCore (Timeout.Infinite, CancellationToken.None);
+ task.WaitCore (Timeout.Infinite, CancellationToken.None, true);
if (task.Status != TaskStatus.RanToCompletion)
// Merge current and dispatched stack traces if there is any
if (continueOnSourceContext && SynchronizationContext.Current != null) {
task.ContinueWith (new SynchronizationContextContinuation (continuation, SynchronizationContext.Current));
} else {
- task.ContinueWith (new ActionContinuation (continuation));
+ IContinuation cont;
+ if (TaskScheduler.Current != TaskScheduler.Default) {
+ var runner = new Task (TaskActionInvoker.Create (continuation), null, CancellationToken.None, TaskCreationOptions.None, null);
+ runner.SetupScheduler (TaskScheduler.Current);
+ cont = new SchedulerAwaitContinuation (runner);
+ } else {
+ cont = new ActionContinuation (continuation);
+ }
+
+ task.ContinueWith (cont);
}
}
public TResult GetResult ()
{
if (!task.IsCompleted)
- task.WaitCore (Timeout.Infinite, CancellationToken.None);
+ task.WaitCore (Timeout.Infinite, CancellationToken.None, true);
if (task.Status != TaskStatus.RanToCompletion)
ExceptionDispatchInfo.Capture (TaskAwaiter.HandleUnexpectedTaskResult (task)).Throw ();
}
Schedule ();
- Wait ();
+ WaitCore (Timeout.Infinite, CancellationToken.None, false);
}
#endregion
if (millisecondsTimeout < -1)
throw new ArgumentOutOfRangeException ("millisecondsTimeout");
- bool result = WaitCore (millisecondsTimeout, cancellationToken);
+ bool result = WaitCore (millisecondsTimeout, cancellationToken, true);
if (IsCanceled)
throw new AggregateException (new TaskCanceledException (this));
return result;
}
- internal bool WaitCore (int millisecondsTimeout, CancellationToken cancellationToken)
+ internal bool WaitCore (int millisecondsTimeout, CancellationToken cancellationToken, bool runInline)
{
if (IsCompleted)
return true;
// If the task is ready to be run and we were supposed to wait on it indefinitely without cancellation, just run it
- if (Status == TaskStatus.WaitingToRun && millisecondsTimeout == Timeout.Infinite && scheduler != null && !cancellationToken.CanBeCanceled)
+ if (runInline && Status == TaskStatus.WaitingToRun && millisecondsTimeout == Timeout.Infinite && scheduler != null && !cancellationToken.CanBeCanceled)
scheduler.RunInline (this, true);
bool result = true;
}
}
+ class SchedulerAwaitContinuation : IContinuation
+ {
+ readonly Task task;
+
+ public SchedulerAwaitContinuation (Task task)
+ {
+ this.task = task;
+ }
+
+ public void Execute ()
+ {
+ task.RunSynchronouslyCore (task.scheduler);
+ }
+ }
+
class SynchronizationContextContinuation : IContinuation
{
readonly Action action;
using System.Threading.Tasks;
using NUnit.Framework;
using System.Runtime.CompilerServices;
+using System.Collections.Generic;
namespace MonoTests.System.Runtime.CompilerServices
{
[TestFixture]
public class TaskAwaiterTest
{
+ class Scheduler : TaskScheduler
+ {
+ string name;
+
+ public Scheduler (string name)
+ {
+ this.name = name;
+ }
+
+ public int InlineCalls { get; set; }
+ public int QueueCalls { get; set; }
+
+ protected override IEnumerable<Task> GetScheduledTasks ()
+ {
+ throw new NotImplementedException ();
+ }
+
+ protected override void QueueTask (Task task)
+ {
+ ++QueueCalls;
+ ThreadPool.QueueUserWorkItem (o => {
+ TryExecuteTask (task);
+ });
+ }
+
+ protected override bool TryExecuteTaskInline (Task task, bool taskWasPreviouslyQueued)
+ {
+ ++InlineCalls;
+ return false;
+ }
+ }
+
[Test]
public void GetResultFaulted ()
{
awaiter.GetResult ();
Assert.AreEqual (TaskStatus.RanToCompletion, task.Status);
}
+
+ [Test]
+ public void CustomScheduler ()
+ {
+ var a = new Scheduler ("a");
+ var b = new Scheduler ("b");
+
+ var r = TestCS (a, b).Result;
+ Assert.AreEqual (0, r, "#1");
+ Assert.AreEqual (1, a.InlineCalls, "#2a");
+ Assert.AreEqual (0, b.InlineCalls, "#2b");
+ Assert.AreEqual (2, a.QueueCalls, "#3a");
+ Assert.AreEqual (1, b.QueueCalls, "#3b");
+ }
+
+ static async Task<int> TestCS (TaskScheduler schedulerA, TaskScheduler schedulerB)
+ {
+ var res = await Task.Factory.StartNew (async () => {
+ if (TaskScheduler.Current != schedulerA)
+ return 1;
+
+ await Task.Factory.StartNew (
+ () => {
+ if (TaskScheduler.Current != schedulerB)
+ return 2;
+
+ return 0;
+ }, CancellationToken.None, TaskCreationOptions.None, schedulerB);
+
+ if (TaskScheduler.Current != schedulerA)
+ return 3;
+
+ return 0;
+ }, CancellationToken.None, TaskCreationOptions.None, schedulerA);
+
+ return res.Result;
+ }
}
}