Remove duplicated files from system.dll (saves about 100kb)
[mono.git] / mcs / class / System / System.Net / HttpConnection.cs
index c776105f155702047b9e3731c60a6d2cdd47350d..a34f19024c80761d1d16660ed488d4b62bd8ee33 100644 (file)
@@ -2,9 +2,10 @@
 // System.Net.HttpConnection
 //
 // Author:
-//     Gonzalo Paniagua Javier (gonzalo@novell.com)
+//     Gonzalo Paniagua Javier (gonzalo.mono@gmail.com)
 //
-// Copyright (c) 2005 Novell, Inc. (http://www.novell.com)
+// Copyright (c) 2005-2009 Novell, Inc. (http://www.novell.com)
+// Copyright (c) 2012 Xamarin, Inc. (http://xamarin.com)
 //
 // Permission is hereby granted, free of charge, to any person obtaining
 // a copy of this software and associated documentation files (the
 // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 //
-#if NET_2_0
+
+#if SECURITY_DEP
+
+extern alias MonoSecurity;
+
 using System.IO;
 using System.Net.Sockets;
+using System.Reflection;
 using System.Text;
+using System.Threading;
+using System.Security.Cryptography;
+using System.Security.Cryptography.X509Certificates;
+using MonoSecurity::Mono.Security.Protocol.Tls;
+
 namespace System.Net {
        sealed class HttpConnection
        {
+               static AsyncCallback onread_cb = new AsyncCallback (OnRead);
                const int BufferSize = 8192;
                Socket sock;
-               NetworkStream stream;
+               Stream stream;
                EndPointListener epl;
                MemoryStream ms;
                byte [] buffer;
                HttpListenerContext context;
-               bool secure;
                StringBuilder current_line;
                ListenerPrefix prefix;
                RequestStream i_stream;
                ResponseStream o_stream;
                bool chunked;
-               int chunked_uses;
+               int reuses;
                bool context_bound;
-
-               public HttpConnection (Socket sock, EndPointListener epl, bool secure)
+               bool secure;
+               AsymmetricAlgorithm key;
+               int s_timeout = 90000; // 90k ms for first request, 15k ms from then on
+               Timer timer;
+               IPEndPoint local_ep;
+               HttpListener last_listener;
+               int [] client_cert_errors;
+               X509Certificate2 client_cert;
+
+               public HttpConnection (Socket sock, EndPointListener epl, bool secure, X509Certificate2 cert, AsymmetricAlgorithm key)
                {
                        this.sock = sock;
-                       stream = new NetworkStream (sock, false);
                        this.epl = epl;
                        this.secure = secure;
+                       this.key = key;
+                       if (secure == false) {
+                               stream = new NetworkStream (sock, false);
+                       } else {
+                               SslServerStream ssl_stream = new SslServerStream (new NetworkStream (sock, false), cert, false, true, false);
+                               ssl_stream.PrivateKeyCertSelectionDelegate += OnPVKSelection;
+                               ssl_stream.ClientCertValidationDelegate += OnClientCertificateValidation;
+                               stream = ssl_stream;
+                       }
+                       timer = new Timer (OnTimeout, null, Timeout.Infinite, Timeout.Infinite);
                        Init ();
                }
 
+               internal int [] ClientCertificateErrors {
+                       get { return client_cert_errors; }
+               }
+
+               internal X509Certificate2 ClientCertificate {
+                       get { return client_cert; }
+               }
+
+               bool OnClientCertificateValidation (X509Certificate certificate, int[] errors)
+               {
+                       if (certificate == null)
+                               return true;
+                       X509Certificate2 cert = certificate as X509Certificate2;
+                       if (cert == null)
+                               cert = new X509Certificate2 (certificate.GetRawCertData ());
+                       client_cert = cert;
+                       client_cert_errors = errors;
+                       return true;
+               }
+
+               AsymmetricAlgorithm OnPVKSelection (X509Certificate certificate, string targetHost)
+               {
+                       return key;
+               }
+
                void Init ()
                {
                        context_bound = false;
@@ -71,12 +124,22 @@ namespace System.Net {
                        context = new HttpListenerContext (this);
                }
 
-               public int ChunkedUses {
-                       get { return chunked_uses; }
+               public bool IsClosed {
+                       get { return (sock == null); }
+               }
+
+               public int Reuses {
+                       get { return reuses; }
                }
 
                public IPEndPoint LocalEndPoint {
-                       get { return (IPEndPoint) sock.LocalEndPoint; }
+                       get {
+                               if (local_ep != null)
+                                       return local_ep;
+
+                               local_ep = (IPEndPoint) sock.LocalEndPoint;
+                               return local_ep;
+                       }
                }
 
                public IPEndPoint RemoteEndPoint {
@@ -92,11 +155,26 @@ namespace System.Net {
                        set { prefix = value; }
                }
 
+               void OnTimeout (object unused)
+               {
+                       CloseSocket ();
+                       Unbind ();
+               }
+
                public void BeginReadRequest ()
                {
                        if (buffer == null)
                                buffer = new byte [BufferSize];
-                       stream.BeginRead (buffer, 0, BufferSize, OnRead, this);
+                       try {
+                               if (reuses == 1)
+                                       s_timeout = 15000;
+                               timer.Change (s_timeout, Timeout.Infinite);
+                               stream.BeginRead (buffer, 0, BufferSize, onread_cb, this);
+                       } catch {
+                               timer.Change (Timeout.Infinite, Timeout.Infinite);
+                               CloseSocket ();
+                               Unbind ();
+                       }
                }
 
                public RequestStream GetRequestStream (bool chunked, long contentlength)
@@ -108,9 +186,9 @@ namespace System.Net {
                                if (chunked) {
                                        this.chunked = true;
                                        context.Response.SendChunked = true;
-                                       i_stream = new ChunkedInputStream (context, sock, buffer, position, length - position);
+                                       i_stream = new ChunkedInputStream (context, stream, buffer, position, length - position);
                                } else {
-                                       i_stream = new RequestStream (sock, buffer, position, length - position, contentlength);
+                                       i_stream = new RequestStream (stream, buffer, position, length - position, contentlength);
                                }
                        }
                        return i_stream;
@@ -122,30 +200,44 @@ namespace System.Net {
                        if (o_stream == null) {
                                HttpListener listener = context.Listener;
                                bool ign = (listener == null) ? true : listener.IgnoreWriteExceptions;
-                               o_stream = new ResponseStream (sock, context.Response, ign);
+                               o_stream = new ResponseStream (stream, context.Response, ign);
                        }
                        return o_stream;
                }
 
-               void OnRead (IAsyncResult ares)
+               static void OnRead (IAsyncResult ares)
                {
-                       // TODO: set a limit on ms length.
                        HttpConnection cnc = (HttpConnection) ares.AsyncState;
+                       cnc.OnReadInternal (ares);
+               }
+
+               void OnReadInternal (IAsyncResult ares)
+               {
+                       timer.Change (Timeout.Infinite, Timeout.Infinite);
                        int nread = -1;
                        try {
                                nread = stream.EndRead (ares);
                                ms.Write (buffer, 0, nread);
+                               if (ms.Length > 32768) {
+                                       SendError ("Bad request", 400);
+                                       Close (true);
+                                       return;
+                               }
                        } catch {
-                               if (ms.Length > 0)
+                               if (ms != null && ms.Length > 0)
                                        SendError ();
-                               sock.Close ();
+                               if (sock != null) {
+                                       CloseSocket ();
+                                       Unbind ();
+                               }
                                return;
                        }
 
                        if (nread == 0) {
                                //if (ms.Length > 0)
                                //      SendError (); // Why bother?
-                               sock.Close ();
+                               CloseSocket ();
+                               Unbind ();
                                return;
                        }
 
@@ -155,18 +247,35 @@ namespace System.Net {
 
                                if (context.HaveError) {
                                        SendError ();
-                                       Close ();
+                                       Close (true);
                                        return;
                                }
 
                                if (!epl.BindContext (context)) {
                                        SendError ("Invalid host", 400);
-                                       Close ();
+                                       Close (true);
+                                       return;
                                }
+                               HttpListener listener = context.Listener;
+                               if (last_listener != listener) {
+                                       RemoveConnection ();
+                                       listener.AddConnection (this);
+                                       last_listener = listener;
+                               }
+
                                context_bound = true;
+                               listener.RegisterContext (context);
                                return;
                        }
-                       stream.BeginRead (buffer, 0, BufferSize, OnRead, cnc);
+                       stream.BeginRead (buffer, 0, BufferSize, onread_cb, this);
+               }
+
+               void RemoveConnection ()
+               {
+                       if (last_listener == null)
+                               epl.RemoveConnection (this);
+                       else
+                               last_listener.RemoveConnection (this);
                }
 
                enum InputState {
@@ -192,8 +301,19 @@ namespace System.Net {
                        int len = (int) ms.Length;
                        int used = 0;
                        string line;
-                       while ((line = ReadLine (buffer, position, len - position, ref used)) != null) {
+
+                       try {
+                               line = ReadLine (buffer, position, len - position, ref used);
                                position += used;
+                       } catch {
+                               context.ErrorMessage = "Bad request";
+                               context.ErrorStatus = 400;
+                               return true;
+                       }
+
+                       do {
+                               if (line == null)
+                                       break;
                                if (line == "") {
                                        if (input_state == InputState.RequestLine)
                                                continue;
@@ -206,7 +326,13 @@ namespace System.Net {
                                        context.Request.SetRequestLine (line);
                                        input_state = InputState.Headers;
                                } else {
-                                       context.Request.AddHeader (line);
+                                       try {
+                                               context.Request.AddHeader (line);
+                                       } catch (Exception e) {
+                                               context.ErrorMessage = e.Message;
+                                               context.ErrorStatus = 400;
+                                               return true;
+                                       }
                                }
 
                                if (context.HaveError)
@@ -214,7 +340,15 @@ namespace System.Net {
 
                                if (position >= len)
                                        break;
-                       }
+                               try {
+                                       line = ReadLine (buffer, position, len - position, ref used);
+                                       position += used;
+                               } catch {
+                                       context.ErrorMessage = "Bad request";
+                                       context.ErrorStatus = 400;
+                                       return true;
+                               }
+                       } while (line != null);
 
                        if (used == len) {
                                ms.SetLength (0);
@@ -226,7 +360,7 @@ namespace System.Net {
                string ReadLine (byte [] buffer, int offset, int len, ref int used)
                {
                        if (current_line == null)
-                               current_line = new StringBuilder ();
+                               current_line = new StringBuilder (128);
                        int last = offset + len;
                        used = 0;
                        for (int i = offset; i < last && line_state != LineState.LF; i++) {
@@ -253,18 +387,22 @@ namespace System.Net {
 
                public void SendError (string msg, int status)
                {
-                       HttpListenerResponse response = context.Response;
-                       response.StatusCode = status;
-                       response.ContentType = "text/html";
-                       string description = HttpListenerResponse.GetStatusDescription (status);
-                       string str;
-                       if (msg != null)
-                               str = String.Format ("<h1>{0} ({1})</h1>", description, msg);
-                       else
-                               str = String.Format ("<h1>{0}</h1>", description);
-
-                       byte [] error = context.Response.ContentEncoding.GetBytes (str);
-                       response.Close (error, false);
+                       try {
+                               HttpListenerResponse response = context.Response;
+                               response.StatusCode = status;
+                               response.ContentType = "text/html";
+                               string description = HttpListenerResponse.GetStatusDescription (status);
+                               string str;
+                               if (msg != null)
+                                       str = String.Format ("<h1>{0} ({1})</h1>", description, msg);
+                               else
+                                       str = String.Format ("<h1>{0}</h1>", description);
+
+                               byte [] error = context.Response.ContentEncoding.GetBytes (str);
+                               response.Close (error, false);
+                       } catch {
+                               // response was already closed
+                       }
                }
 
                public void SendError ()
@@ -272,18 +410,69 @@ namespace System.Net {
                        SendError (context.ErrorMessage, context.ErrorStatus);
                }
 
+               void Unbind ()
+               {
+                       if (context_bound) {
+                               epl.UnbindContext (context);
+                               context_bound = false;
+                       }
+               }
+
                public void Close ()
                {
-                       if (o_stream != null) {
-                               Stream st = o_stream;
-                               st.Close ();
+                       Close (false);
+               }
+
+               void CloseSocket ()
+               {
+                       if (sock == null)
+                               return;
+
+                       try {
+                               sock.Close ();
+                       } catch {
+                       } finally {
+                               sock = null;
+                       }
+                       RemoveConnection ();
+               }
+
+               internal void Close (bool force_close)
+               {
+                       if (sock != null) {
+                               Stream st = GetResponseStream ();
+                               if (st != null)
+                                       st.Close ();
+
                                o_stream = null;
                        }
 
                        if (sock != null) {
-                               if (chunked && context.Response.ForceCloseChunked == false) {
-                                       // Don't close. Keep working.
-                                       chunked_uses++;
+                               force_close |= !context.Request.KeepAlive;
+                               if (!force_close)
+                                       force_close = (context.Response.Headers ["connection"] == "close");
+                               /*
+                               if (!force_close) {
+//                                     bool conn_close = (status_code == 400 || status_code == 408 || status_code == 411 ||
+//                                                     status_code == 413 || status_code == 414 || status_code == 500 ||
+//                                                     status_code == 503);
+
+                                       force_close |= (context.Request.ProtocolVersion <= HttpVersion.Version10);
+                               }
+                               */
+
+                               if (!force_close && context.Request.FlushInput ()) {
+                                       if (chunked && context.Response.ForceCloseChunked == false) {
+                                               // Don't close. Keep working.
+                                               reuses++;
+                                               Unbind ();
+                                               Init ();
+                                               BeginReadRequest ();
+                                               return;
+                                       }
+
+                                       reuses++;
+                                       Unbind ();
                                        Init ();
                                        BeginReadRequest ();
                                        return;
@@ -291,10 +480,17 @@ namespace System.Net {
 
                                Socket s = sock;
                                sock = null;
-                               s.Shutdown (SocketShutdown.Both);
-                               s.Close ();
-                               if (context_bound)
-                                       epl.UnbindContext (context);
+                               try {
+                                       if (s != null)
+                                               s.Shutdown (SocketShutdown.Both);
+                               } catch {
+                               } finally {
+                                       if (s != null)
+                                               s.Close ();
+                               }
+                               Unbind ();
+                               RemoveConnection ();
+                               return;
                        }
                }
        }