Merge pull request #1345 from mattleibow/websocket-continuation-frame-fix
[mono.git] / mcs / class / System / System.Net.WebSockets / ClientWebSocket.cs
index c7b682eb7a343a766b950f9ccd5444798b5464f5..dba02f2761718c56c9a46d28a8a3a395d5fe8be0 100644 (file)
@@ -26,7 +26,6 @@
 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 // THE SOFTWARE.
 
-#if NET_4_5
 
 using System;
 using System.Net;
@@ -45,6 +44,7 @@ namespace System.Net.WebSockets
 {
        public class ClientWebSocket : WebSocket, IDisposable
        {
+               const string Magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
                const string VersionTag = "13";
 
                ClientWebSocketOptions options;
@@ -53,24 +53,33 @@ namespace System.Net.WebSockets
 
                HttpWebRequest req;
                WebConnection connection;
-               StreamWebSocket internalWebSocket;
+               Socket underlyingSocket;
+
+               Random random = new Random ();
+
+               const int HeaderMaxLength = 14;
+               byte[] headerBuffer;
+               byte[] sendBuffer;
+               long remaining;
+               WebSocketMessageType currentMessageType;
 
                public ClientWebSocket ()
                {
                        options = new ClientWebSocketOptions ();
                        state = WebSocketState.None;
+                       headerBuffer = new byte[HeaderMaxLength];
                }
 
                public override void Dispose ()
                {
-                       if (internalWebSocket != null)
-                               internalWebSocket.Dispose ();
+                       if (connection != null)
+                               connection.Close (false);
                }
 
+               [MonoTODO]
                public override void Abort ()
                {
-                       if (internalWebSocket != null)
-                               internalWebSocket.Abort ();
+                       throw new NotImplementedException ();
                }
 
                public ClientWebSocketOptions Options {
@@ -81,29 +90,28 @@ namespace System.Net.WebSockets
 
                public override WebSocketState State {
                        get {
-                               if (internalWebSocket != null)
-                                       return internalWebSocket.State;
                                return state;
                        }
                }
 
                public override WebSocketCloseStatus? CloseStatus {
                        get {
-                               if (internalWebSocket != null)
-                                       return internalWebSocket.CloseStatus;
                                if (state != WebSocketState.Closed)
                                        return (WebSocketCloseStatus?)null;
                                return WebSocketCloseStatus.Empty;
                        }
                }
 
-               [MonoTODO]
                public override string CloseStatusDescription {
-                       get { return null; }
+                       get {
+                               return null;
+                       }
                }
 
                public override string SubProtocol {
-                       get { return subProtocol; }
+                       get {
+                               return subProtocol;
+                       }
                }
 
                public async Task ConnectAsync (Uri uri, CancellationToken cancellationToken)
@@ -125,7 +133,7 @@ namespace System.Net.WebSockets
                        }
 
                        var secKey = Convert.ToBase64String (Encoding.ASCII.GetBytes (Guid.NewGuid ().ToString ().Substring (0, 16)));
-                       string expectedAccept = StreamWebSocket.CreateAcceptKey (secKey);
+                       string expectedAccept = Convert.ToBase64String (SHA1.Create ().ComputeHash (Encoding.ASCII.GetBytes (secKey + Magic)));
 
                        req.Headers["Upgrade"] = "WebSocket";
                        req.Headers["Sec-WebSocket-Version"] = VersionTag;
@@ -150,6 +158,9 @@ namespace System.Net.WebSockets
                                throw new WebSocketException (WebSocketError.Success, e);
                        }
 
+                       connection = req.StoredConnection;
+                       underlyingSocket = connection.socket;
+
                        if (resp.StatusCode != HttpStatusCode.SwitchingProtocols)
                                throw new WebSocketException ("The server returned status code '" + (int)resp.StatusCode + "' when status code '101' was expected");
                        if (!string.Equals (resp.Headers["Upgrade"], "WebSocket", StringComparison.OrdinalIgnoreCase)
@@ -161,34 +172,221 @@ namespace System.Net.WebSockets
                                        throw new WebSocketException (WebSocketError.UnsupportedProtocol);
                                subProtocol = resp.Headers["Sec-WebSocket-Protocol"];
                        }
+
                        state = WebSocketState.Open;
-                       connection = req.StoredConnection;
-                       internalWebSocket = new StreamWebSocket(connection.nstream, connection.nstream, connection.socket, subProtocol, true, new ArraySegment<byte>(new byte[0]));
                }
 
                public override Task SendAsync (ArraySegment<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
                {
-                       return internalWebSocket.SendAsync (buffer, messageType, endOfMessage, cancellationToken);
+                       EnsureWebSocketConnected ();
+                       ValidateArraySegment (buffer);
+                       if (connection == null)
+                               throw new WebSocketException (WebSocketError.Faulted);
+                       var count = Math.Max (options.SendBufferSize, buffer.Count) + HeaderMaxLength;
+                       if (sendBuffer == null || sendBuffer.Length != count)
+                               sendBuffer = new byte[count];
+                       return Task.Run (() => {
+                               EnsureWebSocketState (WebSocketState.Open, WebSocketState.CloseReceived);
+                               var maskOffset = WriteHeader (messageType, buffer, endOfMessage);
+
+                               if (buffer.Count > 0)
+                                       MaskData (buffer, maskOffset);
+                               //underlyingSocket.Send (headerBuffer, 0, maskOffset + 4, SocketFlags.None);
+                               var headerLength = maskOffset + 4;
+                               Array.Copy (headerBuffer, sendBuffer, headerLength);
+                               underlyingSocket.Send (sendBuffer, 0, buffer.Count + headerLength, SocketFlags.None);
+                       });
+               }
+               
+               const int messageTypeContinuation = 0;
+               const int messageTypeText = 1;
+               const int messageTypeBinary = 2;
+               const int messageTypeClose = 8;
+
+               WebSocketMessageType WireToMessageType (byte msgType)
+               {
+                       
+                       if (msgType == messageTypeContinuation)
+                               return currentMessageType;
+                       if (msgType == messageTypeText)
+                               return WebSocketMessageType.Text;
+                       if (msgType == messageTypeBinary)
+                               return WebSocketMessageType.Binary;
+                       return WebSocketMessageType.Close;
                }
 
+               static byte MessageTypeToWire (WebSocketMessageType type)
+               {
+                       if (type == WebSocketMessageType.Text)
+                               return messageTypeText;
+                       if (type == WebSocketMessageType.Binary)
+                               return messageTypeBinary;
+                       return messageTypeClose;
+               }
+               
                public override Task<WebSocketReceiveResult> ReceiveAsync (ArraySegment<byte> buffer, CancellationToken cancellationToken)
                {
-                       return internalWebSocket.ReceiveAsync (buffer, cancellationToken);
+                       EnsureWebSocketConnected ();
+                       ValidateArraySegment (buffer);
+                       return Task.Run (() => {
+                               EnsureWebSocketState (WebSocketState.Open, WebSocketState.CloseSent);
+
+                               bool isLast;
+                               long length;
+
+                               if (remaining == 0) {
+                                       // First read the two first bytes to know what we are doing next
+                                       connection.Read (req, headerBuffer, 0, 2);
+                                       isLast = (headerBuffer[0] >> 7) > 0;
+                                       var isMasked = (headerBuffer[1] >> 7) > 0;
+                                       int mask = 0;
+                                       currentMessageType = WireToMessageType ((byte)(headerBuffer[0] & 0xF));
+                                       length = headerBuffer[1] & 0x7F;
+                                       int offset = 0;
+                                       if (length == 126) {
+                                               offset = 2;
+                                               connection.Read (req, headerBuffer, 2, offset);
+                                       length = (headerBuffer[2] << 8) | headerBuffer[3];
+                                       } else if (length == 127) {
+                                               offset = 8;
+                                               connection.Read (req, headerBuffer, 2, offset);
+                                               length = 0;
+                                               for (int i = 2; i <= 9; i++)
+                                                       length = (length << 8) | headerBuffer[i];
+                                       }
+
+                                       if (isMasked) {
+                                               connection.Read (req, headerBuffer, 2 + offset, 4);
+                                               for (int i = 0; i < 4; i++) {
+                                                       var pos = i + offset + 2;
+                                                       mask = (mask << 8) | headerBuffer[pos];
+                                               }
+                                       }
+                               } else {
+                                       isLast = (headerBuffer[0] >> 7) > 0;
+                                       currentMessageType = WireToMessageType ((byte)(headerBuffer[0] & 0xF));
+                                       length = remaining;
+                               }
+
+                               if (currentMessageType == WebSocketMessageType.Close) {
+                                       state = WebSocketState.Closed;
+                                       var tmpBuffer = new byte[length];
+                                       connection.Read (req, tmpBuffer, 0, tmpBuffer.Length);
+                                       var closeStatus = (WebSocketCloseStatus)(tmpBuffer[0] << 8 | tmpBuffer[1]);
+                                       var closeDesc = tmpBuffer.Length > 2 ? Encoding.UTF8.GetString (tmpBuffer, 2, tmpBuffer.Length - 2) : string.Empty;
+                                       return new WebSocketReceiveResult ((int)length, currentMessageType, isLast, closeStatus, closeDesc);
+                               } else {
+                                       var readLength = (int)(buffer.Count < length ? buffer.Count : length);
+                                       connection.Read (req, buffer.Array, buffer.Offset, readLength);
+                                       remaining = length - readLength;
+
+                                       return new WebSocketReceiveResult ((int)readLength, currentMessageType, isLast && remaining == 0);
+                               }
+                       });
                }
 
                // The damn difference between those two methods is that CloseAsync will wait for server acknowledgement before completing
                // while CloseOutputAsync will send the close packet and simply complete.
 
-               public override Task CloseAsync (WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
+               public async override Task CloseAsync (WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
+               {
+                       EnsureWebSocketConnected ();
+                       await SendCloseFrame (closeStatus, statusDescription, cancellationToken).ConfigureAwait (false);
+                       state = WebSocketState.CloseSent;
+                       // TODO: figure what's exceptions are thrown if the server returns something faulty here
+                       await ReceiveAsync (new ArraySegment<byte> (new byte[0]), cancellationToken).ConfigureAwait (false);
+                       state = WebSocketState.Closed;
+               }
+
+               public async override Task CloseOutputAsync (WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
+               {
+                       EnsureWebSocketConnected ();
+                       await SendCloseFrame (closeStatus, statusDescription, cancellationToken).ConfigureAwait (false);
+                       state = WebSocketState.CloseSent;
+               }
+
+               async Task SendCloseFrame (WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
+               {
+                       var statusDescBuffer = string.IsNullOrEmpty (statusDescription) ? new byte[2] : new byte[2 + Encoding.UTF8.GetByteCount (statusDescription)];
+                       statusDescBuffer[0] = (byte)(((ushort)closeStatus) >> 8);
+                       statusDescBuffer[1] = (byte)(((ushort)closeStatus) & 0xFF);
+                       if (!string.IsNullOrEmpty (statusDescription))
+                               Encoding.UTF8.GetBytes (statusDescription, 0, statusDescription.Length, statusDescBuffer, 2);
+                       await SendAsync (new ArraySegment<byte> (statusDescBuffer), WebSocketMessageType.Close, true, cancellationToken).ConfigureAwait (false);
+               }
+
+               int WriteHeader (WebSocketMessageType type, ArraySegment<byte> buffer, bool endOfMessage)
+               {
+                       var opCode = MessageTypeToWire (type);
+                       var length = buffer.Count;
+
+                       headerBuffer[0] = (byte)(opCode | (endOfMessage ? 0x80 : 0));
+                       if (length < 126) {
+                               headerBuffer[1] = (byte)length;
+                       } else if (length <= ushort.MaxValue) {
+                               headerBuffer[1] = (byte)126;
+                               headerBuffer[2] = (byte)(length / 256);
+                               headerBuffer[3] = (byte)(length % 256);
+                       } else {
+                               headerBuffer[1] = (byte)127;
+
+                               int left = length;
+                               int unit = 256;
+
+                               for (int i = 9; i > 1; i--) {
+                                       headerBuffer[i] = (byte)(left % unit);
+                                       left = left / unit;
+                               }
+                       }
+
+                       var l = Math.Max (0, headerBuffer[1] - 125);
+                       var maskOffset = 2 + l * l * 2;
+                       GenerateMask (headerBuffer, maskOffset);
+
+                       // Since we are client only, we always mask the payload
+                       headerBuffer[1] |= 0x80;
+
+                       return maskOffset;
+               }
+
+               void GenerateMask (byte[] mask, int offset)
+               {
+                       mask[offset + 0] = (byte)random.Next (0, 255);
+                       mask[offset + 1] = (byte)random.Next (0, 255);
+                       mask[offset + 2] = (byte)random.Next (0, 255);
+                       mask[offset + 3] = (byte)random.Next (0, 255);
+               }
+
+               void MaskData (ArraySegment<byte> buffer, int maskOffset)
+               {
+                       var sendBufferOffset = maskOffset + 4;
+                       for (var i = 0; i < buffer.Count; i++)
+                               sendBuffer[i + sendBufferOffset] = (byte)(buffer.Array[buffer.Offset + i] ^ headerBuffer[maskOffset + (i % 4)]);
+               }
+
+               void EnsureWebSocketConnected ()
+               {
+                       if (state < WebSocketState.Open)
+                               throw new InvalidOperationException ("The WebSocket is not connected");
+               }
+
+               void EnsureWebSocketState (params WebSocketState[] validStates)
                {
-                       return internalWebSocket.CloseAsync (closeStatus, statusDescription, cancellationToken);
+                       foreach (var validState in validStates)
+                               if (state == validState)
+                                       return;
+                       throw new WebSocketException ("The WebSocket is in an invalid state ('" + state + "') for this operation. Valid states are: " + string.Join (", ", validStates));
                }
 
-               public override Task CloseOutputAsync (WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
+               void ValidateArraySegment (ArraySegment<byte> segment)
                {
-                       return internalWebSocket.CloseOutputAsync (closeStatus, statusDescription, cancellationToken);
+                       if (segment.Array == null)
+                               throw new ArgumentNullException ("buffer.Array");
+                       if (segment.Offset < 0)
+                               throw new ArgumentOutOfRangeException ("buffer.Offset");
+                       if (segment.Offset + segment.Count > segment.Array.Length)
+                               throw new ArgumentOutOfRangeException ("buffer.Count");
                }
        }
 }
 
-#endif