[Socket] Fix Socket.BeginConnect to multiple IPAddress (#4960)
[mono.git] / mcs / class / System / System.Net.Sockets / Socket.cs
index 9857fc27ed0b86ef0d14695a00a541a21784f888..d19756b6c30013be319b87d94a0747081fc47394 100644 (file)
@@ -458,7 +458,7 @@ namespace System.Net.Sockets
                        if (list != null) {
                                foreach (Socket sock in list) {
                                        if (sock == null) // MS throws a NullRef
-                                               throw new ArgumentNullException ("name", "Contains a null element");
+                                               throw new ArgumentNullException (name, "Contains a null element");
                                        sockets.Add (sock);
                                }
                        }
@@ -709,11 +709,11 @@ namespace System.Net.Sockets
                        sockares.Complete (acc_socket, total);
                });
 
-               public Socket EndAccept (IAsyncResult result)
+               public Socket EndAccept (IAsyncResult asyncResult)
                {
                        int bytes;
                        byte[] buffer;
-                       return EndAccept (out buffer, out bytes, result);
+                       return EndAccept (out buffer, out bytes, asyncResult);
                }
 
                public Socket EndAccept (out byte[] buffer, out int bytesTransferred, IAsyncResult asyncResult)
@@ -844,51 +844,6 @@ namespace System.Net.Sockets
                        Connect (Dns.GetHostAddresses (host), port);
                }
 
-               public void Connect (IPAddress[] addresses, int port)
-               {
-                       ThrowIfDisposedAndClosed ();
-
-                       if (addresses == null)
-                               throw new ArgumentNullException ("addresses");
-                       if (this.AddressFamily != AddressFamily.InterNetwork && this.AddressFamily != AddressFamily.InterNetworkV6)
-                               throw new NotSupportedException ("This method is only valid for addresses in the InterNetwork or InterNetworkV6 families");
-                       if (is_listening)
-                               throw new InvalidOperationException ();
-
-                       // FIXME: do non-blocking sockets Poll here?
-                       int error = 0;
-                       foreach (IPAddress address in addresses) {
-                               IPEndPoint iep = new IPEndPoint (address, port);
-                               
-                               iep = RemapIPEndPoint (iep);
-
-                               Connect_internal (m_Handle, iep.Serialize (), out error, is_blocking);
-                               if (error == 0) {
-                                       is_connected = true;
-                                       is_bound = true;
-                                       seed_endpoint = iep;
-                                       return;
-                               }
-                               if (error != (int)SocketError.InProgress && error != (int)SocketError.WouldBlock)
-                                       continue;
-
-                               if (!is_blocking) {
-                                       Poll (-1, SelectMode.SelectWrite);
-                                       error = (int)GetSocketOption (SocketOptionLevel.Socket, SocketOptionName.Error);
-                                       if (error == 0) {
-                                               is_connected = true;
-                                               is_bound = true;
-                                               seed_endpoint = iep;
-                                               return;
-                                       }
-                               }
-                       }
-
-                       if (error != 0)
-                               throw new SocketException (error);
-               }
-
-
                public void Connect (EndPoint remoteEP)
                {
                        ThrowIfDisposedAndClosed ();
@@ -946,12 +901,11 @@ namespace System.Net.Sockets
                                SocketAsyncResult ares;
 
                                if (!GetCheckedIPs (e, out addresses)) {
-                                       e.socket_async_result.EndPoint = e.RemoteEndPoint;
+                                       //NOTE: DualMode may cause Socket's RemoteEndpoint to differ in AddressFamily from the
+                                       // SocketAsyncEventArgs, but the SocketAsyncEventArgs itself is not changed
                                        ares = (SocketAsyncResult) BeginConnect (e.RemoteEndPoint, ConnectAsyncCallback, e);
                                } else {
-                                       DnsEndPoint dep = (e.RemoteEndPoint as DnsEndPoint);
-                                       e.socket_async_result.Addresses = addresses;
-                                       e.socket_async_result.Port = dep.Port;
+                                       DnsEndPoint dep = (DnsEndPoint)e.RemoteEndPoint;
                                        ares = (SocketAsyncResult) BeginConnect (addresses, dep.Port, ConnectAsyncCallback, e);
                                }
 
@@ -967,12 +921,6 @@ namespace System.Net.Sockets
                        return true;
                }
 
