2004-02-23 Carlos Guzman Alvarez <carlosga@telefonica.net>
[mono.git] / mcs / class / Mono.Security / Mono.Security.Protocol.Tls / RecordProtocol.cs
1 /* Transport Security Layer (TLS)
2  * Copyright (c) 2003-2004 Carlos Guzman Alvarez
3  * 
4  * Permission is hereby granted, free of charge, to any person 
5  * obtaining a copy of this software and associated documentation 
6  * files (the "Software"), to deal in the Software without restriction, 
7  * including without limitation the rights to use, copy, modify, merge, 
8  * publish, distribute, sublicense, and/or sell copies of the Software, 
9  * and to permit persons to whom the Software is furnished to do so, 
10  * subject to the following conditions:
11  * 
12  * The above copyright notice and this permission notice shall be included 
13  * in all copies or substantial portions of the Software.
14  * 
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 
16  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 
17  * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 
18  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 
19  * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 
20  * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 
22  * DEALINGS IN THE SOFTWARE.
23  */
24
25 using System;
26 using System.IO;
27 using System.Security.Cryptography;
28 using System.Security.Cryptography.X509Certificates;
29
30 using Mono.Security.Protocol.Tls.Alerts;
31 using Mono.Security.Protocol.Tls.Handshake;
32
33 namespace Mono.Security.Protocol.Tls
34 {
35         internal abstract class RecordProtocol
36         {
37                 #region Fields
38
39                 protected Stream        innerStream;
40                 protected Context       context;
41
42                 #endregion
43
44                 #region Properties
45
46                 public Stream InnerStream
47                 {
48                         get { return this.innerStream; }
49                         set { this.innerStream = value; }
50                 }
51
52                 public Context Context
53                 {
54                         get { return this.context; }
55                         set { this.context = value; }
56                 }
57
58                 #endregion
59
60                 #region Constructors
61
62                 public RecordProtocol(Stream innerStream, Context context)
63                 {
64                         this.innerStream        = innerStream;
65                         this.context            = context;
66                 }
67
68                 #endregion
69
70                 #region Abstract Methods
71
72                 public abstract void SendRecord(TlsHandshakeType type);
73                 protected abstract void ProcessHandshakeMessage(TlsStream handMsg);
74                                 
75                 #endregion
76
77                 #region Reveive Record Methods
78
79                 public byte[] ReceiveRecord()
80                 {
81                         if (this.context.ConnectionEnd)
82                         {
83                                 throw this.context.CreateException("The session is finished and it's no longer valid.");
84                         }
85                         
86                         // Try to read the Record Content Type
87                         int type = this.innerStream.ReadByte();
88
89                         // There are no more data for read
90                         if (type == -1)
91                         {
92                                 return null;
93                         }
94
95                         TlsContentType  contentType     = (TlsContentType)type;
96                         short                   protocol        = this.readShort();
97                         short                   length          = this.readShort();
98                         
99                         // Read Record data
100                         int             received        = 0;
101                         byte[]  buffer          = new byte[length];
102                         while (received != length)
103                         {
104                                 received += this.innerStream.Read(
105                                         buffer, received, buffer.Length - received);
106                         }
107
108                         TlsStream message = new TlsStream(buffer);
109                 
110                         // Check that the message has a valid protocol version
111                         if (protocol != this.context.Protocol && this.context.ProtocolNegotiated)
112                         {
113                                 throw this.context.CreateException("Invalid protocol version on message received from server");
114                         }
115
116                         // Decrypt message contents if needed
117                         if (contentType == TlsContentType.Alert && length == 2)
118                         {
119                         }
120                         else
121                         {
122                                 if (this.context.IsActual &&
123                                         contentType != TlsContentType.ChangeCipherSpec)
124                                 {
125                                         message = this.decryptRecordFragment(
126                                                 contentType, 
127                                                 message.ToArray());
128                                 }
129                         }
130
131                         byte[] result = message.ToArray();
132
133                         // Process record
134                         switch (contentType)
135                         {
136                                 case TlsContentType.Alert:
137                                         this.processAlert(
138                                                 (TlsAlertLevel)message.ReadByte(),
139                                                 (TlsAlertDescription)message.ReadByte());
140                                         break;
141
142                                 case TlsContentType.ChangeCipherSpec:
143                                         // Reset sequence numbers
144                                         this.context.ReadSequenceNumber = 0;
145                                         break;
146
147                                 case TlsContentType.ApplicationData:
148                                         break;
149
150                                 case TlsContentType.Handshake:
151                                         while (!message.EOF)
152                                         {
153                                                 this.ProcessHandshakeMessage(message);
154                                         }
155
156                                         // Update handshakes of current messages
157                                         this.context.HandshakeMessages.Write(message.ToArray());
158                                         break;
159
160                                 default:
161                                         throw this.context.CreateException("Unknown record received from server.");
162                         }
163
164                         return result;
165                 }
166
167                 private short readShort()
168                 {
169                         byte[] b = new byte[2];
170                         this.innerStream.Read(b, 0, b.Length);
171
172                         short val = BitConverter.ToInt16(b, 0);
173
174                         return System.Net.IPAddress.HostToNetworkOrder(val);
175                 }
176
177                 private void processAlert(
178                         TlsAlertLevel           alertLevel, 
179                         TlsAlertDescription alertDesc)
180                 {
181                         switch (alertLevel)
182                         {
183                                 case TlsAlertLevel.Fatal:
184                                         throw this.context.CreateException(alertLevel, alertDesc);                                      
185
186                                 case TlsAlertLevel.Warning:
187                                 default:
188                                 switch (alertDesc)
189                                 {
190                                         case TlsAlertDescription.CloseNotify:
191                                                 this.context.ConnectionEnd = true;
192                                                 break;
193                                 }
194                                 break;
195                         }
196                 }
197
198                 #endregion
199
200                 #region Send Record Methods
201
202                 public void SendAlert(TlsAlert alert)
203                 {                       
204                         // Write record
205                         this.SendRecord(TlsContentType.Alert, alert.ToArray());
206
207                         // Update session
208                         alert.Update();
209
210                         // Reset message contents
211                         alert.Reset();
212                 }
213
214                 public void SendChangeCipherSpec()
215                 {
216                         // Send Change Cipher Spec message
217                         this.SendRecord(TlsContentType.ChangeCipherSpec, new byte[] {1});
218
219                         // Reset sequence numbers
220                         this.context.WriteSequenceNumber = 0;
221
222                         // Make the pending state to be the current state
223                         this.context.IsActual = true;
224
225                         // Send Finished message
226                         this.SendRecord(TlsHandshakeType.Finished);                     
227                 }
228
229                 public void SendRecord(TlsContentType contentType, byte[] recordData)
230                 {
231                         if (this.context.ConnectionEnd)
232                         {
233                                 throw this.context.CreateException("The session is finished and it's no longer valid.");
234                         }
235
236                         byte[] record = this.EncodeRecord(contentType, recordData);
237
238                         this.innerStream.Write(record, 0, record.Length);
239                 }
240
241                 public byte[] EncodeRecord(TlsContentType contentType, byte[] recordData)
242                 {
243                         return this.EncodeRecord(
244                                 contentType,
245                                 recordData,
246                                 0,
247                                 recordData.Length);
248                 }
249
250                 public byte[] EncodeRecord(
251                         TlsContentType  contentType, 
252                         byte[]                  recordData,
253                         int                             offset,
254                         int                             count)
255                 {
256                         if (this.context.ConnectionEnd)
257                         {
258                                 throw this.context.CreateException("The session is finished and it's no longer valid.");
259                         }
260
261                         TlsStream record = new TlsStream();
262
263                         int     position = offset;
264
265                         while (position < ( offset + count ))
266                         {
267                                 short   fragmentLength = 0;
268                                 byte[]  fragment;
269
270                                 if ((count - position) > Context.MAX_FRAGMENT_SIZE)
271                                 {
272                                         fragmentLength = Context.MAX_FRAGMENT_SIZE;
273                                 }
274                                 else
275                                 {
276                                         fragmentLength = (short)(count - position);
277                                 }
278
279                                 // Fill the fragment data
280                                 fragment = new byte[fragmentLength];
281                                 Buffer.BlockCopy(recordData, position, fragment, 0, fragmentLength);
282
283                                 if (this.context.IsActual)
284                                 {
285                                         // Encrypt fragment
286                                         fragment = this.encryptRecordFragment(contentType, fragment);
287                                 }
288
289                                 // Write tls message
290                                 record.Write((byte)contentType);
291                                 record.Write(this.context.Protocol);
292                                 record.Write((short)fragment.Length);
293                                 record.Write(fragment);
294
295                                 // Update buffer position
296                                 position += fragmentLength;
297                         }
298
299                         return record.ToArray();
300                 }
301                 
302                 #endregion
303
304                 #region Cryptography Methods
305
306                 private byte[] encryptRecordFragment(
307                         TlsContentType  contentType, 
308                         byte[]                  fragment)
309                 {
310                         // Calculate message MAC
311                         byte[] mac      = this.context.Cipher.ComputeClientRecordMAC(contentType, fragment);
312
313                         // Encrypt the message
314                         byte[] ecr = this.context.Cipher.EncryptRecord(fragment, mac);
315
316                         // Set new IV
317                         if (this.context.Cipher.CipherMode == CipherMode.CBC)
318                         {
319                                 byte[] iv = new byte[this.context.Cipher.IvSize];
320                                 System.Array.Copy(ecr, ecr.Length - iv.Length, iv, 0, iv.Length);
321                                 this.context.Cipher.UpdateClientCipherIV(iv);
322                         }
323
324                         // Update sequence number
325                         this.context.WriteSequenceNumber++;
326
327                         return ecr;
328                 }
329
330                 private TlsStream decryptRecordFragment(
331                         TlsContentType  contentType, 
332                         byte[]                  fragment)
333                 {
334                         byte[]  dcrFragment     = null;
335                         byte[]  dcrMAC          = null;
336
337                         // Decrypt message
338                         this.context.Cipher.DecryptRecord(fragment, ref dcrFragment, ref dcrMAC);
339
340                         // Set new IV
341                         if (this.context.Cipher.CipherMode == CipherMode.CBC)
342                         {
343                                 byte[] iv = new byte[this.context.Cipher.IvSize];
344                                 System.Array.Copy(fragment, fragment.Length - iv.Length, iv, 0, iv.Length);
345                                 this.context.Cipher.UpdateServerCipherIV(iv);
346                         }
347                         
348                         // Check MAC code
349                         byte[] mac = this.context.Cipher.ComputeServerRecordMAC(contentType, dcrFragment);
350
351                         // Check that the mac is correct
352                         if (mac.Length != dcrMAC.Length)
353                         {
354                                 throw new TlsException("Invalid MAC received from server.");
355                         }
356                         for (int i = 0; i < mac.Length; i++)
357                         {
358                                 if (mac[i] != dcrMAC[i])
359                                 {
360                                         throw new TlsException("Invalid MAC received from server.");
361                                 }
362                         }
363
364                         // Update sequence number
365                         this.context.ReadSequenceNumber++;
366
367                         return new TlsStream(dcrFragment);
368                 }
369
370                 #endregion
371         }
372 }