[Socket] Fix Socket.BeginConnect to multiple IPAddress (#4960)
[mono.git] / mcs / class / System / System.Net.Sockets / Socket.cs
index 08c2630eaccea76530243ebdab360af8c5b468ce..d19756b6c30013be319b87d94a0747081fc47394 100644 (file)
@@ -458,7 +458,7 @@ namespace System.Net.Sockets
                        if (list != null) {
                                foreach (Socket sock in list) {
                                        if (sock == null) // MS throws a NullRef
-                                               throw new ArgumentNullException ("name", "Contains a null element");
+                                               throw new ArgumentNullException (name, "Contains a null element");
                                        sockets.Add (sock);
                                }
                        }
@@ -709,11 +709,11 @@ namespace System.Net.Sockets
                        sockares.Complete (acc_socket, total);
                });
 
-               public Socket EndAccept (IAsyncResult result)
+               public Socket EndAccept (IAsyncResult asyncResult)
                {
                        int bytes;
                        byte[] buffer;
-                       return EndAccept (out buffer, out bytes, result);
+                       return EndAccept (out buffer, out bytes, asyncResult);
                }
 
                public Socket EndAccept (out byte[] buffer, out int bytesTransferred, IAsyncResult asyncResult)
@@ -901,12 +901,11 @@ namespace System.Net.Sockets
                                SocketAsyncResult ares;
 
                                if (!GetCheckedIPs (e, out addresses)) {
-                                       e.socket_async_result.EndPoint = e.RemoteEndPoint;
+                                       //NOTE: DualMode may cause Socket's RemoteEndpoint to differ in AddressFamily from the
+                                       // SocketAsyncEventArgs, but the SocketAsyncEventArgs itself is not changed
                                        ares = (SocketAsyncResult) BeginConnect (e.RemoteEndPoint, ConnectAsyncCallback, e);
                                } else {
-                                       DnsEndPoint dep = (e.RemoteEndPoint as DnsEndPoint);
-                                       e.socket_async_result.Addresses = addresses;
-                                       e.socket_async_result.Port = dep.Port;
+                                       DnsEndPoint dep = (DnsEndPoint)e.RemoteEndPoint;
                                        ares = (SocketAsyncResult) BeginConnect (addresses, dep.Port, ConnectAsyncCallback, e);
                                }
 
@@ -948,7 +947,7 @@ namespace System.Net.Sockets
                        }
                });
 
-               public IAsyncResult BeginConnect (string host, int port, AsyncCallback callback, object state)
+               public IAsyncResult BeginConnect (string host, int port, AsyncCallback requestCallback, object state)
                {
                        ThrowIfDisposedAndClosed ();
 
@@ -961,80 +960,27 @@ namespace System.Net.Sockets
                        if (is_listening)
                                throw new InvalidOperationException ();
 
-                       return BeginConnect (Dns.GetHostAddresses (host), port, callback, state);
+                       return BeginConnect (Dns.GetHostAddresses (host), port, requestCallback, state);
                }
 
