2004-11-23 Sebastien Pouliot <sebastien@ximian.com>
[mono.git] / mcs / class / Mono.Security / Mono.Security.Protocol.Tls / RecordProtocol.cs
1 // Transport Security Layer (TLS)
2 // Copyright (c) 2003-2004 Carlos Guzman Alvarez
3
4 //
5 // Permission is hereby granted, free of charge, to any person obtaining
6 // a copy of this software and associated documentation files (the
7 // "Software"), to deal in the Software without restriction, including
8 // without limitation the rights to use, copy, modify, merge, publish,
9 // distribute, sublicense, and/or sell copies of the Software, and to
10 // permit persons to whom the Software is furnished to do so, subject to
11 // the following conditions:
12 // 
13 // The above copyright notice and this permission notice shall be
14 // included in all copies or substantial portions of the Software.
15 // 
16 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19 // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
20 // LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21 // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
22 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
23 //
24
25 using System;
26 using System.Collections;
27 using System.IO;
28 using System.Security.Cryptography;
29 using System.Security.Cryptography.X509Certificates;
30
31 using Mono.Security.Protocol.Tls.Handshake;
32
33 namespace Mono.Security.Protocol.Tls
34 {
35         internal abstract class RecordProtocol
36         {
37                 #region Fields
38
39                 protected Stream        innerStream;
40                 protected Context       context;
41
42                 #endregion
43
44                 #region Properties
45
46                 public Context Context
47                 {
48                         get { return this.context; }
49                         set { this.context = value; }
50                 }
51
52                 #endregion
53
54                 #region Constructors
55
56                 public RecordProtocol(Stream innerStream, Context context)
57                 {
58                         this.innerStream                        = innerStream;
59                         this.context                            = context;
60                         this.context.RecordProtocol = this;
61                 }
62
63                 #endregion
64
65                 #region Abstract Methods
66
67                 public abstract void SendRecord(HandshakeType type);
68                 protected abstract void ProcessHandshakeMessage(TlsStream handMsg);
69                 protected abstract void ProcessChangeCipherSpec();
70                                 
71                 #endregion
72
73                 #region Reveive Record Methods
74
75                 public byte[] ReceiveRecord()
76                 {
77                         if (this.context.ConnectionEnd)
78                         {
79                                 throw new TlsException(
80                                         AlertDescription.InternalError,
81                                         "The session is finished and it's no longer valid.");
82                         }
83         
84                         // Try to read the Record Content Type
85                         int type = this.innerStream.ReadByte();
86                         if (type == -1)
87                         {
88                                 return null;
89                         }
90
91                         // Set last handshake message received to None
92                         this.context.LastHandshakeMsg = HandshakeType.ClientHello;
93
94                         ContentType     contentType     = (ContentType)type;
95                         byte[] buffer = this.ReadRecordBuffer(type);
96
97                         TlsStream message = new TlsStream(buffer);
98                 
99                         // Decrypt message contents if needed
100                         if (contentType == ContentType.Alert && buffer.Length == 2)
101                         {
102                         }
103                         else
104                         {
105                                 if (this.context.IsActual && contentType != ContentType.ChangeCipherSpec)
106                                 {
107                                         message = this.decryptRecordFragment(contentType, message.ToArray());
108
109                                         DebugHelper.WriteLine("Decrypted record data", message.ToArray());
110                                 }
111                         }
112
113                         // Process record
114                         byte[] result = message.ToArray();
115
116                         switch (contentType)
117                         {
118                                 case ContentType.Alert:
119                                         this.ProcessAlert((AlertLevel)message.ReadByte(), (AlertDescription)message.ReadByte());
120                                         result = null;
121                                         break;
122
123                                 case ContentType.ChangeCipherSpec:
124                                         this.ProcessChangeCipherSpec();
125                                         break;
126
127                                 case ContentType.ApplicationData:
128                                         break;
129
130                                 case ContentType.Handshake:
131                                         while (!message.EOF)
132                                         {
133                                                 this.ProcessHandshakeMessage(message);
134                                         }
135
136                                         // Update handshakes of current messages
137                                         this.context.HandshakeMessages.Write(message.ToArray());
138                                         break;
139
140 // FIXME / MCS bug - http://bugzilla.ximian.com/show_bug.cgi?id=67711
141 //                              case (ContentType)0x80:
142 //                                      this.context.HandshakeMessages.Write (result);
143 //                                      break;
144
145                                 default:
146                                         if (contentType != (ContentType)0x80)
147                                         {
148                                                 throw new TlsException(
149                                                         AlertDescription.UnexpectedMessage,
150                                                         "Unknown record received from server.");
151                                         }
152                                         this.context.HandshakeMessages.Write (result);
153                                         break;
154                         }
155
156                         return result;
157                 }
158
159                 private byte[] ReadRecordBuffer(int contentType)
160                 {
161                         switch (contentType)
162                         {
163                                 case 0x80:
164                                         return this.ReadClientHelloV2();
165
166                                 default:
167                                         if (!Enum.IsDefined(typeof(ContentType), (ContentType)contentType))
168                                         {
169                                                 throw new TlsException(AlertDescription.DecodeError);
170                                         }
171                                         return this.ReadStandardRecordBuffer();
172                         }
173                 }
174
175                 private byte[] ReadClientHelloV2()
176                 {
177                         int msgLength                   = this.innerStream.ReadByte();
178                         byte[] message = new byte [msgLength];
179                         this.innerStream.Read (message, 0, msgLength);
180
181                         int msgType             = message [0];
182                         if (msgType != 1)
183                         {
184                                 throw new TlsException(AlertDescription.DecodeError);
185                         }
186                         int protocol = (message [1] << 8 | message [2]);
187                         int cipherSpecLength = (message [3] << 8 | message [4]);
188                         int sessionIdLength = (message [5] << 8 | message [6]);
189                         int challengeLength = (message [7] << 8 | message [8]);
190                         int length = (challengeLength > 32) ? 32 : challengeLength;
191
192                         // Read CipherSpecs
193                         byte[] cipherSpecV2 = new byte[cipherSpecLength];
194                         Buffer.BlockCopy (message, 9, cipherSpecV2, 0, cipherSpecLength);
195
196                         // Read session ID
197                         byte[] sessionId = new byte[sessionIdLength];
198                         Buffer.BlockCopy (message, 9 + cipherSpecLength, sessionId, 0, sessionIdLength);
199
200                         // Read challenge ID
201                         byte[] challenge = new byte[challengeLength];
202                         Buffer.BlockCopy (message, 9 + cipherSpecLength + sessionIdLength, challenge, 0, challengeLength);
203                 
204                         if (challengeLength < 16 || cipherSpecLength == 0 || (cipherSpecLength % 3) != 0)
205                         {
206                                 throw new TlsException(AlertDescription.DecodeError);
207                         }
208
209                         // Updated the Session ID
210                         if (sessionId.Length > 0)
211                         {
212                                 this.context.SessionId = sessionId;
213                         }
214
215                         // Update the protocol version
216                         this.Context.ChangeProtocol((short)protocol);
217
218                         // Select the Cipher suite
219                         this.ProcessCipherSpecV2Buffer(this.Context.SecurityProtocol, cipherSpecV2);
220
221                         // Updated the Client Random\r
222                         this.context.ClientRandom = new byte [32]; // Always 32\r
223                         // 1. if challenge is bigger than 32 bytes only use the last 32 bytes\r
224                         // 2. right justify (0) challenge in ClientRandom if less than 32\r
225                         Buffer.BlockCopy (challenge, challenge.Length - length, this.context.ClientRandom, 32 - length, length);\r
226 \r
227                         // Set 
228                         this.context.LastHandshakeMsg = HandshakeType.ClientHello;
229                         this.context.ProtocolNegotiated = true;
230
231                         return message;
232                 }
233
234                 private byte[] ReadStandardRecordBuffer()
235                 {
236                         short protocol  = this.ReadShort();
237                         short length    = this.ReadShort();
238                         
239                         // Read Record data
240                         int             received        = 0;
241                         byte[]  buffer          = new byte[length];
242                         while (received != length)
243                         {
244                                 received += this.innerStream.Read(buffer, received, buffer.Length - received);
245                         }
246
247                         // Check that the message has a valid protocol version
248                         if (protocol != this.context.Protocol && this.context.ProtocolNegotiated)
249                         {
250                                 throw new TlsException(
251                                         AlertDescription.ProtocolVersion, "Invalid protocol version on message received");
252                         }
253
254                         DebugHelper.WriteLine("Record data", buffer);
255
256                         return buffer;
257                 }
258
259                 private short ReadShort()
260                 {
261                         byte[] b = new byte[2];
262                         this.innerStream.Read(b, 0, b.Length);
263
264                         short val = BitConverter.ToInt16(b, 0);
265
266                         return System.Net.IPAddress.HostToNetworkOrder(val);
267                 }
268
269                 private void ProcessAlert(AlertLevel alertLevel, AlertDescription alertDesc)
270                 {
271                         switch (alertLevel)
272                         {
273                                 case AlertLevel.Fatal:
274                                         throw new TlsException(alertLevel, alertDesc);
275
276                                 case AlertLevel.Warning:
277                                 default:
278                                 switch (alertDesc)
279                                 {
280                                         case AlertDescription.CloseNotify:
281                                                 this.context.ConnectionEnd = true;
282                                                 break;
283                                 }
284                                 break;
285                         }
286                 }
287
288                 #endregion
289
290                 #region Send Alert Methods
291
292                 public void SendAlert(AlertDescription description)
293                 {
294                         this.SendAlert(new Alert(description));
295                 }
296
297                 public void SendAlert(
298                         AlertLevel                      level, 
299                         AlertDescription        description)
300                 {
301                         this.SendAlert(new Alert(level, description));
302                 }
303
304                 public void SendAlert(Alert alert)
305                 {
306                         DebugHelper.WriteLine(">>>> Write Alert ({0}|{1})", alert.Description, alert.Message);
307
308                         // Write record
309                         this.SendRecord(
310                                 ContentType.Alert, 
311                                 new byte[]{(byte)alert.Level, (byte)alert.Description});
312
313                         if (alert.IsCloseNotify)
314                         {
315                                 this.context.ConnectionEnd = true;
316                         }
317                 }
318
319                 #endregion
320
321                 #region Send Record Methods
322
323                 public void SendChangeCipherSpec()
324                 {
325                         DebugHelper.WriteLine(">>>> Write Change Cipher Spec");
326
327                         // Send Change Cipher Spec message as a plain message
328                         this.context.IsActual = false;
329
330                         // Send Change Cipher Spec message
331                         this.SendRecord(ContentType.ChangeCipherSpec, new byte[] {1});
332
333                         // Reset sequence numbers
334                         this.context.WriteSequenceNumber = 0;
335
336                         // Make the pending state to be the current state
337                         this.context.IsActual = true;
338
339                         // Send Finished message
340                         this.SendRecord(HandshakeType.Finished);                        
341                 }
342
343                 public void SendRecord(ContentType contentType, byte[] recordData)
344                 {
345                         if (this.context.ConnectionEnd)
346                         {
347                                 throw new TlsException(
348                                         AlertDescription.InternalError,
349                                         "The session is finished and it's no longer valid.");
350                         }
351
352                         byte[] record = this.EncodeRecord(contentType, recordData);
353
354                         this.innerStream.Write(record, 0, record.Length);
355                 }
356
357                 public byte[] EncodeRecord(ContentType contentType, byte[] recordData)
358                 {
359                         return this.EncodeRecord(
360                                 contentType,
361                                 recordData,
362                                 0,
363                                 recordData.Length);
364                 }
365
366                 public byte[] EncodeRecord(
367                         ContentType     contentType, 
368                         byte[]          recordData,
369                         int                     offset,
370                         int                     count)
371                 {
372                         if (this.context.ConnectionEnd)
373                         {
374                                 throw new TlsException(
375                                         AlertDescription.InternalError,
376                                         "The session is finished and it's no longer valid.");
377                         }
378
379                         TlsStream record = new TlsStream();
380
381                         int     position = offset;
382
383                         while (position < ( offset + count ))
384                         {
385                                 short   fragmentLength = 0;
386                                 byte[]  fragment;
387
388                                 if ((count - position) > Context.MAX_FRAGMENT_SIZE)
389                                 {
390                                         fragmentLength = Context.MAX_FRAGMENT_SIZE;
391                                 }
392                                 else
393                                 {
394                                         fragmentLength = (short)(count - position);
395                                 }
396
397                                 // Fill the fragment data
398                                 fragment = new byte[fragmentLength];
399                                 Buffer.BlockCopy(recordData, position, fragment, 0, fragmentLength);
400
401                                 if (this.context.IsActual)
402                                 {
403                                         // Encrypt fragment
404                                         fragment = this.encryptRecordFragment(contentType, fragment);
405                                 }
406
407                                 // Write tls message
408                                 record.Write((byte)contentType);
409                                 record.Write(this.context.Protocol);
410                                 record.Write((short)fragment.Length);
411                                 record.Write(fragment);
412
413                                 DebugHelper.WriteLine("Record data", fragment);
414
415                                 // Update buffer position
416                                 position += fragmentLength;
417                         }
418
419                         return record.ToArray();
420                 }
421                 
422                 #endregion
423
424                 #region Cryptography Methods
425
426                 private byte[] encryptRecordFragment(
427                         ContentType     contentType, 
428                         byte[]          fragment)
429                 {
430                         byte[] mac      = null;
431
432                         // Calculate message MAC
433                         if (this.Context is ClientContext)
434                         {
435                                 mac     = this.context.Cipher.ComputeClientRecordMAC(contentType, fragment);
436                         }       
437                         else
438                         {
439                                 mac     = this.context.Cipher.ComputeServerRecordMAC(contentType, fragment);
440                         }
441
442                         DebugHelper.WriteLine(">>>> Record MAC", mac);
443
444                         // Encrypt the message
445                         byte[] ecr = this.context.Cipher.EncryptRecord(fragment, mac);
446
447                         // Set new Client Cipher IV
448                         if (this.context.Cipher.CipherMode == CipherMode.CBC)
449                         {
450                                 byte[] iv = new byte[this.context.Cipher.IvSize];
451                                 Buffer.BlockCopy(ecr, ecr.Length - iv.Length, iv, 0, iv.Length);
452
453                                 this.context.Cipher.UpdateClientCipherIV(iv);
454                         }
455
456                         // Update sequence number
457                         this.context.WriteSequenceNumber++;
458
459                         return ecr;
460                 }
461
462                 private TlsStream decryptRecordFragment(
463                         ContentType     contentType, 
464                         byte[]          fragment)
465                 {
466                         byte[]  dcrFragment             = null;
467                         byte[]  dcrMAC                  = null;
468                         bool    badRecordMac    = false;
469
470                         try
471                         {
472                                 this.context.Cipher.DecryptRecord(fragment, ref dcrFragment, ref dcrMAC);
473                         }
474                         catch
475                         {
476                                 if (this.context is ServerContext)
477                                 {
478                                         this.Context.RecordProtocol.SendAlert(AlertDescription.DecryptionFailed);
479                                 }
480
481                                 throw;
482                         }
483                         
484                         // Generate record MAC
485                         byte[] mac = null;
486
487                         if (this.Context is ClientContext)
488                         {
489                                 mac = this.context.Cipher.ComputeServerRecordMAC(contentType, dcrFragment);
490                         }
491                         else
492                         {
493                                 mac = this.context.Cipher.ComputeClientRecordMAC(contentType, dcrFragment);
494                         }
495
496                         DebugHelper.WriteLine(">>>> Record MAC", mac);
497
498                         // Check record MAC
499                         if (mac.Length != dcrMAC.Length)
500                         {
501                                 badRecordMac = true;
502                         }
503                         else
504                         {
505                                 for (int i = 0; i < mac.Length; i++)
506                                 {
507                                         if (mac[i] != dcrMAC[i])
508                                         {
509                                                 badRecordMac = true;
510                                                 break;
511                                         }
512                                 }
513                         }
514
515                         if (badRecordMac)
516                         {
517                                 throw new TlsException(AlertDescription.BadRecordMAC, "Bad record MAC");
518                         }
519
520                         // Update sequence number
521                         this.context.ReadSequenceNumber++;
522
523                         return new TlsStream(dcrFragment);
524                 }
525
526                 #endregion
527
528                 #region CipherSpecV2 processing
529
530                 private void ProcessCipherSpecV2Buffer(SecurityProtocolType protocol, byte[] buffer)
531                 {
532                         TlsStream codes = new TlsStream(buffer);
533
534                         string prefix = (protocol == SecurityProtocolType.Ssl3) ? "SSL_" : "TLS_";
535
536                         while (codes.Position < codes.Length)
537                         {
538                                 byte check = codes.ReadByte();
539
540                                 if (check == 0)
541                                 {
542                                         // SSL/TLS cipher spec
543                                         short code = codes.ReadInt16(); 
544                                         int index = this.Context.SupportedCiphers.IndexOf(code);
545                                         if (index != -1)
546                                         {
547                                                 this.Context.Cipher     = this.Context.SupportedCiphers[index];
548                                                 break;
549                                         }
550                                 }
551                                 else
552                                 {
553                                         byte[] tmp = new byte[2];
554                                         codes.Read(tmp, 0, tmp.Length);
555
556                                         int tmpCode = ((check & 0xff) << 16) | ((tmp[0] & 0xff) << 8) | (tmp[1] & 0xff);
557                                         CipherSuite cipher = this.MapV2CipherCode(prefix, tmpCode);
558
559                                         if (cipher != null)
560                                         {
561                                                 this.Context.Cipher = cipher;
562                                                 break;
563                                         }
564                                 }
565                         }
566
567                         if (this.Context.Cipher == null)
568                         {
569                                 throw new TlsException(AlertDescription.InsuficientSecurity, "Insuficient Security");
570                         }
571                 }
572
573                 private CipherSuite MapV2CipherCode(string prefix, int code)
574                 {
575                         try
576                         {
577                                 switch (code)
578                                 {
579                                         case 65664:
580                                                 // TLS_RC4_128_WITH_MD5
581                                                 return this.Context.SupportedCiphers[prefix + "RSA_WITH_RC4_128_MD5"];
582                                         
583                                         case 131200:
584                                                 // TLS_RC4_128_EXPORT40_WITH_MD5
585                                                 return this.Context.SupportedCiphers[prefix + "RSA_EXPORT_WITH_RC4_40_MD5"];
586                                         
587                                         case 196736:
588                                                 // TLS_RC2_CBC_128_CBC_WITH_MD5
589                                                 return this.Context.SupportedCiphers[prefix + "RSA_EXPORT_WITH_RC2_CBC_40_MD5"];
590                                         
591                                         case 262272:
592                                                 // TLS_RC2_CBC_128_CBC_EXPORT40_WITH_MD5
593                                                 return this.Context.SupportedCiphers[prefix + "RSA_EXPORT_WITH_RC2_CBC_40_MD5"];
594                                         
595                                         case 327808:
596                                                 // TLS_IDEA_128_CBC_WITH_MD5
597                                                 return null;
598                                         
599                                         case 393280:
600                                                 // TLS_DES_64_CBC_WITH_MD5
601                                                 return null;
602
603                                         case 458944:
604                                                 // TLS_DES_192_EDE3_CBC_WITH_MD5
605                                                 return null;
606
607                                         default:
608                                                 return null;
609                                 }
610                         }
611                         catch
612                         {
613                                 return null;
614                         }
615                 }
616
617                 #endregion
618         }
619 }