Merge pull request #900 from Blewzman/FixAggregateExceptionGetBaseException
[mono.git] / mcs / class / System.ServiceModel / System.ServiceModel.Channels / HttpRequestChannel.cs
index c92c3b07d9364929690183dfce17ae66f1fdcf90..6ebafbd99da921404d65d3adb895336911dfc45b 100644 (file)
@@ -1,5 +1,5 @@
 //
-// HttpRequestChannel.cs
+// HttpRequestChannel.cs 
 //
 // Author:
 //     Atsushi Enomoto <atsushi@ximian.com>
@@ -41,7 +41,7 @@ namespace System.ServiceModel.Channels
        {
                HttpChannelFactory<IRequestChannel> source;
 
-               WebRequest web_request;
+               List<WebRequest> web_requests = new List<WebRequest> ();
 
                // Constructor
 
@@ -56,6 +56,15 @@ namespace System.ServiceModel.Channels
                        get { return source.MessageEncoder; }
                }
 
+#if NET_2_1
+               public override T GetProperty<T> ()
+               {
+                       if (typeof (T) == typeof (IHttpCookieContainerManager))
+                               return source.GetProperty<T> ();
+                       return base.GetProperty<T> ();
+               }
+#endif
+
                // Request
 
                public override Message Request (Message message, TimeSpan timeout)
@@ -76,17 +85,18 @@ namespace System.ServiceModel.Channels
                                        destination = Via ?? RemoteAddress.Uri;
                        }
 
-                       web_request = HttpWebRequest.Create (destination);
+                       var web_request = (HttpWebRequest) HttpWebRequest.Create (destination);
+                       web_requests.Add (web_request);
+                       result.WebRequest = web_request;
                        web_request.Method = "POST";
                        web_request.ContentType = Encoder.ContentType;
-
-#if NET_2_1
+#if NET_2_1 || NET_4_0
+                       HttpWebRequest hwr = (web_request as HttpWebRequest);
                        var cmgr = source.GetProperty<IHttpCookieContainerManager> ();
                        if (cmgr != null)
-                               ((HttpWebRequest) web_request).CookieContainer = cmgr.CookieContainer;
+                               hwr.CookieContainer = cmgr.CookieContainer;
 #endif
 
-#if !MOONLIGHT // until we support NetworkCredential like SL4 will do.
                        // client authentication (while SL3 has NetworkCredential class, it is not implemented yet. So, it is non-SL only.)
                        var httpbe = (HttpTransportBindingElement) source.Transport;
                        string authType = null;
@@ -116,7 +126,6 @@ namespace System.ServiceModel.Channels
                                // FIXME: it is said required in SL4, but it blocks full WCF.
                                //web_request.UseDefaultCredentials = false;
                        }
-#endif
 
 #if !NET_2_1 // FIXME: implement this to not depend on Timeout property
                        web_request.Timeout = (int) timeout.TotalMilliseconds;
@@ -133,17 +142,64 @@ namespace System.ServiceModel.Channels
 
                        // apply HttpRequestMessageProperty if exists.
                        bool suppressEntityBody = false;
