From 06e12b837f24c7b5463f59bbb71745417a63895d Mon Sep 17 00:00:00 2001 From: Mikhail Filippov Date: Mon, 25 Apr 2016 16:06:31 +0300 Subject: [PATCH] Socket.AcceptAsync shouldn't create new socket if SocketAsyncEventArgs.AcceptSocket isn't null --- mcs/class/System/System.Net.Sockets/Socket.cs | 18 ++-- mcs/class/System/System_test.dll.sources | 1 + .../SocketAcceptAsyncTest.cs | 82 +++++++++++++++++++ 3 files changed, 95 insertions(+), 6 deletions(-) create mode 100644 mcs/class/System/Test/System.Net.Sockets/SocketAcceptAsyncTest.cs diff --git a/mcs/class/System/System.Net.Sockets/Socket.cs b/mcs/class/System/System.Net.Sockets/Socket.cs index 0f1c977a71d..b41d204ace8 100644 --- a/mcs/class/System/System.Net.Sockets/Socket.cs +++ b/mcs/class/System/System.Net.Sockets/Socket.cs @@ -991,16 +991,20 @@ namespace System.Net.Sockets static IOAsyncCallback BeginAcceptCallback = new IOAsyncCallback (ares => { SocketAsyncResult sockares = (SocketAsyncResult) ares; - Socket socket = null; - + Socket acc_socket = null; try { - socket = sockares.socket.Accept (); + if (sockares.AcceptSocket == null) { + acc_socket = sockares.socket.Accept (); + } else { + acc_socket = sockares.AcceptSocket; + sockares.socket.Accept (acc_socket); + } + } catch (Exception e) { sockares.Complete (e); return; } - - sockares.Complete (socket); + sockares.Complete (acc_socket); }); public IAsyncResult BeginAccept (int receiveSize, AsyncCallback callback, object state) @@ -3427,7 +3431,9 @@ namespace System.Net.Sockets void InitSocketAsyncEventArgs (SocketAsyncEventArgs e, AsyncCallback callback, object state, SocketOperation operation) { e.socket_async_result.Init (this, callback, state, operation); - + if (e.AcceptSocket != null) { + e.socket_async_result.AcceptSocket = e.AcceptSocket; + } e.current_socket = this; e.SetLastOperation (SocketOperationToSocketAsyncOperation (operation)); e.SocketError = SocketError.Success; diff --git a/mcs/class/System/System_test.dll.sources b/mcs/class/System/System_test.dll.sources index 7d07dd75041..6ec32758faf 100644 --- a/mcs/class/System/System_test.dll.sources +++ b/mcs/class/System/System_test.dll.sources @@ -247,6 +247,7 @@ System.Net.Sockets/MulticastOptionTest.cs System.Net.Sockets/NetworkStreamTest.cs System.Net.Sockets/TcpClientTest.cs System.Net.Sockets/TcpListenerTest.cs +System.Net.Sockets/SocketAcceptAsyncTest.cs System.Net.Sockets/SocketTest.cs System.Net.Sockets/SocketAsyncEventArgsTest.cs System.Net.Sockets/SocketConnectAsyncTest.cs diff --git a/mcs/class/System/Test/System.Net.Sockets/SocketAcceptAsyncTest.cs b/mcs/class/System/Test/System.Net.Sockets/SocketAcceptAsyncTest.cs new file mode 100644 index 00000000000..6866ff58513 --- /dev/null +++ b/mcs/class/System/Test/System.Net.Sockets/SocketAcceptAsyncTest.cs @@ -0,0 +1,82 @@ +using System; +using System.Threading; +using System.Net; +using System.Net.Sockets; +using NUnit.Framework; + +namespace MonoTests.System.Net.Sockets +{ + [TestFixture] + public class SocketAcceptAsyncTest + { + private Socket _listenSocket; + private Socket _clientSocket; + private Socket _serverSocket; + private Socket _acceptedSocket; + private ManualResetEvent _readyEvent; + private ManualResetEvent _mainEvent; + + [TestFixtureSetUp] + public void SetUp() + { + _readyEvent = new ManualResetEvent(false); + _mainEvent = new ManualResetEvent(false); + + ThreadPool.QueueUserWorkItem(_ => StartListen()); + if (!_readyEvent.WaitOne(1500)) + throw new TimeoutException(); + + _clientSocket = new Socket( + AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + _clientSocket.Connect(_listenSocket.LocalEndPoint); + _clientSocket.NoDelay = true; + } + + [TestFixtureTearDown] + public void TearDown() + { + if (_acceptedSocket != null) + _acceptedSocket.Close(); + if (_listenSocket != null) + _listenSocket.Close(); + _readyEvent.Close(); + _mainEvent.Close(); + } + + private void StartListen() + { + _listenSocket = new Socket( + AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + _listenSocket.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + _listenSocket.Listen(1); + + _serverSocket = new Socket( + AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + + var async = new SocketAsyncEventArgs(); + async.AcceptSocket = _serverSocket; + async.Completed += (s, e) => OnAccepted(e); + + _readyEvent.Set(); + + if (!_listenSocket.AcceptAsync(async)) + OnAccepted(async); + } + + private void OnAccepted(SocketAsyncEventArgs e) + { + _acceptedSocket = e.AcceptSocket; + _mainEvent.Set(); + } + + [Test] + [Category("Test")] + public void AcceptAsyncShouldUseAcceptSocketFromEventArgs() + { + if (!_mainEvent.WaitOne(1500)) + throw new TimeoutException(); + Assert.AreEqual(_serverSocket, _acceptedSocket); + _mainEvent.Reset(); + } + } +} -- 2.25.1