Merge pull request #495 from nicolas-raoul/fix-for-issue2907-with-no-formatting-changes
[mono.git] / mcs / class / System / System.Net / WebConnection.cs
index 33f8dc54cac5f5c56ca2570552321916bd5f8605..4922fb3291c34bbeac6e02b87b7e9720e0995a86 100644 (file)
@@ -46,7 +46,8 @@ namespace System.Net
                None,
                Status,
                Headers,
-               Content
+               Content,
+               Aborted
        }
 
        class WebConnection
@@ -62,19 +63,27 @@ namespace System.Net
                static AsyncCallback readDoneDelegate = new AsyncCallback (ReadDone);
                EventHandler abortHandler;
                AbortHelper abortHelper;
-               ReadState readState;
                internal WebConnectionData Data;
                bool chunkedRead;
                ChunkStream chunkStream;
                Queue queue;
                bool reused;
                int position;
-               bool busy;
-               HttpWebRequest priority_request;
+               bool busy;              
+               HttpWebRequest priority_request;                
                NetworkCredential ntlm_credentials;
                bool ntlm_authenticated;
                bool unsafe_sharing;
 
+               enum NtlmAuthState
+               {
+                       None,
+                       Challenge,
+                       Response
+               }
+               NtlmAuthState connect_ntlm_auth_state;
+               HttpWebRequest connect_request;
+
                bool ssl;
                bool certsAvailable;
                Exception connect_exception;