-#if !NET_2_1
                        string pname = HttpRequestMessageProperty.Name;
                        if (message.Properties.ContainsKey (pname)) {
                                HttpRequestMessageProperty hp = (HttpRequestMessageProperty) message.Properties [pname];
-                               web_request.Headers.Clear ();
-                               web_request.Headers.Add (hp.Headers);
+                               foreach (var key in hp.Headers.AllKeys) {
+                                       if (WebHeaderCollection.IsRestricted (key)) { // do not ignore this. WebHeaderCollection rejects restricted ones.
+                                               // FIXME: huh, there should be any better way to do such stupid conversion.
+                                               switch (key) {
+                                               case "Accept":
+                                                       web_request.Accept = hp.Headers [key];
+                                                       break;
+                                               case "Connection":
+                                                       web_request.Connection = hp.Headers [key];
+                                                       break;
+                                               //case "ContentLength":
+                                               //      web_request.ContentLength = hp.Headers [key];
+                                               //      break;
+                                               case "ContentType":
+                                                       web_request.ContentType = hp.Headers [key];
+                                                       break;
+                                               //case "Date":
+                                               //      web_request.Date = hp.Headers [key];
+                                               //      break;
+                                               case "Expect":
+                                                       web_request.Expect = hp.Headers [key];
+                                                       break;
+#if NET_4_0
+                                               case "Host":
+                                                       web_request.Host = hp.Headers [key];
+                                                       break;
+#endif
+                                               //case "If-Modified-Since":
+                                               //      web_request.IfModifiedSince = hp.Headers [key];
+                                               //      break;
+                                               case "Referer":
+                                                       web_request.Referer = hp.Headers [key];
+                                                       break;
+                                               case "Transfer-Encoding":
+                                                       web_request.TransferEncoding = hp.Headers [key];
+                                                       break;
+                                               case "User-Agent":
+                                                       web_request.UserAgent = hp.Headers [key];
+                                                       break;
+                                               }
+                                       }
+                                       else
+                                               web_request.Headers [key] = hp.Headers [key];
+                               }
                                web_request.Method = hp.Method;
                                // FIXME: do we have to handle hp.QueryString ?
                                if (hp.SuppressEntityBody)
                                        suppressEntityBody = true;
                        }
+#if !NET_2_1
+                       if (source.ClientCredentials != null) {
+                               var cred = source.ClientCredentials;
+                               if ((cred.ClientCertificate != null) && (cred.ClientCertificate.Certificate != null))
+                                       ((HttpWebRequest)web_request).ClientCertificates.Add (cred.ClientCertificate.Certificate);
+                       }
 #endif
 
                        if (!suppressEntityBody && String.Compare (web_request.Method, "GET", StringComparison.OrdinalIgnoreCase) != 0) {
@@ -163,6 +219,18 @@ namespace System.ServiceModel.Channels
                                                using (Stream s = web_request.EndGetRequestStream (r))
                                                        s.Write (buffer.GetBuffer (), 0, (int) buffer.Length);
                                                web_request.BeginGetResponse (GotResponse, result);
+                                       } catch (WebException ex) {
+                                               switch (ex.Status) {
+#if !NET_2_1
+                                               case WebExceptionStatus.NameResolutionFailure:
+#endif
+                                               case WebExceptionStatus.ConnectFailure:
+                                                       result.Complete (new EndpointNotFoundException (new EndpointNotFoundException ().Message, ex));
+                                                       break;
+                                               default:
+                                                       result.Complete (ex);
+                                                       break;
+                                               }
                                        } catch (Exception ex) {
                                                result.Complete (ex);
                                        }
@@ -180,7 +248,7 @@ namespace System.ServiceModel.Channels
                        WebResponse res;
                        Stream resstr;
                        try {
-                               res = web_request.EndGetResponse (result);
+                               res = channelResult.WebRequest.EndGetResponse (result);
                                resstr = res.GetResponseStream ();
                        } catch (WebException we) {
                                res = we.Response;
@@ -188,6 +256,29 @@ namespace System.ServiceModel.Channels
                                        channelResult.Complete (we);
                                        return;
                                }
