Convert blocking operations in HttpWebRequest and SslClientStream to non-blocking...
[mono.git] / mcs / class / Mono.Security / Mono.Security.Protocol.Tls / SslClientStream.cs
index 4bcc5b9f7f87d242cff58866177f933d8f0603e7..e615e83e3f66a531ff09fbb82107745837c8eeba 100644 (file)
@@ -280,139 +280,323 @@ namespace Mono.Security.Protocol.Tls
                                        Fig. 1 - Message flow for a full handshake              
                */
 
-               internal override IAsyncResult OnBeginNegotiateHandshake(AsyncCallback callback, object state)
+               private void SafeEndReceiveRecord (IAsyncResult ar, bool ignoreEmpty = false)
                {
-                       try
+                       byte[] record = this.protocol.EndReceiveRecord (ar);
+                       if (!ignoreEmpty && ((record == null) || (record.Length == 0))) {
+                               throw new TlsException (
+                                       AlertDescription.HandshakeFailiure,
+                                       "The server stopped the handshake.");
+                       }
+               }
+
+               private enum NegotiateState
+               {
+                       SentClientHello,
+                       ReceiveClientHelloResponse,
+                       SentCipherSpec,
+                       ReceiveCipherSpecResponse,
+                       SentKeyExchange,
+                       ReceiveFinishResponse,
+                       SentFinished,
+               };
+
+               private class NegotiateAsyncResult : IAsyncResult
+               {
+                       private object locker = new object ();
+                       private AsyncCallback _userCallback;
+                       private object _userState;
+                       private Exception _asyncException;
+                       private ManualResetEvent handle;
+                       private NegotiateState _state;
+                       private bool completed;
+
+                       public NegotiateAsyncResult(AsyncCallback userCallback, object userState, NegotiateState state)
                        {
-                               if (this.context.HandshakeState != HandshakeState.None)
-                               {
-                                       this.context.Clear();
-                               }
+                               _userCallback = userCallback;
+                               _userState = userState;
+                               _state = state;
+                       }
 
-                               // Obtain supported cipher suites
-                               this.context.SupportedCiphers = CipherSuiteFactory.GetSupportedCiphers(this.context.SecurityProtocol);
+                       public NegotiateState State
+                       {
+                               get { return _state; }
+                               set { _state = value; }
+                       }
 
-                               // Set handshake state
-                               this.context.HandshakeState = HandshakeState.Started;
+                       public object AsyncState
+                       {
+                               get { return _userState; }
+                       }
 
-                               // Send client hello
-                               return this.protocol.BeginSendRecord(HandshakeType.ClientHello, callback, state);
+                       public Exception AsyncException
+                       {
+                               get { return _asyncException; }
                        }
-                       catch (TlsException ex)
+
+                       public bool CompletedWithError
                        {
-                               this.protocol.SendAlert(ex.Alert);
+                               get {
+                                       if (!IsCompleted)
+                                               return false; // Perhaps throw InvalidOperationExcetion?
 
-                               throw new IOException("The authentication or decryption has failed.", ex);
+                                       return null != _asyncException;
+                               }
                        }
-                       catch (Exception ex)
+
+                       public WaitHandle AsyncWaitHandle
                        {
-                               this.protocol.SendAlert(AlertDescription.InternalError);
+                               get {
+                                       lock (locker) {
+                                               if (handle == null)
+                                                       handle = new ManualResetEvent (completed);
+                                       }
+                                       return handle;
+                               }
+
+                       }
+
+                       public bool CompletedSynchronously
+                       {
+                               get { return false; }
+                       }
+
+                       public bool IsCompleted
+                       {
+                               get {
+                                       lock (locker) {
+                                               return completed;
+                                       }
+                               }
+                       }
+
+                       public void SetComplete(Exception ex)
+                       {
+                               lock (locker) {
+                                       if (completed)
+                                               return;
 
-                               throw new IOException("The authentication or decryption has failed.", ex);
+                                       completed = true;
+                                       if (handle != null)
+                                               handle.Set ();
+
+                                       if (_userCallback != null)
+                                               _userCallback.BeginInvoke (this, null, null);
+
+                                       _asyncException = ex;
+                               }
+                       }
+
+                       public void SetComplete()
+                       {
+                               SetComplete(null);
                        }
                }
 
