Add to right place
[mono.git] / mcs / class / Mono.Security / Mono.Security.Protocol.Tls / TlsSocket.cs
1 /* Transport Security Layer (TLS)
2  * Copyright (c) 2003 Carlos Guzmán Álvarez
3  * 
4  * Permission is hereby granted, free of charge, to any person 
5  * obtaining a copy of this software and associated documentation 
6  * files (the "Software"), to deal in the Software without restriction, 
7  * including without limitation the rights to use, copy, modify, merge, 
8  * publish, distribute, sublicense, and/or sell copies of the Software, 
9  * and to permit persons to whom the Software is furnished to do so, 
10  * subject to the following conditions:
11  * 
12  * The above copyright notice and this permission notice shall be included 
13  * in all copies or substantial portions of the Software.
14  * 
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 
16  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 
17  * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 
18  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 
19  * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 
20  * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 
22  * DEALINGS IN THE SOFTWARE.
23  */
24
25 using System;
26 using System.IO;
27 using System.Net;
28 using System.Collections;
29 using System.Net.Sockets;
30 using System.Security.Cryptography;
31
32 using Mono.Security.Protocol.Tls;
33 using Mono.Security.Protocol.Tls.Alerts;
34 using Mono.Security.Protocol.Tls.Handshake;
35 using Mono.Security.Protocol.Tls.Handshake.Client;
36
37 namespace Mono.Security.Protocol.Tls
38 {
39         public sealed class TlsSocket : Socket
40         {
41                 #region FIELDS
42
43                 private TlsSession              session;
44                 private BufferedStream  inputBuffer;
45
46                 #endregion
47
48                 #region PROPERTIES
49
50                 internal TlsSession Session
51                 {
52                         get { return this.session; }
53                 }
54
55                 internal BufferedStream InputBuffer
56                 {
57                         get { return this.inputBuffer; }
58                 }
59
60                 #endregion
61
62                 #region CONSTRUCTORS
63
64                 private TlsSocket(
65                                 AddressFamily   addressFamily,
66                                 SocketType              socketType,
67                                 ProtocolType    protocolType
68                                 ) : base(addressFamily, socketType, protocolType)
69                 {
70                         this.inputBuffer = new BufferedStream(new MemoryStream());
71                 }
72
73                 public TlsSocket(
74                         TlsSession              session,
75                         AddressFamily   addressFamily,
76                         SocketType              socketType,
77                         ProtocolType    protocolType
78                         ) : this(addressFamily, socketType, protocolType)
79                 {
80                         this.session = session;
81                 }
82
83                 #endregion
84
85                 #region REPLACED_METHODS
86
87                 public new void Close()
88                 {
89                         this.resetBuffer();
90                         base.Close();
91                         if (this.session.State != TlsSessionState.Closing &&
92                                 this.session.State != TlsSessionState.Closed)
93                         {
94                                 this.session.Close();
95                         }
96                 }
97
98                 public new int Receive(byte[] buffer)
99                 {
100                         return this.Receive(buffer, 0, buffer != null ? buffer.Length : 0, SocketFlags.None);
101                 }
102                                                 
103                 public new int Receive(byte[] buffer, SocketFlags socketFlags)
104                 {
105                         return this.Receive(buffer, 0, buffer != null ? buffer.Length : 0, socketFlags);
106                 }
107
108                 public new int Receive(byte[] buffer, int size, SocketFlags socketFlags)
109                 {
110                         return this.Receive(buffer, 0, size, socketFlags);
111                 }
112
113                 public new int Receive(byte[] buffer, int offset, int size, SocketFlags socketFlags)
114                 {
115                         if (!this.session.IsSecure)
116                         {
117                                 return base.Receive(buffer, offset, size, socketFlags);
118                         }
119                         
120                         // If actual buffer is full readed reset it
121                         if (this.inputBuffer.Position == this.inputBuffer.Length)
122                         {
123                                 this.resetBuffer();
124                         }
125
126                         // Check if we have space in the middle buffer
127                         // if not Read next TLS record and update the inputBuffer
128                         while ((this.inputBuffer.Length - this.inputBuffer.Position) < size)
129                         {
130                                 // Read next record and write it into the inputBuffer
131                                 long    position        = this.inputBuffer.Position;                                    
132                                 byte[]  record          = this.receiveRecord();
133
134                                 if (record.Length > 0)
135                                 {
136                                         // Write new data to the inputBuffer
137                                         this.inputBuffer.Seek(0, SeekOrigin.End);
138                                         this.inputBuffer.Write(record, 0, record.Length);
139
140                                         // Restore buffer position
141                                         this.inputBuffer.Seek(position, SeekOrigin.Begin);
142                                 }
143
144                                 if (base.Available == 0)
145                                 {
146                                         break;
147                                 }
148                         }
149
150                         return this.inputBuffer.Read(buffer, offset, size);
151                 }
152
153                 public new int Send(byte[] buffer)
154                 {
155                         return this.Send(buffer, 0, buffer != null ? buffer.Length : 0, SocketFlags.None);
156                 }
157                                                 
158                 public new int Send(byte[] buffer, SocketFlags socketFlags)
159                 {
160                         return this.Send(buffer, 0, buffer != null ? buffer.Length : 0, socketFlags);
161                 }
162
163                 public new int Send(byte[] buffer, int size, SocketFlags socketFlags)
164                 {
165                         return this.Send(buffer, 0, size, socketFlags);
166                 }
167
168                 public new int Send(byte[] buffer, int offset, int size, SocketFlags socketFlags)
169                 {
170                         if (!this.session.IsSecure)
171                         {
172                                 return base.Send(buffer, offset, size, socketFlags);
173                         }
174
175                         // Send the buffer as a TLS record
176                         byte[] recordData = new byte[size];
177                         System.Array.Copy(buffer, offset, recordData, 0, size);
178
179                         return this.sendRecord(TlsContentType.ApplicationData, recordData);
180                 }
181
182                 #endregion
183
184                 #region TLS_RECORD_METHODS
185
186                 private byte[] receiveRecord()
187                 {
188                         if (this.session.Context.ConnectionEnd)
189                         {
190                                 throw this.session.CreateException("The session is finished and it's no longer valid.");
191                         }
192                         
193                         TlsContentType  contentType     = (TlsContentType)this.ReadByte();
194                         TlsProtocol             protocol        = (TlsProtocol)this.ReadShort();
195                         short                   length          = this.ReadShort();
196                         
197                         // Read Record data
198                         int             received        = 0;
199                         byte[]  buffer          = new byte[length];                                             
200                         while (received != length)
201                         {
202                                 received += base.Receive(
203                                         buffer, received, buffer.Length - received, SocketFlags.None);
204                         }
205
206                         TlsStream message = new TlsStream(buffer);
207                 
208                         // Check that the message has a valid protocol version
209                         if (protocol != this.session.Context.Protocol)
210                         {
211                                 throw session.CreateException("Invalid protocol version on message received from server");
212                         }
213
214                         // Decrypt message contents if needed
215                         if (contentType == TlsContentType.Alert && length == 2)
216                         {
217                         }
218                         else
219                         {
220                                 if (session.Context.IsActual &&
221                                         contentType != TlsContentType.ChangeCipherSpec)
222                                 {
223                                         message = this.decryptRecordFragment(
224                                                 contentType, 
225                                                 protocol,
226                                                 message.ToArray());
227                                 }
228                         }
229
230                         byte[] result = message.ToArray();
231
232                         // Process record
233                         switch (contentType)
234                         {
235                                 case TlsContentType.Alert:
236                                         this.processAlert((TlsAlertLevel)message.ReadByte(),
237                                                 (TlsAlertDescription)message.ReadByte());
238                                         break;
239
240                                 case TlsContentType.ChangeCipherSpec:
241                                         // Reset sequence numbers
242                                         this.session.Context.ReadSequenceNumber = 0;
243                                         break;
244
245                                 case TlsContentType.ApplicationData:
246                                         break;
247
248                                 case TlsContentType.Handshake:
249                                         while (!message.EOF)
250                                         {
251                                                 this.processHandshakeMessage(message);
252                                         }
253                                         // Update handshakes of current messages
254                                         this.session.Context.HandshakeMessages.Write(message.ToArray());
255                                         break;
256
257                                 default:
258                                         throw session.CreateException("Unknown record received from server.");
259                         }
260
261                         return result;
262                 }
263
264                 #endregion
265
266                 #region TLS_CRYPTO_METHODS
267
268                 private byte[] encryptRecordFragment(TlsContentType contentType, byte[] fragment)
269                 {
270                         // Calculate message MAC
271                         byte[] mac      = this.session.Context.Cipher.ComputeClientRecordMAC(contentType, fragment);
272
273                         // Encrypt the message
274                         byte[] ecr = this.session.Context.Cipher.EncryptRecord(fragment, mac);
275
276                         // Set new IV
277                         if (this.session.Context.Cipher.CipherMode == CipherMode.CBC)
278                         {
279                                 byte[] iv = new byte[this.session.Context.Cipher.IvSize];
280                                 System.Array.Copy(ecr, ecr.Length - iv.Length, iv, 0, iv.Length);
281                                 this.session.Context.Cipher.UpdateClientCipherIV(iv);
282                         }
283
284                         // Update sequence number
285                         this.session.Context.WriteSequenceNumber++;
286
287                         return ecr;
288                 }
289
290                 private TlsStream decryptRecordFragment(TlsContentType contentType, 
291                         TlsProtocol protocol,
292                         byte[] fragment)
293                 {
294                         byte[]  dcrFragment     = null;
295                         byte[]  dcrMAC          = null;
296
297                         // Decrypt message
298                         this.session.Context.Cipher.DecryptRecord(fragment, ref dcrFragment, ref dcrMAC);
299
300                         // Set new IV
301                         if (this.session.Context.Cipher.CipherMode == CipherMode.CBC)
302                         {
303                                 byte[] iv = new byte[session.Context.Cipher.IvSize];
304                                 System.Array.Copy(fragment, fragment.Length - iv.Length, iv, 0, iv.Length);
305                                 this.session.Context.Cipher.UpdateServerCipherIV(iv);
306                         }
307                         
308                         // Check MAC code
309                         byte[] mac = this.session.Context.Cipher.ComputeServerRecordMAC(contentType, dcrFragment);
310
311                         // Check that the mac is correct
312                         if (mac.Length != dcrMAC.Length)
313                         {
314                                 throw new TlsException("Invalid MAC received from server.");
315                         }
316                         for (int i = 0; i < mac.Length; i++)
317                         {
318                                 if (mac[i] != dcrMAC[i])
319                                 {
320                                         throw new TlsException("Invalid MAC received from server.");
321                                 }
322                         }
323
324                         // Update sequence number
325                         this.session.Context.ReadSequenceNumber++;
326
327                         return new TlsStream(dcrFragment);
328                 }
329
330                 #endregion
331
332                 #region TLS_SEND_METHODS
333
334                 internal int SendAlert(TlsAlert alert)
335                 {                       
336                         // Write record
337                         int bytesSent = this.sendRecord(TlsContentType.Alert, alert.ToArray());
338
339                         // Update session
340                         alert.UpdateSession();
341
342                         // Reset message contents
343                         alert.Reset();
344
345                         return bytesSent;
346                 }
347
348                 private int sendRecord(TlsHandshakeType type)
349                 {
350                         TlsHandshakeMessage msg = createClientHandshakeMessage(type);
351                         
352                         // Write record
353                         int bytesSent = this.sendRecord(msg.ContentType, msg.EncodeMessage());
354
355                         // Update session
356                         msg.UpdateSession();
357
358                         // Reset message contents
359                         msg.Reset();
360
361                         return bytesSent;
362                 }
363
364                 private int sendChangeCipherSpec()
365                 {
366                         // Send Change Cipher Spec message
367                         int bytesSent = this.sendRecord(TlsContentType.ChangeCipherSpec, new byte[] {1});
368
369                         // Reset sequence numbers
370                         this.session.Context.WriteSequenceNumber = 0;
371
372                         // Make the pending state to be the current state
373                         this.session.Context.IsActual = true;
374
375                         // Send Finished message
376                         bytesSent += this.sendRecord(TlsHandshakeType.Finished);
377
378                         return bytesSent;
379                 }
380                 
381                 private int sendRecord(TlsContentType contentType, byte[] recordData)
382                 {
383                         if (this.session.Context.ConnectionEnd)
384                         {
385                                 throw this.session.CreateException("The session is finished and it's no longer valid.");
386                         }
387
388                         int                     bytesSent = 0;
389                         byte[][]        fragments = fragmentData(recordData);
390                         for (int i = 0; i < fragments.Length; i++)
391                         {
392                                 byte[] fragment = fragments[i];
393
394                                 if (this.session.Context.IsActual)
395                                 {
396                                         // Encrypt fragment
397                                         fragment = this.encryptRecordFragment(contentType, fragment);
398                                 }
399
400                                 // Write tls message
401                                 TlsStream record = new TlsStream();
402                                 record.Write((byte)contentType);
403                                 record.Write((short)this.session.Context.Protocol);
404                                 record.Write((short)fragment.Length);
405                                 record.Write(fragment);
406
407                                 // Write record
408                                 bytesSent += base.Send(record.ToArray());
409
410                                 // Reset record data
411                                 record.Reset();
412                         }
413
414                         return bytesSent;
415                 }
416
417                 private byte[][] fragmentData(byte[] messageData)
418                 {
419                         ArrayList d = new ArrayList();
420                         
421                         int     position = 0;
422
423                         while (position < messageData.Length)
424                         {
425                                 short   fragmentLength = 0;
426                                 byte[]  fragmentData;
427                                 if ((messageData.Length - position) > TlsSessionContext.MAX_FRAGMENT_SIZE)
428                                 {
429                                         fragmentLength = TlsSessionContext.MAX_FRAGMENT_SIZE;
430                                 }
431                                 else
432                                 {
433                                         fragmentLength = (short)(messageData.Length - position);
434                                 }
435                                 fragmentData = new byte[fragmentLength];
436
437                                 System.Array.Copy(messageData, position, fragmentData, 0, fragmentLength);
438
439                                 d.Add(fragmentData);
440
441                                 position += fragmentLength;
442                         }
443
444                         byte[][] result = new byte[d.Count][];
445                         for (int i = 0; i < d.Count; i++)
446                         {
447                                 result[i] = (byte[])d[i];
448                         }
449
450                         return result;
451                 }
452
453                 #endregion
454
455                 #region MESSAGE_PROCESSING
456
457                 private void processHandshakeMessage(TlsStream handMsg)
458                 {
459                         TlsHandshakeType        handshakeType   = (TlsHandshakeType)handMsg.ReadByte();
460                         TlsHandshakeMessage     message                 = null;
461
462                         // Read message length
463                         int length = handMsg.ReadInt24();
464
465                         // Read message data
466                         byte[] data = new byte[length];
467                         handMsg.Read(data, 0, length);
468
469                         // Create and process the server message
470                         message = this.createServerHandshakeMessage(handshakeType, data);
471
472                         // Update session
473                         if (message != null)
474                         {
475                                 message.UpdateSession();
476                         }
477                 }
478
479                 private void processAlert(TlsAlertLevel alertLevel, TlsAlertDescription alertDesc)
480                 {
481                         switch (alertLevel)
482                         {
483                                 case TlsAlertLevel.Fatal:
484                                         throw this.session.CreateException(alertLevel, alertDesc);                                      
485
486                                 case TlsAlertLevel.Warning:
487                                 default:
488                                 switch (alertDesc)
489                                 {
490                                         case TlsAlertDescription.CloseNotify:
491                                                 this.session.Context.ConnectionEnd = true;
492                                                 break;
493
494                                         default:
495                                                 this.session.RaiseWarningAlert(alertLevel, alertDesc);
496                                                 break;
497                                 }
498                                 break;
499                         }
500                 }
501
502                 #endregion
503
504                 #region MISC_METHODS
505
506                 private void resetBuffer()
507                 {
508                         this.inputBuffer.SetLength(0);
509                         this.inputBuffer.Position = 0;
510                 }
511
512                 private byte ReadByte()
513                 {
514                         byte[] b = new byte[1];
515                         base.Receive(b);
516
517                         return b[0];
518                 }
519
520                 private short ReadShort()
521                 {
522                         byte[] b = new byte[2];
523                         base.Receive(b);
524
525                         short val = BitConverter.ToInt16(b, 0);
526
527                         return System.Net.IPAddress.HostToNetworkOrder(val);
528                 }
529
530                 #endregion
531
532                 #region HANDSHAKE_METHODS
533
534                 /*
535                         Client                                                                                  Server
536
537                         ClientHello                 -------->
538                                                                                                                         ServerHello
539                                                                                                                         Certificate*
540                                                                                                                         ServerKeyExchange*
541                                                                                                                         CertificateRequest*
542                                                                                 <--------                       ServerHelloDone
543                         Certificate*
544                         ClientKeyExchange
545                         CertificateVerify*
546                         [ChangeCipherSpec]
547                         Finished                    -------->
548                                                                                                                         [ChangeCipherSpec]
549                                                                                 <--------           Finished
550                         Application Data            <------->                   Application Data
551
552                                         Fig. 1 - Message flow for a full handshake              
553                 */
554
555                 internal void DoHandshake()
556                 {
557                         // Reset isSecure field
558                         this.session.IsSecure = false;
559
560                         // Send client hello
561                         this.sendRecord(TlsHandshakeType.ClientHello);
562
563                         // Read server response
564                         while (!this.session.Context.HelloDone)
565                         {
566                                 // Read next record
567                                 this.receiveRecord();
568                         }
569                         
570                         // Send client certificate if requested
571                         if (this.session.Context.ServerSettings.CertificateRequest)
572                         {
573                                 this.sendRecord(TlsHandshakeType.Certificate);
574                         }
575
576                         // Send Client Key Exchange
577                         this.sendRecord(TlsHandshakeType.ClientKeyExchange);
578
579                         // Now initialize session cipher with the generated keys
580                         this.session.Context.Cipher.InitializeCipher();
581
582                         // Send certificate verify if requested
583                         if (this.session.Context.ServerSettings.CertificateRequest)
584                         {
585                                 this.sendRecord(TlsHandshakeType.CertificateVerify);
586                         }
587
588                         // Send Cipher Spec protocol
589                         this.sendChangeCipherSpec();                    
590                         
591                         // Read Cipher Spec protocol
592                         this.receiveRecord();
593
594                         // Read server finished
595                         if (!this.session.Context.HandshakeFinished)
596                         {
597                                 this.receiveRecord();
598                         }
599
600                         // Clear Key Info
601                         this.session.Context.ClearKeyInfo();
602
603                         // Set isSecure
604                         this.session.IsSecure = true;
605                 }
606                 
607                 private TlsHandshakeMessage createClientHandshakeMessage(TlsHandshakeType type)
608                 {
609                         switch (type)
610                         {
611                                 case TlsHandshakeType.ClientHello:
612                                         return new TlsClientHello(session);
613
614                                 case TlsHandshakeType.Certificate:
615                                         return new TlsClientCertificate(session);
616
617                                 case TlsHandshakeType.ClientKeyExchange:
618                                         return new TlsClientKeyExchange(session);
619
620                                 case TlsHandshakeType.CertificateVerify:
621                                         return new TlsClientCertificateVerify(session);
622
623                                 case TlsHandshakeType.Finished:
624                                         return new TlsClientFinished(session);
625
626                                 default:
627                                         throw new InvalidOperationException("Unknown client handshake message type: " + type.ToString() );
628                         }
629                 }
630
631                 private TlsHandshakeMessage createServerHandshakeMessage(TlsHandshakeType type, byte[] buffer)
632                 {
633                         switch (type)
634                         {
635                                 case TlsHandshakeType.HelloRequest:
636                                         this.sendRecord(TlsHandshakeType.ClientHello);
637                                         return null;
638
639                                 case TlsHandshakeType.ServerHello:
640                                         return new TlsServerHello(session, buffer);
641
642                                 case TlsHandshakeType.Certificate:
643                                         return new TlsServerCertificate(session, buffer);
644
645                                 case TlsHandshakeType.ServerKeyExchange:
646                                         return new TlsServerKeyExchange(session, buffer);
647
648                                 case TlsHandshakeType.CertificateRequest:
649                                         return new TlsServerCertificateRequest(session, buffer);
650
651                                 case TlsHandshakeType.ServerHelloDone:
652                                         return new TlsServerHelloDone(session, buffer);
653
654                                 case TlsHandshakeType.Finished:
655                                         return new TlsServerFinished(session, buffer);
656
657                                 default:
658                                         throw this.session.CreateException("Unknown server handshake message received ({0})", type.ToString());
659                         }
660                 }
661
662                 #endregion
663         }
664 }