[System] Implement a first pass of ClientWebSocket.
authorJérémie Laval <jeremie.laval@gmail.com>
Mon, 29 Jul 2013 04:46:19 +0000 (00:46 -0400)
committerJérémie Laval <jeremie.laval@gmail.com>
Mon, 29 Jul 2013 04:47:12 +0000 (00:47 -0400)
mcs/class/System/System.Net.WebSockets/ClientWebSocket.cs
mcs/class/System/System.Net.WebSockets/ClientWebSocketOptions.cs
mcs/class/System/System.Net.WebSockets/WebSocketException.cs
mcs/class/System/System.Net.WebSockets/WebSocketMessageType.cs
mcs/class/System/System.Net.WebSockets/WebSocketReceiveResult.cs
mcs/class/System/System_test.dll.sources
mcs/class/System/Test/System.Net.WebSockets/ClientWebSocketTest.cs [new file with mode: 0644]

index f04cb3c054343cab67538bc1473ce826cc0ce76c..3778c740538ffde77005774f0408e68a27f87b69 100644 (file)
@@ -1,10 +1,12 @@
 //
 // ClientWebSocket.cs
 //
-// Author:
-//       Martin Baulig <martin.baulig@xamarin.com>
+// Authors:
+//       Jérémie Laval <jeremie dot laval at xamarin dot com>
 //
-// Copyright (c) 2013 Xamarin Inc. (http://www.xamarin.com)
+// Copyright 2013 Xamarin Inc (http://www.xamarin.com).
+//
+// Lightly inspired from WebSocket4Net distributed under the Apache License 2.0
 //
 // Permission is hereby granted, free of charge, to any person obtaining a copy
 // of this software and associated documentation files (the "Software"), to deal
 #if NET_4_5
 
 using System;
+using System.Net;
+using System.Net.Sockets;
+using System.Security.Principal;
+using System.Security.Cryptography.X509Certificates;
+using System.Runtime.CompilerServices;
+using System.Collections.Generic;
 using System.Threading;
 using System.Threading.Tasks;
