[WebConnection] Fix race condition between Close and BeginWrite (#4693)
[mono.git] / mcs / class / System / System.Net / WebConnectionStream.cs
index 6510de5a098a8c6a2026419b95d7b88b2bfe2061..96a2b3bf42fc4cc506ad4e6fd056e2c6168d36b6 100644 (file)
@@ -359,7 +359,6 @@ namespace System.Net
                        }
 
                        WebAsyncResult result = new WebAsyncResult (cb, state, buffer, offset, size);
-                       result.AsyncObject = request;
                        if (totalRead >= contentLength) {
                                result.SetCompleted (true, -1);
                                result.DoCallback ();
@@ -389,7 +388,7 @@ namespace System.Net
                                size = (int)(contentLength - totalRead);
 
                        if (!read_eof) {
-                               cnc.ReadAsync (request, buffer, offset, size, result);
+                               result.InnerAsyncResult = cnc.BeginRead (request, buffer, offset, size, cb, result);
                        } else {
                                result.SetCompleted (true, result.NBytes);
                                result.DoCallback ();
@@ -400,35 +399,53 @@ namespace System.Net
                public override int EndRead (IAsyncResult r)
                {
                        WebAsyncResult result = (WebAsyncResult) r;
-                       int nb = result.NBytes;
+                       if (result.EndCalled) {
+                               int xx = result.NBytes;
+                               return (xx >= 0) ? xx : 0;
+                       }
 
-                       if (result.EndCalled)
-                               return (nb >= 0) ? nb : 0;
                        result.EndCalled = true;
 
+                       if (!result.IsCompleted) {
+                               int nbytes = -1;
+                               try {
+                                       nbytes = cnc.EndRead (request, result);
+                               } catch (Exception exc) {
+                                       lock (locker) {
+                                               pendingReads--;
+                                               if (pendingReads == 0)
+                                                       pending.Set ();
+                                       }
+
+                                       nextReadCalled = true;
+                                       cnc.Close (true);
+                                       result.SetCompleted (false, exc);
+                                       result.DoCallback ();
+                                       throw;
+                               }
+
+                               if (nbytes < 0) {
+                                       nbytes = 0;
+                                       read_eof = true;
+                               }
+
+                               totalRead += nbytes;
+                               result.SetCompleted (false, nbytes + result.NBytes);
+                               result.DoCallback ();
+                               if (nbytes == 0)
+                                       contentLength = totalRead;
+                       }
+
                        lock (locker) {
                                pendingReads--;
                                if (pendingReads == 0)
                                        pending.Set ();
                        }
 
-                       if (result.GotException) {
-                               nextReadCalled = true;
-                               cnc.Close (true);
-                               throw result.Exception;
-                       }
-
-                       if (nb < 0) {
-                               read_eof = true;
-                       } else {
-                               totalRead += result.NBytes;
-                               if (nb == 0)
-                                       contentLength = totalRead;
-                       }
-
                        if (totalRead >= contentLength && !nextReadCalled)
                                ReadAll ();
 
+                       int nb = result.NBytes;
                        return (nb >= 0) ? nb : 0;
                }
 
@@ -442,7 +459,7 @@ namespace System.Net
                                result.SetCompleted (false, 0);
                                if (!initRead) {
                                        initRead = true;
-                                       WebConnection.InitRead (cnc);
+                                       cnc.InitRead ();
                                }
                        } catch (Exception e) {
                                KillBuffer ();
@@ -463,7 +480,7 @@ namespace System.Net
                                                        AsyncCallback cb, object state)
                {
                        if (request.Aborted)
-                               throw new WebException ("The request was canceled.", null, WebExceptionStatus.RequestCanceled);
+                               throw new WebException ("The request was canceled.", WebExceptionStatus.RequestCanceled);
 
                        if (isRead)
                                throw new NotSupportedException ("this stream does not allow writing");
@@ -649,7 +666,7 @@ namespace System.Net
                                        cnc.EndWrite (request, true, r);
                                        if (!initRead) {
                                                initRead = true;
-                                               WebConnection.InitRead (cnc);
+                                               cnc.InitRead ();
                                        }
                                        var cl = request.ContentLength;
                                        if (!sendChunked && cl == 0)
@@ -658,7 +675,7 @@ namespace System.Net
                                } catch (WebException e) {
                                        result.SetCompleted (false, e);
                                } catch (Exception e) {
-                                       result.SetCompleted (false, new WebException ("Error writing headers", e, WebExceptionStatus.SendFailure));
+                                       result.SetCompleted (false, new WebException ("Error writing headers", WebExceptionStatus.SendFailure, WebExceptionInternalStatus.RequestFatal, e));
                                }
                        }, null);
 
@@ -702,23 +719,23 @@ namespace System.Net
 
                        SetHeadersAsync (true, inner => {
                                if (inner.GotException) {
-                                       result.SetCompleted (inner.CompletedSynchronously, inner.Exception);
+                                       result.SetCompleted (inner.CompletedSynchronouslyPeek, inner.Exception);
                                        return;
                                }
 
                                if (cnc.Data.StatusCode != 0 && cnc.Data.StatusCode != 100) {
-                                       result.SetCompleted (inner.CompletedSynchronously);
+                                       result.SetCompleted (inner.CompletedSynchronouslyPeek);
                                        return;
                                }
 
                                if (!initRead) {
                                        initRead = true;
-                                       WebConnection.InitRead (cnc);
+                                       cnc.InitRead ();
                                }
 
                                if (length == 0) {
                                        complete_request_written = true;
-                                       result.SetCompleted (inner.CompletedSynchronously);
+                                       result.SetCompleted (inner.CompletedSynchronouslyPeek);
                                        return;
                                }
 
@@ -783,7 +800,7 @@ namespace System.Net
                                complete_request_written = true;
                                if (!initRead) {
                                        initRead = true;
-                                       WebConnection.InitRead (cnc);
+                                       cnc.InitRead ();
                                }
                                return;
                        }
@@ -797,7 +814,7 @@ namespace System.Net
                                IOException io = new IOException ("Cannot close the stream until all bytes are written");
                                nextReadCalled = true;
                                cnc.Close (true);
-                               throw new WebException ("Request was cancelled.", io, WebExceptionStatus.RequestCanceled);
+                               throw new WebException ("Request was cancelled.", WebExceptionStatus.RequestCanceled, WebExceptionInternalStatus.RequestFatal, io);
                        }
 
                        // Commented out the next line to fix xamarin bug #1512