2004-03-17 Carlos Guzman Alvarez <carlosga@telefonica.net>
[mono.git] / mcs / class / Mono.Security / Mono.Security.Protocol.Tls / RecordProtocol.cs
1 /* Transport Security Layer (TLS)
2  * Copyright (c) 2003-2004 Carlos Guzman Alvarez
3  * 
4  * Permission is hereby granted, free of charge, to any person 
5  * obtaining a copy of this software and associated documentation 
6  * files (the "Software"), to deal in the Software without restriction, 
7  * including without limitation the rights to use, copy, modify, merge, 
8  * publish, distribute, sublicense, and/or sell copies of the Software, 
9  * and to permit persons to whom the Software is furnished to do so, 
10  * subject to the following conditions:
11  * 
12  * The above copyright notice and this permission notice shall be included 
13  * in all copies or substantial portions of the Software.
14  * 
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 
16  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 
17  * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 
18  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 
19  * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 
20  * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 
22  * DEALINGS IN THE SOFTWARE.
23  */
24
25 using System;
26 using System.IO;
27 using System.Security.Cryptography;
28 using System.Security.Cryptography.X509Certificates;
29
30 using Mono.Security.Protocol.Tls.Handshake;
31
32 namespace Mono.Security.Protocol.Tls
33 {
34         internal abstract class RecordProtocol
35         {
36                 #region Fields
37
38                 protected Stream        innerStream;
39                 protected Context       context;
40
41                 #endregion
42
43                 #region Properties
44
45                 public Stream InnerStream
46                 {
47                         get { return this.innerStream; }
48                         set { this.innerStream = value; }
49                 }
50
51                 public Context Context
52                 {
53                         get { return this.context; }
54                         set { this.context = value; }
55                 }
56
57                 #endregion
58
59                 #region Constructors
60
61                 public RecordProtocol(Stream innerStream, Context context)
62                 {
63                         this.innerStream                        = innerStream;
64                         this.context                            = context;
65                         this.context.RecordProtocol = this;
66                 }
67
68                 #endregion
69
70                 #region Abstract Methods
71
72                 public abstract void SendRecord(HandshakeType type);
73                 protected abstract void ProcessHandshakeMessage(TlsStream handMsg);
74                 protected abstract void ProcessChangeCipherSpec();
75                                 
76                 #endregion
77
78                 #region Reveive Record Methods
79
80                 public byte[] ReceiveRecord()
81                 {
82                         if (this.context.ConnectionEnd)
83                         {
84                                 throw this.context.CreateException("The session is finished and it's no longer valid.");
85                         }
86                         
87                         // Try to read the Record Content Type
88                         int type = this.innerStream.ReadByte();
89
90                         // There are no more data for read
91                         if (type == -1)
92                         {
93                                 return null;
94                         }
95
96                         ContentType     contentType     = (ContentType)type;
97                         short           protocol        = this.readShort();
98                         short           length          = this.readShort();
99                         
100                         // Read Record data
101                         int             received        = 0;
102                         byte[]  buffer          = new byte[length];
103                         while (received != length)
104                         {
105                                 received += this.innerStream.Read(
106                                         buffer, received, buffer.Length - received);
107                         }
108
109                         TlsStream message = new TlsStream(buffer);
110                 
111                         // Check that the message has a valid protocol version
112                         if (protocol != this.context.Protocol && 
113                                 this.context.ProtocolNegotiated)
114                         {
115                                 throw this.context.CreateException("Invalid protocol version on message received from server");
116                         }
117
118                         // Decrypt message contents if needed
119                         if (contentType == ContentType.Alert && length == 2)
120                         {
121                         }
122                         else
123                         {
124                                 if (this.context.IsActual &&
125                                         contentType != ContentType.ChangeCipherSpec)
126                                 {
127                                         message = this.decryptRecordFragment(
128                                                 contentType, 
129                                                 message.ToArray());
130                                 }
131                         }
132
133                         // Set last handshake message received to None
134                         this.context.LastHandshakeMsg = HandshakeType.None;
135                         
136                         // Process record
137                         byte[] result = message.ToArray();
138
139                         switch (contentType)
140                         {
141                                 case ContentType.Alert:
142                                         this.processAlert(
143                                                 (AlertLevel)message.ReadByte(),
144                                                 (AlertDescription)message.ReadByte());
145                                         break;
146
147                                 case ContentType.ChangeCipherSpec:
148                                         this.ProcessChangeCipherSpec();
149                                         break;
150
151                                 case ContentType.ApplicationData:
152                                         break;
153
154                                 case ContentType.Handshake:
155                                         while (!message.EOF)
156                                         {
157                                                 this.ProcessHandshakeMessage(message);
158                                         }
159
160                                         // Update handshakes of current messages
161                                         this.context.HandshakeMessages.Write(message.ToArray());
162                                         break;
163
164                                 default:
165                                         throw this.context.CreateException("Unknown record received from server.");
166                         }
167
168                         return result;
169                 }
170
171                 private short readShort()
172                 {
173                         byte[] b = new byte[2];
174                         this.innerStream.Read(b, 0, b.Length);
175
176                         short val = BitConverter.ToInt16(b, 0);
177
178                         return System.Net.IPAddress.HostToNetworkOrder(val);
179                 }
180
181                 private void processAlert(
182                         AlertLevel                      alertLevel, 
183                         AlertDescription        alertDesc)
184                 {
185                         switch (alertLevel)
186                         {
187                                 case AlertLevel.Fatal:
188                                         throw this.context.CreateException(alertLevel, alertDesc);                                      
189
190                                 case AlertLevel.Warning:
191                                 default:
192                                 switch (alertDesc)
193                                 {
194                                         case AlertDescription.CloseNotify:
195                                                 this.context.ConnectionEnd = true;
196                                                 break;
197                                 }
198                                 break;
199                         }
200                 }
201
202                 #endregion
203
204                 #region Send Alert Methods
205
206                 public void SendAlert(AlertDescription description)
207                 {
208                         this.SendAlert(new Alert(this.Context, description));
209                 }
210
211                 public void SendAlert(
212                         AlertLevel                      level, 
213                         AlertDescription        description)
214                 {
215                         this.SendAlert(new Alert(this.Context, level, description));
216                 }
217
218                 public void SendAlert(Alert alert)
219                 {                       
220                         // Write record
221                         this.SendRecord(ContentType.Alert, alert.ToArray());
222
223                         // Update session
224                         alert.Update();
225
226                         // Reset message contents
227                         alert.Reset();
228                 }
229
230                 #endregion
231
232                 #region Send Record Methods
233
234                 public void SendChangeCipherSpec()
235                 {
236                         // Send Change Cipher Spec message as a plain message
237                         this.context.IsActual = false;
238
239                         // Send Change Cipher Spec message
240                         this.SendRecord(ContentType.ChangeCipherSpec, new byte[] {1});
241
242                         // Reset sequence numbers
243                         this.context.WriteSequenceNumber = 0;
244
245                         // Make the pending state to be the current state
246                         this.context.IsActual = true;
247
248                         // Send Finished message
249                         this.SendRecord(HandshakeType.Finished);                        
250                 }
251
252                 public void SendRecord(ContentType contentType, byte[] recordData)
253                 {
254                         if (this.context.ConnectionEnd)
255                         {
256                                 throw this.context.CreateException("The session is finished and it's no longer valid.");
257                         }
258
259                         byte[] record = this.EncodeRecord(contentType, recordData);
260
261                         this.innerStream.Write(record, 0, record.Length);
262                 }
263
264                 public byte[] EncodeRecord(ContentType contentType, byte[] recordData)
265                 {
266                         return this.EncodeRecord(
267                                 contentType,
268                                 recordData,
269                                 0,
270                                 recordData.Length);
271                 }
272
273                 public byte[] EncodeRecord(
274                         ContentType     contentType, 
275                         byte[]          recordData,
276                         int                     offset,
277                         int                     count)
278                 {
279                         if (this.context.ConnectionEnd)
280                         {
281                                 throw this.context.CreateException("The session is finished and it's no longer valid.");
282                         }
283
284                         TlsStream record = new TlsStream();
285
286                         int     position = offset;
287
288                         while (position < ( offset + count ))
289                         {
290                                 short   fragmentLength = 0;
291                                 byte[]  fragment;
292
293                                 if ((count - position) > Context.MAX_FRAGMENT_SIZE)
294                                 {
295                                         fragmentLength = Context.MAX_FRAGMENT_SIZE;
296                                 }
297                                 else
298                                 {
299                                         fragmentLength = (short)(count - position);
300                                 }
301
302                                 // Fill the fragment data
303                                 fragment = new byte[fragmentLength];
304                                 Buffer.BlockCopy(recordData, position, fragment, 0, fragmentLength);
305
306                                 if (this.context.IsActual)
307                                 {
308                                         // Encrypt fragment
309                                         fragment = this.encryptRecordFragment(contentType, fragment);
310                                 }
311
312                                 // Write tls message
313                                 record.Write((byte)contentType);
314                                 record.Write(this.context.Protocol);
315                                 record.Write((short)fragment.Length);
316                                 record.Write(fragment);
317
318                                 // Update buffer position
319                                 position += fragmentLength;
320                         }
321
322                         return record.ToArray();
323                 }
324                 
325                 #endregion
326
327                 #region Cryptography Methods
328
329                 private byte[] encryptRecordFragment(
330                         ContentType     contentType, 
331                         byte[]          fragment)
332                 {
333                         byte[] mac      = null;
334
335                         // Calculate message MAC
336                         if (this.Context is ClientContext)
337                         {
338                                 mac     = this.context.Cipher.ComputeClientRecordMAC(contentType, fragment);
339                         }       
340                         else
341                         {
342                                 mac     = this.context.Cipher.ComputeServerRecordMAC(contentType, fragment);
343                         }
344
345                         // Encrypt the message
346                         byte[] ecr = this.context.Cipher.EncryptRecord(fragment, mac);
347
348                         // Set new Client Cipher IV
349                         if (this.context.Cipher.CipherMode == CipherMode.CBC)
350                         {
351                                 byte[] iv = new byte[this.context.Cipher.IvSize];
352                                 Buffer.BlockCopy(ecr, ecr.Length - iv.Length, iv, 0, iv.Length);
353
354                                 this.context.Cipher.UpdateClientCipherIV(iv);
355                         }
356
357                         // Update sequence number
358                         this.context.WriteSequenceNumber++;
359
360                         return ecr;
361                 }
362
363                 private TlsStream decryptRecordFragment(
364                         ContentType     contentType, 
365                         byte[]          fragment)
366                 {
367                         byte[]  dcrFragment             = null;
368                         byte[]  dcrMAC                  = null;
369                         bool    badRecordMac    = false;
370
371                         try
372                         {
373                                 this.context.Cipher.DecryptRecord(fragment, ref dcrFragment, ref dcrMAC);
374                         }
375                         catch
376                         {
377                                 if (this.context is ServerContext)
378                                 {
379                                         this.Context.RecordProtocol.SendAlert(AlertDescription.DecryptionFailed);
380                                 }
381
382                                 throw;
383                         }
384                         
385                         // Generate record MAC
386                         byte[] mac = null;
387
388                         if (this.Context is ClientContext)
389                         {
390                                 mac = this.context.Cipher.ComputeServerRecordMAC(contentType, dcrFragment);
391                         }
392                         else
393                         {
394                                 mac = this.context.Cipher.ComputeClientRecordMAC(contentType, dcrFragment);
395                         }
396
397                         // Check record MAC
398                         if (mac.Length != dcrMAC.Length)
399                         {
400                                 badRecordMac = true;
401                         }
402                         else
403                         {
404                                 for (int i = 0; i < mac.Length; i++)
405                                 {
406                                         if (mac[i] != dcrMAC[i])
407                                         {
408                                                 badRecordMac = true;
409                                                 break;
410                                         }
411                                 }
412                         }
413
414                         if (badRecordMac)
415                         {
416                                 if (this.context is ServerContext)
417                                 {
418                                         this.Context.RecordProtocol.SendAlert(AlertDescription.BadRecordMAC);
419                                 }
420
421                                 throw new TlsException("Bad record MAC");
422                         }
423
424                         // Update sequence number
425                         this.context.ReadSequenceNumber++;
426
427                         return new TlsStream(dcrFragment);
428                 }
429
430                 #endregion
431         }
432 }