[Mono.Security.Interface]: Improve synergy between `SslStream` and `IMonoSslStream...
[mono.git] / mcs / class / System / System.Net / HttpConnection.cs
1 //
2 // System.Net.HttpConnection
3 //
4 // Author:
5 //      Gonzalo Paniagua Javier (gonzalo.mono@gmail.com)
6 //
7 // Copyright (c) 2005-2009 Novell, Inc. (http://www.novell.com)
8 // Copyright (c) 2012 Xamarin, Inc. (http://xamarin.com)
9 //
10 // Permission is hereby granted, free of charge, to any person obtaining
11 // a copy of this software and associated documentation files (the
12 // "Software"), to deal in the Software without restriction, including
13 // without limitation the rights to use, copy, modify, merge, publish,
14 // distribute, sublicense, and/or sell copies of the Software, and to
15 // permit persons to whom the Software is furnished to do so, subject to
16 // the following conditions:
17 // 
18 // The above copyright notice and this permission notice shall be
19 // included in all copies or substantial portions of the Software.
20 // 
21 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
22 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
23 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
24 // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
25 // LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
26 // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
27 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
28 //
29
30 #if SECURITY_DEP
31 #if MONO_SECURITY_ALIAS
32 extern alias MonoSecurity;
33 #endif
34
35 #if MONO_SECURITY_ALIAS
36 using MSI = MonoSecurity::Mono.Security.Interface;
37 #else
38 using MSI = Mono.Security.Interface;
39 #endif
40
41 using System.IO;
42 using System.Net.Sockets;
43 using System.Text;
44 using System.Threading;
45 using System.Net.Security;
46 using System.Security.Authentication;
47 using System.Security.Cryptography;
48 using System.Security.Cryptography.X509Certificates;
49
50 namespace System.Net {
51         sealed class HttpConnection
52         {
53                 static AsyncCallback onread_cb = new AsyncCallback (OnRead);
54                 const int BufferSize = 8192;
55                 Socket sock;
56                 Stream stream;
57                 EndPointListener epl;
58                 MemoryStream ms;
59                 byte [] buffer;
60                 HttpListenerContext context;
61                 StringBuilder current_line;
62                 ListenerPrefix prefix;
63                 RequestStream i_stream;
64                 ResponseStream o_stream;
65                 bool chunked;
66                 int reuses;
67                 bool context_bound;
68                 bool secure;
69                 X509Certificate cert;
70                 int s_timeout = 90000; // 90k ms for first request, 15k ms from then on
71                 Timer timer;
72                 IPEndPoint local_ep;
73                 HttpListener last_listener;
74                 int [] client_cert_errors;
75                 X509Certificate2 client_cert;
76                 SslStream ssl_stream;
77
78                 public HttpConnection (Socket sock, EndPointListener epl, bool secure, X509Certificate cert)
79                 {
80                         this.sock = sock;
81                         this.epl = epl;
82                         this.secure = secure;
83                         this.cert = cert;
84                         if (secure == false) {
85                                 stream = new NetworkStream (sock, false);
86                         } else {
87                                 ssl_stream = epl.Listener.CreateSslStream (new NetworkStream (sock, false), false, (t, c, ch, e) => {
88                                         if (c == null)
89                                                 return true;
90                                         var c2 = c as X509Certificate2;
91                                         if (c2 == null)
92                                                 c2 = new X509Certificate2 (c.GetRawCertData ());
93                                         client_cert = c2;
94                                         client_cert_errors = new int[] { (int)e };
95                                         return true;
96                                 });
97                                 stream = ssl_stream;
98                         }
99                         timer = new Timer (OnTimeout, null, Timeout.Infinite, Timeout.Infinite);
100                         Init ();
101                 }
102
103                 internal SslStream SslStream {
104                         get { return ssl_stream; }
105                 }
106
107                 internal int [] ClientCertificateErrors {
108                         get { return client_cert_errors; }
109                 }
110
111                 internal X509Certificate2 ClientCertificate {
112                         get { return client_cert; }
113                 }
114
115                 void Init ()
116                 {
117                         if (ssl_stream != null) {
118                                 ssl_stream.AuthenticateAsServer (cert, true, (SslProtocols)ServicePointManager.SecurityProtocol, false);
119                         }
120
121                         context_bound = false;
122                         i_stream = null;
123                         o_stream = null;
124                         prefix = null;
125                         chunked = false;
126                         ms = new MemoryStream ();
127                         position = 0;
128                         input_state = InputState.RequestLine;
129                         line_state = LineState.None;
130                         context = new HttpListenerContext (this);
131                 }
132
133                 public bool IsClosed {
134                         get { return (sock == null); }
135                 }
136
137                 public int Reuses {
138                         get { return reuses; }
139                 }
140
141                 public IPEndPoint LocalEndPoint {
142                         get {
143                                 if (local_ep != null)
144                                         return local_ep;
145
146                                 local_ep = (IPEndPoint) sock.LocalEndPoint;
147                                 return local_ep;
148                         }
149                 }
150
151                 public IPEndPoint RemoteEndPoint {
152                         get { return (IPEndPoint) sock.RemoteEndPoint; }
153                 }
154
155                 public bool IsSecure {
156                         get { return secure; }
157                 }
158
159                 public ListenerPrefix Prefix {
160                         get { return prefix; }
161                         set { prefix = value; }
162                 }
163
164                 void OnTimeout (object unused)
165                 {
166                         CloseSocket ();
167                         Unbind ();
168                 }
169
170                 public void BeginReadRequest ()
171                 {
172                         if (buffer == null)
173                                 buffer = new byte [BufferSize];
174                         try {
175                                 if (reuses == 1)
176                                         s_timeout = 15000;
177                                 timer.Change (s_timeout, Timeout.Infinite);
178                                 stream.BeginRead (buffer, 0, BufferSize, onread_cb, this);
179                         } catch {
180                                 timer.Change (Timeout.Infinite, Timeout.Infinite);
181                                 CloseSocket ();
182                                 Unbind ();
183                         }
184                 }
185
186                 public RequestStream GetRequestStream (bool chunked, long contentlength)
187                 {
188                         if (i_stream == null) {
189                                 byte [] buffer = ms.GetBuffer ();
190                                 int length = (int) ms.Length;
191                                 ms = null;
192                                 if (chunked) {
193                                         this.chunked = true;
194                                         context.Response.SendChunked = true;
195                                         i_stream = new ChunkedInputStream (context, stream, buffer, position, length - position);
196                                 } else {
197                                         i_stream = new RequestStream (stream, buffer, position, length - position, contentlength);
198                                 }
199                         }
200                         return i_stream;
201                 }
202
203                 public ResponseStream GetResponseStream ()
204                 {
205                         // TODO: can we get this stream before reading the input?
206                         if (o_stream == null) {
207                                 HttpListener listener = context.Listener;
208                                 
209                                 if(listener == null)
210                                         return new ResponseStream (stream, context.Response, true);
211
212                                 o_stream = new ResponseStream (stream, context.Response, listener.IgnoreWriteExceptions);
213                         }
214                         return o_stream;
215                 }
216
217                 static void OnRead (IAsyncResult ares)
218                 {
219                         HttpConnection cnc = (HttpConnection) ares.AsyncState;
220                         cnc.OnReadInternal (ares);
221                 }
222
223                 void OnReadInternal (IAsyncResult ares)
224                 {
225                         timer.Change (Timeout.Infinite, Timeout.Infinite);
226                         int nread = -1;
227                         try {
228                                 nread = stream.EndRead (ares);
229                                 ms.Write (buffer, 0, nread);
230                                 if (ms.Length > 32768) {
231                                         SendError ("Bad request", 400);
232                                         Close (true);
233                                         return;
234                                 }
235                         } catch {
236                                 if (ms != null && ms.Length > 0)
237                                         SendError ();
238                                 if (sock != null) {
239                                         CloseSocket ();
240                                         Unbind ();
241                                 }
242                                 return;
243                         }
244
245                         if (nread == 0) {
246                                 //if (ms.Length > 0)
247                                 //      SendError (); // Why bother?
248                                 CloseSocket ();
249                                 Unbind ();
250                                 return;
251                         }
252
253                         if (ProcessInput (ms)) {
254                                 if (!context.HaveError)
255                                         context.Request.FinishInitialization ();
256
257                                 if (context.HaveError) {
258                                         SendError ();
259                                         Close (true);
260                                         return;
261                                 }
262
263                                 if (!epl.BindContext (context)) {
264                                         SendError ("Invalid host", 400);
265                                         Close (true);
266                                         return;
267                                 }
268                                 HttpListener listener = context.Listener;
269                                 if (last_listener != listener) {
270                                         RemoveConnection ();
271                                         listener.AddConnection (this);
272                                         last_listener = listener;
273                                 }
274
275                                 context_bound = true;
276                                 listener.RegisterContext (context);
277                                 return;
278                         }
279                         stream.BeginRead (buffer, 0, BufferSize, onread_cb, this);
280                 }
281
282                 void RemoveConnection ()
283                 {
284                         if (last_listener == null)
285                                 epl.RemoveConnection (this);
286                         else
287                                 last_listener.RemoveConnection (this);
288                 }
289
290                 enum InputState {
291                         RequestLine,
292                         Headers
293                 }
294
295                 enum LineState {
296                         None,
297                         CR,
298                         LF
299                 }
300
301                 InputState input_state = InputState.RequestLine;
302                 LineState line_state = LineState.None;
303                 int position;
304
305                 // true -> done processing
306                 // false -> need more input
307                 bool ProcessInput (MemoryStream ms)
308                 {
309                         byte [] buffer = ms.GetBuffer ();
310                         int len = (int) ms.Length;
311                         int used = 0;
312                         string line;
313
314                         while (true) {
315                                 if (context.HaveError)
316                                         return true;
317
318                                 if (position >= len)
319                                         break;
320
321                                 try {
322                                         line = ReadLine (buffer, position, len - position, ref used);
323                                         position += used;
324                                 } catch {
325                                         context.ErrorMessage = "Bad request";
326                                         context.ErrorStatus = 400;
327                                         return true;
328                                 }
329
330                                 if (line == null)
331                                         break;
332
333                                 if (line == "") {
334                                         if (input_state == InputState.RequestLine)
335                                                 continue;
336                                         current_line = null;
337                                         ms = null;
338                                         return true;
339                                 }
340
341                                 if (input_state == InputState.RequestLine) {
342                                         context.Request.SetRequestLine (line);
343                                         input_state = InputState.Headers;
344                                 } else {
345                                         try {
346                                                 context.Request.AddHeader (line);
347                                         } catch (Exception e) {
348                                                 context.ErrorMessage = e.Message;
349                                                 context.ErrorStatus = 400;
350                                                 return true;
351                                         }
352                                 }
353                         }
354
355                         if (used == len) {
356                                 ms.SetLength (0);
357                                 position = 0;
358                         }
359                         return false;
360                 }
361
362                 string ReadLine (byte [] buffer, int offset, int len, ref int used)
363                 {
364                         if (current_line == null)
365                                 current_line = new StringBuilder (128);
366                         int last = offset + len;
367                         used = 0;
368                         for (int i = offset; i < last && line_state != LineState.LF; i++) {
369                                 used++;
370                                 byte b = buffer [i];
371                                 if (b == 13) {
372                                         line_state = LineState.CR;
373                                 } else if (b == 10) {
374                                         line_state = LineState.LF;
375                                 } else {
376                                         current_line.Append ((char) b);
377                                 }
378                         }
379
380                         string result = null;
381                         if (line_state == LineState.LF) {
382                                 line_state = LineState.None;
383                                 result = current_line.ToString ();
384                                 current_line.Length = 0;
385                         }
386
387                         return result;
388                 }
389
390                 public void SendError (string msg, int status)
391                 {
392                         try {
393                                 HttpListenerResponse response = context.Response;
394                                 response.StatusCode = status;
395                                 response.ContentType = "text/html";
396                                 string description = HttpListenerResponseHelper.GetStatusDescription (status);
397                                 string str;
398                                 if (msg != null)
399                                         str = String.Format ("<h1>{0} ({1})</h1>", description, msg);
400                                 else
401                                         str = String.Format ("<h1>{0}</h1>", description);
402
403                                 byte [] error = context.Response.ContentEncoding.GetBytes (str);
404                                 response.Close (error, false);
405                         } catch {
406                                 // response was already closed
407                         }
408                 }
409
410                 public void SendError ()
411                 {
412                         SendError (context.ErrorMessage, context.ErrorStatus);
413                 }
414
415                 void Unbind ()
416                 {
417                         if (context_bound) {
418                                 epl.UnbindContext (context);
419                                 context_bound = false;
420                         }
421                 }
422
423                 public void Close ()
424                 {
425                         Close (false);
426                 }
427
428                 void CloseSocket ()
429                 {
430                         if (sock == null)
431                                 return;
432
433                         try {
434                                 sock.Close ();
435                         } catch {
436                         } finally {
437                                 sock = null;
438                         }
439                         RemoveConnection ();
440                 }
441
442                 internal void Close (bool force_close)
443                 {
444                         if (sock != null) {
445                                 Stream st = GetResponseStream ();
446                                 if (st != null)
447                                         st.Close ();
448
449                                 o_stream = null;
450                         }
451
452                         if (sock != null) {
453                                 force_close |= !context.Request.KeepAlive;
454                                 if (!force_close)
455                                         force_close = (context.Response.Headers ["connection"] == "close");
456                                 /*
457                                 if (!force_close) {
458 //                                      bool conn_close = (status_code == 400 || status_code == 408 || status_code == 411 ||
459 //                                                      status_code == 413 || status_code == 414 || status_code == 500 ||
460 //                                                      status_code == 503);
461
462                                         force_close |= (context.Request.ProtocolVersion <= HttpVersion.Version10);
463                                 }
464                                 */
465
466                                 if (!force_close && context.Request.FlushInput ()) {
467                                         if (chunked && context.Response.ForceCloseChunked == false) {
468                                                 // Don't close. Keep working.
469                                                 reuses++;
470                                                 Unbind ();
471                                                 Init ();
472                                                 BeginReadRequest ();
473                                                 return;
474                                         }
475
476                                         reuses++;
477                                         Unbind ();
478                                         Init ();
479                                         BeginReadRequest ();
480                                         return;
481                                 }
482
483                                 Socket s = sock;
484                                 sock = null;
485                                 try {
486                                         if (s != null)
487                                                 s.Shutdown (SocketShutdown.Both);
488                                 } catch {
489                                 } finally {
490                                         if (s != null)
491                                                 s.Close ();
492                                 }
493                                 Unbind ();
494                                 RemoveConnection ();
495                                 return;
496                         }
497                 }
498         }
499 }
500 #endif
501