-               private void SafeReceiveRecord (Stream s, bool ignoreEmpty = false)
+               internal override IAsyncResult BeginNegotiateHandshake(AsyncCallback callback, object state)
                {
-                       byte[] record = this.protocol.ReceiveRecord (s);
-                       if (!ignoreEmpty && ((record == null) || (record.Length == 0))) {
-                               throw new TlsException (
-                                       AlertDescription.HandshakeFailiure,
-                                       "The server stopped the handshake.");
+                       if (this.context.HandshakeState != HandshakeState.None) {
+                               this.context.Clear ();
                        }
+
+                       // Obtain supported cipher suites
+                       this.context.SupportedCiphers = CipherSuiteFactory.GetSupportedCiphers (this.context.SecurityProtocol);
+
+                       // Set handshake state
+                       this.context.HandshakeState = HandshakeState.Started;
+
+                       NegotiateAsyncResult result = new NegotiateAsyncResult (callback, state, NegotiateState.SentClientHello);
+
+                       // Begin sending the client hello
+                       this.protocol.BeginSendRecord (HandshakeType.ClientHello, NegotiateAsyncWorker, result);
+
+                       return result;
+               }
+
+               internal override void EndNegotiateHandshake (IAsyncResult result)
+               {
+                       NegotiateAsyncResult negotiate = result as NegotiateAsyncResult;
+
+                       if (negotiate == null)
+                               throw new ArgumentNullException ();
+                       if (!negotiate.IsCompleted)
+                               negotiate.AsyncWaitHandle.WaitOne();
+                       if (negotiate.CompletedWithError)
+                               throw negotiate.AsyncException;
                }
 
-               internal override void OnNegotiateHandshakeCallback(IAsyncResult asyncResult)
+               private void NegotiateAsyncWorker (IAsyncResult result)
                {
-                       this.protocol.EndSendRecord(asyncResult);
+                       NegotiateAsyncResult negotiate = result.AsyncState as NegotiateAsyncResult;
 
-                       // Read server response
-                       while (this.context.LastHandshakeMsg != HandshakeType.ServerHelloDone) 
+                       try
                        {
-                               // Read next record (skip empty, e.g. warnings alerts)
-                               SafeReceiveRecord (this.innerStream, true);
+                               switch (negotiate.State)
+                               {
+                               case NegotiateState.SentClientHello:
+                                       this.protocol.EndSendRecord (result);
 
-                               // special case for abbreviated handshake where no ServerHelloDone is sent from the server
-                               if (this.context.AbbreviatedHandshake && (this.context.LastHandshakeMsg == HandshakeType.ServerHello))
+                                       // we are now ready to ready the receive the hello response.
+                                       negotiate.State = NegotiateState.ReceiveClientHelloResponse;
+
+                                       // Start reading the client hello response
+                                       this.protocol.BeginReceiveRecord (this.innerStream, NegotiateAsyncWorker, negotiate);
                                        break;
-                       }
 
-                       // the handshake is much easier if we can reuse a previous session settings
-                       if (this.context.AbbreviatedHandshake) 
-                       {
-                               ClientSessionCache.SetContextFromCache (this.context);
-                               this.context.Negotiating.Cipher.ComputeKeys ();
-                               this.context.Negotiating.Cipher.InitializeCipher ();
+                               case NegotiateState.ReceiveClientHelloResponse:
+                                       this.SafeEndReceiveRecord (result, true);
+
+                                       if (this.context.LastHandshakeMsg != HandshakeType.ServerHelloDone &&
+                                               (!this.context.AbbreviatedHandshake || this.context.LastHandshakeMsg != HandshakeType.ServerHello)) {
+                                               // Read next record (skip empty, e.g. warnings alerts)
+                                               this.protocol.BeginReceiveRecord (this.innerStream, NegotiateAsyncWorker, negotiate);
+                                               break;
+                                       }
+
+                                       // special case for abbreviated handshake where no ServerHelloDone is sent from the server
+                                       if (this.context.AbbreviatedHandshake) {
+                                               ClientSessionCache.SetContextFromCache (this.context);
+                                               this.context.Negotiating.Cipher.ComputeKeys ();
+                                               this.context.Negotiating.Cipher.InitializeCipher ();
+
+                                               negotiate.State = NegotiateState.SentCipherSpec;
+
+                                               // Send Change Cipher Spec message with the current cipher
+                                               // or as plain text if this is the initial negotiation
+                                               this.protocol.BeginSendChangeCipherSpec(NegotiateAsyncWorker, negotiate);
+                                       } else {
+                                               // Send client certificate if requested
+                                               // even if the server ask for it it _may_ still be optional
+                                               bool clientCertificate = this.context.ServerSettings.CertificateRequest;
+
+                                               using (var memstream = new MemoryStream())
+                                               {
+                                                       // NOTE: sadly SSL3 and TLS1 differs in how they handle this and
+                                                       // the current design doesn't allow a very cute way to handle 
+                                                       // SSL3 alert warning for NoCertificate (41).
+                                                       if (this.context.SecurityProtocol == SecurityProtocolType.Ssl3)
+                                                       {
+                                                               clientCertificate = ((this.context.ClientSettings.Certificates != null) &&
+                                                                       (this.context.ClientSettings.Certificates.Count > 0));
+                                                               // this works well with OpenSSL (but only for SSL3)
+                                                       }
+
+                                                       byte[] record = null;
+
+                                                       if (clientCertificate)
+                                                       {
+                                                               record = this.protocol.EncodeHandshakeRecord(HandshakeType.Certificate);
+                                                               memstream.Write(record, 0, record.Length);
+                                                       }
+
+                                                       // Send Client Key Exchange
+                                                       record = this.protocol.EncodeHandshakeRecord(HandshakeType.ClientKeyExchange);
+                                                       memstream.Write(record, 0, record.Length);
+
+                                                       // Now initialize session cipher with the generated keys
+                                                       this.context.Negotiating.Cipher.InitializeCipher();
+
+                                                       // Send certificate verify if requested (optional)
+                                                       if (clientCertificate && (this.context.ClientSettings.ClientCertificate != null))
+                                                       {
+                                                               record = this.protocol.EncodeHandshakeRecord(HandshakeType.CertificateVerify);
+                                                               memstream.Write(record, 0, record.Length);
+                                                       }
+
+                                                       // send the chnage cipher spec.
+                                                       this.protocol.SendChangeCipherSpec(memstream);
+
+                                                       // Send Finished message
+                                                       record = this.protocol.EncodeHandshakeRecord(HandshakeType.Finished);
+                                                       memstream.Write(record, 0, record.Length);
+
+                                                       negotiate.State = NegotiateState.SentKeyExchange;
+
+                                                       // send all the records.
+                                                       this.innerStream.BeginWrite (memstream.GetBuffer (), 0, (int)memstream.Length, NegotiateAsyncWorker, negotiate);
+                                               }
+                                       }
+                                       break;
 
-                               // Send Cipher Spec protocol
-                               this.protocol.SendChangeCipherSpec ();
+                               case NegotiateState.SentKeyExchange:
+                                       this.innerStream.EndWrite (result);
 
-                               // Read record until server finished is received
-                               while (this.context.HandshakeState != HandshakeState.Finished) 
-                               {
-                                       // If all goes well this will process messages:
-                                       //              Change Cipher Spec
-                                       //              Server finished
-                                       SafeReceiveRecord (this.innerStream);
-                               }
+                                       negotiate.State = NegotiateState.ReceiveFinishResponse;
 
-                               // Send Finished message
-                               this.protocol.SendRecord (HandshakeType.Finished);
-                       }
-                       else
-                       {
-                               // Send client certificate if requested
-                               // even if the server ask for it it _may_ still be optional
-                               bool clientCertificate = this.context.ServerSettings.CertificateRequest;
-
-                               // NOTE: sadly SSL3 and TLS1 differs in how they handle this and
-                               // the current design doesn't allow a very cute way to handle 
-                               // SSL3 alert warning for NoCertificate (41).
-                               if (this.context.SecurityProtocol == SecurityProtocolType.Ssl3)
-                               {
-                                       clientCertificate = ((this.context.ClientSettings.Certificates != null) &&
-                                               (this.context.ClientSettings.Certificates.Count > 0));
-                                       // this works well with OpenSSL (but only for SSL3)
-                               }
+                                       this.protocol.BeginReceiveRecord (this.innerStream, NegotiateAsyncWorker, negotiate);
 
-                               if (clientCertificate)
-                               {
-                                       this.protocol.SendRecord(HandshakeType.Certificate);
-                               }
+                                       break;
 
-                               // Send Client Key Exchange
-                               this.protocol.SendRecord(HandshakeType.ClientKeyExchange);
+                               case NegotiateState.ReceiveFinishResponse:
+                                       this.SafeEndReceiveRecord (result);
+
+                                       // Read record until server finished is received
+                                       if (this.context.HandshakeState != HandshakeState.Finished) {
+                                               // If all goes well this will process messages:
+                                               //              Change Cipher Spec
+                                               //              Server finished
+                                               this.protocol.BeginReceiveRecord (this.innerStream, NegotiateAsyncWorker, negotiate);
+                                       }
+                                       else {
+                                               // Reset Handshake messages information
+                                               this.context.HandshakeMessages.Reset ();
+
+                                               // Clear Key Info
+                                               this.context.ClearKeyInfo();
+
+                                               negotiate.SetComplete ();
+                                       }
+                                       break;
 
-                               // Now initialize session cipher with the generated keys
-                               this.context.Negotiating.Cipher.InitializeCipher();
 
-                               // Send certificate verify if requested (optional)
-                               if (clientCertificate && (this.context.ClientSettings.ClientCertificate != null))
-                               {
-                                       this.protocol.SendRecord(HandshakeType.CertificateVerify);
-                               }
+                               case NegotiateState.SentCipherSpec:
+                                       this.protocol.EndSendChangeCipherSpec (result);
 
-                               // Send Cipher Spec protocol
-                               this.protocol.SendChangeCipherSpec ();
+                                       negotiate.State = NegotiateState.ReceiveCipherSpecResponse;
 
-                               // Send Finished message
-                               this.protocol.SendRecord (HandshakeType.Finished);
+                                       // Start reading the cipher spec response
+                                       this.protocol.BeginReceiveRecord (this.innerStream, NegotiateAsyncWorker, negotiate);
+                                       break;
 
-                               // Read record until server finished is received
-                               while (this.context.HandshakeState != HandshakeState.Finished) {
-                                       // If all goes well this will process messages:
-                                       //              Change Cipher Spec
-                                       //              Server finished
-                                       SafeReceiveRecord (this.innerStream);
-                               }
-                       }
+                               case NegotiateState.ReceiveCipherSpecResponse:
+                                       this.SafeEndReceiveRecord (result, true);
+
+                                       if (this.context.HandshakeState != HandshakeState.Finished)
+                                       {
+                                               this.protocol.BeginReceiveRecord (this.innerStream, NegotiateAsyncWorker, negotiate);
+                                       }
+                                       else
+                                       {
+                                               negotiate.State = NegotiateState.SentFinished;
+                                               this.protocol.BeginSendRecord(HandshakeType.Finished, NegotiateAsyncWorker, negotiate);
+                                       }
+                                       break;
+
+                               case NegotiateState.SentFinished:
+                                       this.protocol.EndSendRecord (result);
+
+                                       // Reset Handshake messages information
+                                       this.context.HandshakeMessages.Reset ();
 
-                       // Reset Handshake messages information
-                       this.context.HandshakeMessages.Reset ();
+                                       // Clear Key Info
+                                       this.context.ClearKeyInfo();
 
-                       // Clear Key Info
-                       this.context.ClearKeyInfo();
+                                       negotiate.SetComplete ();
 
+                                       break;
+                               }
+                       }
+                       catch (TlsException ex)
+                       {
+                               // FIXME: should the send alert also be done asynchronously here and below?
+                               this.protocol.SendAlert(ex.Alert);
+                               negotiate.SetComplete (new IOException("The authentication or decryption has failed.", ex));
+                       }
+                       catch (Exception ex)
+                       {
+                               this.protocol.SendAlert(AlertDescription.InternalError);
+                               negotiate.SetComplete (new IOException("The authentication or decryption has failed.", ex));
+                       }
                }
 
                #endregion