Warnings
[mono.git] / mcs / class / Mono.Security / Mono.Security.Protocol.Tls / RecordProtocol.cs
index 84f1012b7f23aba392ed69846f9f1eabe8126f95..1744d00196f2071586bf278f89dbff458be52def 100644 (file)
@@ -27,7 +27,6 @@ using System.IO;
 using System.Security.Cryptography;
 using System.Security.Cryptography.X509Certificates;
 
-using Mono.Security.Protocol.Tls.Alerts;
 using Mono.Security.Protocol.Tls.Handshake;
 
 namespace Mono.Security.Protocol.Tls
@@ -61,16 +60,18 @@ namespace Mono.Security.Protocol.Tls
 
                public RecordProtocol(Stream innerStream, Context context)
                {
-                       this.innerStream        = innerStream;
-                       this.context            = context;
+                       this.innerStream                        = innerStream;
+                       this.context                            = context;
+                       this.context.RecordProtocol = this;
                }
 
                #endregion
 
                #region Abstract Methods
 
-               public abstract void SendRecord(TlsHandshakeType type);
+               public abstract void SendRecord(HandshakeType type);
                protected abstract void ProcessHandshakeMessage(TlsStream handMsg);
+               protected abstract void ProcessChangeCipherSpec();
                                
                #endregion
 
@@ -80,7 +81,9 @@ namespace Mono.Security.Protocol.Tls
                {
                        if (this.context.ConnectionEnd)
                        {
-                               throw this.context.CreateException("The session is finished and it's no longer valid.");
+                               throw new TlsException(
+                                       AlertDescription.InternalError,
+                                       "The session is finished and it's no longer valid.");
                        }
                        
                        // Try to read the Record Content Type
@@ -92,9 +95,9 @@ namespace Mono.Security.Protocol.Tls
                                return null;
                        }
 
-                       TlsContentType  contentType     = (TlsContentType)type;
-                       short                   protocol        = this.readShort();
-                       short                   length          = this.readShort();
+                       ContentType     contentType     = (ContentType)type;
+                       short           protocol        = this.readShort();
+                       short           length          = this.readShort();
                        
                        // Read Record data
                        int             received        = 0;
@@ -105,49 +108,62 @@ namespace Mono.Security.Protocol.Tls
                                        buffer, received, buffer.Length - received);
                        }
 
+                       DebugHelper.WriteLine(
+                               ">>>> Read record ({0}|{1})", 
+                               this.context.DecodeProtocolCode(protocol),
+                               contentType);
+                       DebugHelper.WriteLine("Record data", buffer);
+
                        TlsStream message = new TlsStream(buffer);
                
                        // Check that the message has a valid protocol version
-                       if (protocol != this.context.Protocol && this.context.ProtocolNegotiated)
+                       if (protocol != this.context.Protocol && 
+                               this.context.ProtocolNegotiated)
                        {
-                               throw this.context.CreateException("Invalid protocol version on message received from server");
+                               throw new TlsException(
+                                       AlertDescription.ProtocolVersion,
+                                       "Invalid protocol version on message received from server");
                        }
 
                        // Decrypt message contents if needed
-                       if (contentType == TlsContentType.Alert && length == 2)
+                       if (contentType == ContentType.Alert && length == 2)
                        {
                        }
                        else
                        {
                                if (this.context.IsActual &&
-                                       contentType != TlsContentType.ChangeCipherSpec)
+                                       contentType != ContentType.ChangeCipherSpec)
                                {
                                        message = this.decryptRecordFragment(
                                                contentType, 
                                                message.ToArray());
+
+                                       DebugHelper.WriteLine("Decrypted record data", message.ToArray());
                                }
                        }
 
+                       // Set last handshake message received to None
+                       this.context.LastHandshakeMsg = HandshakeType.None;
+                       
+                       // Process record
                        byte[] result = message.ToArray();
 
-                       // Process record
                        switch (contentType)
                        {
-                               case TlsContentType.Alert:
+                               case ContentType.Alert:
                                        this.processAlert(
-                                               (TlsAlertLevel)message.ReadByte(),
-                                               (TlsAlertDescription)message.ReadByte());
+                                               (AlertLevel)message.ReadByte(),
+                                               (AlertDescription)message.ReadByte());
                                        break;
 
-                               case TlsContentType.ChangeCipherSpec:
-                                       // Reset sequence numbers
-                                       this.context.ReadSequenceNumber = 0;
+                               case ContentType.ChangeCipherSpec:
+                                       this.ProcessChangeCipherSpec();
                                        break;
 
-                               case TlsContentType.ApplicationData:
+                               case ContentType.ApplicationData:
                                        break;
 
-                               case TlsContentType.Handshake:
+                               case ContentType.Handshake:
                                        while (!message.EOF)
                                        {
                                                this.ProcessHandshakeMessage(message);
@@ -158,7 +174,9 @@ namespace Mono.Security.Protocol.Tls
                                        break;
 
                                default:
-                                       throw this.context.CreateException("Unknown record received from server.");
+                                       throw new TlsException(
+                                               AlertDescription.UnexpectedMessage,
+                                               "Unknown record received from server.");
                        }
 
                        return result;
