2004-02-26 Carlos Guzman Alvarez <carlosga@telefonica.net>
[mono.git] / mcs / class / Mono.Security / Mono.Security.Protocol.Tls / RecordProtocol.cs
1 /* Transport Security Layer (TLS)
2  * Copyright (c) 2003-2004 Carlos Guzman Alvarez
3  * 
4  * Permission is hereby granted, free of charge, to any person 
5  * obtaining a copy of this software and associated documentation 
6  * files (the "Software"), to deal in the Software without restriction, 
7  * including without limitation the rights to use, copy, modify, merge, 
8  * publish, distribute, sublicense, and/or sell copies of the Software, 
9  * and to permit persons to whom the Software is furnished to do so, 
10  * subject to the following conditions:
11  * 
12  * The above copyright notice and this permission notice shall be included 
13  * in all copies or substantial portions of the Software.
14  * 
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 
16  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 
17  * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 
18  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 
19  * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 
20  * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 
22  * DEALINGS IN THE SOFTWARE.
23  */
24
25 using System;
26 using System.IO;
27 using System.Security.Cryptography;
28 using System.Security.Cryptography.X509Certificates;
29
30 using Mono.Security.Protocol.Tls.Alerts;
31 using Mono.Security.Protocol.Tls.Handshake;
32
33 namespace Mono.Security.Protocol.Tls
34 {
35         internal abstract class RecordProtocol
36         {
37                 #region Fields
38
39                 protected Stream        innerStream;
40                 protected Context       context;
41
42                 #endregion
43
44                 #region Properties
45
46                 public Stream InnerStream
47                 {
48                         get { return this.innerStream; }
49                         set { this.innerStream = value; }
50                 }
51
52                 public Context Context
53                 {
54                         get { return this.context; }
55                         set { this.context = value; }
56                 }
57
58                 #endregion
59
60                 #region Constructors
61
62                 public RecordProtocol(Stream innerStream, Context context)
63                 {
64                         this.innerStream        = innerStream;
65                         this.context            = context;
66                 }
67
68                 #endregion
69
70                 #region Abstract Methods
71
72                 public abstract void SendRecord(TlsHandshakeType type);
73                 protected abstract void ProcessHandshakeMessage(TlsStream handMsg);
74                                 
75                 #endregion
76
77                 #region Reveive Record Methods
78
79                 public byte[] ReceiveRecord()
80                 {
81                         if (this.context.ConnectionEnd)
82                         {
83                                 throw this.context.CreateException("The session is finished and it's no longer valid.");
84                         }
85                         
86                         // Try to read the Record Content Type
87                         int type = this.innerStream.ReadByte();
88
89                         // There are no more data for read
90                         if (type == -1)
91                         {
92                                 return null;
93                         }
94
95                         TlsContentType  contentType     = (TlsContentType)type;
96                         short                   protocol        = this.readShort();
97                         short                   length          = this.readShort();
98                         
99                         // Read Record data
100                         int             received        = 0;
101                         byte[]  buffer          = new byte[length];
102                         while (received != length)
103                         {
104                                 received += this.innerStream.Read(
105                                         buffer, received, buffer.Length - received);
106                         }
107
108                         TlsStream message = new TlsStream(buffer);
109                 
110                         // Check that the message has a valid protocol version
111                         if (protocol != this.context.Protocol && this.context.ProtocolNegotiated)
112                         {
113                                 throw this.context.CreateException("Invalid protocol version on message received from server");
114                         }
115
116                         // Decrypt message contents if needed
117                         if (contentType == TlsContentType.Alert && length == 2)
118                         {
119                         }
120                         else
121                         {
122                                 if (this.context.IsActual &&
123                                         contentType != TlsContentType.ChangeCipherSpec)
124                                 {
125                                         message = this.decryptRecordFragment(
126                                                 contentType, 
127                                                 message.ToArray());
128                                 }
129                         }
130
131                         byte[] result = message.ToArray();
132
133                         // Process record
134                         switch (contentType)
135                         {
136                                 case TlsContentType.Alert:
137                                         this.processAlert(
138                                                 (TlsAlertLevel)message.ReadByte(),
139                                                 (TlsAlertDescription)message.ReadByte());
140                                         break;
141
142                                 case TlsContentType.ChangeCipherSpec:
143                                         // Reset sequence numbers
144                                         this.context.ReadSequenceNumber = 0;
145                                         break;
146
147                                 case TlsContentType.ApplicationData:
148                                         break;
149
150                                 case TlsContentType.Handshake:
151                                         while (!message.EOF)
152                                         {
153                                                 this.ProcessHandshakeMessage(message);
154                                         }
155
156                                         // Update handshakes of current messages
157                                         this.context.HandshakeMessages.Write(message.ToArray());
158                                         break;
159
160                                 default:
161                                         throw this.context.CreateException("Unknown record received from server.");
162                         }
163
164                         return result;
165                 }
166
167                 private short readShort()
168                 {
169                         byte[] b = new byte[2];
170                         this.innerStream.Read(b, 0, b.Length);
171
172                         short val = BitConverter.ToInt16(b, 0);
173
174                         return System.Net.IPAddress.HostToNetworkOrder(val);
175                 }
176
177                 private void processAlert(
178                         TlsAlertLevel           alertLevel, 
179                         TlsAlertDescription alertDesc)
180                 {
181                         switch (alertLevel)
182                         {
183                                 case TlsAlertLevel.Fatal:
184                                         throw this.context.CreateException(alertLevel, alertDesc);                                      
185
186                                 case TlsAlertLevel.Warning:
187                                 default:
188                                 switch (alertDesc)
189                                 {
190                                         case TlsAlertDescription.CloseNotify:
191                                                 this.context.ConnectionEnd = true;
192                                                 break;
193                                 }
194                                 break;
195                         }
196                 }
197
198                 #endregion
199
200                 #region Send Alert Methods
201
202                 public void SendAlert(TlsAlertDescription description)
203                 {
204                         this.SendAlert(new TlsAlert(this.Context, description));
205                 }
206
207                 public void SendAlert(
208                         TlsAlertLevel           level, 
209                         TlsAlertDescription description)
210                 {
211                         this.SendAlert(new TlsAlert(this.Context, level, description));
212                 }
213
214                 public void SendAlert(TlsAlert alert)
215                 {                       
216                         // Write record
217                         this.SendRecord(TlsContentType.Alert, alert.ToArray());
218
219                         // Update session
220                         alert.Update();
221
222                         // Reset message contents
223                         alert.Reset();
224                 }
225
226                 #endregion
227
228                 #region Send Record Methods
229
230                 public void SendChangeCipherSpec()
231                 {
232                         // Send Change Cipher Spec message
233                         this.SendRecord(TlsContentType.ChangeCipherSpec, new byte[] {1});
234
235                         // Reset sequence numbers
236                         this.context.WriteSequenceNumber = 0;
237
238                         // Make the pending state to be the current state
239                         this.context.IsActual = true;
240
241                         // Send Finished message
242                         this.SendRecord(TlsHandshakeType.Finished);                     
243                 }
244
245                 public void SendRecord(TlsContentType contentType, byte[] recordData)
246                 {
247                         if (this.context.ConnectionEnd)
248                         {
249                                 throw this.context.CreateException("The session is finished and it's no longer valid.");
250                         }
251
252                         byte[] record = this.EncodeRecord(contentType, recordData);
253
254                         this.innerStream.Write(record, 0, record.Length);
255                 }
256
257                 public byte[] EncodeRecord(TlsContentType contentType, byte[] recordData)
258                 {
259                         return this.EncodeRecord(
260                                 contentType,
261                                 recordData,
262                                 0,
263                                 recordData.Length);
264                 }
265
266                 public byte[] EncodeRecord(
267                         TlsContentType  contentType, 
268                         byte[]                  recordData,
269                         int                             offset,
270                         int                             count)
271                 {
272                         if (this.context.ConnectionEnd)
273                         {
274                                 throw this.context.CreateException("The session is finished and it's no longer valid.");
275                         }
276
277                         TlsStream record = new TlsStream();
278
279                         int     position = offset;
280
281                         while (position < ( offset + count ))
282                         {
283                                 short   fragmentLength = 0;
284                                 byte[]  fragment;
285
286                                 if ((count - position) > Context.MAX_FRAGMENT_SIZE)
287                                 {
288                                         fragmentLength = Context.MAX_FRAGMENT_SIZE;
289                                 }
290                                 else
291                                 {
292                                         fragmentLength = (short)(count - position);
293                                 }
294
295                                 // Fill the fragment data
296                                 fragment = new byte[fragmentLength];
297                                 Buffer.BlockCopy(recordData, position, fragment, 0, fragmentLength);
298
299                                 if (this.context.IsActual)
300                                 {
301                                         // Encrypt fragment
302                                         fragment = this.encryptRecordFragment(contentType, fragment);
303                                 }
304
305                                 // Write tls message
306                                 record.Write((byte)contentType);
307                                 record.Write(this.context.Protocol);
308                                 record.Write((short)fragment.Length);
309                                 record.Write(fragment);
310
311                                 // Update buffer position
312                                 position += fragmentLength;
313                         }
314
315                         return record.ToArray();
316                 }
317                 
318                 #endregion
319
320                 #region Cryptography Methods
321
322                 private byte[] encryptRecordFragment(
323                         TlsContentType  contentType, 
324                         byte[]                  fragment)
325                 {
326                         // Calculate message MAC
327                         byte[] mac      = this.context.Cipher.ComputeClientRecordMAC(contentType, fragment);
328
329                         // Encrypt the message
330                         byte[] ecr = this.context.Cipher.EncryptRecord(fragment, mac);
331
332                         // Set new IV
333                         if (this.context.Cipher.CipherMode == CipherMode.CBC)
334                         {
335                                 byte[] iv = new byte[this.context.Cipher.IvSize];
336                                 System.Array.Copy(ecr, ecr.Length - iv.Length, iv, 0, iv.Length);
337                                 this.context.Cipher.UpdateClientCipherIV(iv);
338                         }
339
340                         // Update sequence number
341                         this.context.WriteSequenceNumber++;
342
343                         return ecr;
344                 }
345
346                 private TlsStream decryptRecordFragment(
347                         TlsContentType  contentType, 
348                         byte[]                  fragment)
349                 {
350                         byte[]  dcrFragment     = null;
351                         byte[]  dcrMAC          = null;
352
353                         // Decrypt message
354                         this.context.Cipher.DecryptRecord(fragment, ref dcrFragment, ref dcrMAC);
355
356                         // Set new IV
357                         if (this.context.Cipher.CipherMode == CipherMode.CBC)
358                         {
359                                 byte[] iv = new byte[this.context.Cipher.IvSize];
360                                 System.Array.Copy(fragment, fragment.Length - iv.Length, iv, 0, iv.Length);
361                                 this.context.Cipher.UpdateServerCipherIV(iv);
362                         }
363                         
364                         // Check MAC code
365                         byte[] mac = this.context.Cipher.ComputeServerRecordMAC(contentType, dcrFragment);
366
367                         // Check that the mac is correct
368                         if (mac.Length != dcrMAC.Length)
369                         {
370                                 throw new TlsException("Invalid MAC received from server.");
371                         }
372                         for (int i = 0; i < mac.Length; i++)
373                         {
374                                 if (mac[i] != dcrMAC[i])
375                                 {
376                                         throw new TlsException("Invalid MAC received from server.");
377                                 }
378                         }
379
380                         // Update sequence number
381                         this.context.ReadSequenceNumber++;
382
383                         return new TlsStream(dcrFragment);
384                 }
385
386                 #endregion
387         }
388 }