2004-02-20 Carlos Guzm��n ��lvarez <carlosga@telefonica.net>
[mono.git] / mcs / class / Mono.Security / Mono.Security.Protocol.Tls / RecordProtocol.cs
1 /* Transport Security Layer (TLS)
2  * Copyright (c) 2003-2004 Carlos Guzman Alvarez
3  * 
4  * Permission is hereby granted, free of charge, to any person 
5  * obtaining a copy of this software and associated documentation 
6  * files (the "Software"), to deal in the Software without restriction, 
7  * including without limitation the rights to use, copy, modify, merge, 
8  * publish, distribute, sublicense, and/or sell copies of the Software, 
9  * and to permit persons to whom the Software is furnished to do so, 
10  * subject to the following conditions:
11  * 
12  * The above copyright notice and this permission notice shall be included 
13  * in all copies or substantial portions of the Software.
14  * 
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 
16  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 
17  * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 
18  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 
19  * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 
20  * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 
22  * DEALINGS IN THE SOFTWARE.
23  */
24
25 using System;
26 using System.IO;
27 using System.Security.Cryptography;
28 using System.Security.Cryptography.X509Certificates;
29
30 using Mono.Security.Protocol.Tls.Alerts;
31 using Mono.Security.Protocol.Tls.Handshake;
32
33 namespace Mono.Security.Protocol.Tls
34 {
35         internal abstract class RecordProtocol
36         {
37                 #region Fields
38
39                 protected Stream                innerStream;
40                 protected TlsContext    context;
41
42                 #endregion
43
44                 #region Properties
45
46                 public Stream InnerStream
47                 {
48                         get { return this.innerStream; }
49                         set { this.innerStream = value; }
50                 }
51
52                 public TlsContext Context
53                 {
54                         get { return this.context; }
55                         set { this.context = value; }
56                 }
57
58                 #endregion
59
60                 #region Constructors
61
62                 public RecordProtocol(Stream innerStream, TlsContext context)
63                 {
64                         this.innerStream        = innerStream;
65                         this.context            = context;
66                 }
67
68                 #endregion
69
70                 #region Abstract Methods
71
72                 public abstract void SendRecord(TlsHandshakeType type);
73                 protected abstract void ProcessHandshakeMessage(TlsStream handMsg);
74                                 
75                 #endregion
76
77                 #region Reveive Record Methods
78
79                 public byte[] ReceiveRecord()
80                 {
81                         if (this.context.ConnectionEnd)
82                         {
83                                 throw this.context.CreateException("The session is finished and it's no longer valid.");
84                         }
85                         
86                         // Try to read the Record Content Type
87                         int type = this.innerStream.ReadByte();
88
89                         // There are no more data for read
90                         if (type == -1)
91                         {
92                                 return null;
93                         }
94
95                         TlsContentType  contentType     = (TlsContentType)type;
96                         short                   protocol        = this.readShort();
97                         short                   length          = this.readShort();
98                         
99                         // Read Record data
100                         int             received        = 0;
101                         byte[]  buffer          = new byte[length];
102                         while (received != length)
103                         {
104                                 received += this.innerStream.Read(
105                                         buffer, received, buffer.Length - received);
106                         }
107
108                         TlsStream message = new TlsStream(buffer);
109                 
110                         // Check that the message has a valid protocol version
111                         if (protocol != this.context.Protocol &&
112                                 this.context.HelloDone)
113                         {
114                                 throw this.context.CreateException("Invalid protocol version on message received from server");
115                         }
116
117                         // Decrypt message contents if needed
118                         if (contentType == TlsContentType.Alert && length == 2)
119                         {
120                         }
121                         else
122                         {
123                                 if (this.context.IsActual &&
124                                         contentType != TlsContentType.ChangeCipherSpec)
125                                 {
126                                         message = this.decryptRecordFragment(
127                                                 contentType, 
128                                                 message.ToArray());
129                                 }
130                         }
131
132                         byte[] result = message.ToArray();
133
134                         // Process record
135                         switch (contentType)
136                         {
137                                 case TlsContentType.Alert:
138                                         this.processAlert(
139                                                 (TlsAlertLevel)message.ReadByte(),
140                                                 (TlsAlertDescription)message.ReadByte());
141                                         break;
142
143                                 case TlsContentType.ChangeCipherSpec:
144                                         // Reset sequence numbers
145                                         this.context.ReadSequenceNumber = 0;
146                                         break;
147
148                                 case TlsContentType.ApplicationData:
149                                         break;
150
151                                 case TlsContentType.Handshake:
152                                         while (!message.EOF)
153                                         {
154                                                 this.ProcessHandshakeMessage(message);
155                                         }
156
157                                         // Update handshakes of current messages
158                                         this.context.HandshakeMessages.Write(message.ToArray());
159                                         break;
160
161                                 default:
162                                         throw this.context.CreateException("Unknown record received from server.");
163                         }
164
165                         return result;
166                 }
167
168                 private short readShort()
169                 {
170                         byte[] b = new byte[2];
171                         this.innerStream.Read(b, 0, b.Length);
172
173                         short val = BitConverter.ToInt16(b, 0);
174
175                         return System.Net.IPAddress.HostToNetworkOrder(val);
176                 }
177
178                 private void processAlert(
179                         TlsAlertLevel           alertLevel, 
180                         TlsAlertDescription alertDesc)
181                 {
182                         switch (alertLevel)
183                         {
184                                 case TlsAlertLevel.Fatal:
185                                         throw this.context.CreateException(alertLevel, alertDesc);                                      
186
187                                 case TlsAlertLevel.Warning:
188                                 default:
189                                 switch (alertDesc)
190                                 {
191                                         case TlsAlertDescription.CloseNotify:
192                                                 this.context.ConnectionEnd = true;
193                                                 break;
194                                 }
195                                 break;
196                         }
197                 }
198
199                 #endregion
200
201                 #region Send Record Methods
202
203                 public void SendAlert(TlsAlert alert)
204                 {                       
205                         // Write record
206                         this.SendRecord(TlsContentType.Alert, alert.ToArray());
207
208                         // Update session
209                         alert.Update();
210
211                         // Reset message contents
212                         alert.Reset();
213                 }
214
215                 public void SendChangeCipherSpec()
216                 {
217                         // Send Change Cipher Spec message
218                         this.SendRecord(TlsContentType.ChangeCipherSpec, new byte[] {1});
219
220                         // Reset sequence numbers
221                         this.context.WriteSequenceNumber = 0;
222
223                         // Make the pending state to be the current state
224                         this.context.IsActual = true;
225
226                         // Send Finished message
227                         this.SendRecord(TlsHandshakeType.Finished);                     
228                 }
229
230                 public void SendRecord(TlsContentType contentType, byte[] recordData)
231                 {
232                         if (this.context.ConnectionEnd)
233                         {
234                                 throw this.context.CreateException("The session is finished and it's no longer valid.");
235                         }
236
237                         byte[] record = this.EncodeRecord(contentType, recordData);
238
239                         this.innerStream.Write(record, 0, record.Length);
240                 }
241
242                 public byte[] EncodeRecord(TlsContentType contentType, byte[] recordData)
243                 {
244                         return this.EncodeRecord(
245                                 contentType,
246                                 recordData,
247                                 0,
248                                 recordData.Length);
249                 }
250
251                 public byte[] EncodeRecord(
252                         TlsContentType  contentType, 
253                         byte[]                  recordData,
254                         int                             offset,
255                         int                             count)
256                 {
257                         if (this.context.ConnectionEnd)
258                         {
259                                 throw this.context.CreateException("The session is finished and it's no longer valid.");
260                         }
261
262                         TlsStream record = new TlsStream();
263
264                         int     position = offset;
265
266                         while (position < ( offset + count ))
267                         {
268                                 short   fragmentLength = 0;
269                                 byte[]  fragment;
270
271                                 if ((count - position) > TlsContext.MAX_FRAGMENT_SIZE)
272                                 {
273                                         fragmentLength = TlsContext.MAX_FRAGMENT_SIZE;
274                                 }
275                                 else
276                                 {
277                                         fragmentLength = (short)(count - position);
278                                 }
279
280                                 // Fill the fragment data
281                                 fragment = new byte[fragmentLength];
282                                 Buffer.BlockCopy(recordData, position, fragment, 0, fragmentLength);
283
284                                 if (this.context.IsActual)
285                                 {
286                                         // Encrypt fragment
287                                         fragment = this.encryptRecordFragment(contentType, fragment);
288                                 }
289
290                                 // Write tls message
291                                 record.Write((byte)contentType);
292                                 record.Write(this.context.Protocol);
293                                 record.Write((short)fragment.Length);
294                                 record.Write(fragment);
295
296                                 // Update buffer position
297                                 position += fragmentLength;
298                         }
299
300                         return record.ToArray();
301                 }
302                 
303                 #endregion
304
305                 #region Cryptography Methods
306
307                 private byte[] encryptRecordFragment(
308                         TlsContentType  contentType, 
309                         byte[]                  fragment)
310                 {
311                         // Calculate message MAC
312                         byte[] mac      = this.context.Cipher.ComputeClientRecordMAC(contentType, fragment);
313
314                         // Encrypt the message
315                         byte[] ecr = this.context.Cipher.EncryptRecord(fragment, mac);
316
317                         // Set new IV
318                         if (this.context.Cipher.CipherMode == CipherMode.CBC)
319                         {
320                                 byte[] iv = new byte[this.context.Cipher.IvSize];
321                                 System.Array.Copy(ecr, ecr.Length - iv.Length, iv, 0, iv.Length);
322                                 this.context.Cipher.UpdateClientCipherIV(iv);
323                         }
324
325                         // Update sequence number
326                         this.context.WriteSequenceNumber++;
327
328                         return ecr;
329                 }
330
331                 private TlsStream decryptRecordFragment(
332                         TlsContentType  contentType, 
333                         byte[]                  fragment)
334                 {
335                         byte[]  dcrFragment     = null;
336                         byte[]  dcrMAC          = null;
337
338                         // Decrypt message
339                         this.context.Cipher.DecryptRecord(fragment, ref dcrFragment, ref dcrMAC);
340
341                         // Set new IV
342                         if (this.context.Cipher.CipherMode == CipherMode.CBC)
343                         {
344                                 byte[] iv = new byte[this.context.Cipher.IvSize];
345                                 System.Array.Copy(fragment, fragment.Length - iv.Length, iv, 0, iv.Length);
346                                 this.context.Cipher.UpdateServerCipherIV(iv);
347                         }
348                         
349                         // Check MAC code
350                         byte[] mac = this.context.Cipher.ComputeServerRecordMAC(contentType, dcrFragment);
351
352                         // Check that the mac is correct
353                         if (mac.Length != dcrMAC.Length)
354                         {
355                                 throw new TlsException("Invalid MAC received from server.");
356                         }
357                         for (int i = 0; i < mac.Length; i++)
358                         {
359                                 if (mac[i] != dcrMAC[i])
360                                 {
361                                         throw new TlsException("Invalid MAC received from server.");
362                                 }
363                         }
364
365                         // Update sequence number
366                         this.context.ReadSequenceNumber++;
367
368                         return new TlsStream(dcrFragment);
369                 }
370
371                 #endregion
372         }
373 }