Merge branch 'patch-1' of https://github.com/ReubenBond/mono into ReubenBond-patch-1
authorAlexander Köplinger <alex.koeplinger@outlook.com>
Fri, 7 Nov 2014 18:34:27 +0000 (19:34 +0100)
committerAlexander Köplinger <alex.koeplinger@outlook.com>
Fri, 7 Nov 2014 19:14:29 +0000 (20:14 +0100)
1  2 
mcs/class/System/System.Net.WebSockets/ClientWebSocket.cs
mcs/class/System/Test/System.Net.WebSockets/ClientWebSocketTest.cs

index d7e999c46f36b4624cea92e5ad32c920a791308b,c903720f24d7d231dfeb48b71658b439470279d5..ff011d2f93044da4f6f28a9a7966e2800004572d
@@@ -61,7 -61,6 +61,7 @@@ namespace System.Net.WebSocket
                const int HeaderMaxLength = 14;
                byte[] headerBuffer;
                byte[] sendBuffer;
 +              long remaining;
  
                public ClientWebSocket ()
                {
                                underlyingSocket.Send (sendBuffer, 0, buffer.Count + headerLength, SocketFlags.None);
                        });
                }
 +              
 +              const int messageTypeText = 1;
 +              const int messageTypeBinary = 2;
 +              const int messageTypeClose = 8;
  
 +              static WebSocketMessageType WireToMessageType (byte msgType)
 +              {
 +                      
 +                      if (msgType == messageTypeText)
 +                              return WebSocketMessageType.Text;
 +                      if (msgType == messageTypeBinary)
 +                              return WebSocketMessageType.Binary;
 +                      return WebSocketMessageType.Close;
 +              }
 +
 +              static byte MessageTypeToWire (WebSocketMessageType type)
 +              {
 +                      if (type == WebSocketMessageType.Text)
 +                              return messageTypeText;
 +                      if (type == WebSocketMessageType.Binary)
 +                              return messageTypeBinary;
 +                      return messageTypeClose;
 +              }
 +              
                public override Task<WebSocketReceiveResult> ReceiveAsync (ArraySegment<byte> buffer, CancellationToken cancellationToken)
                {
                        EnsureWebSocketConnected ();
                        ValidateArraySegment (buffer);
                        return Task.Run (() => {
                                EnsureWebSocketState (WebSocketState.Open, WebSocketState.CloseSent);
 -                              // First read the two first bytes to know what we are doing next
 -                              connection.Read (req, headerBuffer, 0, 2);
 -                              var isLast = (headerBuffer[0] >> 7) > 0;
 -                              var isMasked = (headerBuffer[1] >> 7) > 0;
 -                              int mask = 0;
 -                              var type = (WebSocketMessageType)(headerBuffer[0] & 0xF);
 -                              long length = headerBuffer[1] & 0x7F;
 -                              int offset = 0;
 -                              if (length == 126) {
 -                                      offset = 2;
 -                                      connection.Read (req, headerBuffer, 2, offset);
 +
 +                              bool isLast;
 +                              WebSocketMessageType type;
 +                              long length;
 +
 +                              if (remaining == 0) {
 +                                      // First read the two first bytes to know what we are doing next
 +                                      connection.Read (req, headerBuffer, 0, 2);
 +                                      isLast = (headerBuffer[0] >> 7) > 0;
 +                                      var isMasked = (headerBuffer[1] >> 7) > 0;
 +                                      int mask = 0;
 +                                      type = WireToMessageType ((byte)(headerBuffer[0] & 0xF));
 +                                      length = headerBuffer[1] & 0x7F;
 +                                      int offset = 0;
 +                                      if (length == 126) {
 +                                              offset = 2;
 +                                              connection.Read (req, headerBuffer, 2, offset);
                                        length = (headerBuffer[2] << 8) | headerBuffer[3];
 -                              } else if (length == 127) {
 -                                      offset = 8;
 -                                      connection.Read (req, headerBuffer, 2, offset);
 -                                      length = 0;
 -                                      for (int i = 2; i <= 9; i++)
 -                                              length = (length << 8) | headerBuffer[i];
 -                              }
 +                                      } else if (length == 127) {
 +                                              offset = 8;
 +                                              connection.Read (req, headerBuffer, 2, offset);
 +                                              length = 0;
 +                                              for (int i = 2; i <= 9; i++)
 +                                                      length = (length << 8) | headerBuffer[i];
 +                                      }
  
 -                              if (isMasked) {
 -                                      connection.Read (req, headerBuffer, 2 + offset, 4);
 -                                      for (int i = 0; i < 4; i++) {
 -                                              var pos = i + offset + 2;
 -                                              mask = (mask << 8) | headerBuffer[pos];
 +                                      if (isMasked) {
 +                                              connection.Read (req, headerBuffer, 2 + offset, 4);
 +                                              for (int i = 0; i < 4; i++) {
 +                                                      var pos = i + offset + 2;
 +                                                      mask = (mask << 8) | headerBuffer[pos];
 +                                              }
                                        }
 +                              } else {
 +                                      isLast = (headerBuffer[0] >> 7) > 0;
 +                                      type = WireToMessageType ((byte)(headerBuffer[0] & 0xF));
 +                                      length = remaining;
                                }
  
                                if (type == WebSocketMessageType.Close) {
                                } else {
                                        var readLength = (int)(buffer.Count < length ? buffer.Count : length);
                                        connection.Read (req, buffer.Array, buffer.Offset, readLength);
 +                                      remaining = length - readLength;
  
 -                                      return new WebSocketReceiveResult ((int)length, type, isLast);
 +                                      return new WebSocketReceiveResult ((int)readLength, type, isLast && remaining == 0);
                                }
                        });
                }
  
                int WriteHeader (WebSocketMessageType type, ArraySegment<byte> buffer, bool endOfMessage)
                {
 -                      var opCode = (byte)type;
 +                      var opCode = MessageTypeToWire (type);
                        var length = buffer.Count;
  
-                       headerBuffer[0] = (byte)(opCode | (endOfMessage ? 0 : 0x80));
+                       headerBuffer[0] = (byte)(opCode | (endOfMessage ? 0x80 : 0));
                        if (length < 126) {
                                headerBuffer[1] = (byte)length;
                        } else if (length <= ushort.MaxValue) {
index 2140daac9bb866aa77fb5dcebd666424dcff6f1d,7c89da6ef7158c11e5eb40427e2df607c1d3158e..a990f67e65848183a4c9ac8d2c50060406987d95
@@@ -49,10 -49,7 +49,10 @@@ namespace MonoTests.System.Net.WebSocke
                [Test]
                public void ServerHandshakeReturnCrapStatusCodeTest ()
                {
 +                      // On purpose, 
 +                      #pragma warning disable 4014
                        HandleHttpRequestAsync ((req, resp) => resp.StatusCode = 418);
 +                      #pragma warning restore 4014
                        try {
                                Assert.IsTrue (socket.ConnectAsync (new Uri ("ws://localhost:" + Port), CancellationToken.None).Wait (5000));
                        } catch (AggregateException e) {
                [Test]
                public void ServerHandshakeReturnWrongUpgradeHeader ()
                {
 +                      #pragma warning disable 4014
                        HandleHttpRequestAsync ((req, resp) => {
                                        resp.StatusCode = 101;
                                        resp.Headers["Upgrade"] = "gtfo";
                                });
 +                      #pragma warning restore 4014
                        try {
                                Assert.IsTrue (socket.ConnectAsync (new Uri ("ws://localhost:" + Port), CancellationToken.None).Wait (5000));
                        } catch (AggregateException e) {
                [Test]
                public void ServerHandshakeReturnWrongConnectionHeader ()
                {
 +                      #pragma warning disable 4014
                        HandleHttpRequestAsync ((req, resp) => {
                                        resp.StatusCode = 101;
                                        resp.Headers["Upgrade"] = "websocket";
                                        // Mono http request doesn't like the forcing, test still valid since the default connection header value is empty
                                        //ForceSetHeader (resp.Headers, "Connection", "Foo");
                                });
 +                      #pragma warning restore 4014
                        try {
                                Assert.IsTrue (socket.ConnectAsync (new Uri ("ws://localhost:" + Port), CancellationToken.None).Wait (5000));
                        } catch (AggregateException e) {
                        }
                        Assert.Fail ("Should have thrown");
                }
+               
+               [Test]
+               public async Task SendAsyncEndOfMessageTest() {
+                       var cancellationToken = new CancellationTokenSource(TimeSpan.FromSeconds(30)).Token;
+                       await SendAsyncEndOfMessageTest(false, WebSocketMessageType.Text, cancellationToken);
+                       await SendAsyncEndOfMessageTest(true, WebSocketMessageType.Text, cancellationToken);
+                       await SendAsyncEndOfMessageTest(false, WebSocketMessageType.Binary, cancellationToken);
+                       await SendAsyncEndOfMessageTest(true, WebSocketMessageType.Binary, cancellationToken);
+               }
+               
+               public async Task SendAsyncEndOfMessageTest(bool expectedEndOfMessage, WebSocketMessageType webSocketMessageType, CancellationToken cancellationToken){
+                       using (var client = new ClientWebSocket()) {    
+                               // Configure the listener.
+                               var serverReceive = HandleHttpWebSocketRequestAsync<WebSocketReceiveResult>(async socket => await socket.ReceiveAsync(new ArraySegment<byte>(new byte[32]), cancellationToken), cancellationToken);
+                               
+                               // Connect to the listener and make the request.
+                               await client.ConnectAsync (new Uri ("ws://localhost:" + Port + "/"), cancellationToken);
+                               await client.SendAsync(new ArraySegment<byte>(Encoding.UTF8.GetBytes("test")), webSocketMessageType, expectedEndOfMessage, cancellationToken);
+                               
+                               // Wait for the listener to handle the request and return its result.
+                               var result = await serverReceive;
+                               
+                               // Cleanup and check results.
+                               await client.CloseAsync(WebSocketCloseStatus.NormalClosure, "Finished", cancellationToken);
+                               Assert.AreEqual(expectedEndOfMessage, result.EndOfMessage, "EndOfMessage should be " + expectedEndOfMessage);
+                       }
+               }
+               
+               async Task<T> HandleHttpWebSocketRequestAsync<T>(Func<WebSocket, Task<T>> action, CancellationToken cancellationToken) {
+                       var ctx = await this.listener.GetContextAsync();
+                       var wsContext = await ctx.AcceptWebSocketAsync(null);
+                       var result = await action(wsContext.WebSocket);
+                       await wsContext.WebSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "Finished", cancellationToken);
+                       return result;
+               }
  
                async Task HandleHttpRequestAsync (Action<HttpListenerRequest, HttpListenerResponse> handler)
                {