[Socket] Fix Socket.BeginConnect to multiple IPAddress (#4960)
[mono.git] / mcs / class / System / System.Net.Sockets / Socket.cs
index 5f6a4c312df305374798bcc87afa47e5add25966..d19756b6c30013be319b87d94a0747081fc47394 100644 (file)
@@ -458,7 +458,7 @@ namespace System.Net.Sockets
                        if (list != null) {
                                foreach (Socket sock in list) {
                                        if (sock == null) // MS throws a NullRef
-                                               throw new ArgumentNullException ("name", "Contains a null element");
+                                               throw new ArgumentNullException (name, "Contains a null element");
                                        sockets.Add (sock);
                                }
                        }
@@ -709,11 +709,11 @@ namespace System.Net.Sockets
                        sockares.Complete (acc_socket, total);
                });
 
-               public Socket EndAccept (IAsyncResult result)
+               public Socket EndAccept (IAsyncResult asyncResult)
                {
                        int bytes;
                        byte[] buffer;
-                       return EndAccept (out buffer, out bytes, result);
+                       return EndAccept (out buffer, out bytes, asyncResult);
                }
 
                public Socket EndAccept (out byte[] buffer, out int bytesTransferred, IAsyncResult asyncResult)
@@ -901,12 +901,11 @@ namespace System.Net.Sockets
                                SocketAsyncResult ares;
 
                                if (!GetCheckedIPs (e, out addresses)) {
-                                       e.socket_async_result.EndPoint = e.RemoteEndPoint;
+                                       //NOTE: DualMode may cause Socket's RemoteEndpoint to differ in AddressFamily from the
+                                       // SocketAsyncEventArgs, but the SocketAsyncEventArgs itself is not changed
                                        ares = (SocketAsyncResult) BeginConnect (e.RemoteEndPoint, ConnectAsyncCallback, e);
                                } else {
-                                       DnsEndPoint dep = (e.RemoteEndPoint as DnsEndPoint);
-                                       e.socket_async_result.Addresses = addresses;
-                                       e.socket_async_result.Port = dep.Port;
+                                       DnsEndPoint dep = (DnsEndPoint)e.RemoteEndPoint;
                                        ares = (SocketAsyncResult) BeginConnect (addresses, dep.Port, ConnectAsyncCallback, e);
                                }
 
@@ -948,7 +947,7 @@ namespace System.Net.Sockets
                        }
                });
 
-               public IAsyncResult BeginConnect (string host, int port, AsyncCallback callback, object state)
+               public IAsyncResult BeginConnect (string host, int port, AsyncCallback requestCallback, object state)
                {
                        ThrowIfDisposedAndClosed ();
 
@@ -961,80 +960,27 @@ namespace System.Net.Sockets
                        if (is_listening)
                                throw new InvalidOperationException ();
 
-                       return BeginConnect (Dns.GetHostAddresses (host), port, callback, state);
+                       return BeginConnect (Dns.GetHostAddresses (host), port, requestCallback, state);
                }
 
-               public IAsyncResult BeginConnect (EndPoint end_point, AsyncCallback callback, object state)
+               public IAsyncResult BeginConnect (EndPoint remoteEP, AsyncCallback callback, object state)
                {
                        ThrowIfDisposedAndClosed ();
 
-                       if (end_point == null)
-                               throw new ArgumentNullException ("end_point");
+                       if (remoteEP == null)
+                               throw new ArgumentNullException ("remoteEP");
                        if (is_listening)
                                throw new InvalidOperationException ();
 
                        SocketAsyncResult sockares = new SocketAsyncResult (this, callback, state, SocketOperation.Connect) {
-                               EndPoint = end_point,
+                               EndPoint = remoteEP,
                        };
 
-                       // Bug #75154: Connect() should not succeed for .Any addresses.
-                       if (end_point is IPEndPoint) {
-                               IPEndPoint ep = (IPEndPoint) end_point;
-                               if (ep.Address.Equals (IPAddress.Any) || ep.Address.Equals (IPAddress.IPv6Any)) {
-                                       sockares.Complete (new SocketException ((int) SocketError.AddressNotAvailable), true);
-                                       return sockares;
-                               }
-                               
-                               end_point = RemapIPEndPoint (ep);
-                       }
-
-                       int error = 0;
-
-                       if (connect_in_progress) {
-                               // This could happen when multiple IPs are used
-                               // Calling connect() again will reset the connection attempt and cause
-                               // an error. Better to just close the socket and move on.
-                               connect_in_progress = false;
-                               m_Handle.Dispose ();
-                               m_Handle = new SafeSocketHandle (Socket_internal (addressFamily, socketType, protocolType, out error), true);
-                               if (error != 0)
-                                       throw new SocketException (error);
-                       }
-
-                       bool blk = is_blocking;
-                       if (blk)
-                               Blocking = false;
-                       Connect_internal (m_Handle, end_point.Serialize (), out error, false);
-                       if (blk)
-                               Blocking = true;
-
-                       if (error == 0) {
-                               // succeeded synch
-                               is_connected = true;
-                               is_bound = true;
-                               sockares.Complete (true);
-                               return sockares;
-                       }
-
-                       if (error != (int) SocketError.InProgress && error != (int) SocketError.WouldBlock) {
-                               // error synch
-                               is_connected = false;
-                               is_bound = false;
-                               sockares.Complete (new SocketException (error), true);
-                               return sockares;
-                       }
-
-                       // continue asynch
-                       is_connected = false;
-                       is_bound = false;
-                       connect_in_progress = true;
-
-                       IOSelector.Add (sockares.Handle, new IOSelectorJob (IOOperation.Write, BeginConnectCallback, sockares));
-
+                       BeginSConnect (sockares);
                        return sockares;
                }
 
