2009-02-23 Gonzalo Paniagua Javier <gonzalo@novell.com>
[mono.git] / mcs / class / System / System.Net / HttpListenerRequest.cs
index eefb361575701f4356f5f42ce158623b1660ee40..18586fb2df2a33cae7c98bee51a39d0d06cbf1d0 100644 (file)
@@ -25,7 +25,9 @@
 // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 //
-#if NET_2_0
+
+#if NET_2_0 && SECURITY_DEP
+
 using System.Collections;
 using System.Collections.Specialized;
 using System.Globalization;
@@ -36,7 +38,8 @@ namespace System.Net {
        public sealed class HttpListenerRequest
        {
                string [] accept_types;
-               int client_cert_error;
+//             int client_cert_error;
+//             bool no_get_certificate;
                Encoding content_encoding;
                long content_length;
                bool cl_set;
@@ -44,7 +47,6 @@ namespace System.Net {
                WebHeaderCollection headers;
                string method;
                Stream input_stream;
-               bool is_authenticated;
                Version version;
                NameValueCollection query_string; // check if null is ok, check if read-only, check case-sensitiveness
                string raw_url;
@@ -52,22 +54,22 @@ namespace System.Net {
                Uri url;
                Uri referrer;
                string [] user_languages;
-               bool no_get_certificate;
                HttpListenerContext context;
                bool is_chunked;
                static byte [] _100continue = Encoding.ASCII.GetBytes ("HTTP/1.1 100 Continue\r\n\r\n");
+               static readonly string [] no_body_methods = new string [] {
+                       "GET", "HEAD", "DELETE" };
 
                internal HttpListenerRequest (HttpListenerContext context)
                {
                        this.context = context;
                        headers = new WebHeaderCollection ();
                        input_stream = Stream.Null;
+                       version = HttpVersion.Version10;
                }
 
                static char [] separators = new char [] { ' ' };
-               // From WebRequestMethods.Http
-               static readonly string [] methods = new string [] { "GET", "POST", "HEAD",
-                                                               "PUT", "CONNECT", "MKCOL" };
+
                internal void SetRequestLine (string req)
                {
                        string [] parts = req.Split (separators, 3);
@@ -77,8 +79,17 @@ namespace System.Net {
                        }
 
                        method = parts [0];
-                       if (Array.IndexOf (methods, method) == -1) {
-                               context.ErrorMessage = "Invalid request line (verb).";
+                       foreach (char c in method){
+                               int ic = (int) c;
+
+                               if ((ic >= 'A' && ic <= 'Z') ||
+                                   (ic > 32 && c < 127 && c != '(' && c != ')' && c != '<' &&
+                                    c != '<' && c != '>' && c != '@' && c != ',' && c != ';' &&
+                                    c != ':' && c != '\\' && c != '"' && c != '/' && c != '[' &&
+                                    c != ']' && c != '?' && c != '=' && c != '{' && c != '}'))
+                                       continue;
+
+                               context.ErrorMessage = "(Invalid verb)";
                                return;
                        }
 
@@ -104,6 +115,8 @@ namespace System.Net {
                        if (query == null || query.Length == 0)
                                return;
 
+                       if (query [0] == '?')
+                               query = query.Substring (1);
                        string [] components = query.Split ('&');
                        foreach (string kv in components) {
                                int pos = kv.IndexOf ('=');
@@ -121,14 +134,24 @@ namespace System.Net {
                internal void FinishInitialization ()
                {
                        string host = UserHostName;
-                       if (version > HttpVersion.Version10 && (host == null || host == "")) {
+                       if (version > HttpVersion.Version10 && (host == null || host.Length == 0)) {
                                context.ErrorMessage = "Invalid host name";
                                return;
                        }
 
-                       if (host == null || host == "")
+                       string path;
+                       Uri raw_uri;
+                       if (Uri.MaybeUri (raw_url) && Uri.TryCreate (raw_url, UriKind.Absolute, out raw_uri))
+                               path = raw_uri.PathAndQuery;
+                       else
+                               path = raw_url;
+
+                       if ((host == null || host.Length == 0))
                                host = UserHostAddress;
 
+                       if (raw_uri != null)
+                               host = raw_uri.Host;
+       
                        int colon = host.IndexOf (':');
                        if (colon >= 0)
                                host = host.Substring (0, colon);
@@ -137,18 +160,14 @@ namespace System.Net {
                                                                (IsSecureConnection) ? "https" : "http",
                                                                host,
                                                                LocalEndPoint.Port);
-                       try {
-                               url = new Uri (base_uri + raw_url);
-                       } catch {
-                               context.ErrorMessage = "Invalid url";
+
+                       if (!Uri.TryCreate (base_uri + path, UriKind.Absolute, out url)){
+                               context.ErrorMessage = "Invalid url: " + base_uri + path;
                                return;
                        }
 
                        CreateQueryString (url.Query);
 
-                       if (method == "GET" || method == "HEAD")
-                               return;
-
                        string t_encoding = null;
                        if (version >= HttpVersion.Version11) {
                                t_encoding = Headers ["Transfer-Encoding"];
@@ -159,19 +178,35 @@ namespace System.Net {
                                }
                        }
 
-                       bool is_chunked = (t_encoding == "chunked");
+                       is_chunked = (t_encoding == "chunked");
+
+                       foreach (string m in no_body_methods)
+                               if (string.Compare (method, m, StringComparison.InvariantCultureIgnoreCase) == 0)
+                                       return;
+
                        if (!is_chunked && !cl_set) {
                                context.Connection.SendError (null, 411);
                                return;
                        }
 
-                       input_stream = context.Connection.GetRequestStream (is_chunked);
+                       if (is_chunked || content_length > 0) {
+                               input_stream = context.Connection.GetRequestStream (is_chunked, content_length);
+                       }
+
                        if (Headers ["Expect"] == "100-continue") {
                                ResponseStream output = context.Connection.GetResponseStream ();
                                output.InternalWrite (_100continue, 0, _100continue.Length);
                        }
                }
 
+               internal static string Unquote (String str) {
+                       int start = str.IndexOf ('\"');
+                       int end = str.LastIndexOf ('\"');
+                       if (start >= 0 && end >=0)
+                               str = str.Substring (start + 1, end - 1);
+                       return str.Trim ();
+               }
+
                internal void AddHeader (string header)
                {
                        int colon = header.IndexOf (':');
@@ -188,7 +223,7 @@ namespace System.Net {
                                case "accept-language":
                                        user_languages = val.Split (','); // yes, only split with a ','
                                        break;
-                               case "accept-types":
+                               case "accept":
                                        accept_types = val.Split (','); // yes, only split with a ','
                                        break;
                                case "content-length":
@@ -210,7 +245,48 @@ namespace System.Net {
                                                referrer = new Uri ("http://someone.is.screwing.with.the.headers.com/");
                                        }
                                        break;
-                               //TODO: cookie headers
+                               case "cookie":
+                                       if (cookies == null)
+                                               cookies = new CookieCollection();
+
+                                       string[] cookieStrings = val.Split(new char[] {',', ';'});
+                                       Cookie current = null;
+                                       int version = 0;
+                                       foreach (string cookieString in cookieStrings) {
+                                               string str = cookieString.Trim ();
+                                               if (str.Length == 0)
+                                                       continue;
+                                               if (str.StartsWith ("$Version")) {
+                                                       version = Int32.Parse (Unquote (str.Substring (str.IndexOf ("=") + 1)));
+                                               } else if (str.StartsWith ("$Path")) {
+                                                       if (current != null)
+                                                               current.Path = str.Substring (str.IndexOf ("=") + 1).Trim ();
+                                               } else if (str.StartsWith ("$Domain")) {
+                                                       if (current != null)
+                                                               current.Domain = str.Substring (str.IndexOf ("=") + 1).Trim ();
+                                               } else if (str.StartsWith ("$Port")) {
+                                                       if (current != null)
+                                                               current.Port = str.Substring (str.IndexOf ("=") + 1).Trim ();
+                                               } else {
+                                                       if (current != null) {
+                                                               cookies.Add (current);
+                                                       }
+                                                       current = new Cookie ();
+                                                       int idx = str.IndexOf ("=");
+                                                       if (idx > 0) {
+                                                               current.Name = str.Substring (0, idx).Trim ();
+                                                               current.Value =  str.Substring (idx + 1).Trim ();
+                                                       } else {
+                                                               current.Name = str.Trim ();
+                                                               current.Value = String.Empty;
+                                                       }
+                                                       current.Version = version;
+                                               }
+                                       }
+                                       if (current != null) {
+                                               cookies.Add (current);
+                                       }
+                                       break;
                        }
                }
 
@@ -218,12 +294,16 @@ namespace System.Net {
                        get { return accept_types; }
                }
 
+               [MonoTODO ("Always returns 0")]
                public int ClientCertificateError {
                        get {
+/*                             
                                if (no_get_certificate)
                                        throw new InvalidOperationException (
                                                "Call GetClientCertificate() before calling this method.");
                                return client_cert_error;
+*/
+                               return 0;
                        }
                }
 
@@ -253,7 +333,7 @@ namespace System.Net {
                }
 
                public bool HasEntityBody {
-                       get { return (method == "GET" || method == "HEAD" || content_length <= 0 || is_chunked); }
+                       get { return (content_length > 0 || is_chunked); }
                }
 
                public NameValueCollection Headers {
@@ -268,8 +348,9 @@ namespace System.Net {
                        get { return input_stream; }
                }
 
+               [MonoTODO ("Always returns false")]
                public bool IsAuthenticated {
-                       get { return is_authenticated; }
+                       get { return false; }
                }
 
                public bool IsLocal {