Convert blocking operations in HttpWebRequest and SslClientStream to non-blocking...
[mono.git] / mcs / class / Mono.Security / Mono.Security.Protocol.Tls / SslStreamBase.cs
index 0e4802669d5a5e793cf62b7c1f860b91fb65c5a8..5c0032c02f0c5724a29fcf5052a639e39a606cc0 100644 (file)
@@ -1,6 +1,6 @@
 // Transport Security Layer (TLS)
 // Copyright (c) 2003-2004 Carlos Guzman Alvarez
-// Copyright (C) 2006 Novell, Inc (http://www.novell.com)
+// Copyright (C) 2006-2007 Novell, Inc (http://www.novell.com)
 //
 // Permission is hereby granted, free of charge, to any person obtaining
 // a copy of this software and associated documentation files (the
@@ -33,13 +33,18 @@ using System.Threading;
 
 namespace Mono.Security.Protocol.Tls
 {
-       public abstract class SslStreamBase: Stream, IDisposable
+#if INSIDE_SYSTEM
+       internal
+#else
+       public
+#endif
+       abstract class SslStreamBase: Stream, IDisposable
        {
                private delegate void AsyncHandshakeDelegate(InternalAsyncResult asyncResult, bool fromWrite);
                
                #region Fields
 
-               private const int WaitTimeOut = 5 * 60 * 1000;
+               static ManualResetEvent record_processing = new ManualResetEvent (true);        
 
                internal Stream innerStream;
                internal MemoryStream inputBuffer;
@@ -91,7 +96,7 @@ namespace Mono.Security.Protocol.Tls
                        {
                                try
                                {
-                                       this.OnNegotiateHandshakeCallback(asyncResult);
+                                       this.EndNegotiateHandshake(asyncResult);
                                }
                                catch (TlsException ex)
                                {
@@ -174,8 +179,8 @@ namespace Mono.Security.Protocol.Tls
 
                #region Abstracts/Virtuals
 
-               internal abstract IAsyncResult OnBeginNegotiateHandshake(AsyncCallback callback, object state);
-               internal abstract void OnNegotiateHandshakeCallback(IAsyncResult asyncResult);
+               internal abstract IAsyncResult BeginNegotiateHandshake (AsyncCallback callback, object state);
+               internal abstract void EndNegotiateHandshake (IAsyncResult result);
 
                internal abstract X509Certificate OnLocalCertificateSelection(X509CertificateCollection clientCertificates,
                                                                                                                        X509Certificate serverCertificate,
@@ -183,6 +188,8 @@ namespace Mono.Security.Protocol.Tls
                                                                                                                        X509CertificateCollection serverRequestedCertificates);
 
                internal abstract bool OnRemoteCertificateValidation(X509Certificate certificate, int[] errors);
+               internal abstract ValidationResult OnRemoteCertificateValidation2 (Mono.Security.X509.X509CertificateCollection collection);
+               internal abstract bool HaveRemoteValidation2Callback { get; }
 
                internal abstract AsymmetricAlgorithm OnLocalPrivateKeySelection(X509Certificate certificate, string targetHost);
 
@@ -203,6 +210,11 @@ namespace Mono.Security.Protocol.Tls
                        return OnRemoteCertificateValidation(certificate, errors);
                }
 
+               internal ValidationResult RaiseRemoteCertificateValidation2 (Mono.Security.X509.X509CertificateCollection collection)
+               {
+                       return OnRemoteCertificateValidation2 (collection);
+               }
+
                internal AsymmetricAlgorithm RaiseLocalPrivateKeySelection(
                        X509Certificate certificate,
                        string targetHost)
@@ -480,7 +492,7 @@ namespace Mono.Security.Protocol.Tls
                                {
                                        if (this.context.HandshakeState == HandshakeState.None)
                                        {
-                                               this.OnBeginNegotiateHandshake(new AsyncCallback(AsyncHandshakeCallback), asyncResult);
+                                               this.BeginNegotiateHandshake(new AsyncCallback(AsyncHandshakeCallback), asyncResult);
 
                                                return true;
                                        }
@@ -603,7 +615,7 @@ namespace Mono.Security.Protocol.Tls
                                {
                                        asyncResult.SetComplete(preReadSize);
                                }
-                               else if (!this.context.ConnectionEnd)
+                               else if (!this.context.ReceivedConnectionEnd)
                                {
                                        // this will read data from the network until we have (at least) one
                                        // record to send back to the caller
@@ -722,11 +734,15 @@ namespace Mono.Security.Protocol.Tls
 
                                if (!dataToReturn && (n > 0))
                                {
-                                       // there is no record to return to caller and (possibly) more data waiting
-                                       // so continue reading from network (and appending to stream)
-                                       recordStream.Position = recordStream.Length;
-                                       this.innerStream.BeginRead(recbuf, 0, recbuf.Length,
-                                               new AsyncCallback(InternalReadCallback), state);
+                                       if (context.ReceivedConnectionEnd) {
+                                               internalResult.SetComplete (0);
+                                       } else {
+                                               // there is no record to return to caller and (possibly) more data waiting
+                                               // so continue reading from network (and appending to stream)
+                                               recordStream.Position = recordStream.Length;
+                                               this.innerStream.BeginRead(recbuf, 0, recbuf.Length,
+                                                       new AsyncCallback(InternalReadCallback), state);
+                                       }
                                }
                                else
                                {
@@ -862,7 +878,7 @@ namespace Mono.Security.Protocol.Tls
                        // Always wait until the read is complete
                        if (!asyncResult.IsCompleted)
                        {
-                               if (!asyncResult.AsyncWaitHandle.WaitOne (WaitTimeOut, false))
+                               if (!asyncResult.AsyncWaitHandle.WaitOne ())
                                        throw new TlsException (AlertDescription.InternalError, "Couldn't complete EndRead");
                        }
 
@@ -887,7 +903,7 @@ namespace Mono.Security.Protocol.Tls
 
                        if (!asyncResult.IsCompleted)
                        {
-                               if (!internalResult.AsyncWaitHandle.WaitOne (WaitTimeOut, false))
+                               if (!internalResult.AsyncWaitHandle.WaitOne ())
                                        throw new TlsException (AlertDescription.InternalError, "Couldn't complete EndWrite");
                        }
 
@@ -899,7 +915,7 @@ namespace Mono.Security.Protocol.Tls
 
                public override void Close()
                {
-                       ((IDisposable)this).Dispose();
+                       base.Close ();
                }
 
                public override void Flush()
@@ -946,6 +962,7 @@ namespace Mono.Security.Protocol.Tls
 
                        lock (this.read) {
                                try {
+                                       record_processing.Reset ();
                                        // do we already have some decrypted data ?
                                        if (this.inputBuffer.Position > 0) {
                                                // or maybe we used all the buffer before ?
@@ -953,8 +970,10 @@ namespace Mono.Security.Protocol.Tls
                                                        this.inputBuffer.SetLength (0);
                                                } else {
                                                        int n = this.inputBuffer.Read (buffer, offset, count);
-                                                       if (n > 0)
+                                                       if (n > 0) {
+                                                               record_processing.Set ();
                                                                return n;
+                                                       }
                                                }
                                        }
 
@@ -965,7 +984,16 @@ namespace Mono.Security.Protocol.Tls
                                                        needMoreData = false;
                                                        // if we loop, then it either means we need more data
                                                        byte[] recbuf = new byte[16384];
-                                                       int n = innerStream.Read (recbuf, 0, recbuf.Length);
+                                                       int n = 0;
+                                                       if (count == 1) {
+                                                               int value = innerStream.ReadByte ();
+                                                               if (value >= 0) {
+                                                                       recbuf[0] = (byte) value;
+                                                                       n = 1;
+                                                               }
+                                                       } else {
+                                                               n = innerStream.Read (recbuf, 0, recbuf.Length);
+                                                       }
                                                        if (n > 0) {
                                                                // Add the new received data to the waiting data
                                                                if ((recordStream.Length > 0) && (recordStream.Position != recordStream.Length))
@@ -973,6 +1001,7 @@ namespace Mono.Security.Protocol.Tls
                                                                recordStream.Write (recbuf, 0, n);
                                                        } else {
                                                                // or that the read operation is done (lost connection in the case of a network stream).
+                                                               record_processing.Set ();
                                                                return 0;
                                                        }
                                                }
@@ -1021,7 +1050,9 @@ namespace Mono.Security.Protocol.Tls
                                                        if (dataToReturn) {
                                                                // we have record(s) to return -or- no more available to read from network
                                                                // reset position for further reading
-                                                               return this.inputBuffer.Read (buffer, offset, count);
+                                                               int i = inputBuffer.Read (buffer, offset, count);
+                                                               record_processing.Set ();
+                                                               return i;
                                                        }
                                                }
                                        }
@@ -1143,13 +1174,7 @@ namespace Mono.Security.Protocol.Tls
                        this.Dispose(false);
                }
 
-               public void Dispose()
-               {
-                       this.Dispose(true);
-                       GC.SuppressFinalize(this);
-               }
-
-               protected virtual void Dispose(bool disposing)
+               protected override void Dispose (bool disposing)
                {
                        if (!this.disposed)
                        {
@@ -1158,10 +1183,13 @@ namespace Mono.Security.Protocol.Tls
                                        if (this.innerStream != null)
                                        {
                                                if (this.context.HandshakeState == HandshakeState.Finished &&
-                                                       !this.context.ConnectionEnd)
+                                                       !this.context.SentConnectionEnd)
                                                {
-                                                       // Write close notify                                                   
-                                                       this.protocol.SendAlert(AlertDescription.CloseNotify);
+                                                       // Write close notify
+                                                       try {
+                                                               this.protocol.SendAlert(AlertDescription.CloseNotify);
+                                                       } catch {
+                                                       }
                                                }
 
                                                if (this.ownsStream)
@@ -1175,6 +1203,7 @@ namespace Mono.Security.Protocol.Tls
                                }
 
                                this.disposed = true;
+                               base.Dispose (disposing);
                        }
                }