-               public IAsyncResult BeginConnect (EndPoint end_point, AsyncCallback callback, object state)
+               public IAsyncResult BeginConnect (EndPoint remoteEP, AsyncCallback callback, object state)
                {
                        ThrowIfDisposedAndClosed ();
 
-                       if (end_point == null)
-                               throw new ArgumentNullException ("end_point");
+                       if (remoteEP == null)
+                               throw new ArgumentNullException ("remoteEP");
                        if (is_listening)
                                throw new InvalidOperationException ();
 
                        SocketAsyncResult sockares = new SocketAsyncResult (this, callback, state, SocketOperation.Connect) {
-                               EndPoint = end_point,
+                               EndPoint = remoteEP,
                        };
 
-                       // Bug #75154: Connect() should not succeed for .Any addresses.
-                       if (end_point is IPEndPoint) {
-                               IPEndPoint ep = (IPEndPoint) end_point;
-                               if (ep.Address.Equals (IPAddress.Any) || ep.Address.Equals (IPAddress.IPv6Any)) {
-                                       sockares.Complete (new SocketException ((int) SocketError.AddressNotAvailable), true);
-                                       return sockares;
-                               }
-                               
-                               end_point = RemapIPEndPoint (ep);
-                       }
-
-                       int error = 0;
-
-                       if (connect_in_progress) {
-                               // This could happen when multiple IPs are used
-                               // Calling connect() again will reset the connection attempt and cause
-                               // an error. Better to just close the socket and move on.
-                               connect_in_progress = false;
-                               m_Handle.Dispose ();
-                               m_Handle = new SafeSocketHandle (Socket_internal (addressFamily, socketType, protocolType, out error), true);
-                               if (error != 0)
-                                       throw new SocketException (error);
-                       }
-
-                       bool blk = is_blocking;
-                       if (blk)
-                               Blocking = false;
-                       Connect_internal (m_Handle, end_point.Serialize (), out error, false);
-                       if (blk)
-                               Blocking = true;
-
-                       if (error == 0) {
-                               // succeeded synch
-                               is_connected = true;
-                               is_bound = true;
-                               sockares.Complete (true);
-                               return sockares;
-                       }
-
-                       if (error != (int) SocketError.InProgress && error != (int) SocketError.WouldBlock) {
-                               // error synch
-                               is_connected = false;
-                               is_bound = false;
-                               sockares.Complete (new SocketException (error), true);
-                               return sockares;
-                       }
-
-                       // continue asynch
-                       is_connected = false;
-                       is_bound = false;
-                       connect_in_progress = true;
-
-                       IOSelector.Add (sockares.Handle, new IOSelectorJob (IOOperation.Write, BeginConnectCallback, sockares));
-
+                       BeginSConnect (sockares);
                        return sockares;
                }
 
