TLS protocol: add handshake state validation
[mono.git] / mcs / class / Mono.Security / Mono.Security.Protocol.Tls / ServerRecordProtocol.cs
index 6e316dc3659dd1219a9a6dd436d789626c92a512..31c2547902beeb8bc2fba4d4d4b27b27d7e783e8 100644 (file)
@@ -33,6 +33,8 @@ namespace Mono.Security.Protocol.Tls
 {
        internal class ServerRecordProtocol : RecordProtocol
        {
+               TlsClientCertificate cert;
+               
                #region Constructors
 
                public ServerRecordProtocol(
@@ -93,30 +95,45 @@ namespace Mono.Security.Protocol.Tls
                private HandshakeMessage createClientHandshakeMessage(
                        HandshakeType type, byte[] buffer)
                {
+                       var last = context.LastHandshakeMsg;
                        switch (type)
                        {
                                case HandshakeType.ClientHello:
                                        return new TlsClientHello(this.context, buffer);
 
                                case HandshakeType.Certificate:
-                                       return new TlsClientCertificate(this.context, buffer);
+                                       if (last != HandshakeType.ClientHello)
+                                               break;
+                                       cert = new TlsClientCertificate(this.context, buffer);
+                                       return cert;
 
                                case HandshakeType.ClientKeyExchange:
-                                       return new TlsClientKeyExchange(this.context, buffer);
+                                       if (last == HandshakeType.ClientHello || last == HandshakeType.Certificate)
+                                               return new TlsClientKeyExchange(this.context, buffer);
+                                       break;
 
                                case HandshakeType.CertificateVerify:
-                                       return new TlsClientCertificateVerify(this.context, buffer);
+                                       if (last == HandshakeType.ClientKeyExchange && cert != null)
+                                               return new TlsClientCertificateVerify(this.context, buffer);
+                                       break;
 
                                case HandshakeType.Finished:
-                                       return new TlsClientFinished(this.context, buffer);
-
+                                       // Certificates are optional, but if provided, they should send a CertificateVerify
+                                       bool check = (cert == null) ? (last == HandshakeType.ClientKeyExchange) : (last == HandshakeType.CertificateVerify);
+                                       // ChangeCipherSpecDone is not an handshake message (it's a content type) but still needs to be happens before finished
+                                       if (check && context.ChangeCipherSpecDone) {
+                                               context.ChangeCipherSpecDone = false;
+                                               return new TlsClientFinished(this.context, buffer);
+                                       }
+                                       break;
+                                       
                                default:
-                                       throw new TlsException(
-                                               AlertDescription.UnexpectedMessage,
-                                               String.Format(CultureInfo.CurrentUICulture,
-                                                       "Unknown server handshake message received ({0})", 
-                                                       type.ToString()));
+                                       throw new TlsException(AlertDescription.UnexpectedMessage, String.Format(CultureInfo.CurrentUICulture,
+                                                                                                                "Unknown server handshake message received ({0})", 
+                                                                                                                type.ToString()));
+                                       break;
                        }
+                       throw new TlsException (AlertDescription.HandshakeFailiure, String.Format ("Protocol error, unexpected protocol transition from {0} to {1}", last, type));
                }
 
                private HandshakeMessage createServerHandshakeMessage(