+//
+// Mono.Net.Dns.SimpleResolver
+//
+// Authors:
+// Gonzalo Paniagua Javier (gonzalo.mono@gmail.com)
+//
+// Copyright 2011 Gonzalo Paniagua Javier
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+using System;
+using System.Collections.Generic;
+using System.Collections.ObjectModel;
+using System.Diagnostics;
+using System.Net;
+using System.Net.Sockets;
+using System.Net.NetworkInformation;
+using System.Text;
+using System.Threading;
+
+namespace Mono.Net.Dns {
+ sealed class SimpleResolver : IDisposable {
+ static string [] EmptyStrings = new string [0];
+ static IPAddress [] EmptyAddresses = new IPAddress [0];
+ IPEndPoint [] endpoints;
+ Socket client;
+ Dictionary<int, SimpleResolverEventArgs> queries;
+ AsyncCallback receive_cb;
+ TimerCallback timeout_cb;
+ bool disposed;
+#if REUSE_RESPONSES
+ Stack<DnsResponse> responses_avail = new Stack<DnsResponse> ();
+#endif
+
+ public SimpleResolver ()
+ {
+ queries = new Dictionary<int, SimpleResolverEventArgs> ();
+ receive_cb = new AsyncCallback (OnReceive);
+ timeout_cb = new TimerCallback (OnTimeout);
+ InitFromSystem ();
+ InitSocket ();
+ }
+
+ void IDisposable.Dispose ()
+ {
+ if (!disposed) {
+ disposed = true;
+ if (client != null) {
+ client.Close ();
+ client = null;
+ }
+ }
+ }
+
+ public void Close ()
+ {
+ ((IDisposable) this).Dispose ();
+ }
+
+ void GetLocalHost (SimpleResolverEventArgs args)
+ {
+ //FIXME
+ IPHostEntry entry = new IPHostEntry ();
+ entry.HostName = "localhost";
+ entry.AddressList = new IPAddress [] { IPAddress.Loopback };
+ entry.Aliases = EmptyStrings;
+ args.ResolverError = 0;
+ args.HostEntry = entry;
+ return;
+
+/*
+ List<IPEndPoint> eps = new List<IPEndPoint> ();
+ foreach (NetworkInterface iface in NetworkInterface.GetAllNetworkInterfaces ()) {
+ if (NetworkInterfaceType.Loopback == iface.NetworkInterfaceType)
+ continue;
+
+ foreach (IPAddress addr in iface.GetIPProperties ().DnsAddresses) {
+ if (AddressFamily.InterNetworkV6 == addr.AddressFamily)
+ continue;
+ IPEndPoint ep = new IPEndPoint (addr, 53);
+ if (eps.Contains (ep))
+ continue;
+
+ eps.Add (ep);
+ }
+ }
+ endpoints = eps.ToArray ();
+*/
+ }
+
+ // Type A query
+ // Might fill in Aliases
+ // -IPAddress -> return the same IPAddress
+ // -"" -> Local host ip addresses (filter out IPv6 if needed)
+ public bool GetHostAddressesAsync (SimpleResolverEventArgs args)
+ {
+ if (args == null)
+ throw new ArgumentNullException ("args");
+
+ if (args.HostName == null)
+ throw new ArgumentNullException ("args.HostName is null");
+
+ if (args.HostName.Length > 255)
+ throw new ArgumentException ("args.HostName is too long");
+
+ args.Reset (ResolverAsyncOperation.GetHostAddresses);
+ string host = args.HostName;
+ if (host == "") {
+ GetLocalHost (args);
+ return false;
+ }
+ IPAddress addr;
+ if (IPAddress.TryParse (host, out addr)) {
+ IPHostEntry entry = new IPHostEntry ();
+ entry.HostName = host;
+ entry.Aliases = EmptyStrings;
+ entry.AddressList = new IPAddress [1] { addr };
+ args.HostEntry = entry;
+ return false;
+ }
+
+ SendAQuery (args, true);
+ return true;
+ }
+
+ // For names -> type A Query
+ // For IP addresses -> PTR + A -> will at least return itself
+ // Careful: for IP addresses with PTR, the hostname might yield different IP addresses!
+ public bool GetHostEntryAsync (SimpleResolverEventArgs args)
+ {
+ if (args == null)
+ throw new ArgumentNullException ("args");
+
+ if (args.HostName == null)
+ throw new ArgumentNullException ("args.HostName is null");
+
+ if (args.HostName.Length > 255)
+ throw new ArgumentException ("args.HostName is too long");
+
+ args.Reset (ResolverAsyncOperation.GetHostEntry);
+ string host = args.HostName;
+ if (host == "") {
+ GetLocalHost (args);
+ return false;
+ }
+
+ IPAddress addr;
+ if (IPAddress.TryParse (host, out addr)) {
+ IPHostEntry entry = new IPHostEntry ();
+ entry.HostName = host;
+ entry.Aliases = EmptyStrings;
+ entry.AddressList = new IPAddress [1] { addr };
+ args.HostEntry = entry;
+ args.PTRAddress = addr;
+ SendPTRQuery (args, true);
+ return true;
+ }
+
+ // 3. For IP addresses:
+ // 3.1 Parsing IP succeeds
+ // 3.2 Reverse lookup of the IP fills in HostName -> fails? HostName = IP
+ // 3.3 The hostname resulting from this is used to query DNS again to get the IP addresses
+ //
+ // Exclude IPv6 addresses if not supported by the system
+ // .Aliases is always empty
+ // Length > 255
+ SendAQuery (args, true);
+ return true;
+ }
+
+ bool AddQuery (DnsQuery query, SimpleResolverEventArgs args)
+ {
+ lock (queries) {
+ if (queries.ContainsKey (query.Header.ID))
+ return false;
+ queries [query.Header.ID] = args;
+ }
+ return true;
+ }
+
+ static DnsQuery GetQuery (string host, DnsQType q, DnsQClass c)
+ {
+ return new DnsQuery (host, q, c);
+ }
+
+ void SendAQuery (SimpleResolverEventArgs args, bool add_it)
+ {
+ SendAQuery (args, args.HostName, add_it);
+ }
+
+ void SendAQuery (SimpleResolverEventArgs args, string host, bool add_it)
+ {
+ DnsQuery query = GetQuery (host, DnsQType.A, DnsQClass.IN);
+ SendQuery (args, query, add_it);
+ }
+
+ static string GetPTRName (IPAddress address)
+ {
+ // TODO: IPv6 PTR query?
+ byte [] bytes = address.GetAddressBytes ();
+ // "XXX.XXX.XXX.XXX.in-addr.arpa".Length
+ StringBuilder sb = new StringBuilder (28);
+ for (int i = bytes.Length - 1; i >= 0; i--) {
+ sb.AppendFormat ("{0}.", bytes [i]);
+ }
+ sb.Append ("in-addr.arpa");
+ return sb.ToString ();
+ }
+
+ void SendPTRQuery (SimpleResolverEventArgs args, bool add_it)
+ {
+ DnsQuery query = GetQuery (GetPTRName (args.PTRAddress), DnsQType.PTR, DnsQClass.IN);
+ SendQuery (args, query, add_it);
+ }
+
+ void SendQuery (SimpleResolverEventArgs args, DnsQuery query, bool add_it)
+ {
+ // TODO: not sure about reusing IDs when add_it == false
+ int count = 0;
+ if (add_it) {
+ do {
+ query.Header.ID = (ushort)new Random().Next(1, 65534);
+ if (count > 500)
+ throw new InvalidOperationException ("Too many pending queries (or really bad luck)");
+ } while (AddQuery (query, args) == false);
+ args.QueryID = query.Header.ID;
+ } else {
+ query.Header.ID = args.QueryID;
+ }
+ if (args.Timer == null)
+ args.Timer = new Timer (timeout_cb, args, 5000, Timeout.Infinite);
+ else
+ args.Timer.Change (5000, Timeout.Infinite);
+ client.BeginSend (query.Packet, 0, query.Length, SocketFlags.None, null, null);
+ }
+
+ byte [] GetFreshBuffer ()
+ {
+#if !REUSE_RESPONSES
+ return new byte [512];
+#else
+
+ DnsResponse response = null;
+ lock (responses_avail) {
+ if (responses_avail.Count > 0) {
+ response = responses_avail.Pop ();
+ }
+ }
+ if (response == null) {
+ response = new DnsResponse ();
+ } else {
+ response.Reset ();
+ }
+ return response;
+#endif
+ }
+
+ void FreeBuffer (byte [] buffer)
+ {
+#if REUSE_RESPONSES
+ // TODO: set some limit here. Configurable?
+ lock (responses_avail) {
+ responses_avail.Push (response);
+ }
+#endif
+ }
+
+ void InitSocket ()
+ {
+ client = new Socket (AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
+ client.Blocking = true;
+ client.Bind (new IPEndPoint (IPAddress.Any, 0));
+ client.Connect (endpoints [0]);
+ BeginReceive ();
+ }
+
+ void BeginReceive ()
+ {
+ byte [] buffer = GetFreshBuffer ();
+ client.BeginReceive (buffer, 0, buffer.Length, SocketFlags.None, receive_cb, buffer);
+ }
+
+ void OnTimeout (object obj)
+ {
+ SimpleResolverEventArgs args = (SimpleResolverEventArgs) obj;
+ SimpleResolverEventArgs args2;
+ lock (queries) {
+ if (!queries.TryGetValue (args.QueryID, out args2)) {
+ return; // Already processed.
+ }
+ if (args != args2)
+ throw new Exception ("Should not happen: args != args2");
+ args.Retries++;
+ if (args.Retries > 1) {
+ // Error timeout
+ args.ResolverError = ResolverError.Timeout;
+ args.OnCompleted (this);
+ } else {
+ SendAQuery (args, false);
+ }
+ }
+ }
+
+ void OnReceive (IAsyncResult ares)
+ {
+ if (disposed)
+ return;
+
+ int nread = 0;
+ EndPoint remote_ep = client.RemoteEndPoint;
+ try {
+ nread = client.EndReceive (ares);
+ } catch (Exception e) {
+ Console.Error.WriteLine (e);
+ }
+
+ BeginReceive ();
+
+ byte [] buffer = (byte []) ares.AsyncState;
+ if (nread > 12) {
+ DnsResponse response = new DnsResponse (buffer, nread);
+ int id = response.Header.ID;
+ SimpleResolverEventArgs args = null;
+ lock (queries) {
+ if (queries.TryGetValue (id, out args)) {
+ queries.Remove (id);
+ }
+ }
+
+ if (args != null) {
+ Timer t = args.Timer;
+ if (t != null)
+ t.Change (Timeout.Infinite, Timeout.Infinite);
+
+ try {
+ ProcessResponse (args, response, remote_ep);
+ } catch (Exception e) {
+ args.ResolverError = (ResolverError) (-1);
+ args.ErrorMessage = e.Message;
+ }
+
+ IPHostEntry entry = args.HostEntry;
+ if (args.ResolverError != 0 && args.PTRAddress != null && entry != null && entry.HostName != null) {
+ args.PTRAddress = null;
+ SendAQuery (args, entry.HostName, true);
+ args.Timer.Change (5000, Timeout.Infinite);
+ } else {
+ args.OnCompleted (this);
+ }
+ }
+ }
+ FreeBuffer (buffer);
+ }
+
+ void ProcessResponse (SimpleResolverEventArgs args, DnsResponse response, EndPoint server_ep)
+ {
+ DnsRCode status = response.Header.RCode;
+ if (status != 0) {
+ if (args.PTRAddress != null) {
+ // PTR query failed -> no error, we have the IP
+ return;
+ }
+ args.ResolverError = (ResolverError) status;
+ return;
+ }
+
+ // TODO: verify IP of the server is in our list and the same one that got the query
+ IPEndPoint ep = (IPEndPoint) server_ep;
+ if (ep.Port != 53) {
+ args.ResolverError = ResolverError.ResponseHeaderError;
+ args.ErrorMessage = "Port";
+ return;
+ }
+
+ DnsHeader header = response.Header;
+ if (!header.IsQuery) {
+ args.ResolverError = ResolverError.ResponseHeaderError;
+ args.ErrorMessage = "IsQuery";
+ return;
+ }
+
+ // TODO: handle Truncation. Retry with bigger buffer?
+
+ if (header.QuestionCount > 1) {
+ args.ResolverError = ResolverError.ResponseHeaderError;
+ args.ErrorMessage = "QuestionCount";
+ return;
+ }
+ ReadOnlyCollection<DnsQuestion> q = response.GetQuestions ();
+ if (q.Count != 1) {
+ args.ResolverError = ResolverError.ResponseHeaderError;
+ args.ErrorMessage = "QuestionCount 2";
+ return;
+ }
+ DnsQuestion question = q [0];
+ /* The answer might have dot at the end, etc...
+ if (String.Compare (question.Name, args.HostName) != 0) {
+ args.ResolverError = ResolverError.ResponseHeaderError;
+ args.ErrorMessage = "HostName - " + question.Name + " != " + args.HostName;
+ return;
+ }
+ */
+
+ DnsQType t = question.Type;
+ if (t != DnsQType.A && t != DnsQType.AAAA && t != DnsQType.PTR) {
+ args.ResolverError = ResolverError.ResponseHeaderError;
+ args.ErrorMessage = "QType " + question.Type;
+ return;
+ }
+
+ if (question.Class != DnsQClass.IN) {
+ args.ResolverError = ResolverError.ResponseHeaderError;
+ args.ErrorMessage = "QClass " + question.Class;
+ return;
+ }
+
+ ReadOnlyCollection<DnsResourceRecord> records = response.GetAnswers ();
+ if (records.Count == 0) {
+ if (args.PTRAddress != null) {
+ // PTR query failed -> no error
+ return;
+ }
+ args.ResolverError = ResolverError.NameError; // is this ok?
+ args.ErrorMessage = "NoAnswers";
+ return;
+ }
+
+ List<string> aliases = null;
+ List<IPAddress> addresses = null;
+ foreach (DnsResourceRecord r in records) {
+ if (r.Class != DnsClass.IN)
+ continue;
+ if (r.Type == DnsType.A || r.Type == DnsType.AAAA) {
+ if (addresses == null)
+ addresses = new List<IPAddress> ();
+ addresses.Add (((DnsResourceRecordIPAddress) r).Address);
+ } else if (r.Type == DnsType.CNAME) {
+ if (aliases == null)
+ aliases = new List<string> ();
+ aliases.Add (((DnsResourceRecordCName) r).CName);
+ } else if (r.Type == DnsType.PTR) {
+ args.HostEntry.HostName = ((DnsResourceRecordPTR) r).DName;
+ args.HostEntry.Aliases = aliases == null ? EmptyStrings : aliases.ToArray ();
+ args.HostEntry.AddressList = EmptyAddresses;
+ return;
+ }
+ }
+
+ IPHostEntry entry = args.HostEntry ?? new IPHostEntry ();
+ if (entry.HostName == null && aliases != null && aliases.Count > 0) {
+ entry.HostName = aliases [0];
+ aliases.RemoveAt (0);
+ }
+ entry.Aliases = aliases == null ? EmptyStrings : aliases.ToArray ();
+ entry.AddressList = addresses == null ? EmptyAddresses : addresses.ToArray ();
+ args.HostEntry = entry;
+ if ((question.Type == DnsQType.A || question.Type == DnsQType.AAAA) && entry.AddressList == EmptyAddresses) {
+ args.ResolverError = ResolverError.NameError;
+ args.ErrorMessage = "No addresses in response";
+ } else if (question.Type == DnsQType.PTR && entry.HostName == null) {
+ args.ResolverError = ResolverError.NameError;
+ args.ErrorMessage = "No PTR in response";
+ }
+
+ }
+
+ void InitFromSystem ()
+ {
+ List<IPEndPoint> eps = new List<IPEndPoint> ();
+ foreach (NetworkInterface iface in NetworkInterface.GetAllNetworkInterfaces ()) {
+ if (NetworkInterfaceType.Loopback == iface.NetworkInterfaceType)
+ continue;
+
+ foreach (IPAddress addr in iface.GetIPProperties ().DnsAddresses) {
+ if (AddressFamily.InterNetworkV6 == addr.AddressFamily)
+ continue;
+ IPEndPoint ep = new IPEndPoint (addr, 53);
+ if (eps.Contains (ep))
+ continue;
+
+ eps.Add (ep);
+ }
+ }
+ endpoints = eps.ToArray ();
+ }
+ }
+}
+