Rework exception handling in MemoryStream async methods
authorLarry Ewing <lewing@gmail.com>
Fri, 14 Sep 2012 18:27:54 +0000 (13:27 -0500)
committerLarry Ewing <lewing@gmail.com>
Fri, 14 Sep 2012 18:27:54 +0000 (13:27 -0500)
Return faulted tasks for runtime errors when using the MemoryStream
Async methods.

mcs/class/corlib/System.IO/MemoryStream.cs
mcs/class/corlib/System.Threading.Tasks/Task_T.cs
mcs/class/corlib/Test/System.IO/MemoryStreamTest.cs

index fa874887d6ff71c6a770c983022bce552ebb2d9d..4ce86df30b616f0b9ff96498f8d7fa783d76a698 100644 (file)
@@ -376,9 +376,6 @@ namespace System.IO
 
                public override void Write (byte [] buffer, int offset, int count)
                {
-                       if (!canWrite)
-                               throw new NotSupportedException ("Cannot write to this stream.");
-
                        if (buffer == null)
                                throw new ArgumentNullException ("buffer");
                        
@@ -391,6 +388,9 @@ namespace System.IO
 
                        CheckIfClosedThrowDisposed ();
 
+                       if (!CanWrite)
+                               throw new NotSupportedException ("Cannot write to this stream.");
+
                        // reordered to avoid possible integer overflow
                        if (position > length - count)
                                Expand (position + count);
@@ -436,33 +436,64 @@ namespace System.IO
                public override Task FlushAsync (CancellationToken cancellationToken)
                {
                        if (cancellationToken.IsCancellationRequested)
-                               return TaskConstants<int>.Canceled;
+                               return TaskConstants.Canceled;
 
-                       Flush ();
-                       return TaskConstants.Finished;
+                       try {
+                               Flush ();
+                               return TaskConstants.Finished;
+                       } catch (Exception ex) {
+                               return Task<object>.FromException (ex);
+                       }
                }
 
                public override Task<int> ReadAsync (byte[] buffer, int offset, int count, CancellationToken cancellationToken)
                {
+                       if (buffer == null)
+                               throw new ArgumentNullException ("buffer");
+
+                       if (offset < 0 || count < 0)
+                               throw new ArgumentOutOfRangeException ("offset or count less than zero.");
+
+                       if (buffer.Length - offset < count )
+                               throw new ArgumentException ("offset+count",
+                                                            "The size of the buffer is less than offset + count.");
                        if (cancellationToken.IsCancellationRequested)
                                return TaskConstants<int>.Canceled;
 
-                       count = Read (buffer, offset, count);
+                       try {
+                               count = Read (buffer, offset, count);
 
-                       // Try not to allocate a new task for every buffer read
-                       if (read_task == null || read_task.Result != count)
-                               read_task = Task<int>.FromResult (count);
+                               // Try not to allocate a new task for every buffer read
+                               if (read_task == null || read_task.Result != count)
+                                       read_task = Task<int>.FromResult (count);
 
-                       return read_task;
+                               return read_task;
+                       } catch (Exception ex) {
+                               return Task<int>.FromException (ex);
+                       }
                }
 
                public override Task WriteAsync (byte[] buffer, int offset, int count, CancellationToken cancellationToken)
                {
+                       if (buffer == null)
+                               throw new ArgumentNullException ("buffer");
+                       
+                       if (offset < 0 || count < 0)
+                               throw new ArgumentOutOfRangeException ();
+
+                       if (buffer.Length - offset < count)
+                               throw new ArgumentException ("offset+count",
+                                                            "The size of the buffer is less than offset + count.");
+
                        if (cancellationToken.IsCancellationRequested)
-                               return TaskConstants<int>.Canceled;
+                               return TaskConstants.Canceled;
 
-                       Write (buffer, offset, count);
-                       return TaskConstants.Finished;
+                       try {
+                               Write (buffer, offset, count);
+                               return TaskConstants.Finished;
+                       } catch (Exception ex) {
+                               return Task<object>.FromException (ex);
+                       }
                }
 #endif
        }               
index f02f17bf9d0f4938717ac7103f789e93d2801627..ce7063309037bd8fbb5b53ede1ff23f7a14418fc 100644 (file)
@@ -323,6 +323,13 @@ namespace System.Threading.Tasks
                {
                        return new TaskAwaiter<TResult> (this);
                }
+
+               internal static Task<TResult> FromException (Exception ex)
+               {
+                       var tcs = new TaskCompletionSource<TResult>();
+                       tcs.TrySetException (ex);
+                       return tcs.Task;
+               }
 #endif
        }
 }
index 938d578a85f8aa51e628ac41f25a531e67263f02..2f8e2c05f5a779229b5cb3a1d15ece50c1f80045 100644 (file)
@@ -17,6 +17,9 @@ using System.IO;
 using System.Runtime.Serialization.Formatters.Binary;
 using System.Text;
 using System.Threading;
+#if NET_4_5
+using System.Threading.Tasks;
+#endif
 
 using NUnit.Framework;
 
