2004-04-21 Carlos Guzman Alvarez <carlosga@telefonica.net>
[mono.git] / mcs / class / Mono.Security / Mono.Security.Protocol.Tls / RecordProtocol.cs
1 /* Transport Security Layer (TLS)
2  * Copyright (c) 2003-2004 Carlos Guzman Alvarez
3  * 
4  * Permission is hereby granted, free of charge, to any person 
5  * obtaining a copy of this software and associated documentation 
6  * files (the "Software"), to deal in the Software without restriction, 
7  * including without limitation the rights to use, copy, modify, merge, 
8  * publish, distribute, sublicense, and/or sell copies of the Software, 
9  * and to permit persons to whom the Software is furnished to do so, 
10  * subject to the following conditions:
11  * 
12  * The above copyright notice and this permission notice shall be included 
13  * in all copies or substantial portions of the Software.
14  * 
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 
16  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 
17  * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 
18  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 
19  * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 
20  * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 
22  * DEALINGS IN THE SOFTWARE.
23  */
24
25 using System;
26 using System.IO;
27 using System.Security.Cryptography;
28 using System.Security.Cryptography.X509Certificates;
29
30 using Mono.Security.Protocol.Tls.Handshake;
31
32 namespace Mono.Security.Protocol.Tls
33 {
34         internal abstract class RecordProtocol
35         {
36                 #region Fields
37
38                 protected Stream        innerStream;
39                 protected Context       context;
40
41                 #endregion
42
43                 #region Properties
44
45                 public Stream InnerStream
46                 {
47                         get { return this.innerStream; }
48                         set { this.innerStream = value; }
49                 }
50
51                 public Context Context
52                 {
53                         get { return this.context; }
54                         set { this.context = value; }
55                 }
56
57                 #endregion
58
59                 #region Constructors
60
61                 public RecordProtocol(Stream innerStream, Context context)
62                 {
63                         this.innerStream                        = innerStream;
64                         this.context                            = context;
65                         this.context.RecordProtocol = this;
66                 }
67
68                 #endregion
69
70                 #region Abstract Methods
71
72                 public abstract void SendRecord(HandshakeType type);
73                 protected abstract void ProcessHandshakeMessage(TlsStream handMsg);
74                 protected abstract void ProcessChangeCipherSpec();
75                                 
76                 #endregion
77
78                 #region Reveive Record Methods
79
80                 public byte[] ReceiveRecord()
81                 {
82                         if (this.context.ConnectionEnd)
83                         {
84                                 throw new TlsException(
85                                         AlertDescription.InternalError,
86                                         "The session is finished and it's no longer valid.");
87                         }
88                         
89                         // Try to read the Record Content Type
90                         int type = this.innerStream.ReadByte();
91
92                         // There are no more data for read
93                         if (type == -1)
94                         {
95                                 return null;
96                         }
97
98                         ContentType     contentType     = (ContentType)type;
99                         short           protocol        = this.readShort();
100                         short           length          = this.readShort();
101                         
102                         // Read Record data
103                         int             received        = 0;
104                         byte[]  buffer          = new byte[length];
105                         while (received != length)
106                         {
107                                 received += this.innerStream.Read(
108                                         buffer, received, buffer.Length - received);
109                         }
110
111                         DebugHelper.WriteLine(
112                                 ">>>> Read record ({0}|{1})", 
113                                 this.context.DecodeProtocolCode(protocol),
114                                 contentType);
115                         DebugHelper.WriteLine("Record data", buffer);
116
117                         TlsStream message = new TlsStream(buffer);
118                 
119                         // Check that the message has a valid protocol version
120                         if (protocol != this.context.Protocol && 
121                                 this.context.ProtocolNegotiated)
122                         {
123                                 throw new TlsException(
124                                         AlertDescription.ProtocolVersion,
125                                         "Invalid protocol version on message received from server");
126                         }
127
128                         // Decrypt message contents if needed
129                         if (contentType == ContentType.Alert && length == 2)
130                         {
131                         }
132                         else
133                         {
134                                 if (this.context.IsActual &&
135                                         contentType != ContentType.ChangeCipherSpec)
136                                 {
137                                         message = this.decryptRecordFragment(
138                                                 contentType, 
139                                                 message.ToArray());
140
141                                         DebugHelper.WriteLine("Decrypted record data", message.ToArray());
142                                 }
143                         }
144
145                         // Set last handshake message received to None
146                         this.context.LastHandshakeMsg = HandshakeType.None;
147                         
148                         // Process record
149                         byte[] result = message.ToArray();
150
151                         switch (contentType)
152                         {
153                                 case ContentType.Alert:
154                                         this.processAlert(
155                                                 (AlertLevel)message.ReadByte(),
156                                                 (AlertDescription)message.ReadByte());
157                                         break;
158
159                                 case ContentType.ChangeCipherSpec:
160                                         this.ProcessChangeCipherSpec();
161                                         break;
162
163                                 case ContentType.ApplicationData:
164                                         break;
165
166                                 case ContentType.Handshake:
167                                         while (!message.EOF)
168                                         {
169                                                 this.ProcessHandshakeMessage(message);
170                                         }
171
172                                         // Update handshakes of current messages
173                                         this.context.HandshakeMessages.Write(message.ToArray());
174                                         break;
175
176                                 default:
177                                         throw new TlsException(
178                                                 AlertDescription.UnexpectedMessage,
179                                                 "Unknown record received from server.");
180                         }
181
182                         return result;
183                 }
184
185                 private short readShort()
186                 {
187                         byte[] b = new byte[2];
188                         this.innerStream.Read(b, 0, b.Length);
189
190                         short val = BitConverter.ToInt16(b, 0);
191
192                         return System.Net.IPAddress.HostToNetworkOrder(val);
193                 }
194
195                 private void processAlert(
196                         AlertLevel                      alertLevel, 
197                         AlertDescription        alertDesc)
198                 {
199                         switch (alertLevel)
200                         {
201                                 case AlertLevel.Fatal:
202                                         throw new TlsException(alertLevel, alertDesc);
203
204                                 case AlertLevel.Warning:
205                                 default:
206                                 switch (alertDesc)
207                                 {
208                                         case AlertDescription.CloseNotify:
209                                                 this.context.ConnectionEnd = true;
210                                                 break;
211                                 }
212                                 break;
213                         }
214                 }
215
216                 #endregion
217
218                 #region Send Alert Methods
219
220                 public void SendAlert(AlertDescription description)
221                 {
222                         this.SendAlert(new Alert(description));
223                 }
224
225                 public void SendAlert(
226                         AlertLevel                      level, 
227                         AlertDescription        description)
228                 {
229                         this.SendAlert(new Alert(level, description));
230                 }
231
232                 public void SendAlert(Alert alert)
233                 {
234                         DebugHelper.WriteLine(">>>> Write Alert ({0}|{1})", alert.Description, alert.Message);
235
236                         // Write record
237                         this.SendRecord(
238                                 ContentType.Alert, 
239                                 new byte[]{(byte)alert.Level, (byte)alert.Description});
240
241                         if (alert.IsCloseNotify)
242                         {
243                                 this.context.ConnectionEnd = true;
244                         }
245                 }
246
247                 #endregion
248
249                 #region Send Record Methods
250
251                 public void SendChangeCipherSpec()
252                 {
253                         DebugHelper.WriteLine(">>>> Write Change Cipher Spec");
254
255                         // Send Change Cipher Spec message as a plain message
256                         this.context.IsActual = false;
257
258                         // Send Change Cipher Spec message
259                         this.SendRecord(ContentType.ChangeCipherSpec, new byte[] {1});
260
261                         // Reset sequence numbers
262                         this.context.WriteSequenceNumber = 0;
263
264                         // Make the pending state to be the current state
265                         this.context.IsActual = true;
266
267                         // Send Finished message
268                         this.SendRecord(HandshakeType.Finished);                        
269                 }
270
271                 public void SendRecord(ContentType contentType, byte[] recordData)
272                 {
273                         if (this.context.ConnectionEnd)
274                         {
275                                 throw new TlsException(
276                                         AlertDescription.InternalError,
277                                         "The session is finished and it's no longer valid.");
278                         }
279
280                         byte[] record = this.EncodeRecord(contentType, recordData);
281
282                         this.innerStream.Write(record, 0, record.Length);
283                 }
284
285                 public byte[] EncodeRecord(ContentType contentType, byte[] recordData)
286                 {
287                         return this.EncodeRecord(
288                                 contentType,
289                                 recordData,
290                                 0,
291                                 recordData.Length);
292                 }
293
294                 public byte[] EncodeRecord(
295                         ContentType     contentType, 
296                         byte[]          recordData,
297                         int                     offset,
298                         int                     count)
299                 {
300                         if (this.context.ConnectionEnd)
301                         {
302                                 throw new TlsException(
303                                         AlertDescription.InternalError,
304                                         "The session is finished and it's no longer valid.");
305                         }
306
307                         TlsStream record = new TlsStream();
308
309                         int     position = offset;
310
311                         while (position < ( offset + count ))
312                         {
313                                 short   fragmentLength = 0;
314                                 byte[]  fragment;
315
316                                 if ((count - position) > Context.MAX_FRAGMENT_SIZE)
317                                 {
318                                         fragmentLength = Context.MAX_FRAGMENT_SIZE;
319                                 }
320                                 else
321                                 {
322                                         fragmentLength = (short)(count - position);
323                                 }
324
325                                 // Fill the fragment data
326                                 fragment = new byte[fragmentLength];
327                                 Buffer.BlockCopy(recordData, position, fragment, 0, fragmentLength);
328
329                                 if (this.context.IsActual)
330                                 {
331                                         // Encrypt fragment
332                                         fragment = this.encryptRecordFragment(contentType, fragment);
333                                 }
334
335                                 // Write tls message
336                                 record.Write((byte)contentType);
337                                 record.Write(this.context.Protocol);
338                                 record.Write((short)fragment.Length);
339                                 record.Write(fragment);
340
341                                 DebugHelper.WriteLine("Record data", fragment);
342
343                                 // Update buffer position
344                                 position += fragmentLength;
345                         }
346
347                         return record.ToArray();
348                 }
349                 
350                 #endregion
351
352                 #region Cryptography Methods
353
354                 private byte[] encryptRecordFragment(
355                         ContentType     contentType, 
356                         byte[]          fragment)
357                 {
358                         byte[] mac      = null;
359
360                         // Calculate message MAC
361                         if (this.Context is ClientContext)
362                         {
363                                 mac     = this.context.Cipher.ComputeClientRecordMAC(contentType, fragment);
364                         }       
365                         else
366                         {
367                                 mac     = this.context.Cipher.ComputeServerRecordMAC(contentType, fragment);
368                         }
369
370                         DebugHelper.WriteLine(">>>> Record MAC", mac);
371
372                         // Encrypt the message
373                         byte[] ecr = this.context.Cipher.EncryptRecord(fragment, mac);
374
375                         // Set new Client Cipher IV
376                         if (this.context.Cipher.CipherMode == CipherMode.CBC)
377                         {
378                                 byte[] iv = new byte[this.context.Cipher.IvSize];
379                                 Buffer.BlockCopy(ecr, ecr.Length - iv.Length, iv, 0, iv.Length);
380
381                                 this.context.Cipher.UpdateClientCipherIV(iv);
382                         }
383
384                         // Update sequence number
385                         this.context.WriteSequenceNumber++;
386
387                         return ecr;
388                 }
389
390                 private TlsStream decryptRecordFragment(
391                         ContentType     contentType, 
392                         byte[]          fragment)
393                 {
394                         byte[]  dcrFragment             = null;
395                         byte[]  dcrMAC                  = null;
396                         bool    badRecordMac    = false;
397
398                         try
399                         {
400                                 this.context.Cipher.DecryptRecord(fragment, ref dcrFragment, ref dcrMAC);
401                         }
402                         catch
403                         {
404                                 if (this.context is ServerContext)
405                                 {
406                                         this.Context.RecordProtocol.SendAlert(AlertDescription.DecryptionFailed);
407                                 }
408
409                                 throw;
410                         }
411                         
412                         // Generate record MAC
413                         byte[] mac = null;
414
415                         if (this.Context is ClientContext)
416                         {
417                                 mac = this.context.Cipher.ComputeServerRecordMAC(contentType, dcrFragment);
418                         }
419                         else
420                         {
421                                 mac = this.context.Cipher.ComputeClientRecordMAC(contentType, dcrFragment);
422                         }
423
424                         DebugHelper.WriteLine(">>>> Record MAC", mac);
425
426                         // Check record MAC
427                         if (mac.Length != dcrMAC.Length)
428                         {
429                                 badRecordMac = true;
430                         }
431                         else
432                         {
433                                 for (int i = 0; i < mac.Length; i++)
434                                 {
435                                         if (mac[i] != dcrMAC[i])
436                                         {
437                                                 badRecordMac = true;
438                                                 break;
439                                         }
440                                 }
441                         }
442
443                         if (badRecordMac)
444                         {
445                                 throw new TlsException(AlertDescription.BadRecordMAC, "Bad record MAC");
446                         }
447
448                         // Update sequence number
449                         this.context.ReadSequenceNumber++;
450
451                         return new TlsStream(dcrFragment);
452                 }
453
454                 #endregion
455         }
456 }