Convert blocking operations in HttpWebRequest and SslClientStream to non-blocking...
authorBassam Tabbara <bassam@symform.com>
Fri, 6 Dec 2013 00:21:15 +0000 (16:21 -0800)
committerBassam Tabbara <bassam@symform.com>
Fri, 6 Dec 2013 17:32:22 +0000 (09:32 -0800)
More details:
  * HttpWebRequest now writes the HTTP request headers asynchronously. This will avoid blocking a thread pool thread until the headers are written.
  * SslClientStream was doing the TLS negotiations in a blocking manner. Converted the negotiation method to an async state machine where all operations are done asynchronously.
  * The test case identified in #15451 now completes quickly with less than 15 threads created, vs. 100s of threads used before this fix.

Future work if needed:
  * I did not convert SslServerStream but the same pattern could be used here.
  * SendAlert is still blocking. I thought it was not worth converting it to non-blocking.

mcs/class/Mono.Security/Mono.Security.Protocol.Tls/RecordProtocol.cs
mcs/class/Mono.Security/Mono.Security.Protocol.Tls/SslClientStream.cs
mcs/class/Mono.Security/Mono.Security.Protocol.Tls/SslServerStream.cs
mcs/class/Mono.Security/Mono.Security.Protocol.Tls/SslStreamBase.cs
mcs/class/System/System.Net/HttpWebRequest.cs
mcs/class/System/System.Net/WebConnectionStream.cs

index 9ef45da2c4452226391c9bf53552677fdb4298d8..166f12f0d23c7d0790fc722c2d140494f75b1b11 100644 (file)
@@ -437,10 +437,88 @@ namespace Mono.Security.Protocol.Tls
 
                public byte[] ReceiveRecord(Stream record)
                {
+                       if (this.context.ReceivedConnectionEnd)
+                       {
+                               throw new TlsException(
+                                       AlertDescription.InternalError,
+                                       "The session is finished and it's no longer valid.");
+                       }
+
+                       record_processing.Reset ();
+                       byte[] recordTypeBuffer = new byte[1];
+
+                       int bytesRead = record.Read(recordTypeBuffer, 0, recordTypeBuffer.Length);
+
+                       //We're at the end of the stream. Time to bail.
+                       if (bytesRead == 0)
+                       {
+                               return null;
+                       }
+
+                       // Try to read the Record Content Type
+                       int type = recordTypeBuffer[0];
+
+                       // Set last handshake message received to None
+                       this.context.LastHandshakeMsg = HandshakeType.ClientHello;
 
-                       IAsyncResult ar = this.BeginReceiveRecord(record, null, null);
-                       return this.EndReceiveRecord(ar);
+                       ContentType     contentType     = (ContentType)type;
+                       byte[] buffer = this.ReadRecordBuffer(type, record);
+                       if (buffer == null)
+                       {
+                               // record incomplete (at the moment)
+                               return null;
+                       }
 
+                       // Decrypt message contents if needed
+                       if (contentType == ContentType.Alert && buffer.Length == 2)
+                       {
+                       }
+                       else if ((this.Context.Read != null) && (this.Context.Read.Cipher != null))
+                       {
+                               buffer = this.decryptRecordFragment (contentType, buffer);
+                               DebugHelper.WriteLine ("Decrypted record data", buffer);
+                       }
+
+                       // Process record
+                       switch (contentType)
+                       {
+                       case ContentType.Alert:
+                               this.ProcessAlert((AlertLevel)buffer [0], (AlertDescription)buffer [1]);
+                               if (record.CanSeek) 
+                               {
+                                       // don't reprocess that memory block
+                                       record.SetLength (0); 
+                               }
+                               buffer = null;
+                               break;
+
+                       case ContentType.ChangeCipherSpec:
+                               this.ProcessChangeCipherSpec();
+                               break;
+
+                       case ContentType.ApplicationData:
+                               break;
+
+                       case ContentType.Handshake:
+                               TlsStream message = new TlsStream (buffer);
+                               while (!message.EOF)
+                               {
+                                       this.ProcessHandshakeMessage(message);
+                               }
+                               break;
+
+                       case (ContentType)0x80:
+                               this.context.HandshakeMessages.Write (buffer);
+                               break;
+
+                       default:
+                               throw new TlsException(
+                                       AlertDescription.UnexpectedMessage,
+                                       "Unknown record received from server.");
+                       }
+
+                       record_processing.Set ();
+                       return buffer;
                }
 
                private byte[] ReadRecordBuffer (int contentType, Stream record)
