Warnings
[mono.git] / mcs / class / Mono.Security / Mono.Security.Protocol.Tls / RecordProtocol.cs
index 547d9ef77337316823b5c6ace523db2213fc0a01..1744d00196f2071586bf278f89dbff458be52def 100644 (file)
@@ -60,8 +60,9 @@ namespace Mono.Security.Protocol.Tls
 
                public RecordProtocol(Stream innerStream, Context context)
                {
-                       this.innerStream        = innerStream;
-                       this.context            = context;
+                       this.innerStream                        = innerStream;
+                       this.context                            = context;
+                       this.context.RecordProtocol = this;
                }
 
                #endregion
@@ -80,7 +81,9 @@ namespace Mono.Security.Protocol.Tls
                {
                        if (this.context.ConnectionEnd)
                        {
-                               throw this.context.CreateException("The session is finished and it's no longer valid.");
+                               throw new TlsException(
+                                       AlertDescription.InternalError,
+                                       "The session is finished and it's no longer valid.");
                        }
                        
                        // Try to read the Record Content Type
@@ -105,13 +108,21 @@ namespace Mono.Security.Protocol.Tls
                                        buffer, received, buffer.Length - received);
                        }
 
+                       DebugHelper.WriteLine(
+                               ">>>> Read record ({0}|{1})", 
+                               this.context.DecodeProtocolCode(protocol),
+                               contentType);
+                       DebugHelper.WriteLine("Record data", buffer);
+
                        TlsStream message = new TlsStream(buffer);
                
                        // Check that the message has a valid protocol version
                        if (protocol != this.context.Protocol && 
                                this.context.ProtocolNegotiated)
                        {
-                               throw this.context.CreateException("Invalid protocol version on message received from server");
+                               throw new TlsException(
+                                       AlertDescription.ProtocolVersion,
+                                       "Invalid protocol version on message received from server");
                        }
 
                        // Decrypt message contents if needed
@@ -126,6 +137,8 @@ namespace Mono.Security.Protocol.Tls
                                        message = this.decryptRecordFragment(
                                                contentType, 
                                                message.ToArray());
+
+                                       DebugHelper.WriteLine("Decrypted record data", message.ToArray());
                                }
                        }
 
@@ -161,7 +174,9 @@ namespace Mono.Security.Protocol.Tls
                                        break;
 
                                default:
-                                       throw this.context.CreateException("Unknown record received from server.");
+                                       throw new TlsException(
+                                               AlertDescription.UnexpectedMessage,
+                                               "Unknown record received from server.");
                        }
 
                        return result;