-               public static bool ConnectAsync (SocketType socketType, ProtocolType protocolType, SocketAsyncEventArgs e)
-               {
-                       var sock = new Socket (e.RemoteEndPoint.AddressFamily, socketType, protocolType);
-                       return sock.ConnectAsync (e);
-               }
-
                public static void CancelConnectAsync (SocketAsyncEventArgs e)
                {
                        if (e == null)
@@ -999,7 +947,7 @@ namespace System.Net.Sockets
                        }
                });
 
-               public IAsyncResult BeginConnect (string host, int port, AsyncCallback callback, object state)
+               public IAsyncResult BeginConnect (string host, int port, AsyncCallback requestCallback, object state)
                {
                        ThrowIfDisposedAndClosed ();
 
@@ -1012,80 +960,27 @@ namespace System.Net.Sockets
                        if (is_listening)
                                throw new InvalidOperationException ();
 
-                       return BeginConnect (Dns.GetHostAddresses (host), port, callback, state);
+                       return BeginConnect (Dns.GetHostAddresses (host), port, requestCallback, state);
                }
 
-               public IAsyncResult BeginConnect (EndPoint end_point, AsyncCallback callback, object state)
+               public IAsyncResult BeginConnect (EndPoint remoteEP, AsyncCallback callback, object state)
                {
                        ThrowIfDisposedAndClosed ();
 
-                       if (end_point == null)
-                               throw new ArgumentNullException ("end_point");
+                       if (remoteEP == null)
+                               throw new ArgumentNullException ("remoteEP");
                        if (is_listening)
                                throw new InvalidOperationException ();
 
                        SocketAsyncResult sockares = new SocketAsyncResult (this, callback, state, SocketOperation.Connect) {
-                               EndPoint = end_point,
+                               EndPoint = remoteEP,
                        };
 
-                       // Bug #75154: Connect() should not succeed for .Any addresses.
-                       if (end_point is IPEndPoint) {
-                               IPEndPoint ep = (IPEndPoint) end_point;
-                               if (ep.Address.Equals (IPAddress.Any) || ep.Address.Equals (IPAddress.IPv6Any)) {
-                                       sockares.Complete (new SocketException ((int) SocketError.AddressNotAvailable), true);
-                                       return sockares;
-                               }
-                               
-                               end_point = RemapIPEndPoint (ep);
-                       }
-
-                       int error = 0;
-
-                       if (connect_in_progress) {
-                               // This could happen when multiple IPs are used
-                               // Calling connect() again will reset the connection attempt and cause
-                               // an error. Better to just close the socket and move on.
-                               connect_in_progress = false;
-                               m_Handle.Dispose ();
-                               m_Handle = new SafeSocketHandle (Socket_internal (addressFamily, socketType, protocolType, out error), true);
-                               if (error != 0)
-                                       throw new SocketException (error);
-                       }
-
-                       bool blk = is_blocking;
-                       if (blk)
-                               Blocking = false;
-                       Connect_internal (m_Handle, end_point.Serialize (), out error, false);
-                       if (blk)
-                               Blocking = true;
-
-                       if (error == 0) {
-                               // succeeded synch
-                               is_connected = true;
-                               is_bound = true;
-                               sockares.Complete (true);
-                               return sockares;
-                       }
-
-                       if (error != (int) SocketError.InProgress && error != (int) SocketError.WouldBlock) {
-                               // error synch
-                               is_connected = false;
-                               is_bound = false;
-                               sockares.Complete (new SocketException (error), true);
-                               return sockares;
-                       }
-
-                       // continue asynch
-                       is_connected = false;
-                       is_bound = false;
-                       connect_in_progress = true;
-
-                       IOSelector.Add (sockares.Handle, new IOSelectorJob (IOOperation.Write, BeginConnectCallback, sockares));
-
+                       BeginSConnect (sockares);
                        return sockares;
                }
 