-               public IAsyncResult BeginConnect (IPAddress[] addresses, int port, AsyncCallback callback, object state)
+               public IAsyncResult BeginConnect (IPAddress[] addresses, int port, AsyncCallback requestCallback, object state)
                {
                        ThrowIfDisposedAndClosed ();
 
@@ -1049,46 +995,92 @@ namespace System.Net.Sockets
                        if (is_listening)
                                throw new InvalidOperationException ();
 
-                       SocketAsyncResult sockares = new SocketAsyncResult (this, callback, state, SocketOperation.Connect) {
+                       SocketAsyncResult sockares = new SocketAsyncResult (this, requestCallback, state, SocketOperation.Connect) {
                                Addresses = addresses,
                                Port = port,
                        };
 
                        is_connected = false;
 
-                       return BeginMConnect (sockares);
+                       BeginMConnect (sockares);
+                       return sockares;
                }
 
-               internal IAsyncResult BeginMConnect (SocketAsyncResult sockares)
+               static void BeginMConnect (SocketAsyncResult sockares)
                {
-                       SocketAsyncResult ares = null;
                        Exception exc = null;
-                       AsyncCallback callback;
 
                        for (int i = sockares.CurrentAddress; i < sockares.Addresses.Length; i++) {
                                try {
                                        sockares.CurrentAddress++;
+                                       sockares.EndPoint = new IPEndPoint (sockares.Addresses [i], sockares.Port);
 
-                                       ares = (SocketAsyncResult) BeginConnect (new IPEndPoint (sockares.Addresses [i], sockares.Port), null, sockares);
-                                       if (ares.IsCompleted && ares.CompletedSynchronously) {
-                                               ares.CheckIfThrowDelayedException ();
-
-                                               callback = ares.AsyncCallback;
-                                               if (callback != null)
-                                                       ThreadPool.UnsafeQueueUserWorkItem (_ => callback (ares), null);
-                                       }
-
-                                       break;
+                                       BeginSConnect (sockares);
+                                       return;
                                } catch (Exception e) {
                                        exc = e;
-                                       ares = null;
                                }
                        }
 
-                       if (ares == null)
-                               throw exc;
+                       throw exc;
+               }
 
-                       return sockares;
+               static void BeginSConnect (SocketAsyncResult sockares)
+               {
+                       EndPoint remoteEP = sockares.EndPoint;
+                       // Bug #75154: Connect() should not succeed for .Any addresses.
+                       if (remoteEP is IPEndPoint) {
+                               IPEndPoint ep = (IPEndPoint) remoteEP;
+                               if (ep.Address.Equals (IPAddress.Any) || ep.Address.Equals (IPAddress.IPv6Any)) {
+                                       sockares.Complete (new SocketException ((int) SocketError.AddressNotAvailable), true);
+                                       return;
+                               }
+
+                               sockares.EndPoint = remoteEP = sockares.socket.RemapIPEndPoint (ep);
+                       }
+
+                       int error = 0;
+
+                       if (sockares.socket.connect_in_progress) {
+                               // This could happen when multiple IPs are used
+                               // Calling connect() again will reset the connection attempt and cause
+                               // an error. Better to just close the socket and move on.
+                               sockares.socket.connect_in_progress = false;
+                               sockares.socket.m_Handle.Dispose ();
+                               sockares.socket.m_Handle = new SafeSocketHandle (sockares.socket.Socket_internal (sockares.socket.addressFamily, sockares.socket.socketType, sockares.socket.protocolType, out error), true);
+                               if (error != 0)
+                                       throw new SocketException (error);
+                       }
+
+                       bool blk = sockares.socket.is_blocking;
+                       if (blk)
+                               sockares.socket.Blocking = false;
+                       Connect_internal (sockares.socket.m_Handle, remoteEP.Serialize (), out error, false);
+                       if (blk)
+                               sockares.socket.Blocking = true;
+
+                       if (error == 0) {
+                               // succeeded synch
+                               sockares.socket.is_connected = true;
+                               sockares.socket.is_bound = true;
+                               sockares.Complete (true);
+                               return;
+                       }
+
+                       if (error != (int) SocketError.InProgress && error != (int) SocketError.WouldBlock) {
+                               // error synch
+                               sockares.socket.is_connected = false;
+                               sockares.socket.is_bound = false;
+                               sockares.Complete (new SocketException (error), true);
+                               return;
+                       }
+
+                       // continue asynch
+                       sockares.socket.is_connected = false;
+                       sockares.socket.is_bound = false;
+                       sockares.socket.connect_in_progress = true;
+
+                       IOSelector.Add (sockares.Handle, new IOSelectorJob (IOOperation.Write, BeginConnectCallback, sockares));
                }
 
                static IOAsyncCallback BeginConnectCallback = new IOAsyncCallback (ares => {
@@ -1099,18 +1091,11 @@ namespace System.Net.Sockets
                                return;
                        }
 
-                       SocketAsyncResult mconnect = sockares.AsyncState as SocketAsyncResult;
-                       bool is_mconnect = mconnect != null && mconnect.Addresses != null;
-
                        try {
-                               EndPoint ep = sockares.EndPoint;
-                               int error_code = (int) sockares.socket.GetSocketOption (SocketOptionLevel.Socket, SocketOptionName.Error);
-
-                               if (error_code == 0) {
-                                       if (is_mconnect)
-                                               sockares = mconnect;
+                               int error = (int) sockares.socket.GetSocketOption (SocketOptionLevel.Socket, SocketOptionName.Error);
 
-                                       sockares.socket.seed_endpoint = ep;
+                               if (error == 0) {
+                                       sockares.socket.seed_endpoint = sockares.EndPoint;
                                        sockares.socket.is_connected = true;
                                        sockares.socket.is_bound = true;
                                        sockares.socket.connect_in_progress = false;
@@ -1119,34 +1104,29 @@ namespace System.Net.Sockets
                                        return;
                                }
 
-                               if (!is_mconnect) {
+                               if (sockares.Addresses == null) {
                                        sockares.socket.connect_in_progress = false;
-                                       sockares.Complete (new SocketException (error_code));
+                                       sockares.Complete (new SocketException (error));
                                        return;
                                }
 
-                               if (mconnect.CurrentAddress >= mconnect.Addresses.Length) {
-                                       mconnect.Complete (new SocketException (error_code));
+                               if (sockares.CurrentAddress >= sockares.Addresses.Length) {
+                                       sockares.Complete (new SocketException (error));
                                        return;
                                }
 
-                               mconnect.socket.BeginMConnect (mconnect);
+                               BeginMConnect (sockares);
                        } catch (Exception e) {
                                sockares.socket.connect_in_progress = false;
-
-                               if (is_mconnect)
-                                       sockares = mconnect;
-
                                sockares.Complete (e);
-                               return;
                        }
                });
 
-               public void EndConnect (IAsyncResult result)
+               public void EndConnect (IAsyncResult asyncResult)
                {
                        ThrowIfDisposedAndClosed ();
 
-                       SocketAsyncResult sockares = ValidateEndIAsyncResult (result, "EndConnect", "result");
+                       SocketAsyncResult sockares = ValidateEndIAsyncResult (asyncResult, "EndConnect", "asyncResult");
 
                        if (!sockares.IsCompleted)
                                sockares.AsyncWaitHandle.WaitOne();
@@ -1629,21 +1609,21 @@ namespace System.Net.Sockets
                        }
                });
 
-               public IAsyncResult BeginReceiveFrom (byte[] buffer, int offset, int size, SocketFlags socket_flags, ref EndPoint remote_end, AsyncCallback callback, object state)
+               public IAsyncResult BeginReceiveFrom (byte[] buffer, int offset, int size, SocketFlags socketFlags, ref EndPoint remoteEP, AsyncCallback callback, object state)
                {
                        ThrowIfDisposedAndClosed ();
                        ThrowIfBufferNull (buffer);
                        ThrowIfBufferOutOfRange (buffer, offset, size);
 
-                       if (remote_end == null)
-                               throw new ArgumentNullException ("remote_end");
+                       if (remoteEP == null)
+                               throw new ArgumentNullException ("remoteEP");
 
                        SocketAsyncResult sockares = new SocketAsyncResult (this, callback, state, SocketOperation.ReceiveFrom) {
                                Buffer = buffer,
                                Offset = offset,
                                Size = size,
-                               SockFlags = socket_flags,
-                               EndPoint = remote_end,
+                               SockFlags = socketFlags,
+                               EndPoint = remoteEP,
                        };
 
                        QueueIOSelectorJob (ReadSem, sockares.Handle, new IOSelectorJob (IOOperation.Read, BeginReceiveFromCallback, sockares));
@@ -1671,21 +1651,21 @@ namespace System.Net.Sockets
                        sockares.Complete (total);
                });
 
