Merge pull request #900 from Blewzman/FixAggregateExceptionGetBaseException
[mono.git] / mcs / class / System / System.Net / HttpWebRequest.cs
index 6b71f21243118bcd28e6b49a3200f5eaab448a0b..54eb1aa81df085eb4a2642350ae481c5d3d02804 100644 (file)
@@ -36,6 +36,7 @@ using System.Collections;
 using System.Configuration;
 using System.Globalization;
 using System.IO;
+using System.Net;
 using System.Net.Cache;
 using System.Net.Sockets;
 using System.Runtime.Remoting.Messaging;
@@ -46,12 +47,8 @@ using System.Threading;
 
 namespace System.Net 
 {
-#if MOONLIGHT
-       internal class HttpWebRequest : WebRequest, ISerializable {
-#else
        [Serializable]
        public class HttpWebRequest : WebRequest, ISerializable {
-#endif
                Uri requestUri;
                Uri actualUri;
                bool hostChanged;
@@ -92,7 +89,6 @@ namespace System.Net
                bool gotRequestStream;
                int redirects;
                bool expectContinue;
-               bool authCompleted;
                byte[] bodyBuffer;
                int bodyBufferLength;
                bool getResponseCalled;
@@ -110,7 +106,7 @@ namespace System.Net
                        Challenge,
                        Response
                }
-               NtlmAuthState ntlm_auth_state;
+               AuthorizationState auth_state, proxy_auth_state;
                string host;
 
                // Constructors
@@ -141,6 +137,7 @@ namespace System.Net
                        this.proxy = GlobalProxySelection.Select;
                        this.webHeaders = new WebHeaderCollection (WebHeaderCollection.HeaderInfo.Request);
                        ThrowOnError = true;
+                       ResetAuthorization ();
                }
                
                [Obsolete ("Serialization is obsoleted for this type", false)]
@@ -168,6 +165,13 @@ namespace System.Net
                        timeout = info.GetInt32 ("timeout");
                        redirects = info.GetInt32 ("redirects");
                        host = info.GetString ("host");
+                       ResetAuthorization ();
+               }
+
+               void ResetAuthorization ()
+               {
+                       auth_state = new AuthorizationState (this, false);
+                       proxy_auth_state = new AuthorizationState (this, true);
                }
                
                // Properties
@@ -194,6 +198,13 @@ namespace System.Net
                        get { return allowBuffering; }
                        set { allowBuffering = value; }
                }
+               
+#if NET_4_5
+               public virtual bool AllowReadStreamBuffering {
+                       get { return allowBuffering; }
+                       set { allowBuffering = value; }
+               }
+#endif
 
                static Exception GetMustImplement ()
                {
@@ -291,6 +302,9 @@ namespace System.Net
                        set { continueDelegate = value; }
                }
                
+#if NET_4_5
+               virtual
+#endif
                public CookieContainer CookieContainer {
                        get { return cookieContainer; }
                        set { cookieContainer = value; }
@@ -316,6 +330,8 @@ namespace System.Net
                        }
                }
 #endif
+
+#if !NET_2_1
                [MonoTODO]
                public static new RequestCachePolicy DefaultCachePolicy
                {
@@ -326,6 +342,7 @@ namespace System.Net
                                throw GetMustImplement ();
                        }
                }
+#endif
                
                [MonoTODO]
                public static int DefaultMaximumErrorResponseLength
@@ -358,6 +375,9 @@ namespace System.Net
                        }
                }
                
+#if NET_4_5
+               virtual
+#endif
                public bool HaveResponse {
                        get { return haveResponse; }
                }
@@ -409,6 +429,10 @@ namespace System.Net
                        if (idx >= 0)
                                return false;
 
+                       IPAddress ipaddr;
+                       if (IPAddress.TryParse (val, out ipaddr))
+                               return true;
+
                        string u = scheme + "://" + val + "/";
                        return Uri.IsWellFormedUriString (u, UriKind.Absolute);
                }
@@ -477,6 +501,14 @@ namespace System.Net
                        }
                }
                
