[Socket] Improved ConnectAsync
authorGonzalo Paniagua Javier <gonzalo.mono@gmail.com>
Tue, 26 Apr 2011 17:47:15 +0000 (13:47 -0400)
committerGonzalo Paniagua Javier <gonzalo.mono@gmail.com>
Tue, 26 Apr 2011 23:45:16 +0000 (19:45 -0400)
ConnectAsync() now uses the improved BeginConnect.

mcs/class/System/System.Net.Sockets/Socket.cs
mcs/class/System/System.Net.Sockets/SocketAsyncEventArgs.cs
mcs/class/System/System.Net.Sockets/Socket_2_1.cs

index 5844cb9374a7063881daf75373dc61cc981d9557..53ed16251f97221b2153653d6d9fd1917bf0dcc8 100644 (file)
@@ -9,7 +9,7 @@
 //
 // Copyright (C) 2001, 2002 Phillip Pearson and Ximian, Inc.
 //    http://www.myelin.co.nz
-// (c) 2004-2006 Novell, Inc. (http://www.novell.com)
+// (c) 2004-2011 Novell, Inc. (http://www.novell.com)
 //
 
 //
@@ -665,56 +665,6 @@ namespace System.Net.Sockets
                        return(req);
                }
 
-               public IAsyncResult BeginConnect(EndPoint end_point,
-                                                AsyncCallback callback,
-                                                object state) {
-
-                       if (disposed && closed)
-                               throw new ObjectDisposedException (GetType ().ToString ());
-
-                       if (end_point == null)
-                               throw new ArgumentNullException ("end_point");
-
-                       SocketAsyncResult req = new SocketAsyncResult (this, state, callback, SocketOperation.Connect);
-                       req.EndPoint = end_point;
-
-                       // Bug #75154: Connect() should not succeed for .Any addresses.
-                       if (end_point is IPEndPoint) {
-                               IPEndPoint ep = (IPEndPoint) end_point;
-                               if (ep.Address.Equals (IPAddress.Any) || ep.Address.Equals (IPAddress.IPv6Any)) {
-                                       req.Complete (new SocketException ((int) SocketError.AddressNotAvailable), true);
-                                       return req;
-                               }
-                       }
-
-                       int error = 0;
-                       bool blk = blocking;
-                       if (blk)
-                               Blocking = false;
-                       SocketAddress serial = end_point.Serialize ();
-                       Connect_internal (socket, serial, out error);
-                       if (blk)
-                               Blocking = true;
-                       if (error == 0) {
-                               // succeeded synch
-                               connected = true;
-                               req.Complete (true);
-                               return req;
-                       }
-
-                       if (error != (int) SocketError.InProgress && error != (int) SocketError.WouldBlock) {
-                               // error synch
-                               connected = false;
-                               req.Complete (new SocketException (error), true);
-                               return req;
-                       }
-
-                       // continue asynch
-                       connected = false;
-                       socket_pool_queue (Worker.Dispatcher, req);
-                       return(req);
-               }
-
                public IAsyncResult BeginConnect (IPAddress address, int port,
                                                  AsyncCallback callback,
                                                  object state)
@@ -738,63 +688,6 @@ namespace System.Net.Sockets
                        return(BeginConnect (iep, callback, state));
                }
 
-               public IAsyncResult BeginConnect (IPAddress[] addresses,
-                                                 int port,
-                                                 AsyncCallback callback,
-                                                 object state)
-               {
-                       if (disposed && closed)
-                               throw new ObjectDisposedException (GetType ().ToString ());
-
-                       if (addresses == null)
-                               throw new ArgumentNullException ("addresses");
-
-                       if (addresses.Length == 0)
-                               throw new ArgumentException ("Empty addresses list");
-
-                       if (this.AddressFamily != AddressFamily.InterNetwork &&
-                               this.AddressFamily != AddressFamily.InterNetworkV6)
-                               throw new NotSupportedException ("This method is only valid for addresses in the InterNetwork or InterNetworkV6 families");
-
-                       if (port <= 0 || port > 65535)
-                               throw new ArgumentOutOfRangeException ("port", "Must be > 0 and < 65536");
-
-                       if (islistening)
-                               throw new InvalidOperationException ();
-
-                       SocketAsyncResult req = new SocketAsyncResult (this, state, callback, SocketOperation.Connect);
-                       req.Addresses = addresses;
-                       req.Port = port;
-                       connected = false;
-                       return BeginMConnect (req);
-               }
-
-               IAsyncResult BeginMConnect (SocketAsyncResult req)
-               {
-                       IAsyncResult ares = null;
-                       Exception exc = null;
-                       for (int i = req.CurrentAddress; i < req.Addresses.Length; i++) {
-                               IPAddress addr = req.Addresses [i];
-                               IPEndPoint ep = new IPEndPoint (addr, req.Port);
-                               try {
-                                       req.CurrentAddress++;
-                                       ares = BeginConnect (ep, null, req);
-                                       if (ares.IsCompleted && ares.CompletedSynchronously) {
-                                               ((SocketAsyncResult) ares).CheckIfThrowDelayedException ();
-                                               req.DoMConnectCallback ();
-                                       }
-                                       break;
-                               } catch (Exception e) {
-                                       exc = e;
-                               }
-                       }
-
-                       if (ares == null)
-                               throw exc;
-
-                       return req;
-               }
-
                public IAsyncResult BeginConnect (string host, int port,
                                                  AsyncCallback callback,
                                                  object state)
