2008-11-01 Marek Habersack <mhabersack@novell.com>
[mono.git] / mcs / class / System.ServiceModel / System.ServiceModel.Security.Tokens / TlsClientSession.cs
1 //
2 // TlsClientSession.cs
3 //
4 // Author:
5 //      Atsushi Enomoto <atsushi@ximian.com>
6 //
7 // Copyright (C) 2007 Novell, Inc.  http://www.novell.com
8 //
9 // Permission is hereby granted, free of charge, to any person obtaining
10 // a copy of this software and associated documentation files (the
11 // "Software"), to deal in the Software without restriction, including
12 // without limitation the rights to use, copy, modify, merge, publish,
13 // distribute, sublicense, and/or sell copies of the Software, and to
14 // permit persons to whom the Software is furnished to do so, subject to
15 // the following conditions:
16 // 
17 // The above copyright notice and this permission notice shall be
18 // included in all copies or substantial portions of the Software.
19 // 
20 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
21 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
22 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
23 // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
24 // LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
25 // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
26 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
27 //
28 using System;
29 using System.Collections.Generic;
30 using System.IO;
31 using System.Security.Cryptography;
32 using System.Security.Cryptography.X509Certificates;
33 using System.Text;
34 using Mono.Security.Protocol.Tls;
35 using Mono.Security.Protocol.Tls.Handshake;
36 using Mono.Security.Protocol.Tls.Handshake.Client;
37
38 namespace System.ServiceModel.Security.Tokens
39 {
40         internal abstract class TlsSession
41         {
42                 protected abstract Context Context { get; }
43
44                 protected abstract RecordProtocol Protocol { get; }
45
46                 public byte [] MasterSecret {
47                         get { return Context.MasterSecret; }
48                 }
49
50                 public byte [] CreateHash (byte [] key, byte [] seedSrc, string label)
51                 {
52                         byte [] labelBytes = Encoding.UTF8.GetBytes (label);
53                         byte [] seed = new byte [seedSrc.Length + labelBytes.Length];
54                         Array.Copy (seedSrc, seed, seedSrc.Length);
55                         Array.Copy (labelBytes, 0, seed, seedSrc.Length, labelBytes.Length);
56                         return Context.Current.Cipher.Expand ("SHA1", key, seed, 256 / 8);
57                 }
58
59                 public byte [] CreateHashAlt (byte [] key, byte [] seed, string label)
60                 {
61                         return Context.Current.Cipher.PRF (key, label, seed, 256 / 8);
62                 }
63
64                 protected void WriteHandshake (MemoryStream ms)
65                 {
66                         Context.SupportedCiphers = CipherSuiteFactory.GetSupportedCiphers (SecurityProtocolType.Tls);
67                         ms.WriteByte (0x16); // Handshake
68                         ms.WriteByte (3); // version-major
69                         ms.WriteByte (1); // version-minor
70                 }
71
72                 protected void WriteChangeCipherSpec (MemoryStream ms)
73                 {
74                         ms.WriteByte (0x14); // Handshake
75                         ms.WriteByte (3); // version-major
76                         ms.WriteByte (1); // version-minor
77                         ms.WriteByte (0); // size-upper
78                         ms.WriteByte (1); // size-lower
79                         ms.WriteByte (1); // ChangeCipherSpec content (1 byte)
80                 }
81
82                 protected void ReadHandshake (MemoryStream ms)
83                 {
84                         if (ms.ReadByte () != 0x16)
85                                 throw new Exception ("INTERNAL ERROR: handshake is expected");
86                         Context.ChangeProtocol ((short) (ms.ReadByte () * 0x100 + ms.ReadByte ()));
87                 }
88
89                 protected void ReadChangeCipherSpec (MemoryStream ms)
90                 {
91                         if (ms.ReadByte () != 0x14)
92                                 throw new Exception ("INTERNAL ERROR: ChangeCipherSpec is expected");
93                         Context.ChangeProtocol ((short) (ms.ReadByte () * 0x100 + ms.ReadByte ()));
94                         if (ms.ReadByte () * 0x100 + ms.ReadByte () != 1)
95                                 throw new Exception ("INTERNAL ERROR: unexpected ChangeCipherSpec length");
96                         ms.ReadByte (); // ChangeCipherSpec content (1 byte) ... anything is OK?
97                 }
98
99                 protected byte [] ReadNextOperation (MemoryStream ms, HandshakeType expected)
100                 {
101                         if (ms.ReadByte () != (int) expected)
102                                 throw new Exception ("INTERNAL ERROR: unexpected server response");
103                         int size = ms.ReadByte () * 0x10000 + ms.ReadByte () * 0x100 + ms.ReadByte ();
104                         // FIXME: use correct valid input range
105                         if (size > 0x100000)
106                                 throw new Exception ("rejected massive input size.");
107                         byte [] bytes = new byte [size];
108                         ms.Read (bytes, 0, size);
109                         return bytes;
110                 }
111
112                 protected void WriteOperations (MemoryStream ms, params HandshakeMessage [] msgs)
113                 {
114                         List<byte []> rawbufs = new List<byte []> ();
115                         int total = 0;
116                         for (int i = 0; i < msgs.Length; i++) {
117                                 HandshakeMessage msg = msgs [i];
118                                 msg.Process ();
119                                 rawbufs.Add (msg.EncodeMessage ());
120                                 total += rawbufs [i].Length;
121                                 msg.Update ();
122                         }
123                         // FIXME: split packets when the size exceeded 0x10000 (or so)
124                         ms.WriteByte ((byte) (total / 0x100));
125                         ms.WriteByte ((byte) (total % 0x100));
126                         foreach (byte [] bytes in rawbufs)
127                                 ms.Write (bytes, 0, bytes.Length);
128                 }
129
130                 protected void VerifyEndOfTransmit (MemoryStream ms)
131                 {
132                         if (ms.Position == ms.Length)
133                                 return;
134
135                         /*
136                         byte [] bytes = new byte [ms.Length - ms.Position];
137                         ms.Read (bytes, 0, bytes.Length);
138                         foreach (byte b in bytes)
139                                 Console.Write ("{0:X02} ", b);
140                         Console.WriteLine (" - total {0} bytes remained.", bytes.Length);
141                         */
142
143                         throw new Exception ("INTERNAL ERROR: unexpected server response");
144                 }
145         }
146
147         internal class TlsClientSession : TlsSession
148         {
149                 SslClientStream ssl;
150                 MemoryStream stream;
151                 bool mutual;
152
153                 public TlsClientSession (string host, X509Certificate2 clientCert)
154                 {
155                         stream = new MemoryStream ();
156                         if (clientCert == null)
157                                 ssl = new SslClientStream (stream, host, true, SecurityProtocolType.Tls);
158                         else {
159                                 ssl = new SslClientStream (stream, host, true, SecurityProtocolType.Tls, new X509CertificateCollection (new X509Certificate [] {clientCert}));
160                                 mutual = true;
161                                 ssl.ClientCertSelection += delegate (
162                                         X509CertificateCollection clientCertificates,
163                                 X509Certificate serverCertificate,
164                                 string targetHost,
165                                 X509CertificateCollection serverRequestedCertificates) {
166                                         return clientCertificates [0];
167                                 };
168                         }
169                 }
170
171                 protected override Context Context {
172                         get { return ssl.context; }
173                 }
174
175                 protected override RecordProtocol Protocol {
176                         get { return ssl.protocol; }
177                 }
178
179                 public byte [] ProcessClientHello ()
180                 {
181                         Context.SupportedCiphers = CipherSuiteFactory.GetSupportedCiphers (Context.SecurityProtocol);
182                         Context.HandshakeState = HandshakeState.Started;
183                         Protocol.SendRecord (HandshakeType.ClientHello);
184                         stream.Flush ();
185                         return stream.ToArray ();
186                 }
187
188                 // ServerHello, ServerCertificate and ServerHelloDone
189                 public void ProcessServerHello (byte [] raw)
190                 {
191                         stream.SetLength (0);
192                         stream.Write (raw, 0, raw.Length);
193                         stream.Seek (0, SeekOrigin.Begin);
194
195                         Protocol.ReceiveRecord (stream); // ServerHello
196                         Protocol.ReceiveRecord (stream); // ServerCertificate
197                         if (mutual)
198                                 Protocol.ReceiveRecord (stream); // CertificateRequest
199                         Protocol.ReceiveRecord (stream); // ServerHelloDone
200                         if (stream.Position != stream.Length)
201                                 throw new SecurityNegotiationException (String.Format ("Unexpected SSL negotiation binary: {0} bytes of excess in {1} bytes of the octets", stream.Length - stream.Position, stream.Length));
202                 }
203
204                 public byte [] ProcessClientKeyExchange ()
205                 {
206                         stream.SetLength (0);
207                         if (mutual)
208                                 Protocol.SendRecord (HandshakeType.Certificate);
209                         Protocol.SendRecord (HandshakeType.ClientKeyExchange);
210                         Context.Negotiating.Cipher.ComputeKeys ();
211                         Context.Negotiating.Cipher.InitializeCipher ();
212                         Protocol.SendChangeCipherSpec ();
213                         Context.SupportedCiphers = CipherSuiteFactory.GetSupportedCiphers (SecurityProtocolType.Tls);
214                         Protocol.SendRecord (HandshakeType.Finished);
215                         stream.Flush ();
216                         return stream.ToArray ();
217                 }
218
219                 public void ProcessServerFinished (byte [] raw)
220                 {
221                         stream.SetLength (0);
222                         stream.Write (raw, 0, raw.Length);
223                         stream.Seek (0, SeekOrigin.Begin);
224
225                         Protocol.ReceiveRecord (stream); // ChangeCipherSpec
226                         Protocol.ReceiveRecord (stream); // ServerFinished
227                 }
228
229                 public byte [] ProcessApplicationData (byte [] raw)
230                 {
231                         stream.SetLength (0);
232                         stream.Write (raw, 0, raw.Length);
233                         stream.Seek (0, SeekOrigin.Begin);
234                         return Protocol.ReceiveRecord (stream); // ApplicationData
235                 }
236         }
237 }