-               public int EndReceiveFrom(IAsyncResult result, ref EndPoint end_point)
+               public int EndReceiveFrom(IAsyncResult asyncResult, ref EndPoint endPoint)
                {
                        ThrowIfDisposedAndClosed ();
 
-                       if (end_point == null)
-                               throw new ArgumentNullException ("remote_end");
+                       if (endPoint == null)
+                               throw new ArgumentNullException ("endPoint");
 
-                       SocketAsyncResult sockares = ValidateEndIAsyncResult (result, "EndReceiveFrom", "result");
+                       SocketAsyncResult sockares = ValidateEndIAsyncResult (asyncResult, "EndReceiveFrom", "asyncResult");
 
                        if (!sockares.IsCompleted)
                                sockares.AsyncWaitHandle.WaitOne();
 
                        sockares.CheckIfThrowDelayedException();
 
-                       end_point = sockares.EndPoint;
+                       endPoint = sockares.EndPoint;
 
                        return sockares.Total;
                }
@@ -2095,7 +2075,7 @@ m_Handle, buffer, offset + sent, size - sent, socketFlags, out nativeError, is_b
                        }
                });
 
-               public IAsyncResult BeginSendTo(byte[] buffer, int offset, int size, SocketFlags socket_flags, EndPoint remote_end, AsyncCallback callback, object state)
+               public IAsyncResult BeginSendTo(byte[] buffer, int offset, int size, SocketFlags socketFlags, EndPoint remoteEP, AsyncCallback callback, object state)
                {
                        ThrowIfDisposedAndClosed ();
                        ThrowIfBufferNull (buffer);
@@ -2105,8 +2085,8 @@ m_Handle, buffer, offset + sent, size - sent, socketFlags, out nativeError, is_b
                                Buffer = buffer,
                                Offset = offset,
                                Size = size,
-                               SockFlags = socket_flags,
-                               EndPoint = remote_end,
+                               SockFlags = socketFlags,
+                               EndPoint = remoteEP,
                        };
 
                        QueueIOSelectorJob (WriteSem, sockares.Handle, new IOSelectorJob (IOOperation.Write, s => BeginSendToCallback ((SocketAsyncResult) s, 0), sockares));
