[WebConnection] Fix race condition between Close and BeginWrite (#4693)
[mono.git] / mcs / class / System / System.Net / WebConnectionStream.cs
index 8ef558d72fd997d284c76319ee04b76159d0b7e6..96a2b3bf42fc4cc506ad4e6fd056e2c6168d36b6 100644 (file)
@@ -45,8 +45,8 @@ namespace System.Net
                int readBufferOffset;
                int readBufferSize;
                int stream_length; // -1 when CL not present
-               int contentLength;
-               int totalRead;
+               long contentLength;
+               long totalRead;
                internal long totalWritten;
                bool nextReadCalled;
                int pendingReads;
@@ -93,10 +93,10 @@ namespace System.Net
                                                ReadAll ();
                                        }
                                } catch {
-                                       contentLength = Int32.MaxValue;
+                                       contentLength = Int64.MaxValue;
                                }
                        } else {
-                               contentLength = Int32.MaxValue;
+                               contentLength = Int64.MaxValue;
                        }
 
                        // Negative numbers?
@@ -208,7 +208,7 @@ namespace System.Net
                internal void ForceCompletion ()
                {
                        if (!nextReadCalled) {
-                               if (contentLength == Int32.MaxValue)
+                               if (contentLength == Int64.MaxValue)
                                        contentLength = 0;
                                nextReadCalled = true;
                                cnc.NextRead ();
@@ -234,7 +234,8 @@ namespace System.Net
                                return;
                        }
 
-                       pending.WaitOne ();
+                       if (!pending.WaitOne (ReadTimeout))
+                               throw new WebException ("The operation has timed out.", WebExceptionStatus.Timeout);
                        lock (locker) {
                                if (totalRead >= contentLength)
                                        return;
@@ -243,7 +244,7 @@ namespace System.Net
                                int diff = readBufferSize - readBufferOffset;
                                int new_size;
 
-                               if (contentLength == Int32.MaxValue) {
+                               if (contentLength == Int64.MaxValue) {
                                        MemoryStream ms = new MemoryStream ();
                                        byte [] buffer = null;
                                        if (readBuffer != null && diff > 0) {
@@ -263,7 +264,7 @@ namespace System.Net
                                        new_size = (int) ms.Length;
                                        contentLength = new_size;
                                } else {
-                                       new_size = contentLength - totalRead;
+                                       new_size = (int) (contentLength - totalRead);
                                        b = new byte [new_size];
                                        if (readBuffer != null && diff > 0) {
                                                if (diff > new_size)
@@ -383,8 +384,8 @@ namespace System.Net
                        if (cb != null)
                                cb = cb_wrapper;
 
-                       if (contentLength != Int32.MaxValue && contentLength - totalRead < size)
-                               size = contentLength - totalRead;
+                       if (contentLength != Int64.MaxValue && contentLength - totalRead < size)
+                               size = (int)(contentLength - totalRead);
 
                        if (!read_eof) {
                                result.InnerAsyncResult = cnc.BeginRead (request, buffer, offset, size, cb, result);
@@ -458,7 +459,7 @@ namespace System.Net
                                result.SetCompleted (false, 0);
                                if (!initRead) {
                                        initRead = true;
-                                       WebConnection.InitRead (cnc);
+                                       cnc.InitRead ();
                                }
                        } catch (Exception e) {
                                KillBuffer ();
@@ -479,7 +480,7 @@ namespace System.Net
                                                        AsyncCallback cb, object state)
                {
                        if (request.Aborted)
-                               throw new WebException ("The request was canceled.", null, WebExceptionStatus.RequestCanceled);
+                               throw new WebException ("The request was canceled.", WebExceptionStatus.RequestCanceled);
 
                        if (isRead)
                                throw new NotSupportedException ("this stream does not allow writing");
@@ -592,6 +593,14 @@ namespace System.Net
                        if (result.EndCalled)
                                return;
 
+                       if (sendChunked) {
+                               lock (locker) {
+                                       pendingWrites--;
+                                       if (pendingWrites <= 0)
+                                               pending.Set ();
+                               }
+                       }
+
                        result.EndCalled = true;
                        if (result.AsyncWriteAll) {
                                result.WaitUntilComplete ();
@@ -605,14 +614,6 @@ namespace System.Net
 
                        if (result.GotException)
                                throw result.Exception;
-
-                       if (sendChunked) {
-                               lock (locker) {
-                                       pendingWrites--;
-                                       if (pendingWrites == 0)
-                                               pending.Set ();
-                               }
-                       }
                }
                
                public override void Write (byte [] buffer, int offset, int size)
@@ -653,7 +654,8 @@ namespace System.Net
                        if (setInternalLength && !no_writestream && writeBuffer != null)
                                request.InternalContentLength = writeBuffer.Length;
 
-                       if (!(sendChunked || request.ContentLength > -1 || no_writestream || webdav))
+                       bool has_content = !no_writestream && (writeBuffer == null || request.ContentLength > -1);
+                       if (!(sendChunked || has_content || no_writestream || webdav))
                                return false;
 
                        headersSent = true;
@@ -664,7 +666,7 @@ namespace System.Net
                                        cnc.EndWrite (request, true, r);
                                        if (!initRead) {
                                                initRead = true;
-                                               WebConnection.InitRead (cnc);
+                                               cnc.InitRead ();
                                        }
                                        var cl = request.ContentLength;
                                        if (!sendChunked && cl == 0)
@@ -673,7 +675,7 @@ namespace System.Net
                                } catch (WebException e) {
                                        result.SetCompleted (false, e);
                                } catch (Exception e) {
-                                       result.SetCompleted (false, new WebException ("Error writing headers", e, WebExceptionStatus.SendFailure));
+                                       result.SetCompleted (false, new WebException ("Error writing headers", WebExceptionStatus.SendFailure, WebExceptionInternalStatus.RequestFatal, e));
                                }
                        }, null);
 
@@ -717,23 +719,23 @@ namespace System.Net
 
                        SetHeadersAsync (true, inner => {
                                if (inner.GotException) {
-                                       result.SetCompleted (inner.CompletedSynchronously, inner.Exception);
+                                       result.SetCompleted (inner.CompletedSynchronouslyPeek, inner.Exception);
                                        return;
                                }
 
                                if (cnc.Data.StatusCode != 0 && cnc.Data.StatusCode != 100) {
-                                       result.SetCompleted (inner.CompletedSynchronously);
+                                       result.SetCompleted (inner.CompletedSynchronouslyPeek);
                                        return;
                                }
 
                                if (!initRead) {
                                        initRead = true;
-                                       WebConnection.InitRead (cnc);
+                                       cnc.InitRead ();
                                }
 
                                if (length == 0) {
                                        complete_request_written = true;
-                                       result.SetCompleted (inner.CompletedSynchronously);
+                                       result.SetCompleted (inner.CompletedSynchronouslyPeek);
                                        return;
                                }
 
@@ -775,7 +777,9 @@ namespace System.Net
                                if (disposed)
                                        return;
                                disposed = true;
-                               pending.WaitOne ();
+                               if (!pending.WaitOne (WriteTimeout)) {
+                                       throw new WebException ("The operation has timed out.", WebExceptionStatus.Timeout);
+                               }
                                byte [] chunk = Encoding.ASCII.GetBytes ("0\r\n\r\n");
                                string err_msg = null;
                                cnc.Write (request, chunk, 0, chunk.Length, ref err_msg);
@@ -796,7 +800,7 @@ namespace System.Net
                                complete_request_written = true;
                                if (!initRead) {
                                        initRead = true;
-                                       WebConnection.InitRead (cnc);
+                                       cnc.InitRead ();
                                }
                                return;
                        }
@@ -810,7 +814,7 @@ namespace System.Net
                                IOException io = new IOException ("Cannot close the stream until all bytes are written");
                                nextReadCalled = true;
                                cnc.Close (true);
-                               throw new WebException ("Request was cancelled.", io, WebExceptionStatus.RequestCanceled);
+                               throw new WebException ("Request was cancelled.", WebExceptionStatus.RequestCanceled, WebExceptionInternalStatus.RequestFatal, io);
                        }
 
                        // Commented out the next line to fix xamarin bug #1512