+#if NET_4_5
+               [MonoTODO]
+               public int ContinueTimeout {
+                       get { throw new NotImplementedException (); }
+                       set { throw new NotImplementedException (); }
+               }
+#endif
+               
                public string MediaType {
                        get { return mediaType; }
                        set { 
@@ -560,9 +592,9 @@ namespace System.Net
                internal ServicePoint ServicePointNoLock {
                        get { return servicePoint; }
                }
-#if NET_4_5 || MOBILE
+#if NET_4_0
                [MonoTODO ("for portable library support")]
-               public bool SupportsCookieContainer { 
+               public virtual bool SupportsCookieContainer { 
                        get {
                                throw new NotImplementedException ();
                        }
@@ -835,10 +867,10 @@ namespace System.Net
                        if (writeStream == null || writeStream.RequestWritten || !InternalAllowBuffering)
                                return;
 #if NET_4_0
-                       if (contentLength < 0 && writeStream.CanWrite == true && writeStream.WriteBufferLength <= 0)
+                       if (contentLength < 0 && writeStream.CanWrite == true && writeStream.WriteBufferLength < 0)
                                return;
 
-                       if (contentLength < 0 && writeStream.WriteBufferLength > 0)
+                       if (contentLength < 0 && writeStream.WriteBufferLength >= 0)
                                InternalContentLength = writeStream.WriteBufferLength;
 #else
                        if (contentLength < 0 && writeStream.CanWrite == true)
@@ -926,6 +958,14 @@ namespace System.Net
 
                        return result.Response;
                }
+               
+#if NET_3_5
+               public Stream EndGetRequestStream (IAsyncResult asyncResult, out TransportContext transportContext)
+               {
+                       transportContext = null;
+                       return EndGetRequestStream (asyncResult);
+               }
+#endif
 
                public override WebResponse GetResponse()
                {
@@ -1047,29 +1087,19 @@ namespace System.Net
                        redirects++;
                        Exception e = null;
                        string uriString = null;
-
                        switch (code) {
                        case HttpStatusCode.Ambiguous: // 300
                                e = new WebException ("Ambiguous redirect.");
                                break;
                        case HttpStatusCode.MovedPermanently: // 301
                        case HttpStatusCode.Redirect: // 302
-                       case HttpStatusCode.TemporaryRedirect: // 307
-                               /* MS follows the redirect for POST too
-                               if (method != "GET" && method != "HEAD") // 10.3
-                                       return false;
-                               */
-
-                               contentLength = -1;
-                               bodyBufferLength = 0;
-                               bodyBuffer = null;
-                               if (code != HttpStatusCode.TemporaryRedirect)
+                               if (method == "POST")
                                        method = "GET";
-                               uriString = webResponse.Headers ["Location"];
+                               break;
+                       case HttpStatusCode.TemporaryRedirect: // 307
                                break;
                        case HttpStatusCode.SeeOther: //303
                                method = "GET";
-                               uriString = webResponse.Headers ["Location"];
                                break;
                        case HttpStatusCode.NotModified: // 304
                                return false;
@@ -1085,6 +1115,11 @@ namespace System.Net
                        if (e != null)
                                throw e;
 
+                       //contentLength = -1;
+                       //bodyBufferLength = 0;
+                       //bodyBuffer = null;
+                       uriString = webResponse.Headers ["Location"];
+
                        if (uriString == null)
                                throw new WebException ("No Location header found for " + (int) code,
                                                        WebExceptionStatus.ProtocolError);
@@ -1110,13 +1145,14 @@ namespace System.Net
                                webHeaders.RemoveAndAdd ("Transfer-Encoding", "chunked");
                                webHeaders.RemoveInternal ("Content-Length");
                        } else if (contentLength != -1) {
-                               if (ntlm_auth_state != NtlmAuthState.Challenge) {
+                               if (auth_state.NtlmAuthState == NtlmAuthState.Challenge || proxy_auth_state.NtlmAuthState == NtlmAuthState.Challenge) {
+                                       // We don't send any body with the NTLM Challenge request.
+                                       webHeaders.SetInternal ("Content-Length", "0");
+                               } else {
                                        if (contentLength > 0)
                                                continue100 = true;
 
                                        webHeaders.SetInternal ("Content-Length", contentLength.ToString ());
-                               } else {
-                                       webHeaders.SetInternal ("Content-Length", "0");
                                }
                                webHeaders.RemoveInternal ("Transfer-Encoding");
                        } else {
@@ -1139,7 +1175,9 @@ namespace System.Net
                        bool spoint10 = (proto_version == null || proto_version == HttpVersion.Version10);
 
                        if (keepAlive && (version == HttpVersion.Version10 || spoint10)) {
-                               webHeaders.RemoveAndAdd (connectionHeader, "keep-alive");
+                               if (webHeaders[connectionHeader] == null
+                                   || webHeaders[connectionHeader].IndexOf ("keep-alive", StringComparison.OrdinalIgnoreCase) == -1)
+                                       webHeaders.RemoveAndAdd (connectionHeader, "keep-alive");
                        } else if (!keepAlive && version == HttpVersion.Version11) {
                                webHeaders.RemoveAndAdd (connectionHeader, "close");
                        }
@@ -1206,7 +1244,7 @@ namespace System.Net
                        }
                }
 
-               internal void SendRequestHeaders (bool propagate_error)
+               internal byte[] GetRequestHeaders ()
                {
                        StringBuilder req = new StringBuilder ();
                        string query;
@@ -1228,18 +1266,7 @@ namespace System.Net
                                                                actualVersion.Major, actualVersion.Minor);
                        req.Append (GetHeaders ());
                        string reqstr = req.ToString ();
-                       byte [] bytes = Encoding.UTF8.GetBytes (reqstr);
-                       try {
-                               writeStream.SetHeaders (bytes);
-                       } catch (WebException wexc) {
-                               SetWriteStreamError (wexc.Status, wexc);
-                               if (propagate_error)
-                                       throw;
-                       } catch (Exception exc) {
-                               SetWriteStreamError (WebExceptionStatus.SendFailure, exc);
-                               if (propagate_error)
-                                       throw;
-                       }
+                       return Encoding.UTF8.GetBytes (reqstr);
                }
 
                internal void SetWriteStream (WebConnectionStream stream)
@@ -1254,14 +1281,32 @@ namespace System.Net
                                writeStream.SendChunked = false;
                        }
 
-                       SendRequestHeaders (false);
+                       byte[] requestHeaders = GetRequestHeaders ();
+                       WebAsyncResult result = new WebAsyncResult (new AsyncCallback (SetWriteStreamCB), null);
+                       writeStream.SetHeadersAsync (requestHeaders, result);
+               }
+
+               void SetWriteStreamCB(IAsyncResult ar)
+               {
+                       WebAsyncResult result = ar as WebAsyncResult;
 
+                       if (result.Exception != null) {
+                               WebException wexc = result.Exception as WebException;
+                               if (wexc != null) {
+                                       SetWriteStreamError (wexc.Status, wexc);
+                                       return;
+                               }
+                               SetWriteStreamError (WebExceptionStatus.SendFailure, result.Exception);
+                               return;
+                       }
+               
                        haveRequest = true;
-                       
+
                        if (bodyBuffer != null) {
                                // The body has been written and buffered. The request "user"
                                // won't write it again, so we must do it.
-                               if (ntlm_auth_state != NtlmAuthState.Challenge) {
+                               if (auth_state.NtlmAuthState != NtlmAuthState.Challenge && proxy_auth_state.NtlmAuthState != NtlmAuthState.Challenge) {
+                                       // FIXME: this is a blocking call on the thread pool that could lead to thread pool exhaustion
                                        writeStream.Write (bodyBuffer, 0, bodyBufferLength);
                                        bodyBuffer = null;
                                        writeStream.Close ();
@@ -1269,11 +1314,12 @@ namespace System.Net
                        } else if (method != "HEAD" && method != "GET" && method != "MKCOL" && method != "CONNECT" &&
                                        method != "TRACE") {
                                if (getResponseCalled && !writeStream.RequestWritten)
+                                       // FIXME: this is a blocking call on the thread pool that could lead to thread pool exhaustion
                                        writeStream.WriteRequest ();
                        }
 
                        if (asyncWrite != null) {
-                               asyncWrite.SetCompleted (false, stream);
+                               asyncWrite.SetCompleted (false, writeStream);
                                asyncWrite.DoCallback ();
                                asyncWrite = null;
                        }
@@ -1331,14 +1377,17 @@ namespace System.Net
                        }
                }
 
-               void HandleNtlmAuth (WebAsyncResult r)
+               bool HandleNtlmAuth (WebAsyncResult r)
                {
+                       bool isProxy = webResponse.StatusCode == HttpStatusCode.ProxyAuthenticationRequired;
+                       if ((isProxy ? proxy_auth_state.NtlmAuthState : auth_state.NtlmAuthState) == NtlmAuthState.None)
+                               return false;
+
                        WebConnectionStream wce = webResponse.GetResponseStream () as WebConnectionStream;
                        if (wce != null) {
                                WebConnection cnc = wce.Connection;
                                cnc.PriorityRequest = this;
-                               bool isProxy = (proxy != null && !proxy.IsBypassed (actualUri));
-                               ICredentials creds = (!isProxy) ? credentials : proxy.Credentials;
+                               ICredentials creds = !isProxy ? credentials : proxy.Credentials;
                                if (creds != null) {
                                        cnc.NtlmCredential = creds.GetCredential (requestUri, "NTLM");
                                        cnc.UnsafeAuthenticatedConnectionSharing = unsafe_auth_blah;
@@ -1349,6 +1398,7 @@ namespace System.Net
                        haveResponse = false;
                        webResponse.ReadAll ();
                        webResponse = null;
+                       return true;
                }
 
                internal void SetResponseData (WebConnectionData data)
@@ -1394,12 +1444,14 @@ namespace System.Net
                                        return;
                                }
 
+                               bool isProxy = ProxyQuery && !proxy.IsBypassed (actualUri);
+
                                bool redirected;
                                try {
                                        redirected = CheckFinalStatus (r);
                                        if (!redirected) {
-                                               if (ntlm_auth_state != NtlmAuthState.None && authCompleted && webResponse != null
-                                                       && (int)webResponse.StatusCode < 400) {
+                                               if ((isProxy ? proxy_auth_state.IsNtlmAuthenticated : auth_state.IsNtlmAuthenticated) &&
+                                                               webResponse != null && (int)webResponse.StatusCode < 400) {
                                                        WebConnectionStream wce = webResponse.GetResponseStream () as WebConnectionStream;
                                                        if (wce != null) {
                                                                WebConnection cnc = wce.Connection;
@@ -1417,10 +1469,8 @@ namespace System.Net
                                                r.DoCallback ();
                                        } else {
                                                if (webResponse != null) {
-                                                       if (ntlm_auth_state != NtlmAuthState.None) {
-                                                               HandleNtlmAuth (r);
+                                                       if (HandleNtlmAuth (r))
                                                                return;
-                                                       }
                                                        webResponse.Close ();
                                                }
                                                finished_reading = false;
@@ -1452,35 +1502,84 @@ namespace System.Net
                        }
                }
 
-               bool CheckAuthorization (WebResponse response, HttpStatusCode code)
+               struct AuthorizationState
                {
-                       authCompleted = false;
-                       if (code == HttpStatusCode.Unauthorized && credentials == null)
-                               return false;
+                       readonly HttpWebRequest request;
+                       readonly bool isProxy;
+                       bool isCompleted;
+                       NtlmAuthState ntlm_auth_state;
 
-                       bool isProxy = (code == HttpStatusCode.ProxyAuthenticationRequired);
-                       if (isProxy && (proxy == null || proxy.Credentials == null))
-                               return false;
+                       public bool IsCompleted {
+                               get { return isCompleted; }
+                       }
 
-                       string [] authHeaders = response.Headers.GetValues_internal ( (isProxy) ? "Proxy-Authenticate" : "WWW-Authenticate", false);
-                       if (authHeaders == null || authHeaders.Length == 0)
-                               return false;
+                       public NtlmAuthState NtlmAuthState {
+                               get { return ntlm_auth_state; }
+                       }
 
-                       ICredentials creds = (!isProxy) ? credentials : proxy.Credentials;
-                       Authorization auth = null;
-                       foreach (string authHeader in authHeaders) {
-                               auth = AuthenticationManager.Authenticate (authHeader, this, creds);
-                               if (auth != null)
-                                       break;
+                       public bool IsNtlmAuthenticated {
+                               get { return isCompleted && ntlm_auth_state != NtlmAuthState.None; }
                        }
-                       if (auth == null)
-                               return false;
-                       webHeaders [(isProxy) ? "Proxy-Authorization" : "Authorization"] = auth.Message;
-                       authCompleted = auth.Complete;
-                       bool is_ntlm = (auth.Module.AuthenticationType == "NTLM");
-                       if (is_ntlm)
-                               ntlm_auth_state = (NtlmAuthState)((int) ntlm_auth_state + 1);
-                       return true;
+
+                       public AuthorizationState (HttpWebRequest request, bool isProxy)
+                       {
+                               this.request = request;
+                               this.isProxy = isProxy;
+                               isCompleted = false;
+                               ntlm_auth_state = NtlmAuthState.None;
+                       }
+
+                       public bool CheckAuthorization (WebResponse response, HttpStatusCode code)
+                       {
+                               isCompleted = false;
+                               if (code == HttpStatusCode.Unauthorized && request.credentials == null)
+                                       return false;
+
+                               // FIXME: This should never happen!
+                               if (isProxy != (code == HttpStatusCode.ProxyAuthenticationRequired))
+                                       return false;
+
+                               if (isProxy && (request.proxy == null || request.proxy.Credentials == null))
+                                       return false;
+
+                               string [] authHeaders = response.Headers.GetValues_internal (isProxy ? "Proxy-Authenticate" : "WWW-Authenticate", false);
+                               if (authHeaders == null || authHeaders.Length == 0)
+                                       return false;
+
+                               ICredentials creds = (!isProxy) ? request.credentials : request.proxy.Credentials;
+                               Authorization auth = null;
+                               foreach (string authHeader in authHeaders) {
+                                       auth = AuthenticationManager.Authenticate (authHeader, request, creds);
+                                       if (auth != null)
+                                               break;
+                               }
+                               if (auth == null)
+                                       return false;
+                               request.webHeaders [isProxy ? "Proxy-Authorization" : "Authorization"] = auth.Message;
+                               isCompleted = auth.Complete;
+                               bool is_ntlm = (auth.Module.AuthenticationType == "NTLM");
+                               if (is_ntlm)
+                                       ntlm_auth_state = (NtlmAuthState)((int) ntlm_auth_state + 1);
+                               return true;
+                       }
+
+                       public void Reset ()
+                       {
+                               isCompleted = false;
+                               ntlm_auth_state = NtlmAuthState.None;
+                               request.webHeaders.RemoveInternal (isProxy ? "Proxy-Authorization" : "Authorization");
+                       }
+
+                       public override string ToString ()
+                       {
+                               return string.Format ("{0}AuthState [{1}:{2}]", isProxy ? "Proxy" : "", isCompleted, ntlm_auth_state);
+                       }
+               }
+
+               bool CheckAuthorization (WebResponse response, HttpStatusCode code)
+               {
+                       bool isProxy = code == HttpStatusCode.ProxyAuthenticationRequired;
+                       return isProxy ? proxy_auth_state.CheckAuthorization (response, code) : auth_state.CheckAuthorization (response, code);
                }
 
                // Returns true if redirected
@@ -1498,8 +1597,8 @@ namespace System.Net
                        HttpStatusCode code = 0;
                        if (throwMe == null && webResponse != null) {
                                code = webResponse.StatusCode;
-                               if (!authCompleted && ((code == HttpStatusCode.Unauthorized && credentials != null) ||
-                                    (ProxyQuery && code == HttpStatusCode.ProxyAuthenticationRequired))) {
+                               if ((!auth_state.IsCompleted && code == HttpStatusCode.Unauthorized && credentials != null) ||
+                                       (ProxyQuery && !proxy_auth_state.IsCompleted && code == HttpStatusCode.ProxyAuthenticationRequired)) {
                                        if (!usedPreAuth && CheckAuthorization (webResponse, code)) {
                                                // Keep the written body, so it can be rewritten in the retry
                                                if (InternalAllowBuffering) {
@@ -1507,7 +1606,7 @@ namespace System.Net
                                                        // We save it in the first request (first 401), don't send anything
                                                        // in the challenge request and send it in the response request along
                                                        // with the buffers kept form the first request.
-                                                       if (ntlm_auth_state != NtlmAuthState.Response) {
+                                                       if (auth_state.NtlmAuthState == NtlmAuthState.Challenge || proxy_auth_state.NtlmAuthState == NtlmAuthState.Challenge) {
                                                                bodyBuffer = writeStream.WriteBuffer;
                                                                bodyBufferLength = writeStream.WriteBufferLength;
                                                        }
@@ -1554,13 +1653,15 @@ namespace System.Net
                                bool b = false;
                                int c = (int) code;
                                if (allowAutoRedirect && c >= 300) {
+                                       b = Redirect (result, code);
                                        if (InternalAllowBuffering && writeStream.WriteBufferLength > 0) {
                                                bodyBuffer = writeStream.WriteBuffer;
                                                bodyBufferLength = writeStream.WriteBufferLength;
                                        }
-                                       b = Redirect (result, code);
-                                       if (b && ntlm_auth_state != 0)
-                                               ntlm_auth_state = 0;
+                                       if (b && !unsafe_auth_blah) {
+                                               auth_state.Reset ();
+                                               proxy_auth_state.Reset ();
+                                       }
                                }
 
                                if (resp != null && c >= 300 && c != 304)
@@ -1581,6 +1682,13 @@ namespace System.Net
 
                        throw throwMe;
                }
+
+               internal bool ReuseConnection {
+                       get;
+                       set;
+               }
+
+               internal WebConnection StoredConnection;
        }
 }