@@ -655,6 +733,57 @@ namespace Mono.Security.Protocol.Tls
                        }
                }
 
+               public void SendChangeCipherSpec(Stream recordStream)
+               {
+                       DebugHelper.WriteLine(">>>> Write Change Cipher Spec");
+
+                       byte[] record = this.EncodeRecord (ContentType.ChangeCipherSpec, new byte[] { 1 });
+
+                       // Send Change Cipher Spec message with the current cipher
+                       // or as plain text if this is the initial negotiation
+                       recordStream.Write(record, 0, record.Length);
+
+                       Context ctx = this.context;
+
+                       // Reset sequence numbers
+                       ctx.WriteSequenceNumber = 0;
+
+                       // all further data sent will be encrypted with the negotiated
+                       // security parameters (now the current parameters)
+                       if (ctx is ClientContext) {
+                               ctx.StartSwitchingSecurityParameters (true);
+                       } else {
+                               ctx.EndSwitchingSecurityParameters (false);
+                       }
+               }
+
+               public IAsyncResult BeginSendChangeCipherSpec(AsyncCallback callback, object state)
+               {
+                       DebugHelper.WriteLine (">>>> Write Change Cipher Spec");
+
+                       // Send Change Cipher Spec message with the current cipher
+                       // or as plain text if this is the initial negotiation
+                       return this.BeginSendRecord (ContentType.ChangeCipherSpec, new byte[] { 1 }, callback, state);
+               }
+
+               public void EndSendChangeCipherSpec (IAsyncResult asyncResult)
+               {
+                       this.EndSendRecord (asyncResult);
+
+                       Context ctx = this.context;
+
+                       // Reset sequence numbers
+                       ctx.WriteSequenceNumber = 0;
+
+                       // all further data sent will be encrypted with the negotiated
+                       // security parameters (now the current parameters)
+                       if (ctx is ClientContext) {
+                               ctx.StartSwitchingSecurityParameters (true);
+                       } else {
+                               ctx.EndSwitchingSecurityParameters (false);
+                       }
+               }
+
                public IAsyncResult BeginSendRecord(HandshakeType handshakeType, AsyncCallback callback, object state)
                {
                        HandshakeMessage msg = this.GetMessage(handshakeType);
@@ -793,7 +922,22 @@ namespace Mono.Security.Protocol.Tls
 
                        return record.ToArray();
                }
-               
+
+               public byte[] EncodeHandshakeRecord(HandshakeType handshakeType)
+               {
+                       HandshakeMessage msg = this.GetMessage(handshakeType);
+
+                       msg.Process();
+
+                       var bytes = this.EncodeRecord (msg.ContentType, msg.EncodeMessage ());
+
+                       msg.Update();
+
+                       msg.Reset();
+
+                       return bytes;
+               }
+                               
                #endregion
 
                #region Cryptography Methods
index 4bcc5b9f7f87d242cff58866177f933d8f0603e7..e615e83e3f66a531ff09fbb82107745837c8eeba 100644 (file)
@@ -280,139 +280,323 @@ namespace Mono.Security.Protocol.Tls
                                        Fig. 1 - Message flow for a full handshake              
                */
 