+
+
+                               var hrr2 = (HttpWebResponse) res;
+                               
+                               if ((int) hrr2.StatusCode >= 400 && (int) hrr2.StatusCode < 500) {
+                                       Exception exception = new WebException (
+                                               String.Format ("There was an error on processing web request: Status code {0}({1}): {2}",
+                                                              (int) hrr2.StatusCode, hrr2.StatusCode, hrr2.StatusDescription), null,
+                                               WebExceptionStatus.ProtocolError, hrr2); 
+                                       
+                                       if ((int) hrr2.StatusCode == 404) {
+                                               // Throw the same exception .NET does
+                                               exception = new EndpointNotFoundException (
+                                                       "There was no endpoint listening at {0} that could accept the message. This is often caused by an incorrect address " +
+                                                       "or SOAP action. See InnerException, if present, for more details.",
+                                                       exception);
+                                       }
+                                       
+                                       channelResult.Complete (exception);
+                                       return;
+                               }
+
+
                                try {
                                        // The response might contain SOAP fault. It might not.
                                        resstr = res.GetResponseStream ();
@@ -208,7 +299,7 @@ namespace System.ServiceModel.Channels
                                // TODO: unit test to make sure an empty response never throws
                                // an exception at this level
                                if (hrr.ContentLength == 0) {
-                                       ret = Message.CreateMessage (MessageVersion.Default, String.Empty);
+                                       ret = Message.CreateMessage (Encoder.MessageVersion, String.Empty);
                                } else {
 
                                        using (var responseStream = resstr) {
@@ -247,7 +338,7 @@ namespace System.ServiceModel.Channels
                {
                        ThrowIfDisposedOrNotOpen ();
 
-                       HttpChannelRequestAsyncResult result = new HttpChannelRequestAsyncResult (message, timeout, callback, state);
+                       HttpChannelRequestAsyncResult result = new HttpChannelRequestAsyncResult (message, timeout, this, callback, state);
                        BeginProcessRequest (result);
                        return result;
                }
@@ -267,25 +358,21 @@ namespace System.ServiceModel.Channels
 
                protected override void OnAbort ()
                {
-                       if (web_request != null)
+                       foreach (var web_request in web_requests.ToArray ())
                                web_request.Abort ();
-                       web_request = null;
+                       web_requests.Clear ();
                }
 
                // Close
 
                protected override void OnClose (TimeSpan timeout)
                {
-                       if (web_request != null)
-                               web_request.Abort ();
-                       web_request = null;
+                       OnAbort ();
                }
 
                protected override IAsyncResult OnBeginClose (TimeSpan timeout, AsyncCallback callback, object state)
                {
-                       if (web_request != null)
-                               web_request.Abort ();
-                       web_request = null;
+                       OnAbort ();
                        return base.OnBeginClose (timeout, callback, state);
                }
 
@@ -312,7 +399,7 @@ namespace System.ServiceModel.Channels
                        base.OnEndOpen (result);
                }
 
-               class HttpChannelRequestAsyncResult : IAsyncResult
+               class HttpChannelRequestAsyncResult : IAsyncResult, IDisposable
                {
                        public Message Message {
                                get; private set;
@@ -327,11 +414,13 @@ namespace System.ServiceModel.Channels
                        Exception error;
                        object locker = new object ();
                        bool is_completed;
+                       HttpRequestChannel owner;
 
-                       public HttpChannelRequestAsyncResult (Message message, TimeSpan timeout, AsyncCallback callback, object state)
+                       public HttpChannelRequestAsyncResult (Message message, TimeSpan timeout, HttpRequestChannel owner, AsyncCallback callback, object state)
                        {
                                Message = message;
                                Timeout = timeout;
+                               this.owner = owner;
                                this.callback = callback;
                                AsyncState = state;
                        }
@@ -340,6 +429,8 @@ namespace System.ServiceModel.Channels
                                get; set;
                        }
 
+                       public WebRequest WebRequest { get; set; }
+
                        public WaitHandle AsyncWaitHandle {
                                get {
                                        lock (locker) {
@@ -383,6 +474,7 @@ namespace System.ServiceModel.Channels
                                        lock (locker) {
                                                if (is_completed && wait != null)
                                                        wait.Set ();
+                                               Cleanup ();
                                        }
                                }
                        }
@@ -405,6 +497,16 @@ namespace System.ServiceModel.Channels
                                if (error != null)
                                        throw error;
                        }
+                       
+                       public void Dispose ()
+                       {
+                               Cleanup ();
+                       }
+                       
+                       void Cleanup ()
+                       {
+                               owner.web_requests.Remove (WebRequest);
+                       }
                }
        }
 }