Allocate non-code memory using mono_domain_alloc0 in mono_create_jit_trampoline_from_...
[mono.git] / mcs / class / System / Mono.Dns / SimpleResolver.cs
1 //
2 // Mono.Dns.SimpleResolver
3 //
4 // Authors:
5 //      Gonzalo Paniagua Javier (gonzalo.mono@gmail.com)
6 //
7 // Copyright 2011 Gonzalo Paniagua Javier
8 //
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 //
13 // http://www.apache.org/licenses/LICENSE-2.0
14 //
15 // Unless required by applicable law or agreed to in writing, software
16 // distributed under the License is distributed on an "AS IS" BASIS,
17 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 // See the License for the specific language governing permissions and
19 // limitations under the License.
20 //
21 using System;
22 using System.Collections.Generic;
23 using System.Collections.ObjectModel;
24 using System.Diagnostics;
25 using System.Net;
26 using System.Net.Sockets;
27 using System.Net.NetworkInformation;
28 using System.Text;
29 using System.Threading;
30
31 namespace Mono.Dns {
32         sealed class SimpleResolver : IDisposable {
33                 static string [] EmptyStrings = new string [0];
34                 static IPAddress [] EmptyAddresses = new IPAddress [0];
35                 IPEndPoint [] endpoints;
36                 Socket client;
37                 Dictionary<int, SimpleResolverEventArgs> queries;
38                 AsyncCallback receive_cb;
39                 TimerCallback timeout_cb;
40                 bool disposed;
41 #if REUSE_RESPONSES
42                 Stack<DnsResponse> responses_avail = new Stack<DnsResponse> ();
43 #endif
44
45                 public SimpleResolver ()
46                 {
47                         queries = new Dictionary<int, SimpleResolverEventArgs> ();
48                         receive_cb = new AsyncCallback (OnReceive);
49                         timeout_cb = new TimerCallback (OnTimeout);
50                         InitFromSystem ();
51                         InitSocket ();
52                 }
53
54                 void IDisposable.Dispose ()
55                 {
56                         if (!disposed) {
57                                 disposed = true;
58                                 if (client != null) {
59                                         client.Close ();
60                                         client = null;
61                                 }
62                         }
63                 }
64
65                 public void Close ()
66                 {
67                         ((IDisposable) this).Dispose ();
68                 }
69
70                 void GetLocalHost (SimpleResolverEventArgs args)
71                 {
72                         //FIXME
73                         IPHostEntry entry = new IPHostEntry ();
74                         entry.HostName = "localhost";
75                         entry.AddressList = new IPAddress [] { IPAddress.Loopback };
76                         entry.Aliases = EmptyStrings;
77                         args.ResolverError = 0;
78                         args.HostEntry = entry;
79                         return;
80
81 /*
82                         List<IPEndPoint> eps = new List<IPEndPoint> ();
83                         foreach (NetworkInterface iface in NetworkInterface.GetAllNetworkInterfaces ()) {
84                                 if (NetworkInterfaceType.Loopback == iface.NetworkInterfaceType)
85                                         continue;
86
87                                 foreach (IPAddress addr in iface.GetIPProperties ().DnsAddresses) {
88                                         if (AddressFamily.InterNetworkV6 == addr.AddressFamily)
89                                                 continue;
90                                         IPEndPoint ep = new IPEndPoint (addr, 53);
91                                         if (eps.Contains (ep))
92                                                 continue;
93
94                                         eps.Add (ep);
95                                 }
96                         }
97                         endpoints = eps.ToArray ();
98 */
99                 }
100
101                 // Type A query
102                 // Might fill in Aliases
103                 // -IPAddress -> return the same IPAddress
104                 // -"" -> Local host ip addresses (filter out IPv6 if needed)
105                 public bool GetHostAddressesAsync (SimpleResolverEventArgs args)
106                 {
107                         if (args == null)
108                                 throw new ArgumentNullException ("args");
109
110                         if (args.HostName == null)
111                                 throw new ArgumentNullException ("args.HostName is null");
112
113                         if (args.HostName.Length > 255)
114                                 throw new ArgumentException ("args.HostName is too long");
115
116                         args.Reset (ResolverAsyncOperation.GetHostAddresses);
117                         string host = args.HostName;
118                         if (host == "") {
119                                 GetLocalHost (args);
120                                 return false;
121                         }
122                         IPAddress addr;
123                         if (IPAddress.TryParse (host, out addr)) {
124                                 IPHostEntry entry = new IPHostEntry ();
125                                 entry.HostName = host;
126                                 entry.Aliases = EmptyStrings;
127                                 entry.AddressList = new IPAddress [1] { addr };
128                                 args.HostEntry = entry;
129                                 return false;
130                         }
131
132                         SendAQuery (args, true);
133                         return true;
134                 }
135
136                 // For names -> type A Query
137                 // For IP addresses -> PTR + A -> will at least return itself
138                 //      Careful: for IP addresses with PTR, the hostname might yield different IP addresses!
139                 public bool GetHostEntryAsync (SimpleResolverEventArgs args)
140                 {
141                         if (args == null)
142                                 throw new ArgumentNullException ("args");
143
144                         if (args.HostName == null)
145                                 throw new ArgumentNullException ("args.HostName is null");
146
147                         if (args.HostName.Length > 255)
148                                 throw new ArgumentException ("args.HostName is too long");
149
150                         args.Reset (ResolverAsyncOperation.GetHostEntry);
151                         string host = args.HostName;
152                         if (host == "") {
153                                 GetLocalHost (args);
154                                 return false;
155                         }
156
157                         IPAddress addr;
158                         if (IPAddress.TryParse (host, out addr)) {
159                                 IPHostEntry entry = new IPHostEntry ();
160                                 entry.HostName = host;
161                                 entry.Aliases = EmptyStrings;
162                                 entry.AddressList = new IPAddress [1] { addr };
163                                 args.HostEntry = entry;
164                                 args.PTRAddress = addr;
165                                 SendPTRQuery (args, true);
166                                 return true;
167                         }
168
169                         // 3. For IP addresses:
170                         //      3.1 Parsing IP succeeds
171                         //      3.2 Reverse lookup of the IP fills in HostName -> fails? HostName = IP
172                         //      3.3 The hostname resulting from this is used to query DNS again to get the IP addresses
173                         //
174                         // Exclude IPv6 addresses if not supported by the system
175                         // .Aliases is always empty
176                         // Length > 255
177                         SendAQuery (args, true);
178                         return true;
179                 }
180
181                 bool AddQuery (DnsQuery query, SimpleResolverEventArgs args)
182                 {
183                         lock (queries) {
184                                 if (queries.ContainsKey (query.Header.ID))
185                                         return false;
186                                 queries [query.Header.ID] = args;
187                         }
188                         return true;
189                 }
190
191                 static DnsQuery GetQuery (string host, DnsQType q, DnsQClass c)
192                 {
193                         return new DnsQuery (host, q, c);
194                 }
195
196                 void SendAQuery (SimpleResolverEventArgs args, bool add_it)
197                 {
198                         SendAQuery (args, args.HostName, add_it);
199                 }
200
201                 void SendAQuery (SimpleResolverEventArgs args, string host, bool add_it)
202                 {
203                         DnsQuery query = GetQuery (host, DnsQType.A, DnsQClass.IN);
204                         SendQuery (args, query, add_it);
205                 }
206
207                 static string GetPTRName (IPAddress address)
208                 {
209                         // TODO: IPv6 PTR query?
210                         byte [] bytes = address.GetAddressBytes ();
211                         // "XXX.XXX.XXX.XXX.in-addr.arpa".Length
212                         StringBuilder sb = new StringBuilder (28);
213                         for (int i = bytes.Length - 1; i >= 0; i--) {
214                                 sb.AppendFormat ("{0}.", bytes [i]);
215                         }
216                         sb.Append ("in-addr.arpa");
217                         return sb.ToString ();
218                 }
219
220                 void SendPTRQuery (SimpleResolverEventArgs args, bool add_it)
221                 {
222                         DnsQuery query = GetQuery (GetPTRName (args.PTRAddress), DnsQType.PTR, DnsQClass.IN);
223                         SendQuery (args, query, add_it);
224                 }
225
226                 void SendQuery (SimpleResolverEventArgs args, DnsQuery query, bool add_it)
227                 {
228                         // TODO: not sure about reusing IDs when add_it == false
229                         int count = 0;
230                         if (add_it) {
231                                 do {
232                                         query.Header.ID = (ushort)new Random().Next(1, 65534);
233                                         if (count > 500)
234                                                 throw new InvalidOperationException ("Too many pending queries (or really bad luck)");
235                                 } while (AddQuery (query, args) == false);
236                                 args.QueryID = query.Header.ID;
237                         } else {
238                                 query.Header.ID = args.QueryID;
239                         }
240                         if (args.Timer == null)
241                                 args.Timer = new Timer (timeout_cb, args, 5000, Timeout.Infinite);
242                         else
243                                 args.Timer.Change (5000, Timeout.Infinite);
244                         client.BeginSend (query.Packet, 0, query.Length, SocketFlags.None, null, null);
245                 }
246
247                 byte [] GetFreshBuffer ()
248                 {
249 #if !REUSE_RESPONSES
250                         return new byte [512];
251 #else
252
253                         DnsResponse response = null;
254                         lock (responses_avail) {
255                                 if (responses_avail.Count > 0) {
256                                         response = responses_avail.Pop ();
257                                 }
258                         }
259                         if (response == null) {
260                                 response = new DnsResponse ();
261                         } else {
262                                 response.Reset ();
263                         }
264                         return response;
265 #endif
266                 }
267
268                 void FreeBuffer (byte [] buffer)
269                 {
270 #if REUSE_RESPONSES
271                         // TODO: set some limit here. Configurable?
272                         lock (responses_avail) {
273                                 responses_avail.Push (response);
274                         }
275 #endif
276                 }
277
278                 void InitSocket ()
279                 {
280                         client = new Socket (AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
281                         client.Blocking = true;
282                         client.Bind (new IPEndPoint (IPAddress.Any, 0));
283                         client.Connect (endpoints [0]);
284                         BeginReceive ();
285                 }
286
287                 void BeginReceive ()
288                 {
289                         byte [] buffer = GetFreshBuffer ();
290                         client.BeginReceive (buffer, 0, buffer.Length, SocketFlags.None, receive_cb, buffer);
291                 }
292
293                 void OnTimeout (object obj)
294                 {
295                         SimpleResolverEventArgs args = (SimpleResolverEventArgs) obj;
296                         SimpleResolverEventArgs args2;
297                         lock (queries) {
298                                 if (!queries.TryGetValue (args.QueryID, out args2)) {
299                                         return; // Already processed.
300                                 }
301                                 if (args != args2)
302                                         throw new Exception ("Should not happen: args != args2");
303                                 args.Retries++;
304                                 if (args.Retries > 1) {
305                                         // Error timeout
306                                         args.ResolverError = ResolverError.Timeout;
307                                         args.OnCompleted (this);
308                                 } else {
309                                         SendAQuery (args, false);
310                                 }
311                         }
312                 }
313
314                 void OnReceive (IAsyncResult ares)
315                 {
316                         if (disposed)
317                                 return;
318
319                         int nread = 0;
320                         EndPoint remote_ep = client.RemoteEndPoint;
321                         try {
322                                 nread = client.EndReceive (ares);
323                         } catch (Exception e) {
324                                 Console.Error.WriteLine (e);
325                         }
326
327                         BeginReceive ();
328
329                         byte [] buffer  = (byte []) ares.AsyncState;
330                         if (nread > 12) {
331                                 DnsResponse response = new DnsResponse (buffer, nread);
332                                 int id = response.Header.ID;
333                                 SimpleResolverEventArgs args = null;
334                                 lock (queries) {
335                                         if (queries.TryGetValue (id, out args)) {
336                                                 queries.Remove (id);
337                                         }
338                                 }
339
340                                 if (args != null) {
341                                         Timer t = args.Timer;
342                                         if (t != null)
343                                                 t.Change (Timeout.Infinite, Timeout.Infinite);
344
345                                         try {
346                                                 ProcessResponse (args, response, remote_ep);
347                                         } catch (Exception e) {
348                                                 args.ResolverError = (ResolverError) (-1);
349                                                 args.ErrorMessage = e.Message;
350                                         }
351
352                                         IPHostEntry entry = args.HostEntry;
353                                         if (args.ResolverError != 0 && args.PTRAddress != null && entry != null && entry.HostName != null) {
354                                                 args.PTRAddress = null;
355                                                 SendAQuery (args, entry.HostName, true);
356                                                 args.Timer.Change (5000, Timeout.Infinite);
357                                         } else {
358                                                 args.OnCompleted (this);
359                                         }
360                                 }
361                         }
362                         FreeBuffer (buffer);
363                 }
364
365                 void ProcessResponse (SimpleResolverEventArgs args, DnsResponse response, EndPoint server_ep)
366                 {
367                         DnsRCode status = response.Header.RCode;
368                         if (status != 0) {
369                                 if (args.PTRAddress != null) {
370                                         // PTR query failed -> no error, we have the IP
371                                         return;
372                                 }
373                                 args.ResolverError = (ResolverError) status;
374                                 return;
375                         }
376
377                         // TODO: verify IP of the server is in our list and the same one that got the query
378                         IPEndPoint ep = (IPEndPoint) server_ep;
379                         if (ep.Port != 53) {
380                                 args.ResolverError = ResolverError.ResponseHeaderError;
381                                 args.ErrorMessage = "Port";
382                                 return;
383                         }
384
385                         DnsHeader header = response.Header;
386                         if (!header.IsQuery) {
387                                 args.ResolverError = ResolverError.ResponseHeaderError;
388                                 args.ErrorMessage = "IsQuery";
389                                 return;
390                         }
391
392                         // TODO: handle Truncation. Retry with bigger buffer?
393
394                         if (header.QuestionCount > 1) {
395                                 args.ResolverError = ResolverError.ResponseHeaderError;
396                                 args.ErrorMessage = "QuestionCount";
397                                 return;
398                         }
399                         ReadOnlyCollection<DnsQuestion> q = response.GetQuestions ();
400                         if (q.Count != 1) {
401                                 args.ResolverError = ResolverError.ResponseHeaderError;
402                                 args.ErrorMessage = "QuestionCount 2";
403                                 return;
404                         }
405                         DnsQuestion question = q [0];
406                         /* The answer might have dot at the end, etc...
407                         if (String.Compare (question.Name, args.HostName) != 0) {
408                                 args.ResolverError = ResolverError.ResponseHeaderError;
409                                 args.ErrorMessage = "HostName - " + question.Name + " != " + args.HostName;
410                                 return;
411                         }
412                         */
413
414                         DnsQType t = question.Type;
415                         if (t != DnsQType.A && t != DnsQType.AAAA && t != DnsQType.PTR) {
416                                 args.ResolverError = ResolverError.ResponseHeaderError;
417                                 args.ErrorMessage = "QType " + question.Type;
418                                 return;
419                         }
420
421                         if (question.Class != DnsQClass.IN) {
422                                 args.ResolverError = ResolverError.ResponseHeaderError;
423                                 args.ErrorMessage = "QClass " + question.Class;
424                                 return;
425                         }
426
427                         ReadOnlyCollection<DnsResourceRecord> records = response.GetAnswers ();
428                         if (records.Count == 0) {
429                                 if (args.PTRAddress != null) {
430                                         // PTR query failed -> no error
431                                         return;
432                                 }
433                                 args.ResolverError = ResolverError.NameError; // is this ok?
434                                 args.ErrorMessage = "NoAnswers";
435                                 return;
436                         }
437
438                         List<string> aliases = null;
439                         List<IPAddress> addresses = null;
440                         foreach (DnsResourceRecord r in records) {
441                                 if (r.Class != DnsClass.IN)
442                                         continue;
443                                 if (r.Type == DnsType.A || r.Type == DnsType.AAAA) {
444                                         if (addresses == null)
445                                                 addresses = new List<IPAddress> ();
446                                         addresses.Add (((DnsResourceRecordIPAddress) r).Address);
447                                 } else if (r.Type == DnsType.CNAME) {
448                                         if (aliases == null)
449                                                 aliases = new List<string> ();
450                                         aliases.Add (((DnsResourceRecordCName) r).CName);
451                                 } else if (r.Type == DnsType.PTR) {
452                                         args.HostEntry.HostName = ((DnsResourceRecordPTR) r).DName;
453                                         args.HostEntry.Aliases = aliases == null ? EmptyStrings : aliases.ToArray ();
454                                         args.HostEntry.AddressList = EmptyAddresses;
455                                         return;
456                                 }
457                         }
458
459                         IPHostEntry entry = args.HostEntry ?? new IPHostEntry ();
460                         if (entry.HostName == null && aliases != null && aliases.Count > 0) {
461                                 entry.HostName = aliases [0];
462                                 aliases.RemoveAt (0);
463                         }
464                         entry.Aliases = aliases == null ? EmptyStrings : aliases.ToArray ();
465                         entry.AddressList = addresses == null ? EmptyAddresses : addresses.ToArray ();
466                         args.HostEntry = entry;
467                         if ((question.Type == DnsQType.A || question.Type == DnsQType.AAAA) && entry.AddressList == EmptyAddresses) {
468                                 args.ResolverError = ResolverError.NameError;
469                                 args.ErrorMessage = "No addresses in response";
470                         } else if (question.Type == DnsQType.PTR && entry.HostName == null) {
471                                 args.ResolverError = ResolverError.NameError;
472                                 args.ErrorMessage = "No PTR in response";
473                         }
474
475                 }
476
477                 void InitFromSystem ()
478                 {
479                         List<IPEndPoint> eps = new List<IPEndPoint> ();
480                         foreach (NetworkInterface iface in NetworkInterface.GetAllNetworkInterfaces ()) {
481                                 if (NetworkInterfaceType.Loopback == iface.NetworkInterfaceType)
482                                         continue;
483
484                                 foreach (IPAddress addr in iface.GetIPProperties ().DnsAddresses) {
485                                         if (AddressFamily.InterNetworkV6 == addr.AddressFamily)
486                                                 continue;
487                                         IPEndPoint ep = new IPEndPoint (addr, 53);
488                                         if (eps.Contains (ep))
489                                                 continue;
490
491                                         eps.Add (ep);
492                                 }
493                         }
494                         endpoints = eps.ToArray ();
495                 }
496         }
497 }
498