-               internal override IAsyncResult OnBeginNegotiateHandshake(AsyncCallback callback, object state)
+               private void SafeEndReceiveRecord (IAsyncResult ar, bool ignoreEmpty = false)
                {
-                       try
+                       byte[] record = this.protocol.EndReceiveRecord (ar);
+                       if (!ignoreEmpty && ((record == null) || (record.Length == 0))) {
+                               throw new TlsException (
+                                       AlertDescription.HandshakeFailiure,
+                                       "The server stopped the handshake.");
+                       }
+               }
+
+               private enum NegotiateState
+               {
+                       SentClientHello,
+                       ReceiveClientHelloResponse,
+                       SentCipherSpec,
+                       ReceiveCipherSpecResponse,
+                       SentKeyExchange,
+                       ReceiveFinishResponse,
+                       SentFinished,
+               };
+
+               private class NegotiateAsyncResult : IAsyncResult
+               {
+                       private object locker = new object ();
+                       private AsyncCallback _userCallback;
+                       private object _userState;
+                       private Exception _asyncException;
+                       private ManualResetEvent handle;
+                       private NegotiateState _state;
+                       private bool completed;
+
+                       public NegotiateAsyncResult(AsyncCallback userCallback, object userState, NegotiateState state)
                        {
-                               if (this.context.HandshakeState != HandshakeState.None)
-                               {
-                                       this.context.Clear();
-                               }
+                               _userCallback = userCallback;
+                               _userState = userState;
+                               _state = state;
+                       }
 
-                               // Obtain supported cipher suites
-                               this.context.SupportedCiphers = CipherSuiteFactory.GetSupportedCiphers(this.context.SecurityProtocol);
+                       public NegotiateState State
+                       {
+                               get { return _state; }
+                               set { _state = value; }
+                       }
 
-                               // Set handshake state
-                               this.context.HandshakeState = HandshakeState.Started;
+                       public object AsyncState
+                       {
+                               get { return _userState; }
+                       }
 
-                               // Send client hello
-                               return this.protocol.BeginSendRecord(HandshakeType.ClientHello, callback, state);
+                       public Exception AsyncException
+                       {
+                               get { return _asyncException; }
                        }
-                       catch (TlsException ex)
+
+                       public bool CompletedWithError
                        {
-                               this.protocol.SendAlert(ex.Alert);
+                               get {
+                                       if (!IsCompleted)
+                                               return false; // Perhaps throw InvalidOperationExcetion?
 
-                               throw new IOException("The authentication or decryption has failed.", ex);
+                                       return null != _asyncException;
+                               }
                        }
-                       catch (Exception ex)
+
+                       public WaitHandle AsyncWaitHandle
                        {
-                               this.protocol.SendAlert(AlertDescription.InternalError);
+                               get {
+                                       lock (locker) {
+                                               if (handle == null)
+                                                       handle = new ManualResetEvent (completed);
+                                       }
+                                       return handle;
+                               }
+
+                       }
+
+                       public bool CompletedSynchronously
+                       {
+                               get { return false; }
+                       }
+
+                       public bool IsCompleted
+                       {
+                               get {
+                                       lock (locker) {
+                                               return completed;
+                                       }
+                               }
+                       }
+
+                       public void SetComplete(Exception ex)
+                       {
+                               lock (locker) {
+                                       if (completed)
+                                               return;
 
-                               throw new IOException("The authentication or decryption has failed.", ex);
+                                       completed = true;
+                                       if (handle != null)
+                                               handle.Set ();
+
+                                       if (_userCallback != null)
+                                               _userCallback.BeginInvoke (this, null, null);
+
+                                       _asyncException = ex;
+                               }
+                       }
+
+                       public void SetComplete()
+                       {
+                               SetComplete(null);
                        }
                }
 