@@ -175,19 +193,19 @@ namespace Mono.Security.Protocol.Tls
                }
 
                private void processAlert(
-                       TlsAlertLevel           alertLevel, 
-                       TlsAlertDescription alertDesc)
+                       AlertLevel                      alertLevel, 
+                       AlertDescription        alertDesc)
                {
                        switch (alertLevel)
                        {
-                               case TlsAlertLevel.Fatal:
-                                       throw this.context.CreateException(alertLevel, alertDesc);                                      
+                               case AlertLevel.Fatal:
+                                       throw new TlsException(alertLevel, alertDesc);
 
-                               case TlsAlertLevel.Warning:
+                               case AlertLevel.Warning:
                                default:
                                switch (alertDesc)
                                {
-                                       case TlsAlertDescription.CloseNotify:
+                                       case AlertDescription.CloseNotify:
                                                this.context.ConnectionEnd = true;
                                                break;
                                }
@@ -199,28 +217,31 @@ namespace Mono.Security.Protocol.Tls
 
                #region Send Alert Methods
 
-               public void SendAlert(TlsAlertDescription description)
+               public void SendAlert(AlertDescription description)
                {
-                       this.SendAlert(new TlsAlert(this.Context, description));
+                       this.SendAlert(new Alert(description));
                }
 
                public void SendAlert(
-                       TlsAlertLevel           level, 
-                       TlsAlertDescription description)
+                       AlertLevel                      level, 
+                       AlertDescription        description)
                {
-                       this.SendAlert(new TlsAlert(this.Context, level, description));
+                       this.SendAlert(new Alert(level, description));
                }
 
-               public void SendAlert(TlsAlert alert)
-               {                       
-                       // Write record
-                       this.SendRecord(TlsContentType.Alert, alert.ToArray());
+               public void SendAlert(Alert alert)
+               {
+                       DebugHelper.WriteLine(">>>> Write Alert ({0}|{1})", alert.Description, alert.Message);
 
-                       // Update session
-                       alert.Update();
+                       // Write record
+                       this.SendRecord(
+                               ContentType.Alert, 
+                               new byte[]{(byte)alert.Level, (byte)alert.Description});
 
-                       // Reset message contents
-                       alert.Reset();
+                       if (alert.IsCloseNotify)
+                       {
+                               this.context.ConnectionEnd = true;
+                       }
                }
 
                #endregion
@@ -229,8 +250,13 @@ namespace Mono.Security.Protocol.Tls
 
                public void SendChangeCipherSpec()
                {
+                       DebugHelper.WriteLine(">>>> Write Change Cipher Spec");
+
+                       // Send Change Cipher Spec message as a plain message
+                       this.context.IsActual = false;
+
                        // Send Change Cipher Spec message
-                       this.SendRecord(TlsContentType.ChangeCipherSpec, new byte[] {1});
+                       this.SendRecord(ContentType.ChangeCipherSpec, new byte[] {1});
 
                        // Reset sequence numbers
                        this.context.WriteSequenceNumber = 0;
@@ -239,14 +265,16 @@ namespace Mono.Security.Protocol.Tls
                        this.context.IsActual = true;
 
                        // Send Finished message
-                       this.SendRecord(TlsHandshakeType.Finished);                     
+                       this.SendRecord(HandshakeType.Finished);                        
                }
 
-               public void SendRecord(TlsContentType contentType, byte[] recordData)
+               public void SendRecord(ContentType contentType, byte[] recordData)
                {
                        if (this.context.ConnectionEnd)
                        {
-                               throw this.context.CreateException("The session is finished and it's no longer valid.");
+                               throw new TlsException(
+                                       AlertDescription.InternalError,
+                                       "The session is finished and it's no longer valid.");
                        }
 
                        byte[] record = this.EncodeRecord(contentType, recordData);
@@ -254,7 +282,7 @@ namespace Mono.Security.Protocol.Tls
                        this.innerStream.Write(record, 0, record.Length);
                }
 
-               public byte[] EncodeRecord(TlsContentType contentType, byte[] recordData)
+               public byte[] EncodeRecord(ContentType contentType, byte[] recordData)
                {
                        return this.EncodeRecord(
                                contentType,
@@ -264,14 +292,16 @@ namespace Mono.Security.Protocol.Tls
                }
 
                public byte[] EncodeRecord(
-                       TlsContentType  contentType, 
-                       byte[]                  recordData,
-                       int                             offset,
-                       int                             count)
+                       ContentType     contentType, 
+                       byte[]          recordData,
+                       int                     offset,
+                       int                     count)
                {
                        if (this.context.ConnectionEnd)
                        {
-                               throw this.context.CreateException("The session is finished and it's no longer valid.");
+                               throw new TlsException(
+                                       AlertDescription.InternalError,
+                                       "The session is finished and it's no longer valid.");
                        }
 
                        TlsStream record = new TlsStream();
@@ -308,6 +338,8 @@ namespace Mono.Security.Protocol.Tls
                                record.Write((short)fragment.Length);
                                record.Write(fragment);
 
+                               DebugHelper.WriteLine("Record data", fragment);
+
                                // Update buffer position
                                position += fragmentLength;
                        }
@@ -320,20 +352,32 @@ namespace Mono.Security.Protocol.Tls
                #region Cryptography Methods
 
                private byte[] encryptRecordFragment(
-                       TlsContentType  contentType, 
-                       byte[]                  fragment)
+                       ContentType     contentType, 
+                       byte[]          fragment)
                {
+                       byte[] mac      = null;
+
                        // Calculate message MAC
-                       byte[] mac      = this.context.Cipher.ComputeClientRecordMAC(contentType, fragment);
+                       if (this.Context is ClientContext)
+                       {
+                               mac     = this.context.Cipher.ComputeClientRecordMAC(contentType, fragment);
+                       }       
+                       else
+                       {
+                               mac     = this.context.Cipher.ComputeServerRecordMAC(contentType, fragment);
+                       }
+
+                       DebugHelper.WriteLine(">>>> Record MAC", mac);
 
                        // Encrypt the message
                        byte[] ecr = this.context.Cipher.EncryptRecord(fragment, mac);
 
-                       // Set new IV
+                       // Set new Client Cipher IV
                        if (this.context.Cipher.CipherMode == CipherMode.CBC)
                        {
                                byte[] iv = new byte[this.context.Cipher.IvSize];
-                               System.Array.Copy(ecr, ecr.Length - iv.Length, iv, 0, iv.Length);
+                               Buffer.BlockCopy(ecr, ecr.Length - iv.Length, iv, 0, iv.Length);
+
                                this.context.Cipher.UpdateClientCipherIV(iv);
                        }
 
@@ -344,39 +388,63 @@ namespace Mono.Security.Protocol.Tls
                }
 
                private TlsStream decryptRecordFragment(
-                       TlsContentType  contentType, 
-                       byte[]                  fragment)
+                       ContentType     contentType, 
+                       byte[]          fragment)
                {
-                       byte[]  dcrFragment     = null;
-                       byte[]  dcrMAC          = null;
-
-                       // Decrypt message
-                       this.context.Cipher.DecryptRecord(fragment, ref dcrFragment, ref dcrMAC);
+                       byte[]  dcrFragment             = null;
+                       byte[]  dcrMAC                  = null;
+                       bool    badRecordMac    = false;
 
-                       // Set new IV
-                       if (this.context.Cipher.CipherMode == CipherMode.CBC)
+                       try
                        {
-                               byte[] iv = new byte[this.context.Cipher.IvSize];
-                               System.Array.Copy(fragment, fragment.Length - iv.Length, iv, 0, iv.Length);
-                               this.context.Cipher.UpdateServerCipherIV(iv);
+                               this.context.Cipher.DecryptRecord(fragment, ref dcrFragment, ref dcrMAC);
+                       }
+                       catch
+                       {
+                               if (this.context is ServerContext)
+                               {
+                                       this.Context.RecordProtocol.SendAlert(AlertDescription.DecryptionFailed);
+                               }
+
+                               throw;
                        }
                        
-                       // Check MAC code
-                       byte[] mac = this.context.Cipher.ComputeServerRecordMAC(contentType, dcrFragment);
+                       // Generate record MAC
+                       byte[] mac = null;
 
-                       // Check that the mac is correct
+                       if (this.Context is ClientContext)
+                       {
+                               mac = this.context.Cipher.ComputeServerRecordMAC(contentType, dcrFragment);
+                       }
+                       else
+                       {
+                               mac = this.context.Cipher.ComputeClientRecordMAC(contentType, dcrFragment);
+                       }
+
+                       DebugHelper.WriteLine(">>>> Record MAC", mac);
+
+                       // Check record MAC
                        if (mac.Length != dcrMAC.Length)
                        {
-                               throw new TlsException("Invalid MAC received from server.");
+                               badRecordMac = true;
                        }
-                       for (int i = 0; i < mac.Length; i++)
+                       else
                        {
-                               if (mac[i] != dcrMAC[i])
+                               for (int i = 0; i < mac.Length; i++)
                                {
-                                       throw new TlsException("Invalid MAC received from server.");
+                                       if (mac[i] != dcrMAC[i])
+                                       {
+                                               badRecordMac = true;
+                                               break;
+                                       }
                                }
                        }
 
+                       if (badRecordMac)
+                       {
+                               throw new TlsException(AlertDescription.BadRecordMAC, "Bad record MAC");
+                       }
+
                        // Update sequence number
                        this.context.ReadSequenceNumber++;