TLS protocol: add handshake state validation
[mono.git] / mcs / class / Mono.Security / Mono.Security.Protocol.Tls / Context.cs
index 52617ddfde8922400dd540700b9e67f7db61395c..3923daf1a8bbea3de03326bc160c4c8f2a157fa8 100644 (file)
@@ -76,7 +76,8 @@ namespace Mono.Security.Protocol.Tls
 
                // Misc
                private bool    abbreviatedHandshake;
-               private bool    connectionEnd;
+               private bool    receivedConnectionEnd;
+               private bool    sentConnectionEnd;
                private bool    protocolNegotiated;
                
                // Sequence numbers
@@ -121,6 +122,8 @@ namespace Mono.Security.Protocol.Tls
                        set { this.protocolNegotiated = value; }
                }
 
+               public bool ChangeCipherSpecDone { get; set; }
+
                public SecurityProtocolType SecurityProtocol
                {
                        get 
@@ -203,10 +206,16 @@ namespace Mono.Security.Protocol.Tls
                        set { this.handshakeState = value; }
                }
 
-               public bool ConnectionEnd
+               public bool ReceivedConnectionEnd
+               {
+                       get { return this.receivedConnectionEnd; }
+                       set { this.receivedConnectionEnd = value; }
+               }
+
+               public bool SentConnectionEnd
                {
-                       get { return this.connectionEnd; }
-                       set { this.connectionEnd = value; }
+                       get { return this.sentConnectionEnd; }
+                       set { this.sentConnectionEnd = value; }
                }
 
                public CipherSuiteCollection SupportedCiphers
@@ -316,7 +325,7 @@ namespace Mono.Security.Protocol.Tls
                {
                        DateTime now = DateTime.UtcNow;
                                                                                                                                                     
-                       return (int)(now.Ticks - UNIX_BASE_TICKS / TimeSpan.TicksPerSecond);
+                       return (int)((now.Ticks - UNIX_BASE_TICKS) / TimeSpan.TicksPerSecond);
                }
 
                public byte[] GetSecureRandomBytes(int count)
@@ -343,21 +352,48 @@ namespace Mono.Security.Protocol.Tls
                public virtual void ClearKeyInfo()
                {
                        // Clear Master Secret
-                       this.masterSecret       = null;
+                       if (masterSecret != null) {
+                               Array.Clear (masterSecret, 0, masterSecret.Length);
+                               masterSecret = null;
+                       }
 
                        // Clear client and server random
-                       this.clientRandom       = null;
-                       this.serverRandom       = null;
-                       this.randomCS           = null;
-                       this.randomSC           = null;
+                       if (clientRandom != null) {
+                               Array.Clear (clientRandom, 0, clientRandom.Length);
+                               clientRandom = null;
+                       }
+                       if (serverRandom != null) {
+                               Array.Clear (serverRandom, 0, serverRandom.Length);
+                               serverRandom = null;
+                       }
+                       if (randomCS != null) {
+                               Array.Clear (randomCS, 0, randomCS.Length);
+                               randomCS = null;
+                       }
+                       if (randomSC != null) {
+                               Array.Clear (randomSC, 0, randomSC.Length);
+                               randomSC = null;
+                       }
 
                        // Clear client keys
-                       this.clientWriteKey     = null;
-                       this.clientWriteIV      = null;
+                       if (clientWriteKey != null) {
+                               Array.Clear (clientWriteKey, 0, clientWriteKey.Length);
+                               clientWriteKey = null;
+                       }
+                       if (clientWriteIV != null) {
+                               Array.Clear (clientWriteIV, 0, clientWriteIV.Length);
+                               clientWriteIV = null;
+                       }
                        
                        // Clear server keys
-                       this.serverWriteKey     = null;
-                       this.serverWriteIV      = null;
+                       if (serverWriteKey != null) {
+                               Array.Clear (serverWriteKey, 0, serverWriteKey.Length);
+                               serverWriteKey = null;
+                       }
+                       if (serverWriteIV != null) {
+                               Array.Clear (serverWriteIV, 0, serverWriteIV.Length);
+                               serverWriteIV = null;
+                       }
 
                        // Reset handshake messages
                        this.handshakeMessages.Reset();
@@ -371,7 +407,7 @@ namespace Mono.Security.Protocol.Tls
                        }
                }
 
-               public SecurityProtocolType DecodeProtocolCode(short code)
+               public SecurityProtocolType DecodeProtocolCode (short code, bool allowFallback = false)
                {
                        switch (code)
                        {
@@ -382,6 +418,10 @@ namespace Mono.Security.Protocol.Tls
                                        return SecurityProtocolType.Ssl3;
 
                                default:
+                                       // if allowed we'll continue using TLS (1.0) even if the other side is capable of using a newer
+                                       // version of the TLS protocol
+                                       if (allowFallback && (code > (short) Context.TLS1_PROTOCOL_CODE))
+                                               return SecurityProtocolType.Tls;
                                        throw new NotSupportedException("Unsupported security protocol type");
                        }
                }
@@ -394,9 +434,7 @@ namespace Mono.Security.Protocol.Tls
                                (this.SecurityProtocolFlags & SecurityProtocolType.Default) == SecurityProtocolType.Default)
                        {
                                this.SecurityProtocol = protocolType;
-                               this.SupportedCiphers.Clear();
-                               this.SupportedCiphers = null;
-                               this.SupportedCiphers = CipherSuiteFactory.GetSupportedCiphers(protocolType);
+                               this.SupportedCiphers = CipherSuiteFactory.GetSupportedCiphers ((this is ServerContext), protocolType);
                        }
                        else
                        {