using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
-using Mono.Security.Protocol.Tls.Alerts;
using Mono.Security.Protocol.Tls.Handshake;
namespace Mono.Security.Protocol.Tls
public RecordProtocol(Stream innerStream, Context context)
{
- this.innerStream = innerStream;
- this.context = context;
+ this.innerStream = innerStream;
+ this.context = context;
+ this.context.RecordProtocol = this;
}
#endregion
#region Abstract Methods
- public abstract void SendRecord(TlsHandshakeType type);
+ public abstract void SendRecord(HandshakeType type);
protected abstract void ProcessHandshakeMessage(TlsStream handMsg);
+ protected abstract void ProcessChangeCipherSpec();
#endregion
{
if (this.context.ConnectionEnd)
{
- throw this.context.CreateException("The session is finished and it's no longer valid.");
+ throw new TlsException(
+ AlertDescription.InternalError,
+ "The session is finished and it's no longer valid.");
}
// Try to read the Record Content Type
return null;
}
- TlsContentType contentType = (TlsContentType)type;
- short protocol = this.readShort();
- short length = this.readShort();
+ ContentType contentType = (ContentType)type;
+ short protocol = this.readShort();
+ short length = this.readShort();
// Read Record data
int received = 0;
buffer, received, buffer.Length - received);
}
+ DebugHelper.WriteLine(
+ ">>>> Read record ({0}|{1})",
+ this.context.DecodeProtocolCode(protocol),
+ contentType);
+ DebugHelper.WriteLine("Record data", buffer);
+
TlsStream message = new TlsStream(buffer);
// Check that the message has a valid protocol version
- if (protocol != this.context.Protocol && this.context.ProtocolNegotiated)
+ if (protocol != this.context.Protocol &&
+ this.context.ProtocolNegotiated)
{
- throw this.context.CreateException("Invalid protocol version on message received from server");
+ throw new TlsException(
+ AlertDescription.ProtocolVersion,
+ "Invalid protocol version on message received from server");
}
// Decrypt message contents if needed
- if (contentType == TlsContentType.Alert && length == 2)
+ if (contentType == ContentType.Alert && length == 2)
{
}
else
{
if (this.context.IsActual &&
- contentType != TlsContentType.ChangeCipherSpec)
+ contentType != ContentType.ChangeCipherSpec)
{
message = this.decryptRecordFragment(
contentType,
message.ToArray());
+
+ DebugHelper.WriteLine("Decrypted record data", message.ToArray());
}
}
+ // Set last handshake message received to None
+ this.context.LastHandshakeMsg = HandshakeType.None;
+
+ // Process record
byte[] result = message.ToArray();
- // Process record
switch (contentType)
{
- case TlsContentType.Alert:
+ case ContentType.Alert:
this.processAlert(
- (TlsAlertLevel)message.ReadByte(),
- (TlsAlertDescription)message.ReadByte());
+ (AlertLevel)message.ReadByte(),
+ (AlertDescription)message.ReadByte());
break;
- case TlsContentType.ChangeCipherSpec:
- // Reset sequence numbers
- this.context.ReadSequenceNumber = 0;
+ case ContentType.ChangeCipherSpec:
+ this.ProcessChangeCipherSpec();
break;
- case TlsContentType.ApplicationData:
+ case ContentType.ApplicationData:
break;
- case TlsContentType.Handshake:
+ case ContentType.Handshake:
while (!message.EOF)
{
this.ProcessHandshakeMessage(message);
break;
default:
- throw this.context.CreateException("Unknown record received from server.");
+ throw new TlsException(
+ AlertDescription.UnexpectedMessage,
+ "Unknown record received from server.");
}
return result;
}
private void processAlert(
- TlsAlertLevel alertLevel,
- TlsAlertDescription alertDesc)
+ AlertLevel alertLevel,
+ AlertDescription alertDesc)
{
switch (alertLevel)
{
- case TlsAlertLevel.Fatal:
- throw this.context.CreateException(alertLevel, alertDesc);
+ case AlertLevel.Fatal:
+ throw new TlsException(alertLevel, alertDesc);
- case TlsAlertLevel.Warning:
+ case AlertLevel.Warning:
default:
switch (alertDesc)
{
- case TlsAlertDescription.CloseNotify:
+ case AlertDescription.CloseNotify:
this.context.ConnectionEnd = true;
break;
}
#region Send Alert Methods
- public void SendAlert(TlsAlertDescription description)
+ public void SendAlert(AlertDescription description)
{
- this.SendAlert(new TlsAlert(this.Context, description));
+ this.SendAlert(new Alert(description));
}
public void SendAlert(
- TlsAlertLevel level,
- TlsAlertDescription description)
+ AlertLevel level,
+ AlertDescription description)
{
- this.SendAlert(new TlsAlert(this.Context, level, description));
+ this.SendAlert(new Alert(level, description));
}
- public void SendAlert(TlsAlert alert)
- {
- // Write record
- this.SendRecord(TlsContentType.Alert, alert.ToArray());
+ public void SendAlert(Alert alert)
+ {
+ DebugHelper.WriteLine(">>>> Write Alert ({0}|{1})", alert.Description, alert.Message);
- // Update session
- alert.Update();
+ // Write record
+ this.SendRecord(
+ ContentType.Alert,
+ new byte[]{(byte)alert.Level, (byte)alert.Description});
- // Reset message contents
- alert.Reset();
+ if (alert.IsCloseNotify)
+ {
+ this.context.ConnectionEnd = true;
+ }
}
#endregion
public void SendChangeCipherSpec()
{
+ DebugHelper.WriteLine(">>>> Write Change Cipher Spec");
+
+ // Send Change Cipher Spec message as a plain message
+ this.context.IsActual = false;
+
// Send Change Cipher Spec message
- this.SendRecord(TlsContentType.ChangeCipherSpec, new byte[] {1});
+ this.SendRecord(ContentType.ChangeCipherSpec, new byte[] {1});
// Reset sequence numbers
this.context.WriteSequenceNumber = 0;
this.context.IsActual = true;
// Send Finished message
- this.SendRecord(TlsHandshakeType.Finished);
+ this.SendRecord(HandshakeType.Finished);
}
- public void SendRecord(TlsContentType contentType, byte[] recordData)
+ public void SendRecord(ContentType contentType, byte[] recordData)
{
if (this.context.ConnectionEnd)
{
- throw this.context.CreateException("The session is finished and it's no longer valid.");
+ throw new TlsException(
+ AlertDescription.InternalError,
+ "The session is finished and it's no longer valid.");
}
byte[] record = this.EncodeRecord(contentType, recordData);
this.innerStream.Write(record, 0, record.Length);
}
- public byte[] EncodeRecord(TlsContentType contentType, byte[] recordData)
+ public byte[] EncodeRecord(ContentType contentType, byte[] recordData)
{
return this.EncodeRecord(
contentType,
}
public byte[] EncodeRecord(
- TlsContentType contentType,
- byte[] recordData,
- int offset,
- int count)
+ ContentType contentType,
+ byte[] recordData,
+ int offset,
+ int count)
{
if (this.context.ConnectionEnd)
{
- throw this.context.CreateException("The session is finished and it's no longer valid.");
+ throw new TlsException(
+ AlertDescription.InternalError,
+ "The session is finished and it's no longer valid.");
}
TlsStream record = new TlsStream();
record.Write((short)fragment.Length);
record.Write(fragment);
+ DebugHelper.WriteLine("Record data", fragment);
+
// Update buffer position
position += fragmentLength;
}
#region Cryptography Methods
private byte[] encryptRecordFragment(
- TlsContentType contentType,
- byte[] fragment)
+ ContentType contentType,
+ byte[] fragment)
{
+ byte[] mac = null;
+
// Calculate message MAC
- byte[] mac = this.context.Cipher.ComputeClientRecordMAC(contentType, fragment);
+ if (this.Context is ClientContext)
+ {
+ mac = this.context.Cipher.ComputeClientRecordMAC(contentType, fragment);
+ }
+ else
+ {
+ mac = this.context.Cipher.ComputeServerRecordMAC(contentType, fragment);
+ }
+
+ DebugHelper.WriteLine(">>>> Record MAC", mac);
// Encrypt the message
byte[] ecr = this.context.Cipher.EncryptRecord(fragment, mac);
- // Set new IV
+ // Set new Client Cipher IV
if (this.context.Cipher.CipherMode == CipherMode.CBC)
{
byte[] iv = new byte[this.context.Cipher.IvSize];
- System.Array.Copy(ecr, ecr.Length - iv.Length, iv, 0, iv.Length);
+ Buffer.BlockCopy(ecr, ecr.Length - iv.Length, iv, 0, iv.Length);
+
this.context.Cipher.UpdateClientCipherIV(iv);
}
}
private TlsStream decryptRecordFragment(
- TlsContentType contentType,
- byte[] fragment)
+ ContentType contentType,
+ byte[] fragment)
{
- byte[] dcrFragment = null;
- byte[] dcrMAC = null;
-
- // Decrypt message
- this.context.Cipher.DecryptRecord(fragment, ref dcrFragment, ref dcrMAC);
+ byte[] dcrFragment = null;
+ byte[] dcrMAC = null;
+ bool badRecordMac = false;
- // Set new IV
- if (this.context.Cipher.CipherMode == CipherMode.CBC)
+ try
{
- byte[] iv = new byte[this.context.Cipher.IvSize];
- System.Array.Copy(fragment, fragment.Length - iv.Length, iv, 0, iv.Length);
- this.context.Cipher.UpdateServerCipherIV(iv);
+ this.context.Cipher.DecryptRecord(fragment, ref dcrFragment, ref dcrMAC);
+ }
+ catch
+ {
+ if (this.context is ServerContext)
+ {
+ this.Context.RecordProtocol.SendAlert(AlertDescription.DecryptionFailed);
+ }
+
+ throw;
}
- // Check MAC code
- byte[] mac = this.context.Cipher.ComputeServerRecordMAC(contentType, dcrFragment);
+ // Generate record MAC
+ byte[] mac = null;
- // Check that the mac is correct
+ if (this.Context is ClientContext)
+ {
+ mac = this.context.Cipher.ComputeServerRecordMAC(contentType, dcrFragment);
+ }
+ else
+ {
+ mac = this.context.Cipher.ComputeClientRecordMAC(contentType, dcrFragment);
+ }
+
+ DebugHelper.WriteLine(">>>> Record MAC", mac);
+
+ // Check record MAC
if (mac.Length != dcrMAC.Length)
{
- throw new TlsException("Invalid MAC received from server.");
+ badRecordMac = true;
}
- for (int i = 0; i < mac.Length; i++)
+ else
{
- if (mac[i] != dcrMAC[i])
+ for (int i = 0; i < mac.Length; i++)
{
- throw new TlsException("Invalid MAC received from server.");
+ if (mac[i] != dcrMAC[i])
+ {
+ badRecordMac = true;
+ break;
+ }
}
}
+ if (badRecordMac)
+ {
+ throw new TlsException(AlertDescription.BadRecordMAC, "Bad record MAC");
+ }
+
// Update sequence number
this.context.ReadSequenceNumber++;