-               public IAsyncResult BeginConnect (IPAddress[] addresses, int port, AsyncCallback callback, object state)
+               public IAsyncResult BeginConnect (IPAddress[] addresses, int port, AsyncCallback requestCallback, object state)
                {
                        ThrowIfDisposedAndClosed ();
 
@@ -1049,46 +995,92 @@ namespace System.Net.Sockets
                        if (is_listening)
                                throw new InvalidOperationException ();
 
-                       SocketAsyncResult sockares = new SocketAsyncResult (this, callback, state, SocketOperation.Connect) {
+                       SocketAsyncResult sockares = new SocketAsyncResult (this, requestCallback, state, SocketOperation.Connect) {
                                Addresses = addresses,
                                Port = port,
                        };
 
                        is_connected = false;
 
-                       return BeginMConnect (sockares);
+                       BeginMConnect (sockares);
+                       return sockares;
                }
 
-               internal IAsyncResult BeginMConnect (SocketAsyncResult sockares)
+               static void BeginMConnect (SocketAsyncResult sockares)
                {
-                       SocketAsyncResult ares = null;
                        Exception exc = null;
-                       AsyncCallback callback;
 
                        for (int i = sockares.CurrentAddress; i < sockares.Addresses.Length; i++) {
                                try {
                                        sockares.CurrentAddress++;
+                                       sockares.EndPoint = new IPEndPoint (sockares.Addresses [i], sockares.Port);
 
-                                       ares = (SocketAsyncResult) BeginConnect (new IPEndPoint (sockares.Addresses [i], sockares.Port), null, sockares);
-                                       if (ares.IsCompleted && ares.CompletedSynchronously) {
-                                               ares.CheckIfThrowDelayedException ();
-
-                                               callback = ares.AsyncCallback;
-                                               if (callback != null)
-                                                       ThreadPool.UnsafeQueueUserWorkItem (_ => callback (ares), null);
-                                       }
-
-                                       break;
+                                       BeginSConnect (sockares);
+                                       return;
                                } catch (Exception e) {
                                        exc = e;
-                                       ares = null;
                                }
                        }
 
-                       if (ares == null)
-                               throw exc;
+                       throw exc;
+               }
 
-                       return sockares;
+               static void BeginSConnect (SocketAsyncResult sockares)
+               {
+                       EndPoint remoteEP = sockares.EndPoint;
+                       // Bug #75154: Connect() should not succeed for .Any addresses.
+                       if (remoteEP is IPEndPoint) {
+                               IPEndPoint ep = (IPEndPoint) remoteEP;
+                               if (ep.Address.Equals (IPAddress.Any) || ep.Address.Equals (IPAddress.IPv6Any)) {
+                                       sockares.Complete (new SocketException ((int) SocketError.AddressNotAvailable), true);
+                                       return;
+                               }
+
+                               sockares.EndPoint = remoteEP = sockares.socket.RemapIPEndPoint (ep);
+                       }
+
+                       int error = 0;
+
+                       if (sockares.socket.connect_in_progress) {
+                               // This could happen when multiple IPs are used
+                               // Calling connect() again will reset the connection attempt and cause
+                               // an error. Better to just close the socket and move on.
+                               sockares.socket.connect_in_progress = false;
+                               sockares.socket.m_Handle.Dispose ();
+                               sockares.socket.m_Handle = new SafeSocketHandle (sockares.socket.Socket_internal (sockares.socket.addressFamily, sockares.socket.socketType, sockares.socket.protocolType, out error), true);
+                               if (error != 0)
+                                       throw new SocketException (error);
+                       }
+
+                       bool blk = sockares.socket.is_blocking;
+                       if (blk)
+                               sockares.socket.Blocking = false;
+                       Connect_internal (sockares.socket.m_Handle, remoteEP.Serialize (), out error, false);
+                       if (blk)
+                               sockares.socket.Blocking = true;
+
+                       if (error == 0) {
+                               // succeeded synch
+                               sockares.socket.is_connected = true;
+                               sockares.socket.is_bound = true;
+                               sockares.Complete (true);
+                               return;
+                       }
+
+                       if (error != (int) SocketError.InProgress && error != (int) SocketError.WouldBlock) {
+                               // error synch
+                               sockares.socket.is_connected = false;
+                               sockares.socket.is_bound = false;
+                               sockares.Complete (new SocketException (error), true);
+                               return;
+                       }
+
+                       // continue asynch
+                       sockares.socket.is_connected = false;
+                       sockares.socket.is_bound = false;
+                       sockares.socket.connect_in_progress = true;
+
+                       IOSelector.Add (sockares.Handle, new IOSelectorJob (IOOperation.Write, BeginConnectCallback, sockares));
                }
 
                static IOAsyncCallback BeginConnectCallback = new IOAsyncCallback (ares => {
@@ -1099,18 +1091,11 @@ namespace System.Net.Sockets
                                return;
                        }
 
-                       SocketAsyncResult mconnect = sockares.AsyncState as SocketAsyncResult;
-                       bool is_mconnect = mconnect != null && mconnect.Addresses != null;
-
                        try {
-                               EndPoint ep = sockares.EndPoint;
-                               int error_code = (int) sockares.socket.GetSocketOption (SocketOptionLevel.Socket, SocketOptionName.Error);
+                               int error = (int) sockares.socket.GetSocketOption (SocketOptionLevel.Socket, SocketOptionName.Error);
 
-                               if (error_code == 0) {
-                                       if (is_mconnect)
-                                               sockares = mconnect;
-
-                                       sockares.socket.seed_endpoint = ep;
+                               if (error == 0) {
+                                       sockares.socket.seed_endpoint = sockares.EndPoint;
                                        sockares.socket.is_connected = true;
                                        sockares.socket.is_bound = true;
                                        sockares.socket.connect_in_progress = false;
@@ -1119,34 +1104,29 @@ namespace System.Net.Sockets
                                        return;
                                }
 
-                               if (!is_mconnect) {
+                               if (sockares.Addresses == null) {
                                        sockares.socket.connect_in_progress = false;
-                                       sockares.Complete (new SocketException (error_code));
+                                       sockares.Complete (new SocketException (error));
                                        return;
                                }
 
-                               if (mconnect.CurrentAddress >= mconnect.Addresses.Length) {
-                                       mconnect.Complete (new SocketException (error_code));
+                               if (sockares.CurrentAddress >= sockares.Addresses.Length) {
+                                       sockares.Complete (new SocketException (error));
                                        return;
                                }
 
-                               mconnect.socket.BeginMConnect (mconnect);
+                               BeginMConnect (sockares);
                        } catch (Exception e) {
                                sockares.socket.connect_in_progress = false;
-
-                               if (is_mconnect)
-                                       sockares = mconnect;
-
                                sockares.Complete (e);
-                               return;
                        }
                });
 
-               public void EndConnect (IAsyncResult result)
+               public void EndConnect (IAsyncResult asyncResult)
                {
                        ThrowIfDisposedAndClosed ();
 
-                       SocketAsyncResult sockares = ValidateEndIAsyncResult (result, "EndConnect", "result");
+                       SocketAsyncResult sockares = ValidateEndIAsyncResult (asyncResult, "EndConnect", "asyncResult");
 
                        if (!sockares.IsCompleted)
                                sockares.AsyncWaitHandle.WaitOne();
@@ -1629,21 +1609,21 @@ namespace System.Net.Sockets
                        }
                });
 
-               public IAsyncResult BeginReceiveFrom (byte[] buffer, int offset, int size, SocketFlags socket_flags, ref EndPoint remote_end, AsyncCallback callback, object state)
+               public IAsyncResult BeginReceiveFrom (byte[] buffer, int offset, int size, SocketFlags socketFlags, ref EndPoint remoteEP, AsyncCallback callback, object state)
                {
                        ThrowIfDisposedAndClosed ();
                        ThrowIfBufferNull (buffer);
                        ThrowIfBufferOutOfRange (buffer, offset, size);
 
-                       if (remote_end == null)
-                               throw new ArgumentNullException ("remote_end");
+                       if (remoteEP == null)
+                               throw new ArgumentNullException ("remoteEP");
 
                        SocketAsyncResult sockares = new SocketAsyncResult (this, callback, state, SocketOperation.ReceiveFrom) {
                                Buffer = buffer,
                                Offset = offset,
                                Size = size,
-                               SockFlags = socket_flags,
-                               EndPoint = remote_end,
+                               SockFlags = socketFlags,
+                               EndPoint = remoteEP,
                        };
 
                        QueueIOSelectorJob (ReadSem, sockares.Handle, new IOSelectorJob (IOOperation.Read, BeginReceiveFromCallback, sockares));
@@ -1671,21 +1651,21 @@ namespace System.Net.Sockets
                        sockares.Complete (total);
                });
 
