[Socket] Improved ConnectAsync
[mono.git] / mcs / class / System / System.Net.Sockets / Socket_2_1.cs
index 4ba9375edb3dd220b94fcf17c3fbd0ceabf2da37..d9edf3ea47e3cb2e957ebdfc24cc9f767ebfe977 100644 (file)
@@ -9,7 +9,7 @@
 //
 // Copyright (C) 2001, 2002 Phillip Pearson and Ximian, Inc.
 //    http://www.myelin.co.nz
-// (c) 2004-2006 Novell, Inc. (http://www.novell.com)
+// (c) 2004-2011 Novell, Inc. (http://www.novell.com)
 //
 
 //
@@ -48,6 +48,9 @@ using System.Text;
 using System.Net.Configuration;
 using System.Net.NetworkInformation;
 #endif
+#if MOONLIGHT && !INSIDE_SYSTEM
+using System.Net.Policy;
+#endif
 
 namespace System.Net.Sockets {
 
@@ -369,8 +372,7 @@ namespace System.Net.Sockets {
                                set { total = value; }
                        }
 
-                       public SocketError ErrorCode
-                       {
+                       public SocketError ErrorCode {
                                get {
                                        SocketException ex = delayedException as SocketException;
                                        if (ex != null)
@@ -1490,6 +1492,214 @@ namespace System.Net.Sockets {
 #endif
                }
 
+#if !MOONLIGHT
+               public
+#endif
+               IAsyncResult BeginConnect(EndPoint end_point, AsyncCallback callback, object state)
+               {
+                       if (disposed && closed)
+                               throw new ObjectDisposedException (GetType ().ToString ());
+
+                       if (end_point == null)
+                               throw new ArgumentNullException ("end_point");
+
+                       SocketAsyncResult req = new SocketAsyncResult (this, state, callback, SocketOperation.Connect);
+                       req.EndPoint = end_point;
+
+                       // Bug #75154: Connect() should not succeed for .Any addresses.
+                       if (end_point is IPEndPoint) {
+                               IPEndPoint ep = (IPEndPoint) end_point;
+                               if (ep.Address.Equals (IPAddress.Any) || ep.Address.Equals (IPAddress.IPv6Any)) {
+                                       req.Complete (new SocketException ((int) SocketError.AddressNotAvailable), true);
+                                       return req;
+                               }
+                       }
+
+                       int error = 0;
+                       bool blk = blocking;
+                       if (blk)
+                               Blocking = false;
+                       SocketAddress serial = end_point.Serialize ();
+                       Connect_internal (socket, serial, out error);
+                       if (blk)
+                               Blocking = true;
+                       if (error == 0) {
+                               // succeeded synch
+                               connected = true;
+                               req.Complete (true);
+                               return req;
+                       }
+
+                       if (error != (int) SocketError.InProgress && error != (int) SocketError.WouldBlock) {
+                               // error synch
+                               connected = false;
+                               req.Complete (new SocketException (error), true);
+                               return req;
+                       }
+
+                       // continue asynch
+                       connected = false;
+                       socket_pool_queue (Worker.Dispatcher, req);
+                       return req;
+               }
+
+#if !MOONLIGHT
+               public
+#else
+               internal
+#endif
+               IAsyncResult BeginConnect (IPAddress[] addresses, int port, AsyncCallback callback, object state)
+
+               {
+                       if (disposed && closed)
+                               throw new ObjectDisposedException (GetType ().ToString ());
+
+                       if (addresses == null)
+                               throw new ArgumentNullException ("addresses");
+
+                       if (addresses.Length == 0)
+                               throw new ArgumentException ("Empty addresses list");
+
+                       if (this.AddressFamily != AddressFamily.InterNetwork &&
+                               this.AddressFamily != AddressFamily.InterNetworkV6)
+                               throw new NotSupportedException ("This method is only valid for addresses in the InterNetwork or InterNetworkV6 families");
+
+                       if (port <= 0 || port > 65535)
+                               throw new ArgumentOutOfRangeException ("port", "Must be > 0 and < 65536");
+#if !MOONLIGHT
+                       if (islistening)
+                               throw new InvalidOperationException ();
+#endif
+
+                       SocketAsyncResult req = new SocketAsyncResult (this, state, callback, SocketOperation.Connect);
+                       req.Addresses = addresses;
+                       req.Port = port;
+                       connected = false;
+                       return BeginMConnect (req);
+               }
+
+               IAsyncResult BeginMConnect (SocketAsyncResult req)
+               {
+                       IAsyncResult ares = null;
+                       Exception exc = null;
+                       for (int i = req.CurrentAddress; i < req.Addresses.Length; i++) {
+                               IPAddress addr = req.Addresses [i];
+                               IPEndPoint ep = new IPEndPoint (addr, req.Port);
+                               try {
+                                       req.CurrentAddress++;
+                                       ares = BeginConnect (ep, null, req);
+                                       if (ares.IsCompleted && ares.CompletedSynchronously) {
+                                               ((SocketAsyncResult) ares).CheckIfThrowDelayedException ();
+                                               req.DoMConnectCallback ();
+                                       }
+                                       break;
+                               } catch (Exception e) {
+                                       exc = e;
+                               }
+                       }
+
+                       if (ares == null)
+                               throw exc;
+
+                       return req;
+               }
+
+               // Returns false when it is ok to use RemoteEndPoint
+               //         true when addresses must be used (and addresses could be null/empty)
+               bool GetCheckedIPs (SocketAsyncEventArgs e, out IPAddress [] addresses)
+               {
+                       addresses = null;
+#if MOONLIGHT || NET_4_0
+                       // Connect to the first address that match the host name, like:
+                       // http://blogs.msdn.com/ncl/archive/2009/07/20/new-ncl-features-in-net-4-0-beta-2.aspx
+                       // while skipping entries that do not match the address family
+                       DnsEndPoint dep = (RemoteEndPoint as DnsEndPoint);
+                       if (dep != null) {
+                               addresses = Dns.GetHostAddresses (dep.Host);
+                               IPEndPoint endpoint;
+#if MOONLIGHT && !INSIDE_SYSTEM
+                               if (!e.PolicyRestricted && !SecurityManager.HasElevatedPermissions) {
+                                       List<IPAddress> valid = new List<IPAddress> ();
+                                       foreach (IPAddress a in addresses) {
+                                               // if we're not downloading a socket policy then check the policy
+                                               // and if we're not running with elevated permissions (SL4 OoB option)
+                                               endpoint = new IPEndPoint (a, dep.Port);
+                                               if (!CrossDomainPolicyManager.CheckEndPoint (endpoint, e.SocketClientAccessPolicyProtocol))
+                                                       continue;
+                                               valid.Add (a);
+                                       }
+                                       addresses = valid.ToArray ();
+                               }
+#endif
+                               return true;
+                       } else {
+                               e.ConnectByNameError = null;
+#if MOONLIGHT && !INSIDE_SYSTEM
+                               if (!e.PolicyRestricted && !SecurityManager.HasElevatedPermissions) {
+                                       if (CrossDomainPolicyManager.CheckEndPoint (RemoteEndPoint, e.SocketClientAccessPolicyProtocol))
+                                               return false;
+                               } else
+#endif
+                                       return false;
+                       }
+                       return true; // do not use remote endpoint
+#else
+                       return false; // < NET_4_0 -> use remote endpoint
+#endif
+               }
+
+               bool ConnectAsyncReal (SocketAsyncEventArgs e)
+               {
+                       IPAddress [] addresses = null;
+                       bool use_remoteep = true;
+#if MOONLIGHT || NET_4_0
+                       use_remoteep = !GetCheckedIPs (e, out addresses);
+#endif
+                       e.curSocket = this;
+                       Worker w = e.Worker;
+                       w.Init (this, e, SocketOperation.Connect);
+                       SocketAsyncResult result = w.result;
+                       IAsyncResult ares = null;
+                       try {
+                               if (use_remoteep) {
+                                       result.EndPoint = e.RemoteEndPoint;
+                                       ares = BeginConnect (e.RemoteEndPoint, SocketAsyncEventArgs.Dispatcher, e);
+                               }
+#if MOONLIGHT || NET_4_0
+                               else {
+
+                                       DnsEndPoint dep = (e.RemoteEndPoint as DnsEndPoint);
+                                       result.Addresses = addresses;
+                                       result.Port = dep.Port;
+
+                                       ares = BeginConnect (addresses, dep.Port, SocketAsyncEventArgs.Dispatcher, e);
+                               }
+#endif
+                               if (ares.IsCompleted && ares.CompletedSynchronously) {
+                                       ((SocketAsyncResult) ares).CheckIfThrowDelayedException ();
+                                       return false;
+                               }
+                       } catch (Exception exc) {
+                               result.Complete (exc, true);
+                               return false;
+                       }
+                       return true;
+               }
+
+#if !MOONLIGHT
+               public bool ConnectAsync (SocketAsyncEventArgs e)
+               {
+                       // NO check is made whether e != null in MS.NET (NRE is thrown in such case)
+                       if (disposed && closed)
+                               throw new ObjectDisposedException (GetType ().ToString ());
+                       if (islistening)
+                               throw new InvalidOperationException ("You may not perform this operation after calling the Listen method.");
+                       if (e.RemoteEndPoint == null)
+                               throw new ArgumentNullException ("remoteEP");
+
+                       return ConnectAsyncReal (e);
+               }
+#endif
 #if MOONLIGHT
                static void CheckConnect (SocketAsyncEventArgs e)
                {
@@ -1512,10 +1722,7 @@ namespace System.Net.Sockets {
                        if ((raf != AddressFamily.Unspecified) && (raf != AddressFamily))
                                throw new NotSupportedException ("AddressFamily mismatch between socket and endpoint");
 
-                       e.DoOperation (SocketAsyncOperation.Connect, this);
-
-                       // We always return true for now
-                       return true;
+                       return ConnectAsyncReal (e);
                }
 
                public static bool ConnectAsync (SocketType socketType, ProtocolType protocolType, SocketAsyncEventArgs e)
@@ -1528,10 +1735,7 @@ namespace System.Net.Sockets {
                        if (raf == AddressFamily.Unspecified)
                                raf = AddressFamily.InterNetwork;
                        Socket s = new Socket (raf, socketType, protocolType);
-                       e.DoOperation (SocketAsyncOperation.Connect, s);
-
-                       // We always return true for now
-                       return true;
+                       return s.ConnectAsyncReal (e);
                }
 
                public static void CancelConnectAsync (SocketAsyncEventArgs e)
@@ -1539,6 +1743,7 @@ namespace System.Net.Sockets {
                        if (e == null)
                                throw new ArgumentNullException ("e");
 
+                       // FIXME: this is canceling a synchronous connect, not an async one
                        Socket s = e.ConnectSocket;
                        if ((s != null) && (s.blocking_thread != null))
                                s.blocking_thread.Abort ();