7d76afda0ba5900552d0406d966bf36b2557bdc8
[mono.git] / mcs / class / Mono.Security / Mono.Security.Protocol.Tls / TlsSocket.cs
1 /* Transport Security Layer (TLS)
2  * Copyright (c) 2003 Carlos Guzmán Álvarez
3  * 
4  * Permission is hereby granted, free of charge, to any person 
5  * obtaining a copy of this software and associated documentation 
6  * files (the "Software"), to deal in the Software without restriction, 
7  * including without limitation the rights to use, copy, modify, merge, 
8  * publish, distribute, sublicense, and/or sell copies of the Software, 
9  * and to permit persons to whom the Software is furnished to do so, 
10  * subject to the following conditions:
11  * 
12  * The above copyright notice and this permission notice shall be included 
13  * in all copies or substantial portions of the Software.
14  * 
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 
16  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 
17  * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 
18  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 
19  * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 
20  * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 
22  * DEALINGS IN THE SOFTWARE.
23  */
24
25 using System;
26 using System.IO;
27 using System.Net;
28 using System.Collections;
29 using System.Net.Sockets;
30 using System.Security.Cryptography;
31
32 using Mono.Security.Protocol.Tls;
33 using Mono.Security.Protocol.Tls.Alerts;
34 using Mono.Security.Protocol.Tls.Handshake;
35 using Mono.Security.Protocol.Tls.Handshake.Client;
36
37 namespace Mono.Security.Protocol.Tls
38 {
39         public sealed class TlsSocket : Socket
40         {
41                 #region FIELDS
42
43                 private TlsSession              session;
44                 private BufferedStream  inputBuffer;
45
46                 #endregion
47
48                 #region PROPERTIES
49
50                 internal TlsSession Session
51                 {
52                         get { return this.session; }
53                 }
54
55                 internal BufferedStream InputBuffer
56                 {
57                         get { return inputBuffer; }
58                 }
59
60                 #endregion
61
62                 #region CONSTRUCTORS
63
64                 private TlsSocket(
65                                 AddressFamily   addressFamily,
66                                 SocketType              socketType,
67                                 ProtocolType    protocolType
68                                 ) : base(addressFamily, socketType, protocolType)
69                 {
70                         this.inputBuffer = new BufferedStream(new MemoryStream());
71                 }
72
73                 public TlsSocket(
74                         TlsSession              session,
75                         AddressFamily   addressFamily,
76                         SocketType              socketType,
77                         ProtocolType    protocolType
78                         ) : this(addressFamily, socketType, protocolType)
79                 {
80                         this.session = session;
81                 }
82
83                 #endregion
84
85                 #region REPLACED_METHODS
86
87                 public new void Close()
88                 {
89                         this.resetBuffer();
90                         base.Close();
91                         if (this.session.State != TlsSessionState.Closing &&
92                                 this.session.State != TlsSessionState.Closed)
93                         {
94                                 this.session.Close();
95                         }
96                 }
97
98                 public new int Receive(byte[] buffer)
99                 {
100                         return this.Receive(buffer, 0, buffer != null ? buffer.Length : 0, SocketFlags.None);
101                 }
102                                                 
103                 public new int Receive(byte[] buffer, SocketFlags socketFlags)
104                 {
105                         return this.Receive(buffer, 0, buffer != null ? buffer.Length : 0, socketFlags);
106                 }
107
108                 public new int Receive(byte[] buffer, int size, SocketFlags socketFlags)
109                 {
110                         return this.Receive(buffer, 0, size, socketFlags);
111                 }
112
113                 public new int Receive(byte[] buffer, int offset, int size, SocketFlags socketFlags)
114                 {
115                         if (!session.IsSecure)
116                         {
117                                 return base.Receive(buffer, offset, size, socketFlags);
118                         }
119                         
120                         // If actual buffer is full readed reset it
121                         if (inputBuffer.Position == inputBuffer.Length)
122                         {
123                                 this.resetBuffer();
124                         }
125
126                         // Check if we have space in the middle buffer
127                         // if not Read next TLS record and update the inputBuffer
128                         while ((inputBuffer.Length - inputBuffer.Position) < size)
129                         {
130                                 // Read next record and write it into the inputBuffer
131                                 long    position        = inputBuffer.Position;                                 
132                                 byte[]  record          = this.receiveRecord();
133
134                                 if (record.Length > 0)
135                                 {
136                                         // Write new data to the inputBuffer
137                                         inputBuffer.Seek(0, SeekOrigin.End);
138                                         inputBuffer.Write(record, 0, record.Length);
139
140                                         // Restore buffer position
141                                         inputBuffer.Seek(position, SeekOrigin.Begin);
142                                 }
143
144                                 if (base.Available == 0)
145                                 {
146                                         break;
147                                 }
148                         }
149
150                         return inputBuffer.Read(buffer, offset, size);
151                 }
152
153                 public new int Send(byte[] buffer)
154                 {
155                         return this.Send(buffer, 0, buffer != null ? buffer.Length : 0, SocketFlags.None);
156                 }
157                                                 
158                 public new int Send(byte[] buffer, SocketFlags socketFlags)
159                 {
160                         return this.Send(buffer, 0, buffer != null ? buffer.Length : 0, socketFlags);
161                 }
162
163                 public new int Send(byte[] buffer, int size, SocketFlags socketFlags)
164                 {
165                         return this.Send(buffer, 0, size, socketFlags);
166                 }
167
168                 public new int Send(byte[] buffer, int offset, int size, SocketFlags socketFlags)
169                 {
170                         if (!session.IsSecure)
171                         {
172                                 return base.Send(buffer, offset, size, socketFlags);
173                         }
174
175                         // Send the buffer as a TLS record
176                         byte[] recordData = new byte[size];
177                         System.Array.Copy(buffer, offset, recordData, 0, size);
178
179                         return this.sendRecord(TlsContentType.ApplicationData, recordData);
180                 }
181
182                 #endregion
183
184                 #region TLS_RECORD_METHODS
185
186                 private byte[] receiveRecord()
187                 {
188                         if (session.Context.ConnectionEnd)
189                         {
190                                 throw session.CreateException("The session is finished and it's no longer valid.");
191                         }
192                         
193                         TlsContentType  contentType     = (TlsContentType)this.ReadByte();
194                         TlsProtocol             protocol        = (TlsProtocol)this.ReadShort();
195                         int                             length          = this.ReadShort();
196                         
197                         // Read Record data
198                         int             received        = 0;
199                         byte[]  buffer          = new byte[length];                                             
200                         while (received != length)
201                         {
202                                 received += base.Receive(
203                                         buffer, received, buffer.Length - received, SocketFlags.None);
204                         }
205
206                         TlsStream message       = new TlsStream(buffer);
207                 
208                         // Check that the message as a valid protocol version
209                         if (protocol != session.Context.Protocol)
210                         {
211                                 throw session.CreateException("Invalid protocol version on message received from server");
212                         }
213
214                         // Decrypt message contents if needed
215                         if (contentType == TlsContentType.Alert &&
216                                 length == 2)
217                         {
218                         }
219                         else
220                         {
221                                 if (session.Context.IsActual &&
222                                         contentType != TlsContentType.ChangeCipherSpec)
223                                 {
224                                         message = decryptRecordFragment(
225                                                 contentType, 
226                                                 protocol, 
227                                                 message.ToArray());
228                                 }
229                         }
230
231                         byte[] result = message.ToArray();
232
233                         // Process record
234                         switch (contentType)
235                         {
236                                 case TlsContentType.Alert:
237                                         processAlert((TlsAlertLevel)message.ReadByte(),
238                                                 (TlsAlertDescription)message.ReadByte());
239                                         break;
240
241                                 case TlsContentType.ChangeCipherSpec:
242                                         // Reset sequence numbers
243                                         session.Context.ReadSequenceNumber = 0;
244                                         break;
245
246                                 case TlsContentType.ApplicationData:
247                                         break;
248
249                                 case TlsContentType.Handshake:
250                                         while (!message.EOF)
251                                         {
252                                                 processHandshakeMessage(message);
253                                         }
254                                         // Update handshakes of current messages
255                                         this.session.Context.HandshakeHashes.Update(message.ToArray());
256                                         break;
257
258                                 default:
259                                         throw session.CreateException("Unknown record received from server.");
260                         }
261
262                         return result;
263                 }
264
265                 #endregion
266
267                 #region TLS_CRYPTO_METHODS
268
269                 private byte[] encryptRecordFragment(TlsContentType contentType, byte[] fragment)
270                 {
271                         // Calculate message MAC
272                         byte[] mac      = encodeClientRecordMAC(contentType, fragment);
273
274                         // Encrypt the message
275                         byte[] ecr = session.Context.Cipher.EncryptRecord(fragment, mac);
276
277                         // Set new IV
278                         if (session.Context.Cipher.CipherMode == CipherMode.CBC)
279                         {
280                                 byte[] iv = new byte[session.Context.Cipher.IvSize];
281                                 System.Array.Copy(ecr, ecr.Length - iv.Length, iv, 0, iv.Length);
282                                 session.Context.Cipher.UpdateClientCipherIV(iv);
283                         }
284
285                         // Update sequence number
286                         session.Context.WriteSequenceNumber++;
287
288                         return ecr;
289                 }
290
291                 private TlsStream decryptRecordFragment(TlsContentType contentType, 
292                         TlsProtocol protocol,
293                         byte[] fragment)
294                 {
295                         byte[]  dcrFragment     = null;
296                         byte[]  dcrMAC          = null;
297
298                         // Decrypt message
299                         session.Context.Cipher.DecryptRecord(fragment, ref dcrFragment, ref dcrMAC);
300
301                         // Set new IV
302                         if (session.Context.Cipher.CipherMode == CipherMode.CBC)
303                         {
304                                 byte[] iv = new byte[session.Context.Cipher.IvSize];
305                                 System.Array.Copy(fragment, fragment.Length - iv.Length, iv, 0, iv.Length);
306                                 session.Context.Cipher.UpdateServerCipherIV(iv);
307                         }
308                         
309                         // Check MAC code
310                         byte[] mac = this.encodeServerRecordMAC(contentType, dcrFragment);
311
312                         // Check that the mac is correct
313                         if (mac.Length != dcrMAC.Length)
314                         {
315                                 throw new TlsException("Invalid MAC received from server.");
316                         }
317                         for (int i = 0; i < mac.Length; i++)
318                         {
319                                 if (mac[i] != dcrMAC[i])
320                                 {
321                                         throw new TlsException("Invalid MAC received from server.");
322                                 }
323                         }
324
325                         // Update sequence number
326                         session.Context.ReadSequenceNumber++;
327
328                         return new TlsStream(dcrFragment);
329                 }
330
331                 #endregion
332
333                 #region TLS_SEND_METHODS
334
335                 internal int SendAlert(TlsAlert alert)
336                 {                       
337                         // Write record
338                         int bytesSent = this.sendRecord(TlsContentType.Alert, alert.ToArray());
339
340                         // Update session
341                         alert.UpdateSession();
342
343                         // Reset message contents
344                         alert.Reset();
345
346                         return bytesSent;
347                 }
348
349                 private int sendRecord(TlsHandshakeType type)
350                 {
351                         TlsHandshakeMessage msg = createClientHandshakeMessage(type);
352                         
353                         // Write record
354                         int bytesSent = this.sendRecord(msg.ContentType, msg.EncodeMessage());
355
356                         // Update session
357                         msg.UpdateSession();
358
359                         // Reset message contents
360                         msg.Reset();
361
362                         return bytesSent;
363                 }
364
365                 private int sendChangeCipherSpec()
366                 {
367                         // Send Change Cipher Spec message
368                         int bytesSent = this.sendRecord(TlsContentType.ChangeCipherSpec, new byte[] {1});
369
370                         // Reset sequence numbers
371                         session.Context.WriteSequenceNumber = 0;
372
373                         // Make the pending state to be the current state
374                         session.Context.IsActual = true;
375
376                         // Send Finished message
377                         bytesSent += this.sendRecord(TlsHandshakeType.Finished);
378
379                         return bytesSent;
380                 }
381                 
382                 private int sendRecord(TlsContentType contentType, byte[] recordData)
383                 {
384                         if (session.Context.ConnectionEnd)
385                         {
386                                 throw session.CreateException("The session is finished and it's no longer valid.");
387                         }
388
389                         int                     bytesSent = 0;
390                         byte[][]        fragments = fragmentData(recordData);
391                         for (int i = 0; i < fragments.Length; i++)
392                         {
393                                 byte[] fragment = fragments[i];
394
395                                 if (session.Context.IsActual)
396                                 {
397                                         // Encrypt fragment
398                                         fragment = encryptRecordFragment(contentType, fragment);
399                                 }
400
401                                 // Write tls message
402                                 TlsStream record = new TlsStream();
403                                 record.Write((byte)contentType);
404                                 record.Write((short)TlsProtocol.Tls1);
405                                 record.Write((short)fragment.Length);
406                                 record.Write(fragment);
407
408                                 // Write record
409                                 bytesSent += base.Send(record.ToArray());
410
411                                 // Reset record data
412                                 record.Reset();
413                         }
414
415                         return bytesSent;
416                 }
417
418                 private byte[][] fragmentData(byte[] messageData)
419                 {
420                         ArrayList d = new ArrayList();
421                         
422                         int     position = 0;
423
424                         while (position < messageData.Length)
425                         {
426                                 short   fragmentLength = 0;
427                                 byte[]  fragmentData;
428                                 if ((messageData.Length - position) > session.MaxFragmentSize)
429                                 {
430                                         fragmentLength = session.MaxFragmentSize;
431                                 }
432                                 else
433                                 {
434                                         fragmentLength = (short)(messageData.Length - position);
435                                 }
436                                 fragmentData = new byte[fragmentLength];
437
438                                 System.Array.Copy(messageData, position, fragmentData, 0, fragmentLength);
439
440                                 d.Add(fragmentData);
441
442                                 position += fragmentLength;
443                         }
444
445                         byte[][] result = new byte[d.Count][];
446                         for (int i = 0; i < d.Count; i++)
447                         {
448                                 result[i] = (byte[])d[i];
449                         }
450
451                         return result;
452                 }
453
454                 #endregion
455
456                 #region MESSAGE_PROCESSING
457
458                 private void processHandshakeMessage(TlsStream handMsg)
459                 {
460                         TlsHandshakeType        handshakeType   = (TlsHandshakeType)handMsg.ReadByte();
461                         TlsHandshakeMessage     message                 = null;
462
463                         // Read message length
464                         int length = handMsg.ReadInt24();
465
466                         // Read message data
467                         byte[] data = new byte[length];
468                         handMsg.Read(data, 0, length);
469
470                         // Create and process the server message
471                         message = createServerHandshakeMessage(handshakeType, data);
472
473                         // Update session
474                         if (message != null)
475                         {
476                                 message.UpdateSession();
477                         }
478                 }
479
480                 private void processAlert(TlsAlertLevel alertLevel, 
481                         TlsAlertDescription alertDesc)
482                 {
483                         switch (alertLevel)
484                         {
485                                 case TlsAlertLevel.Fatal:
486                                         throw session.CreateException(alertLevel, alertDesc);                                   
487
488                                 case TlsAlertLevel.Warning:
489                                 default:
490                                 switch (alertDesc)
491                                 {
492                                         case TlsAlertDescription.CloseNotify:
493                                                 session.Context.ConnectionEnd = true;
494                                                 break;
495
496                                         default:
497                                                 session.RaiseWarningAlert(alertLevel, alertDesc);
498                                                 break;
499                                 }
500                                         break;
501                         }
502                 }
503
504                 #endregion
505
506                 #region MISC_METHODS
507
508                 private void resetBuffer()
509                 {
510                         this.inputBuffer.SetLength(0);
511                         this.inputBuffer.Position = 0;
512                 }
513
514                 private byte[] encodeServerRecordMAC(TlsContentType contentType, byte[] fragment)
515                 {
516                         TlsStream       data    = new TlsStream();
517                         byte[]          result  = null;
518
519                         data.Write(session.Context.ReadSequenceNumber);
520                         data.Write((byte)contentType);
521                         data.Write((short)TlsProtocol.Tls1);
522                         data.Write((short)fragment.Length);
523                         data.Write(fragment);
524
525                         result = session.Context.Cipher.ServerHMAC.ComputeHash(data.ToArray());
526
527                         data.Reset();
528
529                         return result;
530                 }
531
532                 private byte[] encodeClientRecordMAC(TlsContentType contentType, byte[] fragment)
533                 {
534                         TlsStream       data    = new TlsStream();
535                         byte[]          result  = null;
536
537                         data.Write(session.Context.WriteSequenceNumber);
538                         data.Write((byte)contentType);
539                         data.Write((short)TlsProtocol.Tls1);
540                         data.Write((short)fragment.Length);
541                         data.Write(fragment);
542
543                         result = session.Context.Cipher.ClientHMAC.ComputeHash(data.ToArray());
544
545                         data.Reset();
546
547                         return result;
548                 }
549
550                 private byte ReadByte()
551                 {
552                         byte[] b = new byte[1];
553                         base.Receive(b);
554
555                         return b[0];
556                 }
557
558                 private short ReadShort()
559                 {
560                         byte[] b = new byte[2];
561                         base.Receive(b);
562
563                         short val = BitConverter.ToInt16(b, 0);
564
565                         return System.Net.IPAddress.HostToNetworkOrder(val);
566                 }
567
568                 #endregion
569
570                 #region HANDSHAKE_METHODS
571
572                 /*
573                         Client                                                                                  Server
574
575                         ClientHello                 -------->
576                                                                                                                         ServerHello
577                                                                                                                         Certificate*
578                                                                                                                         ServerKeyExchange*
579                                                                                                                         CertificateRequest*
580                                                                                 <--------                       ServerHelloDone
581                         Certificate*
582                         ClientKeyExchange
583                         CertificateVerify*
584                         [ChangeCipherSpec]
585                         Finished                    -------->
586                                                                                                                         [ChangeCipherSpec]
587                                                                                 <--------           Finished
588                         Application Data            <------->                   Application Data
589
590                                         Fig. 1 - Message flow for a full handshake              
591                 */
592
593                 internal void DoHandshake()
594                 {
595                         // Reset isSecure field
596                         this.session.IsSecure = false;
597
598                         // Send client hello
599                         this.sendRecord(TlsHandshakeType.ClientHello);
600
601                         // Read server response
602                         while (!session.HelloDone)
603                         {
604                                 // Read next record
605                                 this.receiveRecord();
606                         }
607                         
608                         // Send client certificate if requested
609                         if (session.Context.ServerSettings.CertificateRequest)
610                         {
611                                 this.sendRecord(TlsHandshakeType.Certificate);
612                         }
613
614                         // Send Client Key Exchange
615                         this.sendRecord(TlsHandshakeType.ClientKeyExchange);
616
617                         // Now initialize session cipher with the generated keys
618                         this.session.Context.Cipher.InitializeCipher();
619
620                         // Send certificate verify if requested
621                         if (session.Context.ServerSettings.CertificateRequest)
622                         {
623                                 this.sendRecord(TlsHandshakeType.CertificateVerify);
624                         }
625
626                         // Send Cipher Spec protocol
627                         this.sendChangeCipherSpec();                    
628                         
629                         // Read Cipher Spec protocol
630                         this.receiveRecord();
631
632                         // Read server finished
633                         if (!session.HandshakeFinished)
634                         {
635                                 this.receiveRecord();
636                         }
637
638                         // Clear Key Info
639                         this.session.Context.ClearKeyInfo();
640
641                         // Set isSecure
642                         this.session.IsSecure = true;
643                 }
644                 
645                 private TlsHandshakeMessage createClientHandshakeMessage(TlsHandshakeType type)
646                 {
647                         switch (type)
648                         {
649                                 case TlsHandshakeType.ClientHello:
650                                         return new TlsClientHello(session);
651
652                                 case TlsHandshakeType.Certificate:
653                                         return new TlsClientCertificate(session);
654
655                                 case TlsHandshakeType.ClientKeyExchange:
656                                         return new TlsClientKeyExchange(session);
657
658                                 case TlsHandshakeType.CertificateVerify:
659                                         return new TlsClientCertificateVerify(session);
660
661                                 case TlsHandshakeType.Finished:
662                                         return new TlsClientFinished(session);
663
664                                 default:
665                                         throw new InvalidOperationException("Unknown client handshake message type: " + type.ToString() );
666                         }
667                 }
668
669                 private TlsHandshakeMessage createServerHandshakeMessage(TlsHandshakeType type, byte[] buffer)
670                 {
671                         switch (type)
672                         {
673                                 case TlsHandshakeType.HelloRequest:
674                                         this.sendRecord(TlsHandshakeType.ClientHello);
675                                         return null;
676
677                                 case TlsHandshakeType.ServerHello:
678                                         return new TlsServerHello(session, buffer);
679
680                                 case TlsHandshakeType.Certificate:
681                                         return new TlsServerCertificate(session, buffer);
682
683                                 case TlsHandshakeType.ServerKeyExchange:
684                                         return new TlsServerKeyExchange(session, buffer);
685
686                                 case TlsHandshakeType.CertificateRequest:
687                                         return new TlsServerCertificateRequest(session, buffer);
688
689                                 case TlsHandshakeType.ServerHelloDone:
690                                         return new TlsServerHelloDone(session, buffer);
691
692                                 case TlsHandshakeType.Finished:
693                                         return new TlsServerFinished(session, buffer);
694
695                                 default:
696                                         throw session.CreateException("Unknown server handshake message received ({0})", type.ToString());
697                         }
698                 }
699
700                 #endregion
701         }
702 }