-               public int EndReceiveFrom(IAsyncResult result, ref EndPoint end_point)
+               public int EndReceiveFrom(IAsyncResult asyncResult, ref EndPoint endPoint)
                {
                        ThrowIfDisposedAndClosed ();
 
-                       if (end_point == null)
-                               throw new ArgumentNullException ("remote_end");
+                       if (endPoint == null)
+                               throw new ArgumentNullException ("endPoint");
 
-                       SocketAsyncResult sockares = ValidateEndIAsyncResult (result, "EndReceiveFrom", "result");
+                       SocketAsyncResult sockares = ValidateEndIAsyncResult (asyncResult, "EndReceiveFrom", "asyncResult");
 
                        if (!sockares.IsCompleted)
                                sockares.AsyncWaitHandle.WaitOne();
 
                        sockares.CheckIfThrowDelayedException();
 
-                       end_point = sockares.EndPoint;
+                       endPoint = sockares.EndPoint;
 
                        return sockares.Total;
                }
@@ -2095,7 +2075,7 @@ m_Handle, buffer, offset + sent, size - sent, socketFlags, out nativeError, is_b
                        }
                });
 
-               public IAsyncResult BeginSendTo(byte[] buffer, int offset, int size, SocketFlags socket_flags, EndPoint remote_end, AsyncCallback callback, object state)
+               public IAsyncResult BeginSendTo(byte[] buffer, int offset, int size, SocketFlags socketFlags, EndPoint remoteEP, AsyncCallback callback, object state)
                {
                        ThrowIfDisposedAndClosed ();
                        ThrowIfBufferNull (buffer);
@@ -2105,8 +2085,8 @@ m_Handle, buffer, offset + sent, size - sent, socketFlags, out nativeError, is_b
                                Buffer = buffer,
                                Offset = offset,
                                Size = size,
-                               SockFlags = socket_flags,
-                               EndPoint = remote_end,
+                               SockFlags = socketFlags,
+                               EndPoint = remoteEP,
                        };
 
                        QueueIOSelectorJob (WriteSem, sockares.Handle, new IOSelectorJob (IOOperation.Write, s => BeginSendToCallback ((SocketAsyncResult) s, 0), sockares));