@@ -1215,27 +1108,6 @@ namespace System.Net.Sockets
                        seed_endpoint = local_end;
                }
 
-#if !MOONLIGHT
-               public bool ConnectAsync (SocketAsyncEventArgs e)
-               {
-                       // NO check is made whether e != null in MS.NET (NRE is thrown in such case)
-                       
-                       if (disposed && closed)
-                               throw new ObjectDisposedException (GetType ().ToString ());
-                       if (islistening)
-                               throw new InvalidOperationException ("You may not perform this operation after calling the Listen method.");
-                       if (e.RemoteEndPoint == null)
-                               throw new ArgumentNullException ("remoteEP", "Value cannot be null.");
-                       if (e.BufferList != null)
-                               throw new ArgumentException ("Multiple buffers cannot be used with this method.");
-
-                       e.DoOperation (SocketAsyncOperation.Connect, this);
-
-                       // We always return true for now
-                       return true;
-               }
-#endif
-               
                public void Connect (IPAddress address, int port)
                {
                        Connect (new IPEndPoint (address, port));
index 4ef2d3c508385cefa53e19f00da5073fb18edbb2..18db222523aac94e27455b0af862249ecebea54c 100644 (file)
@@ -32,9 +32,6 @@ using System.Collections.Generic;
 using System.Reflection;
 using System.Security;
 using System.Threading;
-#if MOONLIGHT && !INSIDE_SYSTEM
-using System.Net.Policy;
-#endif
 
 namespace System.Net.Sockets
 {
@@ -241,7 +238,7 @@ namespace System.Net.Sockets
                                args.DisconnectCallback (ares);
 #endif
                        else if (op == SocketAsyncOperation.Connect)
-                               args.ConnectCallback (); /* This should not be hit yet. See DoOperation() */
+                               args.ConnectCallback ();
                        /*
                        else if (op == Socket.SocketOperation.ReceiveMessageFrom)
                        else if (op == Socket.SocketOperation.SendPackets)
@@ -266,76 +263,13 @@ namespace System.Net.Sockets
 
                void ConnectCallback ()
                {
-                       SocketError error = SocketError.AccessDenied;
                        try {
-#if MOONLIGHT || NET_4_0
-                               // Connect to the first address that match the host name, like:
-                               // http://blogs.msdn.com/ncl/archive/2009/07/20/new-ncl-features-in-net-4-0-beta-2.aspx
-                               // while skipping entries that do not match the address family
-                               DnsEndPoint dep = (RemoteEndPoint as DnsEndPoint);
-                               if (dep != null) {
-                                       IPAddress[] addresses = Dns.GetHostAddresses (dep.Host);
-                                       IPEndPoint endpoint;
-#if MOONLIGHT && !INSIDE_SYSTEM
-                                       if (!PolicyRestricted && !SecurityManager.HasElevatedPermissions) {
-                                               List<IPAddress> valid = new List<IPAddress> ();
-                                               foreach (IPAddress a in addresses) {
-                                                       // if we're not downloading a socket policy then check the policy
-                                                       // and if we're not running with elevated permissions (SL4 OoB option)
-                                                       endpoint = new IPEndPoint (a, dep.Port);
-                                                       if (!CrossDomainPolicyManager.CheckEndPoint (endpoint, policy_protocol))
-                                                               continue;
-                                                       valid.Add (a);
-                                               }
-                                               addresses = valid.ToArray ();
-                                       }
-#endif
-                                       foreach (IPAddress addr in addresses) {
-                                               try {
-                                                       if (curSocket.AddressFamily == addr.AddressFamily) {
-                                                               endpoint = new IPEndPoint (addr, dep.Port);
-                                                               error = TryConnect (endpoint);
-                                                               if (error == SocketError.Success) {
-                                                                       ConnectByNameError = null;
-                                                                       break;
-                                                               }
-                                                       }
-                                               } catch (SocketException se) {
-                                                       ConnectByNameError = se;
-                                                       error = SocketError.AccessDenied;
-                                               }
-                                       }
-                               } else {
-                                       ConnectByNameError = null;
-#if MOONLIGHT && !INSIDE_SYSTEM
-                                       if (!PolicyRestricted && !SecurityManager.HasElevatedPermissions) {
-                                               if (CrossDomainPolicyManager.CheckEndPoint (RemoteEndPoint, policy_protocol))
-                                                       error = TryConnect (RemoteEndPoint);
-                                       } else
-#endif
-                                               error = TryConnect (RemoteEndPoint);
-                               }
-#else
-                               error = TryConnect (RemoteEndPoint);
-#endif
+                               SocketError = (SocketError) Worker.result.error;
                        } finally {
-                               SocketError = error;
                                OnCompleted (this);
                        }
                }
 
-               SocketError TryConnect (EndPoint endpoint)
-               {
-                       try {
-                               curSocket.Connect (endpoint);
-                               return (curSocket.Connected ? 0 : SocketError);
-                       } catch (SocketException se){
-                               return se.SocketErrorCode;
-                       } catch (ObjectDisposedException) {
-                               return SocketError.OperationAborted;
-                       }
-               }
-
                internal void SendCallback (IAsyncResult ares)
                {
                        try {
@@ -405,29 +339,6 @@ namespace System.Net.Sockets
                }
 
 #endif
-               internal void DoOperation (SocketAsyncOperation operation, Socket socket)
-               {
-                       ThreadStart callback = null;
-                       curSocket = socket;
-                       
-                       switch (operation) {
-                               case SocketAsyncOperation.Connect:
-#if MOONLIGHT
-                                       socket.seed_endpoint = RemoteEndPoint;
-#endif
-                                       callback = new ThreadStart (ConnectCallback);
-                                       SocketError = SocketError.Success;
-                                       LastOperation = operation;
-                                       break;
-
-                               default:
-                                       throw new NotSupportedException ();
-                       }
-
-                       Thread t = new Thread (callback);
-                       t.IsBackground = true;
-                       t.Start ();
-               }
 #endregion
        }
 }
index 4ba9375edb3dd220b94fcf17c3fbd0ceabf2da37..d9edf3ea47e3cb2e957ebdfc24cc9f767ebfe977 100644 (file)
@@ -9,7 +9,7 @@
 //
 // Copyright (C) 2001, 2002 Phillip Pearson and Ximian, Inc.
 //    http://www.myelin.co.nz
-// (c) 2004-2006 Novell, Inc. (http://www.novell.com)
+// (c) 2004-2011 Novell, Inc. (http://www.novell.com)
 //
 
 //
@@ -48,6 +48,9 @@ using System.Text;
 using System.Net.Configuration;
 using System.Net.NetworkInformation;
 #endif
+#if MOONLIGHT && !INSIDE_SYSTEM
+using System.Net.Policy;
+#endif
 
 namespace System.Net.Sockets {
 
@@ -369,8 +372,7 @@ namespace System.Net.Sockets {
                                set { total = value; }
                        }
 
-                       public SocketError ErrorCode
-                       {
+                       public SocketError ErrorCode {
                                get {
                                        SocketException ex = delayedException as SocketException;
                                        if (ex != null)
@@ -1490,6 +1492,214 @@ namespace System.Net.Sockets {
 #endif
                }
 
+#if !MOONLIGHT
+               public
+#endif
+               IAsyncResult BeginConnect(EndPoint end_point, AsyncCallback callback, object state)
+               {
+                       if (disposed && closed)
+                               throw new ObjectDisposedException (GetType ().ToString ());
+
+                       if (end_point == null)
+                               throw new ArgumentNullException ("end_point");
+
+                       SocketAsyncResult req = new SocketAsyncResult (this, state, callback, SocketOperation.Connect);
+                       req.EndPoint = end_point;
+
+                       // Bug #75154: Connect() should not succeed for .Any addresses.
+                       if (end_point is IPEndPoint) {
+                               IPEndPoint ep = (IPEndPoint) end_point;
+                               if (ep.Address.Equals (IPAddress.Any) || ep.Address.Equals (IPAddress.IPv6Any)) {
+                                       req.Complete (new SocketException ((int) SocketError.AddressNotAvailable), true);
+                                       return req;
+                               }
+                       }
+
+                       int error = 0;
+                       bool blk = blocking;
+                       if (blk)
+                               Blocking = false;
+                       SocketAddress serial = end_point.Serialize ();
+                       Connect_internal (socket, serial, out error);
+                       if (blk)
+                               Blocking = true;
+                       if (error == 0) {
+                               // succeeded synch
+                               connected = true;
+                               req.Complete (true);
+                               return req;
+                       }
+
+                       if (error != (int) SocketError.InProgress && error != (int) SocketError.WouldBlock) {
+                               // error synch
+                               connected = false;
+                               req.Complete (new SocketException (error), true);
+                               return req;
+                       }
+
+                       // continue asynch
+                       connected = false;
+                       socket_pool_queue (Worker.Dispatcher, req);
+                       return req;
+               }
+
+#if !MOONLIGHT
+               public
+#else
+               internal
+#endif
+               IAsyncResult BeginConnect (IPAddress[] addresses, int port, AsyncCallback callback, object state)
+
+               {
+                       if (disposed && closed)
+                               throw new ObjectDisposedException (GetType ().ToString ());
+
+                       if (addresses == null)
+                               throw new ArgumentNullException ("addresses");
+
+                       if (addresses.Length == 0)
+                               throw new ArgumentException ("Empty addresses list");
+
+                       if (this.AddressFamily != AddressFamily.InterNetwork &&
+                               this.AddressFamily != AddressFamily.InterNetworkV6)
+                               throw new NotSupportedException ("This method is only valid for addresses in the InterNetwork or InterNetworkV6 families");
+
+                       if (port <= 0 || port > 65535)
+                               throw new ArgumentOutOfRangeException ("port", "Must be > 0 and < 65536");
+#if !MOONLIGHT
+                       if (islistening)
+                               throw new InvalidOperationException ();
+#endif
+
+                       SocketAsyncResult req = new SocketAsyncResult (this, state, callback, SocketOperation.Connect);
+                       req.Addresses = addresses;
+                       req.Port = port;
+                       connected = false;
+                       return BeginMConnect (req);
+               }
+
+               IAsyncResult BeginMConnect (SocketAsyncResult req)
+               {
+                       IAsyncResult ares = null;
+                       Exception exc = null;
+                       for (int i = req.CurrentAddress; i < req.Addresses.Length; i++) {
+                               IPAddress addr = req.Addresses [i];
+                               IPEndPoint ep = new IPEndPoint (addr, req.Port);
+                               try {
+                                       req.CurrentAddress++;
+                                       ares = BeginConnect (ep, null, req);
+                                       if (ares.IsCompleted && ares.CompletedSynchronously) {
+                                               ((SocketAsyncResult) ares).CheckIfThrowDelayedException ();
+                                               req.DoMConnectCallback ();
+                                       }
+                                       break;
+                               } catch (Exception e) {
+                                       exc = e;
+                               }
+                       }
+
+                       if (ares == null)
+                               throw exc;
+
+                       return req;
+               }
+
+               // Returns false when it is ok to use RemoteEndPoint
+               //         true when addresses must be used (and addresses could be null/empty)
+               bool GetCheckedIPs (SocketAsyncEventArgs e, out IPAddress [] addresses)
+               {
+                       addresses = null;
+#if MOONLIGHT || NET_4_0
+                       // Connect to the first address that match the host name, like:
+                       // http://blogs.msdn.com/ncl/archive/2009/07/20/new-ncl-features-in-net-4-0-beta-2.aspx
+                       // while skipping entries that do not match the address family
+                       DnsEndPoint dep = (RemoteEndPoint as DnsEndPoint);
+                       if (dep != null) {
+                               addresses = Dns.GetHostAddresses (dep.Host);
+                               IPEndPoint endpoint;
+#if MOONLIGHT && !INSIDE_SYSTEM
+                               if (!e.PolicyRestricted && !SecurityManager.HasElevatedPermissions) {
+                                       List<IPAddress> valid = new List<IPAddress> ();
+                                       foreach (IPAddress a in addresses) {
+                                               // if we're not downloading a socket policy then check the policy
+                                               // and if we're not running with elevated permissions (SL4 OoB option)
+                                               endpoint = new IPEndPoint (a, dep.Port);
+                                               if (!CrossDomainPolicyManager.CheckEndPoint (endpoint, e.SocketClientAccessPolicyProtocol))
+                                                       continue;
+                                               valid.Add (a);
+                                       }
+                                       addresses = valid.ToArray ();
+                               }
+#endif
+                               return true;
+                       } else {
+                               e.ConnectByNameError = null;
+#if MOONLIGHT && !INSIDE_SYSTEM
+                               if (!e.PolicyRestricted && !SecurityManager.HasElevatedPermissions) {
+                                       if (CrossDomainPolicyManager.CheckEndPoint (RemoteEndPoint, e.SocketClientAccessPolicyProtocol))
+                                               return false;
+                               } else
+#endif
+                                       return false;
+                       }
+                       return true; // do not use remote endpoint
+#else
+                       return false; // < NET_4_0 -> use remote endpoint
+#endif
+               }
+
+               bool ConnectAsyncReal (SocketAsyncEventArgs e)
+               {
+                       IPAddress [] addresses = null;
+                       bool use_remoteep = true;
+#if MOONLIGHT || NET_4_0
+                       use_remoteep = !GetCheckedIPs (e, out addresses);
+#endif
+                       e.curSocket = this;
+                       Worker w = e.Worker;
+                       w.Init (this, e, SocketOperation.Connect);
+                       SocketAsyncResult result = w.result;
+                       IAsyncResult ares = null;
+                       try {
+                               if (use_remoteep) {
+                                       result.EndPoint = e.RemoteEndPoint;
+                                       ares = BeginConnect (e.RemoteEndPoint, SocketAsyncEventArgs.Dispatcher, e);
+                               }
+#if MOONLIGHT || NET_4_0
+                               else {
+
+                                       DnsEndPoint dep = (e.RemoteEndPoint as DnsEndPoint);
+                                       result.Addresses = addresses;
+                                       result.Port = dep.Port;
+
+                                       ares = BeginConnect (addresses, dep.Port, SocketAsyncEventArgs.Dispatcher, e);
+                               }
+#endif
+                               if (ares.IsCompleted && ares.CompletedSynchronously) {
+                                       ((SocketAsyncResult) ares).CheckIfThrowDelayedException ();
+                                       return false;
+                               }
+                       } catch (Exception exc) {
+                               result.Complete (exc, true);
+                               return false;
+                       }
+                       return true;
+               }
+
+#if !MOONLIGHT
+               public bool ConnectAsync (SocketAsyncEventArgs e)
+               {
+                       // NO check is made whether e != null in MS.NET (NRE is thrown in such case)
+                       if (disposed && closed)
+                               throw new ObjectDisposedException (GetType ().ToString ());
+                       if (islistening)
+                               throw new InvalidOperationException ("You may not perform this operation after calling the Listen method.");
+                       if (e.RemoteEndPoint == null)
+                               throw new ArgumentNullException ("remoteEP");
+
+                       return ConnectAsyncReal (e);
+               }
+#endif
 #if MOONLIGHT
                static void CheckConnect (SocketAsyncEventArgs e)
                {
@@ -1512,10 +1722,7 @@ namespace System.Net.Sockets {
                        if ((raf != AddressFamily.Unspecified) && (raf != AddressFamily))
                                throw new NotSupportedException ("AddressFamily mismatch between socket and endpoint");
 
-                       e.DoOperation (SocketAsyncOperation.Connect, this);
-
-                       // We always return true for now
-                       return true;
+                       return ConnectAsyncReal (e);
                }
 
                public static bool ConnectAsync (SocketType socketType, ProtocolType protocolType, SocketAsyncEventArgs e)
@@ -1528,10 +1735,7 @@ namespace System.Net.Sockets {
                        if (raf == AddressFamily.Unspecified)
                                raf = AddressFamily.InterNetwork;
                        Socket s = new Socket (raf, socketType, protocolType);
-                       e.DoOperation (SocketAsyncOperation.Connect, s);
-
-                       // We always return true for now
-                       return true;
+                       return s.ConnectAsyncReal (e);
                }
 
                public static void CancelConnectAsync (SocketAsyncEventArgs e)
@@ -1539,6 +1743,7 @@ namespace System.Net.Sockets {
                        if (e == null)
                                throw new ArgumentNullException ("e");
 
+                       // FIXME: this is canceling a synchronous connect, not an async one
                        Socket s = e.ConnectSocket;
                        if ((s != null) && (s.blocking_thread != null))
                                s.blocking_thread.Abort ();