Convert blocking operations in HttpWebRequest and SslClientStream to non-blocking...
[mono.git] / mcs / class / Mono.Security / Mono.Security.Protocol.Tls / RecordProtocol.cs
index 98aa811ee06b40eab99ab1e9116fa6b80a7edc53..166f12f0d23c7d0790fc722c2d140494f75b1b11 100644 (file)
@@ -1,6 +1,6 @@
 // Transport Security Layer (TLS)
 // Copyright (c) 2003-2004 Carlos Guzman Alvarez
-
+// Copyright (C) 2006-2007 Novell, Inc (http://www.novell.com)
 //
 // Permission is hereby granted, free of charge, to any person obtaining
 // a copy of this software and associated documentation files (the
@@ -25,8 +25,6 @@
 using System;
 using System.Collections;
 using System.IO;
-using System.Security.Cryptography;
-using System.Security.Cryptography.X509Certificates;
 using System.Threading;
 
 using Mono.Security.Protocol.Tls.Handshake;
@@ -37,6 +35,8 @@ namespace Mono.Security.Protocol.Tls
        {
                #region Fields
 
+               private static ManualResetEvent record_processing = new ManualResetEvent (true);
+
                protected Stream        innerStream;
                protected Context       context;
 
@@ -75,7 +75,20 @@ namespace Mono.Security.Protocol.Tls
                }
 
                protected abstract void ProcessHandshakeMessage(TlsStream handMsg);
