TLS protocol: add handshake state validation
authorSebastien Pouliot <sebastien@xamarin.com>
Fri, 6 Mar 2015 15:34:14 +0000 (10:34 -0500)
committerMiguel de Icaza <miguel@gnome.org>
Fri, 6 Mar 2015 15:34:14 +0000 (10:34 -0500)
mcs/class/Mono.Security/Mono.Security.Protocol.Tls/ClientRecordProtocol.cs
mcs/class/Mono.Security/Mono.Security.Protocol.Tls/Context.cs
mcs/class/Mono.Security/Mono.Security.Protocol.Tls/RecordProtocol.cs
mcs/class/Mono.Security/Mono.Security.Protocol.Tls/ServerRecordProtocol.cs

index 7cece5060e6fb8785a1779cbdfc3b5edad5a9cd5..acaa0c2c37c28798539dcbe0488c9332f52f71c0 100644 (file)
@@ -129,6 +129,7 @@ namespace Mono.Security.Protocol.Tls
                        HandshakeType type, byte[] buffer)
                {
                        ClientContext context = (ClientContext)this.context;
+                       var last = context.LastHandshakeMsg;
 
                        switch (type)
                        {
@@ -148,23 +149,44 @@ namespace Mono.Security.Protocol.Tls
                                        return null;
 
                                case HandshakeType.ServerHello:
+                                       if (last != HandshakeType.HelloRequest)
+                                               break;
                                        return new TlsServerHello(this.context, buffer);
 
+                                       // Optional
                                case HandshakeType.Certificate:
+                                       if (last != HandshakeType.ServerHello)
+                                               break;
                                        return new TlsServerCertificate(this.context, buffer);
 
+                                       // Optional
                                case HandshakeType.ServerKeyExchange:
-                                       return new TlsServerKeyExchange(this.context, buffer);
+                                       // only for RSA_EXPORT
+                                       if (last == HandshakeType.Certificate && context.Current.Cipher.IsExportable)
+                                               return new TlsServerKeyExchange(this.context, buffer);
+                                       break;
 
+                                       // Optional
                                case HandshakeType.CertificateRequest:
-                                       return new TlsServerCertificateRequest(this.context, buffer);
+                                       if (last == HandshakeType.ServerKeyExchange || last == HandshakeType.Certificate)
+                                               return new TlsServerCertificateRequest(this.context, buffer);
+                                       break;
 
                                case HandshakeType.ServerHelloDone:
-                                       return new TlsServerHelloDone(this.context, buffer);
+                                       if (last == HandshakeType.CertificateRequest || last == HandshakeType.Certificate || last == HandshakeType.ServerHello)
+                                               return new TlsServerHelloDone(this.context, buffer);
+                                       break;
 
                                case HandshakeType.Finished:
-                                       return new TlsServerFinished(this.context, buffer);
-
+                                       // depends if a full (ServerHelloDone) or an abbreviated handshake (ServerHello) is being done
+                                       bool check = context.AbbreviatedHandshake ? (last == HandshakeType.ServerHello) : (last == HandshakeType.ServerHelloDone);
+                                       // ChangeCipherSpecDone is not an handshake message (it's a content type) but still needs to be happens before finished
+                                       if (check && context.ChangeCipherSpecDone) {
+                                               context.ChangeCipherSpecDone = false;
+                                               return new TlsServerFinished (this.context, buffer);
+                                       }
+                                       break;
+                                       
                                default:
                                        throw new TlsException(
                                                AlertDescription.UnexpectedMessage,
@@ -172,6 +194,7 @@ namespace Mono.Security.Protocol.Tls
                                                        "Unknown server handshake message received ({0})", 
                                                        type.ToString()));
                        }
+                       throw new TlsException (AlertDescription.HandshakeFailiure, String.Format ("Protocol error, unexpected protocol transition from {0} to {1}", last, type));
                }
 
                #endregion
index b4caf28b5c7bb0e27e61e7214dcb81ab45e57151..3923daf1a8bbea3de03326bc160c4c8f2a157fa8 100644 (file)
@@ -122,6 +122,8 @@ namespace Mono.Security.Protocol.Tls
                        set { this.protocolNegotiated = value; }
                }
 
