Moved chain building and validation from Mono.Security to System
[mono.git] / mcs / class / Mono.Security / Mono.Security.Protocol.Tls / SslStreamBase.cs
old mode 100755 (executable)
new mode 100644 (file)
index 9b6df73..f347ccb
@@ -1,6 +1,6 @@
 // Transport Security Layer (TLS)
 // Copyright (c) 2003-2004 Carlos Guzman Alvarez
-
+// Copyright (C) 2006-2007 Novell, Inc (http://www.novell.com)
 //
 // Permission is hereby granted, free of charge, to any person obtaining
 // a copy of this software and associated documentation files (the
@@ -39,6 +39,10 @@ namespace Mono.Security.Protocol.Tls
                
                #region Fields
 
+               static ManualResetEvent record_processing = new ManualResetEvent (true);
+
+               private const int WaitTimeOut = 5 * 60 * 1000;
+
                internal Stream innerStream;
                internal MemoryStream inputBuffer;
                internal Context context;
@@ -104,8 +108,6 @@ namespace Mono.Security.Protocol.Tls
                                        throw new IOException("The authentication or decryption has failed.", ex);
                                }
 
-                               negotiationComplete.Set();
-
                                if (internalResult.ProceedAfterHandshake)
                                {
                                        //kick off the read or write process (whichever called us) after the handshake is complete
@@ -117,6 +119,7 @@ namespace Mono.Security.Protocol.Tls
                                        {
                                                InternalBeginRead(internalResult);
                                        }
+                                       negotiationComplete.Set();
                                }
                                else
                                {
@@ -182,6 +185,8 @@ namespace Mono.Security.Protocol.Tls
                                                                                                                        X509CertificateCollection serverRequestedCertificates);
 
                internal abstract bool OnRemoteCertificateValidation(X509Certificate certificate, int[] errors);
+               internal abstract ValidationResult OnRemoteCertificateValidation2 (Mono.Security.X509.X509CertificateCollection collection);
+               internal abstract bool HaveRemoteValidation2Callback { get; }
 
                internal abstract AsymmetricAlgorithm OnLocalPrivateKeySelection(X509Certificate certificate, string targetHost);
 
@@ -202,6 +207,11 @@ namespace Mono.Security.Protocol.Tls
                        return OnRemoteCertificateValidation(certificate, errors);
                }
 
+               internal ValidationResult RaiseRemoteCertificateValidation2 (Mono.Security.X509.X509CertificateCollection collection)
+               {
+                       return OnRemoteCertificateValidation2 (collection);
+               }
+
                internal AsymmetricAlgorithm RaiseLocalPrivateKeySelection(
                        X509Certificate certificate,
                        string targetHost)
