Removed Consoles and ^Ms
[mono.git] / mcs / class / System / System.Net / WebConnection.cs
index fa76308c4b2021bab192bd925d71c67f4d002191..d5a2920a093842b2361d8d9ce50790a4c06874ed 100644 (file)
@@ -7,8 +7,11 @@
 // (C) 2003 Ximian, Inc (http://www.ximian.com)
 //
 
+using System.IO;
 using System.Collections;
 using System.Net.Sockets;
+using System.Reflection;
+using System.Security.Cryptography.X509Certificates;
 using System.Text;
 using System.Threading;
 
@@ -25,64 +28,115 @@ namespace System.Net
        class WebConnection
        {
                ServicePoint sPoint;
-               NetworkStream nstream;
+               Stream nstream;
                Socket socket;
                WebExceptionStatus status;
                WebConnectionGroup group;
                bool busy;
-               ArrayList queue;
                WaitOrTimerCallback initConn;
-               internal ManualResetEvent dataAvailable;
                bool keepAlive;
-               bool aborted;
                byte [] buffer;
-               internal static AsyncCallback readDoneDelegate = new AsyncCallback (ReadDone);
+               static AsyncCallback readDoneDelegate = new AsyncCallback (ReadDone);
                EventHandler abortHandler;
                ReadState readState;
                internal WebConnectionData Data;
                WebConnectionStream prevStream;
                bool chunkedRead;
                ChunkStream chunkStream;
-               AutoResetEvent waitForContinue;
-               bool waitingForContinue;
-               
+               AutoResetEvent goAhead;
+               Queue queue;
+               bool reused;
+               int position;
+
+               bool ssl;
+               bool certsAvailable;
+               static bool sslCheck;
+               static Type sslStream;
+               static PropertyInfo piClient;
+               static PropertyInfo piServer;
+
                public WebConnection (WebConnectionGroup group, ServicePoint sPoint)
                {
                        this.group = group;
                        this.sPoint = sPoint;
-                       queue = new ArrayList (1);
-                       dataAvailable = new ManualResetEvent (true);
                        buffer = new byte [4096];
                        readState = ReadState.None;
                        Data = new WebConnectionData ();
                        initConn = new WaitOrTimerCallback (InitConnection);
                        abortHandler = new EventHandler (Abort);
+                       goAhead = new AutoResetEvent (true);
+                       queue = group.Queue;
                }
 
                public void Connect ()
                {
-                       if (socket != null && socket.Connected && status == WebExceptionStatus.Success)
-                               return;
-
                        lock (this) {
-                               if (socket != null && socket.Connected && status == WebExceptionStatus.Success)
+                               if (socket != null && socket.Connected && status == WebExceptionStatus.Success) {
+                                       reused = true;
                                        return;
+                               }
 
+                               reused = false;
+                               if (socket != null) {
+                                       socket.Close();
+                                       socket = null;
+                               }
                                
-                               if (socket == null)
-                                       socket = new Socket (AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.IP);
-
-                               status = sPoint.Connect (socket);
                                chunkStream = null;
+                               IPHostEntry hostEntry = sPoint.HostEntry;
+
+                               if (hostEntry == null) {
+                                       status = sPoint.UsesProxy ? WebExceptionStatus.ProxyNameResolutionFailure :
+                                                                   WebExceptionStatus.NameResolutionFailure;
+                                       return;
+                               }
+
+                               foreach (IPAddress address in hostEntry.AddressList) {
+                                       socket = new Socket (address.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
+                                       try {
+                                               socket.Connect (new IPEndPoint(address, sPoint.Address.Port));
+                                               status = WebExceptionStatus.Success;
+                                               break;
+                                       } catch (SocketException) {
+                                               socket.Close();
+                                               socket = null;
+                                               status = WebExceptionStatus.ConnectFailure;
+                                       }
+                               }
                        }
                }
 
                bool CreateStream (HttpWebRequest request)
                {
-                       //TODO: create stream for https
                        try {
-                               nstream = new NetworkStream (socket, false);
-                       } catch (Exception e) {
+                               NetworkStream serverStream = new NetworkStream (socket, false);
+                               if (request.RequestUri.Scheme == Uri.UriSchemeHttps) {
+                                       ssl = true;
+                                       if (!sslCheck) {
+                                               lock (typeof (WebConnection)) {
+                                                       sslCheck = true;
+                                                       // HttpsClientStream is an internal glue class in Mono.Security.dll
+                                                       sslStream = Type.GetType ("Mono.Security.Protocol.Tls.HttpsClientStream, " + Consts.AssemblyMono_Security, false);
+                                                       if (sslStream != null) {
+                                                               piClient = sslStream.GetProperty ("SelectedClientCertificate");
+                                                               piServer = sslStream.GetProperty ("ServerCertificate");
+                                                       }
+                                               }
+                                       }
+                                       if (sslStream == null)
+                                               throw new NotSupportedException ("Missing Mono.Security.dll assembly. Support for SSL/TLS is unavailable.");
+
+                                       object[] args = new object [4] { serverStream, request.RequestUri.Host, request.ClientCertificates, request };
+                                       nstream = (Stream) Activator.CreateInstance (sslStream, args);
+
+                                       // we also need to set ServicePoint.Certificate 
+                                       // and ServicePoint.ClientCertificate but this can
+                                       // only be done later (after handshake - which is
+                                       // done only after a read operation).
+                               }
+                               else
+                                       nstream = serverStream;
+                       } catch (Exception) {
                                status = WebExceptionStatus.ConnectFailure;
                                return false;
                        }
@@ -93,10 +147,17 @@ namespace System.Net
                void HandleError (WebExceptionStatus st, Exception e)
                {
                        status = st;
-                       Close ();
+                       lock (this) {
+                               busy = false;
+                               if (st == WebExceptionStatus.RequestCanceled)
+                                       Data = new WebConnectionData ();
+
+                               status = st;
+                       }
+
                        if (e == null) { // At least we now where it comes from
                                try {
-                                       throw new Exception ();
+                                       throw new Exception (new System.Diagnostics.StackTrace ().ToString ());
                                } catch (Exception e2) {
                                        e = e2;
                                }
@@ -104,98 +165,71 @@ namespace System.Net
 
                        if (Data != null && Data.request != null)
                                Data.request.SetResponseError (st, e);
-               }
-               
-               internal bool WaitForContinue (byte [] headers, int offset, int size)
-               {
-                       Data.StatusCode = 0;
-                       waitingForContinue = sPoint.SendContinue;
-                       if (waitingForContinue && waitForContinue == null)
-                               waitForContinue = new AutoResetEvent (false);
-
-                       Write (headers, offset, size);
-                       if (!waitingForContinue)
-                               return false;
-
-                       bool result = waitForContinue.WaitOne (2000, false);
-                       waitingForContinue = false;
-                       if (result) {
-                               sPoint.SendContinue = true;
-                               if (Data.request.ExpectContinue)
-                                       Data.request.DoContinueDelegate (Data.StatusCode, Data.Headers);
-                       } else {
-                               sPoint.SendContinue = false;
-                       }
 
-                       return result;
+                       Close (true);
                }
                
                static void ReadDone (IAsyncResult result)
                {
                        WebConnection cnc = (WebConnection) result.AsyncState;
                        WebConnectionData data = cnc.Data;
-                       NetworkStream ns = cnc.nstream;
-                       if (ns == null)
+                       Stream ns = cnc.nstream;
+                       if (ns == null) {
+                               cnc.Close (true);
                                return;
+                       }
 
                        int nread = -1;
-                       cnc.dataAvailable.Reset ();
                        try {
                                nread = ns.EndRead (result);
                        } catch (Exception e) {
                                cnc.status = WebExceptionStatus.ReceiveFailure;
                                cnc.HandleError (cnc.status, e);
-                               cnc.dataAvailable.Set ();
                                return;
                        }
 
                        if (nread == 0) {
-                               Console.WriteLine ("nread == 0: may be the connection was closed?");
-                               cnc.dataAvailable.Set ();
+                               cnc.status = WebExceptionStatus.ReceiveFailure;
+                               cnc.HandleError (cnc.status, null);
                                return;
                        }
 
                        if (nread < 0) {
                                cnc.HandleError (WebExceptionStatus.ServerProtocolViolation, null);
-                               cnc.dataAvailable.Set ();
                                return;
                        }
 
-                       //Console.WriteLine (System.Text.Encoding.Default.GetString (cnc.buffer, 0, nread));
+                       //Console.WriteLine (System.Text.Encoding.Default.GetString (cnc.buffer, 0, nread + cnc.position));
                        int pos = -1;
+                       nread += cnc.position;
                        if (cnc.readState == ReadState.None) { 
                                Exception exc = null;
                                try {
                                        pos = cnc.GetResponse (cnc.buffer, nread);
-                                       if (data.StatusCode == 100) {
-                                               cnc.readState = ReadState.None;
-                                               InitRead (cnc);
-                                               cnc.sPoint.SendContinue = true;
-                                               if (cnc.waitingForContinue) {
-                                                       cnc.waitForContinue.Set ();
-                                               } else if (data.request.ExpectContinue) { // We get a 100 after waiting for it.
-                                                       data.request.DoContinueDelegate (data.StatusCode, data.Headers);
-                                               }
-
-                                               return;
-                                       }
                                } catch (Exception e) {
                                        exc = e;
                                }
 
-                               if (pos == -1 || exc != null) {
+                               if (exc != null) {
                                        cnc.HandleError (WebExceptionStatus.ServerProtocolViolation, exc);
-                                       cnc.dataAvailable.Set ();
                                        return;
                                }
                        }
 
                        if (cnc.readState != ReadState.Content) {
-                               cnc.HandleError (WebExceptionStatus.ServerProtocolViolation, null);
-                               cnc.dataAvailable.Set ();
+                               int est = nread * 2;
+                               int max = (est < cnc.buffer.Length) ? cnc.buffer.Length : est;
+                               byte [] newBuffer = new byte [max];
+                               Buffer.BlockCopy (cnc.buffer, 0, newBuffer, 0, nread);
+                               cnc.buffer = newBuffer;
+                               cnc.position = nread;
+                               cnc.readState = ReadState.None;
+                               InitRead (cnc);
                                return;
                        }
 
+                       cnc.position = 0;
+
                        WebConnectionStream stream = new WebConnectionStream (cnc);
 
                        string contentType = data.Headers ["Transfer-Encoding"];
@@ -211,22 +245,41 @@ namespace System.Net
                                cnc.chunkStream.Write (cnc.buffer, pos, nread);
                        }
 
-                       cnc.prevStream = stream;
                        data.stream = stream;
+                       
+                       lock (cnc) {
+                               lock (cnc.queue) {
+                                       if (cnc.queue.Count > 0) {
+                                               stream.ReadAll ();
+                                       } else {
+                                               cnc.prevStream = stream;
+                                               stream.CheckComplete ();
+                                       }
+                               }
+                       }
+                       
                        data.request.SetResponseData (data);
-                       stream.CheckComplete ();
+               }
+
+               internal void GetCertificates () 
+               {
+                       // here the SSL negotiation have been done
+                       X509Certificate client = (X509Certificate) piClient.GetValue (nstream, null);
+                       X509Certificate server = (X509Certificate) piServer.GetValue (nstream, null);
+                       sPoint.SetCertificates (client, server);
+                       certsAvailable = (server != null);
                }
                
                static void InitRead (object state)
                {
                        WebConnection cnc = (WebConnection) state;
-                       NetworkStream ns = cnc.nstream;
-                       
+                       Stream ns = cnc.nstream;
+
                        try {
-                               ns.BeginRead (cnc.buffer, 0, cnc.buffer.Length, readDoneDelegate, cnc);
+                               int size = cnc.buffer.Length - cnc.position;
+                               ns.BeginRead (cnc.buffer, cnc.position, size, readDoneDelegate, cnc);
                        } catch (Exception e) {
                                cnc.HandleError (WebExceptionStatus.ReceiveFailure, e);
-                               cnc.dataAvailable.Set ();
                        }
                }
                
@@ -235,67 +288,91 @@ namespace System.Net
                        int pos = 0;
                        string line = null;
                        bool lineok = false;
-                       
-                       if (readState == ReadState.None) {
-                               lineok = ReadLine (buffer, ref pos, max, ref line);
-                               if (!lineok)
-                                       return -1;
-
-                               readState = ReadState.Status;
+                       bool isContinue = false;
+                       do {
+                               if (readState == ReadState.None) {
+                                       lineok = ReadLine (buffer, ref pos, max, ref line);
+                                       if (!lineok)
+                                               return -1;
+
+                                       readState = ReadState.Status;
+
+                                       string [] parts = line.Split (' ');
+                                       if (parts.Length < 2)
+                                               return -1;
+
+                                       if (String.Compare (parts [0], "HTTP/1.1", true) == 0) {
+                                               Data.Version = HttpVersion.Version11;
+                                               sPoint.SetVersion (HttpVersion.Version11);
+                                       } else {
+                                               Data.Version = HttpVersion.Version10;
+                                               sPoint.SetVersion (HttpVersion.Version10);
+                                       }
 
-                               string [] parts = line.Split (' ');
-                               if (parts.Length < 3)
-                                       return -1;
+                                       Data.StatusCode = (int) UInt32.Parse (parts [1]);
+                                       if (parts.Length >= 3)
+                                               Data.StatusDescription = String.Join (" ", parts, 2, parts.Length - 2);
+                                       else
+                                               Data.StatusDescription = "";
 
-                               if (String.Compare (parts [0], "HTTP/1.1", true) == 0) {
-                                       Data.Version = HttpVersion.Version11;
-                               } else {
-                                       Data.Version = HttpVersion.Version10;
+                                       if (pos >= max)
+                                               return pos;
                                }
 
-                               Data.StatusCode = (int) UInt32.Parse (parts [1]);
-                               Data.StatusDescription = String.Join (" ", parts, 2, parts.Length - 2);
-                               if (pos >= max)
-                                       return pos;
-                       }
-
-                       if (readState == ReadState.Status) {
-                               readState = ReadState.Headers;
-                               Data.Headers = new WebHeaderCollection ();
-                               ArrayList headers = new ArrayList ();
-                               bool finished = false;
-                               while (!finished) {
-                                       if (ReadLine (buffer, ref pos, max, ref line) == false)
-                                               break;
+                               if (readState == ReadState.Status) {
+                                       readState = ReadState.Headers;
+                                       Data.Headers = new WebHeaderCollection ();
+                                       ArrayList headers = new ArrayList ();
+                                       bool finished = false;
+                                       while (!finished) {
+                                               if (ReadLine (buffer, ref pos, max, ref line) == false)
+                                                       break;
                                        
-                                       if (line == null) {
-                                               // Empty line: end of headers
-                                               finished = true;
-                                               continue;
-                                       }
+                                               if (line == null) {
+                                                       // Empty line: end of headers
+                                                       finished = true;
+                                                       continue;
+                                               }
                                        
-                                       if (line.Length > 0 && (line [0] == ' ' || line [0] == '\t')) {
-                                               int count = headers.Count - 1;
-                                               if (count < 0)
-                                                       break;
-
-                                               string prev = (string) headers [count] + line;
-                                               headers [count] = prev;
-                                       } else {
-                                               headers.Add (line);
+                                               if (line.Length > 0 && (line [0] == ' ' || line [0] == '\t')) {
+                                                       int count = headers.Count - 1;
+                                                       if (count < 0)
+                                                               break;
+
+                                                       string prev = (string) headers [count] + line;
+                                                       headers [count] = prev;
+                                               } else {
+                                                       headers.Add (line);
+                                               }
                                        }
-                               }
 
-                               if (!finished) {
-                                       // handle the error...
-                               } else {
+                                       if (!finished)
+                                               return -1;
+
                                        foreach (string s in headers)
-                                               Data.Headers.Add (s);
+                                               Data.Headers.SetInternal (s);
+
+                                       if (Data.StatusCode == (int) HttpStatusCode.Continue) {
+                                               sPoint.SendContinue = true;
+                                               if (pos >= max)
+                                                       return pos;
+
+                                               if (Data.request.ExpectContinue) {
+                                                       Data.request.DoContinueDelegate (Data.StatusCode, Data.Headers);
+                                                       // Prevent double calls when getting the
+                                                       // headers in several packets.
+                                                       Data.request.ExpectContinue = false;
+                                               }
 
-                                       readState = ReadState.Content;
-                                       return pos;
+                                               readState = ReadState.None;
+                                               isContinue = true;
+                                       }
+                                       else {
+                                               readState = ReadState.Content;
+                                               return pos;
+                                       }
                                }
-                       }
+                       } while (isContinue == true);
 
                        return -1;
                }
@@ -303,22 +380,29 @@ namespace System.Net
                void InitConnection (object state, bool notUsed)
                {
                        HttpWebRequest request = (HttpWebRequest) state;
-                       if (aborted) {
-                               status = WebExceptionStatus.RequestCanceled;
-                               request.SetWriteStreamError (status);
+
+                       if (status == WebExceptionStatus.RequestCanceled) {
+                               busy = false;
+                               Data = new WebConnectionData ();
+                               goAhead.Set ();
+                               SendNext ();
                                return;
                        }
 
+                       keepAlive = request.KeepAlive;
+                       Data = new WebConnectionData ();
+                       Data.request = request;
+
                        Connect ();
                        if (status != WebExceptionStatus.Success) {
                                request.SetWriteStreamError (status);
-                               Close ();
+                               Close (true);
                                return;
                        }
                        
                        if (!CreateStream (request)) {
                                request.SetWriteStreamError (status);
-                               Close ();
+                               Close (true);
                                return;
                        }
 
@@ -327,65 +411,62 @@ namespace System.Net
                        InitRead (this);
                }
                
-               void BeginRequest (HttpWebRequest request)
-               {
-                       lock (this) {
-                               keepAlive = request.KeepAlive;
-                               Data.Init ();
-                               Data.request = request;
-                       }
-
-                       ThreadPool.RegisterWaitForSingleObject (dataAvailable, initConn, request, -1, true);
-               }
-
                internal EventHandler SendRequest (HttpWebRequest request)
                {
-                       Monitor.Enter (this);
-
-                       if (prevStream != null && socket != null && socket.Connected) {
-                               prevStream.ReadAll ();
-                               prevStream = null;
-                       }
+                       lock (this) {
+                               if (prevStream != null && socket != null && socket.Connected) {
+                                       prevStream.ReadAll ();
+                                       prevStream = null;
+                               }
 
-                       if (!busy) {
-                               busy = true;
-                               Monitor.Exit (this);
-                               BeginRequest (request);
-                       } else {
-                               queue.Add (request);
-                               Monitor.Exit (this);
+                               if (!busy) {
+                                       busy = true;
+                                       ThreadPool.RegisterWaitForSingleObject (goAhead, initConn,
+                                                                               request, -1, true);
+                               } else {
+                                       lock (queue) {
+                                               queue.Enqueue (request);
+                                       }
+                               }
                        }
 
                        return abortHandler;
                }
                
-               internal void NextRead ()
+               void SendNext ()
                {
-                       Monitor.Enter (this);
-                       string header = (sPoint.UsesProxy) ? "Proxy-Connection" : "Connection";
-                       string cncHeader = (Data.Headers != null) ? Data.Headers [header] : null;
-                       bool keepAlive = this.keepAlive;
-                       if (cncHeader != null) {
-                               cncHeader = cncHeader.ToLower ();
-                               keepAlive = (keepAlive && cncHeader.IndexOf ("keep-alive") != -1);
+                       lock (queue) {
+                               if (queue.Count > 0) {
+                                       prevStream = null;
+                                       SendRequest ((HttpWebRequest) queue.Dequeue ());
+                               }
                        }
+               }
 
-                       if ((socket != null && !socket.Connected) ||
-                          (!keepAlive || (cncHeader != null && cncHeader.IndexOf ("close") != -1))) {
-                               Console.WriteLine ("CLosing");
-                               Close ();
-                       }
+               internal void NextRead ()
+               {
+                       lock (this) {
+                               busy = false;
+                               string header = (sPoint.UsesProxy) ? "Proxy-Connection" : "Connection";
+                               string cncHeader = (Data.Headers != null) ? Data.Headers [header] : null;
+                               bool keepAlive = (Data.Version == HttpVersion.Version11);
+                               if (cncHeader != null) {
+                                       cncHeader = cncHeader.ToLower ();
+                                       keepAlive = (this.keepAlive && cncHeader.IndexOf ("keep-alive") != -1);
+                               }
 
-                       busy = false;
-                       dataAvailable.Set ();
+                               if ((socket != null && !socket.Connected) ||
+                                  (!keepAlive || (cncHeader != null && cncHeader.IndexOf ("close") != -1))) {
+                                       Close (false);
+                               }
 
-                       if (queue.Count > 0) {
-                               HttpWebRequest request = (HttpWebRequest) queue [0];
-                               queue.RemoveAt (0);
-                               Monitor.Exit (this);
-                               SendRequest (request);
-                       } else {
-                               Monitor.Exit (this);
+                               goAhead.Set ();
+                               lock (queue) {
+                                       if (queue.Count > 0) {
+                                               prevStream = null;
+                                               SendRequest ((HttpWebRequest) queue.Dequeue ());
+                                       }
+                               }
                        }
                }
                
@@ -440,7 +521,7 @@ namespace System.Net
                        if (!chunkedRead || chunkStream.WantMore) {
                                try {
                                        result = nstream.BeginRead (buffer, offset, size, cb, state);
-                               } catch (Exception e) {
+                               } catch (Exception) {
                                        status = WebExceptionStatus.ReceiveFailure;
                                        throw;
                                }
@@ -467,6 +548,10 @@ namespace System.Net
                                        nbytes = nstream.EndRead (wr.InnerAsyncResult);
 
                                chunkStream.WriteAndReadBack (wr.Buffer, wr.Offset, wr.Size, ref nbytes);
+                               if (nbytes == 0 && chunkStream.WantMore) {
+                                       nbytes = nstream.Read (wr.Buffer, wr.Offset, wr.Size);          
+                                       chunkStream.WriteAndReadBack (wr.Buffer, wr.Offset, wr.Size, ref nbytes);
+                               }
                                return nbytes;
                        }
 
@@ -481,7 +566,7 @@ namespace System.Net
 
                        try {
                                result = nstream.BeginWrite (buffer, offset, size, cb, state);
-                       } catch (Exception e) {
+                       } catch (Exception) {
                                status = WebExceptionStatus.SendFailure;
                                throw;
                        }
@@ -522,15 +607,42 @@ namespace System.Net
 
                        try {
                                nstream.Write (buffer, offset, size);
-                       } catch (Exception e) {
-                               status = WebExceptionStatus.SendFailure;
-                               HandleError (status, e);
+                               // here SSL handshake should have been done
+                               if (ssl && !certsAvailable) {
+                                       GetCertificates ();
+                               }
+                       } catch (Exception) {
                        }
                }
 
-               void Close ()
+               internal bool TryReconnect ()
                {
                        lock (this) {
+                               if (!reused) {
+                                       HandleError (WebExceptionStatus.SendFailure, null);
+                                       return false;
+                               }
+
+                               Close (false);
+                               reused = false;
+                               Connect ();
+                               if (status != WebExceptionStatus.Success) {
+                                       HandleError (WebExceptionStatus.SendFailure, null);
+                                       return false;
+                               }
+                       
+                               if (!CreateStream (Data.request)) {
+                                       HandleError (WebExceptionStatus.SendFailure, null);
+                                       return false;
+                               }
+                       }
+                       return true;
+               }
+
+               void Close (bool sendNext)
+               {
+                       lock (this) {
+                               busy = false;
                                if (nstream != null) {
                                        try {
                                                nstream.Close ();
@@ -544,6 +656,11 @@ namespace System.Net
                                        } catch {}
                                        socket = null;
                                }
+
+                               if (sendNext) {
+                                       goAhead.Set ();
+                                       SendNext ();
+                               }
                        }
                }
 
@@ -552,9 +669,21 @@ namespace System.Net
                        HandleError (WebExceptionStatus.RequestCanceled, null);
                }
 
+               internal bool Busy {
+                       get { lock (this) return busy; }
+               }
+               
+               internal bool Connected {
+                       get {
+                               lock (this) {
+                                       return (socket != null && socket.Connected);
+                               }
+                       }
+               }
+               
                ~WebConnection ()
                {
-                       Close ();
+                       Close (false);
                }
        }
 }