-               private void SafeReceiveRecord (Stream s, bool ignoreEmpty = false)
+               internal override IAsyncResult BeginNegotiateHandshake(AsyncCallback callback, object state)
                {
-                       byte[] record = this.protocol.ReceiveRecord (s);
-                       if (!ignoreEmpty && ((record == null) || (record.Length == 0))) {
-                               throw new TlsException (
-                                       AlertDescription.HandshakeFailiure,
-                                       "The server stopped the handshake.");
+                       if (this.context.HandshakeState != HandshakeState.None) {
+                               this.context.Clear ();
                        }
+
+                       // Obtain supported cipher suites
+                       this.context.SupportedCiphers = CipherSuiteFactory.GetSupportedCiphers (this.context.SecurityProtocol);
+
+                       // Set handshake state
+                       this.context.HandshakeState = HandshakeState.Started;
+
+                       NegotiateAsyncResult result = new NegotiateAsyncResult (callback, state, NegotiateState.SentClientHello);
+
+                       // Begin sending the client hello
+                       this.protocol.BeginSendRecord (HandshakeType.ClientHello, NegotiateAsyncWorker, result);
+
+                       return result;
+               }
+
+               internal override void EndNegotiateHandshake (IAsyncResult result)
+               {
+                       NegotiateAsyncResult negotiate = result as NegotiateAsyncResult;
+
+                       if (negotiate == null)
+                               throw new ArgumentNullException ();
+                       if (!negotiate.IsCompleted)
+                               negotiate.AsyncWaitHandle.WaitOne();
+                       if (negotiate.CompletedWithError)
+                               throw negotiate.AsyncException;
                }
 
-               internal override void OnNegotiateHandshakeCallback(IAsyncResult asyncResult)
+               private void NegotiateAsyncWorker (IAsyncResult result)
                {
-                       this.protocol.EndSendRecord(asyncResult);
+                       NegotiateAsyncResult negotiate = result.AsyncState as NegotiateAsyncResult;
 
-                       // Read server response
-                       while (this.context.LastHandshakeMsg != HandshakeType.ServerHelloDone) 
+                       try
                        {
-                               // Read next record (skip empty, e.g. warnings alerts)
-                               SafeReceiveRecord (this.innerStream, true);
+                               switch (negotiate.State)
+                               {
+                               case NegotiateState.SentClientHello:
+                                       this.protocol.EndSendRecord (result);
 
-                               // special case for abbreviated handshake where no ServerHelloDone is sent from the server
-                               if (this.context.AbbreviatedHandshake && (this.context.LastHandshakeMsg == HandshakeType.ServerHello))
+                                       // we are now ready to ready the receive the hello response.
+                                       negotiate.State = NegotiateState.ReceiveClientHelloResponse;
+
+                                       // Start reading the client hello response
+                                       this.protocol.BeginReceiveRecord (this.innerStream, NegotiateAsyncWorker, negotiate);
                                        break;
-                       }
 
-                       // the handshake is much easier if we can reuse a previous session settings
-                       if (this.context.AbbreviatedHandshake) 
-                       {
-                               ClientSessionCache.SetContextFromCache (this.context);
-                               this.context.Negotiating.Cipher.ComputeKeys ();
-                               this.context.Negotiating.Cipher.InitializeCipher ();
+                               case NegotiateState.ReceiveClientHelloResponse:
+                                       this.SafeEndReceiveRecord (result, true);
+
+                                       if (this.context.LastHandshakeMsg != HandshakeType.ServerHelloDone &&
+                                               (!this.context.AbbreviatedHandshake || this.context.LastHandshakeMsg != HandshakeType.ServerHello)) {
+                                               // Read next record (skip empty, e.g. warnings alerts)
+                                               this.protocol.BeginReceiveRecord (this.innerStream, NegotiateAsyncWorker, negotiate);
+                                               break;
+                                       }
+
+                                       // special case for abbreviated handshake where no ServerHelloDone is sent from the server
+                                       if (this.context.AbbreviatedHandshake) {
+                                               ClientSessionCache.SetContextFromCache (this.context);
+                                               this.context.Negotiating.Cipher.ComputeKeys ();
+                                               this.context.Negotiating.Cipher.InitializeCipher ();
+
+                                               negotiate.State = NegotiateState.SentCipherSpec;
+
+                                               // Send Change Cipher Spec message with the current cipher
+                                               // or as plain text if this is the initial negotiation
+                                               this.protocol.BeginSendChangeCipherSpec(NegotiateAsyncWorker, negotiate);
+                                       } else {
+                                               // Send client certificate if requested
+                                               // even if the server ask for it it _may_ still be optional
+                                               bool clientCertificate = this.context.ServerSettings.CertificateRequest;
+
+                                               using (var memstream = new MemoryStream())
+                                               {
+                                                       // NOTE: sadly SSL3 and TLS1 differs in how they handle this and
+                                                       // the current design doesn't allow a very cute way to handle 
+                                                       // SSL3 alert warning for NoCertificate (41).
+                                                       if (this.context.SecurityProtocol == SecurityProtocolType.Ssl3)
+                                                       {
+                                                               clientCertificate = ((this.context.ClientSettings.Certificates != null) &&
+                                                                       (this.context.ClientSettings.Certificates.Count > 0));
+                                                               // this works well with OpenSSL (but only for SSL3)
+                                                       }
+
+                                                       byte[] record = null;
+
+                                                       if (clientCertificate)
+                                                       {
+                                                               record = this.protocol.EncodeHandshakeRecord(HandshakeType.Certificate);
+                                                               memstream.Write(record, 0, record.Length);
+                                                       }
+
+                                                       // Send Client Key Exchange
+                                                       record = this.protocol.EncodeHandshakeRecord(HandshakeType.ClientKeyExchange);
+                                                       memstream.Write(record, 0, record.Length);
+
+                                                       // Now initialize session cipher with the generated keys
+                                                       this.context.Negotiating.Cipher.InitializeCipher();
+
+                                                       // Send certificate verify if requested (optional)
+                                                       if (clientCertificate && (this.context.ClientSettings.ClientCertificate != null))
+                                                       {
+                                                               record = this.protocol.EncodeHandshakeRecord(HandshakeType.CertificateVerify);
+                                                               memstream.Write(record, 0, record.Length);
+                                                       }
+
+                                                       // send the chnage cipher spec.
+                                                       this.protocol.SendChangeCipherSpec(memstream);
+
+                                                       // Send Finished message
+                                                       record = this.protocol.EncodeHandshakeRecord(HandshakeType.Finished);
+                                                       memstream.Write(record, 0, record.Length);
+
+                                                       negotiate.State = NegotiateState.SentKeyExchange;
+
+                                                       // send all the records.
+                                                       this.innerStream.BeginWrite (memstream.GetBuffer (), 0, (int)memstream.Length, NegotiateAsyncWorker, negotiate);
+                                               }
+                                       }
+                                       break;
 
-                               // Send Cipher Spec protocol
-                               this.protocol.SendChangeCipherSpec ();
+                               case NegotiateState.SentKeyExchange:
+                                       this.innerStream.EndWrite (result);
 
-                               // Read record until server finished is received
-                               while (this.context.HandshakeState != HandshakeState.Finished) 
-                               {
-                                       // If all goes well this will process messages:
-                                       //              Change Cipher Spec
-                                       //              Server finished
-                                       SafeReceiveRecord (this.innerStream);
-                               }
+                                       negotiate.State = NegotiateState.ReceiveFinishResponse;
 
-                               // Send Finished message
-                               this.protocol.SendRecord (HandshakeType.Finished);
-                       }
-                       else
-                       {
-                               // Send client certificate if requested
-                               // even if the server ask for it it _may_ still be optional
-                               bool clientCertificate = this.context.ServerSettings.CertificateRequest;
-
-                               // NOTE: sadly SSL3 and TLS1 differs in how they handle this and
-                               // the current design doesn't allow a very cute way to handle 
-                               // SSL3 alert warning for NoCertificate (41).
-                               if (this.context.SecurityProtocol == SecurityProtocolType.Ssl3)
-                               {
-                                       clientCertificate = ((this.context.ClientSettings.Certificates != null) &&
-                                               (this.context.ClientSettings.Certificates.Count > 0));
-                                       // this works well with OpenSSL (but only for SSL3)
-                               }
+                                       this.protocol.BeginReceiveRecord (this.innerStream, NegotiateAsyncWorker, negotiate);
 
-                               if (clientCertificate)
-                               {
-                                       this.protocol.SendRecord(HandshakeType.Certificate);
-                               }
+                                       break;
 
-                               // Send Client Key Exchange
-                               this.protocol.SendRecord(HandshakeType.ClientKeyExchange);
+                               case NegotiateState.ReceiveFinishResponse:
+                                       this.SafeEndReceiveRecord (result);
+
+                                       // Read record until server finished is received
+                                       if (this.context.HandshakeState != HandshakeState.Finished) {
+                                               // If all goes well this will process messages:
+                                               //              Change Cipher Spec
+                                               //              Server finished
+                                               this.protocol.BeginReceiveRecord (this.innerStream, NegotiateAsyncWorker, negotiate);
+                                       }
+                                       else {
+                                               // Reset Handshake messages information
+                                               this.context.HandshakeMessages.Reset ();
+
+                                               // Clear Key Info
+                                               this.context.ClearKeyInfo();
+
+                                               negotiate.SetComplete ();
+                                       }
+                                       break;
 
-                               // Now initialize session cipher with the generated keys
-                               this.context.Negotiating.Cipher.InitializeCipher();
 
-                               // Send certificate verify if requested (optional)
-                               if (clientCertificate && (this.context.ClientSettings.ClientCertificate != null))
-                               {
-                                       this.protocol.SendRecord(HandshakeType.CertificateVerify);
-                               }
+                               case NegotiateState.SentCipherSpec:
+                                       this.protocol.EndSendChangeCipherSpec (result);
 
-                               // Send Cipher Spec protocol
-                               this.protocol.SendChangeCipherSpec ();
+                                       negotiate.State = NegotiateState.ReceiveCipherSpecResponse;
 
-                               // Send Finished message
-                               this.protocol.SendRecord (HandshakeType.Finished);
+                                       // Start reading the cipher spec response
+                                       this.protocol.BeginReceiveRecord (this.innerStream, NegotiateAsyncWorker, negotiate);
+                                       break;
 
-                               // Read record until server finished is received
-                               while (this.context.HandshakeState != HandshakeState.Finished) {
-                                       // If all goes well this will process messages:
-                                       //              Change Cipher Spec
-                                       //              Server finished
-                                       SafeReceiveRecord (this.innerStream);
-                               }
-                       }
+                               case NegotiateState.ReceiveCipherSpecResponse:
+                                       this.SafeEndReceiveRecord (result, true);
+
+                                       if (this.context.HandshakeState != HandshakeState.Finished)
+                                       {
+                                               this.protocol.BeginReceiveRecord (this.innerStream, NegotiateAsyncWorker, negotiate);
+                                       }
+                                       else
+                                       {
+                                               negotiate.State = NegotiateState.SentFinished;
+                                               this.protocol.BeginSendRecord(HandshakeType.Finished, NegotiateAsyncWorker, negotiate);
+                                       }
+                                       break;
+
+                               case NegotiateState.SentFinished:
+                                       this.protocol.EndSendRecord (result);
+
+                                       // Reset Handshake messages information
+                                       this.context.HandshakeMessages.Reset ();
 
-                       // Reset Handshake messages information
-                       this.context.HandshakeMessages.Reset ();
+                                       // Clear Key Info
+                                       this.context.ClearKeyInfo();
 
-                       // Clear Key Info
-                       this.context.ClearKeyInfo();
+                                       negotiate.SetComplete ();
 
+                                       break;
+                               }
+                       }
+                       catch (TlsException ex)
+                       {
+                               // FIXME: should the send alert also be done asynchronously here and below?
+                               this.protocol.SendAlert(ex.Alert);
+                               negotiate.SetComplete (new IOException("The authentication or decryption has failed.", ex));
+                       }
+                       catch (Exception ex)
+                       {
+                               this.protocol.SendAlert(AlertDescription.InternalError);
+                               negotiate.SetComplete (new IOException("The authentication or decryption has failed.", ex));
+                       }
                }
 
                #endregion