@@ -45,6 +48,54 @@ namespace MonoTests.System.IO
                        }
                }
 
+               class ExceptionalStream : MemoryStream
+               {
+                       public static string Message = "ExceptionalMessage";
+                       public bool Throw = false;
+
+                       public ExceptionalStream ()
+                       {
+                               AllowRead = true;
+                               AllowWrite = true;
+                       }
+
+                       public ExceptionalStream (byte [] buffer, bool writable) : base (buffer, writable)
+                       {
+                               AllowRead = true;
+                               AllowWrite = true;  // we are testing the inherited write property
+                       }
+
+
+                       public override int Read(byte[] buffer, int offset, int count)
+                       {
+                               if (Throw)
+                                       throw new Exception(Message);
+
+                               return base.Read(buffer, offset, count);
+                       }
+
+                       public override void Write(byte[] buffer, int offset, int count)
+                       {
+                               if (Throw)
+                                       throw new Exception(Message);
+
+                               base.Write(buffer, offset, count);
+                       }
+
+                       public bool AllowRead { get; set; }
+                       public override bool CanRead { get { return AllowRead; } }
+
+                       public bool AllowWrite { get; set; }
+                       public override bool CanWrite { get { return AllowWrite; } }
+                       
+                       public override void Flush()
+                       {
+                               if (Throw)
+                                       throw new Exception(Message);
+
+                               base.Flush();
+                       }
+               }
 
                MemoryStream testStream;
                byte [] testStreamData;
@@ -1074,6 +1125,150 @@ namespace MonoTests.System.IO
                        Assert.AreEqual (1, buffer[0], "#4");
                }
 
