* roottypes.cs: Rename from tree.cs.
[mono.git] / mcs / class / System / System.Net / WebConnectionStream.cs
index 6cd780e659dc35af9c46c83a04765dc3f7efc9a3..37c9eb4b8406ba486b592b2fe23d8b14f0e35562 100644 (file)
@@ -57,15 +57,23 @@ namespace System.Net
                byte [] headers;
                bool disposed;
                bool headersSent;
-               bool forceCompletion;
+               object locker = new object ();
+               bool initRead;
+               bool read_eof;
+               bool complete_request_written;
+               long max_buffer_size;
 
                public WebConnectionStream (WebConnection cnc)
                {
                        isRead = true;
                        pending = new ManualResetEvent (true);
+                       this.request = cnc.Data.request;
                        this.cnc = cnc;
+                       string contentType = cnc.Data.Headers ["Transfer-Encoding"];
+                       bool chunkedRead = (contentType != null && contentType.ToLower ().IndexOf ("chunked") != -1);
                        string clength = cnc.Data.Headers ["Content-Length"];
-                       if (clength != null && clength != "") {
+                       if (!chunkedRead && clength != null && clength != "") {
+
                                try {
                                        contentLength = Int32.Parse (clength);
                                } catch {
@@ -83,13 +91,21 @@ namespace System.Net
                        this.request = request;
                        allowBuffering = request.InternalAllowBuffering;
                        sendChunked = request.SendChunked;
-                       if (allowBuffering)
+                       if (allowBuffering) {
                                writeBuffer = new MemoryStream ();
+                               max_buffer_size = request.ContentLength;
+                       } else {
+                               max_buffer_size = -1;
+                       }
 
                        if (sendChunked)
                                pending = new ManualResetEvent (true);
                }
 
+               internal bool CompleteRequestWritten {
+                       get { return complete_request_written; }
+               }
+
                internal bool SendChunked {
                        set { sendChunked = value; }
                }
@@ -116,13 +132,14 @@ namespace System.Net
 
                internal void ForceCompletion ()
                {
-                       forceCompletion = true;
+                       nextReadCalled = true;
+                       cnc.NextRead ();
                }
                
                internal void CheckComplete ()
                {
                        bool nrc = nextReadCalled;
-                       if (forceCompletion || (!nrc && readBufferSize - readBufferOffset == contentLength)) {
+                       if (!nrc && readBufferSize - readBufferOffset == contentLength) {
                                nextReadCalled = true;
                                cnc.NextRead ();
                        }
@@ -130,8 +147,8 @@ namespace System.Net
 
                internal void ReadAll ()
                {
-                       if (!isRead || totalRead >= contentLength || nextReadCalled) {
-                               if (!nextReadCalled) {
+                       if (!isRead || read_eof || totalRead >= contentLength || nextReadCalled) {
+                               if (isRead && !nextReadCalled) {
                                        nextReadCalled = true;
                                        cnc.NextRead ();
                                }
@@ -139,7 +156,7 @@ namespace System.Net
                        }
 
                        pending.WaitOne ();
-                       lock (this) {
+                       lock (locker) {
                                if (totalRead >= contentLength)
                                        return;
                                
@@ -194,12 +211,29 @@ namespace System.Net
 
                        cnc.NextRead ();
                }
-               
-               static void CallbackWrapper (IAsyncResult r)
+
+               void WriteCallbackWrapper (IAsyncResult r)
                {
-                       WebAsyncResult result = (WebAsyncResult) r.AsyncState;
-                       result.InnerAsyncResult = r;
-                       result.DoCallback ();
+                       WebAsyncResult result;
+                       if (r.AsyncState != null) {
+                               result = (WebAsyncResult) r.AsyncState;
+                               result.InnerAsyncResult = r;
+                               result.DoCallback ();
+                       } else {
+                               EndWrite (r);
+                       }
+               }
+
+               void ReadCallbackWrapper (IAsyncResult r)
+               {
+                       WebAsyncResult result;
+                       if (r.AsyncState != null) {
+                               result = (WebAsyncResult) r.AsyncState;
+                               result.InnerAsyncResult = r;
+                               result.DoCallback ();
+                       } else {
+                               EndRead (r);
+                       }
                }
 
                public override int Read (byte [] buffer, int offset, int size)
@@ -210,7 +244,14 @@ namespace System.Net
                        if (totalRead >= contentLength)
                                return 0;
 
-                       IAsyncResult res = BeginRead (buffer, offset, size, null, null);
+                       AsyncCallback cb = new AsyncCallback (ReadCallbackWrapper);
+                       WebAsyncResult res = (WebAsyncResult) BeginRead (buffer, offset, size, cb, null);
+                       if (!res.IsCompleted && !res.WaitUntilComplete (request.ReadWriteTimeout, false)) {
+                               nextReadCalled = true;
+                               cnc.Close (true);
+                               throw new IOException ("Read timed out.");
+                       }
+
                        return EndRead (res);
                }
 
@@ -227,7 +268,7 @@ namespace System.Net
                        if (size < 0 || offset < 0 || length < offset || length - offset < size)
                                throw new ArgumentOutOfRangeException ();
 
-                       lock (this) {
+                       lock (locker) {
                                pendingReads++;
                                pending.Reset ();
                        }
@@ -256,32 +297,60 @@ namespace System.Net
                        }
 
                        if (cb != null)
-                               cb = new AsyncCallback (CallbackWrapper);
+                               cb = new AsyncCallback (ReadCallbackWrapper);
 
                        if (contentLength != Int32.MaxValue && contentLength - totalRead < size)
                                size = contentLength - totalRead;
 
-                       result.InnerAsyncResult = cnc.BeginRead (buffer, offset, size, cb, result);
+                       if (!read_eof) {
+                               result.InnerAsyncResult = cnc.BeginRead (buffer, offset, size, cb, result);
+                       } else {
+                               result.SetCompleted (true, result.NBytes);
+                               result.DoCallback ();
+                       }
                        return result;
                }
 
                public override int EndRead (IAsyncResult r)
                {
                        WebAsyncResult result = (WebAsyncResult) r;
+                       if (result.EndCalled) {
+                               int xx = result.NBytes;
+                               return (xx >= 0) ? xx : 0;
+                       }
+
+                       result.EndCalled = true;
 
                        if (!result.IsCompleted) {
-                               int nbytes = cnc.EndRead (result.InnerAsyncResult);
-                               bool finished = (nbytes == -1);
-                               if (finished && result.NBytes > 0)
+                               int nbytes = -1;
+                               try {
+                                       nbytes = cnc.EndRead (result);
+                               } catch (Exception exc) {
+                                       lock (locker) {
+                                               pendingReads--;
+                                               if (pendingReads == 0)
+                                                       pending.Set ();
+                                       }
+
+                                       nextReadCalled = true;
+                                       cnc.Close (true);
+                                       result.SetCompleted (false, exc);
+                                       throw;
+                               }
+
+                               if (nbytes < 0) {
                                        nbytes = 0;
+                                       read_eof = true;
+                               }
 
-                               result.SetCompleted (false, nbytes + result.NBytes);
                                totalRead += nbytes;
-                               if (finished || nbytes == 0)
+                               result.SetCompleted (false, nbytes + result.NBytes);
+                               result.DoCallback ();
+                               if (nbytes == 0)
                                        contentLength = totalRead;
                        }
 
-                       lock (this) {
+                       lock (locker) {
                                pendingReads--;
                                if (pendingReads == 0)
                                        pending.Set ();
@@ -290,7 +359,8 @@ namespace System.Net
                        if (totalRead >= contentLength && !nextReadCalled)
                                ReadAll ();
 
-                       return result.NBytes;
+                       int nb = result.NBytes;
+                       return (nb >= 0) ? nb : 0;
                }
                
                public override IAsyncResult BeginWrite (byte [] buffer, int offset, int size,
@@ -307,7 +377,7 @@ namespace System.Net
                                throw new ArgumentOutOfRangeException ();
 
                        if (sendChunked) {
-                               lock (this) {
+                               lock (locker) {
                                        pendingWrites++;
                                        pending.Reset ();
                                }
@@ -315,6 +385,15 @@ namespace System.Net
 
                        WebAsyncResult result = new WebAsyncResult (cb, state);
                        if (allowBuffering) {
+                               if (max_buffer_size >= 0) {
+                                       long avail = max_buffer_size - writeBuffer.Length;
+                                       if (size > avail) {
+                                               if (requestWritten)
+                                                       throw new ProtocolViolationException (
+                                                       "The number of bytes to be written is greater than " +
+                                                       "the specified ContentLength.");
+                                       }
+                               }
                                writeBuffer.Write (buffer, offset, size);
                                if (!sendChunked) {
                                        result.SetCompleted (true, 0);
@@ -325,7 +404,7 @@ namespace System.Net
 
                        AsyncCallback callback = null;
                        if (cb != null)
-                               callback = new AsyncCallback (CallbackWrapper);
+                               callback = new AsyncCallback (WriteCallbackWrapper);
 
                        if (sendChunked) {
                                WriteRequest ();
@@ -352,19 +431,30 @@ namespace System.Net
                        if (r == null)
                                throw new ArgumentNullException ("r");
 
-                       if (allowBuffering && !sendChunked)
-                               return;
-
                        WebAsyncResult result = r as WebAsyncResult;
                        if (result == null)
                                throw new ArgumentException ("Invalid IAsyncResult");
 
+                       if (result.EndCalled)
+                               return;
+
+                       result.EndCalled = true;
+
+                       if (allowBuffering && !sendChunked)
+                               return;
+
                        if (result.GotException)
                                throw result.Exception;
 
-                       cnc.EndWrite (result.InnerAsyncResult);
+                       try { 
+                               cnc.EndWrite (result.InnerAsyncResult);
+                               result.SetCompleted (false, 0);
+                       } catch (Exception e) {
+                               result.SetCompleted (false, e);
+                       }
+
                        if (sendChunked) {
-                               lock (this) {
+                               lock (locker) {
                                        pendingWrites--;
                                        if (pendingWrites == 0)
                                                pending.Set ();
@@ -377,7 +467,14 @@ namespace System.Net
                        if (isRead)
                                throw new NotSupportedException ("This stream does not allow writing");
 
-                       IAsyncResult res = BeginWrite (buffer, offset, size, null, null);
+                       AsyncCallback cb = new AsyncCallback (WriteCallbackWrapper);
+                       WebAsyncResult res = (WebAsyncResult) BeginWrite (buffer, offset, size, cb, null);
+                       if (!res.IsCompleted && !res.WaitUntilComplete (request.ReadWriteTimeout, false)) {
+                               nextReadCalled = true;
+                               cnc.Close (true);
+                               throw new IOException ("Write timed out.");
+                       }
+
                        EndWrite (res);
                }
 
@@ -392,16 +489,13 @@ namespace System.Net
 
                        if (!allowBuffering || sendChunked) {
                                headersSent = true;
-                               try {
-                                       cnc.Write (buffer, offset, size);
-                               } catch (IOException) {
-                                       if (cnc.Connected)
-                                               throw;
+                               if (!cnc.Connected)
+                                       throw new WebException ("Not connected", null, WebExceptionStatus.SendFailure, null);
 
-                                       if (!cnc.TryReconnect ())
-                                               throw;
-
-                                       cnc.Write (buffer, offset, size);
+                               cnc.Write (buffer, offset, size);
+                               if (!initRead) {
+                                       initRead = true;
+                                       WebConnection.InitRead (cnc);
                                }
                        } else {
                                headers = new byte [size];
@@ -409,6 +503,10 @@ namespace System.Net
                        }
                }
 
+               internal bool RequestWritten {
+                       get { return requestWritten; }
+               }
+
                internal void WriteRequest ()
                {
                        if (requestWritten)
@@ -426,39 +524,49 @@ namespace System.Net
                        byte [] bytes = writeBuffer.GetBuffer ();
                        int length = (int) writeBuffer.Length;
                        if (request.ContentLength != -1 && request.ContentLength < length) {
-                               throw new ProtocolViolationException ("Specified Content-Length is less than the " +
-                                                                     "number of bytes to write");
+                               throw new WebException ("Specified Content-Length is less than the number of bytes to write", null,
+                                                       WebExceptionStatus.ServerProtocolViolation, null);
                        }
 
                        request.InternalContentLength = length;
                        request.SendRequestHeaders ();
                        requestWritten = true;
-                       while (true) {
-                               cnc.Write (headers, 0, headers.Length);
-                               if (!cnc.Connected) {
-                                       if (!cnc.TryReconnect ())
-                                               return;
+                       cnc.Write (headers, 0, headers.Length);
+                       if (!cnc.Connected)
+                               throw new WebException ("Error writing request.", null, WebExceptionStatus.SendFailure, null);
 
-                                       continue;
-                               }
-                               headersSent = true;
-
-                               if (cnc.Data.StatusCode != 0 && cnc.Data.StatusCode != 100)
-                                       return;
+                       headersSent = true;
+                       if (cnc.Data.StatusCode != 0 && cnc.Data.StatusCode != 100)
+                               return;
 
-                               cnc.Write (bytes, 0, length);
-                               if (!cnc.Connected && cnc.TryReconnect ())
-                                       continue;
+                       IAsyncResult result = null;
+                       if (length > 0)
+                               result = cnc.BeginWrite (bytes, 0, length, null, null);
 
-                               break;
+                       if (!initRead) {
+                               initRead = true;
+                               WebConnection.InitRead (cnc);
                        }
+
+                       if (length > 0) 
+                               complete_request_written = cnc.EndWrite (result);
+                       else
+                               complete_request_written = true;
                }
 
                internal void InternalClose ()
                {
                        disposed = true;
                }
-               
+
+               internal void ForceCloseConnection ()
+               {
+                       if (!disposed) {
+                               disposed = true;
+                               cnc.Close (true);
+                       }
+               }
+
                public override void Close ()
                {
                        if (sendChunked) {
@@ -468,16 +576,34 @@ namespace System.Net
                                return;
                        }
 
-                       if (isRead || !allowBuffering || disposed)
+                       if (isRead) {
+                               if (!nextReadCalled) {
+                                       CheckComplete ();
+                                       // If we have not read all the contents
+                                       if (!nextReadCalled) {
+                                               nextReadCalled = true;
+                                               cnc.Close (true);
+                                       }
+                               }
+                               return;
+                       } else if (!allowBuffering) {
+                               complete_request_written = true;
+                               if (!initRead) {
+                                       initRead = true;
+                                       WebConnection.InitRead (cnc);
+                               }
                                return;
+                       }
 
-                       disposed = true;
+                       if (disposed)
+                               return;
 
                        long length = request.ContentLength;
                        if (length != -1 && length > writeBuffer.Length)
                                throw new IOException ("Cannot close the stream until all bytes are written");
 
                        WriteRequest ();
+                       disposed = true;
                }
 
                public override long Seek (long a, SeekOrigin b)