@@ -2140,11 +2120,11 @@ m_Handle, buffer, offset + sent, size - sent, socketFlags, out nativeError, is_b
                        sockares.Complete ();
                }
 
-               public int EndSendTo (IAsyncResult result)
+               public int EndSendTo (IAsyncResult asyncResult)
                {
                        ThrowIfDisposedAndClosed ();
 
-                       SocketAsyncResult sockares = ValidateEndIAsyncResult (result, "EndSendTo", "result");
+                       SocketAsyncResult sockares = ValidateEndIAsyncResult (asyncResult, "EndSendTo", "result");
 
                        if (!sockares.IsCompleted)
                                sockares.AsyncWaitHandle.WaitOne();
@@ -2295,10 +2275,15 @@ m_Handle, buffer, offset + sent, size - sent, socketFlags, out nativeError, is_b
                                (is_blocking       ? 0 : SocketInformationOptions.NonBlocking) |
                                (useOverlappedIO ? SocketInformationOptions.UseOnlyOverlappedIO : 0);
 
-                       si.ProtocolInformation = Mono.DataConverter.Pack ("iiiil", (int)addressFamily, (int)socketType, (int)protocolType, is_bound ? 1 : 0, (long)Handle);
-                       m_Handle = null;
+                       MonoIOError error;
+                       IntPtr duplicateHandle;
+                       if (!MonoIO.DuplicateHandle (System.Diagnostics.Process.GetCurrentProcess ().Handle, Handle, new IntPtr (targetProcessId), out duplicateHandle, 0, 0, 0x00000002 /* DUPLICATE_SAME_ACCESS */, out error))
+                               throw MonoIO.GetException (error);
 
-                       return si;
+                       si.ProtocolInformation = Mono.DataConverter.Pack ("iiiil", (int)addressFamily, (int)socketType, (int)protocolType, is_bound ? 1 : 0, (long)duplicateHandle);
+                       m_Handle = null;
+                       return si;
                }
 
 #endregion
@@ -2453,9 +2438,6 @@ m_Handle, buffer, offset + sent, size - sent, socketFlags, out nativeError, is_b
                {
                        ThrowIfDisposedAndClosed ();
 
-                       if (optionLevel == SocketOptionLevel.Socket && optionName == SocketOptionName.ReuseAddress && optionValue != 0 && !SupportsPortReuse (protocolType))
-                               throw new SocketException ((int) SocketError.OperationNotSupported, "Operating system sockets do not support ReuseAddress.\nIf your socket is not intended to bind to the same address and port multiple times remove this option, otherwise you should ignore this exception inside a try catch and check that ReuseAddress is true before binding to the same address and port multiple times.");
-
                        int error;
                        SetSocketOption_internal (m_Handle, optionLevel, optionName, null, null, optionValue, out error);
 
@@ -2679,14 +2661,25 @@ m_Handle, buffer, offset + sent, size - sent, socketFlags, out nativeError, is_b
 
                void QueueIOSelectorJob (SemaphoreSlim sem, IntPtr handle, IOSelectorJob job)
                {
-                       sem.WaitAsync ().ContinueWith (t => {
+                       var task = sem.WaitAsync();
+                       // fast path without Task<Action> allocation.
+                       if (task.IsCompleted) {
                                if (CleanedUp) {
                                        job.MarkDisposed ();
                                        return;
                                }
-
                                IOSelector.Add (handle, job);
-                       });
+                       }
+                       else
+                       {
+                               task.ContinueWith( t => {
+                                       if (CleanedUp) {
+                                               job.MarkDisposed ();
+                                               return;
+                                       }
+                                       IOSelector.Add(handle, job);
+                               });
+                       }
                }
 
                void InitSocketAsyncEventArgs (SocketAsyncEventArgs e, AsyncCallback callback, object state, SocketOperation operation)