-               protected abstract void ProcessChangeCipherSpec();
+
+               protected virtual void ProcessChangeCipherSpec ()
+               {
+                       Context ctx = this.Context;
+
+                       // Reset sequence numbers
+                       ctx.ReadSequenceNumber = 0;
+
+                       if (ctx is ClientContext) {
+                               ctx.EndSwitchingSecurityParameters (true);
+                       } else {
+                               ctx.StartSwitchingSecurityParameters (false);
+                       }
+               }
 
                public virtual HandshakeMessage GetMessage(HandshakeType type)
                {
@@ -298,13 +311,14 @@ namespace Mono.Security.Protocol.Tls
 
                public IAsyncResult BeginReceiveRecord(Stream record, AsyncCallback callback, object state)
                {
-                       if (this.context.ConnectionEnd)
+                       if (this.context.ReceivedConnectionEnd)
                        {
                                throw new TlsException(
                                        AlertDescription.InternalError,
                                        "The session is finished and it's no longer valid.");
                        }
 
+                       record_processing.Reset ();
                        byte[] recordTypeBuffer = new byte[1];
 
                        ReceiveRecordAsyncResult internalResult = new ReceiveRecordAsyncResult(callback, state, recordTypeBuffer, record);
@@ -350,13 +364,10 @@ namespace Mono.Security.Protocol.Tls
                                if (contentType == ContentType.Alert && buffer.Length == 2)
                                {
                                }
-                               else
+                               else if ((this.Context.Read != null) && (this.Context.Read.Cipher != null))
                                {
-                                       if (this.context.IsActual && contentType != ContentType.ChangeCipherSpec)
-                                       {
-                                               buffer = this.decryptRecordFragment(contentType, buffer);
-                                               DebugHelper.WriteLine("Decrypted record data", buffer);
-                                       }
+                                       buffer = this.decryptRecordFragment (contentType, buffer);
+                                       DebugHelper.WriteLine ("Decrypted record data", buffer);
                                }
 
                                // Process record
@@ -395,7 +406,6 @@ namespace Mono.Security.Protocol.Tls
                                                throw new TlsException(
                                                        AlertDescription.UnexpectedMessage,
                                                        "Unknown record received from server.");
-                                               break;
                                }
 
                                internalResult.SetComplete(buffer);
@@ -419,16 +429,96 @@ namespace Mono.Security.Protocol.Tls
 
                        if (internalResult.CompletedWithError)
                                throw internalResult.AsyncException;
-                       else
-                               return internalResult.ResultingBuffer;
+
+                       byte[] result = internalResult.ResultingBuffer;
+                       record_processing.Set ();
+                       return result;
                }
 
                public byte[] ReceiveRecord(Stream record)
                {
+                       if (this.context.ReceivedConnectionEnd)
+                       {
+                               throw new TlsException(
+                                       AlertDescription.InternalError,
+                                       "The session is finished and it's no longer valid.");
+                       }
 
-                       IAsyncResult ar = this.BeginReceiveRecord(record, null, null);
-                       return this.EndReceiveRecord(ar);
+                       record_processing.Reset ();
+                       byte[] recordTypeBuffer = new byte[1];
+
+                       int bytesRead = record.Read(recordTypeBuffer, 0, recordTypeBuffer.Length);
 
+                       //We're at the end of the stream. Time to bail.
+                       if (bytesRead == 0)
+                       {
+                               return null;
+                       }
+
+                       // Try to read the Record Content Type
+                       int type = recordTypeBuffer[0];
+
+                       // Set last handshake message received to None
+                       this.context.LastHandshakeMsg = HandshakeType.ClientHello;
+
+                       ContentType     contentType     = (ContentType)type;
+                       byte[] buffer = this.ReadRecordBuffer(type, record);
+                       if (buffer == null)
+                       {
+                               // record incomplete (at the moment)
+                               return null;
+                       }
+
+                       // Decrypt message contents if needed
+                       if (contentType == ContentType.Alert && buffer.Length == 2)
+                       {
+                       }
+                       else if ((this.Context.Read != null) && (this.Context.Read.Cipher != null))
+                       {
+                               buffer = this.decryptRecordFragment (contentType, buffer);
+                               DebugHelper.WriteLine ("Decrypted record data", buffer);
+                       }
+
+                       // Process record
+                       switch (contentType)
+                       {
+                       case ContentType.Alert:
+                               this.ProcessAlert((AlertLevel)buffer [0], (AlertDescription)buffer [1]);
+                               if (record.CanSeek) 
+                               {
+                                       // don't reprocess that memory block
+                                       record.SetLength (0); 
+                               }
+                               buffer = null;
+                               break;
+
+                       case ContentType.ChangeCipherSpec:
+                               this.ProcessChangeCipherSpec();
+                               break;
+
+                       case ContentType.ApplicationData:
+                               break;
+
+                       case ContentType.Handshake:
+                               TlsStream message = new TlsStream (buffer);
+                               while (!message.EOF)
+                               {
+                                       this.ProcessHandshakeMessage(message);
+                               }
+                               break;
+
+                       case (ContentType)0x80:
+                               this.context.HandshakeMessages.Write (buffer);
+                               break;
+
+                       default:
+                               throw new TlsException(
+                                       AlertDescription.UnexpectedMessage,
+                                       "Unknown record received from server.");
+                       }
+
+                       record_processing.Set ();
+                       return buffer;
                }
 
                private byte[] ReadRecordBuffer (int contentType, Stream record)
@@ -514,8 +604,12 @@ namespace Mono.Security.Protocol.Tls
 
                private byte[] ReadStandardRecordBuffer (Stream record)
                {
-                       short protocol  = this.ReadShort(record);
-                       short length    = this.ReadShort(record);
+                       byte[] header = new byte[4];
+                       if (record.Read (header, 0, 4) != 4)
+                               throw new TlsException ("buffer underrun");
+                       
+                       short protocol = (short)((header [0] << 8) | header [1]);
+                       short length = (short)((header [2] << 8) | header [3]);
 
                        // process further only if the whole record is available
                        // note: the first 5 bytes aren't part of the length
@@ -552,16 +646,6 @@ namespace Mono.Security.Protocol.Tls
                        return buffer;
                }
 