-               public IAsyncResult BeginConnect (IPAddress[] addresses, int port, AsyncCallback callback, object state)
+               public IAsyncResult BeginConnect (IPAddress[] addresses, int port, AsyncCallback requestCallback, object state)
                {
                        ThrowIfDisposedAndClosed ();
 
@@ -1100,46 +995,92 @@ namespace System.Net.Sockets
                        if (is_listening)
                                throw new InvalidOperationException ();
 
-                       SocketAsyncResult sockares = new SocketAsyncResult (this, callback, state, SocketOperation.Connect) {
+                       SocketAsyncResult sockares = new SocketAsyncResult (this, requestCallback, state, SocketOperation.Connect) {
                                Addresses = addresses,
                                Port = port,
                        };
 
                        is_connected = false;
 
-                       return BeginMConnect (sockares);
+                       BeginMConnect (sockares);
+                       return sockares;
                }
 
-               internal IAsyncResult BeginMConnect (SocketAsyncResult sockares)
+               static void BeginMConnect (SocketAsyncResult sockares)
                {
-                       SocketAsyncResult ares = null;
                        Exception exc = null;
-                       AsyncCallback callback;
 
                        for (int i = sockares.CurrentAddress; i < sockares.Addresses.Length; i++) {
                                try {
                                        sockares.CurrentAddress++;
+                                       sockares.EndPoint = new IPEndPoint (sockares.Addresses [i], sockares.Port);
 
-                                       ares = (SocketAsyncResult) BeginConnect (new IPEndPoint (sockares.Addresses [i], sockares.Port), null, sockares);
-                                       if (ares.IsCompleted && ares.CompletedSynchronously) {
-                                               ares.CheckIfThrowDelayedException ();
-
-                                               callback = ares.AsyncCallback;
-                                               if (callback != null)
-                                                       ThreadPool.UnsafeQueueUserWorkItem (_ => callback (ares), null);
-                                       }
-
-                                       break;
+                                       BeginSConnect (sockares);
+                                       return;
                                } catch (Exception e) {
                                        exc = e;
-                                       ares = null;
                                }
                        }
 
-                       if (ares == null)
-                               throw exc;
+                       throw exc;
+               }
 
-                       return sockares;
+               static void BeginSConnect (SocketAsyncResult sockares)
+               {
+                       EndPoint remoteEP = sockares.EndPoint;
+                       // Bug #75154: Connect() should not succeed for .Any addresses.
+                       if (remoteEP is IPEndPoint) {
+                               IPEndPoint ep = (IPEndPoint) remoteEP;
+                               if (ep.Address.Equals (IPAddress.Any) || ep.Address.Equals (IPAddress.IPv6Any)) {
+                                       sockares.Complete (new SocketException ((int) SocketError.AddressNotAvailable), true);
+                                       return;
+                               }
+
+                               sockares.EndPoint = remoteEP = sockares.socket.RemapIPEndPoint (ep);
+                       }
+
+                       int error = 0;
+
+                       if (sockares.socket.connect_in_progress) {
+                               // This could happen when multiple IPs are used
+                               // Calling connect() again will reset the connection attempt and cause
+                               // an error. Better to just close the socket and move on.
+                               sockares.socket.connect_in_progress = false;
+                               sockares.socket.m_Handle.Dispose ();
+                               sockares.socket.m_Handle = new SafeSocketHandle (sockares.socket.Socket_internal (sockares.socket.addressFamily, sockares.socket.socketType, sockares.socket.protocolType, out error), true);
+                               if (error != 0)
+                                       throw new SocketException (error);
+                       }
+
+                       bool blk = sockares.socket.is_blocking;
+                       if (blk)
+                               sockares.socket.Blocking = false;
+                       Connect_internal (sockares.socket.m_Handle, remoteEP.Serialize (), out error, false);
+                       if (blk)
+                               sockares.socket.Blocking = true;
+
+                       if (error == 0) {
+                               // succeeded synch
+                               sockares.socket.is_connected = true;
+                               sockares.socket.is_bound = true;
+                               sockares.Complete (true);
+                               return;
+                       }
+
+                       if (error != (int) SocketError.InProgress && error != (int) SocketError.WouldBlock) {
+                               // error synch
+                               sockares.socket.is_connected = false;
+                               sockares.socket.is_bound = false;
+                               sockares.Complete (new SocketException (error), true);
+                               return;
+                       }
+
+                       // continue asynch
+                       sockares.socket.is_connected = false;
+                       sockares.socket.is_bound = false;
+                       sockares.socket.connect_in_progress = true;
+
+                       IOSelector.Add (sockares.Handle, new IOSelectorJob (IOOperation.Write, BeginConnectCallback, sockares));
                }
 
                static IOAsyncCallback BeginConnectCallback = new IOAsyncCallback (ares => {
@@ -1150,18 +1091,11 @@ namespace System.Net.Sockets
                                return;
                        }
 
-                       SocketAsyncResult mconnect = sockares.AsyncState as SocketAsyncResult;
-                       bool is_mconnect = mconnect != null && mconnect.Addresses != null;
-
                        try {
-                               EndPoint ep = sockares.EndPoint;
-                               int error_code = (int) sockares.socket.GetSocketOption (SocketOptionLevel.Socket, SocketOptionName.Error);
+                               int error = (int) sockares.socket.GetSocketOption (SocketOptionLevel.Socket, SocketOptionName.Error);
 
-                               if (error_code == 0) {
-                                       if (is_mconnect)
-                                               sockares = mconnect;
-
-                                       sockares.socket.seed_endpoint = ep;
+                               if (error == 0) {
+                                       sockares.socket.seed_endpoint = sockares.EndPoint;
                                        sockares.socket.is_connected = true;
                                        sockares.socket.is_bound = true;
                                        sockares.socket.connect_in_progress = false;
@@ -1170,34 +1104,29 @@ namespace System.Net.Sockets
                                        return;
                                }
 
-                               if (!is_mconnect) {
+                               if (sockares.Addresses == null) {
                                        sockares.socket.connect_in_progress = false;
-                                       sockares.Complete (new SocketException (error_code));
+                                       sockares.Complete (new SocketException (error));
                                        return;
                                }
 
-                               if (mconnect.CurrentAddress >= mconnect.Addresses.Length) {
-                                       mconnect.Complete (new SocketException (error_code));
+                               if (sockares.CurrentAddress >= sockares.Addresses.Length) {
+                                       sockares.Complete (new SocketException (error));
                                        return;
                                }
 
-                               mconnect.socket.BeginMConnect (mconnect);
+                               BeginMConnect (sockares);
                        } catch (Exception e) {
                                sockares.socket.connect_in_progress = false;
-
-                               if (is_mconnect)
-                                       sockares = mconnect;
-
                                sockares.Complete (e);
-                               return;
                        }
                });
 
-               public void EndConnect (IAsyncResult result)
+               public void EndConnect (IAsyncResult asyncResult)
                {
                        ThrowIfDisposedAndClosed ();
 
-                       SocketAsyncResult sockares = ValidateEndIAsyncResult (result, "EndConnect", "result");
+                       SocketAsyncResult sockares = ValidateEndIAsyncResult (asyncResult, "EndConnect", "asyncResult");
 
                        if (!sockares.IsCompleted)
                                sockares.AsyncWaitHandle.WaitOne();
@@ -1680,21 +1609,21 @@ namespace System.Net.Sockets
                        }
                });
 
-               public IAsyncResult BeginReceiveFrom (byte[] buffer, int offset, int size, SocketFlags socket_flags, ref EndPoint remote_end, AsyncCallback callback, object state)
+               public IAsyncResult BeginReceiveFrom (byte[] buffer, int offset, int size, SocketFlags socketFlags, ref EndPoint remoteEP, AsyncCallback callback, object state)
                {
                        ThrowIfDisposedAndClosed ();
                        ThrowIfBufferNull (buffer);
                        ThrowIfBufferOutOfRange (buffer, offset, size);
 
-                       if (remote_end == null)
-                               throw new ArgumentNullException ("remote_end");
+                       if (remoteEP == null)
+                               throw new ArgumentNullException ("remoteEP");
 
                        SocketAsyncResult sockares = new SocketAsyncResult (this, callback, state, SocketOperation.ReceiveFrom) {
                                Buffer = buffer,
                                Offset = offset,
                                Size = size,
-                               SockFlags = socket_flags,
-                               EndPoint = remote_end,
+                               SockFlags = socketFlags,
+                               EndPoint = remoteEP,
                        };
 
                        QueueIOSelectorJob (ReadSem, sockares.Handle, new IOSelectorJob (IOOperation.Read, BeginReceiveFromCallback, sockares));
@@ -1722,21 +1651,21 @@ namespace System.Net.Sockets
                        sockares.Complete (total);
                });
 