@@ -224,7 +234,7 @@ namespace Mono.Security.Protocol.Tls
                        {
                                if (this.context.HandshakeState == HandshakeState.Finished)
                                {
-                                       return this.context.Cipher.CipherAlgorithmType;
+                                       return this.context.Current.Cipher.CipherAlgorithmType;
                                }
 
                                return CipherAlgorithmType.None;
@@ -237,7 +247,7 @@ namespace Mono.Security.Protocol.Tls
                        {
                                if (this.context.HandshakeState == HandshakeState.Finished)
                                {
-                                       return this.context.Cipher.EffectiveKeyBits;
+                                       return this.context.Current.Cipher.EffectiveKeyBits;
                                }
 
                                return 0;
@@ -250,7 +260,7 @@ namespace Mono.Security.Protocol.Tls
                        {
                                if (this.context.HandshakeState == HandshakeState.Finished)
                                {
-                                       return this.context.Cipher.HashAlgorithmType;
+                                       return this.context.Current.Cipher.HashAlgorithmType;
                                }
 
                                return HashAlgorithmType.None;
@@ -263,7 +273,7 @@ namespace Mono.Security.Protocol.Tls
                        {
                                if (this.context.HandshakeState == HandshakeState.Finished)
                                {
-                                       return this.context.Cipher.HashSize * 8;
+                                       return this.context.Current.Cipher.HashSize * 8;
                                }
 
                                return 0;
@@ -289,7 +299,7 @@ namespace Mono.Security.Protocol.Tls
                        {
                                if (this.context.HandshakeState == HandshakeState.Finished)
                                {
-                                       return this.context.Cipher.ExchangeAlgorithmType;
+                                       return this.context.Current.Cipher.ExchangeAlgorithmType;
                                }
 
                                return ExchangeAlgorithmType.None;
@@ -338,10 +348,12 @@ namespace Mono.Security.Protocol.Tls
 
                private class InternalAsyncResult : IAsyncResult
                {
+                       private object locker = new object ();
                        private AsyncCallback _userCallback;
                        private object _userState;
                        private Exception _asyncException;
-                       private ManualResetEvent _complete;
+                       private ManualResetEvent handle;
+                       private bool completed;
                        private int _bytesRead;
                        private bool _fromWrite;
                        private bool _proceedAfterHandshake;
@@ -354,7 +366,6 @@ namespace Mono.Security.Protocol.Tls
                        {
                                _userCallback = userCallback;
                                _userState = userState;
-                               _complete = new ManualResetEvent(false);
                                _buffer = buffer;
                                _offset = offset;
                                _count = count;
@@ -404,12 +415,22 @@ namespace Mono.Security.Protocol.Tls
 
                        public bool CompletedWithError
                        {
-                               get { return null != _asyncException; }
+                               get {
+                                       if (IsCompleted == false)
+                                               return false;
+                                       return null != _asyncException;
+                               }
                        }
 
                        public WaitHandle AsyncWaitHandle
                        {
-                               get { return _complete; }
+                               get {
+                                       lock (locker) {
+                                               if (handle == null)
+                                                       handle = new ManualResetEvent (completed);
+                                       }
+                                       return handle;
+                               }
                        }
 
                        public bool CompletedSynchronously
@@ -419,27 +440,26 @@ namespace Mono.Security.Protocol.Tls
 
                        public bool IsCompleted
                        {
-                               get { return _complete.WaitOne(0, false); }
+                               get {
+                                       lock (locker)
+                                               return completed;
+                               }
                        }
 
                        private void SetComplete(Exception ex, int bytesRead)
                        {
-                               if (this.IsCompleted)
-                                       return;
-
-                               lock (this)
-                               {
-                                       if (this.IsCompleted)
+                               lock (locker) {
+                                       if (completed)
                                                return;
 
+                                       completed = true;
                                        _asyncException = ex;
                                        _bytesRead = bytesRead;
-                                       _complete.Set();
+                                       if (handle != null)
+                                               handle.Set ();
                                }
-
-                               if (null != _userCallback)
-                                       _userCallback (this);
-
+                               if (_userCallback != null)
+                                       _userCallback.BeginInvoke (this, null, null);
                        }
 
                        public void SetComplete(Exception ex)
@@ -497,7 +517,8 @@ namespace Mono.Security.Protocol.Tls
 
                private void EndNegotiateHandshake(InternalAsyncResult asyncResult)
                {
-                       asyncResult.AsyncWaitHandle.WaitOne();
+                       if (asyncResult.IsCompleted == false)
+                               asyncResult.AsyncWaitHandle.WaitOne();
 
                        if (asyncResult.CompletedWithError)
                        {
@@ -557,6 +578,9 @@ namespace Mono.Security.Protocol.Tls
                        return asyncResult;
                }
 
+               // bigger than max record length for SSL/TLS
+               private byte[] recbuf = new byte[16384];
+
                private void InternalBeginRead(InternalAsyncResult asyncResult)
                {
                        try
@@ -590,9 +614,6 @@ namespace Mono.Security.Protocol.Tls
                                }
                                else if (!this.context.ConnectionEnd)
                                {
-                                       // bigger than max record length for SSL/TLS
-                                       byte[] recbuf = new byte[16384];
-
                                        // this will read data from the network until we have (at least) one
                                        // record to send back to the caller
                                        this.innerStream.BeginRead(recbuf, 0, recbuf.Length,
@@ -770,13 +791,13 @@ namespace Mono.Security.Protocol.Tls
                {
                        if (this.disposed)
                                return;
-
+                       
                        InternalAsyncResult internalResult = (InternalAsyncResult)ar.AsyncState;
 
                        try
                        {
                                this.innerStream.EndWrite(ar);
-                               internalResult.SetComplete(0);
+                               internalResult.SetComplete();
                        }
                        catch (Exception ex)
                        {
@@ -842,15 +863,18 @@ namespace Mono.Security.Protocol.Tls
                        this.checkDisposed();
 
                        InternalAsyncResult internalResult = asyncResult as InternalAsyncResult;
-
-                       // Always wait until the read is complete
-                       internalResult.AsyncWaitHandle.WaitOne();
-
                        if (internalResult == null)
                        {
                                throw new ArgumentNullException("asyncResult is null or was not obtained by calling BeginRead.");
                        }
 
+                       // Always wait until the read is complete
+                       if (!asyncResult.IsCompleted)
+                       {
+                               if (!asyncResult.AsyncWaitHandle.WaitOne (WaitTimeOut, false))
+                                       throw new TlsException (AlertDescription.InternalError, "Couldn't complete EndRead");
+                       }
+
                        if (internalResult.CompletedWithError)
                        {
                                throw internalResult.AsyncException;
@@ -864,12 +888,16 @@ namespace Mono.Security.Protocol.Tls
                        this.checkDisposed();
 
                        InternalAsyncResult internalResult = asyncResult as InternalAsyncResult;
+                       if (internalResult == null)
+                       {
+                               throw new ArgumentNullException("asyncResult is null or was not obtained by calling BeginWrite.");
+                       }
 
-                       internalResult.AsyncWaitHandle.WaitOne();
 
-                       if (asyncResult == null)
+                       if (!asyncResult.IsCompleted)
                        {
-                               throw new ArgumentNullException("asyncResult is null or was not obtained by calling BeginWrite.");
+                               if (!internalResult.AsyncWaitHandle.WaitOne (WaitTimeOut, false))
+                                       throw new TlsException (AlertDescription.InternalError, "Couldn't complete EndWrite");
                        }
 
                        if (internalResult.CompletedWithError)
@@ -880,7 +908,7 @@ namespace Mono.Security.Protocol.Tls
 
                public override void Close()
                {
-                       ((IDisposable)this).Dispose();
+                       base.Close ();
                }
 
                public override void Flush()
@@ -897,9 +925,140 @@ namespace Mono.Security.Protocol.Tls
 
                public override int Read(byte[] buffer, int offset, int count)
                {
-                       IAsyncResult res = this.BeginRead(buffer, offset, count, null, null);
+                       this.checkDisposed ();
+                       
+                       if (buffer == null)
+                       {
+                               throw new ArgumentNullException ("buffer");
+                       }
+                       if (offset < 0)
+                       {
+                               throw new ArgumentOutOfRangeException("offset is less than 0.");
+                       }
+                       if (offset > buffer.Length)
+                       {
+                               throw new ArgumentOutOfRangeException("offset is greater than the length of buffer.");
+                       }
+                       if (count < 0)
+                       {
+                               throw new ArgumentOutOfRangeException("count is less than 0.");
+                       }
+                       if (count > (buffer.Length - offset))
+                       {
+                               throw new ArgumentOutOfRangeException("count is less than the length of buffer minus the value of the offset parameter.");
+                       }
+
+                       if (this.context.HandshakeState != HandshakeState.Finished)
+                       {
+                               this.NegotiateHandshake (); // Handshake negotiation
+                       }
+
+                       lock (this.read) {
+                               try {
+                                       record_processing.Reset ();
+                                       // do we already have some decrypted data ?
+                                       if (this.inputBuffer.Position > 0) {
+                                               // or maybe we used all the buffer before ?
+                                               if (this.inputBuffer.Position == this.inputBuffer.Length) {
+                                                       this.inputBuffer.SetLength (0);
+                                               } else {
+                                                       int n = this.inputBuffer.Read (buffer, offset, count);
+                                                       if (n > 0) {
+                                                               record_processing.Set ();
+                                                               return n;
+                                                       }
+                                               }
+                                       }
+
+                                       bool needMoreData = false;
+                                       while (true) {
+                                               // we first try to process the read with the data we already have
+                                               if ((recordStream.Position == 0) || needMoreData) {
+                                                       needMoreData = false;
+                                                       // if we loop, then it either means we need more data
+                                                       byte[] recbuf = new byte[16384];
+                                                       int n = 0;
+                                                       if (count == 1) {
+                                                               int value = innerStream.ReadByte ();
+                                                               if (value >= 0) {
+                                                                       recbuf[0] = (byte) value;
+                                                                       n = 1;
+                                                               }
+                                                       } else {
+                                                               n = innerStream.Read (recbuf, 0, recbuf.Length);
+                                                       }
+                                                       if (n > 0) {
+                                                               // Add the new received data to the waiting data
+                                                               if ((recordStream.Length > 0) && (recordStream.Position != recordStream.Length))
+                                                                       recordStream.Seek (0, SeekOrigin.End);
+                                                               recordStream.Write (recbuf, 0, n);
+                                                       } else {
+                                                               // or that the read operation is done (lost connection in the case of a network stream).
+                                                               record_processing.Set ();
+                                                               return 0;
+                                                       }
+                                               }
+
+                                               bool dataToReturn = false;
+
+                                               recordStream.Position = 0;
+                                               byte[] record = null;
+
+                                               // don't try to decode record unless we have at least 5 bytes
+                                               // i.e. type (1), protocol (2) and length (2)
+                                               if (recordStream.Length >= 5) {
+                                                       record = this.protocol.ReceiveRecord (recordStream);
+                                                       needMoreData = (record == null);
+                                               }
 
-                       return this.EndRead(res);
+                                               // a record of 0 length is valid (and there may be more record after it)
+                                               while (record != null) {
+                                                       // we probably received more stuff after the record, and we must keep it!
+                                                       long remainder = recordStream.Length - recordStream.Position;
+                                                       byte[] outofrecord = null;
+                                                       if (remainder > 0) {
+                                                               outofrecord = new byte[remainder];
+                                                               recordStream.Read (outofrecord, 0, outofrecord.Length);
+                                                       }
+
+                                                       long position = this.inputBuffer.Position;
+
+                                                       if (record.Length > 0) {
+                                                               // Write new data to the inputBuffer
+                                                               this.inputBuffer.Seek (0, SeekOrigin.End);
+                                                               this.inputBuffer.Write (record, 0, record.Length);
+
+                                                               // Restore buffer position
+                                                               this.inputBuffer.Seek (position, SeekOrigin.Begin);
+                                                               dataToReturn = true;
+                                                       }
+
+                                                       recordStream.SetLength (0);
+                                                       record = null;
+
+                                                       if (remainder > 0) {
+                                                               recordStream.Write (outofrecord, 0, outofrecord.Length);
+                                                       }
+
+                                                       if (dataToReturn) {
+                                                               // we have record(s) to return -or- no more available to read from network
+                                                               // reset position for further reading
+                                                               int i = inputBuffer.Read (buffer, offset, count);
+                                                               record_processing.Set ();
+                                                               return i;
+                                                       }
+                                               }
+                                       }
+                               }
+                               catch (TlsException ex)
+                               {
+                                       throw new IOException("The authentication or decryption has failed.", ex);
+                               }
+                               catch (Exception ex)
+                               {
+                                       throw new IOException("IO exception during read.", ex);
+                               }
+                       }
                }
 
                public override long Seek(long offset, SeekOrigin origin)
@@ -919,9 +1078,53 @@ namespace Mono.Security.Protocol.Tls
 
                public override void Write(byte[] buffer, int offset, int count)
                {
-                       IAsyncResult res = this.BeginWrite(buffer, offset, count, null, null);
+                       this.checkDisposed ();
+                       
+                       if (buffer == null)
+                       {
+                               throw new ArgumentNullException ("buffer");
+                       }
+                       if (offset < 0)
+                       {
+                               throw new ArgumentOutOfRangeException("offset is less than 0.");
+                       }
+                       if (offset > buffer.Length)
+                       {
+                               throw new ArgumentOutOfRangeException("offset is greater than the length of buffer.");
+                       }
+                       if (count < 0)
+                       {
+                               throw new ArgumentOutOfRangeException("count is less than 0.");
+                       }
+                       if (count > (buffer.Length - offset))
+                       {
+                               throw new ArgumentOutOfRangeException("count is less than the length of buffer minus the value of the offset parameter.");
+                       }
 
-                       this.EndWrite(res);
+                       if (this.context.HandshakeState != HandshakeState.Finished)
+                       {
+                               this.NegotiateHandshake ();
+                       }
+
+                       lock (this.write)
+                       {
+                               try
+                               {
+                                       // Send the buffer as a TLS record
+                                       byte[] record = this.protocol.EncodeRecord (ContentType.ApplicationData, buffer, offset, count);
+                                       this.innerStream.Write (record, 0, record.Length);
+                               }
+                               catch (TlsException ex)
+                               {
+                                       this.protocol.SendAlert(ex.Alert);
+                                       this.Close();
+                                       throw new IOException("The authentication or decryption has failed.", ex);
+                               }
+                               catch (Exception ex)
+                               {
+                                       throw new IOException("IO exception during Write.", ex);
+                               }
+                       }
                }
 
                public override bool CanRead
@@ -964,13 +1167,7 @@ namespace Mono.Security.Protocol.Tls
                        this.Dispose(false);
                }
 
-               public void Dispose()
-               {
-                       this.Dispose(true);
-                       GC.SuppressFinalize(this);
-               }
-
-               protected virtual void Dispose(bool disposing)
+               protected override void Dispose (bool disposing)
                {
                        if (!this.disposed)
                        {
@@ -996,6 +1193,7 @@ namespace Mono.Security.Protocol.Tls
                                }
 
                                this.disposed = true;
+                               base.Dispose (disposing);
                        }
                }