2005-04-16 Gonzalo Paniagua Javier <gonzalo@ximian.com>
[mono.git] / mcs / class / System / System.Net / WebConnectionStream.cs
index df8d1eb2cb9381f905ee331dd6fcaa420ce82542..70ec0173b2c87caf8824888e64ec657966e4bf4a 100644 (file)
@@ -5,15 +5,39 @@
 //     Gonzalo Paniagua Javier (gonzalo@ximian.com)
 //
 // (C) 2003 Ximian, Inc (http://www.ximian.com)
+// (C) 2004 Novell, Inc (http://www.novell.com)
+//
+
+//
+// Permission is hereby granted, free of charge, to any person obtaining
+// a copy of this software and associated documentation files (the
+// "Software"), to deal in the Software without restriction, including
+// without limitation the rights to use, copy, modify, merge, publish,
+// distribute, sublicense, and/or sell copies of the Software, and to
+// permit persons to whom the Software is furnished to do so, subject to
+// the following conditions:
+// 
+// The above copyright notice and this permission notice shall be
+// included in all copies or substantial portions of the Software.
+// 
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
+// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
+// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
+// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 //
 
 using System.IO;
+using System.Text;
 using System.Threading;
 
 namespace System.Net
 {
        class WebConnectionStream : Stream
        {
+               static byte [] crlf = new byte [] { 13, 10 };
                bool isRead;
                WebConnection cnc;
                HttpWebRequest request;
@@ -24,21 +48,32 @@ namespace System.Net
                int totalRead;
                bool nextReadCalled;
                int pendingReads;
+               int pendingWrites;
                ManualResetEvent pending;
                bool allowBuffering;
                bool sendChunked;
                MemoryStream writeBuffer;
                bool requestWritten;
                byte [] headers;
+               bool disposed;
+               bool headersSent;
+               object locker = new object ();
+               bool initRead;
 
                public WebConnectionStream (WebConnection cnc)
                {
                        isRead = true;
                        pending = new ManualResetEvent (true);
+                       this.request = cnc.Data.request;
                        this.cnc = cnc;
-                       try {
-                               contentLength = Int32.Parse (cnc.Data.Headers ["Content-Length"]);
-                       } catch {
+                       string clength = cnc.Data.Headers ["Content-Length"];
+                       if (clength != null && clength != "") {
+                               try {
+                                       contentLength = Int32.Parse (clength);
+                               } catch {
+                                       contentLength = Int32.MaxValue;
+                               }
+                       } else {
                                contentLength = Int32.MaxValue;
                        }
                }
@@ -52,6 +87,13 @@ namespace System.Net
                        sendChunked = request.SendChunked;
                        if (allowBuffering)
                                writeBuffer = new MemoryStream ();
+
+                       if (sendChunked)
+                               pending = new ManualResetEvent (true);
+               }
+
+               internal bool SendChunked {
+                       set { sendChunked = value; }
                }
 
                internal byte [] ReadBuffer {
@@ -66,9 +108,24 @@ namespace System.Net
                        set { readBufferSize = value; }
                }
                
+               internal byte[] WriteBuffer {
+                       get { return writeBuffer.GetBuffer (); }
+               }
+
+               internal int WriteBufferLength {
+                       get { return (int) writeBuffer.Length; }
+               }
+
+               internal void ForceCompletion ()
+               {
+                       nextReadCalled = true;
+                       cnc.NextRead ();
+               }
+               
                internal void CheckComplete ()
                {
-                       if (readBufferSize - readBufferOffset == contentLength) {
+                       bool nrc = nextReadCalled;
+                       if (!nrc && readBufferSize - readBufferOffset == contentLength) {
                                nextReadCalled = true;
                                cnc.NextRead ();
                        }
@@ -76,11 +133,16 @@ namespace System.Net
 
                internal void ReadAll ()
                {
-                       if (!isRead || totalRead >= contentLength || nextReadCalled)
+                       if (!isRead || totalRead >= contentLength || nextReadCalled) {
+                               if (!nextReadCalled) {
+                                       nextReadCalled = true;
+                                       cnc.NextRead ();
+                               }
                                return;
+                       }
 
                        pending.WaitOne ();
-                       lock (this) {
+                       lock (locker) {
                                if (totalRead >= contentLength)
                                        return;
                                
@@ -90,21 +152,32 @@ namespace System.Net
 
                                if (contentLength == Int32.MaxValue) {
                                        MemoryStream ms = new MemoryStream ();
-                                       if (readBuffer != null && diff > 0)
+                                       byte [] buffer = null;
+                                       if (readBuffer != null && diff > 0) {
                                                ms.Write (readBuffer, readBufferOffset, diff);
+                                               if (readBufferSize >= 8192)
+                                                       buffer = readBuffer;
+                                       }
+
+                                       if (buffer == null)
+                                               buffer = new byte [8192];
 
-                                       byte [] buffer = new byte [2048];
                                        int read;
-                                       while ((read = cnc.Read (buffer, 0, 2048)) != 0)
+                                       while ((read = cnc.Read (buffer, 0, buffer.Length)) != 0)
                                                ms.Write (buffer, 0, read);
 
                                        b = ms.GetBuffer ();
                                        new_size = (int) ms.Length;
+                                       contentLength = new_size;
                                } else {
                                        new_size = contentLength - totalRead;
                                        b = new byte [new_size];
-                                       if (readBuffer != null && diff > 0)
+                                       if (readBuffer != null && diff > 0) {
+                                               if (diff > new_size)
+                                                       diff = new_size;
+
                                                Buffer.BlockCopy (readBuffer, readBufferOffset, b, 0, diff);
+                                       }
                                        
                                        int remaining = new_size - diff;
                                        int r = -1;
@@ -118,14 +191,37 @@ namespace System.Net
                                readBuffer = b;
                                readBufferOffset = 0;
                                readBufferSize = new_size;
-                               contentLength = new_size;
                                totalRead = 0;
                                nextReadCalled = true;
                        }
 
                        cnc.NextRead ();
                }
-               
+
+               void WriteCallbackWrapper (IAsyncResult r)
+               {
+                       WebAsyncResult result;
+                       if (r.AsyncState != null) {
+                               result = (WebAsyncResult) r.AsyncState;
+                               result.InnerAsyncResult = r;
+                               result.DoCallback ();
+                       } else {
+                               EndWrite (r);
+                       }
+               }
+
+               void ReadCallbackWrapper (IAsyncResult r)
+               {
+                       WebAsyncResult result;
+                       if (r.AsyncState != null) {
+                               result = (WebAsyncResult) r.AsyncState;
+                               result.InnerAsyncResult = r;
+                               result.DoCallback ();
+                       } else {
+                               EndRead (r);
+                       }
+               }
+
                public override int Read (byte [] buffer, int offset, int size)
                {
                        if (!isRead)
@@ -134,7 +230,13 @@ namespace System.Net
                        if (totalRead >= contentLength)
                                return 0;
 
-                       IAsyncResult res = BeginRead (buffer, offset, size, null, null);
+                       AsyncCallback cb = new AsyncCallback (ReadCallbackWrapper);
+                       WebAsyncResult res = (WebAsyncResult) BeginRead (buffer, offset, size, cb, null);
+                       if (!res.WaitUntilComplete (request.ReadWriteTimeout, false)) {
+                               cnc.Close (true);
+                               throw new IOException ("Read timed out.");
+                       }
+
                        return EndRead (res);
                }
 
@@ -151,14 +253,14 @@ namespace System.Net
                        if (size < 0 || offset < 0 || length < offset || length - offset < size)
                                throw new ArgumentOutOfRangeException ();
 
-                       lock (this) {
+                       lock (locker) {
                                pendingReads++;
                                pending.Reset ();
                        }
-                       
+
                        WebAsyncResult result = new WebAsyncResult (cb, state, buffer, offset, size);
                        if (totalRead >= contentLength) {
-                               result.SetCompleted (true, 0);
+                               result.SetCompleted (true, -1);
                                result.DoCallback ();
                                return result;
                        }
@@ -167,10 +269,10 @@ namespace System.Net
                        if (remaining > 0) {
                                int copy = (remaining > size) ? size : remaining;
                                Buffer.BlockCopy (readBuffer, readBufferOffset, buffer, offset, copy);
-                               totalRead += copy;
                                readBufferOffset += copy;
                                offset += copy;
                                size -= copy;
+                               totalRead += copy;
                                if (size == 0 || totalRead >= contentLength) {
                                        result.SetCompleted (true, copy);
                                        result.DoCallback ();
@@ -179,36 +281,50 @@ namespace System.Net
                                result.NBytes = copy;
                        }
 
-                       result.InnerAsyncResult = cnc.BeginRead (buffer, offset, size, null, null);
+                       if (cb != null)
+                               cb = new AsyncCallback (ReadCallbackWrapper);
+
+                       if (contentLength != Int32.MaxValue && contentLength - totalRead < size)
+                               size = contentLength - totalRead;
+
+                       result.InnerAsyncResult = cnc.BeginRead (buffer, offset, size, cb, result);
                        return result;
                }
 
                public override int EndRead (IAsyncResult r)
                {
                        WebAsyncResult result = (WebAsyncResult) r;
+                       if (result.EndCalled) {
+                               int xx = result.NBytes;
+                               return (xx >= 0) ? xx : 0;
+                       }
 
-                       int nbytes = -1;
-                       if (result.IsCompleted) {
-                               nbytes = result.NBytes;
-                       } else {
-                               nbytes = cnc.EndRead (result.InnerAsyncResult);
-                               lock (this) {
-                                       pendingReads--;
-                                       if (pendingReads == 0)
-                                               pending.Set ();
-                               }
+                       result.EndCalled = true;
+
+                       if (!result.IsCompleted) {
+                               int nbytes = cnc.EndRead (result);
+                               bool finished = (nbytes == -1);
+                               if (finished && result.NBytes > 0)
+                                       nbytes = 0;
 
-                               nbytes += result.NBytes; // partially filled from the read buffer
-                               result.SetCompleted (false, nbytes);
                                totalRead += nbytes;
+                               result.SetCompleted (false, nbytes + result.NBytes);
+                               result.DoCallback ();
+                               if (finished || nbytes == 0)
+                                       contentLength = totalRead;
                        }
 
-                       if (totalRead >= contentLength && !nextReadCalled) {
-                               nextReadCalled = true;
-                               cnc.NextRead ();
+                       lock (locker) {
+                               pendingReads--;
+                               if (pendingReads == 0)
+                                       pending.Set ();
                        }
 
-                       return nbytes;
+                       if (totalRead >= contentLength && !nextReadCalled)
+                               ReadAll ();
+
+                       int nb = result.NBytes;
+                       return (nb >= 0) ? nb : 0;
                }
                
                public override IAsyncResult BeginWrite (byte [] buffer, int offset, int size,
@@ -224,17 +340,44 @@ namespace System.Net
                        if (size < 0 || offset < 0 || length < offset || length - offset < size)
                                throw new ArgumentOutOfRangeException ();
 
+                       if (sendChunked) {
+                               lock (locker) {
+                                       pendingWrites++;
+                                       pending.Reset ();
+                               }
+                       }
+
                        WebAsyncResult result = new WebAsyncResult (cb, state);
                        if (allowBuffering) {
                                writeBuffer.Write (buffer, offset, size);
-                               result.SetCompleted (true, 0);
-                               result.DoCallback ();
-                       } else {
-                               result.InnerAsyncResult = cnc.BeginWrite (buffer, offset, size, cb, state);
-                               if (result.InnerAsyncResult == null)
-                                       throw new WebException ("Aborted");
+                               if (!sendChunked) {
+                                       result.SetCompleted (true, 0);
+                                       result.DoCallback ();
+                                       return result;
+                               }
                        }
 
+                       AsyncCallback callback = null;
+                       if (cb != null)
+                               callback = new AsyncCallback (WriteCallbackWrapper);
+
+                       if (sendChunked) {
+                               WriteRequest ();
+
+                               string cSize = String.Format ("{0:X}\r\n", size);
+                               byte [] head = Encoding.ASCII.GetBytes (cSize);
+                               int chunkSize = 2 + size + head.Length;
+                               byte [] newBuffer = new byte [chunkSize];
+                               Buffer.BlockCopy (head, 0, newBuffer, 0, head.Length);
+                               Buffer.BlockCopy (buffer, offset, newBuffer, head.Length, size);
+                               Buffer.BlockCopy (crlf, 0, newBuffer, head.Length + size, crlf.Length);
+
+                               buffer = newBuffer;
+                               offset = 0;
+                               size = chunkSize;
+                       }
+
+                       result.InnerAsyncResult = cnc.BeginWrite (buffer, offset, size, callback, result);
                        return result;
                }
 
@@ -243,23 +386,44 @@ namespace System.Net
                        if (r == null)
                                throw new ArgumentNullException ("r");
 
-                       if (allowBuffering)
-                               return;
-
                        WebAsyncResult result = r as WebAsyncResult;
                        if (result == null)
                                throw new ArgumentException ("Invalid IAsyncResult");
 
+                       if (result.EndCalled)
+                               return;
+
+                       result.EndCalled = true;
+
+                       if (allowBuffering && !sendChunked)
+                               return;
+
+                       if (result.GotException)
+                               throw result.Exception;
+
                        cnc.EndWrite (result.InnerAsyncResult);
-                       return;
+                       result.SetCompleted (false, 0);
+                       if (sendChunked) {
+                               lock (locker) {
+                                       pendingWrites--;
+                                       if (pendingWrites == 0)
+                                               pending.Set ();
+                               }
+                       }
                }
                
                public override void Write (byte [] buffer, int offset, int size)
                {
                        if (isRead)
-                               throw new NotSupportedException ("this stream does not allow writing");
+                               throw new NotSupportedException ("This stream does not allow writing");
+
+                       AsyncCallback cb = new AsyncCallback (WriteCallbackWrapper);
+                       WebAsyncResult res = (WebAsyncResult) BeginWrite (buffer, offset, size, cb, null);
+                       if (!res.WaitUntilComplete (request.ReadWriteTimeout, false)) {
+                               cnc.Close (true);
+                               throw new IOException ("Write timed out.");
+                       }
 
-                       IAsyncResult res = BeginWrite (buffer, offset, size, null, null);
                        EndWrite (res);
                }
 
@@ -269,59 +433,110 @@ namespace System.Net
 
                internal void SetHeaders (byte [] buffer, int offset, int size)
                {
-                       if (!allowBuffering) {
-                               Write (buffer, offset, size);
+                       if (headersSent)
+                               return;
+
+                       if (!allowBuffering || sendChunked) {
+                               headersSent = true;
+                               if (!cnc.Connected)
+                                       throw new WebException ("Not connected", null, WebExceptionStatus.SendFailure, null);
+
+                               cnc.Write (buffer, offset, size);
+                               if (!initRead) {
+                                       initRead = true;
+                                       WebConnection.InitRead (cnc);
+                               }
                        } else {
                                headers = new byte [size];
                                Buffer.BlockCopy (buffer, offset, headers, 0, size);
                        }
                }
 
+               internal bool RequestWritten {
+                       get { return requestWritten; }
+               }
+
                internal void WriteRequest ()
                {
-                       if (!allowBuffering || writeBuffer == null || requestWritten)
+                       if (requestWritten)
+                               return;
+
+                       if (sendChunked) {
+                               request.SendRequestHeaders ();
+                               requestWritten = true;
+                               return;
+                       }
+
+                       if (!allowBuffering || writeBuffer == null)
                                return;
 
                        byte [] bytes = writeBuffer.GetBuffer ();
                        int length = (int) writeBuffer.Length;
                        if (request.ContentLength != -1 && request.ContentLength < length) {
-                               throw new ProtocolViolationException ("Specified Content-Length is less than the " +
-                                                                     "number of bytes to write");
+                               throw new WebException ("Specified Content-Length is less than the number of bytes to write", null,
+                                                       WebExceptionStatus.ServerProtocolViolation, null);
                        }
 
                        request.InternalContentLength = length;
                        request.SendRequestHeaders ();
-                       cnc.WaitForContinue (headers, 0, headers.Length);
+                       requestWritten = true;
+                       cnc.Write (headers, 0, headers.Length);
+                       if (!cnc.Connected)
+                               throw new WebException ("Error writing request.", null, WebExceptionStatus.SendFailure, null);
+
+                       headersSent = true;
                        if (cnc.Data.StatusCode != 0 && cnc.Data.StatusCode != 100)
                                return;
 
                        cnc.Write (bytes, 0, length);
-                       requestWritten = true;
-                       cnc.dataAvailable.Set ();
                }
 
+               internal void InternalClose ()
+               {
+                       disposed = true;
+               }
+               
                public override void Close ()
                {
-                       if (!allowBuffering)
+                       if (sendChunked) {
+                               pending.WaitOne ();
+                               byte [] chunk = Encoding.ASCII.GetBytes ("0\r\n\r\n");
+                               cnc.Write (chunk, 0, chunk.Length);
+                               return;
+                       }
+
+                       if (isRead) {
+                               if (!nextReadCalled) {
+                                       CheckComplete ();
+                                       // If we have not read all the contents
+                                       if (!nextReadCalled)
+                                               cnc.Close (true);
+                               }
+                               return;
+                       } else if (!allowBuffering) {
+                               if (!initRead) {
+                                       initRead = true;
+                                       WebConnection.InitRead (cnc);
+                               }
+                               return;
+                       }
+
+                       if (disposed)
                                return;
 
-                       // may be ReadAll is isRead?
+                       disposed = true;
+
                        long length = request.ContentLength;
                        if (length != -1 && length > writeBuffer.Length)
                                throw new IOException ("Cannot close the stream until all bytes are written");
 
                        WriteRequest ();
+                       if (!initRead) {
+                               initRead = true;
+                               WebConnection.InitRead (cnc);
+                       }
                }
 
-               internal void ResetWriteBuffer ()
-               {
-                       if (!allowBuffering)
-                               return;
-
-                       writeBuffer = new MemoryStream ();
-                       requestWritten = false;
-               }
-               
                public override long Seek (long a, SeekOrigin b)
                {
                        throw new NotSupportedException ();
@@ -337,7 +552,7 @@ namespace System.Net
                }
 
                public override bool CanRead {
-                       get { return isRead && (contentLength == Int32.MaxValue || totalRead < contentLength); }
+                       get { return isRead; }
                }
 
                public override bool CanWrite {