Merge pull request #900 from Blewzman/FixAggregateExceptionGetBaseException
[mono.git] / mcs / class / System.ServiceModel / System.ServiceModel.Channels / HttpRequestChannel.cs
index 42fb54e2279f47da959c8462af0a39b6dca2c77a..6ebafbd99da921404d65d3adb895336911dfc45b 100644 (file)
@@ -1,5 +1,5 @@
 //
-// HttpRequestChannel.cs
+// HttpRequestChannel.cs 
 //
 // Author:
 //     Atsushi Enomoto <atsushi@ximian.com>
@@ -41,7 +41,7 @@ namespace System.ServiceModel.Channels
        {
                HttpChannelFactory<IRequestChannel> source;
 
-               WebRequest web_request;
+               List<WebRequest> web_requests = new List<WebRequest> ();
 
                // Constructor
 
@@ -56,6 +56,15 @@ namespace System.ServiceModel.Channels
                        get { return source.MessageEncoder; }
                }
 
+#if NET_2_1
+               public override T GetProperty<T> ()
+               {
+                       if (typeof (T) == typeof (IHttpCookieContainerManager))
+                               return source.GetProperty<T> ();
+                       return base.GetProperty<T> ();
+               }
+#endif
+
                // Request
 
                public override Message Request (Message message, TimeSpan timeout)
@@ -76,17 +85,18 @@ namespace System.ServiceModel.Channels
                                        destination = Via ?? RemoteAddress.Uri;
                        }
 
-                       web_request = HttpWebRequest.Create (destination);
+                       var web_request = (HttpWebRequest) HttpWebRequest.Create (destination);
+                       web_requests.Add (web_request);
+                       result.WebRequest = web_request;
                        web_request.Method = "POST";
                        web_request.ContentType = Encoder.ContentType;
-
-#if NET_2_1
+#if NET_2_1 || NET_4_0
+                       HttpWebRequest hwr = (web_request as HttpWebRequest);
                        var cmgr = source.GetProperty<IHttpCookieContainerManager> ();
                        if (cmgr != null)
-                               ((HttpWebRequest) web_request).CookieContainer = cmgr.CookieContainer;
+                               hwr.CookieContainer = cmgr.CookieContainer;
 #endif
 
-#if !MOONLIGHT // until we support NetworkCredential like SL4 will do.
                        // client authentication (while SL3 has NetworkCredential class, it is not implemented yet. So, it is non-SL only.)
                        var httpbe = (HttpTransportBindingElement) source.Transport;
                        string authType = null;
@@ -116,7 +126,6 @@ namespace System.ServiceModel.Channels
                                // FIXME: it is said required in SL4, but it blocks full WCF.
                                //web_request.UseDefaultCredentials = false;
                        }
-#endif
 
 #if !NET_2_1 // FIXME: implement this to not depend on Timeout property
                        web_request.Timeout = (int) timeout.TotalMilliseconds;
@@ -133,17 +142,64 @@ namespace System.ServiceModel.Channels
 
                        // apply HttpRequestMessageProperty if exists.
                        bool suppressEntityBody = false;
