2004-03-10 Carlos Guzman Alvarez <carlosga@telefonica.net>
[mono.git] / mcs / class / Mono.Security / Mono.Security.Protocol.Tls / RecordProtocol.cs
1 /* Transport Security Layer (TLS)
2  * Copyright (c) 2003-2004 Carlos Guzman Alvarez
3  * 
4  * Permission is hereby granted, free of charge, to any person 
5  * obtaining a copy of this software and associated documentation 
6  * files (the "Software"), to deal in the Software without restriction, 
7  * including without limitation the rights to use, copy, modify, merge, 
8  * publish, distribute, sublicense, and/or sell copies of the Software, 
9  * and to permit persons to whom the Software is furnished to do so, 
10  * subject to the following conditions:
11  * 
12  * The above copyright notice and this permission notice shall be included 
13  * in all copies or substantial portions of the Software.
14  * 
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 
16  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 
17  * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 
18  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 
19  * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 
20  * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 
22  * DEALINGS IN THE SOFTWARE.
23  */
24
25 using System;
26 using System.IO;
27 using System.Security.Cryptography;
28 using System.Security.Cryptography.X509Certificates;
29
30 using Mono.Security.Protocol.Tls.Handshake;
31
32 namespace Mono.Security.Protocol.Tls
33 {
34         internal abstract class RecordProtocol
35         {
36                 #region Fields
37
38                 protected Stream        innerStream;
39                 protected Context       context;
40
41                 #endregion
42
43                 #region Properties
44
45                 public Stream InnerStream
46                 {
47                         get { return this.innerStream; }
48                         set { this.innerStream = value; }
49                 }
50
51                 public Context Context
52                 {
53                         get { return this.context; }
54                         set { this.context = value; }
55                 }
56
57                 #endregion
58
59                 #region Constructors
60
61                 public RecordProtocol(Stream innerStream, Context context)
62                 {
63                         this.innerStream        = innerStream;
64                         this.context            = context;
65                 }
66
67                 #endregion
68
69                 #region Abstract Methods
70
71                 public abstract void SendRecord(HandshakeType type);
72                 protected abstract void ProcessHandshakeMessage(TlsStream handMsg);
73                 protected abstract void ProcessChangeCipherSpec();
74                                 
75                 #endregion
76
77                 #region Reveive Record Methods
78
79                 public byte[] ReceiveRecord()
80                 {
81                         if (this.context.ConnectionEnd)
82                         {
83                                 throw this.context.CreateException("The session is finished and it's no longer valid.");
84                         }
85                         
86                         // Try to read the Record Content Type
87                         int type = this.innerStream.ReadByte();
88
89                         // There are no more data for read
90                         if (type == -1)
91                         {
92                                 return null;
93                         }
94
95                         ContentType     contentType     = (ContentType)type;
96                         short           protocol        = this.readShort();
97                         short           length          = this.readShort();
98                         
99                         // Read Record data
100                         int             received        = 0;
101                         byte[]  buffer          = new byte[length];
102                         while (received != length)
103                         {
104                                 received += this.innerStream.Read(
105                                         buffer, received, buffer.Length - received);
106                         }
107
108                         TlsStream message = new TlsStream(buffer);
109                 
110                         // Check that the message has a valid protocol version
111                         if (protocol != this.context.Protocol && 
112                                 this.context.ProtocolNegotiated)
113                         {
114                                 throw this.context.CreateException("Invalid protocol version on message received from server");
115                         }
116
117                         // Decrypt message contents if needed
118                         if (contentType == ContentType.Alert && length == 2)
119                         {
120                         }
121                         else
122                         {
123                                 if (this.context.IsActual &&
124                                         contentType != ContentType.ChangeCipherSpec)
125                                 {
126                                         message = this.decryptRecordFragment(
127                                                 contentType, 
128                                                 message.ToArray());
129                                 }
130                         }
131
132                         // Set last handshake message received to None
133                         this.context.LastHandshakeMsg = HandshakeType.None;
134                         
135                         // Process record
136                         byte[] result = message.ToArray();
137
138                         switch (contentType)
139                         {
140                                 case ContentType.Alert:
141                                         this.processAlert(
142                                                 (AlertLevel)message.ReadByte(),
143                                                 (AlertDescription)message.ReadByte());
144                                         break;
145
146                                 case ContentType.ChangeCipherSpec:
147                                         this.ProcessChangeCipherSpec();
148                                         break;
149
150                                 case ContentType.ApplicationData:
151                                         break;
152
153                                 case ContentType.Handshake:
154                                         while (!message.EOF)
155                                         {
156                                                 this.ProcessHandshakeMessage(message);
157                                         }
158
159                                         // Update handshakes of current messages
160                                         this.context.HandshakeMessages.Write(message.ToArray());
161                                         break;
162
163                                 default:
164                                         throw this.context.CreateException("Unknown record received from server.");
165                         }
166
167                         return result;
168                 }
169
170                 private short readShort()
171                 {
172                         byte[] b = new byte[2];
173                         this.innerStream.Read(b, 0, b.Length);
174
175                         short val = BitConverter.ToInt16(b, 0);
176
177                         return System.Net.IPAddress.HostToNetworkOrder(val);
178                 }
179
180                 private void processAlert(
181                         AlertLevel                      alertLevel, 
182                         AlertDescription        alertDesc)
183                 {
184                         switch (alertLevel)
185                         {
186                                 case AlertLevel.Fatal:
187                                         throw this.context.CreateException(alertLevel, alertDesc);                                      
188
189                                 case AlertLevel.Warning:
190                                 default:
191                                 switch (alertDesc)
192                                 {
193                                         case AlertDescription.CloseNotify:
194                                                 this.context.ConnectionEnd = true;
195                                                 break;
196                                 }
197                                 break;
198                         }
199                 }
200
201                 #endregion
202
203                 #region Send Alert Methods
204
205                 public void SendAlert(AlertDescription description)
206                 {
207                         this.SendAlert(new Alert(this.Context, description));
208                 }
209
210                 public void SendAlert(
211                         AlertLevel                      level, 
212                         AlertDescription        description)
213                 {
214                         this.SendAlert(new Alert(this.Context, level, description));
215                 }
216
217                 public void SendAlert(Alert alert)
218                 {                       
219                         // Write record
220                         this.SendRecord(ContentType.Alert, alert.ToArray());
221
222                         // Update session
223                         alert.Update();
224
225                         // Reset message contents
226                         alert.Reset();
227                 }
228
229                 #endregion
230
231                 #region Send Record Methods
232
233                 public void SendChangeCipherSpec()
234                 {
235                         // Send Change Cipher Spec message as a plain message
236                         this.context.IsActual = false;
237
238                         // Send Change Cipher Spec message
239                         this.SendRecord(ContentType.ChangeCipherSpec, new byte[] {1});
240
241                         // Reset sequence numbers
242                         this.context.WriteSequenceNumber = 0;
243
244                         // Make the pending state to be the current state
245                         this.context.IsActual = true;
246
247                         // Send Finished message
248                         this.SendRecord(HandshakeType.Finished);                        
249                 }
250
251                 public void SendRecord(ContentType contentType, byte[] recordData)
252                 {
253                         if (this.context.ConnectionEnd)
254                         {
255                                 throw this.context.CreateException("The session is finished and it's no longer valid.");
256                         }
257
258                         byte[] record = this.EncodeRecord(contentType, recordData);
259
260                         this.innerStream.Write(record, 0, record.Length);
261                 }
262
263                 public byte[] EncodeRecord(ContentType contentType, byte[] recordData)
264                 {
265                         return this.EncodeRecord(
266                                 contentType,
267                                 recordData,
268                                 0,
269                                 recordData.Length);
270                 }
271
272                 public byte[] EncodeRecord(
273                         ContentType     contentType, 
274                         byte[]          recordData,
275                         int                     offset,
276                         int                     count)
277                 {
278                         if (this.context.ConnectionEnd)
279                         {
280                                 throw this.context.CreateException("The session is finished and it's no longer valid.");
281                         }
282
283                         TlsStream record = new TlsStream();
284
285                         int     position = offset;
286
287                         while (position < ( offset + count ))
288                         {
289                                 short   fragmentLength = 0;
290                                 byte[]  fragment;
291
292                                 if ((count - position) > Context.MAX_FRAGMENT_SIZE)
293                                 {
294                                         fragmentLength = Context.MAX_FRAGMENT_SIZE;
295                                 }
296                                 else
297                                 {
298                                         fragmentLength = (short)(count - position);
299                                 }
300
301                                 // Fill the fragment data
302                                 fragment = new byte[fragmentLength];
303                                 Buffer.BlockCopy(recordData, position, fragment, 0, fragmentLength);
304
305                                 if (this.context.IsActual)
306                                 {
307                                         // Encrypt fragment
308                                         fragment = this.encryptRecordFragment(contentType, fragment);
309                                 }
310
311                                 // Write tls message
312                                 record.Write((byte)contentType);
313                                 record.Write(this.context.Protocol);
314                                 record.Write((short)fragment.Length);
315                                 record.Write(fragment);
316
317                                 // Update buffer position
318                                 position += fragmentLength;
319                         }
320
321                         return record.ToArray();
322                 }
323                 
324                 #endregion
325
326                 #region Cryptography Methods
327
328                 private byte[] encryptRecordFragment(
329                         ContentType     contentType, 
330                         byte[]          fragment)
331                 {
332                         byte[] mac      = null;
333
334                         // Calculate message MAC
335                         if (this.Context is ClientContext)
336                         {
337                                 mac     = this.context.Cipher.ComputeClientRecordMAC(contentType, fragment);
338                         }       
339                         else
340                         {
341                                 mac     = this.context.Cipher.ComputeServerRecordMAC(contentType, fragment);
342                         }
343
344                         // Encrypt the message
345                         byte[] ecr = this.context.Cipher.EncryptRecord(fragment, mac);
346
347                         // Set new Client Cipher IV
348                         if (this.context.Cipher.CipherMode == CipherMode.CBC)
349                         {
350                                 byte[] iv = new byte[this.context.Cipher.IvSize];
351                                 System.Array.Copy(ecr, ecr.Length - iv.Length, iv, 0, iv.Length);
352
353                                 this.context.Cipher.UpdateClientCipherIV(iv);
354                         }
355
356                         // Update sequence number
357                         this.context.WriteSequenceNumber++;
358
359                         return ecr;
360                 }
361
362                 private TlsStream decryptRecordFragment(
363                         ContentType     contentType, 
364                         byte[]          fragment)
365                 {
366                         byte[]  dcrFragment     = null;
367                         byte[]  dcrMAC          = null;
368
369                         // Decrypt message
370                         this.context.Cipher.DecryptRecord(fragment, ref dcrFragment, ref dcrMAC);
371                         
372                         // Check MAC code
373                         byte[] mac = null;
374                         if (this.Context is ClientContext)
375                         {
376                                 mac = this.context.Cipher.ComputeServerRecordMAC(contentType, dcrFragment);
377                         }
378                         else
379                         {
380                                 mac = this.context.Cipher.ComputeClientRecordMAC(contentType, dcrFragment);
381                         }
382
383                         // Check that the mac is correct
384                         if (mac.Length != dcrMAC.Length)
385                         {
386                                 throw new TlsException("Invalid MAC received from server.");
387                         }
388
389                         for (int i = 0; i < mac.Length; i++)
390                         {
391                                 if (mac[i] != dcrMAC[i])
392                                 {
393                                         throw new TlsException("Invalid MAC received from server.");
394                                 }
395                         }
396
397                         // Update sequence number
398                         this.context.ReadSequenceNumber++;
399
400                         return new TlsStream(dcrFragment);
401                 }
402
403                 #endregion
404         }
405 }