+               [Test]
+               public void TestAsyncReadExceptions ()
+               {
+                       var buffer = new byte [3];
+                       using (var stream = new ExceptionalStream ()) {
+                               stream.Write (buffer, 0, buffer.Length);
+                               stream.Write (buffer, 0, buffer.Length);
+                               stream.Position = 0;
+                               var task = stream.ReadAsync (buffer, 0, buffer.Length);
+                               Assert.AreEqual (TaskStatus.RanToCompletion, task.Status, "#1");
+
+                               stream.Throw = true;
+                               task = stream.ReadAsync (buffer, 0, buffer.Length);
+                               Assert.IsTrue (task.IsFaulted, "#2");
+                               Assert.AreEqual (ExceptionalStream.Message, task.Exception.InnerException.Message, "#3");
+                       }
+               }
+
+               [Test]
+               public void TestAsyncWriteExceptions ()
+               {
+                       var buffer = new byte [3];
+                       using (var stream = new ExceptionalStream ()) {
+                               var task = stream.WriteAsync (buffer, 0, buffer.Length);
+                               Assert.AreEqual(TaskStatus.RanToCompletion, task.Status, "#1");
+
+                               stream.Throw = true;
+                               task = stream.WriteAsync (buffer, 0, buffer.Length);
+                               Assert.IsTrue (task.IsFaulted, "#2");
+                               Assert.AreEqual (ExceptionalStream.Message, task.Exception.InnerException.Message, "#3");
+                       }
+               }
+
+               [Test]
+               public void TestAsyncArgumentExceptions ()
+               {
+                       var buffer = new byte [3];
+                       using (var stream = new ExceptionalStream ()) {
+                               var task = stream.WriteAsync (buffer, 0, buffer.Length);
+                               Assert.IsTrue (task.IsCompleted);
+
+                               Assert.IsTrue (Throws<ArgumentException> (() => { stream.WriteAsync (buffer, 0, 1000); }), "#2");
+                               Assert.IsTrue (Throws<ArgumentException> (() => { stream.ReadAsync (buffer, 0, 1000); }), "#3");
+                               Assert.IsTrue (Throws<ArgumentException> (() => { stream.WriteAsync (buffer, 0, 1000, new CancellationToken (true)); }), "#4");
+                               Assert.IsTrue (Throws<ArgumentException> (() => { stream.ReadAsync (buffer, 0, 1000, new CancellationToken (true)); }), "#5");
+                               Assert.IsTrue (Throws<ArgumentException> (() => { stream.WriteAsync (null, 0, buffer.Length, new CancellationToken (true)); }), "#6");
+                               Assert.IsTrue (Throws<ArgumentException> (() => { stream.ReadAsync (null, 0, buffer.Length, new CancellationToken (true)); }), "#7");
+                               Assert.IsTrue (Throws<ArgumentException> (() => { stream.WriteAsync (buffer, 1000, buffer.Length, new CancellationToken (true)); }), "#8");
+                               Assert.IsTrue (Throws<ArgumentException> (() => { stream.ReadAsync (buffer, 1000, buffer.Length, new CancellationToken (true)); }), "#9");
+
+                               stream.AllowRead = false;
+                               var read_task = stream.ReadAsync (buffer, 0, buffer.Length);
+                               Assert.AreEqual (TaskStatus.RanToCompletion, read_task.Status, "#8");
+                               Assert.AreEqual (0, read_task.Result, "#9");
+
+                               stream.Position = 0;
+                               read_task = stream.ReadAsync (buffer, 0, buffer.Length);
+                               Assert.AreEqual (TaskStatus.RanToCompletion, read_task.Status, "#9");
+                               Assert.AreEqual (3, read_task.Result, "#10");
+
+                               var write_task = stream.WriteAsync (buffer, 0, buffer.Length);
+                               Assert.AreEqual (TaskStatus.RanToCompletion, write_task.Status, "#10");
+
+                               // test what happens when CanRead is overridden
+                               using (var norm = new ExceptionalStream (buffer, false)) {
+                                       write_task = norm.WriteAsync (buffer, 0, buffer.Length);
+                                       Assert.AreEqual (TaskStatus.RanToCompletion, write_task.Status, "#11");
+                               }
+
+                               stream.AllowWrite = false;
+                               Assert.IsTrue (Throws<NotSupportedException> (() => { stream.Write (buffer, 0, buffer.Length); }), "#12");
+                               write_task = stream.WriteAsync (buffer, 0, buffer.Length);
+                               Assert.AreEqual (TaskStatus.Faulted, write_task.Status, "#13");
+                       }
+               }
+
+               [Test]
+               public void TestAsyncFlushExceptions ()
+               {
+                       using (var stream = new ExceptionalStream ()) {
+                               var task = stream.FlushAsync ();
+                               Assert.IsTrue (task.IsCompleted, "#1");
+                               
+                               task = stream.FlushAsync (new CancellationToken(true));
+                               Assert.IsTrue (task.IsCanceled, "#2");
+
+                               stream.Throw = true;
+                               task = stream.FlushAsync ();
+                               Assert.IsTrue (task.IsFaulted, "#3");
+                               Assert.AreEqual (ExceptionalStream.Message, task.Exception.InnerException.Message, "#4");
+
+                               task = stream.FlushAsync (new CancellationToken (true));
+                               Assert.IsTrue (task.IsCanceled, "#5");
+                       }
+               }
+
+               [Test]
+               public void TestCopyAsync ()
+               {
+                       using (var stream = new ExceptionalStream ()) {
+                               using (var dest = new ExceptionalStream ()) {
+                                       byte [] buffer = new byte [] { 12, 13, 8 };
+
+                                       stream.Write (buffer, 0, buffer.Length);
+                                       stream.Position = 0;
+                                       var task = stream.CopyToAsync (dest, 1);
+                                       Assert.AreEqual (TaskStatus.RanToCompletion, task.Status);
+                                       Assert.AreEqual (3, stream.Length);
+                                       Assert.AreEqual (3, dest.Length);
+
+                                       stream.Position = 0;
+                                       dest.Throw = true;
+                                       task = stream.CopyToAsync (dest, 1);
+                                       Assert.AreEqual (TaskStatus.Faulted, task.Status);
+                                       Assert.AreEqual (3, stream.Length);
+                                       Assert.AreEqual (3, dest.Length);
+                               }
+                       }
+               }
+
+               [Test]
+               public void WritableOverride ()
+               {
+                       var buffer = new byte [3];
+                       var stream = new MemoryStream (buffer, false);
+                       Assert.IsTrue (Throws<NotSupportedException> (() => { stream.Write (buffer, 0, buffer.Length); }), "#1");
+                       Assert.IsTrue (Throws<ArgumentNullException> (() => { stream.Write (null, 0, buffer.Length); }), "#1.1");
+                       stream.Close ();
+                       Assert.IsTrue (Throws<ObjectDisposedException> (() => { stream.Write (buffer, 0, buffer.Length); }), "#2");
+                       stream = new MemoryStream (buffer, true);
+                       stream.Close ();
+                       Assert.IsFalse (stream.CanWrite, "#3");
+
+                       var estream = new ExceptionalStream (buffer, false);
+                       Assert.IsFalse (Throws<Exception> (() => { estream.Write (buffer, 0, buffer.Length); }), "#4");
+                       estream.AllowWrite = false;
+                       estream.Position = 0;
+                       Assert.IsTrue (Throws<NotSupportedException> (() => { estream.Write (buffer, 0, buffer.Length); }), "#5");
+                       estream.AllowWrite = true;
+                       estream.Close ();
+                       Assert.IsTrue (estream.CanWrite, "#6");
+                       Assert.IsTrue (Throws<ObjectDisposedException> (() => { stream.Write (buffer, 0, buffer.Length); }), "#7");
+               }
+
                [Test]
                public void ReadAsync_Canceled ()
                {
@@ -1109,6 +1304,16 @@ namespace MonoTests.System.IO
                        t = testStream.WriteAsync (buffer, 0, buffer.Length);
                        Assert.IsTrue (t.IsCompleted, "#1");
                }
+
+               bool Throws<T> (Action a) where T : Exception
+               {
+                       try {
+                               a ();
+                               return false;
+                       } catch (T) {
+                               return true;
+                       }
+               }
 #endif
        }
 }