Merge pull request #2034 from alexrp/ctx-cleanup
[mono.git] / mcs / class / corlib / Test / System.IO / MemoryStreamTest.cs
index 938d578a85f8aa51e628ac41f25a531e67263f02..f552e933a65290eca43c3c6cd3ffa9df00148d7d 100644 (file)
@@ -17,6 +17,9 @@ using System.IO;
 using System.Runtime.Serialization.Formatters.Binary;
 using System.Text;
 using System.Threading;
+#if NET_4_5
+using System.Threading.Tasks;
+#endif
 
 using NUnit.Framework;
 
@@ -45,6 +48,54 @@ namespace MonoTests.System.IO
                        }
                }
 
+               class ExceptionalStream : MemoryStream
+               {
+                       public static string Message = "ExceptionalMessage";
+                       public bool Throw = false;
+
+                       public ExceptionalStream ()
+                       {
+                               AllowRead = true;
+                               AllowWrite = true;
+                       }
+
+                       public ExceptionalStream (byte [] buffer, bool writable) : base (buffer, writable)
+                       {
+                               AllowRead = true;
+                               AllowWrite = true;  // we are testing the inherited write property
+                       }
+
+
+                       public override int Read(byte[] buffer, int offset, int count)
+                       {
+                               if (Throw)
+                                       throw new Exception(Message);
+
+                               return base.Read(buffer, offset, count);
+                       }
+
+                       public override void Write(byte[] buffer, int offset, int count)
+                       {
+                               if (Throw)
+                                       throw new Exception(Message);
+
+                               base.Write(buffer, offset, count);
+                       }
+
+                       public bool AllowRead { get; set; }
+                       public override bool CanRead { get { return AllowRead; } }
+
+                       public bool AllowWrite { get; set; }
+                       public override bool CanWrite { get { return AllowWrite; } }
+                       
+                       public override void Flush()
+                       {
+                               if (Throw)
+                                       throw new Exception(Message);
+
+                               base.Flush();
+                       }
+               }
 
                MemoryStream testStream;
                byte [] testStreamData;
@@ -283,39 +334,44 @@ namespace MonoTests.System.IO
 
                        wh.Close ();
                }
