Convert blocking operations in HttpWebRequest and SslClientStream to non-blocking...
[mono.git] / mcs / class / Mono.Security / Mono.Security.Protocol.Tls / SslStreamBase.cs
index 952623fa050ee9f18dfb372b03ec41ccef179167..5c0032c02f0c5724a29fcf5052a639e39a606cc0 100644 (file)
@@ -1,6 +1,6 @@
 // Transport Security Layer (TLS)
 // Copyright (c) 2003-2004 Carlos Guzman Alvarez
-
+// Copyright (C) 2006-2007 Novell, Inc (http://www.novell.com)
 //
 // Permission is hereby granted, free of charge, to any person obtaining
 // a copy of this software and associated documentation files (the
@@ -33,12 +33,19 @@ using System.Threading;
 
 namespace Mono.Security.Protocol.Tls
 {
-       public abstract class SslStreamBase: Stream, IDisposable
+#if INSIDE_SYSTEM
+       internal
+#else
+       public
+#endif
+       abstract class SslStreamBase: Stream, IDisposable
        {
                private delegate void AsyncHandshakeDelegate(InternalAsyncResult asyncResult, bool fromWrite);
                
                #region Fields
 
+               static ManualResetEvent record_processing = new ManualResetEvent (true);        
+
                internal Stream innerStream;
                internal MemoryStream inputBuffer;
                internal Context context;
@@ -89,7 +96,7 @@ namespace Mono.Security.Protocol.Tls
                        {
                                try
                                {
-                                       this.OnNegotiateHandshakeCallback(asyncResult);
+                                       this.EndNegotiateHandshake(asyncResult);
                                }
                                catch (TlsException ex)
                                {
@@ -172,8 +179,8 @@ namespace Mono.Security.Protocol.Tls
 
                #region Abstracts/Virtuals
 
-               internal abstract IAsyncResult OnBeginNegotiateHandshake(AsyncCallback callback, object state);
-               internal abstract void OnNegotiateHandshakeCallback(IAsyncResult asyncResult);
+               internal abstract IAsyncResult BeginNegotiateHandshake (AsyncCallback callback, object state);
+               internal abstract void EndNegotiateHandshake (IAsyncResult result);
 
                internal abstract X509Certificate OnLocalCertificateSelection(X509CertificateCollection clientCertificates,
                                                                                                                        X509Certificate serverCertificate,
@@ -181,6 +188,8 @@ namespace Mono.Security.Protocol.Tls
                                                                                                                        X509CertificateCollection serverRequestedCertificates);
 
                internal abstract bool OnRemoteCertificateValidation(X509Certificate certificate, int[] errors);
+               internal abstract ValidationResult OnRemoteCertificateValidation2 (Mono.Security.X509.X509CertificateCollection collection);
+               internal abstract bool HaveRemoteValidation2Callback { get; }
 
                internal abstract AsymmetricAlgorithm OnLocalPrivateKeySelection(X509Certificate certificate, string targetHost);
 
@@ -201,6 +210,11 @@ namespace Mono.Security.Protocol.Tls
                        return OnRemoteCertificateValidation(certificate, errors);
                }
 
+               internal ValidationResult RaiseRemoteCertificateValidation2 (Mono.Security.X509.X509CertificateCollection collection)
+               {
+                       return OnRemoteCertificateValidation2 (collection);
+               }
+
                internal AsymmetricAlgorithm RaiseLocalPrivateKeySelection(
                        X509Certificate certificate,
                        string targetHost)
@@ -223,7 +237,7 @@ namespace Mono.Security.Protocol.Tls
                        {
                                if (this.context.HandshakeState == HandshakeState.Finished)
                                {
-                                       return this.context.Cipher.CipherAlgorithmType;
+                                       return this.context.Current.Cipher.CipherAlgorithmType;
                                }
 
                                return CipherAlgorithmType.None;
@@ -236,7 +250,7 @@ namespace Mono.Security.Protocol.Tls
                        {
                                if (this.context.HandshakeState == HandshakeState.Finished)
                                {
-                                       return this.context.Cipher.EffectiveKeyBits;
+                                       return this.context.Current.Cipher.EffectiveKeyBits;
                                }
 
                                return 0;
@@ -249,7 +263,7 @@ namespace Mono.Security.Protocol.Tls
                        {
                                if (this.context.HandshakeState == HandshakeState.Finished)
                                {
-                                       return this.context.Cipher.HashAlgorithmType;
+                                       return this.context.Current.Cipher.HashAlgorithmType;
                                }
 
                                return HashAlgorithmType.None;
@@ -262,7 +276,7 @@ namespace Mono.Security.Protocol.Tls
                        {
                                if (this.context.HandshakeState == HandshakeState.Finished)
                                {
-                                       return this.context.Cipher.HashSize * 8;
+                                       return this.context.Current.Cipher.HashSize * 8;
                                }
 
                                return 0;
@@ -288,7 +302,7 @@ namespace Mono.Security.Protocol.Tls
                        {
                                if (this.context.HandshakeState == HandshakeState.Finished)
                                {
-                                       return this.context.Cipher.ExchangeAlgorithmType;
+                                       return this.context.Current.Cipher.ExchangeAlgorithmType;
                                }
 
                                return ExchangeAlgorithmType.None;
@@ -442,10 +456,10 @@ namespace Mono.Security.Protocol.Tls
                                                return;
 
                                        completed = true;
-                                       if (handle != null)
-                                               handle.Set ();
                                        _asyncException = ex;
                                        _bytesRead = bytesRead;
+                                       if (handle != null)
+                                               handle.Set ();
                                }
                                if (_userCallback != null)
                                        _userCallback.BeginInvoke (this, null, null);
@@ -478,7 +492,7 @@ namespace Mono.Security.Protocol.Tls
                                {
                                        if (this.context.HandshakeState == HandshakeState.None)
                                        {
-                                               this.OnBeginNegotiateHandshake(new AsyncCallback(AsyncHandshakeCallback), asyncResult);
+                                               this.BeginNegotiateHandshake(new AsyncCallback(AsyncHandshakeCallback), asyncResult);
 
                                                return true;
                                        }
@@ -567,6 +581,9 @@ namespace Mono.Security.Protocol.Tls
                        return asyncResult;
                }
 
+               // bigger than max record length for SSL/TLS
+               private byte[] recbuf = new byte[16384];
+
                private void InternalBeginRead(InternalAsyncResult asyncResult)
                {
                        try
@@ -598,11 +615,8 @@ namespace Mono.Security.Protocol.Tls
                                {
                                        asyncResult.SetComplete(preReadSize);
                                }
-                               else if (!this.context.ConnectionEnd)
+                               else if (!this.context.ReceivedConnectionEnd)
                                {
-                                       // bigger than max record length for SSL/TLS
-                                       byte[] recbuf = new byte[16384];
-
                                        // this will read data from the network until we have (at least) one
                                        // record to send back to the caller
                                        this.innerStream.BeginRead(recbuf, 0, recbuf.Length,
@@ -720,11 +734,15 @@ namespace Mono.Security.Protocol.Tls
 
                                if (!dataToReturn && (n > 0))
                                {
-                                       // there is no record to return to caller and (possibly) more data waiting
-                                       // so continue reading from network (and appending to stream)
-                                       recordStream.Position = recordStream.Length;
-                                       this.innerStream.BeginRead(recbuf, 0, recbuf.Length,
-                                               new AsyncCallback(InternalReadCallback), state);
+                                       if (context.ReceivedConnectionEnd) {
+                                               internalResult.SetComplete (0);
+                                       } else {
+                                               // there is no record to return to caller and (possibly) more data waiting
+                                               // so continue reading from network (and appending to stream)
+                                               recordStream.Position = recordStream.Length;
+                                               this.innerStream.BeginRead(recbuf, 0, recbuf.Length,
+                                                       new AsyncCallback(InternalReadCallback), state);
+                                       }
                                }
                                else
                                {
@@ -858,8 +876,11 @@ namespace Mono.Security.Protocol.Tls
                        }
 
                        // Always wait until the read is complete
-                       if (asyncResult.IsCompleted == false)
-                               asyncResult.AsyncWaitHandle.WaitOne();
+                       if (!asyncResult.IsCompleted)
+                       {
+                               if (!asyncResult.AsyncWaitHandle.WaitOne ())
+                                       throw new TlsException (AlertDescription.InternalError, "Couldn't complete EndRead");
+                       }
 
                        if (internalResult.CompletedWithError)
                        {
@@ -880,8 +901,11 @@ namespace Mono.Security.Protocol.Tls
                        }
 
 
-                       if (asyncResult.IsCompleted == false)
-                               internalResult.AsyncWaitHandle.WaitOne();
+                       if (!asyncResult.IsCompleted)
+                       {
+                               if (!internalResult.AsyncWaitHandle.WaitOne ())
+                                       throw new TlsException (AlertDescription.InternalError, "Couldn't complete EndWrite");
+                       }
 
                        if (internalResult.CompletedWithError)
                        {
@@ -891,7 +915,7 @@ namespace Mono.Security.Protocol.Tls
 
                public override void Close()
                {
-                       ((IDisposable)this).Dispose();
+                       base.Close ();
                }
 
                public override void Flush()
@@ -908,9 +932,140 @@ namespace Mono.Security.Protocol.Tls
 
                public override int Read(byte[] buffer, int offset, int count)
                {
-                       IAsyncResult res = this.BeginRead(buffer, offset, count, null, null);
+                       this.checkDisposed ();
+                       
+                       if (buffer == null)
+                       {
+                               throw new ArgumentNullException ("buffer");
+                       }
+                       if (offset < 0)
+                       {
+                               throw new ArgumentOutOfRangeException("offset is less than 0.");
+                       }
+                       if (offset > buffer.Length)
+                       {
+                               throw new ArgumentOutOfRangeException("offset is greater than the length of buffer.");
+                       }
+                       if (count < 0)
+                       {
+                               throw new ArgumentOutOfRangeException("count is less than 0.");
+                       }
+                       if (count > (buffer.Length - offset))
+                       {
+                               throw new ArgumentOutOfRangeException("count is less than the length of buffer minus the value of the offset parameter.");
+                       }
+
+                       if (this.context.HandshakeState != HandshakeState.Finished)
+                       {
+                               this.NegotiateHandshake (); // Handshake negotiation
+                       }
+
+                       lock (this.read) {
+                               try {
+                                       record_processing.Reset ();
+                                       // do we already have some decrypted data ?
+                                       if (this.inputBuffer.Position > 0) {
+                                               // or maybe we used all the buffer before ?
+                                               if (this.inputBuffer.Position == this.inputBuffer.Length) {
+                                                       this.inputBuffer.SetLength (0);
+                                               } else {
+                                                       int n = this.inputBuffer.Read (buffer, offset, count);
+                                                       if (n > 0) {
+                                                               record_processing.Set ();
+                                                               return n;
+                                                       }
+                                               }
+                                       }
+
+                                       bool needMoreData = false;
+                                       while (true) {
+                                               // we first try to process the read with the data we already have
+                                               if ((recordStream.Position == 0) || needMoreData) {
+                                                       needMoreData = false;
+                                                       // if we loop, then it either means we need more data
+                                                       byte[] recbuf = new byte[16384];
+                                                       int n = 0;
+                                                       if (count == 1) {
+                                                               int value = innerStream.ReadByte ();
+                                                               if (value >= 0) {
+                                                                       recbuf[0] = (byte) value;
+                                                                       n = 1;
+                                                               }
+                                                       } else {
+                                                               n = innerStream.Read (recbuf, 0, recbuf.Length);
+                                                       }
+                                                       if (n > 0) {
+                                                               // Add the new received data to the waiting data
+                                                               if ((recordStream.Length > 0) && (recordStream.Position != recordStream.Length))
+                                                                       recordStream.Seek (0, SeekOrigin.End);
+                                                               recordStream.Write (recbuf, 0, n);
+                                                       } else {
+                                                               // or that the read operation is done (lost connection in the case of a network stream).
+                                                               record_processing.Set ();
+                                                               return 0;
+                                                       }
+                                               }
+
+                                               bool dataToReturn = false;
+
+                                               recordStream.Position = 0;
+                                               byte[] record = null;
+
+                                               // don't try to decode record unless we have at least 5 bytes
+                                               // i.e. type (1), protocol (2) and length (2)
+                                               if (recordStream.Length >= 5) {
+                                                       record = this.protocol.ReceiveRecord (recordStream);
+                                                       needMoreData = (record == null);
+                                               }
 
-                       return this.EndRead(res);
+                                               // a record of 0 length is valid (and there may be more record after it)
+                                               while (record != null) {
+                                                       // we probably received more stuff after the record, and we must keep it!
+                                                       long remainder = recordStream.Length - recordStream.Position;
+                                                       byte[] outofrecord = null;
+                                                       if (remainder > 0) {
+                                                               outofrecord = new byte[remainder];
+                                                               recordStream.Read (outofrecord, 0, outofrecord.Length);
+                                                       }
+
+                                                       long position = this.inputBuffer.Position;
+
+                                                       if (record.Length > 0) {
+                                                               // Write new data to the inputBuffer
+                                                               this.inputBuffer.Seek (0, SeekOrigin.End);
+                                                               this.inputBuffer.Write (record, 0, record.Length);
+
+                                                               // Restore buffer position
+                                                               this.inputBuffer.Seek (position, SeekOrigin.Begin);
+                                                               dataToReturn = true;
+                                                       }
+
+                                                       recordStream.SetLength (0);
+                                                       record = null;
+
+                                                       if (remainder > 0) {
+                                                               recordStream.Write (outofrecord, 0, outofrecord.Length);
+                                                       }
+
+                                                       if (dataToReturn) {
+                                                               // we have record(s) to return -or- no more available to read from network
+                                                               // reset position for further reading
+                                                               int i = inputBuffer.Read (buffer, offset, count);
+                                                               record_processing.Set ();
+                                                               return i;
+                                                       }
+                                               }
+                                       }
+                               }
+                               catch (TlsException ex)
+                               {
+                                       throw new IOException("The authentication or decryption has failed.", ex);
+                               }
+                               catch (Exception ex)
+                               {
+                                       throw new IOException("IO exception during read.", ex);
+                               }
+                       }
                }
 
                public override long Seek(long offset, SeekOrigin origin)
@@ -930,9 +1085,53 @@ namespace Mono.Security.Protocol.Tls
 
                public override void Write(byte[] buffer, int offset, int count)
                {
-                       IAsyncResult res = this.BeginWrite(buffer, offset, count, null, null);
+                       this.checkDisposed ();
+                       
+                       if (buffer == null)
+                       {
+                               throw new ArgumentNullException ("buffer");
+                       }
+                       if (offset < 0)
+                       {
+                               throw new ArgumentOutOfRangeException("offset is less than 0.");
+                       }
+                       if (offset > buffer.Length)
+                       {
+                               throw new ArgumentOutOfRangeException("offset is greater than the length of buffer.");
+                       }
+                       if (count < 0)
+                       {
+                               throw new ArgumentOutOfRangeException("count is less than 0.");
+                       }
+                       if (count > (buffer.Length - offset))
+                       {
+                               throw new ArgumentOutOfRangeException("count is less than the length of buffer minus the value of the offset parameter.");
+                       }
 
-                       this.EndWrite(res);
+                       if (this.context.HandshakeState != HandshakeState.Finished)
+                       {
+                               this.NegotiateHandshake ();
+                       }
+
+                       lock (this.write)
+                       {
+                               try
+                               {
+                                       // Send the buffer as a TLS record
+                                       byte[] record = this.protocol.EncodeRecord (ContentType.ApplicationData, buffer, offset, count);
+                                       this.innerStream.Write (record, 0, record.Length);
+                               }
+                               catch (TlsException ex)
+                               {
+                                       this.protocol.SendAlert(ex.Alert);
+                                       this.Close();
+                                       throw new IOException("The authentication or decryption has failed.", ex);
+                               }
+                               catch (Exception ex)
+                               {
+                                       throw new IOException("IO exception during Write.", ex);
+                               }
+                       }
                }
 
                public override bool CanRead
@@ -975,13 +1174,7 @@ namespace Mono.Security.Protocol.Tls
                        this.Dispose(false);
                }
 
-               public void Dispose()
-               {
-                       this.Dispose(true);
-                       GC.SuppressFinalize(this);
-               }
-
-               protected virtual void Dispose(bool disposing)
+               protected override void Dispose (bool disposing)
                {
                        if (!this.disposed)
                        {
@@ -990,10 +1183,13 @@ namespace Mono.Security.Protocol.Tls
                                        if (this.innerStream != null)
                                        {
                                                if (this.context.HandshakeState == HandshakeState.Finished &&
-                                                       !this.context.ConnectionEnd)
+                                                       !this.context.SentConnectionEnd)
                                                {
-                                                       // Write close notify                                                   
-                                                       this.protocol.SendAlert(AlertDescription.CloseNotify);
+                                                       // Write close notify
+                                                       try {
+                                                               this.protocol.SendAlert(AlertDescription.CloseNotify);
+                                                       } catch {
+                                                       }
                                                }
 
                                                if (this.ownsStream)
@@ -1007,6 +1203,7 @@ namespace Mono.Security.Protocol.Tls
                                }
 
                                this.disposed = true;
+                               base.Dispose (disposing);
                        }
                }