@@ -99,7 +108,6 @@ namespace System.Net
                {
                        this.sPoint = sPoint;
                        buffer = new byte [4096];
-                       readState = ReadState.None;
                        Data = new WebConnectionData ();
                        initConn = new WaitCallback (state => {
                                try {
@@ -247,7 +255,8 @@ namespace System.Net
                        }
                }
 
-               bool CreateTunnel (HttpWebRequest request, Stream stream, out byte [] buffer)
+               bool CreateTunnel (HttpWebRequest request, Uri connectUri,
+                                  Stream stream, out byte[] buffer)
                {
                        StringBuilder sb = new StringBuilder ();
                        sb.Append ("CONNECT ");
@@ -262,21 +271,44 @@ namespace System.Net
 
                        sb.Append ("\r\nHost: ");
                        sb.Append (request.Address.Authority);
-                       string challenge = Data.Challenge;
+
+                       bool ntlm = false;
+                       var challenge = Data.Challenge;
                        Data.Challenge = null;
-                       bool have_auth = (request.Headers ["Proxy-Authorization"] != null);
+                       var auth_header = request.Headers ["Proxy-Authorization"];
+                       bool have_auth = auth_header != null;
                        if (have_auth) {
                                sb.Append ("\r\nProxy-Authorization: ");
-                               sb.Append (request.Headers ["Proxy-Authorization"]);
+                               sb.Append (auth_header);
+                               ntlm = auth_header.ToUpper ().Contains ("NTLM");
                        } else if (challenge != null && Data.StatusCode == 407) {
-                               have_auth = true;
                                ICredentials creds = request.Proxy.Credentials;
-                               Authorization auth = AuthenticationManager.Authenticate (challenge, request, creds);
-                               if (auth != null) {
+                               have_auth = true;
+
+                               if (connect_request == null) {
+                                       // create a CONNECT request to use with Authenticate
+                                       connect_request = (HttpWebRequest)WebRequest.Create (
+                                               connectUri.Scheme + "://" + connectUri.Host + ":" + connectUri.Port + "/");
+                                       connect_request.Method = "CONNECT";
+                                       connect_request.Credentials = creds;
+                               }
+
+                               for (int i = 0; i < challenge.Length; i++) {
+                                       var auth = AuthenticationManager.Authenticate (challenge [i], connect_request, creds);
+                                       if (auth == null)
+                                               continue;
+                                       ntlm = (auth.Module.AuthenticationType == "NTLM");
                                        sb.Append ("\r\nProxy-Authorization: ");
                                        sb.Append (auth.Message);
+                                       break;
                                }
                        }
+
+                       if (ntlm) {
+                               sb.Append ("\r\nProxy-Connection: keep-alive");
+                               connect_ntlm_auth_state++;
+                       }
+
                        sb.Append ("\r\n\r\n");
 
                        Data.StatusCode = 0;
@@ -284,10 +316,19 @@ namespace System.Net
                        stream.Write (connectBytes, 0, connectBytes.Length);
 
                        int status;
-                       WebHeaderCollection result = ReadHeaders (request, stream, out buffer, out status);
-                       if (!have_auth && result != null && status == 407) { // Needs proxy auth
+                       WebHeaderCollection result = ReadHeaders (stream, out buffer, out status);
+                       if ((!have_auth || connect_ntlm_auth_state == NtlmAuthState.Challenge) &&
+                           result != null && status == 407) { // Needs proxy auth
+                               var connectionHeader = result ["Connection"];
+                               if (socket != null && !string.IsNullOrEmpty (connectionHeader) &&
+                                   connectionHeader.ToLower() == "close") {
+                                       // The server is requesting that this connection be closed
+                                       socket.Close();
+                                       socket = null;
+                               }
+
                                Data.StatusCode = status;
-                               Data.Challenge = result ["Proxy-Authenticate"];
+                               Data.Challenge = result.GetValues_internal ("Proxy-Authenticate", false);
                                return false;
                        } else if (status != 200) {
                                string msg = String.Format ("The remote server returned a {0} status code.", status);
@@ -298,7 +339,7 @@ namespace System.Net
                        return (result != null);
                }
 
-               WebHeaderCollection ReadHeaders (HttpWebRequest request, Stream stream, out byte [] retBuffer, out int status)
+               WebHeaderCollection ReadHeaders (Stream stream, out byte [] retBuffer, out int status)
                {
                        retBuffer = null;
                        status = 200;
@@ -321,10 +362,25 @@ namespace System.Net
                                headers = new WebHeaderCollection ();
                                while (ReadLine (ms.GetBuffer (), ref start, (int) ms.Length, ref str)) {
                                        if (str == null) {
-                                               if (ms.Length - start > 0) {
-                                                       retBuffer = new byte [ms.Length - start];
-                                                       Buffer.BlockCopy (ms.GetBuffer (), start, retBuffer, 0, retBuffer.Length);
+                                               int contentLen = 0;
+                                               try     {
+                                                       contentLen = int.Parse(headers["Content-Length"]);
+                                               }
+                                               catch {
+                                                       contentLen = 0;
                                                }
+
+                                               if (ms.Length - start - contentLen > 0) {
+                                                       // we've read more data than the response header and conents,
+                                                       // give back extra data to the caller
+                                                       retBuffer = new byte[ms.Length - start - contentLen];
+                                                       Buffer.BlockCopy(ms.GetBuffer(), start + contentLen, retBuffer, 0, retBuffer.Length);
+                                               }
+                                               else {
+                                                       // haven't read in some or all of the contents for the response, do so now
+                                                       FlushContents(stream, contentLen - (int)(ms.Length - start));
+                                               }
+
                                                return headers;
                                        }
 
@@ -345,6 +401,20 @@ namespace System.Net
                        }
                }
 
+               void FlushContents(Stream stream, int contentLength)
+               {
+                       while (contentLength > 0) {
+                               byte[] contentBuffer = new byte[contentLength];
+                               int bytesRead = stream.Read(contentBuffer, 0, contentLength);
+                               if (bytesRead > 0) {
+                                       contentLength -= bytesRead;
+                               }
+                               else {
+                                       break;
+                               }
+                       }
+               }
+
                bool CreateStream (HttpWebRequest request)
                {
                        try {
@@ -356,7 +426,7 @@ namespace System.Net
                                        if (!reused || nstream == null || nstream.GetType () != sslStream) {
                                                byte [] buffer = null;
                                                if (sPoint.UseConnect) {
-                                                       bool ok = CreateTunnel (request, serverStream, out buffer);
+                                                       bool ok = CreateTunnel (request, sPoint.Address, serverStream, out buffer);
                                                        if (!ok)
                                                                return false;
                                                }
@@ -422,7 +492,7 @@ namespace System.Net
                
                static void ReadDone (IAsyncResult result)
                {
-                       WebConnection cnc = (WebConnection) result.AsyncState;
+                       WebConnection cnc = (WebConnection)result.AsyncState;
                        WebConnectionData data = cnc.Data;
                        Stream ns = cnc.nstream;
                        if (ns == null) {
@@ -455,10 +525,10 @@ namespace System.Net
 
                        int pos = -1;
                        nread += cnc.position;
-                       if (cnc.readState == ReadState.None) { 
+                       if (data.ReadState == ReadState.None) { 
                                Exception exc = null;
                                try {
-                                       pos = cnc.GetResponse (cnc.buffer, nread);
+                                       pos = GetResponse (data, cnc.sPoint, cnc.buffer, nread);
                                } catch (Exception e) {
                                        exc = e;
                                }
@@ -469,14 +539,19 @@ namespace System.Net
                                }
                        }
 
-                       if (cnc.readState != ReadState.Content) {
+                       if (data.ReadState == ReadState.Aborted) {
+                               cnc.HandleError (WebExceptionStatus.RequestCanceled, null, "ReadDone");
+                               return;
+                       }
+
+                       if (data.ReadState != ReadState.Content) {
                                int est = nread * 2;
                                int max = (est < cnc.buffer.Length) ? cnc.buffer.Length : est;
                                byte [] newBuffer = new byte [max];
                                Buffer.BlockCopy (cnc.buffer, 0, newBuffer, 0, nread);
                                cnc.buffer = newBuffer;
                                cnc.position = nread;
-                               cnc.readState = ReadState.None;
+                               data.ReadState = ReadState.None;
                                InitRead (cnc);
                                return;
                        }
@@ -494,7 +569,11 @@ namespace System.Net
                                stream.ReadBuffer = cnc.buffer;
                                stream.ReadBufferOffset = pos;
                                stream.ReadBufferSize = nread;
-                               stream.CheckResponseInBuffer ();
+                               try {
+                                       stream.CheckResponseInBuffer ();
+                               } catch (Exception e) {
+                                       cnc.HandleError (WebExceptionStatus.ReceiveFailure, e, "ReadDone7");
+                               }
                        } else if (cnc.chunkStream == null) {
                                try {
                                        cnc.chunkStream = new ChunkStream (cnc.buffer, pos, nread, data.Headers);
@@ -549,7 +628,8 @@ namespace System.Net
                        }
                }
                
-               int GetResponse (byte [] buffer, int max)
+               static int GetResponse (WebConnectionData data, ServicePoint sPoint,
+                                       byte [] buffer, int max)
                {
                        int pos = 0;
                        string line = null;
@@ -557,7 +637,10 @@ namespace System.Net
                        bool isContinue = false;
                        bool emptyFirstLine = false;
                        do {
-                               if (readState == ReadState.None) {
+                               if (data.ReadState == ReadState.Aborted)
+                                       return -1;
+
+                               if (data.ReadState == ReadState.None) {
                                        lineok = ReadLine (buffer, ref pos, max, ref line);
                                        if (!lineok)
                                                return 0;
@@ -567,35 +650,34 @@ namespace System.Net
                                                continue;
                                        }
                                        emptyFirstLine = false;
-
-                                       readState = ReadState.Status;
+                                       data.ReadState = ReadState.Status;
 
                                        string [] parts = line.Split (' ');
                                        if (parts.Length < 2)
                                                return -1;
 
                                        if (String.Compare (parts [0], "HTTP/1.1", true) == 0) {
-                                               Data.Version = HttpVersion.Version11;
+                                               data.Version = HttpVersion.Version11;
                                                sPoint.SetVersion (HttpVersion.Version11);
                                        } else {
-                                               Data.Version = HttpVersion.Version10;
+                                               data.Version = HttpVersion.Version10;
                                                sPoint.SetVersion (HttpVersion.Version10);
                                        }
 
-                                       Data.StatusCode = (int) UInt32.Parse (parts [1]);
+                                       data.StatusCode = (int) UInt32.Parse (parts [1]);
                                        if (parts.Length >= 3)
-                                               Data.StatusDescription = String.Join (" ", parts, 2, parts.Length - 2);
+                                               data.StatusDescription = String.Join (" ", parts, 2, parts.Length - 2);
                                        else
-                                               Data.StatusDescription = "";
+                                               data.StatusDescription = "";
 
                                        if (pos >= max)
                                                return pos;
                                }
 
                                emptyFirstLine = false;
-                               if (readState == ReadState.Status) {
-                                       readState = ReadState.Headers;
-                                       Data.Headers = new WebHeaderCollection ();
+                               if (data.ReadState == ReadState.Status) {
+                                       data.ReadState = ReadState.Headers;
+                                       data.Headers = new WebHeaderCollection ();
                                        ArrayList headers = new ArrayList ();
                                        bool finished = false;
                                        while (!finished) {
@@ -624,25 +706,25 @@ namespace System.Net
                                                return 0;
 
                                        foreach (string s in headers)
-                                               Data.Headers.SetInternal (s);
+                                               data.Headers.SetInternal (s);
 
-                                       if (Data.StatusCode == (int) HttpStatusCode.Continue) {
+                                       if (data.StatusCode == (int) HttpStatusCode.Continue) {
                                                sPoint.SendContinue = true;
                                                if (pos >= max)
                                                        return pos;
 
-                                               if (Data.request.ExpectContinue) {
-                                                       Data.request.DoContinueDelegate (Data.StatusCode, Data.Headers);
+                                               if (data.request.ExpectContinue) {
+                                                       data.request.DoContinueDelegate (data.StatusCode, data.Headers);
                                                        // Prevent double calls when getting the
                                                        // headers in several packets.
-                                                       Data.request.ExpectContinue = false;
+                                                       data.request.ExpectContinue = false;
                                                }
 
-                                               readState = ReadState.None;
+                                               data.ReadState = ReadState.None;
                                                isContinue = true;
                                        }
                                        else {
-                                               readState = ReadState.Content;
+                                               data.ReadState = ReadState.Content;
                                                return pos;
                                        }
                                }
@@ -689,10 +771,13 @@ namespace System.Net
                                return;
                        }
 
-                       readState = ReadState.None;
                        request.SetWriteStream (new WebConnectionStream (this, request));
                }
-               
+
+#if MONOTOUCH
+               static bool warned_about_queue = false;
+#endif
+
                internal EventHandler SendRequest (HttpWebRequest request)
                {
                        if (request.Aborted)
@@ -705,6 +790,12 @@ namespace System.Net
                                        ThreadPool.QueueUserWorkItem (initConn, request);
                                } else {
                                        lock (queue) {
+#if MONOTOUCH
+                                               if (!warned_about_queue) {
+                                                       warned_about_queue = true;
+                                                       Console.WriteLine ("WARNING: An HttpWebRequest was added to the ConnectionGroup queue because the connection limit was reached.");
+                                               }
+#endif
                                                queue.Enqueue (request);
                                        }
                                }
@@ -717,8 +808,7 @@ namespace System.Net
                {
                        lock (queue) {
                                if (queue.Count > 0) {
-                                       Data.request = (HttpWebRequest) queue.Dequeue ();
-                                       SendRequest (Data.request);
+                                       SendRequest ((HttpWebRequest) queue.Dequeue ());
                                }
                        }
                }
@@ -1086,10 +1176,18 @@ namespace System.Net
 
                                if (ntlm_authenticated)
                                        ResetNtlm ();
+                               if (Data != null) {
+                                       lock (Data) {
+                                               Data.ReadState = ReadState.Aborted;
+                                       }
+                               }
                                busy = false;
                                Data = new WebConnectionData ();
                                if (sendNext)
                                        SendNext ();
+                               
+                               connect_request = null;
+                               connect_ntlm_auth_state = NtlmAuthState.None;
                        }
                }