-               
+
                [Test]
                public void BeginReadIsBlockingNextRead ()
                {
                        byte[] readBytes = new byte[5];
                        byte[] readBytes2 = new byte[3];
-                       var wh = new ManualResetEvent (false);
-                       var end = new ManualResetEvent (false);
-
-                       using (var testStream = new SignaledMemoryStream (testStreamData, wh)) {
-                               var res = testStream.BeginRead (readBytes, 0, 5, null, null);
-
-                               bool blocking = true;
-                               ThreadPool.QueueUserWorkItem (l => {
-                                       var res2 = testStream.BeginRead (readBytes2, 0, 3, null, null);
-                                       blocking = false;
-                                       Assert.IsTrue (res2.AsyncWaitHandle.WaitOne (2000), "#10");
-                                       Assert.IsTrue (res2.IsCompleted, "#11");
-                                       Assert.AreEqual (3, testStream.EndRead (res2), "#12");
-                                       Assert.AreEqual (95, readBytes2[0], "#13");
-                                       end.Set ();
-                               });
-
-                               Assert.IsFalse (res.IsCompleted, "#1");
-                               Thread.Sleep (500);     // Lame but don't know how to wait for another BeginRead which does not return
-                               Assert.IsTrue (blocking, "#2");
+                       ManualResetEvent begin_read_unblock = new ManualResetEvent (false);
+                       ManualResetEvent begin_read_blocking = new ManualResetEvent (false);
+                       Task begin_read_task = null;
 
-                               wh.Set ();
-                               Assert.IsTrue (res.AsyncWaitHandle.WaitOne (2000), "#3");
-                               Assert.IsTrue (res.IsCompleted, "#4");
-                               Assert.AreEqual (5, testStream.EndRead (res), "#5");
-                               Assert.IsTrue (end.WaitOne (2000), "#6");
-                               Assert.AreEqual (100, readBytes[0], "#7");
+                       try {
+                               using (var testStream = new SignaledMemoryStream (testStreamData, begin_read_unblock)) {
+                                       IAsyncResult begin_read_1_ares = testStream.BeginRead (readBytes, 0, 5, null, null);
+
+                                       begin_read_task = Task.Factory.StartNew (() => {
+                                               IAsyncResult begin_read_2_ares = testStream.BeginRead (readBytes2, 0, 3, null, null);
+                                               begin_read_blocking.Set ();
+
+                                               Assert.IsTrue (begin_read_2_ares.AsyncWaitHandle.WaitOne (2000), "#10");
+                                               Assert.IsTrue (begin_read_2_ares.IsCompleted, "#11");
+                                               Assert.AreEqual (3, testStream.EndRead (begin_read_2_ares), "#12");
+                                               Assert.AreEqual (95, readBytes2[0], "#13");
+                                       });
+
+                                       Assert.IsFalse (begin_read_1_ares.IsCompleted, "#1");
+                                       Assert.IsFalse (begin_read_blocking.WaitOne (500), "#2");
+
+                                       begin_read_unblock.Set ();
+
+                                       Assert.IsTrue (begin_read_1_ares.AsyncWaitHandle.WaitOne (2000), "#3");
+                                       Assert.IsTrue (begin_read_1_ares.IsCompleted, "#4");
+                                       Assert.AreEqual (5, testStream.EndRead (begin_read_1_ares), "#5");
+                                       Assert.IsTrue (begin_read_task.Wait (2000), "#6");
+                                       Assert.AreEqual (100, readBytes[0], "#7");
+                               }
+                       } finally {
+                               if (begin_read_task != null)
+                                       begin_read_task.Wait ();
                        }
                }
 
@@ -344,34 +400,38 @@ namespace MonoTests.System.IO
                {
                        byte[] readBytes = new byte[5];
                        byte[] readBytes2 = new byte[3] { 1, 2, 3 };
-                       var wh = new ManualResetEvent (false);
-                       var end = new ManualResetEvent (false);
-
-                       using (var testStream = new SignaledMemoryStream (testStreamData, wh)) {
-                               var res = testStream.BeginRead (readBytes, 0, 5, null, null);
-
-                               bool blocking = true;
-                               ThreadPool.QueueUserWorkItem (l => {
-                                       var res2 = testStream.BeginWrite (readBytes2, 0, 3, null, null);
-                                       blocking = false;
-                                       Assert.IsTrue (res2.AsyncWaitHandle.WaitOne (2000), "#10");
-                                       Assert.IsTrue (res2.IsCompleted, "#11");
-                                       testStream.EndWrite (res2);
-                                       end.Set ();
-                               });
+                       ManualResetEvent begin_read_unblock = new ManualResetEvent (false);
+                       ManualResetEvent begin_write_blocking = new ManualResetEvent (false);
+                       Task begin_write_task = null;
 
-                               Assert.IsFalse (res.IsCompleted, "#1");
-                               Thread.Sleep (500);     // Lame but don't know how to wait for another BeginWrite which does not return
-                               Assert.IsTrue (blocking, "#2");
-
-                               wh.Set ();
-                               Assert.IsTrue (res.AsyncWaitHandle.WaitOne (2000), "#3");
-                               Assert.IsTrue (res.IsCompleted, "#4");
-                               Assert.AreEqual (5, testStream.EndRead (res), "#5");
-                               Assert.IsTrue (end.WaitOne (2000), "#6");
+                       try {
+                               using (MemoryStream stream = new SignaledMemoryStream (testStreamData, begin_read_unblock)) {
+                                       IAsyncResult begin_read_ares = stream.BeginRead (readBytes, 0, 5, null, null);
+
+                                       begin_write_task = Task.Factory.StartNew (() => {
+                                               var begin_write_ares = stream.BeginWrite (readBytes2, 0, 3, null, null);
+                                               begin_write_blocking.Set ();
+                                               Assert.IsTrue (begin_write_ares.AsyncWaitHandle.WaitOne (2000), "#10");
+                                               Assert.IsTrue (begin_write_ares.IsCompleted, "#11");
+                                               stream.EndWrite (begin_write_ares);
+                                       });
+
+                                       Assert.IsFalse (begin_read_ares.IsCompleted, "#1");
+                                       Assert.IsFalse (begin_write_blocking.WaitOne (500), "#2");
+
+                                       begin_read_unblock.Set ();
+
+                                       Assert.IsTrue (begin_read_ares.AsyncWaitHandle.WaitOne (2000), "#3");
+                                       Assert.IsTrue (begin_read_ares.IsCompleted, "#4");
+                                       Assert.AreEqual (5, stream.EndRead (begin_read_ares), "#5");
+                                       Assert.IsTrue (begin_write_task.Wait (2000), "#6");
+                               }
+                       } finally {
+                               if (begin_write_task != null)
+                                       begin_write_task.Wait ();
                        }
                }
-               
+
                [Test]
                public void BeginWrite ()
                {
@@ -1074,6 +1134,150 @@ namespace MonoTests.System.IO
                        Assert.AreEqual (1, buffer[0], "#4");
                }
 
+               [Test]
+               public void TestAsyncReadExceptions ()
+               {
+                       var buffer = new byte [3];
+                       using (var stream = new ExceptionalStream ()) {
+                               stream.Write (buffer, 0, buffer.Length);
+                               stream.Write (buffer, 0, buffer.Length);
+                               stream.Position = 0;
+                               var task = stream.ReadAsync (buffer, 0, buffer.Length);
+                               Assert.AreEqual (TaskStatus.RanToCompletion, task.Status, "#1");
+
+                               stream.Throw = true;
+                               task = stream.ReadAsync (buffer, 0, buffer.Length);
+                               Assert.IsTrue (task.IsFaulted, "#2");
+                               Assert.AreEqual (ExceptionalStream.Message, task.Exception.InnerException.Message, "#3");
+                       }
+               }
+
+               [Test]
+               public void TestAsyncWriteExceptions ()
+               {
+                       var buffer = new byte [3];
+                       using (var stream = new ExceptionalStream ()) {
+                               var task = stream.WriteAsync (buffer, 0, buffer.Length);
+                               Assert.AreEqual(TaskStatus.RanToCompletion, task.Status, "#1");
+
+                               stream.Throw = true;
+                               task = stream.WriteAsync (buffer, 0, buffer.Length);
+                               Assert.IsTrue (task.IsFaulted, "#2");
+                               Assert.AreEqual (ExceptionalStream.Message, task.Exception.InnerException.Message, "#3");
+                       }
+               }
+
+               [Test]
+               public void TestAsyncArgumentExceptions ()
+               {
+                       var buffer = new byte [3];
+                       using (var stream = new ExceptionalStream ()) {
+                               var task = stream.WriteAsync (buffer, 0, buffer.Length);
+                               Assert.IsTrue (task.IsCompleted);
+
+                               Assert.IsTrue (Throws<ArgumentException> (() => { stream.WriteAsync (buffer, 0, 1000); }), "#2");
+                               Assert.IsTrue (Throws<ArgumentException> (() => { stream.ReadAsync (buffer, 0, 1000); }), "#3");
+                               Assert.IsTrue (Throws<ArgumentException> (() => { stream.WriteAsync (buffer, 0, 1000, new CancellationToken (true)); }), "#4");
+                               Assert.IsTrue (Throws<ArgumentException> (() => { stream.ReadAsync (buffer, 0, 1000, new CancellationToken (true)); }), "#5");
+                               Assert.IsTrue (Throws<ArgumentException> (() => { stream.WriteAsync (null, 0, buffer.Length, new CancellationToken (true)); }), "#6");
+                               Assert.IsTrue (Throws<ArgumentException> (() => { stream.ReadAsync (null, 0, buffer.Length, new CancellationToken (true)); }), "#7");
+                               Assert.IsTrue (Throws<ArgumentException> (() => { stream.WriteAsync (buffer, 1000, buffer.Length, new CancellationToken (true)); }), "#8");
+                               Assert.IsTrue (Throws<ArgumentException> (() => { stream.ReadAsync (buffer, 1000, buffer.Length, new CancellationToken (true)); }), "#9");
+
+                               stream.AllowRead = false;
+                               var read_task = stream.ReadAsync (buffer, 0, buffer.Length);
+                               Assert.AreEqual (TaskStatus.RanToCompletion, read_task.Status, "#8");
+                               Assert.AreEqual (0, read_task.Result, "#9");
+
+                               stream.Position = 0;
+                               read_task = stream.ReadAsync (buffer, 0, buffer.Length);
+                               Assert.AreEqual (TaskStatus.RanToCompletion, read_task.Status, "#9");
+                               Assert.AreEqual (3, read_task.Result, "#10");
+
+                               var write_task = stream.WriteAsync (buffer, 0, buffer.Length);
+                               Assert.AreEqual (TaskStatus.RanToCompletion, write_task.Status, "#10");
+
+                               // test what happens when CanRead is overridden
+                               using (var norm = new ExceptionalStream (buffer, false)) {
+                                       write_task = norm.WriteAsync (buffer, 0, buffer.Length);
+                                       Assert.AreEqual (TaskStatus.RanToCompletion, write_task.Status, "#11");
+                               }
+
+                               stream.AllowWrite = false;
+                               Assert.IsTrue (Throws<NotSupportedException> (() => { stream.Write (buffer, 0, buffer.Length); }), "#12");
+                               write_task = stream.WriteAsync (buffer, 0, buffer.Length);
+                               Assert.AreEqual (TaskStatus.Faulted, write_task.Status, "#13");
+                       }
+               }
+
+               [Test]
+               public void TestAsyncFlushExceptions ()
+               {
+                       using (var stream = new ExceptionalStream ()) {
+                               var task = stream.FlushAsync ();
+                               Assert.IsTrue (task.IsCompleted, "#1");
+                               
+                               task = stream.FlushAsync (new CancellationToken(true));
+                               Assert.IsTrue (task.IsCanceled, "#2");
+
+                               stream.Throw = true;
+                               task = stream.FlushAsync ();
+                               Assert.IsTrue (task.IsFaulted, "#3");
+                               Assert.AreEqual (ExceptionalStream.Message, task.Exception.InnerException.Message, "#4");
+
+                               task = stream.FlushAsync (new CancellationToken (true));
+                               Assert.IsTrue (task.IsCanceled, "#5");
+                       }
+               }
+
+               [Test]
+               public void TestCopyAsync ()
+               {
+                       using (var stream = new ExceptionalStream ()) {
+                               using (var dest = new ExceptionalStream ()) {
+                                       byte [] buffer = new byte [] { 12, 13, 8 };
+
+                                       stream.Write (buffer, 0, buffer.Length);
+                                       stream.Position = 0;
+                                       var task = stream.CopyToAsync (dest, 1);
+                                       Assert.AreEqual (TaskStatus.RanToCompletion, task.Status);
+                                       Assert.AreEqual (3, stream.Length);
+                                       Assert.AreEqual (3, dest.Length);
+
+                                       stream.Position = 0;
+                                       dest.Throw = true;
+                                       task = stream.CopyToAsync (dest, 1);
+                                       Assert.AreEqual (TaskStatus.Faulted, task.Status);
+                                       Assert.AreEqual (3, stream.Length);
+                                       Assert.AreEqual (3, dest.Length);
+                               }
+                       }
+               }
+
+               [Test]
+               public void WritableOverride ()
+               {
+                       var buffer = new byte [3];
+                       var stream = new MemoryStream (buffer, false);
+                       Assert.IsTrue (Throws<NotSupportedException> (() => { stream.Write (buffer, 0, buffer.Length); }), "#1");
+                       Assert.IsTrue (Throws<ArgumentNullException> (() => { stream.Write (null, 0, buffer.Length); }), "#1.1");
+                       stream.Close ();
+                       Assert.IsTrue (Throws<ObjectDisposedException> (() => { stream.Write (buffer, 0, buffer.Length); }), "#2");
+                       stream = new MemoryStream (buffer, true);
+                       stream.Close ();
+                       Assert.IsFalse (stream.CanWrite, "#3");
+
+                       var estream = new ExceptionalStream (buffer, false);
+                       Assert.IsFalse (Throws<Exception> (() => { estream.Write (buffer, 0, buffer.Length); }), "#4");
+                       estream.AllowWrite = false;
+                       estream.Position = 0;
+                       Assert.IsTrue (Throws<NotSupportedException> (() => { estream.Write (buffer, 0, buffer.Length); }), "#5");
+                       estream.AllowWrite = true;
+                       estream.Close ();
+                       Assert.IsTrue (estream.CanWrite, "#6");
+                       Assert.IsTrue (Throws<ObjectDisposedException> (() => { stream.Write (buffer, 0, buffer.Length); }), "#7");
+               }
+
                [Test]
                public void ReadAsync_Canceled ()
                {
@@ -1109,6 +1313,16 @@ namespace MonoTests.System.IO
                        t = testStream.WriteAsync (buffer, 0, buffer.Length);
                        Assert.IsTrue (t.IsCompleted, "#1");
                }
+
+               bool Throws<T> (Action a) where T : Exception
+               {
+                       try {
+                               a ();
+                               return false;
+                       } catch (T) {
+                               return true;
+                       }
+               }
 #endif
        }
 }