copying the latest Sys.Web.Services from trunk.
[mono.git] / mcs / class / Mono.Security / Mono.Security.Protocol.Tls / RecordProtocol.cs
1 // Transport Security Layer (TLS)
2 // Copyright (c) 2003-2004 Carlos Guzman Alvarez
3
4 //
5 // Permission is hereby granted, free of charge, to any person obtaining
6 // a copy of this software and associated documentation files (the
7 // "Software"), to deal in the Software without restriction, including
8 // without limitation the rights to use, copy, modify, merge, publish,
9 // distribute, sublicense, and/or sell copies of the Software, and to
10 // permit persons to whom the Software is furnished to do so, subject to
11 // the following conditions:
12 // 
13 // The above copyright notice and this permission notice shall be
14 // included in all copies or substantial portions of the Software.
15 // 
16 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19 // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
20 // LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21 // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
22 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
23 //
24
25 using System;
26 using System.Collections;
27 using System.IO;
28 using System.Security.Cryptography;
29 using System.Security.Cryptography.X509Certificates;
30
31 using Mono.Security.Protocol.Tls.Handshake;
32
33 namespace Mono.Security.Protocol.Tls
34 {
35         internal abstract class RecordProtocol
36         {
37                 #region Fields
38
39                 protected Stream        innerStream;
40                 protected Context       context;
41
42                 #endregion
43
44                 #region Properties
45
46                 public Context Context
47                 {
48                         get { return this.context; }
49                         set { this.context = value; }
50                 }
51
52                 #endregion
53
54                 #region Constructors
55
56                 public RecordProtocol(Stream innerStream, Context context)
57                 {
58                         this.innerStream                        = innerStream;
59                         this.context                            = context;
60                         this.context.RecordProtocol = this;
61                 }
62
63                 #endregion
64
65                 #region Abstract Methods
66
67                 public abstract void SendRecord(HandshakeType type);
68                 protected abstract void ProcessHandshakeMessage(TlsStream handMsg);
69                 protected abstract void ProcessChangeCipherSpec();
70                                 
71                 #endregion
72
73                 #region Reveive Record Methods
74
75                 public byte[] ReceiveRecord(Stream record)
76                 {
77                         if (this.context.ConnectionEnd)
78                         {
79                                 throw new TlsException(
80                                         AlertDescription.InternalError,
81                                         "The session is finished and it's no longer valid.");
82                         }
83
84                         // Try to read the Record Content Type\r
85                         int type = record.ReadByte ();
86                         if (type == -1)
87                         {
88                                 return null;
89                         }
90
91                         // Set last handshake message received to None
92                         this.context.LastHandshakeMsg = HandshakeType.ClientHello;
93
94                         ContentType     contentType     = (ContentType)type;
95                         byte[] buffer = this.ReadRecordBuffer(type, record);\r
96                         if (buffer == null)
97                         {
98                                 // record incomplete (at the moment)\r
99                                 return null;
100                         }
101
102                         // Decrypt message contents if needed
103                         if (contentType == ContentType.Alert && buffer.Length == 2)
104                         {
105                         }
106                         else
107                         {
108                                 if (this.context.IsActual && contentType != ContentType.ChangeCipherSpec)
109                                 {
110                                         buffer = this.decryptRecordFragment(contentType, buffer);\r
111                                         DebugHelper.WriteLine("Decrypted record data", buffer);
112                                 }
113                         }
114
115                         // Process record
116                         switch (contentType)
117                         {
118                                 case ContentType.Alert:
119                                         this.ProcessAlert((AlertLevel)buffer [0], (AlertDescription)buffer [1]);\r
120                                         if (record.CanSeek) 
121                                         {
122                                                 // don't reprocess that memory block\r
123                                                 record.SetLength (0); 
124                                         }\r
125                                         buffer = null;
126                                         break;
127
128                                 case ContentType.ChangeCipherSpec:
129                                         this.ProcessChangeCipherSpec();
130                                         break;
131
132                                 case ContentType.ApplicationData:
133                                         break;
134
135                                 case ContentType.Handshake:\r
136                                         TlsStream message = new TlsStream (buffer);
137                                         while (!message.EOF)
138                                         {
139                                                 this.ProcessHandshakeMessage(message);
140                                         }
141
142                                         // Update handshakes of current messages
143                                         this.context.HandshakeMessages.Write(buffer);
144                                         break;
145
146 // FIXME / MCS bug - http://bugzilla.ximian.com/show_bug.cgi?id=67711
147 //                              case (ContentType)0x80:
148 //                                      this.context.HandshakeMessages.Write (result);
149 //                                      break;
150
151                                 default:
152                                         if (contentType != (ContentType)0x80)
153                                         {
154                                                 throw new TlsException(
155                                                         AlertDescription.UnexpectedMessage,
156                                                         "Unknown record received from server.");
157                                         }
158                                         this.context.HandshakeMessages.Write (buffer);
159                                         break;
160                         }
161
162                         return buffer;
163                 }\r
164 \r
165                 private byte[] ReadRecordBuffer (int contentType, Stream record)
166                 {
167                         switch (contentType)
168                         {
169                                 case 0x80:
170                                         return this.ReadClientHelloV2(record);
171
172                                 default:
173                                         if (!Enum.IsDefined(typeof(ContentType), (ContentType)contentType))
174                                         {
175                                                 throw new TlsException(AlertDescription.DecodeError);
176                                         }
177                                         return this.ReadStandardRecordBuffer(record);
178                         }
179                 }\r
180 \r
181                 private byte[] ReadClientHelloV2 (Stream record)
182                 {\r
183                         int msgLength = record.ReadByte ();\r
184                         // process further only if the whole record is available\r
185                         if (record.CanSeek && (msgLength + 1 > record.Length)) 
186                         {
187                                 return null;
188                         }
189 \r
190                         byte[] message = new byte[msgLength];\r
191                         record.Read (message, 0, msgLength);
192
193                         int msgType             = message [0];
194                         if (msgType != 1)
195                         {
196                                 throw new TlsException(AlertDescription.DecodeError);
197                         }
198                         int protocol = (message [1] << 8 | message [2]);
199                         int cipherSpecLength = (message [3] << 8 | message [4]);
200                         int sessionIdLength = (message [5] << 8 | message [6]);
201                         int challengeLength = (message [7] << 8 | message [8]);
202                         int length = (challengeLength > 32) ? 32 : challengeLength;
203
204                         // Read CipherSpecs
205                         byte[] cipherSpecV2 = new byte[cipherSpecLength];
206                         Buffer.BlockCopy (message, 9, cipherSpecV2, 0, cipherSpecLength);
207
208                         // Read session ID
209                         byte[] sessionId = new byte[sessionIdLength];
210                         Buffer.BlockCopy (message, 9 + cipherSpecLength, sessionId, 0, sessionIdLength);
211
212                         // Read challenge ID
213                         byte[] challenge = new byte[challengeLength];
214                         Buffer.BlockCopy (message, 9 + cipherSpecLength + sessionIdLength, challenge, 0, challengeLength);
215                 
216                         if (challengeLength < 16 || cipherSpecLength == 0 || (cipherSpecLength % 3) != 0)
217                         {
218                                 throw new TlsException(AlertDescription.DecodeError);
219                         }
220
221                         // Updated the Session ID
222                         if (sessionId.Length > 0)
223                         {
224                                 this.context.SessionId = sessionId;
225                         }
226
227                         // Update the protocol version
228                         this.Context.ChangeProtocol((short)protocol);
229
230                         // Select the Cipher suite
231                         this.ProcessCipherSpecV2Buffer(this.Context.SecurityProtocol, cipherSpecV2);
232
233                         // Updated the Client Random\r
234                         this.context.ClientRandom = new byte [32]; // Always 32\r
235                         // 1. if challenge is bigger than 32 bytes only use the last 32 bytes\r
236                         // 2. right justify (0) challenge in ClientRandom if less than 32\r
237                         Buffer.BlockCopy (challenge, challenge.Length - length, this.context.ClientRandom, 32 - length, length);\r
238 \r
239                         // Set 
240                         this.context.LastHandshakeMsg = HandshakeType.ClientHello;
241                         this.context.ProtocolNegotiated = true;
242
243                         return message;
244                 }\r
245 \r
246                 private byte[] ReadStandardRecordBuffer (Stream record)
247                 {
248                         short protocol  = this.ReadShort(record);
249                         short length    = this.ReadShort(record);
250
251                         // process further only if the whole record is available\r
252                         // note: the first 5 bytes aren't part of the length\r
253                         if (record.CanSeek && (length + 5 > record.Length)) 
254                         {
255                                 return null;
256                         }
257                         
258                         // Read Record data
259                         int             received        = 0;
260                         byte[]  buffer          = new byte[length];
261                         while (received != length)
262                         {
263                                 received += record.Read(buffer, received, buffer.Length - received);
264                         }
265
266                         // Check that the message has a valid protocol version
267                         if (protocol != this.context.Protocol && this.context.ProtocolNegotiated)
268                         {
269                                 throw new TlsException(
270                                         AlertDescription.ProtocolVersion, "Invalid protocol version on message received");
271                         }
272
273                         DebugHelper.WriteLine("Record data", buffer);
274
275                         return buffer;
276                 }
277
278                 private short ReadShort(Stream record)
279                 {
280                         byte[] b = new byte[2];
281                         record.Read(b, 0, b.Length);
282
283                         short val = BitConverter.ToInt16(b, 0);
284
285                         return System.Net.IPAddress.HostToNetworkOrder(val);
286                 }
287
288                 private void ProcessAlert(AlertLevel alertLevel, AlertDescription alertDesc)
289                 {
290                         switch (alertLevel)
291                         {
292                                 case AlertLevel.Fatal:
293                                         throw new TlsException(alertLevel, alertDesc);
294
295                                 case AlertLevel.Warning:
296                                 default:
297                                 switch (alertDesc)
298                                 {
299                                         case AlertDescription.CloseNotify:
300                                                 this.context.ConnectionEnd = true;
301                                                 break;
302                                 }
303                                 break;
304                         }
305                 }
306
307                 #endregion
308
309                 #region Send Alert Methods
310
311                 public void SendAlert(AlertDescription description)
312                 {
313                         this.SendAlert(new Alert(description));
314                 }
315
316                 public void SendAlert(
317                         AlertLevel                      level, 
318                         AlertDescription        description)
319                 {
320                         this.SendAlert(new Alert(level, description));
321                 }
322
323                 public void SendAlert(Alert alert)
324                 {
325                         DebugHelper.WriteLine(">>>> Write Alert ({0}|{1})", alert.Description, alert.Message);
326
327                         // Write record
328                         this.SendRecord(
329                                 ContentType.Alert, 
330                                 new byte[]{(byte)alert.Level, (byte)alert.Description});
331
332                         if (alert.IsCloseNotify)
333                         {
334                                 this.context.ConnectionEnd = true;
335                         }
336                 }
337
338                 #endregion
339
340                 #region Send Record Methods
341
342                 public void SendChangeCipherSpec()
343                 {
344                         DebugHelper.WriteLine(">>>> Write Change Cipher Spec");
345
346                         // Send Change Cipher Spec message as a plain message
347                         this.context.IsActual = false;
348
349                         // Send Change Cipher Spec message
350                         this.SendRecord(ContentType.ChangeCipherSpec, new byte[] {1});
351
352                         // Reset sequence numbers
353                         this.context.WriteSequenceNumber = 0;
354
355                         // Make the pending state to be the current state
356                         this.context.IsActual = true;
357
358                         // Send Finished message
359                         this.SendRecord(HandshakeType.Finished);                        
360                 }
361
362                 public void SendRecord(ContentType contentType, byte[] recordData)
363                 {
364                         if (this.context.ConnectionEnd)
365                         {
366                                 throw new TlsException(
367                                         AlertDescription.InternalError,
368                                         "The session is finished and it's no longer valid.");
369                         }
370
371                         byte[] record = this.EncodeRecord(contentType, recordData);
372
373                         this.innerStream.Write(record, 0, record.Length);
374                 }
375
376                 public byte[] EncodeRecord(ContentType contentType, byte[] recordData)
377                 {
378                         return this.EncodeRecord(
379                                 contentType,
380                                 recordData,
381                                 0,
382                                 recordData.Length);
383                 }
384
385                 public byte[] EncodeRecord(
386                         ContentType     contentType, 
387                         byte[]          recordData,
388                         int                     offset,
389                         int                     count)
390                 {
391                         if (this.context.ConnectionEnd)
392                         {
393                                 throw new TlsException(
394                                         AlertDescription.InternalError,
395                                         "The session is finished and it's no longer valid.");
396                         }
397
398                         TlsStream record = new TlsStream();
399
400                         int     position = offset;
401
402                         while (position < ( offset + count ))
403                         {
404                                 short   fragmentLength = 0;
405                                 byte[]  fragment;
406
407                                 if ((count - position) > Context.MAX_FRAGMENT_SIZE)
408                                 {
409                                         fragmentLength = Context.MAX_FRAGMENT_SIZE;
410                                 }
411                                 else
412                                 {
413                                         fragmentLength = (short)(count - position);
414                                 }
415
416                                 // Fill the fragment data
417                                 fragment = new byte[fragmentLength];
418                                 Buffer.BlockCopy(recordData, position, fragment, 0, fragmentLength);
419
420                                 if (this.context.IsActual)
421                                 {
422                                         // Encrypt fragment
423                                         fragment = this.encryptRecordFragment(contentType, fragment);
424                                 }
425
426                                 // Write tls message
427                                 record.Write((byte)contentType);
428                                 record.Write(this.context.Protocol);
429                                 record.Write((short)fragment.Length);
430                                 record.Write(fragment);
431
432                                 DebugHelper.WriteLine("Record data", fragment);
433
434                                 // Update buffer position
435                                 position += fragmentLength;
436                         }
437
438                         return record.ToArray();
439                 }
440                 
441                 #endregion
442
443                 #region Cryptography Methods
444
445                 private byte[] encryptRecordFragment(
446                         ContentType     contentType, 
447                         byte[]          fragment)
448                 {
449                         byte[] mac      = null;
450
451                         // Calculate message MAC
452                         if (this.Context is ClientContext)
453                         {
454                                 mac     = this.context.Cipher.ComputeClientRecordMAC(contentType, fragment);
455                         }       
456                         else
457                         {
458                                 mac     = this.context.Cipher.ComputeServerRecordMAC(contentType, fragment);
459                         }
460
461                         DebugHelper.WriteLine(">>>> Record MAC", mac);
462
463                         // Encrypt the message
464                         byte[] ecr = this.context.Cipher.EncryptRecord(fragment, mac);
465
466                         // Set new Client Cipher IV
467                         if (this.context.Cipher.CipherMode == CipherMode.CBC)
468                         {
469                                 byte[] iv = new byte[this.context.Cipher.IvSize];
470                                 Buffer.BlockCopy(ecr, ecr.Length - iv.Length, iv, 0, iv.Length);
471
472                                 this.context.Cipher.UpdateClientCipherIV(iv);
473                         }
474
475                         // Update sequence number
476                         this.context.WriteSequenceNumber++;
477
478                         return ecr;
479                 }
480
481                 private byte[] decryptRecordFragment(
482                         ContentType     contentType, 
483                         byte[]          fragment)
484                 {
485                         byte[]  dcrFragment             = null;
486                         byte[]  dcrMAC                  = null;
487                         bool    badRecordMac    = false;
488
489                         try
490                         {
491                                 this.context.Cipher.DecryptRecord(fragment, ref dcrFragment, ref dcrMAC);
492                         }
493                         catch
494                         {
495                                 if (this.context is ServerContext)
496                                 {
497                                         this.Context.RecordProtocol.SendAlert(AlertDescription.DecryptionFailed);
498                                 }
499
500                                 throw;
501                         }
502                         
503                         // Generate record MAC
504                         byte[] mac = null;
505
506                         if (this.Context is ClientContext)
507                         {
508                                 mac = this.context.Cipher.ComputeServerRecordMAC(contentType, dcrFragment);
509                         }
510                         else
511                         {
512                                 mac = this.context.Cipher.ComputeClientRecordMAC(contentType, dcrFragment);
513                         }
514
515                         DebugHelper.WriteLine(">>>> Record MAC", mac);
516
517                         // Check record MAC
518                         if (mac.Length != dcrMAC.Length)
519                         {
520                                 badRecordMac = true;
521                         }
522                         else
523                         {
524                                 for (int i = 0; i < mac.Length; i++)
525                                 {
526                                         if (mac[i] != dcrMAC[i])
527                                         {
528                                                 badRecordMac = true;
529                                                 break;
530                                         }
531                                 }
532                         }
533
534                         if (badRecordMac)
535                         {
536                                 throw new TlsException(AlertDescription.BadRecordMAC, "Bad record MAC");
537                         }
538
539                         // Update sequence number
540                         this.context.ReadSequenceNumber++;
541
542                         return dcrFragment;
543                 }
544
545                 #endregion
546
547                 #region CipherSpecV2 processing
548
549                 private void ProcessCipherSpecV2Buffer(SecurityProtocolType protocol, byte[] buffer)
550                 {
551                         TlsStream codes = new TlsStream(buffer);
552
553                         string prefix = (protocol == SecurityProtocolType.Ssl3) ? "SSL_" : "TLS_";
554
555                         while (codes.Position < codes.Length)
556                         {
557                                 byte check = codes.ReadByte();
558
559                                 if (check == 0)
560                                 {
561                                         // SSL/TLS cipher spec
562                                         short code = codes.ReadInt16(); 
563                                         int index = this.Context.SupportedCiphers.IndexOf(code);
564                                         if (index != -1)
565                                         {
566                                                 this.Context.Cipher     = this.Context.SupportedCiphers[index];
567                                                 break;
568                                         }
569                                 }
570                                 else
571                                 {
572                                         byte[] tmp = new byte[2];
573                                         codes.Read(tmp, 0, tmp.Length);
574
575                                         int tmpCode = ((check & 0xff) << 16) | ((tmp[0] & 0xff) << 8) | (tmp[1] & 0xff);
576                                         CipherSuite cipher = this.MapV2CipherCode(prefix, tmpCode);
577
578                                         if (cipher != null)
579                                         {
580                                                 this.Context.Cipher = cipher;
581                                                 break;
582                                         }
583                                 }
584                         }
585
586                         if (this.Context.Cipher == null)
587                         {
588                                 throw new TlsException(AlertDescription.InsuficientSecurity, "Insuficient Security");
589                         }
590                 }
591
592                 private CipherSuite MapV2CipherCode(string prefix, int code)
593                 {
594                         try
595                         {
596                                 switch (code)
597                                 {
598                                         case 65664:
599                                                 // TLS_RC4_128_WITH_MD5
600                                                 return this.Context.SupportedCiphers[prefix + "RSA_WITH_RC4_128_MD5"];
601                                         
602                                         case 131200:
603                                                 // TLS_RC4_128_EXPORT40_WITH_MD5
604                                                 return this.Context.SupportedCiphers[prefix + "RSA_EXPORT_WITH_RC4_40_MD5"];
605                                         
606                                         case 196736:
607                                                 // TLS_RC2_CBC_128_CBC_WITH_MD5
608                                                 return this.Context.SupportedCiphers[prefix + "RSA_EXPORT_WITH_RC2_CBC_40_MD5"];
609                                         
610                                         case 262272:
611                                                 // TLS_RC2_CBC_128_CBC_EXPORT40_WITH_MD5
612                                                 return this.Context.SupportedCiphers[prefix + "RSA_EXPORT_WITH_RC2_CBC_40_MD5"];
613                                         
614                                         case 327808:
615                                                 // TLS_IDEA_128_CBC_WITH_MD5
616                                                 return null;
617                                         
618                                         case 393280:
619                                                 // TLS_DES_64_CBC_WITH_MD5
620                                                 return null;
621
622                                         case 458944:
623                                                 // TLS_DES_192_EDE3_CBC_WITH_MD5
624                                                 return null;
625
626                                         default:
627                                                 return null;
628                                 }
629                         }
630                         catch
631                         {
632                                 return null;
633                         }
634                 }
635
636                 #endregion
637         }
638 }