-#if !NET_2_1
                        string pname = HttpRequestMessageProperty.Name;
                        if (message.Properties.ContainsKey (pname)) {
                                HttpRequestMessageProperty hp = (HttpRequestMessageProperty) message.Properties [pname];
-                               web_request.Headers.Clear ();
-                               web_request.Headers.Add (hp.Headers);
+                               foreach (var key in hp.Headers.AllKeys) {
+                                       if (WebHeaderCollection.IsRestricted (key)) { // do not ignore this. WebHeaderCollection rejects restricted ones.
+                                               // FIXME: huh, there should be any better way to do such stupid conversion.
+                                               switch (key) {
+                                               case "Accept":
+                                                       web_request.Accept = hp.Headers [key];
+                                                       break;
+                                               case "Connection":
+                                                       web_request.Connection = hp.Headers [key];
+                                                       break;
+                                               //case "ContentLength":
+                                               //      web_request.ContentLength = hp.Headers [key];
+                                               //      break;
+                                               case "ContentType":
+                                                       web_request.ContentType = hp.Headers [key];
+                                                       break;
+                                               //case "Date":
+                                               //      web_request.Date = hp.Headers [key];
+                                               //      break;
+                                               case "Expect":
+                                                       web_request.Expect = hp.Headers [key];
+                                                       break;
+#if NET_4_0
+                                               case "Host":
+                                                       web_request.Host = hp.Headers [key];
+                                                       break;
+#endif
+                                               //case "If-Modified-Since":
+                                               //      web_request.IfModifiedSince = hp.Headers [key];
+                                               //      break;
+                                               case "Referer":
+                                                       web_request.Referer = hp.Headers [key];
+                                                       break;
+                                               case "Transfer-Encoding":
+                                                       web_request.TransferEncoding = hp.Headers [key];
+                                                       break;
+                                               case "User-Agent":
+                                                       web_request.UserAgent = hp.Headers [key];
+                                                       break;
+                                               }
+                                       }
+                                       else
+                                               web_request.Headers [key] = hp.Headers [key];
+                               }
                                web_request.Method = hp.Method;
                                // FIXME: do we have to handle hp.QueryString ?
                                if (hp.SuppressEntityBody)
                                        suppressEntityBody = true;
                        }
+#if !NET_2_1
+                       if (source.ClientCredentials != null) {
+                               var cred = source.ClientCredentials;
+                               if ((cred.ClientCertificate != null) && (cred.ClientCertificate.Certificate != null))
+                                       ((HttpWebRequest)web_request).ClientCertificates.Add (cred.ClientCertificate.Certificate);
+                       }
 #endif
 
                        if (!suppressEntityBody && String.Compare (web_request.Method, "GET", StringComparison.OrdinalIgnoreCase) != 0) {
@@ -163,6 +219,18 @@ namespace System.ServiceModel.Channels
                                                using (Stream s = web_request.EndGetRequestStream (r))
                                                        s.Write (buffer.GetBuffer (), 0, (int) buffer.Length);
                                                web_request.BeginGetResponse (GotResponse, result);
+                                       } catch (WebException ex) {
+                                               switch (ex.Status) {
+#if !NET_2_1
+                                               case WebExceptionStatus.NameResolutionFailure:
+#endif
+                                               case WebExceptionStatus.ConnectFailure:
+                                                       result.Complete (new EndpointNotFoundException (new EndpointNotFoundException ().Message, ex));
+                                                       break;
+                                               default:
+                                                       result.Complete (ex);
+                                                       break;
+                                               }
                                        } catch (Exception ex) {
                                                result.Complete (ex);
                                        }
@@ -180,7 +248,7 @@ namespace System.ServiceModel.Channels
                        WebResponse res;
                        Stream resstr;
                        try {
-                               res = web_request.EndGetResponse (result);
+                               res = channelResult.WebRequest.EndGetResponse (result);
                                resstr = res.GetResponseStream ();
                        } catch (WebException we) {
                                res = we.Response;
@@ -188,6 +256,29 @@ namespace System.ServiceModel.Channels
                                        channelResult.Complete (we);
                                        return;
                                }
