Fix race condition
[mono.git] / mcs / class / System / System.Net / EndPointListener.cs
1 //
2 // System.Net.EndPointListener
3 //
4 // Author:
5 //      Gonzalo Paniagua Javier (gonzalo@novell.com)
6 //
7 // Copyright (c) 2005 Novell, Inc. (http://www.novell.com)
8 //
9 // Permission is hereby granted, free of charge, to any person obtaining
10 // a copy of this software and associated documentation files (the
11 // "Software"), to deal in the Software without restriction, including
12 // without limitation the rights to use, copy, modify, merge, publish,
13 // distribute, sublicense, and/or sell copies of the Software, and to
14 // permit persons to whom the Software is furnished to do so, subject to
15 // the following conditions:
16 // 
17 // The above copyright notice and this permission notice shall be
18 // included in all copies or substantial portions of the Software.
19 // 
20 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
21 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
22 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
23 // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
24 // LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
25 // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
26 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
27 //
28
29 #if NET_2_0 && SECURITY_DEP
30
31 using System.IO;
32 using System.Net.Sockets;
33 using System.Collections;
34 using System.Security.Cryptography;
35 using System.Security.Cryptography.X509Certificates;
36 using System.Threading;
37 using Mono.Security.Authenticode;
38
39 namespace System.Net {
40         sealed class EndPointListener
41         {
42                 IPEndPoint endpoint;
43                 Socket sock;
44                 Hashtable prefixes;  // Dictionary <ListenerPrefix, HttpListener>
45                 ArrayList unhandled; // List<ListenerPrefix> unhandled; host = '*'
46                 ArrayList all;       // List<ListenerPrefix> all;  host = '+'
47                 X509Certificate2 cert;
48                 AsymmetricAlgorithm key;
49                 bool secure;
50                 Hashtable unregistered;
51
52                 public EndPointListener (IPAddress addr, int port, bool secure)
53                 {
54                         if (secure) {
55                                 this.secure = secure;
56                                 LoadCertificateAndKey (addr, port);
57                         }
58
59                         endpoint = new IPEndPoint (addr, port);
60                         sock = new Socket (addr.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
61                         sock.Bind (endpoint);
62                         sock.Listen (500);
63                         SocketAsyncEventArgs args = new SocketAsyncEventArgs ();
64                         args.UserToken = this;
65                         args.Completed += OnAccept;
66                         sock.AcceptAsync (args);
67                         prefixes = new Hashtable ();
68                         unregistered = new Hashtable ();
69                 }
70
71                 void LoadCertificateAndKey (IPAddress addr, int port)
72                 {
73                         // Actually load the certificate
74                         try {
75                                 string dirname = Environment.GetFolderPath (Environment.SpecialFolder.ApplicationData);
76                                 string path = Path.Combine (dirname, ".mono");
77                                 path = Path.Combine (path, "httplistener");
78                                 string cert_file = Path.Combine (path, String.Format ("{0}.cer", port));
79                                 string pvk_file = Path.Combine (path, String.Format ("{0}.pvk", port));
80                                 cert = new X509Certificate2 (cert_file);
81                                 key = PrivateKey.CreateFromFile (pvk_file).RSA;
82                         } catch {
83                                 // ignore errors
84                         }
85                 }
86
87                 static void OnAccept (object sender, EventArgs e)
88                 {
89                         SocketAsyncEventArgs args = (SocketAsyncEventArgs) e;
90                         EndPointListener epl = (EndPointListener) args.UserToken;
91                         Socket accepted = null;
92                         if (args.SocketError == SocketError.Success) {
93                                 accepted = args.AcceptSocket;
94                                 args.AcceptSocket = null;
95                         }
96
97                         try {
98                                 if (epl.sock != null)
99                                         epl.sock.AcceptAsync (args);
100                         } catch {
101                                 if (accepted != null) {
102                                         try {
103                                                 accepted.Close ();
104                                         } catch {}
105                                         accepted = null;
106                                 }
107                         } 
108
109                         if (accepted == null)
110                                 return;
111
112                         if (epl.secure && (epl.cert == null || epl.key == null)) {
113                                 accepted.Close ();
114                                 return;
115                         }
116                         HttpConnection conn = new HttpConnection (accepted, epl, epl.secure, epl.cert, epl.key);
117                         lock (epl.unregistered) {
118                                 epl.unregistered [conn] = conn;
119                         }
120                         conn.BeginReadRequest ();
121                 }
122
123                 internal void RemoveConnection (HttpConnection conn)
124                 {
125                         lock (unregistered) {
126                                 unregistered.Remove (conn);
127                         }
128                 }
129
130                 public bool BindContext (HttpListenerContext context)
131                 {
132                         HttpListenerRequest req = context.Request;
133                         ListenerPrefix prefix;
134                         HttpListener listener = SearchListener (req.Url, out prefix);
135                         if (listener == null)
136                                 return false;
137
138                         context.Listener = listener;
139                         context.Connection.Prefix = prefix;
140                         return true;
141                 }
142
143                 public void UnbindContext (HttpListenerContext context)
144                 {
145                         if (context == null || context.Request == null)
146                                 return;
147
148                         context.Listener.UnregisterContext (context);
149                 }
150
151                 HttpListener SearchListener (Uri uri, out ListenerPrefix prefix)
152                 {
153                         prefix = null;
154                         if (uri == null)
155                                 return null;
156
157                         string host = uri.Host;
158                         int port = uri.Port;
159                         string path = HttpUtility.UrlDecode (uri.AbsolutePath);
160                         string path_slash = path [path.Length - 1] == '/' ? path : path + "/";
161                         
162                         HttpListener best_match = null;
163                         int best_length = -1;
164
165                         if (host != null && host != "") {
166                                 Hashtable p_ro = prefixes;
167                                 foreach (ListenerPrefix p in p_ro.Keys) {
168                                         string ppath = p.Path;
169                                         if (ppath.Length < best_length)
170                                                 continue;
171
172                                         if (p.Host != host || p.Port != port)
173                                                 continue;
174
175                                         if (path.StartsWith (ppath) || path_slash.StartsWith (ppath)) {
176                                                 best_length = ppath.Length;
177                                                 best_match = (HttpListener) p_ro [p];
178                                                 prefix = p;
179                                         }
180                                 }
181                                 if (best_length != -1)
182                                         return best_match;
183                         }
184
185                         ArrayList list = unhandled;
186                         best_match = MatchFromList (host, path, list, out prefix);
187                         if (path != path_slash && best_match == null)
188                                 best_match = MatchFromList (host, path_slash, list, out prefix);
189                         if (best_match != null)
190                                 return best_match;
191
192                         list = all;
193                         best_match = MatchFromList (host, path, list, out prefix);
194                         if (path != path_slash && best_match == null)
195                                 best_match = MatchFromList (host, path_slash, list, out prefix);
196                         if (best_match != null)
197                                 return best_match;
198
199                         return null;
200                 }
201
202                 HttpListener MatchFromList (string host, string path, ArrayList list, out ListenerPrefix prefix)
203                 {
204                         prefix = null;
205                         if (list == null)
206                                 return null;
207
208                         HttpListener best_match = null;
209                         int best_length = -1;
210                         
211                         foreach (ListenerPrefix p in list) {
212                                 string ppath = p.Path;
213                                 if (ppath.Length < best_length)
214                                         continue;
215
216                                 if (path.StartsWith (ppath)) {
217                                         best_length = ppath.Length;
218                                         best_match = p.Listener;
219                                         prefix = p;
220                                 }
221                         }
222
223                         return best_match;
224                 }
225
226                 void AddSpecial (ArrayList coll, ListenerPrefix prefix)
227                 {
228                         if (coll == null)
229                                 return;
230
231                         foreach (ListenerPrefix p in coll) {
232                                 if (p.Path == prefix.Path) //TODO: code
233                                         throw new HttpListenerException (400, "Prefix already in use.");
234                         }
235                         coll.Add (prefix);
236                 }
237
238                 bool RemoveSpecial (ArrayList coll, ListenerPrefix prefix)
239                 {
240                         if (coll == null)
241                                 return false;
242
243                         int c = coll.Count;
244                         for (int i = 0; i < c; i++) {
245                                 ListenerPrefix p = (ListenerPrefix) coll [i];
246                                 if (p.Path == prefix.Path) {
247                                         coll.RemoveAt (i);
248                                         return true;
249                                 }
250                         }
251                         return false;
252                 }
253
254                 void CheckIfRemove ()
255                 {
256                         if (prefixes.Count > 0)
257                                 return;
258
259                         ArrayList list = unhandled;
260                         if (list != null && list.Count > 0)
261                                 return;
262
263                         list = all;
264                         if (list != null && list.Count > 0)
265                                 return;
266
267                         EndPointManager.RemoveEndPoint (this, endpoint);
268                 }
269
270                 public void Close ()
271                 {
272                         sock.Close ();
273                         lock (unregistered) {
274                                 foreach (HttpConnection c in unregistered.Keys)
275                                         c.Close (true);
276                                 unregistered.Clear ();
277                         }
278                 }
279
280                 public void AddPrefix (ListenerPrefix prefix, HttpListener listener)
281                 {
282                         ArrayList current;
283                         ArrayList future;
284                         if (prefix.Host == "*") {
285                                 do {
286                                         current = unhandled;
287                                         future = (current != null) ? (ArrayList) current.Clone () : new ArrayList ();
288                                         prefix.Listener = listener;
289                                         AddSpecial (future, prefix);
290                                 } while (Interlocked.CompareExchange (ref unhandled, future, current) != current);
291                                 return;
292                         }
293
294                         if (prefix.Host == "+") {
295                                 do {
296                                         current = all;
297                                         future = (current != null) ? (ArrayList) current.Clone () : new ArrayList ();
298                                         prefix.Listener = listener;
299                                         AddSpecial (future, prefix);
300                                 } while (Interlocked.CompareExchange (ref all, future, current) != current);
301                                 return;
302                         }
303
304                         Hashtable prefs, p2;
305                         do {
306                                 prefs = prefixes;
307                                 if (prefs.ContainsKey (prefix)) {
308                                         HttpListener other = (HttpListener) prefs [prefix];
309                                         if (other != listener) // TODO: code.
310                                                 throw new HttpListenerException (400, "There's another listener for " + prefix);
311                                         return;
312                                 }
313                                 p2 = (Hashtable) prefs.Clone ();
314                                 p2 [prefix] = listener;
315                         } while (Interlocked.CompareExchange (ref prefixes, p2, prefs) != prefs);
316                 }
317
318                 public void RemovePrefix (ListenerPrefix prefix, HttpListener listener)
319                 {
320                         ArrayList current;
321                         ArrayList future;
322                         if (prefix.Host == "*") {
323                                 do {
324                                         current = unhandled;
325                                         future = (current != null) ? (ArrayList) current.Clone () : new ArrayList ();
326                                         if (!RemoveSpecial (future, prefix))
327                                                 break; // Prefix not found
328                                 } while (Interlocked.CompareExchange (ref unhandled, future, current) != current);
329                                 CheckIfRemove ();
330                                 return;
331                         }
332
333                         if (prefix.Host == "+") {
334                                 do {
335                                         current = all;
336                                         future = (current != null) ? (ArrayList) current.Clone () : new ArrayList ();
337                                         if (!RemoveSpecial (future, prefix))
338                                                 break; // Prefix not found
339                                 } while (Interlocked.CompareExchange (ref all, future, current) != current);
340                                 CheckIfRemove ();
341                                 return;
342                         }
343
344                         Hashtable prefs, p2;
345                         do {
346                                 prefs = prefixes;
347                                 if (!prefs.ContainsKey (prefix))
348                                         break;
349
350                                 p2 = (Hashtable) prefs.Clone ();
351                                 p2.Remove (prefix);
352                         } while (Interlocked.CompareExchange (ref prefixes, p2, prefs) != prefs);
353                         CheckIfRemove ();
354                 }
355         }
356 }
357 #endif
358