2009-09-23 Sebastien Pouliot <sebastien@ximian.com>
authorSebastien Pouliot <sebastien@ximian.com>
Wed, 23 Sep 2009 21:39:08 +0000 (21:39 -0000)
committerSebastien Pouliot <sebastien@ximian.com>
Wed, 23 Sep 2009 21:39:08 +0000 (21:39 -0000)
* Socket_2_1.cs: Remove NET_2_1 socket policy checks from here.
* SocketAsyncEventArgs.cs: Support DnsEndPoint correctly (NET_2_1 but
that will be useful for NET_4_0 soon). Add socket policy checks here
since it could be called several times to connect to a host.

svn path=/trunk/mcs/; revision=142518

mcs/class/System/System.Net.Sockets/ChangeLog
mcs/class/System/System.Net.Sockets/SocketAsyncEventArgs.cs
mcs/class/System/System.Net.Sockets/Socket_2_1.cs

index 4c788250d85df87fb3275141d197755b24d6aa1f..5308239143a9d127030c75ad2b37ad5d9f79f257 100644 (file)
@@ -1,3 +1,10 @@
+2009-09-23  Sebastien Pouliot  <sebastien@ximian.com>
+
+       * Socket_2_1.cs: Remove NET_2_1 socket policy checks from here.
+       * SocketAsyncEventArgs.cs: Support DnsEndPoint correctly (NET_2_1 but
+       that will be useful for NET_4_0 soon). Add socket policy checks here
+       since it could be called several times to connect to a host.
+
 2009-09-23 Gonzalo Paniagua Javier <gonzalo@novell.com>
 
        * NetworkStream.cs: after disposing the stream, CanRead/CanWrite
index b9b80d73642dea93509a9258110eace0dc88d209..007815d150b4fe0799141d7d64a72ef014df752d 100644 (file)
 #if NET_2_0
 using System;
 using System.Collections.Generic;
+using System.Reflection;
+using System.Security;
 using System.Threading;
 
 namespace System.Net.Sockets
 {
        public class SocketAsyncEventArgs : EventArgs, IDisposable
        {
+#if NET_2_1 && !MONOTOUCH
+               static MethodInfo check_socket_policy;
+
+               static SocketAsyncEventArgs ()
+               {
+                       Type type = Type.GetType ("System.Windows.Browser.Net.CrossDomainPolicyManager, System.Windows.Browser, Version=2.0.5.0, Culture=Neutral, PublicKeyToken=7cec85d7bea7798e");
+                       check_socket_policy = type.GetMethod ("CheckEndPoint");
+               }
+
+               static internal bool CheckEndPoint (EndPoint endpoint)
+               {
+                       if (check_socket_policy == null)
+                               throw new SecurityException ();
+                       return ((bool) check_socket_policy.Invoke (null, new object [1] { endpoint }));
+               }
+#endif
+
                public event EventHandler<SocketAsyncEventArgs> Completed;
 
                IList <ArraySegment <byte>> _bufferList;
@@ -80,6 +99,14 @@ namespace System.Net.Sockets
                                }
                        }
                }
+
+               internal bool PolicyRestricted { get; private set; }
+
+               internal SocketAsyncEventArgs (bool policy) : 
+                       this ()
+               {
+                       PolicyRestricted = policy;
+               }
 #endif
                
                public SocketAsyncEventArgs ()
