Convert blocking operations in HttpWebRequest and SslClientStream to non-blocking...
[mono.git] / mcs / class / Mono.Security / Mono.Security.Protocol.Tls / SslServerStream.cs
index a4eafd385c9de6e5c1e7f7191704bd18422006b0..b0d8bba6ffc99cc8b92675f31d8318bc32d43ef7 100644 (file)
@@ -34,7 +34,12 @@ using Mono.Security.Protocol.Tls.Handshake;
 
 namespace Mono.Security.Protocol.Tls
 {
-       public class SslServerStream : Stream, IDisposable
+#if INSIDE_SYSTEM
+       internal
+#else
+       public
+#endif
+       class SslServerStream : SslStreamBase
        {
                #region Internal Events
                
@@ -43,88 +48,12 @@ namespace Mono.Security.Protocol.Tls
                
                #endregion
 
-               #region Fields
-
-               private ServerRecordProtocol    protocol;
-               private BufferedStream                  inputBuffer;
-               private ServerContext                   context;
-               private Stream                                  innerStream;
-               private bool                                    disposed;
-               private bool                                    ownsStream;
-               private bool                                    checkCertRevocationStatus;
-               private object                                  read;
-               private object                                  write;          
-
-               #endregion
-
                #region Properties
 
-               public override bool CanRead
-               {
-                       get { return this.innerStream.CanRead; }
-               }
-
-               public override bool CanWrite
-               {
-                       get { return this.innerStream.CanWrite; }
-               }
-
-               public override bool CanSeek
-               {
-                       get { return this.innerStream.CanSeek; }
-               }
-
-               public override long Length
-               {
-                       get { throw new NotSupportedException(); }
-               }
-
-               public override long Position
-               {
-                       get { throw new NotSupportedException(); }
-                       set { throw new NotSupportedException(); }
-               }
-
-               #endregion
-
-               #region Security Properties
-
-               public bool CheckCertRevocationStatus 
-               {
-                       get { return this.checkCertRevocationStatus ; }
-                       set { this.checkCertRevocationStatus = value; }
-               }
-
-               public CipherAlgorithmType CipherAlgorithm 
-               {
-                       get 
-                       { 
-                               if (this.context.HandshakeState == HandshakeState.Finished)
-                               {
-                                       return this.context.Cipher.CipherAlgorithmType;
-                               }
-
-                               return CipherAlgorithmType.None;
-                       }
-               }
-
-               public int CipherStrength 
-               {
-                       get 
-                       { 
-                               if (this.context.HandshakeState == HandshakeState.Finished)
-                               {
-                                       return this.context.Cipher.EffectiveKeyBits;
-                               }
-
-                               return 0;
-                       }
-               }
-               
                public X509Certificate ClientCertificate
                {
-                       get 
-                       { 
+                       get
+                       {
                                if (this.context.HandshakeState == HandshakeState.Finished)
                                {
                                        return this.context.ClientSettings.ClientCertificate;
@@ -132,89 +61,7 @@ namespace Mono.Security.Protocol.Tls
 
                                return null;
                        }
-               }               
-               
-               public HashAlgorithmType HashAlgorithm 
-               {
-                       get 
-                       { 
-                               if (this.context.HandshakeState == HandshakeState.Finished)
-                               {
-                                       return this.context.Cipher.HashAlgorithmType; 
-                               }
-
-                               return HashAlgorithmType.None;
-                       }
                }
-               
-               public int HashStrength
-               {
-                       get 
-                       { 
-                               if (this.context.HandshakeState == HandshakeState.Finished)
-                               {
-                                       return this.context.Cipher.HashSize * 8; 
-                               }
-
-                               return 0;
-                       }
-               }
-               
-               public int KeyExchangeStrength 
-               {
-                       get 
-                       { 
-                               if (this.context.HandshakeState == HandshakeState.Finished)
-                               {
-                                       return this.context.ServerSettings.Certificates[0].RSA.KeySize;
-                               }
-
-                               return 0;
-                       }
-               }
-               
-               public ExchangeAlgorithmType KeyExchangeAlgorithm 
-               {
-                       get 
-                       { 
-                               if (this.context.HandshakeState == HandshakeState.Finished)
-                               {
-                                       return this.context.Cipher.ExchangeAlgorithmType; 
-                               }
-
-                               return ExchangeAlgorithmType.None;
-                       }
-               }
-               
-               public SecurityProtocolType SecurityProtocol 
-               {
-                       get 
-                       { 
-                               if (this.context.HandshakeState == HandshakeState.Finished)
-                               {
-                                       return this.context.SecurityProtocol; 
-                               }
-
-                               return 0;
-                       }
-               }
-
-               public X509Certificate ServerCertificate 
-               {
-                       get 
-                       { 
-                               if (this.context.HandshakeState == HandshakeState.Finished)
-                               {
-                                       if (this.context.ServerSettings.Certificates != null &&
-                                               this.context.ServerSettings.Certificates.Count > 0)
-                                       {
-                                               return new X509Certificate(this.context.ServerSettings.Certificates[0].RawData);
-                                       }
-                               }
-
-                               return null;
-                       }
-               } 
 
                #endregion
 
@@ -226,7 +73,7 @@ namespace Mono.Security.Protocol.Tls
                        set { this.ClientCertValidation = value; }
                }
 
-               public PrivateKeySelectionCallback PrivateKeyCertSelectionDelegate 
+               public PrivateKeySelectionCallback PrivateKeyCertSelectionDelegate
                {
                        get { return this.PrivateKeySelection; }
                        set { this.PrivateKeySelection = value; }
@@ -234,6 +81,7 @@ namespace Mono.Security.Protocol.Tls
 
                #endregion
 
+               public event CertificateValidationCallback2 ClientCertValidation2;
                #region Constructors
 
                public SslServerStream(
@@ -260,34 +108,43 @@ namespace Mono.Security.Protocol.Tls
                {
                }
 
+               public SslServerStream(
+                       Stream                  stream,
+                       X509Certificate serverCertificate,
+                       bool                    clientCertificateRequired,
+                       bool                    requestClientCertificate,
+                       bool                    ownsStream)
+                               : this (stream, serverCertificate, clientCertificateRequired, requestClientCertificate, ownsStream, SecurityProtocolType.Default)
+               {
+               }
+
                public SslServerStream(
                        Stream                                  stream,
                        X509Certificate                 serverCertificate,
                        bool                                    clientCertificateRequired,
                        bool                                    ownsStream,
                        SecurityProtocolType    securityProtocolType)
+                       : this (stream, serverCertificate, clientCertificateRequired, false, ownsStream, securityProtocolType)
                {
-                       if (stream == null)
-                       {
-                               throw new ArgumentNullException("stream is null.");
-                       }
-                       if (!stream.CanRead || !stream.CanWrite)
-                       {
-                               throw new ArgumentNullException("stream is not both readable and writable.");
-                       }
+               }
 
+               public SslServerStream(
+                       Stream                                  stream,
+                       X509Certificate                 serverCertificate,
+                       bool                                    clientCertificateRequired,
+                       bool                                    requestClientCertificate,
+                       bool                                    ownsStream,
+                       SecurityProtocolType    securityProtocolType)
+                       : base(stream, ownsStream)
+               {
                        this.context = new ServerContext(
                                this,
                                securityProtocolType,
                                serverCertificate,
-                               clientCertificateRequired);
-
-                       this.inputBuffer        = new BufferedStream(new MemoryStream());
-                       this.innerStream        = stream;
-                       this.ownsStream         = ownsStream;
-                       this.read                       = new object ();
-                       this.write                      = new object ();
-                       this.protocol           = new ServerRecordProtocol(innerStream, context);
+                               clientCertificateRequired,
+                               requestClientCertificate);
+
+                       this.protocol = new ServerRecordProtocol(innerStream, (ServerContext)this.context);
                }
 
                #endregion
@@ -303,308 +160,14 @@ namespace Mono.Security.Protocol.Tls
 
                #region IDisposable Methods
 
-               void IDisposable.Dispose()
+               protected override void Dispose(bool disposing)
                {
-                       this.Dispose(true);
-                       GC.SuppressFinalize(this);
-               }
+                       base.Dispose(disposing);
 
-               protected virtual void Dispose(bool disposing)
-               {
-                       if (!this.disposed)
+                       if (disposing)
                        {
-                               if (disposing)
-                               {
-                                       if (this.innerStream != null)
-                                       {
-                                               if (this.context.HandshakeState == HandshakeState.Finished)
-                                               {
-                                                       // Write close notify
-                                                       this.protocol.SendAlert(AlertDescription.CloseNotify);
-                                               }
-
-                                               if (this.ownsStream)
-                                               {
-                                                       // Close inner stream
-                                                       this.innerStream.Close();
-                                               }
-                                       }
-                                       this.ownsStream                         = false;
-                                       this.innerStream                        = null;
-                                       this.ClientCertValidation       = null;
-                                       this.PrivateKeySelection        = null;
-                               }
-
-                               this.disposed = true;
-                       }
-               }
-
-               #endregion
-
-               #region Methods
-
-               public override IAsyncResult BeginRead(
-                       byte[]                  buffer,
-                       int                             offset,
-                       int                             count,
-                       AsyncCallback   callback,
-                       object                  state)
-               {
-                       this.checkDisposed();
-                       
-                       if (buffer == null)
-                       {
-                               throw new ArgumentNullException("buffer is a null reference.");
-                       }
-                       if (offset < 0)
-                       {
-                               throw new ArgumentOutOfRangeException("offset is less than 0.");
-                       }
-                       if (offset > buffer.Length)
-                       {
-                               throw new ArgumentOutOfRangeException("offset is greater than the length of buffer.");
-                       }
-                       if (count < 0)
-                       {
-                               throw new ArgumentOutOfRangeException("count is less than 0.");
-                       }
-                       if (count > (buffer.Length - offset))
-                       {
-                               throw new ArgumentOutOfRangeException("count is less than the length of buffer minus the value of the offset parameter.");
-                       }
-
-                       lock (this)
-                       {
-                               if (this.context.HandshakeState == HandshakeState.None)
-                               {
-                                       this.doHandshake();     // Handshake negotiation
-                               }
-                       }
-
-                       IAsyncResult asyncResult;
-
-                       lock (this.read)
-                       {
-                               try
-                               {
-                                       // If actual buffer is full readed reset it
-                                       if (this.inputBuffer.Position == this.inputBuffer.Length &&
-                                               this.inputBuffer.Length > 0)
-                                       {
-                                               this.resetBuffer();
-                                       }
-
-                                       if (!this.context.ConnectionEnd)
-                                       {
-                                               // Check if we have space in the middle buffer
-                                               // if not Read next TLS record and update the inputBuffer
-                                               while ((this.inputBuffer.Length - this.inputBuffer.Position) < count)
-                                               {
-                                                       // Read next record and write it into the inputBuffer
-                                                       long    position        = this.inputBuffer.Position;                                    
-                                                       byte[]  record          = this.protocol.ReceiveRecord(this.innerStream);
-                                       
-                                                       if (record != null && record.Length > 0)
-                                                       {
-                                                               // Write new data to the inputBuffer
-                                                               this.inputBuffer.Seek(0, SeekOrigin.End);
-                                                               this.inputBuffer.Write(record, 0, record.Length);
-
-                                                               // Restore buffer position
-                                                               this.inputBuffer.Seek(position, SeekOrigin.Begin);
-                                                       }
-                                                       else
-                                                       {
-                                                               if (record == null)
-                                                               {
-                                                                       break;
-                                                               }
-                                                       }
-
-                                                       // TODO: Review if we need to check the Length
-                                                       // property of the innerStream for other types
-                                                       // of streams, to check that there are data available
-                                                       // for read
-                                                       if (this.innerStream is NetworkStream &&
-                                                               !((NetworkStream)this.innerStream).DataAvailable)
-                                                       {
-                                                               break;
-                                                       }
-                                               }
-                                       }
-
-                                       asyncResult = this.inputBuffer.BeginRead(
-                                               buffer, offset, count, callback, state);
-                               }
-                               catch (TlsException ex)
-                               {
-                                       this.protocol.SendAlert(ex.Alert);
-                                       this.Close();
-
-                                       throw new IOException("The authentication or decryption has failed.");
-                               }
-                               catch (Exception)
-                               {
-                                       throw new IOException("IO exception during read.");
-                               }
-                       }
-
-                       return asyncResult;
-               }
-
-               public override IAsyncResult BeginWrite(
-                       byte[]                  buffer,
-                       int                             offset,
-                       int                             count,
-                       AsyncCallback   callback,
-                       object                  state)
-               {
-                       this.checkDisposed();
-
-                       if (buffer == null)
-                       {
-                               throw new ArgumentNullException("buffer is a null reference.");
-                       }
-                       if (offset < 0)
-                       {
-                               throw new ArgumentOutOfRangeException("offset is less than 0.");
-                       }
-                       if (offset > buffer.Length)
-                       {
-                               throw new ArgumentOutOfRangeException("offset is greater than the length of buffer.");
-                       }
-                       if (count < 0)
-                       {
-                               throw new ArgumentOutOfRangeException("count is less than 0.");
-                       }
-                       if (count > (buffer.Length - offset))
-                       {
-                               throw new ArgumentOutOfRangeException("count is less than the length of buffer minus the value of the offset parameter.");
-                       }
-
-                       lock (this)
-                       {
-                               if (this.context.HandshakeState == HandshakeState.None)
-                               {
-                                       // Start handshake negotiation
-                                       this.doHandshake();
-                               }
-                       }
-
-                       IAsyncResult asyncResult;
-
-                       lock (this.write)
-                       {
-                               try
-                               {
-                                       // Send the buffer as a TLS record                                      
-                                       byte[] record = this.protocol.EncodeRecord(
-                                               ContentType.ApplicationData, buffer, offset, count);
-                               
-                                       asyncResult = this.innerStream.BeginWrite(
-                                               record, 0, record.Length, callback, state);
-                               }
-                               catch (TlsException ex)
-                               {
-                                       this.protocol.SendAlert(ex.Alert);
-                                       this.Close();
-
-                                       throw new IOException("The authentication or decryption has failed.");
-                               }
-                               catch (Exception)
-                               {
-                                       throw new IOException("IO exception during Write.");
-                               }
-                       }
-
-                       return asyncResult;
-               }
-
-               public override int EndRead(IAsyncResult asyncResult)
-               {
-                       this.checkDisposed();
-
-                       if (asyncResult == null)
-                       {
-                               throw new ArgumentNullException("asyncResult is null or was not obtained by calling BeginRead.");
-                       }
-
-                       return this.inputBuffer.EndRead(asyncResult);
-               }
-
-               public override void EndWrite(IAsyncResult asyncResult)
-               {
-                       this.checkDisposed();
-
-                       if (asyncResult == null)
-                       {
-                               throw new ArgumentNullException("asyncResult is null or was not obtained by calling BeginRead.");
-                       }
-
-                       this.innerStream.EndWrite (asyncResult);
-               }
-
-               public override void Close()
-               {
-                       ((IDisposable)this).Dispose();
-               }
-
-               public override void Flush()
-               {
-                       this.checkDisposed();
-
-                       this.innerStream.Flush();
-               }
-
-               public int Read(byte[] buffer)
-               {
-                       return this.Read(buffer, 0, buffer.Length);
-               }
-
-               public override int Read(byte[] buffer, int offset, int count)
-               {
-                       IAsyncResult res = this.BeginRead(buffer, offset, count, null, null);
-
-                       return this.EndRead(res);
-               }
-
-               public override long Seek(long offset, SeekOrigin origin)
-               {
-                       throw new NotSupportedException();
-               }
-               
-               public override void SetLength(long value)
-               {
-                       throw new NotSupportedException();
-               }
-
-               public void Write(byte[] buffer)
-               {
-                       this.Write(buffer, 0, buffer.Length);
-               }
-
-               public override void Write(byte[] buffer, int offset, int count)
-               {
-                       IAsyncResult res = this.BeginWrite (buffer, offset, count, null, null);
-
-                       this.EndWrite(res);
-               }
-
-               #endregion
-
-               #region Misc Methods
-
-               private void resetBuffer()
-               {
-                       this.inputBuffer.SetLength(0);
-                       this.inputBuffer.Position = 0;
-               }
-
-               private void checkDisposed()
-               {
-                       if (this.disposed)
-                       {
-                               throw new ObjectDisposedException("The SslClientStream is closed.");
+                               this.ClientCertValidation = null;
+                               this.PrivateKeySelection = null;
                        }
                }
 
@@ -633,106 +196,128 @@ namespace Mono.Security.Protocol.Tls
                                        Fig. 1 - Message flow for a full handshake              
                */
 
-               private void doHandshake()
+               internal override IAsyncResult BeginNegotiateHandshake(AsyncCallback callback, object state)
                {
-                       try
+                       // Reset the context if needed
+                       if (this.context.HandshakeState != HandshakeState.None)
                        {
-                               // Reset the context if needed
-                               if (this.context.HandshakeState != HandshakeState.None)
-                               {
-                                       this.context.Clear();
-                               }
+                               this.context.Clear();
+                       }
 
-                               // Obtain supported cipher suites
-                               this.context.SupportedCiphers = CipherSuiteFactory.GetSupportedCiphers(this.context.SecurityProtocol);
+                       // Obtain supported cipher suites
+                       this.context.SupportedCiphers = CipherSuiteFactory.GetSupportedCiphers(this.context.SecurityProtocol);
 
-                               // Set handshake state
-                               this.context.HandshakeState = HandshakeState.Started;
+                       // Set handshake state
+                       this.context.HandshakeState = HandshakeState.Started;
 
-                               // Receive Client Hello message
-                               this.protocol.ReceiveRecord (this.innerStream);
+                       // Receive Client Hello message
+                       return this.protocol.BeginReceiveRecord(this.innerStream, callback, state);
 
-                               // If received message is not an ClientHello send a
-                               // Fatal Alert
-                               if (this.context.LastHandshakeMsg != HandshakeType.ClientHello)
-                               {
-                                       this.protocol.SendAlert(AlertDescription.UnexpectedMessage);
-                               }
+               }
 
-                               // Send ServerHello message
-                               this.protocol.SendRecord(HandshakeType.ServerHello);
+               internal override void EndNegotiateHandshake(IAsyncResult asyncResult)
+               {
+                       // Receive Client Hello message and ignore it
+                       this.protocol.EndReceiveRecord(asyncResult);
 
-                               // Send ServerCertificate message
-                               this.protocol.SendRecord(HandshakeType.Certificate);
+                       // If received message is not an ClientHello send a
+                       // Fatal Alert
+                       if (this.context.LastHandshakeMsg != HandshakeType.ClientHello)
+                       {
+                               this.protocol.SendAlert(AlertDescription.UnexpectedMessage);
+                       }
 
-                               // If the negotiated cipher is a KeyEx cipher send ServerKeyExchange
-                               if (this.context.Cipher.ExchangeAlgorithmType == ExchangeAlgorithmType.RsaKeyX)
-                               {
-                                       this.protocol.SendRecord(HandshakeType.ServerKeyExchange);
-                               }
+                       // Send ServerHello message
+                       this.protocol.SendRecord(HandshakeType.ServerHello);
 
-                               // If the negotiated cipher is a KeyEx cipher or
-                               // the client certificate is required send the CertificateRequest message
-                               if (this.context.Cipher.ExchangeAlgorithmType == ExchangeAlgorithmType.RsaKeyX ||
-                                       this.context.ClientCertificateRequired)
-                               {
-                                       this.protocol.SendRecord(HandshakeType.CertificateRequest);
-                               }
+                       // Send ServerCertificate message
+                       this.protocol.SendRecord(HandshakeType.Certificate);
+
+                       // If the negotiated cipher is a KeyEx cipher send ServerKeyExchange
+                       if (this.context.Negotiating.Cipher.IsExportable)
+                       {
+                               this.protocol.SendRecord(HandshakeType.ServerKeyExchange);
+                       }
+
+                       // If the negotiated cipher is a KeyEx cipher or
+                       // the client certificate is required send the CertificateRequest message
+                       if (this.context.Negotiating.Cipher.IsExportable ||
+                               ((ServerContext)this.context).ClientCertificateRequired ||
+                               ((ServerContext)this.context).RequestClientCertificate)
+                       {
+                               this.protocol.SendRecord(HandshakeType.CertificateRequest);
+                       }
 
-                               // Send ServerHelloDone message
-                               this.protocol.SendRecord(HandshakeType.ServerHelloDone);
+                       // Send ServerHelloDone message
+                       this.protocol.SendRecord(HandshakeType.ServerHelloDone);
 
-                               // Receive client response, until the Client Finished message
-                               // is received
-                               while (this.context.LastHandshakeMsg != HandshakeType.Finished)
+                       // Receive client response, until the Client Finished message
+                       // is received. IE can be interrupted at this stage and never
+                       // complete the handshake
+                       while (this.context.LastHandshakeMsg != HandshakeType.Finished)
+                       {
+                               byte[] record = this.protocol.ReceiveRecord(this.innerStream);
+                               if ((record == null) || (record.Length == 0))
                                {
-                                       this.protocol.ReceiveRecord (this.innerStream);
+                                       throw new TlsException(
+                                               AlertDescription.HandshakeFailiure,
+                                               "The client stopped the handshake.");
                                }
-                               
-                               // Send ChangeCipherSpec and ServerFinished messages
-                               this.protocol.SendChangeCipherSpec();
+                       }
 
-                               // The handshake is finished
-                               this.context.HandshakeState = HandshakeState.Finished;
+                       // Send ChangeCipherSpec and ServerFinished messages
+                       this.protocol.SendChangeCipherSpec();
+                       this.protocol.SendRecord (HandshakeType.Finished);
 
-                               // Clear Key Info
-                               this.context.ClearKeyInfo();
-                       }
-                       catch (TlsException ex)
-                       {
-                               this.protocol.SendAlert(ex.Alert);
-                               this.Close();
+                       // The handshake is finished
+                       this.context.HandshakeState = HandshakeState.Finished;
 
-                               throw new IOException("The authentication or decryption has failed.");
-                       }
-                       catch (Exception)
-                       {
-                               this.protocol.SendAlert(AlertDescription.InternalError);
-                               this.Close();
+                       // Reset Handshake messages information
+                       this.context.HandshakeMessages.Reset ();
 
-                               throw new IOException("The authentication or decryption has failed.");
-                       }
+                       // Clear Key Info
+                       this.context.ClearKeyInfo();
                }
 
                #endregion
 
                #region Event Methods
 
-               internal bool RaiseClientCertificateValidation(
-                       X509Certificate certificate, 
-                       int[]                   certificateErrors)
+               internal override X509Certificate OnLocalCertificateSelection(X509CertificateCollection clientCertificates, X509Certificate serverCertificate, string targetHost, X509CertificateCollection serverRequestedCertificates)
+               {
+                       throw new NotSupportedException();
+               }
+
+               internal override bool OnRemoteCertificateValidation(X509Certificate certificate, int[] errors)
                {
                        if (this.ClientCertValidation != null)
                        {
-                               return this.ClientCertValidation(certificate, certificateErrors);
+                               return this.ClientCertValidation(certificate, errors);
                        }
 
-                       return (certificateErrors != null && certificateErrors.Length == 0);
+                       return (errors != null && errors.Length == 0);
                }
 
-               internal AsymmetricAlgorithm RaisePrivateKeySelection(
+               internal override bool HaveRemoteValidation2Callback {
+                       get { return ClientCertValidation2 != null; }
+               }
+
+               internal override ValidationResult OnRemoteCertificateValidation2 (Mono.Security.X509.X509CertificateCollection collection)
+               {
+                       CertificateValidationCallback2 cb = ClientCertValidation2;
+                       if (cb != null)
+                               return cb (collection);
+                       return null;
+               }
+
+               internal bool RaiseClientCertificateValidation(
                        X509Certificate certificate, 
-                       string                  targetHost)
+                       int[]                   certificateErrors)
+               {
+                       return base.RaiseRemoteCertificateValidation(certificate, certificateErrors);
+               }
+
+               internal override AsymmetricAlgorithm OnLocalPrivateKeySelection(X509Certificate certificate, string targetHost)
                {
                        if (this.PrivateKeySelection != null)
                        {
@@ -742,6 +327,13 @@ namespace Mono.Security.Protocol.Tls
                        return null;
                }
 
+               internal AsymmetricAlgorithm RaisePrivateKeySelection(
+                       X509Certificate certificate,
+                       string targetHost)
+               {
+                       return base.RaiseLocalPrivateKeySelection(certificate, targetHost);
+               }
+
                #endregion
        }
 }