+
+
+                               var hrr2 = (HttpWebResponse) res;
+                               
+                               if ((int) hrr2.StatusCode >= 400 && (int) hrr2.StatusCode < 500) {
+                                       Exception exception = new WebException (
+                                               String.Format ("There was an error on processing web request: Status code {0}({1}): {2}",
+                                                              (int) hrr2.StatusCode, hrr2.StatusCode, hrr2.StatusDescription), null,
+                                               WebExceptionStatus.ProtocolError, hrr2); 
+                                       
+                                       if ((int) hrr2.StatusCode == 404) {
+                                               // Throw the same exception .NET does
+                                               exception = new EndpointNotFoundException (
+                                                       "There was no endpoint listening at {0} that could accept the message. This is often caused by an incorrect address " +
+                                                       "or SOAP action. See InnerException, if present, for more details.",
+                                                       exception);
+                                       }
+                                       
+                                       channelResult.Complete (exception);
+                                       return;
+                               }
+
+
                                try {
                                        // The response might contain SOAP fault. It might not.
                                        resstr = res.GetResponseStream ();
@@ -203,36 +294,39 @@ namespace System.ServiceModel.Channels
                        }
 
                        try {
-                               using (var responseStream = resstr) {
-                                       MemoryStream ms = new MemoryStream ();
-                                       byte [] b = new byte [65536];
-                                       int n = 0;
-
-                                       while (true) {
-                                               n = responseStream.Read (b, 0, 65536);
-                                               if (n == 0)
-                                                       break;
-                                               ms.Write (b, 0, n);
+                               Message ret;
+
+                               // TODO: unit test to make sure an empty response never throws
+                               // an exception at this level
+                               if (hrr.ContentLength == 0) {
+                                       ret = Message.CreateMessage (Encoder.MessageVersion, String.Empty);
+                               } else {
+
+                                       using (var responseStream = resstr) {
+                                               MemoryStream ms = new MemoryStream ();
+                                               byte [] b = new byte [65536];
+                                               int n = 0;
+
+                                               while (true) {
+                                                       n = responseStream.Read (b, 0, 65536);
+                                                       if (n == 0)
+                                                               break;
+                                                       ms.Write (b, 0, n);
+                                               }
+                                               ms.Seek (0, SeekOrigin.Begin);
+
+                                               ret = Encoder.ReadMessage (
+                                                       ms, (int) source.Transport.MaxReceivedMessageSize, res.ContentType);
                                        }
-                                       ms.Seek (0, SeekOrigin.Begin);
-
-                                       Message ret = Encoder.ReadMessage (
-                                               ms, (int) source.Transport.MaxReceivedMessageSize, res.ContentType);
-                                       var rp = new HttpResponseMessageProperty () { StatusCode = hrr.StatusCode, StatusDescription = hrr.StatusDescription };
-                                       foreach (var key in hrr.Headers.AllKeys)
-                                               rp.Headers [key] = hrr.Headers [key];
-                                       ret.Properties.Add (HttpResponseMessageProperty.Name, rp);
-/*
-MessageBuffer buf = ret.CreateBufferedCopy (0x10000);
-ret = buf.CreateMessage ();
-System.Xml.XmlTextWriter w = new System.Xml.XmlTextWriter (Console.Out);
-w.Formatting = System.Xml.Formatting.Indented;
-buf.CreateMessage ().WriteMessage (w);
-w.Close ();
-*/
-                                       channelResult.Response = ret;
-                                       channelResult.Complete ();
                                }
+
+                               var rp = new HttpResponseMessageProperty () { StatusCode = hrr.StatusCode, StatusDescription = hrr.StatusDescription };
+                               foreach (var key in hrr.Headers.AllKeys)
+                                       rp.Headers [key] = hrr.Headers [key];
+                               ret.Properties.Add (HttpResponseMessageProperty.Name, rp);
+
+                               channelResult.Response = ret;
+                               channelResult.Complete ();
                        } catch (Exception ex) {
                                channelResult.Complete (ex);
                        } finally {
@@ -244,7 +338,7 @@ w.Close ();
                {
                        ThrowIfDisposedOrNotOpen ();
 
-                       HttpChannelRequestAsyncResult result = new HttpChannelRequestAsyncResult (message, timeout, callback, state);
+                       HttpChannelRequestAsyncResult result = new HttpChannelRequestAsyncResult (message, timeout, this, callback, state);
                        BeginProcessRequest (result);
                        return result;
                }
@@ -264,28 +358,27 @@ w.Close ();
 
                protected override void OnAbort ()
                {
-                       if (web_request != null)
+                       foreach (var web_request in web_requests.ToArray ())
                                web_request.Abort ();
-                       web_request = null;
+                       web_requests.Clear ();
                }
 
                // Close
 
                protected override void OnClose (TimeSpan timeout)
                {
-                       if (web_request != null)
-                               web_request.Abort ();
-                       web_request = null;
+                       OnAbort ();
                }
 
                protected override IAsyncResult OnBeginClose (TimeSpan timeout, AsyncCallback callback, object state)
                {
-                       throw new NotImplementedException ();
+                       OnAbort ();
+                       return base.OnBeginClose (timeout, callback, state);
                }
 
                protected override void OnEndClose (IAsyncResult result)
                {
-                       throw new NotImplementedException ();
+                       base.OnEndClose (result);
                }
 
                // Open
@@ -294,17 +387,19 @@ w.Close ();
                {
                }
 
+               [MonoTODO ("find out what to do here")]
                protected override IAsyncResult OnBeginOpen (TimeSpan timeout, AsyncCallback callback, object state)
                {
-                       throw new NotImplementedException ();
+                       return base.OnBeginOpen (timeout, callback, state);
                }
 
+               [MonoTODO ("find out what to do here")]
                protected override void OnEndOpen (IAsyncResult result)
                {
-                       throw new NotImplementedException ();
+                       base.OnEndOpen (result);
                }
 
-               class HttpChannelRequestAsyncResult : IAsyncResult
+               class HttpChannelRequestAsyncResult : IAsyncResult, IDisposable
                {
                        public Message Message {
                                get; private set;
@@ -317,24 +412,33 @@ w.Close ();
                        AsyncCallback callback;
                        ManualResetEvent wait;
                        Exception error;
+                       object locker = new object ();
+                       bool is_completed;
+                       HttpRequestChannel owner;
 
-                       public HttpChannelRequestAsyncResult (Message message, TimeSpan timeout, AsyncCallback callback, object state)
+                       public HttpChannelRequestAsyncResult (Message message, TimeSpan timeout, HttpRequestChannel owner, AsyncCallback callback, object state)
                        {
-                               CompletedSynchronously = true;
                                Message = message;
                                Timeout = timeout;
+                               this.owner = owner;
                                this.callback = callback;
                                AsyncState = state;
-
-                               wait = new ManualResetEvent (false);
                        }
 
                        public Message Response {
                                get; set;
                        }
 
+                       public WebRequest WebRequest { get; set; }
+
                        public WaitHandle AsyncWaitHandle {
-                               get { return wait; }
+                               get {
+                                       lock (locker) {
+                                               if (wait == null)
+                                                       wait = new ManualResetEvent (is_completed);
+                                       }
+                                       return wait;
+                               }
                        }
 
                        public object AsyncState {
@@ -355,7 +459,6 @@ w.Close ();
                                error = error ?? ex;
 
                                IsCompleted = true;
-                               wait.Set ();
                                if (callback != null)
                                        callback (this);
                        }
@@ -365,7 +468,15 @@ w.Close ();
                        }
 
                        public bool IsCompleted {
-                               get; private set;
+                               get { return is_completed; }
+                               set {
+                                       is_completed = value;
+                                       lock (locker) {
+                                               if (is_completed && wait != null)
+                                                       wait.Set ();
+                                               Cleanup ();
+                                       }
+                               }
                        }
 
                        public void WaitEnd ()
@@ -376,9 +487,9 @@ w.Close ();
                                        // exception to the Complete () method and allow the result to complete 'normally'.
 #if NET_2_1
                                        // neither Moonlight nor MonoTouch supports contexts (WaitOne default to false)
-                                       bool result = wait.WaitOne (Timeout);
+                                       bool result = AsyncWaitHandle.WaitOne (Timeout);
 #else
-                                       bool result = wait.WaitOne (Timeout, true);
+                                       bool result = AsyncWaitHandle.WaitOne (Timeout, true);
 #endif
                                        if (!result)
                                                throw new TimeoutException ();
@@ -386,6 +497,16 @@ w.Close ();
                                if (error != null)
                                        throw error;
                        }
+                       
+                       public void Dispose ()
+                       {
+                               Cleanup ();
+                       }
+                       
+                       void Cleanup ()
+                       {
+                               owner.web_requests.Remove (WebRequest);
+                       }
                }
        }
 }