@@ -2140,11 +2120,11 @@ m_Handle, buffer, offset + sent, size - sent, socketFlags, out nativeError, is_b
                        sockares.Complete ();
                }
 
-               public int EndSendTo (IAsyncResult result)
+               public int EndSendTo (IAsyncResult asyncResult)
                {
                        ThrowIfDisposedAndClosed ();
 
-                       SocketAsyncResult sockares = ValidateEndIAsyncResult (result, "EndSendTo", "result");
+                       SocketAsyncResult sockares = ValidateEndIAsyncResult (asyncResult, "EndSendTo", "result");
 
                        if (!sockares.IsCompleted)
                                sockares.AsyncWaitHandle.WaitOne();
@@ -2681,14 +2661,25 @@ m_Handle, buffer, offset + sent, size - sent, socketFlags, out nativeError, is_b
 
                void QueueIOSelectorJob (SemaphoreSlim sem, IntPtr handle, IOSelectorJob job)
                {
-                       sem.WaitAsync ().ContinueWith (t => {
+                       var task = sem.WaitAsync();
+                       // fast path without Task<Action> allocation.
+                       if (task.IsCompleted) {
                                if (CleanedUp) {
                                        job.MarkDisposed ();
                                        return;
                                }
-
                                IOSelector.Add (handle, job);
-                       });
+                       }
+                       else
+                       {
+                               task.ContinueWith( t => {
+                                       if (CleanedUp) {
+                                               job.MarkDisposed ();
+                                               return;
+                                       }
+                                       IOSelector.Add(handle, job);
+                               });
+                       }
                }
 
                void InitSocketAsyncEventArgs (SocketAsyncEventArgs e, AsyncCallback callback, object state, SocketOperation operation)