index 02cfccfe9527543c3d73adc455080a231acee2e4..b0d8bba6ffc99cc8b92675f31d8318bc32d43ef7 100644 (file)
@@ -196,7 +196,7 @@ namespace Mono.Security.Protocol.Tls
                                        Fig. 1 - Message flow for a full handshake              
                */
 
-               internal override IAsyncResult OnBeginNegotiateHandshake(AsyncCallback callback, object state)
+               internal override IAsyncResult BeginNegotiateHandshake(AsyncCallback callback, object state)
                {
                        // Reset the context if needed
                        if (this.context.HandshakeState != HandshakeState.None)
@@ -215,7 +215,7 @@ namespace Mono.Security.Protocol.Tls
 
                }
 
-               internal override void OnNegotiateHandshakeCallback(IAsyncResult asyncResult)
+               internal override void EndNegotiateHandshake(IAsyncResult asyncResult)
                {
                        // Receive Client Hello message and ignore it
                        this.protocol.EndReceiveRecord(asyncResult);
index 7462702d9a421a34c64d5c5ff717c429fd35c133..5c0032c02f0c5724a29fcf5052a639e39a606cc0 100644 (file)
@@ -96,7 +96,7 @@ namespace Mono.Security.Protocol.Tls
                        {
                                try
                                {
-                                       this.OnNegotiateHandshakeCallback(asyncResult);
+                                       this.EndNegotiateHandshake(asyncResult);
                                }
                                catch (TlsException ex)
                                {
@@ -179,8 +179,8 @@ namespace Mono.Security.Protocol.Tls
 
                #region Abstracts/Virtuals
 
-               internal abstract IAsyncResult OnBeginNegotiateHandshake(AsyncCallback callback, object state);
-               internal abstract void OnNegotiateHandshakeCallback(IAsyncResult asyncResult);
+               internal abstract IAsyncResult BeginNegotiateHandshake (AsyncCallback callback, object state);
+               internal abstract void EndNegotiateHandshake (IAsyncResult result);
 
                internal abstract X509Certificate OnLocalCertificateSelection(X509CertificateCollection clientCertificates,
                                                                                                                        X509Certificate serverCertificate,
@@ -492,7 +492,7 @@ namespace Mono.Security.Protocol.Tls
                                {
                                        if (this.context.HandshakeState == HandshakeState.None)
                                        {
-                                               this.OnBeginNegotiateHandshake(new AsyncCallback(AsyncHandshakeCallback), asyncResult);
+                                               this.BeginNegotiateHandshake(new AsyncCallback(AsyncHandshakeCallback), asyncResult);
 
                                                return true;
                                        }
index c3c4772b5aa3fe29a75026057b81955ac4870b5c..6aee5af52bd3bda083667b07cc29d699d178f597 100644 (file)
@@ -1232,7 +1232,7 @@ namespace System.Net
                        }
                }
 
-               internal void SendRequestHeaders (bool propagate_error)
+               internal byte[] GetRequestHeaders ()
                {
                        StringBuilder req = new StringBuilder ();
                        string query;
@@ -1254,18 +1254,7 @@ namespace System.Net
                                                                actualVersion.Major, actualVersion.Minor);
                        req.Append (GetHeaders ());
                        string reqstr = req.ToString ();
-                       byte [] bytes = Encoding.UTF8.GetBytes (reqstr);
-                       try {
-                               writeStream.SetHeaders (bytes);
-                       } catch (WebException wexc) {
-                               SetWriteStreamError (wexc.Status, wexc);
-                               if (propagate_error)
-                                       throw;
-                       } catch (Exception exc) {
-                               SetWriteStreamError (WebExceptionStatus.SendFailure, exc);
-                               if (propagate_error)
-                                       throw;
-                       }
+                       return Encoding.UTF8.GetBytes (reqstr);
                }
 
                internal void SetWriteStream (WebConnectionStream stream)
@@ -1280,14 +1269,32 @@ namespace System.Net
                                writeStream.SendChunked = false;
                        }
 
-                       SendRequestHeaders (false);
+                       byte[] requestHeaders = GetRequestHeaders ();
+                       WebAsyncResult result = new WebAsyncResult (new AsyncCallback (SetWriteStreamCB), null);
+                       writeStream.SetHeadersAsync (requestHeaders, result);
+               }
 