-               private short ReadShort(Stream record)
-               {
-                       byte[] b = new byte[2];
-                       record.Read(b, 0, b.Length);
-
-                       short val = BitConverter.ToInt16(b, 0);
-
-                       return System.Net.IPAddress.HostToNetworkOrder(val);
-               }
-
                private void ProcessAlert(AlertLevel alertLevel, AlertDescription alertDesc)
                {
                        switch (alertLevel)
@@ -574,7 +658,7 @@ namespace Mono.Security.Protocol.Tls
                                switch (alertDesc)
                                {
                                        case AlertDescription.CloseNotify:
-                                               this.context.ConnectionEnd = true;
+                                               this.context.ReceivedConnectionEnd = true;
                                                break;
                                }
                                break;
@@ -599,16 +683,27 @@ namespace Mono.Security.Protocol.Tls
 
                public void SendAlert(Alert alert)
                {
-                       DebugHelper.WriteLine(">>>> Write Alert ({0}|{1})", alert.Description, alert.Message);
+                       AlertLevel level;
+                       AlertDescription description;
+                       bool close;
+
+                       if (alert == null) {
+                               DebugHelper.WriteLine(">>>> Write Alert NULL");
+                               level = AlertLevel.Fatal;
+                               description = AlertDescription.InternalError;
+                               close = true;
+                       } else {
+                               DebugHelper.WriteLine(">>>> Write Alert ({0}|{1})", alert.Description, alert.Message);
+                               level = alert.Level;
+                               description = alert.Description;
+                               close = alert.IsCloseNotify;
+                       }
 
                        // Write record
-                       this.SendRecord(
-                               ContentType.Alert, 
-                               new byte[]{(byte)alert.Level, (byte)alert.Description});
+                       this.SendRecord (ContentType.Alert, new byte[2] { (byte) level, (byte) description });
 
-                       if (alert.IsCloseNotify)
-                       {
-                               this.context.ConnectionEnd = true;
+                       if (close) {
+                               this.context.SentConnectionEnd = true;
                        }
                }
 
@@ -620,17 +715,73 @@ namespace Mono.Security.Protocol.Tls
                {
                        DebugHelper.WriteLine(">>>> Write Change Cipher Spec");
 
-                       // Send Change Cipher Spec message as a plain message
-                       this.context.IsActual = false;
-
-                       // Send Change Cipher Spec message
+                       // Send Change Cipher Spec message with the current cipher
+                       // or as plain text if this is the initial negotiation
                        this.SendRecord(ContentType.ChangeCipherSpec, new byte[] {1});
 
+                       Context ctx = this.context;
+
+                       // Reset sequence numbers
+                       ctx.WriteSequenceNumber = 0;
+
+                       // all further data sent will be encrypted with the negotiated
+                       // security parameters (now the current parameters)
+                       if (ctx is ClientContext) {
+                               ctx.StartSwitchingSecurityParameters (true);
+                       } else {
+                               ctx.EndSwitchingSecurityParameters (false);
+                       }
+               }
+
+               public void SendChangeCipherSpec(Stream recordStream)
+               {
+                       DebugHelper.WriteLine(">>>> Write Change Cipher Spec");
+
+                       byte[] record = this.EncodeRecord (ContentType.ChangeCipherSpec, new byte[] { 1 });
+
+                       // Send Change Cipher Spec message with the current cipher
+                       // or as plain text if this is the initial negotiation
+                       recordStream.Write(record, 0, record.Length);
+
+                       Context ctx = this.context;
+
                        // Reset sequence numbers
-                       this.context.WriteSequenceNumber = 0;
+                       ctx.WriteSequenceNumber = 0;
+
+                       // all further data sent will be encrypted with the negotiated
+                       // security parameters (now the current parameters)
+                       if (ctx is ClientContext) {
+                               ctx.StartSwitchingSecurityParameters (true);
+                       } else {
+                               ctx.EndSwitchingSecurityParameters (false);
+                       }
+               }
 
-                       // Make the pending state to be the current state
-                       this.context.IsActual = true;
+               public IAsyncResult BeginSendChangeCipherSpec(AsyncCallback callback, object state)
+               {
+                       DebugHelper.WriteLine (">>>> Write Change Cipher Spec");
+
+                       // Send Change Cipher Spec message with the current cipher
+                       // or as plain text if this is the initial negotiation
+                       return this.BeginSendRecord (ContentType.ChangeCipherSpec, new byte[] { 1 }, callback, state);
+               }
+
+               public void EndSendChangeCipherSpec (IAsyncResult asyncResult)
+               {
+                       this.EndSendRecord (asyncResult);
+
+                       Context ctx = this.context;
+
+                       // Reset sequence numbers
+                       ctx.WriteSequenceNumber = 0;
+
+                       // all further data sent will be encrypted with the negotiated
+                       // security parameters (now the current parameters)
+                       if (ctx is ClientContext) {
+                               ctx.StartSwitchingSecurityParameters (true);
+                       } else {
+                               ctx.EndSwitchingSecurityParameters (false);
+                       }
                }
 
                public IAsyncResult BeginSendRecord(HandshakeType handshakeType, AsyncCallback callback, object state)
@@ -672,7 +823,7 @@ namespace Mono.Security.Protocol.Tls
 
                public IAsyncResult BeginSendRecord(ContentType contentType, byte[] recordData, AsyncCallback callback, object state)
                {
-                       if (this.context.ConnectionEnd)
+                       if (this.context.SentConnectionEnd)
                        {
                                throw new TlsException(
                                        AlertDescription.InternalError,
@@ -722,7 +873,7 @@ namespace Mono.Security.Protocol.Tls
                        int                     offset,
                        int                     count)
                {
-                       if (this.context.ConnectionEnd)
+                       if (this.context.SentConnectionEnd)
                        {
                                throw new TlsException(
                                        AlertDescription.InternalError,
@@ -751,10 +902,10 @@ namespace Mono.Security.Protocol.Tls
                                fragment = new byte[fragmentLength];
                                Buffer.BlockCopy(recordData, position, fragment, 0, fragmentLength);
 
-                               if (this.context.IsActual)
+                               if ((this.Context.Write != null) && (this.Context.Write.Cipher != null))
                                {
                                        // Encrypt fragment
-                                       fragment = this.encryptRecordFragment(contentType, fragment);
+                                       fragment = this.encryptRecordFragment (contentType, fragment);
                                }
 
                                // Write tls message
@@ -771,7 +922,22 @@ namespace Mono.Security.Protocol.Tls
 
                        return record.ToArray();
                }
-               
+
+               public byte[] EncodeHandshakeRecord(HandshakeType handshakeType)
+               {
+                       HandshakeMessage msg = this.GetMessage(handshakeType);
+
+                       msg.Process();
+
+                       var bytes = this.EncodeRecord (msg.ContentType, msg.EncodeMessage ());
+
+                       msg.Update();
+
+                       msg.Reset();
+
+                       return bytes;
+               }
+                               
                #endregion
 
                #region Cryptography Methods
@@ -785,26 +951,17 @@ namespace Mono.Security.Protocol.Tls
                        // Calculate message MAC
                        if (this.Context is ClientContext)
                        {
-                               mac     = this.context.Cipher.ComputeClientRecordMAC(contentType, fragment);
+                               mac = this.context.Write.Cipher.ComputeClientRecordMAC(contentType, fragment);
                        }       
                        else
                        {
-                               mac     = this.context.Cipher.ComputeServerRecordMAC(contentType, fragment);
+                               mac = this.context.Write.Cipher.ComputeServerRecordMAC (contentType, fragment);
                        }
 
                        DebugHelper.WriteLine(">>>> Record MAC", mac);
 
                        // Encrypt the message
-                       byte[] ecr = this.context.Cipher.EncryptRecord(fragment, mac);
-
-                       // Set new Client Cipher IV
-                       if (this.context.Cipher.CipherMode == CipherMode.CBC)
-                       {
-                               byte[] iv = new byte[this.context.Cipher.IvSize];
-                               Buffer.BlockCopy(ecr, ecr.Length - iv.Length, iv, 0, iv.Length);
-
-                               this.context.Cipher.UpdateClientCipherIV(iv);
-                       }
+                       byte[] ecr = this.context.Write.Cipher.EncryptRecord (fragment, mac);
 
                        // Update sequence number
                        this.context.WriteSequenceNumber++;
@@ -818,11 +975,10 @@ namespace Mono.Security.Protocol.Tls
                {
                        byte[]  dcrFragment             = null;
                        byte[]  dcrMAC                  = null;
-                       bool    badRecordMac    = false;
 
                        try
                        {
-                               this.context.Cipher.DecryptRecord(fragment, ref dcrFragment, ref dcrMAC);
+                               this.context.Read.Cipher.DecryptRecord (fragment, out dcrFragment, out dcrMAC);
                        }
                        catch
                        {
@@ -830,7 +986,6 @@ namespace Mono.Security.Protocol.Tls
                                {
                                        this.Context.RecordProtocol.SendAlert(AlertDescription.DecryptionFailed);
                                }
-
                                throw;
                        }
                        
@@ -839,33 +994,17 @@ namespace Mono.Security.Protocol.Tls
 
                        if (this.Context is ClientContext)
                        {
-                               mac = this.context.Cipher.ComputeServerRecordMAC(contentType, dcrFragment);
+                               mac = this.context.Read.Cipher.ComputeServerRecordMAC(contentType, dcrFragment);
                        }
                        else
                        {
-                               mac = this.context.Cipher.ComputeClientRecordMAC(contentType, dcrFragment);
+                               mac = this.context.Read.Cipher.ComputeClientRecordMAC (contentType, dcrFragment);
                        }
 
                        DebugHelper.WriteLine(">>>> Record MAC", mac);
 
                        // Check record MAC
-                       if (mac.Length != dcrMAC.Length)
-                       {
-                               badRecordMac = true;
-                       }
-                       else
-                       {
-                               for (int i = 0; i < mac.Length; i++)
-                               {
-                                       if (mac[i] != dcrMAC[i])
-                                       {
-                                               badRecordMac = true;
-                                               break;
-                                       }
-                               }
-                       }
-
-                       if (badRecordMac)
+                       if (!Compare (mac, dcrMAC))
                        {
                                throw new TlsException(AlertDescription.BadRecordMAC, "Bad record MAC");
                        }
@@ -876,11 +1015,26 @@ namespace Mono.Security.Protocol.Tls
                        return dcrFragment;
                }
 
+               private bool Compare (byte[] array1, byte[] array2)
+               {
+                       if (array1 == null)
+                               return (array2 == null);
+                       if (array2 == null)
+                               return false;
+                       if (array1.Length != array2.Length)
+                               return false;
+                       for (int i = 0; i < array1.Length; i++) {
+                               if (array1[i] != array2[i])
+                                       return false;
+                       }
+                       return true;
+               }
+
                #endregion
 
                #region CipherSpecV2 processing
 
-               private void ProcessCipherSpecV2Buffer(SecurityProtocolType protocol, byte[] buffer)
+               private void ProcessCipherSpecV2Buffer (SecurityProtocolType protocol, byte[] buffer)
                {
                        TlsStream codes = new TlsStream(buffer);
 
@@ -897,7 +1051,7 @@ namespace Mono.Security.Protocol.Tls
                                        int index = this.Context.SupportedCiphers.IndexOf(code);
                                        if (index != -1)
                                        {
-                                               this.Context.Cipher     = this.Context.SupportedCiphers[index];
+                                               this.Context.Negotiating.Cipher = this.Context.SupportedCiphers[index];
                                                break;
                                        }
                                }
@@ -911,13 +1065,13 @@ namespace Mono.Security.Protocol.Tls
 
                                        if (cipher != null)
                                        {
-                                               this.Context.Cipher = cipher;
+                                               this.Context.Negotiating.Cipher = cipher;
                                                break;
                                        }
                                }
                        }
 
-                       if (this.Context.Cipher == null)
+                       if (this.Context.Negotiating == null)
                        {
                                throw new TlsException(AlertDescription.InsuficientSecurity, "Insuficient Security");
                        }