-               public int EndReceiveFrom(IAsyncResult result, ref EndPoint end_point)
+               public int EndReceiveFrom(IAsyncResult asyncResult, ref EndPoint endPoint)
                {
                        ThrowIfDisposedAndClosed ();
 
-                       if (end_point == null)
-                               throw new ArgumentNullException ("remote_end");
+                       if (endPoint == null)
+                               throw new ArgumentNullException ("endPoint");
 
-                       SocketAsyncResult sockares = ValidateEndIAsyncResult (result, "EndReceiveFrom", "result");
+                       SocketAsyncResult sockares = ValidateEndIAsyncResult (asyncResult, "EndReceiveFrom", "asyncResult");
 
                        if (!sockares.IsCompleted)
                                sockares.AsyncWaitHandle.WaitOne();
 
                        sockares.CheckIfThrowDelayedException();
 
-                       end_point = sockares.EndPoint;
+                       endPoint = sockares.EndPoint;
 
                        return sockares.Total;
                }
@@ -1805,7 +1734,7 @@ namespace System.Net.Sockets
                        if (endPoint == null)
                                throw new ArgumentNullException ("endPoint");
 
-                       SocketAsyncResult sockares = ValidateEndIAsyncResult (asyncResult, "EndReceiveMessageFrom", "asyncResult");
+                       /*SocketAsyncResult sockares =*/ ValidateEndIAsyncResult (asyncResult, "EndReceiveMessageFrom", "asyncResult");
 
                        throw new NotImplementedException ();
                }
