Add support for an alternative managed DNS resolver
[mono.git] / mcs / class / System / System.Net / Dns.cs
index ed61fc4012e262862e37085b84b1c41a47075c3f..8419fe396eb151a7c585d7d8b20ac3f99d45bb7b 100644 (file)
@@ -3,11 +3,14 @@
 // Authors:
 //     Mads Pultz (mpultz@diku.dk)
 //     Lawrence Pit (loz@cable.a2000.nl)
-//     Marek Safar (marek.safar@gmail.com)
 
+// Author: Mads Pultz (mpultz@diku.dk)
+//        Lawrence Pit (loz@cable.a2000.nl)
+//        Marek Safar (marek.safar@gmail.com)
+//        Gonzalo Paniagua Javier (gonzalo.mono@gmail.com)
 //
 // (C) Mads Pultz, 2001
-// Copyright 2011 Xamarin Inc.
+// Copyright (c) 2011 Novell, Inc.
 
 //
 // Permission is hereby granted, free of charge, to any person obtaining
@@ -41,11 +44,24 @@ using System.Runtime.Remoting.Messaging;
 using System.Threading.Tasks;
 #endif
 
+using Mono.Dns;
+
 namespace System.Net {
        public static class Dns {
+               static bool use_mono_dns;
+               static SimpleResolver resolver;
+
                static Dns ()
                {
                        System.Net.Sockets.Socket.CheckProtocolSupport();
+                       if (Environment.GetEnvironmentVariable ("MONO_DNS") != null) {
+                               resolver = new SimpleResolver ();
+                               use_mono_dns = true;
+                       }
+               }
+
+               internal static bool UseMonoDns {
+                       get { return use_mono_dns; }
                }
 
 #if !MOONLIGHT // global remove of async methods
@@ -56,30 +72,82 @@ namespace System.Net {
                private delegate IPHostEntry GetHostEntryIPCallback (IPAddress hostAddress);
                private delegate IPAddress [] GetHostAddressesCallback (string hostName);
 
+               static void OnCompleted (object sender, SimpleResolverEventArgs e)
+               {
+                       DnsAsyncResult ares = (DnsAsyncResult) e.UserToken;
+                       IPHostEntry entry = e.HostEntry;
+                       if (entry == null || e.ResolverError != 0) {
+                               ares.SetCompleted (false, new Exception ("Error: " + e.ResolverError));
+                               return;
+                       }
+                       ares.SetCompleted (false, entry);
+               }
+
+               static IAsyncResult BeginAsyncCallAddresses (string host, AsyncCallback callback, object state)
+               {
+                       SimpleResolverEventArgs e = new SimpleResolverEventArgs ();
+                       e.Completed += OnCompleted;
+                       e.HostName = host;
+                       DnsAsyncResult ares = new DnsAsyncResult (callback, state);
+                       e.UserToken = ares;
+                       if (resolver.GetHostAddressesAsync (e) == false)
+                               ares.SetCompleted (true, e.HostEntry); // Completed synchronously
+                       return ares;
+               }
+
+               static IAsyncResult BeginAsyncCall (string host, AsyncCallback callback, object state)
+               {
+                       SimpleResolverEventArgs e = new SimpleResolverEventArgs ();
+                       e.Completed += OnCompleted;
+                       e.HostName = host;
+                       DnsAsyncResult ares = new DnsAsyncResult (callback, state);
+                       e.UserToken = ares;
+                       if (resolver.GetHostEntryAsync (e) == false)
+                               ares.SetCompleted (true, e.HostEntry); // Completed synchronously
+                       return ares;
+               }
+
+               static IPHostEntry EndAsyncCall (DnsAsyncResult ares)
+               {
+                       if (ares == null)
+                               throw new ArgumentException ("Invalid asyncResult");
+                       if (!ares.IsCompleted)
+                               ares.AsyncWaitHandle.WaitOne ();
+                       if (ares.Exception != null)
+                               throw ares.Exception;
+                       IPHostEntry entry = ares.HostEntry;
+                       if (entry == null || entry.AddressList == null || entry.AddressList.Length == 0)
+                               throw new SocketException(11001);
+                       return entry;
+               }
+
                [Obsolete ("Use BeginGetHostEntry instead")]
-               public static IAsyncResult BeginGetHostByName (string hostName,
-                       AsyncCallback requestCallback, object stateObject)
+               public static IAsyncResult BeginGetHostByName (string hostName, AsyncCallback requestCallback, object stateObject)
                {
                        if (hostName == null)
                                throw new ArgumentNullException ("hostName");
 
+                       if (use_mono_dns)
+                               return BeginAsyncCall (hostName, requestCallback, stateObject);
+
                        GetHostByNameCallback c = new GetHostByNameCallback (GetHostByName);
                        return c.BeginInvoke (hostName, requestCallback, stateObject);
                }
 
                [Obsolete ("Use BeginGetHostEntry instead")]
-               public static IAsyncResult BeginResolve (string hostName,
-                       AsyncCallback requestCallback, object stateObject)
+               public static IAsyncResult BeginResolve (string hostName, AsyncCallback requestCallback, object stateObject)
                {
                        if (hostName == null)
                                throw new ArgumentNullException ("hostName");
 
+                       if (use_mono_dns)
+                               return BeginAsyncCall (hostName, requestCallback, stateObject);
+
                        ResolveCallback c = new ResolveCallback (Resolve);
                        return c.BeginInvoke (hostName, requestCallback, stateObject);
                }
 
-               public static IAsyncResult BeginGetHostAddresses (string hostNameOrAddress,
-                       AsyncCallback requestCallback, object state)
+               public static IAsyncResult BeginGetHostAddresses (string hostNameOrAddress, AsyncCallback requestCallback, object stateObject)
                {
                        if (hostNameOrAddress == null)
                                throw new ArgumentNullException ("hostName");
@@ -89,12 +157,14 @@ namespace System.Net {
                                        "cannot use them as target address.",
                                        "hostNameOrAddress");
 
+                       if (use_mono_dns)
+                               return BeginAsyncCallAddresses (hostNameOrAddress, requestCallback, stateObject);
+
                        GetHostAddressesCallback c = new GetHostAddressesCallback (GetHostAddresses);
                        return c.BeginInvoke (hostNameOrAddress, requestCallback, state);
                }
 
-               public static IAsyncResult BeginGetHostEntry (string hostNameOrAddress,
-                       AsyncCallback requestCallback, object stateObject)
+               public static IAsyncResult BeginGetHostEntry (string hostNameOrAddress, AsyncCallback requestCallback, object stateObject)
                {
                        if (hostNameOrAddress == null)
                                throw new ArgumentNullException ("hostName");
@@ -104,16 +174,21 @@ namespace System.Net {
                                        "cannot use them as target address.",
                                        "hostNameOrAddress");
 
+                       if (use_mono_dns)
+                               return BeginAsyncCall (hostNameOrAddress, requestCallback, stateObject);
+
                        GetHostEntryNameCallback c = new GetHostEntryNameCallback (GetHostEntry);
                        return c.BeginInvoke (hostNameOrAddress, requestCallback, stateObject);
                }
 
-               public static IAsyncResult BeginGetHostEntry (IPAddress address,
-                       AsyncCallback requestCallback, object stateObject)
+               public static IAsyncResult BeginGetHostEntry (IPAddress address, AsyncCallback requestCallback, object stateObject)
                {
                        if (address == null)
                                throw new ArgumentNullException ("address");
 
+                       if (use_mono_dns)
+                               return BeginAsyncCall (address.ToString (), requestCallback, stateObject);
+
                        GetHostEntryIPCallback c = new GetHostEntryIPCallback (GetHostEntry);
                        return c.BeginInvoke (address, requestCallback, stateObject);
                }
@@ -124,6 +199,9 @@ namespace System.Net {
                        if (asyncResult == null)
                                throw new ArgumentNullException ("asyncResult");
 
+                       if (use_mono_dns)
+                               return EndAsyncCall (asyncResult as DnsAsyncResult);
+
                        AsyncResult async = (AsyncResult) asyncResult;
                        GetHostByNameCallback cb = (GetHostByNameCallback) async.AsyncDelegate;
                        return cb.EndInvoke(asyncResult);
@@ -134,6 +212,10 @@ namespace System.Net {
                {
                        if (asyncResult == null)
                                throw new ArgumentNullException ("asyncResult");
+
+                       if (use_mono_dns)
+                               return EndAsyncCall (asyncResult as DnsAsyncResult);
+
                        AsyncResult async = (AsyncResult) asyncResult;
                        ResolveCallback cb = (ResolveCallback) async.AsyncDelegate;
                        return cb.EndInvoke(asyncResult);
@@ -144,6 +226,13 @@ namespace System.Net {
                        if (asyncResult == null)
                                throw new ArgumentNullException ("asyncResult");
 
+                       if (use_mono_dns) {
+                               IPHostEntry entry = EndAsyncCall (asyncResult as DnsAsyncResult);
+                               if (entry == null)
+                                       return null;
+                               return entry.AddressList;
+                       }
+
                        AsyncResult async = (AsyncResult) asyncResult;
                        GetHostAddressesCallback cb = (GetHostAddressesCallback) async.AsyncDelegate;
                        return cb.EndInvoke(asyncResult);
@@ -153,6 +242,10 @@ namespace System.Net {
                {
                        if (asyncResult == null)
                                throw new ArgumentNullException ("asyncResult");
+
+                       if (use_mono_dns)
+                               return EndAsyncCall (asyncResult as DnsAsyncResult);
+
                        AsyncResult async = (AsyncResult) asyncResult;
                        if (async.AsyncDelegate is GetHostEntryIPCallback)
                                return ((GetHostEntryIPCallback) async.AsyncDelegate).EndInvoke (asyncResult);