@@ -185,36 +212,70 @@ namespace System.Net.Sockets
                void ConnectCallback ()
                {
                        LastOperation = SocketAsyncOperation.Connect;
-#if NET_2_1
-                       if (SocketError == SocketError.AccessDenied) {
-                               curSocket.Connected = false;
+                       SocketError error = SocketError.Success;
+                       try {
+#if NET_2_1 && !MONOTOUCH
+                               // Connect to the first address that match the host name, like:
+                               // http://blogs.msdn.com/ncl/archive/2009/07/20/new-ncl-features-in-net-4-0-beta-2.aspx
+                               // while skipping entries that do not match the address family
+                               DnsEndPoint dep = (RemoteEndPoint as DnsEndPoint);
+                               if (dep != null) {
+                                       IPAddress[] addresses = Dns.GetHostAddresses (dep.Host);
+                                       foreach (IPAddress addr in addresses) {
+                                               try {
+                                                       if (curSocket.AddressFamily == addr.AddressFamily) {
+                                                               error = TryConnect (new IPEndPoint (addr, dep.Port));
+                                                               if (error == SocketError.Success)
+                                                                       break;
+                                                       }
+                                               }
+                                               catch (SocketException) {
+                                                       error = SocketError.AccessDenied;
+                                               }
+                                       }
+                               } else {
+                                       error = TryConnect (RemoteEndPoint);
+                               }
+#else
+                               error = TryConnect (RemoteEndPoint);
+#endif
+                       } finally {
+                               SocketError = error;
                                OnCompleted (this);
-                               return;
                        }
-#endif
-                       SocketError = SocketError.Success;
-                       SocketError error = SocketError.Success;
+               }
 
+               SocketError TryConnect (EndPoint endpoint)
+               {
+                       curSocket.Connected = false;
+                       SocketError error = SocketError.Success;
+#if NET_2_1 && !MONOTOUCH
+                       // if we're not downloading a socket policy then check the policy
+                       if (!PolicyRestricted) {
+                               error = SocketError.AccessDenied;
+                               if (!CheckEndPoint (endpoint)) {
+                                       return error;
+                               }
+                       }
+#endif
                        try {
                                if (!curSocket.Blocking) {
                                        int success;
                                        curSocket.Poll (-1, SelectMode.SelectWrite, out success);
-                                       SocketError = (SocketError)success;
+                                       error = (SocketError)success;
                                        if (success == 0)
                                                curSocket.Connected = true;
                                        else
-                                               return;
+                                               return error;
                                } else {
-                                       curSocket.seed_endpoint = RemoteEndPoint;
-                                       curSocket.Connect (RemoteEndPoint);
+                                       curSocket.seed_endpoint = endpoint;
+                                       curSocket.Connect (endpoint);
                                        curSocket.Connected = true;
                                }
                        } catch (SocketException se){
                                error = se.SocketErrorCode;
-                       } finally {
-                               SocketError = error;
-                               OnCompleted (this);
                        }
+                       return error;
                }
 
                void SendCallback ()
index 179708ef009cf7765442c19444db74c298bcbfc9..41ec125560ed7d607c4f8868bb4653d033d42709 100644 (file)
@@ -39,7 +39,6 @@ using System.Collections;
 using System.Runtime.CompilerServices;
 using System.Runtime.InteropServices;
 using System.Threading;
-using System.Reflection;
 using System.IO;
 using System.Security;
 using System.Text;
@@ -447,21 +446,12 @@ namespace System.Net.Sockets {
 #if NET_2_1 && !MONOTOUCH
                        if (protocol_type != ProtocolType.Tcp)
                                throw new SocketException ((int) SocketError.AccessDenied);
-
-                       DnsEndPoint dep = (remoteEP as DnsEndPoint);
-                       if (dep != null)
-                               serial = dep.AsIPEndPoint ().Serialize ();
-                       else
-                               serial = remoteEP.Serialize ();
 #elif NET_2_0
                        /* TODO: check this for the 1.1 profile too */
                        if (islistening)
                                throw new InvalidOperationException ();
-
-                       serial = remoteEP.Serialize ();
-#else
-                       serial = remoteEP.Serialize ();
 #endif
+                       serial = remoteEP.Serialize ();
 
                        int error = 0;
 
@@ -688,9 +678,7 @@ namespace System.Net.Sockets {
                }
 
 #if NET_2_1 && !MONOTOUCH
-               static MethodInfo check_socket_policy;
-
-               static void CheckConnect (SocketAsyncEventArgs e, bool checkPolicy)
+               static void CheckConnect (SocketAsyncEventArgs e)
                {
                        // NO check is made whether e != null in MS.NET (NRE is thrown in such case)
 
@@ -698,28 +686,14 @@ namespace System.Net.Sockets {
                                throw new ArgumentNullException ("remoteEP");
                        if (e.BufferList != null)
                                throw new ArgumentException ("Multiple buffers cannot be used with this method.");
-
-                       if (!checkPolicy)
-                               return;
-
-                       e.SocketError = SocketError.AccessDenied;
-                       if (check_socket_policy == null) {
-                               Type type = Type.GetType ("System.Windows.Browser.Net.CrossDomainPolicyManager, System.Windows.Browser, Version=2.0.5.0, Culture=Neutral, PublicKeyToken=7cec85d7bea7798e");
-                               check_socket_policy = type.GetMethod ("CheckEndPoint");
-                               if (check_socket_policy == null)
-                                       throw new SecurityException ();
-                       }
-                       if ((bool) check_socket_policy.Invoke (null, new object [1] { e.RemoteEndPoint }))
-                               e.SocketError = SocketError.Success;
                }
 
-               // only _directly_ used (with false) to download the socket policy
-               internal bool ConnectAsync (SocketAsyncEventArgs e, bool checkPolicy)
+               public bool ConnectAsync (SocketAsyncEventArgs e)
                {
                        if (disposed && closed)
                                throw new ObjectDisposedException (GetType ().ToString ());
 
-                       CheckConnect (e, checkPolicy);
+                       CheckConnect (e);
 
                        e.DoOperation (SocketAsyncOperation.Connect, this);
 
@@ -727,14 +701,10 @@ namespace System.Net.Sockets {
                        return true;
                }
 
-               public bool ConnectAsync (SocketAsyncEventArgs e)
-               {
-                       return ConnectAsync (e, true);
-               }
-
                public static bool ConnectAsync (SocketType socketType, ProtocolType protocolType, SocketAsyncEventArgs e)
                {
-                       CheckConnect (e, true);
+                       // exception ordering requires to check before creating the socket (good thing resource wise too)
+                       CheckConnect (e);
 
                        Socket s = new Socket (AddressFamily.InterNetwork, socketType, protocolType);
                        e.DoOperation (SocketAsyncOperation.Connect, s);