@@ -2146,7 +2075,7 @@ m_Handle, buffer, offset + sent, size - sent, socketFlags, out nativeError, is_b
                        }
                });
 
-               public IAsyncResult BeginSendTo(byte[] buffer, int offset, int size, SocketFlags socket_flags, EndPoint remote_end, AsyncCallback callback, object state)
+               public IAsyncResult BeginSendTo(byte[] buffer, int offset, int size, SocketFlags socketFlags, EndPoint remoteEP, AsyncCallback callback, object state)
                {
                        ThrowIfDisposedAndClosed ();
                        ThrowIfBufferNull (buffer);
@@ -2156,8 +2085,8 @@ m_Handle, buffer, offset + sent, size - sent, socketFlags, out nativeError, is_b
                                Buffer = buffer,
                                Offset = offset,
                                Size = size,
-                               SockFlags = socket_flags,
-                               EndPoint = remote_end,
+                               SockFlags = socketFlags,
+                               EndPoint = remoteEP,
                        };
 
                        QueueIOSelectorJob (WriteSem, sockares.Handle, new IOSelectorJob (IOOperation.Write, s => BeginSendToCallback ((SocketAsyncResult) s, 0), sockares));
@@ -2191,11 +2120,11 @@ m_Handle, buffer, offset + sent, size - sent, socketFlags, out nativeError, is_b
                        sockares.Complete ();
                }
 
