public RecordProtocol(Stream innerStream, Context context)
{
- this.innerStream = innerStream;
- this.context = context;
+ this.innerStream = innerStream;
+ this.context = context;
+ this.context.RecordProtocol = this;
}
#endregion
public abstract void SendRecord(HandshakeType type);
protected abstract void ProcessHandshakeMessage(TlsStream handMsg);
+ protected abstract void ProcessChangeCipherSpec();
#endregion
{
if (this.context.ConnectionEnd)
{
- throw this.context.CreateException("The session is finished and it's no longer valid.");
+ throw new TlsException(
+ AlertDescription.InternalError,
+ "The session is finished and it's no longer valid.");
}
// Try to read the Record Content Type
buffer, received, buffer.Length - received);
}
+ DebugHelper.WriteLine(
+ ">>>> Read record ({0}|{1})",
+ this.context.DecodeProtocolCode(protocol),
+ contentType);
+ DebugHelper.WriteLine("Record data", buffer);
+
TlsStream message = new TlsStream(buffer);
// Check that the message has a valid protocol version
if (protocol != this.context.Protocol &&
this.context.ProtocolNegotiated)
{
- throw this.context.CreateException("Invalid protocol version on message received from server");
+ throw new TlsException(
+ AlertDescription.ProtocolVersion,
+ "Invalid protocol version on message received from server");
}
// Decrypt message contents if needed
message = this.decryptRecordFragment(
contentType,
message.ToArray());
+
+ DebugHelper.WriteLine("Decrypted record data", message.ToArray());
}
}
break;
case ContentType.ChangeCipherSpec:
- // Reset sequence numbers
- this.context.ReadSequenceNumber = 0;
+ this.ProcessChangeCipherSpec();
break;
case ContentType.ApplicationData:
break;
default:
- throw this.context.CreateException("Unknown record received from server.");
+ throw new TlsException(
+ AlertDescription.UnexpectedMessage,
+ "Unknown record received from server.");
}
return result;
switch (alertLevel)
{
case AlertLevel.Fatal:
- throw this.context.CreateException(alertLevel, alertDesc);
+ throw new TlsException(alertLevel, alertDesc);
case AlertLevel.Warning:
default:
public void SendAlert(AlertDescription description)
{
- this.SendAlert(new Alert(this.Context, description));
+ this.SendAlert(new Alert(description));
}
public void SendAlert(
AlertLevel level,
AlertDescription description)
{
- this.SendAlert(new Alert(this.Context, level, description));
+ this.SendAlert(new Alert(level, description));
}
public void SendAlert(Alert alert)
- {
- // Write record
- this.SendRecord(ContentType.Alert, alert.ToArray());
+ {
+ DebugHelper.WriteLine(">>>> Write Alert ({0}|{1})", alert.Description, alert.Message);
- // Update session
- alert.Update();
+ // Write record
+ this.SendRecord(
+ ContentType.Alert,
+ new byte[]{(byte)alert.Level, (byte)alert.Description});
- // Reset message contents
- alert.Reset();
+ if (alert.IsCloseNotify)
+ {
+ this.context.ConnectionEnd = true;
+ }
}
#endregion
public void SendChangeCipherSpec()
{
+ DebugHelper.WriteLine(">>>> Write Change Cipher Spec");
+
+ // Send Change Cipher Spec message as a plain message
+ this.context.IsActual = false;
+
// Send Change Cipher Spec message
this.SendRecord(ContentType.ChangeCipherSpec, new byte[] {1});
{
if (this.context.ConnectionEnd)
{
- throw this.context.CreateException("The session is finished and it's no longer valid.");
+ throw new TlsException(
+ AlertDescription.InternalError,
+ "The session is finished and it's no longer valid.");
}
byte[] record = this.EncodeRecord(contentType, recordData);
{
if (this.context.ConnectionEnd)
{
- throw this.context.CreateException("The session is finished and it's no longer valid.");
+ throw new TlsException(
+ AlertDescription.InternalError,
+ "The session is finished and it's no longer valid.");
}
TlsStream record = new TlsStream();
record.Write((short)fragment.Length);
record.Write(fragment);
+ DebugHelper.WriteLine("Record data", fragment);
+
// Update buffer position
position += fragmentLength;
}
ContentType contentType,
byte[] fragment)
{
+ byte[] mac = null;
+
// Calculate message MAC
- byte[] mac = this.context.Cipher.ComputeClientRecordMAC(contentType, fragment);
+ if (this.Context is ClientContext)
+ {
+ mac = this.context.Cipher.ComputeClientRecordMAC(contentType, fragment);
+ }
+ else
+ {
+ mac = this.context.Cipher.ComputeServerRecordMAC(contentType, fragment);
+ }
+
+ DebugHelper.WriteLine(">>>> Record MAC", mac);
// Encrypt the message
byte[] ecr = this.context.Cipher.EncryptRecord(fragment, mac);
- // Set new IV
+ // Set new Client Cipher IV
if (this.context.Cipher.CipherMode == CipherMode.CBC)
{
byte[] iv = new byte[this.context.Cipher.IvSize];
- System.Array.Copy(ecr, ecr.Length - iv.Length, iv, 0, iv.Length);
+ Buffer.BlockCopy(ecr, ecr.Length - iv.Length, iv, 0, iv.Length);
+
this.context.Cipher.UpdateClientCipherIV(iv);
}
ContentType contentType,
byte[] fragment)
{
- byte[] dcrFragment = null;
- byte[] dcrMAC = null;
-
- // Decrypt message
- this.context.Cipher.DecryptRecord(fragment, ref dcrFragment, ref dcrMAC);
+ byte[] dcrFragment = null;
+ byte[] dcrMAC = null;
+ bool badRecordMac = false;
- // Set new IV
- if (this.context.Cipher.CipherMode == CipherMode.CBC)
+ try
{
- byte[] iv = new byte[this.context.Cipher.IvSize];
- System.Array.Copy(fragment, fragment.Length - iv.Length, iv, 0, iv.Length);
- this.context.Cipher.UpdateServerCipherIV(iv);
+ this.context.Cipher.DecryptRecord(fragment, ref dcrFragment, ref dcrMAC);
+ }
+ catch
+ {
+ if (this.context is ServerContext)
+ {
+ this.Context.RecordProtocol.SendAlert(AlertDescription.DecryptionFailed);
+ }
+
+ throw;
}
- // Check MAC code
- byte[] mac = this.context.Cipher.ComputeServerRecordMAC(contentType, dcrFragment);
+ // Generate record MAC
+ byte[] mac = null;
- // Check that the mac is correct
+ if (this.Context is ClientContext)
+ {
+ mac = this.context.Cipher.ComputeServerRecordMAC(contentType, dcrFragment);
+ }
+ else
+ {
+ mac = this.context.Cipher.ComputeClientRecordMAC(contentType, dcrFragment);
+ }
+
+ DebugHelper.WriteLine(">>>> Record MAC", mac);
+
+ // Check record MAC
if (mac.Length != dcrMAC.Length)
{
- throw new TlsException("Invalid MAC received from server.");
+ badRecordMac = true;
}
- for (int i = 0; i < mac.Length; i++)
+ else
{
- if (mac[i] != dcrMAC[i])
+ for (int i = 0; i < mac.Length; i++)
{
- throw new TlsException("Invalid MAC received from server.");
+ if (mac[i] != dcrMAC[i])
+ {
+ badRecordMac = true;
+ break;
+ }
}
}
+ if (badRecordMac)
+ {
+ throw new TlsException(AlertDescription.BadRecordMAC, "Bad record MAC");
+ }
+
// Update sequence number
this.context.ReadSequenceNumber++;