[Mono.Security.Interface]: Improve synergy between `SslStream` and `IMonoSslStream...
[mono.git] / mcs / class / System / System.Net / HttpConnection.cs
index d3cacbc4ba48698dd66723bbbbcebc45b03d414d..680ec8188ede53fa1566e771c2eae315be289c84 100644 (file)
 //
 
 #if SECURITY_DEP
+#if MONO_SECURITY_ALIAS
+extern alias MonoSecurity;
+#endif
 
-#if MONOTOUCH || MONODROID
-using Mono.Security.Protocol.Tls;
+#if MONO_SECURITY_ALIAS
+using MSI = MonoSecurity::Mono.Security.Interface;
 #else
-extern alias MonoSecurity;
-using MonoSecurity::Mono.Security.Protocol.Tls;
+using MSI = Mono.Security.Interface;
 #endif
 
 using System.IO;
 using System.Net.Sockets;
 using System.Text;
 using System.Threading;
+using System.Net.Security;
+using System.Security.Authentication;
 using System.Security.Cryptography;
 using System.Security.Cryptography.X509Certificates;
 
@@ -62,32 +66,44 @@ namespace System.Net {
                int reuses;
                bool context_bound;
                bool secure;
-               AsymmetricAlgorithm key;
+               X509Certificate cert;
                int s_timeout = 90000; // 90k ms for first request, 15k ms from then on
                Timer timer;
                IPEndPoint local_ep;
                HttpListener last_listener;
                int [] client_cert_errors;
                X509Certificate2 client_cert;
+               SslStream ssl_stream;
 
-               public HttpConnection (Socket sock, EndPointListener epl, bool secure, X509Certificate2 cert, AsymmetricAlgorithm key)
+               public HttpConnection (Socket sock, EndPointListener epl, bool secure, X509Certificate cert)
                {
                        this.sock = sock;
                        this.epl = epl;
                        this.secure = secure;
-                       this.key = key;
+                       this.cert = cert;
                        if (secure == false) {
                                stream = new NetworkStream (sock, false);
                        } else {
-                               SslServerStream ssl_stream = new SslServerStream (new NetworkStream (sock, false), cert, false, true, false);
-                               ssl_stream.PrivateKeyCertSelectionDelegate += OnPVKSelection;
-                               ssl_stream.ClientCertValidationDelegate += OnClientCertificateValidation;
+                               ssl_stream = epl.Listener.CreateSslStream (new NetworkStream (sock, false), false, (t, c, ch, e) => {
+                                       if (c == null)
+                                               return true;
+                                       var c2 = c as X509Certificate2;
+                                       if (c2 == null)
+                                               c2 = new X509Certificate2 (c.GetRawCertData ());
+                                       client_cert = c2;
+                                       client_cert_errors = new int[] { (int)e };
+                                       return true;
+                               });
                                stream = ssl_stream;
                        }
                        timer = new Timer (OnTimeout, null, Timeout.Infinite, Timeout.Infinite);
                        Init ();
                }
 
+               internal SslStream SslStream {
+                       get { return ssl_stream; }
+               }
+
                internal int [] ClientCertificateErrors {
                        get { return client_cert_errors; }
                }
@@ -96,25 +112,12 @@ namespace System.Net {
                        get { return client_cert; }
                }
 
-               bool OnClientCertificateValidation (X509Certificate certificate, int[] errors)
-               {
-                       if (certificate == null)
-                               return true;
-                       X509Certificate2 cert = certificate as X509Certificate2;
-                       if (cert == null)
-                               cert = new X509Certificate2 (certificate.GetRawCertData ());
-                       client_cert = cert;
-                       client_cert_errors = errors;
-                       return true;
-               }
-
-               AsymmetricAlgorithm OnPVKSelection (X509Certificate certificate, string targetHost)
-               {
-                       return key;
-               }
-
                void Init ()
                {
+                       if (ssl_stream != null) {
+                               ssl_stream.AuthenticateAsServer (cert, true, (SslProtocols)ServicePointManager.SecurityProtocol, false);
+                       }
+
                        context_bound = false;
                        i_stream = null;
                        o_stream = null;
@@ -202,8 +205,11 @@ namespace System.Net {
                        // TODO: can we get this stream before reading the input?
                        if (o_stream == null) {
                                HttpListener listener = context.Listener;
-                               bool ign = (listener == null) ? true : listener.IgnoreWriteExceptions;
-                               o_stream = new ResponseStream (stream, context.Response, ign);
+                               
+                               if(listener == null)
+                                       return new ResponseStream (stream, context.Response, true);
+
+                               o_stream = new ResponseStream (stream, context.Response, listener.IgnoreWriteExceptions);
                        }
                        return o_stream;
                }
@@ -305,18 +311,25 @@ namespace System.Net {
                        int used = 0;
                        string line;
 
-                       try {
-                               line = ReadLine (buffer, position, len - position, ref used);
-                               position += used;
-                       } catch {
-                               context.ErrorMessage = "Bad request";
-                               context.ErrorStatus = 400;
-                               return true;
-                       }
+                       while (true) {
+                               if (context.HaveError)
+                                       return true;
+
+                               if (position >= len)
+                                       break;
+
+                               try {
+                                       line = ReadLine (buffer, position, len - position, ref used);
+                                       position += used;
+                               } catch {
+                                       context.ErrorMessage = "Bad request";
+                                       context.ErrorStatus = 400;
+                                       return true;
+                               }
 
-                       do {
                                if (line == null)
                                        break;
+
                                if (line == "") {
                                        if (input_state == InputState.RequestLine)
                                                continue;
@@ -337,21 +350,7 @@ namespace System.Net {
                                                return true;
                                        }
                                }
-
-                               if (context.HaveError)
-                                       return true;
-
-                               if (position >= len)
-                                       break;
-                               try {
-                                       line = ReadLine (buffer, position, len - position, ref used);
-                                       position += used;
-                               } catch {
-                                       context.ErrorMessage = "Bad request";
-                                       context.ErrorStatus = 400;
-                                       return true;
-                               }
-                       } while (line != null);
+                       }
 
                        if (used == len) {
                                ms.SetLength (0);
@@ -394,7 +393,7 @@ namespace System.Net {
                                HttpListenerResponse response = context.Response;
                                response.StatusCode = status;
                                response.ContentType = "text/html";
-                               string description = HttpListenerResponse.GetStatusDescription (status);
+                               string description = HttpListenerResponseHelper.GetStatusDescription (status);
                                string str;
                                if (msg != null)
                                        str = String.Format ("<h1>{0} ({1})</h1>", description, msg);