-               public int EndSendTo (IAsyncResult result)
+               public int EndSendTo (IAsyncResult asyncResult)
                {
                        ThrowIfDisposedAndClosed ();
 
-                       SocketAsyncResult sockares = ValidateEndIAsyncResult (result, "EndSendTo", "result");
+                       SocketAsyncResult sockares = ValidateEndIAsyncResult (asyncResult, "EndSendTo", "result");
 
                        if (!sockares.IsCompleted)
                                sockares.AsyncWaitHandle.WaitOne();
@@ -2346,10 +2275,15 @@ m_Handle, buffer, offset + sent, size - sent, socketFlags, out nativeError, is_b
                                (is_blocking       ? 0 : SocketInformationOptions.NonBlocking) |
                                (useOverlappedIO ? SocketInformationOptions.UseOnlyOverlappedIO : 0);
 
-                       si.ProtocolInformation = Mono.DataConverter.Pack ("iiiil", (int)addressFamily, (int)socketType, (int)protocolType, is_bound ? 1 : 0, (long)Handle);
-                       m_Handle = null;
+                       MonoIOError error;
+                       IntPtr duplicateHandle;
+                       if (!MonoIO.DuplicateHandle (System.Diagnostics.Process.GetCurrentProcess ().Handle, Handle, new IntPtr (targetProcessId), out duplicateHandle, 0, 0, 0x00000002 /* DUPLICATE_SAME_ACCESS */, out error))
+                               throw MonoIO.GetException (error);
 
-                       return si;
+                       si.ProtocolInformation = Mono.DataConverter.Pack ("iiiil", (int)addressFamily, (int)socketType, (int)protocolType, is_bound ? 1 : 0, (long)duplicateHandle);
+                       m_Handle = null;
+                       return si;
                }
 
 #endregion
@@ -2504,9 +2438,6 @@ m_Handle, buffer, offset + sent, size - sent, socketFlags, out nativeError, is_b
                {
                        ThrowIfDisposedAndClosed ();
 
-                       if (optionLevel == SocketOptionLevel.Socket && optionName == SocketOptionName.ReuseAddress && optionValue != 0 && !SupportsPortReuse (protocolType))
-                               throw new SocketException ((int) SocketError.OperationNotSupported, "Operating system sockets do not support ReuseAddress.\nIf your socket is not intended to bind to the same address and port multiple times remove this option, otherwise you should ignore this exception inside a try catch and check that ReuseAddress is true before binding to the same address and port multiple times.");
-
                        int error;
                        SetSocketOption_internal (m_Handle, optionLevel, optionName, null, null, optionValue, out error);
 
@@ -2730,14 +2661,25 @@ m_Handle, buffer, offset + sent, size - sent, socketFlags, out nativeError, is_b
 
                void QueueIOSelectorJob (SemaphoreSlim sem, IntPtr handle, IOSelectorJob job)
                {
-                       sem.WaitAsync ().ContinueWith (t => {
+                       var task = sem.WaitAsync();
+                       // fast path without Task<Action> allocation.
+                       if (task.IsCompleted) {
                                if (CleanedUp) {
                                        job.MarkDisposed ();
                                        return;
                                }
-
                                IOSelector.Add (handle, job);
-                       });
+                       }
+                       else
+                       {
+                               task.ContinueWith( t => {
+                                       if (CleanedUp) {
+                                               job.MarkDisposed ();
+                                               return;
+                                       }
+                                       IOSelector.Add(handle, job);
+                               });
+                       }
                }
 
                void InitSocketAsyncEventArgs (SocketAsyncEventArgs e, AsyncCallback callback, object state, SocketOperation operation)