+               public bool ChangeCipherSpecDone { get; set; }
+
                public SecurityProtocolType SecurityProtocol
                {
                        get 
index 589510685a63193d51b8308798be26820c61cf76..e8ae131f2a50d4bf11b748b348ca2b37a3b45987 100644 (file)
@@ -88,6 +88,8 @@ namespace Mono.Security.Protocol.Tls
                        } else {
                                ctx.StartSwitchingSecurityParameters (false);
                        }
+
+                       ctx.ChangeCipherSpecDone = true;
                }
 
                public virtual HandshakeMessage GetMessage(HandshakeType type)
@@ -348,9 +350,6 @@ namespace Mono.Security.Protocol.Tls
                                // Try to read the Record Content Type
                                int type = internalResult.InitialBuffer[0];
 
-                               // Set last handshake message received to None
-                               this.context.LastHandshakeMsg = HandshakeType.ClientHello;
-
                                ContentType     contentType     = (ContentType)type;
                                byte[] buffer = this.ReadRecordBuffer(type, record);
                                if (buffer == null)
@@ -458,9 +457,6 @@ namespace Mono.Security.Protocol.Tls
                        // Try to read the Record Content Type
                        int type = recordTypeBuffer[0];
 
-                       // Set last handshake message received to None
-                       this.context.LastHandshakeMsg = HandshakeType.ClientHello;
-
                        ContentType     contentType     = (ContentType)type;
                        byte[] buffer = this.ReadRecordBuffer(type, record);
                        if (buffer == null)
index 6e316dc3659dd1219a9a6dd436d789626c92a512..31c2547902beeb8bc2fba4d4d4b27b27d7e783e8 100644 (file)
@@ -33,6 +33,8 @@ namespace Mono.Security.Protocol.Tls
 {
        internal class ServerRecordProtocol : RecordProtocol
        {
+               TlsClientCertificate cert;
+               
                #region Constructors
 
                public ServerRecordProtocol(
@@ -93,30 +95,45 @@ namespace Mono.Security.Protocol.Tls
                private HandshakeMessage createClientHandshakeMessage(
                        HandshakeType type, byte[] buffer)
                {
+                       var last = context.LastHandshakeMsg;
                        switch (type)
                        {
                                case HandshakeType.ClientHello:
                                        return new TlsClientHello(this.context, buffer);
 
                                case HandshakeType.Certificate:
-                                       return new TlsClientCertificate(this.context, buffer);
+                                       if (last != HandshakeType.ClientHello)
+                                               break;
+                                       cert = new TlsClientCertificate(this.context, buffer);
+                                       return cert;
 
                                case HandshakeType.ClientKeyExchange:
-                                       return new TlsClientKeyExchange(this.context, buffer);
+                                       if (last == HandshakeType.ClientHello || last == HandshakeType.Certificate)
+                                               return new TlsClientKeyExchange(this.context, buffer);
+                                       break;
 
                                case HandshakeType.CertificateVerify:
-                                       return new TlsClientCertificateVerify(this.context, buffer);
+                                       if (last == HandshakeType.ClientKeyExchange && cert != null)
+                                               return new TlsClientCertificateVerify(this.context, buffer);
+                                       break;
 
                                case HandshakeType.Finished:
-                                       return new TlsClientFinished(this.context, buffer);
-
+                                       // Certificates are optional, but if provided, they should send a CertificateVerify
+                                       bool check = (cert == null) ? (last == HandshakeType.ClientKeyExchange) : (last == HandshakeType.CertificateVerify);
+                                       // ChangeCipherSpecDone is not an handshake message (it's a content type) but still needs to be happens before finished
+                                       if (check && context.ChangeCipherSpecDone) {
+                                               context.ChangeCipherSpecDone = false;
+                                               return new TlsClientFinished(this.context, buffer);
+                                       }
+                                       break;
+                                       
                                default:
-                                       throw new TlsException(
-                                               AlertDescription.UnexpectedMessage,
-                                               String.Format(CultureInfo.CurrentUICulture,
-                                                       "Unknown server handshake message received ({0})", 
-                                                       type.ToString()));
+                                       throw new TlsException(AlertDescription.UnexpectedMessage, String.Format(CultureInfo.CurrentUICulture,
+                                                                                                                "Unknown server handshake message received ({0})", 
+                                                                                                                type.ToString()));
+                                       break;
                        }
+                       throw new TlsException (AlertDescription.HandshakeFailiure, String.Format ("Protocol error, unexpected protocol transition from {0} to {1}", last, type));
                }
 
                private HandshakeMessage createServerHandshakeMessage(