@@ -184,7 +199,7 @@ namespace Mono.Security.Protocol.Tls
                        switch (alertLevel)
                        {
                                case AlertLevel.Fatal:
-                                       throw this.context.CreateException(alertLevel, alertDesc);                                      
+                                       throw new TlsException(alertLevel, alertDesc);
 
                                case AlertLevel.Warning:
                                default:
@@ -204,26 +219,29 @@ namespace Mono.Security.Protocol.Tls
 
                public void SendAlert(AlertDescription description)
                {
-                       this.SendAlert(new Alert(this.Context, description));
+                       this.SendAlert(new Alert(description));
                }
 
                public void SendAlert(
                        AlertLevel                      level, 
                        AlertDescription        description)
                {
-                       this.SendAlert(new Alert(this.Context, level, description));
+                       this.SendAlert(new Alert(level, description));
                }
 
                public void SendAlert(Alert alert)
-               {                       
-                       // Write record
-                       this.SendRecord(ContentType.Alert, alert.ToArray());
+               {
+                       DebugHelper.WriteLine(">>>> Write Alert ({0}|{1})", alert.Description, alert.Message);
 
-                       // Update session
-                       alert.Update();
+                       // Write record
+                       this.SendRecord(
+                               ContentType.Alert, 
+                               new byte[]{(byte)alert.Level, (byte)alert.Description});
 
-                       // Reset message contents
-                       alert.Reset();
+                       if (alert.IsCloseNotify)
+                       {
+                               this.context.ConnectionEnd = true;
+                       }
                }
 
                #endregion
@@ -232,6 +250,8 @@ namespace Mono.Security.Protocol.Tls
 
                public void SendChangeCipherSpec()
                {
+                       DebugHelper.WriteLine(">>>> Write Change Cipher Spec");
+
                        // Send Change Cipher Spec message as a plain message
                        this.context.IsActual = false;
 
@@ -252,7 +272,9 @@ namespace Mono.Security.Protocol.Tls
                {
                        if (this.context.ConnectionEnd)
                        {
-                               throw this.context.CreateException("The session is finished and it's no longer valid.");
+                               throw new TlsException(
+                                       AlertDescription.InternalError,
+                                       "The session is finished and it's no longer valid.");
                        }
 
                        byte[] record = this.EncodeRecord(contentType, recordData);
@@ -277,7 +299,9 @@ namespace Mono.Security.Protocol.Tls
                {
                        if (this.context.ConnectionEnd)
                        {
-                               throw this.context.CreateException("The session is finished and it's no longer valid.");
+                               throw new TlsException(
+                                       AlertDescription.InternalError,
+                                       "The session is finished and it's no longer valid.");
                        }
 
                        TlsStream record = new TlsStream();
@@ -314,6 +338,8 @@ namespace Mono.Security.Protocol.Tls
                                record.Write((short)fragment.Length);
                                record.Write(fragment);
 
+                               DebugHelper.WriteLine("Record data", fragment);
+
                                // Update buffer position
                                position += fragmentLength;
                        }
@@ -341,6 +367,8 @@ namespace Mono.Security.Protocol.Tls
                                mac     = this.context.Cipher.ComputeServerRecordMAC(contentType, fragment);
                        }
 
+                       DebugHelper.WriteLine(">>>> Record MAC", mac);
+
                        // Encrypt the message
                        byte[] ecr = this.context.Cipher.EncryptRecord(fragment, mac);
 
@@ -348,7 +376,7 @@ namespace Mono.Security.Protocol.Tls
                        if (this.context.Cipher.CipherMode == CipherMode.CBC)
                        {
                                byte[] iv = new byte[this.context.Cipher.IvSize];
-                               System.Array.Copy(ecr, ecr.Length - iv.Length, iv, 0, iv.Length);
+                               Buffer.BlockCopy(ecr, ecr.Length - iv.Length, iv, 0, iv.Length);
 
                                this.context.Cipher.UpdateClientCipherIV(iv);
                        }
@@ -363,14 +391,27 @@ namespace Mono.Security.Protocol.Tls
                        ContentType     contentType, 
                        byte[]          fragment)
                {
-                       byte[]  dcrFragment     = null;
-                       byte[]  dcrMAC          = null;
+                       byte[]  dcrFragment             = null;
+                       byte[]  dcrMAC                  = null;
+                       bool    badRecordMac    = false;
 
-                       // Decrypt message
-                       this.context.Cipher.DecryptRecord(fragment, ref dcrFragment, ref dcrMAC);
+                       try
+                       {
+                               this.context.Cipher.DecryptRecord(fragment, ref dcrFragment, ref dcrMAC);
+                       }
+                       catch
+                       {
+                               if (this.context is ServerContext)
+                               {
+                                       this.Context.RecordProtocol.SendAlert(AlertDescription.DecryptionFailed);
+                               }
+
+                               throw;
+                       }
                        
-                       // Check MAC code
+                       // Generate record MAC
                        byte[] mac = null;
+
                        if (this.Context is ClientContext)
                        {
                                mac = this.context.Cipher.ComputeServerRecordMAC(contentType, dcrFragment);
@@ -380,20 +421,30 @@ namespace Mono.Security.Protocol.Tls
                                mac = this.context.Cipher.ComputeClientRecordMAC(contentType, dcrFragment);
                        }
 
-                       // Check that the mac is correct
+                       DebugHelper.WriteLine(">>>> Record MAC", mac);
+
+                       // Check record MAC
                        if (mac.Length != dcrMAC.Length)
                        {
-                               throw new TlsException("Invalid MAC received from server.");
+                               badRecordMac = true;
                        }
-
-                       for (int i = 0; i < mac.Length; i++)
+                       else
                        {
-                               if (mac[i] != dcrMAC[i])
+                               for (int i = 0; i < mac.Length; i++)
                                {
-                                       throw new TlsException("Invalid MAC received from server.");
+                                       if (mac[i] != dcrMAC[i])
+                                       {
+                                               badRecordMac = true;
+                                               break;
+                                       }
                                }
                        }
 
+                       if (badRecordMac)
+                       {
+                               throw new TlsException(AlertDescription.BadRecordMAC, "Bad record MAC");
+                       }
+
                        // Update sequence number
                        this.context.ReadSequenceNumber++;