+using System.Globalization;
+using System.Text;
+using System.Security.Cryptography;
 
 namespace System.Net.WebSockets
 {
-       [MonoTODO]
-       public class ClientWebSocket : WebSocket
+       public class ClientWebSocket : WebSocket, IDisposable
        {
-               public ClientWebSocketOptions Options {
-                       get { throw new NotImplementedException (); }
+               const string Magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
+               const string VersionTag = "13";
+
+               ClientWebSocketOptions options;
+               WebSocketState state;
+               string subProtocol;
+
+               HttpWebRequest req;
+               WebConnection connection;
+               Socket underlyingSocket;
+
+               Random random = new Random ();
+
+               const int HeaderMaxLength = 14;
+               byte[] headerBuffer;
+               byte[] sendBuffer;
+
+               public ClientWebSocket ()
+               {
+                       options = new ClientWebSocketOptions ();
+                       state = WebSocketState.None;
+                       headerBuffer = new byte[HeaderMaxLength];
                }
 
-               public Task ConnectAsync (Uri uri, CancellationToken cancellationToken)
+               public override void Dispose ()
                {
-                       throw new NotImplementedException ();
+                       if (connection != null)
+                               connection.Close (false);
                }
 
-               #region implemented abstract members of WebSocket
+               [MonoTODO]
                public override void Abort ()
                {
                        throw new NotImplementedException ();
                }
-               public override Task CloseAsync (WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
+
+               public ClientWebSocketOptions Options {
+                       get {
+                               return options;
+                       }
+               }
+
+               public override WebSocketState State {
+                       get {
+                               return state;
+                       }
+               }
+
+               public override WebSocketCloseStatus? CloseStatus {
+                       get {
+                               if (state != WebSocketState.Closed)
+                                       return (WebSocketCloseStatus?)null;
+                               return WebSocketCloseStatus.Empty;
+                       }
+               }
+
+               public override string CloseStatusDescription {
+                       get {
+                               return null;
+                       }
+               }
+
+               public override string SubProtocol {
+                       get {
+                               return subProtocol;
+                       }
+               }
+
+               public async Task ConnectAsync (Uri uri, CancellationToken cancellationToken)
                {
-                       throw new NotImplementedException ();
+                       state = WebSocketState.Connecting;
+                       var httpUri = new UriBuilder (uri);
+                       if (uri.Scheme == "wss")
+                               httpUri.Scheme = "https";
+                       else
+                               httpUri.Scheme = "http";
+                       req = (HttpWebRequest)WebRequest.Create (httpUri.Uri);
+                       req.ReuseConnection = true;
+                       if (options.Cookies != null)
+                               req.CookieContainer = options.Cookies;
+
+                       if (options.CustomRequestHeaders.Count > 0) {
+                               foreach (var header in options.CustomRequestHeaders)
+                                       req.Headers[header.Key] = header.Value;
+                       }
+
+                       var secKey = Convert.ToBase64String (Encoding.ASCII.GetBytes (Guid.NewGuid ().ToString ().Substring (0, 16)));
+                       string expectedAccept = Convert.ToBase64String (SHA1.Create ().ComputeHash (Encoding.ASCII.GetBytes (secKey + Magic)));
+
+                       req.Headers["Upgrade"] = "WebSocket";
+                       req.Headers["Sec-WebSocket-Version"] = VersionTag;
+                       req.Headers["Sec-WebSocket-Key"] = secKey;
+                       req.Headers["Sec-WebSocket-Origin"] = uri.Host;
+                       if (options.SubProtocols.Count > 0)
+                               req.Headers["Sec-WebSocket-Protocol"] = string.Join (",", options.SubProtocols);
+
+                       if (options.Credentials != null)
+                               req.Credentials = options.Credentials;
+                       if (options.ClientCertificates != null)
+                               req.ClientCertificates = options.ClientCertificates;
+                       if (options.Proxy != null)
+                               req.Proxy = options.Proxy;
+                       req.UseDefaultCredentials = options.UseDefaultCredentials;
+                       req.Connection = "Upgrade";
+
+                       HttpWebResponse resp = null;
+                       try {
+                               resp = (HttpWebResponse)(await req.GetResponseAsync ().ConfigureAwait (false));
+                       } catch (Exception e) {
+                               throw new WebSocketException (WebSocketError.Success, e);
+                       }
+
+                       connection = req.StoredConnection;
+                       underlyingSocket = connection.socket;
+
+                       if (resp.StatusCode != HttpStatusCode.SwitchingProtocols)
+                               throw new WebSocketException ("The server returned status code '" + (int)resp.StatusCode + "' when status code '101' was expected");
+                       if (!string.Equals (resp.Headers["Upgrade"], "WebSocket", StringComparison.OrdinalIgnoreCase)
+                               || !string.Equals (resp.Headers["Connection"], "Upgrade", StringComparison.OrdinalIgnoreCase)
+                               || !string.Equals (resp.Headers["Sec-WebSocket-Accept"], expectedAccept))
+                               throw new WebSocketException ("HTTP header error during handshake");
+                       if (resp.Headers["Sec-WebSocket-Protocol"] != null) {
+                               if (!options.SubProtocols.Contains (resp.Headers["Sec-WebSocket-Protocol"]))
+                                       throw new WebSocketException (WebSocketError.UnsupportedProtocol);
+                               subProtocol = resp.Headers["Sec-WebSocket-Protocol"];
+                       }
+
+                       state = WebSocketState.Open;
                }
-               public override Task CloseOutputAsync (WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
+
+               public override Task SendAsync (ArraySegment<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
                {
-                       throw new NotImplementedException ();
+                       EnsureWebSocketConnected ();
+                       ValidateArraySegment (buffer);
+                       if (connection == null)
+                               throw new WebSocketException (WebSocketError.Faulted);
+                       var count = Math.Max (options.SendBufferSize, buffer.Count) + HeaderMaxLength;
+                       if (sendBuffer == null || sendBuffer.Length != count)
+                               sendBuffer = new byte[count];
+                       return Task.Run (() => {
+                               EnsureWebSocketState (WebSocketState.Open, WebSocketState.CloseReceived);
+                               var maskOffset = WriteHeader (messageType, buffer, endOfMessage);
+
+                               if (buffer.Count > 0)
+                                       MaskData (buffer, maskOffset);
+                               //underlyingSocket.Send (headerBuffer, 0, maskOffset + 4, SocketFlags.None);
+                               var headerLength = maskOffset + 4;
+                               Array.Copy (headerBuffer, sendBuffer, headerLength);
+                               underlyingSocket.Send (sendBuffer, 0, buffer.Count + headerLength, SocketFlags.None);
+                       });
                }
+
                public override Task<WebSocketReceiveResult> ReceiveAsync (ArraySegment<byte> buffer, CancellationToken cancellationToken)
                {
-                       throw new NotImplementedException ();
+                       EnsureWebSocketConnected ();
+                       ValidateArraySegment (buffer);
+                       return Task.Run (() => {
+                               EnsureWebSocketState (WebSocketState.Open, WebSocketState.CloseSent);
+                               // First read the two first bytes to know what we are doing next
+                               connection.Read (req, headerBuffer, 0, 2);
+                               var isLast = (headerBuffer[0] >> 7) > 0;
+                               var isMasked = (headerBuffer[1] >> 7) > 0;
+                               int mask = 0;
+                               var type = (WebSocketMessageType)(headerBuffer[0] & 0xF);
+                               long length = headerBuffer[1] & 0x7F;
+                               int offset = 0;
+                               if (length == 126) {
+                                       offset = 2;
+                                       connection.Read (req, headerBuffer, 2, offset);
+                                       length = (headerBuffer[2] << 8) | headerBuffer[3];
+                               } else if (length == 127) {
+                                       offset = 8;
+                                       connection.Read (req, headerBuffer, 2, offset);
+                                       length = 0;
+                                       for (int i = 2; i <= 9; i++)
+                                               length = (length << 8) | headerBuffer[i];
+                               }
+
+                               if (isMasked) {
+                                       connection.Read (req, headerBuffer, 2 + offset, 4);
+                                       for (int i = 0; i < 4; i++) {
+                                               var pos = i + offset + 2;
+                                               mask = (mask << 8) | headerBuffer[pos];
+                                       }
+                               }
+
+                               if (type == WebSocketMessageType.Close) {
+                                       state = WebSocketState.Closed;
+                                       var tmpBuffer = new byte[length];
+                                       connection.Read (req, tmpBuffer, 0, tmpBuffer.Length);
+                                       var closeStatus = (WebSocketCloseStatus)(tmpBuffer[0] << 8 | tmpBuffer[1]);
+                                       var closeDesc = tmpBuffer.Length > 2 ? Encoding.UTF8.GetString (tmpBuffer, 2, tmpBuffer.Length - 2) : string.Empty;
+                                       return new WebSocketReceiveResult ((int)length, type, isLast, closeStatus, closeDesc);
+                               } else {
+                                       var readLength = (int)(buffer.Count < length ? buffer.Count : length);
+                                       connection.Read (req, buffer.Array, buffer.Offset, readLength);
+
+                                       return new WebSocketReceiveResult ((int)length, type, isLast);
+                               }
+                       });
                }
-               public override Task SendAsync (ArraySegment<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
+
+               // The damn difference between those two methods is that CloseAsync will wait for server acknowledgement before completing
+               // while CloseOutputAsync will send the close packet and simply complete.
+
+               public async override Task CloseAsync (WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
                {
-                       throw new NotImplementedException ();
+                       EnsureWebSocketConnected ();
+                       await SendCloseFrame (closeStatus, statusDescription, cancellationToken).ConfigureAwait (false);
+                       state = WebSocketState.CloseSent;
+                       // TODO: figure what's exceptions are thrown if the server returns something faulty here
+                       await ReceiveAsync (new ArraySegment<byte> (new byte[0]), cancellationToken).ConfigureAwait (false);
+                       state = WebSocketState.Closed;
                }
-               public override void Dispose ()
+
+               public async override Task CloseOutputAsync (WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
                {
-                       throw new NotImplementedException ();
+                       EnsureWebSocketConnected ();
+                       await SendCloseFrame (closeStatus, statusDescription, cancellationToken).ConfigureAwait (false);
+                       state = WebSocketState.CloseSent;
                }
-               public override WebSocketCloseStatus? CloseStatus {
-                       get {
-                               throw new NotImplementedException ();
-                       }
+
+               async Task SendCloseFrame (WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
+               {
+                       var statusDescBuffer = string.IsNullOrEmpty (statusDescription) ? new byte[2] : new byte[2 + Encoding.UTF8.GetByteCount (statusDescription)];
+                       statusDescBuffer[0] = (byte)(((ushort)closeStatus) >> 8);
+                       statusDescBuffer[1] = (byte)(((ushort)closeStatus) & 0xFF);
+                       if (!string.IsNullOrEmpty (statusDescription))
+                               Encoding.UTF8.GetBytes (statusDescription, 0, statusDescription.Length, statusDescBuffer, 2);
+                       await SendAsync (new ArraySegment<byte> (statusDescBuffer), WebSocketMessageType.Close, true, cancellationToken).ConfigureAwait (false);
                }
-               public override string CloseStatusDescription {
-                       get {
-                               throw new NotImplementedException ();
+
+               int WriteHeader (WebSocketMessageType type, ArraySegment<byte> buffer, bool endOfMessage)
+               {
+                       var opCode = (byte)type;
+                       var length = buffer.Count;
+
+                       headerBuffer[0] = (byte)(opCode | (endOfMessage ? 0 : 0x80));
+                       if (length < 126) {
+                               headerBuffer[1] = (byte)length;
+                       } else if (length <= ushort.MaxValue) {
+                               headerBuffer[1] = (byte)126;
+                               headerBuffer[2] = (byte)(length / 256);
+                               headerBuffer[3] = (byte)(length % 256);
+                       } else {
+                               headerBuffer[1] = (byte)127;
+
+                               int left = length;
+                               int unit = 256;
+
+                               for (int i = 9; i > 1; i--) {
+                                       headerBuffer[i] = (byte)(left % unit);
+                                       left = left / unit;
+                               }
                        }
+
+                       var l = Math.Max (0, headerBuffer[1] - 125);
+                       var maskOffset = 2 + l * l * 2;
+                       GenerateMask (headerBuffer, maskOffset);
+
+                       // Since we are client only, we always mask the payload
+                       headerBuffer[1] |= 0x80;
+
+                       return maskOffset;
                }
-               public override WebSocketState State {
-                       get {
-                               throw new NotImplementedException ();
-                       }
+
+               void GenerateMask (byte[] mask, int offset)
+               {
+                       mask[offset + 0] = (byte)random.Next (0, 255);
+                       mask[offset + 1] = (byte)random.Next (0, 255);
+                       mask[offset + 2] = (byte)random.Next (0, 255);
+                       mask[offset + 3] = (byte)random.Next (0, 255);
                }
-               public override string SubProtocol {
-                       get {
-                               throw new NotImplementedException ();
-                       }
+
+               void MaskData (ArraySegment<byte> buffer, int maskOffset)
+               {
+                       var sendBufferOffset = maskOffset + 4;
+                       for (var i = 0; i < buffer.Count; i++)
+                               sendBuffer[i + sendBufferOffset] = (byte)(buffer.Array[buffer.Offset + i] ^ headerBuffer[maskOffset + (i % 4)]);
+               }
+
+               void EnsureWebSocketConnected ()
+               {
+                       if (state < WebSocketState.Open)
+                               throw new InvalidOperationException ("The WebSocket is not connected");
+               }
+
+               void EnsureWebSocketState (params WebSocketState[] validStates)
+               {
+                       foreach (var validState in validStates)
+                               if (state == validState)
+                                       return;
+                       throw new WebSocketException ("The WebSocket is in an invalid state ('" + state + "') for this operation. Valid states are: " + string.Join (", ", validStates));
+               }
+
+               void ValidateArraySegment (ArraySegment<byte> segment)
+               {
+                       if (segment.Array == null)
+                               throw new ArgumentNullException ("buffer.Array");
+                       if (segment.Offset < 0)
+                               throw new ArgumentOutOfRangeException ("buffer.Offset");
+                       if (segment.Offset + segment.Count > segment.Array.Length)
+                               throw new ArgumentOutOfRangeException ("buffer.Count");
                }
-               #endregion
        }
 }
 
 #endif
-
index a1d617cbadb88a109acdae79b3a8b612cadbaf3e..586752d7f67563c0dc17bc602d395b586c86191b 100644 (file)
@@ -33,11 +33,15 @@ using System.Net;
 using System.Security.Principal;
 using System.Security.Cryptography.X509Certificates;
 using System.Runtime.CompilerServices;
+using System.Collections.Generic;
 
 namespace System.Net.WebSockets
 {
        public sealed class ClientWebSocketOptions
        {
+               List<string> subprotocols = new List<string> ();
+               Dictionary<string, string> customRequestHeaders = new Dictionary<string, string> ();
+
                public X509CertificateCollection ClientCertificates { get; set; }
 
                public CookieContainer Cookies { get; set; }
@@ -50,28 +54,53 @@ namespace System.Net.WebSockets
 
                public bool UseDefaultCredentials { get; set; }
 
-               [MonoTODO]
+               internal IList<string> SubProtocols {
+                       get {
+                               return subprotocols.AsReadOnly ();
+                       }
+               }
+
+               internal Dictionary<string, string> CustomRequestHeaders {
+                       get {
+                               return customRequestHeaders;
+                       }
+               }
+
+               internal int ReceiveBufferSize {
+                       get;
+                       private set;
+               }
+
+               internal ArraySegment<byte> CustomReceiveBuffer {
+                       get;
+                       private set;
+               }
+
+               internal int SendBufferSize {
+                       get;
+                       private set;
+               }
+
                public void AddSubProtocol (string subProtocol)
                {
-                       throw new NotImplementedException ();
+                       subprotocols.Add (subProtocol);
                }
 
-               [MonoTODO]
                public void SetBuffer (int receiveBufferSize, int sendBufferSize)
                {
-                       throw new NotImplementedException ();
+                       SetBuffer (receiveBufferSize, sendBufferSize, new ArraySegment<byte> ());
                }
 
-               [MonoTODO]
                public void SetBuffer (int receiveBufferSize, int sendBufferSize, ArraySegment<byte> buffer)
                {
-                       throw new NotImplementedException ();
+                       ReceiveBufferSize = receiveBufferSize;
+                       SendBufferSize = sendBufferSize;
+                       CustomReceiveBuffer = buffer;
                }
 
-               [MonoTODO]
                public void SetRequestHeader (string headerName, string headerValue)
                {
-                       throw new NotImplementedException ();
+                       customRequestHeaders[headerName] = headerValue;
                }
        }
 }
index b4980174f6bb7c72c7fa9e251b6ad2dc3ed5ebad..e617ab38b0bfee190e81ed48bcf55f03d3d5f2fa 100644 (file)
@@ -36,72 +36,68 @@ namespace System.Net.WebSockets
 {
        public sealed class WebSocketException : Win32Exception
        {
-               public WebSocketException ()
+               const string DefaultMessage = "Generic WebSocket exception";
+
+               public WebSocketException () : this (WebSocketError.Success, -1, DefaultMessage, null)
                {
                        
                }
 
-               public WebSocketException (int nativeError) : base (nativeError)
+               public WebSocketException (int nativeError) : this (WebSocketError.Success, nativeError, DefaultMessage, null)
                {
                        
                }
 
-               public WebSocketException (string message) : base (message)
+               public WebSocketException (string message) : this (WebSocketError.Success, -1, message, null)
                {
                        
                }
 
-               public WebSocketException (WebSocketError error)
+               public WebSocketException (WebSocketError error) : this (error, -1, DefaultMessage, null)
                {
-                       WebSocketErrorCode = error;
                }
 
-               public WebSocketException (int nativeError, Exception innerException)
+               public WebSocketException (int nativeError, Exception innerException) : this (WebSocketError.Success, nativeError, DefaultMessage, innerException)
                {
                        
                }
 
-               public WebSocketException (int nativeError, string message) : base (nativeError, message)
+               public WebSocketException (int nativeError, string message) : this (WebSocketError.Success, nativeError, message, null)
                {
                        
                }
 
-               public WebSocketException (string message, Exception innerException) : base (message, innerException)
+               public WebSocketException (string message, Exception innerException) : this (WebSocketError.Success, -1, message, innerException)
                {
                        
                }
 
-               public WebSocketException (WebSocketError error, Exception innerException)
+               public WebSocketException (WebSocketError error, Exception innerException) : this (error, -1, DefaultMessage, innerException)
                {
-                       WebSocketErrorCode = error;
+
                }
 
-               public WebSocketException (WebSocketError error, int nativeError) : base (nativeError)
+               public WebSocketException (WebSocketError error, int nativeError) : this (error, nativeError, DefaultMessage, null)
                {
-                       WebSocketErrorCode = error;
                }
 
-               public WebSocketException (WebSocketError error, string message) : base (message)
+               public WebSocketException (WebSocketError error, string message) : this (error, -1, message, null)
                {
-                       WebSocketErrorCode = error;
                }
 
-               public WebSocketException (WebSocketError error, int nativeError, Exception innerException) : base (nativeError)
+               public WebSocketException (WebSocketError error, int nativeError, Exception innerException) : this (error, nativeError, DefaultMessage, innerException)
                {
-                       WebSocketErrorCode = error;
                }
 
-               public WebSocketException (WebSocketError error, int nativeError, string message) : base (nativeError, message)
+               public WebSocketException (WebSocketError error, int nativeError, string message) : this (error, nativeError, message, null)
                {
-                       WebSocketErrorCode = error;
                }
 
-               public WebSocketException (WebSocketError error, string message, Exception innerException)
+               public WebSocketException (WebSocketError error, string message, Exception innerException) : this (error, -1, message, innerException)
                {
-                       WebSocketErrorCode = error;
                }
 
-               public WebSocketException (WebSocketError error, int nativeError, string message, Exception innerException) : base (nativeError, message)
+               public WebSocketException (WebSocketError error, int nativeError, string message, Exception innerException) : base (message, innerException)
                {
                        WebSocketErrorCode = error;
                }
index 18e2d9ecbe8d4dd84741dec1b780b9e24e931a89..50cbc003c0f7c254cdcf1506a723ba8530542c91 100644 (file)
@@ -35,9 +35,9 @@ namespace System.Net.WebSockets
 {
        public enum WebSocketMessageType
        {
-               Text,
-               Binary,
-               Close
+               Text = 1,
+               Binary = 2,
+               Close = 8
        }
 }
 
index e237344e46687ba099b656681b8bb80aaa02dc6f..af97ebcdca92cc21acb5e75fa1b663fc39ebb658 100644 (file)
@@ -36,20 +36,22 @@ namespace System.Net.WebSockets
 {
        public class WebSocketReceiveResult
        {
-               [MonoTODO]
                public WebSocketReceiveResult (int count, WebSocketMessageType messageType, bool endOfMessage)
+                    : this (count, messageType, endOfMessage, null, null)
                {
-                       throw new NotImplementedException ();
                }
 
-               [MonoTODO]
                public WebSocketReceiveResult (int count,
                                               WebSocketMessageType messageType,
                                               bool endOfMessage,
                                               WebSocketCloseStatus? closeStatus,
                                               string closeStatusDescription)
                {
-                       throw new NotImplementedException ();
+                       MessageType = messageType;
+                       CloseStatus = closeStatus;
+                       CloseStatusDescription = closeStatusDescription;
+                       Count = count;
+                       EndOfMessage = endOfMessage;
                }
 
                public WebSocketCloseStatus? CloseStatus {
index a306d52be4078327ed987181b077286996e8307f..33f8326b2d6e9041d793afd58945e5239378669a 100644 (file)
@@ -499,3 +499,4 @@ System.Collections.Concurrent/BlockingCollectionTests.cs
 System.Collections.Concurrent/ConcurrentBagTests.cs
 System.Collections.Concurrent/CollectionStressTestHelper.cs
 System.Collections.Concurrent/ParallelTestHelper.cs
+System.Net.WebSockets/ClientWebSocketTest.cs
diff --git a/mcs/class/System/Test/System.Net.WebSockets/ClientWebSocketTest.cs b/mcs/class/System/Test/System.Net.WebSockets/ClientWebSocketTest.cs
new file mode 100644 (file)
index 0000000..212d5db
--- /dev/null
@@ -0,0 +1,242 @@
+using System;
+using System.Net;
+using System.Threading;
+using System.Threading.Tasks;
+using System.Collections.Generic;
+using System.Net.WebSockets;
+using System.Reflection;
+using System.Text;
+
+using NUnit.Framework;
+
+#if NET_4_5
+
+namespace MonoTests.System.Net.WebSockets
+{
+       [TestFixture]
+       public class ClientWebSocketTest
+       {
+               const string EchoServerUrl = "ws://echo.websocket.org";
+               const int Port = 42123;
+               HttpListener listener;
+               ClientWebSocket socket;
+               MethodInfo headerSetMethod;
+
+               [SetUp]
+               public void Setup ()
+               {
+                       listener = new HttpListener ();
+                       listener.Prefixes.Add ("http://localhost:" + Port + "/");
+                       listener.Start ();
+                       socket = new ClientWebSocket ();
+               }
+
+               [TearDown]
+               public void Teardown ()
+               {
+                       if (listener != null) {
+                               listener.Stop ();
+                               listener = null;
+                       }
+                       if (socket != null) {
+                               if (socket.State == WebSocketState.Open)
+                                       socket.CloseAsync (WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).Wait ();
+                               socket.Dispose ();
+                               socket = null;
+                       }
+               }
+
+               [Test]
+               public void ServerHandshakeReturnCrapStatusCodeTest ()
+               {
+                       HandleHttpRequestAsync ((req, resp) => resp.StatusCode = 418);
+                       try {
+                               socket.ConnectAsync (new Uri ("ws://localhost:" + Port), CancellationToken.None).Wait ();
+                       } catch (AggregateException e) {
+                               AssertWebSocketException (e, WebSocketError.Success, typeof (WebException));
+                               return;
+                       }
+                       Assert.Fail ("Should have thrown");
+               }
+
+               [Test]
+               public void ServerHandshakeReturnWrongUpgradeHeader ()
+               {
+                       HandleHttpRequestAsync ((req, resp) => {
+                                       resp.StatusCode = 101;
+                                       resp.Headers["Upgrade"] = "gtfo";
+                               });
+                       try {
+                               socket.ConnectAsync (new Uri ("ws://localhost:" + Port), CancellationToken.None).Wait ();
+                       } catch (AggregateException e) {
+                               AssertWebSocketException (e, WebSocketError.Success);
+                               return;
+                       }
+                       Assert.Fail ("Should have thrown");
+               }
+
+               [Test]
+               public void ServerHandshakeReturnWrongConnectionHeader ()
+               {
+                       HandleHttpRequestAsync ((req, resp) => {
+                                       resp.StatusCode = 101;
+                                       resp.Headers["Upgrade"] = "websocket";
+                                       // Mono http request doesn't like the forcing, test still valid since the default connection header value is empty
+                                       //ForceSetHeader (resp.Headers, "Connection", "Foo");
+                               });
+                       try {
+                               socket.ConnectAsync (new Uri ("ws://localhost:" + Port), CancellationToken.None).Wait ();
+                       } catch (AggregateException e) {
+                               AssertWebSocketException (e, WebSocketError.Success);
+                               return;
+                       }
+                       Assert.Fail ("Should have thrown");
+               }
+
+               [Test]
+               public void EchoTest ()
+               {
+                       const string Payload = "This is a websocket test";
+
+                       Assert.AreEqual (WebSocketState.None, socket.State);
+                       socket.ConnectAsync (new Uri (EchoServerUrl), CancellationToken.None).Wait ();
+                       Assert.AreEqual (WebSocketState.Open, socket.State);
+
+                       var sendBuffer = Encoding.ASCII.GetBytes (Payload);
+                       socket.SendAsync (new ArraySegment<byte> (sendBuffer), WebSocketMessageType.Text, true, CancellationToken.None).Wait ();
+
+                       var receiveBuffer = new byte[Payload.Length];
+                       var resp = socket.ReceiveAsync (new ArraySegment<byte> (receiveBuffer), CancellationToken.None).Result;
+
+                       Assert.AreEqual (Payload.Length, resp.Count);
+                       Assert.IsTrue (resp.EndOfMessage);
+                       Assert.AreEqual (WebSocketMessageType.Text, resp.MessageType);
+                       Assert.AreEqual (Payload, Encoding.ASCII.GetString (receiveBuffer, 0, resp.Count));
+
+                       socket.CloseAsync (WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).Wait ();
+                       Assert.AreEqual (WebSocketState.Closed, socket.State);
+               }
+
+               [Test]
+               public void CloseOutputAsyncTest ()
+               {
+                       socket.ConnectAsync (new Uri (EchoServerUrl), CancellationToken.None).Wait ();
+                       Assert.AreEqual (WebSocketState.Open, socket.State);
+
+                       socket.CloseOutputAsync (WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).Wait ();
+                       Assert.AreEqual (WebSocketState.CloseSent, socket.State);
+
+                       var resp = socket.ReceiveAsync (new ArraySegment<byte> (new byte[0]), CancellationToken.None).Result;
+                       Assert.AreEqual (WebSocketState.Closed, socket.State);
+                       Assert.AreEqual (WebSocketMessageType.Close, resp.MessageType);
+                       Assert.AreEqual (WebSocketCloseStatus.NormalClosure, resp.CloseStatus);
+                       Assert.AreEqual (string.Empty, resp.CloseStatusDescription);
+               }
+
+               [Test]
+               public void CloseAsyncTest ()
+               {
+                       socket.ConnectAsync (new Uri (EchoServerUrl), CancellationToken.None).Wait ();
+                       Assert.AreEqual (WebSocketState.Open, socket.State);
+
+                       socket.CloseAsync (WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).Wait ();
+                       Assert.AreEqual (WebSocketState.Closed, socket.State);
+               }
+
+               [Test, ExpectedException (typeof (InvalidOperationException))]
+               public void SendAsyncArgTest_NotConnected ()
+               {
+                       socket.SendAsync (new ArraySegment<byte> (new byte[0]), WebSocketMessageType.Text, true, CancellationToken.None);
+               }
+
+               [Test, ExpectedException (typeof (ArgumentNullException))]
+               public void SendAsyncArgTest_NoArray ()
+               {
+                       socket.ConnectAsync (new Uri (EchoServerUrl), CancellationToken.None).Wait ();
+                       socket.SendAsync (new ArraySegment<byte> (), WebSocketMessageType.Text, true, CancellationToken.None);
+               }
+
+               [Test, ExpectedException (typeof (InvalidOperationException))]
+               public void ReceiveAsyncArgTest_NotConnected ()
+               {
+                       socket.ReceiveAsync (new ArraySegment<byte> (new byte[0]), CancellationToken.None);
+               }
+
+               [Test, ExpectedException (typeof (ArgumentNullException))]
+               public void ReceiveAsyncArgTest_NoArray ()
+               {
+                       socket.ConnectAsync (new Uri (EchoServerUrl), CancellationToken.None).Wait ();
+                       socket.ReceiveAsync (new ArraySegment<byte> (), CancellationToken.None);
+               }
+
+               [Test]
+               public void ReceiveAsyncWrongState_Closed ()
+               {
+                       try {
+                               socket.ConnectAsync (new Uri (EchoServerUrl), CancellationToken.None).Wait ();
+                               socket.CloseAsync (WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).Wait ();
+                               socket.ReceiveAsync (new ArraySegment<byte> (new byte[0]), CancellationToken.None).Wait ();
+                       } catch (AggregateException e) {
+                               AssertWebSocketException (e, WebSocketError.Success);
+                               return;
+                       }
+                       Assert.Fail ("Should have thrown");
+               }
+
+               [Test]
+               public void SendAsyncWrongState_Closed ()
+               {
+                       try {
+                               socket.ConnectAsync (new Uri (EchoServerUrl), CancellationToken.None).Wait ();
+                               socket.CloseAsync (WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).Wait ();
+                               socket.SendAsync (new ArraySegment<byte> (new byte[0]), WebSocketMessageType.Text, true, CancellationToken.None).Wait ();
+                       } catch (AggregateException e) {
+                               AssertWebSocketException (e, WebSocketError.Success);
+                               return;
+                       }
+                       Assert.Fail ("Should have thrown");
+               }
+
+               [Test]
+               public void SendAsyncWrongState_CloseSent ()
+               {
+                       try {
+                               socket.ConnectAsync (new Uri (EchoServerUrl), CancellationToken.None).Wait ();
+                               socket.CloseOutputAsync (WebSocketCloseStatus.NormalClosure, string.Empty, CancellationToken.None).Wait ();
+                               socket.SendAsync (new ArraySegment<byte> (new byte[0]), WebSocketMessageType.Text, true, CancellationToken.None).Wait ();
+                       } catch (AggregateException e) {
+                               AssertWebSocketException (e, WebSocketError.Success);
+                               return;
+                       }
+                       Assert.Fail ("Should have thrown");
+               }
+
+               async Task HandleHttpRequestAsync (Action<HttpListenerRequest, HttpListenerResponse> handler)
+               {
+                       var ctx = await listener.GetContextAsync ();
+                       handler (ctx.Request, ctx.Response);
+                       ctx.Response.Close ();
+               }
+
+               void AssertWebSocketException (AggregateException e, WebSocketError error, Type inner = null)
+               {
+                       var wsEx = e.InnerException as WebSocketException;
+                       Console.WriteLine (e.InnerException.ToString ());
+                       Assert.IsNotNull (wsEx, "Not a websocketexception");
+                       Assert.AreEqual (error, wsEx.WebSocketErrorCode);
+                       if (inner != null) {
+                               Assert.IsNotNull (wsEx.InnerException);
+                               Assert.IsInstanceOfType (inner, wsEx.InnerException);
+                       }
+               }
+
+               void ForceSetHeader (WebHeaderCollection headers, string name, string value)
+               {
+                       if (headerSetMethod == null)
+                               headerSetMethod = typeof (WebHeaderCollection).GetMethod ("AddValue", BindingFlags.NonPublic);
+                       headerSetMethod.Invoke (headers, new[] { name, value });
+               }
+       }
+}
+
+#endif