+               void SetWriteStreamCB(IAsyncResult ar)
+               {
+                       WebAsyncResult result = ar as WebAsyncResult;
+
+                       if (result.Exception != null) {
+                               WebException wexc = result.Exception as WebException;
+                               if (wexc != null) {
+                                       SetWriteStreamError (wexc.Status, wexc);
+                                       return;
+                               }
+                               SetWriteStreamError (WebExceptionStatus.SendFailure, result.Exception);
+                               return;
+                       }
+               
                        haveRequest = true;
-                       
+
                        if (bodyBuffer != null) {
                                // The body has been written and buffered. The request "user"
                                // won't write it again, so we must do it.
                                if (ntlm_auth_state != NtlmAuthState.Challenge) {
+                                       // FIXME: this is a blocking call on the thread pool that could lead to thread pool exhaustion
                                        writeStream.Write (bodyBuffer, 0, bodyBufferLength);
                                        bodyBuffer = null;
                                        writeStream.Close ();
@@ -1295,11 +1302,12 @@ namespace System.Net
                        } else if (method != "HEAD" && method != "GET" && method != "MKCOL" && method != "CONNECT" &&
                                        method != "TRACE") {
                                if (getResponseCalled && !writeStream.RequestWritten)
+                                       // FIXME: this is a blocking call on the thread pool that could lead to thread pool exhaustion
                                        writeStream.WriteRequest ();
                        }
 
                        if (asyncWrite != null) {
-                               asyncWrite.SetCompleted (false, stream);
+                               asyncWrite.SetCompleted (false, writeStream);
                                asyncWrite.DoCallback ();
                                asyncWrite = null;
                        }
index 90507d8b5c7ca0e0de51a154ae595dd6793a04a2..5b1b85a7b8531049073d4dfaa5cd080bd828a8f5 100644 (file)
@@ -632,7 +632,7 @@ namespace System.Net
                {
                }
 
-               internal void SetHeaders (byte [] buffer)
+               internal void SetHeadersAsync (byte[] buffer, WebAsyncResult result)
                {
                        if (headersSent)
                                return;
@@ -646,14 +646,44 @@ namespace System.Net
                                       method == "COPY" || method == "MOVE" || method == "LOCK" ||
                                       method == "UNLOCK");
                        if (sendChunked || cl > -1 || no_writestream || webdav) {
-                               WriteHeaders ();
+
+                               headersSent = true;
+
+                               try {
+                                       result.InnerAsyncResult = cnc.BeginWrite (request, headers, 0, headers.Length, new AsyncCallback(SetHeadersCB), result);
+                                       if (result.InnerAsyncResult == null) {
+                                               // when does BeginWrite return null? Is the case when the request is aborted?
+                                               if (!result.IsCompleted)
+                                                       result.SetCompleted (true, 0);
+                                               result.DoCallback ();
+                                       }
+                               } catch (Exception exc) {
+                                       result.SetCompleted (true, exc);
+                                       result.DoCallback ();
+                               }
+                       }
+               }
+
+               void SetHeadersCB (IAsyncResult r)
+               {
+                       WebAsyncResult result = (WebAsyncResult) r.AsyncState;
+                       result.InnerAsyncResult = null;
+                       try {
+                               cnc.EndWrite2 (request, r);
+                               result.SetCompleted (false, 0);
                                if (!initRead) {
                                        initRead = true;
                                        WebConnection.InitRead (cnc);
                                }
+                               long cl = request.ContentLength;
                                if (!sendChunked && cl == 0)
                                        requestWritten = true;
+                       } catch (WebException e) {
+                               result.SetCompleted (false, e);
+                       } catch (Exception e) {
+                               result.SetCompleted (false, new WebException ("Error writing headers", e, WebExceptionStatus.SendFailure));
                        }
+                       result.DoCallback ();
                }
 
                internal bool RequestWritten {
@@ -669,17 +699,6 @@ namespace System.Net
                        return (length > 0) ? cnc.BeginWrite (request, bytes, 0, length, cb, state) : null;
                }
 
-               void WriteHeaders ()
-               {
-                       if (headersSent)
-                               return;
-
-                       headersSent = true;
-                       string err_msg = null;
-                       if (!cnc.Write (request, headers, 0, headers.Length, ref err_msg))
-                               throw new WebException ("Error writing request: " + err_msg, null, WebExceptionStatus.SendFailure, null);
-               }
-
                internal void WriteRequest ()
                {
                        if (requestWritten)
@@ -707,9 +726,15 @@ namespace System.Net
                                                        method == "TRACE");
                                if (!no_writestream)
                                        request.InternalContentLength = length;
-                               request.SendRequestHeaders (true);
+
+                               byte[] requestHeaders = request.GetRequestHeaders ();
+                               WebAsyncResult ar = new WebAsyncResult (null, null);
+                               SetHeadersAsync (requestHeaders, ar);
+                               ar.AsyncWaitHandle.WaitOne ();
+                               if (ar.Exception != null)
+                                       throw ar.Exception;
                        }
-                       WriteHeaders ();
+
                        if (cnc.Data.StatusCode != 0 && cnc.Data.StatusCode != 100)
                                return;