Merge branch 'patch-1' of https://github.com/ReubenBond/mono into ReubenBond-patch-1
[mono.git] / mcs / class / System / System.Net.WebSockets / ClientWebSocket.cs
index c903720f24d7d231dfeb48b71658b439470279d5..ff011d2f93044da4f6f28a9a7966e2800004572d 100644 (file)
@@ -61,6 +61,7 @@ namespace System.Net.WebSockets
                const int HeaderMaxLength = 14;
                byte[] headerBuffer;
                byte[] sendBuffer;
+               long remaining;
 
                public ClientWebSocket ()
                {
@@ -196,39 +197,73 @@ namespace System.Net.WebSockets
                                underlyingSocket.Send (sendBuffer, 0, buffer.Count + headerLength, SocketFlags.None);
                        });
                }
+               
+               const int messageTypeText = 1;
+               const int messageTypeBinary = 2;
+               const int messageTypeClose = 8;
 
+               static WebSocketMessageType WireToMessageType (byte msgType)
+               {
+                       
+                       if (msgType == messageTypeText)
+                               return WebSocketMessageType.Text;
+                       if (msgType == messageTypeBinary)
+                               return WebSocketMessageType.Binary;
+                       return WebSocketMessageType.Close;
+               }
+
+               static byte MessageTypeToWire (WebSocketMessageType type)
+               {
+                       if (type == WebSocketMessageType.Text)
+                               return messageTypeText;
+                       if (type == WebSocketMessageType.Binary)
+                               return messageTypeBinary;
+                       return messageTypeClose;
+               }
+               
                public override Task<WebSocketReceiveResult> ReceiveAsync (ArraySegment<byte> buffer, CancellationToken cancellationToken)
                {
                        EnsureWebSocketConnected ();
                        ValidateArraySegment (buffer);
                        return Task.Run (() => {
                                EnsureWebSocketState (WebSocketState.Open, WebSocketState.CloseSent);
-                               // First read the two first bytes to know what we are doing next
-                               connection.Read (req, headerBuffer, 0, 2);
-                               var isLast = (headerBuffer[0] >> 7) > 0;
-                               var isMasked = (headerBuffer[1] >> 7) > 0;
-                               int mask = 0;
-                               var type = (WebSocketMessageType)(headerBuffer[0] & 0xF);
-                               long length = headerBuffer[1] & 0x7F;
-                               int offset = 0;
-                               if (length == 126) {
-                                       offset = 2;
-                                       connection.Read (req, headerBuffer, 2, offset);
+
+                               bool isLast;
+                               WebSocketMessageType type;
+                               long length;
+
+                               if (remaining == 0) {
+                                       // First read the two first bytes to know what we are doing next
+                                       connection.Read (req, headerBuffer, 0, 2);
+                                       isLast = (headerBuffer[0] >> 7) > 0;
+                                       var isMasked = (headerBuffer[1] >> 7) > 0;
+                                       int mask = 0;
+                                       type = WireToMessageType ((byte)(headerBuffer[0] & 0xF));
+                                       length = headerBuffer[1] & 0x7F;
+                                       int offset = 0;
+                                       if (length == 126) {
+                                               offset = 2;
+                                               connection.Read (req, headerBuffer, 2, offset);
                                        length = (headerBuffer[2] << 8) | headerBuffer[3];
-                               } else if (length == 127) {
-                                       offset = 8;
-                                       connection.Read (req, headerBuffer, 2, offset);
-                                       length = 0;
-                                       for (int i = 2; i <= 9; i++)
-                                               length = (length << 8) | headerBuffer[i];
-                               }
+                                       } else if (length == 127) {
+                                               offset = 8;
+                                               connection.Read (req, headerBuffer, 2, offset);
+                                               length = 0;
+                                               for (int i = 2; i <= 9; i++)
+                                                       length = (length << 8) | headerBuffer[i];
+                                       }
 
-                               if (isMasked) {
-                                       connection.Read (req, headerBuffer, 2 + offset, 4);
-                                       for (int i = 0; i < 4; i++) {
-                                               var pos = i + offset + 2;
-                                               mask = (mask << 8) | headerBuffer[pos];
+                                       if (isMasked) {
+                                               connection.Read (req, headerBuffer, 2 + offset, 4);
+                                               for (int i = 0; i < 4; i++) {
+                                                       var pos = i + offset + 2;
+                                                       mask = (mask << 8) | headerBuffer[pos];
+                                               }
                                        }
+                               } else {
+                                       isLast = (headerBuffer[0] >> 7) > 0;
+                                       type = WireToMessageType ((byte)(headerBuffer[0] & 0xF));
+                                       length = remaining;
                                }
 
                                if (type == WebSocketMessageType.Close) {
@@ -241,8 +276,9 @@ namespace System.Net.WebSockets
                                } else {
                                        var readLength = (int)(buffer.Count < length ? buffer.Count : length);
                                        connection.Read (req, buffer.Array, buffer.Offset, readLength);
+                                       remaining = length - readLength;
 
-                                       return new WebSocketReceiveResult ((int)length, type, isLast);
+                                       return new WebSocketReceiveResult ((int)readLength, type, isLast && remaining == 0);
                                }
                        });
                }
@@ -279,7 +315,7 @@ namespace System.Net.WebSockets
 
                int WriteHeader (WebSocketMessageType type, ArraySegment<byte> buffer, bool endOfMessage)
                {
-                       var opCode = (byte)type;
+                       var opCode = MessageTypeToWire (type);
                        var length = buffer.Count;
 
                        headerBuffer[0] = (byte)(opCode | (endOfMessage ? 0x80 : 0));