Merge pull request #1345 from mattleibow/websocket-continuation-frame-fix
[mono.git] / mcs / class / System / System.Net.WebSockets / ClientWebSocket.cs
1 //
2 // ClientWebSocket.cs
3 //
4 // Authors:
5 //        Jérémie Laval <jeremie dot laval at xamarin dot com>
6 //
7 // Copyright 2013 Xamarin Inc (http://www.xamarin.com).
8 //
9 // Lightly inspired from WebSocket4Net distributed under the Apache License 2.0
10 //
11 // Permission is hereby granted, free of charge, to any person obtaining a copy
12 // of this software and associated documentation files (the "Software"), to deal
13 // in the Software without restriction, including without limitation the rights
14 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 // copies of the Software, and to permit persons to whom the Software is
16 // furnished to do so, subject to the following conditions:
17 //
18 // The above copyright notice and this permission notice shall be included in
19 // all copies or substantial portions of the Software.
20 //
21 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 // THE SOFTWARE.
28
29
30 using System;
31 using System.Net;
32 using System.Net.Sockets;
33 using System.Security.Principal;
34 using System.Security.Cryptography.X509Certificates;
35 using System.Runtime.CompilerServices;
36 using System.Collections.Generic;
37 using System.Threading;
38 using System.Threading.Tasks;
39 using System.Globalization;
40 using System.Text;
41 using System.Security.Cryptography;
42
43 namespace System.Net.WebSockets
44 {
45         public class ClientWebSocket : WebSocket, IDisposable
46         {
47                 const string Magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
48                 const string VersionTag = "13";
49
50                 ClientWebSocketOptions options;
51                 WebSocketState state;
52                 string subProtocol;
53
54                 HttpWebRequest req;
55                 WebConnection connection;
56                 Socket underlyingSocket;
57
58                 Random random = new Random ();
59
60                 const int HeaderMaxLength = 14;
61                 byte[] headerBuffer;
62                 byte[] sendBuffer;
63                 long remaining;
64                 WebSocketMessageType currentMessageType;
65
66                 public ClientWebSocket ()
67                 {
68                         options = new ClientWebSocketOptions ();
69                         state = WebSocketState.None;
70                         headerBuffer = new byte[HeaderMaxLength];
71                 }
72
73                 public override void Dispose ()
74                 {
75                         if (connection != null)
76                                 connection.Close (false);
77                 }
78
79                 [MonoTODO]
80                 public override void Abort ()
81                 {
82                         throw new NotImplementedException ();
83                 }
84
85                 public ClientWebSocketOptions Options {
86                         get {
87                                 return options;
88                         }
89                 }
90
91                 public override WebSocketState State {
92                         get {
93                                 return state;
94                         }
95                 }
96
97                 public override WebSocketCloseStatus? CloseStatus {
98                         get {
99                                 if (state != WebSocketState.Closed)
100                                         return (WebSocketCloseStatus?)null;
101                                 return WebSocketCloseStatus.Empty;
102                         }
103                 }
104
105                 public override string CloseStatusDescription {
106                         get {
107                                 return null;
108                         }
109                 }
110
111                 public override string SubProtocol {
112                         get {
113                                 return subProtocol;
114                         }
115                 }
116
117                 public async Task ConnectAsync (Uri uri, CancellationToken cancellationToken)
118                 {
119                         state = WebSocketState.Connecting;
120                         var httpUri = new UriBuilder (uri);
121                         if (uri.Scheme == "wss")
122                                 httpUri.Scheme = "https";
123                         else
124                                 httpUri.Scheme = "http";
125                         req = (HttpWebRequest)WebRequest.Create (httpUri.Uri);
126                         req.ReuseConnection = true;
127                         if (options.Cookies != null)
128                                 req.CookieContainer = options.Cookies;
129
130                         if (options.CustomRequestHeaders.Count > 0) {
131                                 foreach (var header in options.CustomRequestHeaders)
132                                         req.Headers[header.Key] = header.Value;
133                         }
134
135                         var secKey = Convert.ToBase64String (Encoding.ASCII.GetBytes (Guid.NewGuid ().ToString ().Substring (0, 16)));
136                         string expectedAccept = Convert.ToBase64String (SHA1.Create ().ComputeHash (Encoding.ASCII.GetBytes (secKey + Magic)));
137
138                         req.Headers["Upgrade"] = "WebSocket";
139                         req.Headers["Sec-WebSocket-Version"] = VersionTag;
140                         req.Headers["Sec-WebSocket-Key"] = secKey;
141                         req.Headers["Sec-WebSocket-Origin"] = uri.Host;
142                         if (options.SubProtocols.Count > 0)
143                                 req.Headers["Sec-WebSocket-Protocol"] = string.Join (",", options.SubProtocols);
144
145                         if (options.Credentials != null)
146                                 req.Credentials = options.Credentials;
147                         if (options.ClientCertificates != null)
148                                 req.ClientCertificates = options.ClientCertificates;
149                         if (options.Proxy != null)
150                                 req.Proxy = options.Proxy;
151                         req.UseDefaultCredentials = options.UseDefaultCredentials;
152                         req.Connection = "Upgrade";
153
154                         HttpWebResponse resp = null;
155                         try {
156                                 resp = (HttpWebResponse)(await req.GetResponseAsync ().ConfigureAwait (false));
157                         } catch (Exception e) {
158                                 throw new WebSocketException (WebSocketError.Success, e);
159                         }
160
161                         connection = req.StoredConnection;
162                         underlyingSocket = connection.socket;
163
164                         if (resp.StatusCode != HttpStatusCode.SwitchingProtocols)
165                                 throw new WebSocketException ("The server returned status code '" + (int)resp.StatusCode + "' when status code '101' was expected");
166                         if (!string.Equals (resp.Headers["Upgrade"], "WebSocket", StringComparison.OrdinalIgnoreCase)
167                                 || !string.Equals (resp.Headers["Connection"], "Upgrade", StringComparison.OrdinalIgnoreCase)
168                                 || !string.Equals (resp.Headers["Sec-WebSocket-Accept"], expectedAccept))
169                                 throw new WebSocketException ("HTTP header error during handshake");
170                         if (resp.Headers["Sec-WebSocket-Protocol"] != null) {
171                                 if (!options.SubProtocols.Contains (resp.Headers["Sec-WebSocket-Protocol"]))
172                                         throw new WebSocketException (WebSocketError.UnsupportedProtocol);
173                                 subProtocol = resp.Headers["Sec-WebSocket-Protocol"];
174                         }
175
176                         state = WebSocketState.Open;
177                 }
178
179                 public override Task SendAsync (ArraySegment<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
180                 {
181                         EnsureWebSocketConnected ();
182                         ValidateArraySegment (buffer);
183                         if (connection == null)
184                                 throw new WebSocketException (WebSocketError.Faulted);
185                         var count = Math.Max (options.SendBufferSize, buffer.Count) + HeaderMaxLength;
186                         if (sendBuffer == null || sendBuffer.Length != count)
187                                 sendBuffer = new byte[count];
188                         return Task.Run (() => {
189                                 EnsureWebSocketState (WebSocketState.Open, WebSocketState.CloseReceived);
190                                 var maskOffset = WriteHeader (messageType, buffer, endOfMessage);
191
192                                 if (buffer.Count > 0)
193                                         MaskData (buffer, maskOffset);
194                                 //underlyingSocket.Send (headerBuffer, 0, maskOffset + 4, SocketFlags.None);
195                                 var headerLength = maskOffset + 4;
196                                 Array.Copy (headerBuffer, sendBuffer, headerLength);
197                                 underlyingSocket.Send (sendBuffer, 0, buffer.Count + headerLength, SocketFlags.None);
198                         });
199                 }
200                 
201                 const int messageTypeContinuation = 0;
202                 const int messageTypeText = 1;
203                 const int messageTypeBinary = 2;
204                 const int messageTypeClose = 8;
205
206                 WebSocketMessageType WireToMessageType (byte msgType)
207                 {
208                         
209                         if (msgType == messageTypeContinuation)
210                                 return currentMessageType;
211                         if (msgType == messageTypeText)
212                                 return WebSocketMessageType.Text;
213                         if (msgType == messageTypeBinary)
214                                 return WebSocketMessageType.Binary;
215                         return WebSocketMessageType.Close;
216                 }
217
218                 static byte MessageTypeToWire (WebSocketMessageType type)
219                 {
220                         if (type == WebSocketMessageType.Text)
221                                 return messageTypeText;
222                         if (type == WebSocketMessageType.Binary)
223                                 return messageTypeBinary;
224                         return messageTypeClose;
225                 }
226                 
227                 public override Task<WebSocketReceiveResult> ReceiveAsync (ArraySegment<byte> buffer, CancellationToken cancellationToken)
228                 {
229                         EnsureWebSocketConnected ();
230                         ValidateArraySegment (buffer);
231                         return Task.Run (() => {
232                                 EnsureWebSocketState (WebSocketState.Open, WebSocketState.CloseSent);
233
234                                 bool isLast;
235                                 long length;
236
237                                 if (remaining == 0) {
238                                         // First read the two first bytes to know what we are doing next
239                                         connection.Read (req, headerBuffer, 0, 2);
240                                         isLast = (headerBuffer[0] >> 7) > 0;
241                                         var isMasked = (headerBuffer[1] >> 7) > 0;
242                                         int mask = 0;
243                                         currentMessageType = WireToMessageType ((byte)(headerBuffer[0] & 0xF));
244                                         length = headerBuffer[1] & 0x7F;
245                                         int offset = 0;
246                                         if (length == 126) {
247                                                 offset = 2;
248                                                 connection.Read (req, headerBuffer, 2, offset);
249                                         length = (headerBuffer[2] << 8) | headerBuffer[3];
250                                         } else if (length == 127) {
251                                                 offset = 8;
252                                                 connection.Read (req, headerBuffer, 2, offset);
253                                                 length = 0;
254                                                 for (int i = 2; i <= 9; i++)
255                                                         length = (length << 8) | headerBuffer[i];
256                                         }
257
258                                         if (isMasked) {
259                                                 connection.Read (req, headerBuffer, 2 + offset, 4);
260                                                 for (int i = 0; i < 4; i++) {
261                                                         var pos = i + offset + 2;
262                                                         mask = (mask << 8) | headerBuffer[pos];
263                                                 }
264                                         }
265                                 } else {
266                                         isLast = (headerBuffer[0] >> 7) > 0;
267                                         currentMessageType = WireToMessageType ((byte)(headerBuffer[0] & 0xF));
268                                         length = remaining;
269                                 }
270
271                                 if (currentMessageType == WebSocketMessageType.Close) {
272                                         state = WebSocketState.Closed;
273                                         var tmpBuffer = new byte[length];
274                                         connection.Read (req, tmpBuffer, 0, tmpBuffer.Length);
275                                         var closeStatus = (WebSocketCloseStatus)(tmpBuffer[0] << 8 | tmpBuffer[1]);
276                                         var closeDesc = tmpBuffer.Length > 2 ? Encoding.UTF8.GetString (tmpBuffer, 2, tmpBuffer.Length - 2) : string.Empty;
277                                         return new WebSocketReceiveResult ((int)length, currentMessageType, isLast, closeStatus, closeDesc);
278                                 } else {
279                                         var readLength = (int)(buffer.Count < length ? buffer.Count : length);
280                                         connection.Read (req, buffer.Array, buffer.Offset, readLength);
281                                         remaining = length - readLength;
282
283                                         return new WebSocketReceiveResult ((int)readLength, currentMessageType, isLast && remaining == 0);
284                                 }
285                         });
286                 }
287
288                 // The damn difference between those two methods is that CloseAsync will wait for server acknowledgement before completing
289                 // while CloseOutputAsync will send the close packet and simply complete.
290
291                 public async override Task CloseAsync (WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
292                 {
293                         EnsureWebSocketConnected ();
294                         await SendCloseFrame (closeStatus, statusDescription, cancellationToken).ConfigureAwait (false);
295                         state = WebSocketState.CloseSent;
296                         // TODO: figure what's exceptions are thrown if the server returns something faulty here
297                         await ReceiveAsync (new ArraySegment<byte> (new byte[0]), cancellationToken).ConfigureAwait (false);
298                         state = WebSocketState.Closed;
299                 }
300
301                 public async override Task CloseOutputAsync (WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
302                 {
303                         EnsureWebSocketConnected ();
304                         await SendCloseFrame (closeStatus, statusDescription, cancellationToken).ConfigureAwait (false);
305                         state = WebSocketState.CloseSent;
306                 }
307
308                 async Task SendCloseFrame (WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
309                 {
310                         var statusDescBuffer = string.IsNullOrEmpty (statusDescription) ? new byte[2] : new byte[2 + Encoding.UTF8.GetByteCount (statusDescription)];
311                         statusDescBuffer[0] = (byte)(((ushort)closeStatus) >> 8);
312                         statusDescBuffer[1] = (byte)(((ushort)closeStatus) & 0xFF);
313                         if (!string.IsNullOrEmpty (statusDescription))
314                                 Encoding.UTF8.GetBytes (statusDescription, 0, statusDescription.Length, statusDescBuffer, 2);
315                         await SendAsync (new ArraySegment<byte> (statusDescBuffer), WebSocketMessageType.Close, true, cancellationToken).ConfigureAwait (false);
316                 }
317
318                 int WriteHeader (WebSocketMessageType type, ArraySegment<byte> buffer, bool endOfMessage)
319                 {
320                         var opCode = MessageTypeToWire (type);
321                         var length = buffer.Count;
322
323                         headerBuffer[0] = (byte)(opCode | (endOfMessage ? 0x80 : 0));
324                         if (length < 126) {
325                                 headerBuffer[1] = (byte)length;
326                         } else if (length <= ushort.MaxValue) {
327                                 headerBuffer[1] = (byte)126;
328                                 headerBuffer[2] = (byte)(length / 256);
329                                 headerBuffer[3] = (byte)(length % 256);
330                         } else {
331                                 headerBuffer[1] = (byte)127;
332
333                                 int left = length;
334                                 int unit = 256;
335
336                                 for (int i = 9; i > 1; i--) {
337                                         headerBuffer[i] = (byte)(left % unit);
338                                         left = left / unit;
339                                 }
340                         }
341
342                         var l = Math.Max (0, headerBuffer[1] - 125);
343                         var maskOffset = 2 + l * l * 2;
344                         GenerateMask (headerBuffer, maskOffset);
345
346                         // Since we are client only, we always mask the payload
347                         headerBuffer[1] |= 0x80;
348
349                         return maskOffset;
350                 }
351
352                 void GenerateMask (byte[] mask, int offset)
353                 {
354                         mask[offset + 0] = (byte)random.Next (0, 255);
355                         mask[offset + 1] = (byte)random.Next (0, 255);
356                         mask[offset + 2] = (byte)random.Next (0, 255);
357                         mask[offset + 3] = (byte)random.Next (0, 255);
358                 }
359
360                 void MaskData (ArraySegment<byte> buffer, int maskOffset)
361                 {
362                         var sendBufferOffset = maskOffset + 4;
363                         for (var i = 0; i < buffer.Count; i++)
364                                 sendBuffer[i + sendBufferOffset] = (byte)(buffer.Array[buffer.Offset + i] ^ headerBuffer[maskOffset + (i % 4)]);
365                 }
366
367                 void EnsureWebSocketConnected ()
368                 {
369                         if (state < WebSocketState.Open)
370                                 throw new InvalidOperationException ("The WebSocket is not connected");
371                 }
372
373                 void EnsureWebSocketState (params WebSocketState[] validStates)
374                 {
375                         foreach (var validState in validStates)
376                                 if (state == validState)
377                                         return;
378                         throw new WebSocketException ("The WebSocket is in an invalid state ('" + state + "') for this operation. Valid states are: " + string.Join (", ", validStates));
379                 }
380
381                 void ValidateArraySegment (ArraySegment<byte> segment)
382                 {
383                         if (segment.Array == null)
384                                 throw new ArgumentNullException ("buffer.Array");
385                         if (segment.Offset < 0)
386                                 throw new ArgumentOutOfRangeException ("buffer.Offset");
387                         if (segment.Offset + segment.Count > segment.Array.Length)
388                                 throw new ArgumentOutOfRangeException ("buffer.Count");
389                 }
390         }
391 }
392