Merge pull request #963 from kebby/master
[mono.git] / mcs / class / System / System.Net.WebSockets / ClientWebSocket.cs
1 //
2 // ClientWebSocket.cs
3 //
4 // Authors:
5 //        Jérémie Laval <jeremie dot laval at xamarin dot com>
6 //
7 // Copyright 2013 Xamarin Inc (http://www.xamarin.com).
8 //
9 // Lightly inspired from WebSocket4Net distributed under the Apache License 2.0
10 //
11 // Permission is hereby granted, free of charge, to any person obtaining a copy
12 // of this software and associated documentation files (the "Software"), to deal
13 // in the Software without restriction, including without limitation the rights
14 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 // copies of the Software, and to permit persons to whom the Software is
16 // furnished to do so, subject to the following conditions:
17 //
18 // The above copyright notice and this permission notice shall be included in
19 // all copies or substantial portions of the Software.
20 //
21 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 // THE SOFTWARE.
28
29 #if NET_4_5
30
31 using System;
32 using System.Net;
33 using System.Net.Sockets;
34 using System.Security.Principal;
35 using System.Security.Cryptography.X509Certificates;
36 using System.Runtime.CompilerServices;
37 using System.Collections.Generic;
38 using System.Threading;
39 using System.Threading.Tasks;
40 using System.Globalization;
41 using System.Text;
42 using System.Security.Cryptography;
43
44 namespace System.Net.WebSockets
45 {
46         public class ClientWebSocket : WebSocket, IDisposable
47         {
48                 const string Magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
49                 const string VersionTag = "13";
50
51                 ClientWebSocketOptions options;
52                 WebSocketState state;
53                 string subProtocol;
54
55                 HttpWebRequest req;
56                 WebConnection connection;
57                 Socket underlyingSocket;
58
59                 Random random = new Random ();
60
61                 const int HeaderMaxLength = 14;
62                 byte[] headerBuffer;
63                 byte[] sendBuffer;
64
65                 public ClientWebSocket ()
66                 {
67                         options = new ClientWebSocketOptions ();
68                         state = WebSocketState.None;
69                         headerBuffer = new byte[HeaderMaxLength];
70                 }
71
72                 public override void Dispose ()
73                 {
74                         if (connection != null)
75                                 connection.Close (false);
76                 }
77
78                 [MonoTODO]
79                 public override void Abort ()
80                 {
81                         throw new NotImplementedException ();
82                 }
83
84                 public ClientWebSocketOptions Options {
85                         get {
86                                 return options;
87                         }
88                 }
89
90                 public override WebSocketState State {
91                         get {
92                                 return state;
93                         }
94                 }
95
96                 public override WebSocketCloseStatus? CloseStatus {
97                         get {
98                                 if (state != WebSocketState.Closed)
99                                         return (WebSocketCloseStatus?)null;
100                                 return WebSocketCloseStatus.Empty;
101                         }
102                 }
103
104                 public override string CloseStatusDescription {
105                         get {
106                                 return null;
107                         }
108                 }
109
110                 public override string SubProtocol {
111                         get {
112                                 return subProtocol;
113                         }
114                 }
115
116                 public async Task ConnectAsync (Uri uri, CancellationToken cancellationToken)
117                 {
118                         state = WebSocketState.Connecting;
119                         var httpUri = new UriBuilder (uri);
120                         if (uri.Scheme == "wss")
121                                 httpUri.Scheme = "https";
122                         else
123                                 httpUri.Scheme = "http";
124                         req = (HttpWebRequest)WebRequest.Create (httpUri.Uri);
125                         req.ReuseConnection = true;
126                         if (options.Cookies != null)
127                                 req.CookieContainer = options.Cookies;
128
129                         if (options.CustomRequestHeaders.Count > 0) {
130                                 foreach (var header in options.CustomRequestHeaders)
131                                         req.Headers[header.Key] = header.Value;
132                         }
133
134                         var secKey = Convert.ToBase64String (Encoding.ASCII.GetBytes (Guid.NewGuid ().ToString ().Substring (0, 16)));
135                         string expectedAccept = Convert.ToBase64String (SHA1.Create ().ComputeHash (Encoding.ASCII.GetBytes (secKey + Magic)));
136
137                         req.Headers["Upgrade"] = "WebSocket";
138                         req.Headers["Sec-WebSocket-Version"] = VersionTag;
139                         req.Headers["Sec-WebSocket-Key"] = secKey;
140                         req.Headers["Sec-WebSocket-Origin"] = uri.Host;
141                         if (options.SubProtocols.Count > 0)
142                                 req.Headers["Sec-WebSocket-Protocol"] = string.Join (",", options.SubProtocols);
143
144                         if (options.Credentials != null)
145                                 req.Credentials = options.Credentials;
146                         if (options.ClientCertificates != null)
147                                 req.ClientCertificates = options.ClientCertificates;
148                         if (options.Proxy != null)
149                                 req.Proxy = options.Proxy;
150                         req.UseDefaultCredentials = options.UseDefaultCredentials;
151                         req.Connection = "Upgrade";
152
153                         HttpWebResponse resp = null;
154                         try {
155                                 resp = (HttpWebResponse)(await req.GetResponseAsync ().ConfigureAwait (false));
156                         } catch (Exception e) {
157                                 throw new WebSocketException (WebSocketError.Success, e);
158                         }
159
160                         connection = req.StoredConnection;
161                         underlyingSocket = connection.socket;
162
163                         if (resp.StatusCode != HttpStatusCode.SwitchingProtocols)
164                                 throw new WebSocketException ("The server returned status code '" + (int)resp.StatusCode + "' when status code '101' was expected");
165                         if (!string.Equals (resp.Headers["Upgrade"], "WebSocket", StringComparison.OrdinalIgnoreCase)
166                                 || !string.Equals (resp.Headers["Connection"], "Upgrade", StringComparison.OrdinalIgnoreCase)
167                                 || !string.Equals (resp.Headers["Sec-WebSocket-Accept"], expectedAccept))
168                                 throw new WebSocketException ("HTTP header error during handshake");
169                         if (resp.Headers["Sec-WebSocket-Protocol"] != null) {
170                                 if (!options.SubProtocols.Contains (resp.Headers["Sec-WebSocket-Protocol"]))
171                                         throw new WebSocketException (WebSocketError.UnsupportedProtocol);
172                                 subProtocol = resp.Headers["Sec-WebSocket-Protocol"];
173                         }
174
175                         state = WebSocketState.Open;
176                 }
177
178                 public override Task SendAsync (ArraySegment<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
179                 {
180                         EnsureWebSocketConnected ();
181                         ValidateArraySegment (buffer);
182                         if (connection == null)
183                                 throw new WebSocketException (WebSocketError.Faulted);
184                         var count = Math.Max (options.SendBufferSize, buffer.Count) + HeaderMaxLength;
185                         if (sendBuffer == null || sendBuffer.Length != count)
186                                 sendBuffer = new byte[count];
187                         return Task.Run (() => {
188                                 EnsureWebSocketState (WebSocketState.Open, WebSocketState.CloseReceived);
189                                 var maskOffset = WriteHeader (messageType, buffer, endOfMessage);
190
191                                 if (buffer.Count > 0)
192                                         MaskData (buffer, maskOffset);
193                                 //underlyingSocket.Send (headerBuffer, 0, maskOffset + 4, SocketFlags.None);
194                                 var headerLength = maskOffset + 4;
195                                 Array.Copy (headerBuffer, sendBuffer, headerLength);
196                                 underlyingSocket.Send (sendBuffer, 0, buffer.Count + headerLength, SocketFlags.None);
197                         });
198                 }
199
200                 public override Task<WebSocketReceiveResult> ReceiveAsync (ArraySegment<byte> buffer, CancellationToken cancellationToken)
201                 {
202                         EnsureWebSocketConnected ();
203                         ValidateArraySegment (buffer);
204                         return Task.Run (() => {
205                                 EnsureWebSocketState (WebSocketState.Open, WebSocketState.CloseSent);
206                                 // First read the two first bytes to know what we are doing next
207                                 connection.Read (req, headerBuffer, 0, 2);
208                                 var isLast = (headerBuffer[0] >> 7) > 0;
209                                 var isMasked = (headerBuffer[1] >> 7) > 0;
210                                 int mask = 0;
211                                 var type = (WebSocketMessageType)(headerBuffer[0] & 0xF);
212                                 long length = headerBuffer[1] & 0x7F;
213                                 int offset = 0;
214                                 if (length == 126) {
215                                         offset = 2;
216                                         connection.Read (req, headerBuffer, 2, offset);
217                                         length = (headerBuffer[2] << 8) | headerBuffer[3];
218                                 } else if (length == 127) {
219                                         offset = 8;
220                                         connection.Read (req, headerBuffer, 2, offset);
221                                         length = 0;
222                                         for (int i = 2; i <= 9; i++)
223                                                 length = (length << 8) | headerBuffer[i];
224                                 }
225
226                                 if (isMasked) {
227                                         connection.Read (req, headerBuffer, 2 + offset, 4);
228                                         for (int i = 0; i < 4; i++) {
229                                                 var pos = i + offset + 2;
230                                                 mask = (mask << 8) | headerBuffer[pos];
231                                         }
232                                 }
233
234                                 if (type == WebSocketMessageType.Close) {
235                                         state = WebSocketState.Closed;
236                                         var tmpBuffer = new byte[length];
237                                         connection.Read (req, tmpBuffer, 0, tmpBuffer.Length);
238                                         var closeStatus = (WebSocketCloseStatus)(tmpBuffer[0] << 8 | tmpBuffer[1]);
239                                         var closeDesc = tmpBuffer.Length > 2 ? Encoding.UTF8.GetString (tmpBuffer, 2, tmpBuffer.Length - 2) : string.Empty;
240                                         return new WebSocketReceiveResult ((int)length, type, isLast, closeStatus, closeDesc);
241                                 } else {
242                                         var readLength = (int)(buffer.Count < length ? buffer.Count : length);
243                                         connection.Read (req, buffer.Array, buffer.Offset, readLength);
244
245                                         return new WebSocketReceiveResult ((int)length, type, isLast);
246                                 }
247                         });
248                 }
249
250                 // The damn difference between those two methods is that CloseAsync will wait for server acknowledgement before completing
251                 // while CloseOutputAsync will send the close packet and simply complete.
252
253                 public async override Task CloseAsync (WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
254                 {
255                         EnsureWebSocketConnected ();
256                         await SendCloseFrame (closeStatus, statusDescription, cancellationToken).ConfigureAwait (false);
257                         state = WebSocketState.CloseSent;
258                         // TODO: figure what's exceptions are thrown if the server returns something faulty here
259                         await ReceiveAsync (new ArraySegment<byte> (new byte[0]), cancellationToken).ConfigureAwait (false);
260                         state = WebSocketState.Closed;
261                 }
262
263                 public async override Task CloseOutputAsync (WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
264                 {
265                         EnsureWebSocketConnected ();
266                         await SendCloseFrame (closeStatus, statusDescription, cancellationToken).ConfigureAwait (false);
267                         state = WebSocketState.CloseSent;
268                 }
269
270                 async Task SendCloseFrame (WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
271                 {
272                         var statusDescBuffer = string.IsNullOrEmpty (statusDescription) ? new byte[2] : new byte[2 + Encoding.UTF8.GetByteCount (statusDescription)];
273                         statusDescBuffer[0] = (byte)(((ushort)closeStatus) >> 8);
274                         statusDescBuffer[1] = (byte)(((ushort)closeStatus) & 0xFF);
275                         if (!string.IsNullOrEmpty (statusDescription))
276                                 Encoding.UTF8.GetBytes (statusDescription, 0, statusDescription.Length, statusDescBuffer, 2);
277                         await SendAsync (new ArraySegment<byte> (statusDescBuffer), WebSocketMessageType.Close, true, cancellationToken).ConfigureAwait (false);
278                 }
279
280                 int WriteHeader (WebSocketMessageType type, ArraySegment<byte> buffer, bool endOfMessage)
281                 {
282                         var opCode = (byte)type;
283                         var length = buffer.Count;
284
285                         headerBuffer[0] = (byte)(opCode | (endOfMessage ? 0 : 0x80));
286                         if (length < 126) {
287                                 headerBuffer[1] = (byte)length;
288                         } else if (length <= ushort.MaxValue) {
289                                 headerBuffer[1] = (byte)126;
290                                 headerBuffer[2] = (byte)(length / 256);
291                                 headerBuffer[3] = (byte)(length % 256);
292                         } else {
293                                 headerBuffer[1] = (byte)127;
294
295                                 int left = length;
296                                 int unit = 256;
297
298                                 for (int i = 9; i > 1; i--) {
299                                         headerBuffer[i] = (byte)(left % unit);
300                                         left = left / unit;
301                                 }
302                         }
303
304                         var l = Math.Max (0, headerBuffer[1] - 125);
305                         var maskOffset = 2 + l * l * 2;
306                         GenerateMask (headerBuffer, maskOffset);
307
308                         // Since we are client only, we always mask the payload
309                         headerBuffer[1] |= 0x80;
310
311                         return maskOffset;
312                 }
313
314                 void GenerateMask (byte[] mask, int offset)
315                 {
316                         mask[offset + 0] = (byte)random.Next (0, 255);
317                         mask[offset + 1] = (byte)random.Next (0, 255);
318                         mask[offset + 2] = (byte)random.Next (0, 255);
319                         mask[offset + 3] = (byte)random.Next (0, 255);
320                 }
321
322                 void MaskData (ArraySegment<byte> buffer, int maskOffset)
323                 {
324                         var sendBufferOffset = maskOffset + 4;
325                         for (var i = 0; i < buffer.Count; i++)
326                                 sendBuffer[i + sendBufferOffset] = (byte)(buffer.Array[buffer.Offset + i] ^ headerBuffer[maskOffset + (i % 4)]);
327                 }
328
329                 void EnsureWebSocketConnected ()
330                 {
331                         if (state < WebSocketState.Open)
332                                 throw new InvalidOperationException ("The WebSocket is not connected");
333                 }
334
335                 void EnsureWebSocketState (params WebSocketState[] validStates)
336                 {
337                         foreach (var validState in validStates)
338                                 if (state == validState)
339                                         return;
340                         throw new WebSocketException ("The WebSocket is in an invalid state ('" + state + "') for this operation. Valid states are: " + string.Join (", ", validStates));
341                 }
342
343                 void ValidateArraySegment (ArraySegment<byte> segment)
344                 {
345                         if (segment.Array == null)
346                                 throw new ArgumentNullException ("buffer.Array");
347                         if (segment.Offset < 0)
348                                 throw new ArgumentOutOfRangeException ("buffer.Offset");
349                         if (segment.Offset + segment.Count > segment.Array.Length)
350                                 throw new ArgumentOutOfRangeException ("buffer.Count");
351                 }
352         }
353 }
354
355 #endif