2010-05-26 Atsushi Enomoto <atsushi@ximian.com>
[mono.git] / mcs / class / System.ServiceModel / System.ServiceModel / ClientRuntimeChannel.cs
index fe8ba6f6f1e1877b0c1f39b722698390d2210a1e..65e42ad2b13c2771a7e63c8c8258499b1bac2819 100644 (file)
 //
 using System;
 using System.Reflection;
+using System.Runtime.Serialization;
 using System.ServiceModel.Channels;
 using System.ServiceModel.Description;
 using System.ServiceModel.Dispatcher;
 using System.ServiceModel.Security;
 using System.Threading;
+using System.Xml;
 
-namespace System.ServiceModel
+namespace System.ServiceModel.MonoInternal
 {
-#if TARGET_DOTNET
-       [MonoTODO]
-       public
-#else
-       internal
-#endif
-       class ClientRuntimeChannel
+       // FIXME: This is a quick workaround for bug #571907
+       public class ClientRuntimeChannel
                : CommunicationObject, IClientChannel
        {
                ClientRuntime runtime;
-               ChannelFactory factory;
-               IRequestChannel request_channel;
-               IOutputChannel output_channel;
+               EndpointAddress remote_address;
+               ContractDescription contract;
+               MessageVersion message_version;
+               TimeSpan default_open_timeout, default_close_timeout;
+               IChannel channel;
+               IChannelFactory factory;
+               OperationContext context;
+
+               #region delegates
                readonly ProcessDelegate _processDelegate;
 
                delegate object ProcessDelegate (MethodBase method, string operationName, object [] parameters);
 
-               public ClientRuntimeChannel (ClientRuntime runtime,
-                       ChannelFactory factory)
+               readonly RequestDelegate requestDelegate;
+
+               delegate Message RequestDelegate (Message msg, TimeSpan timeout);
+
+               readonly SendDelegate sendDelegate;
+
+               delegate void SendDelegate (Message msg, TimeSpan timeout);
+               #endregion
+
+               public ClientRuntimeChannel (ServiceEndpoint endpoint,
+                       ChannelFactory channelFactory, EndpointAddress remoteAddress, Uri via)
+                       : this (endpoint.CreateRuntime (), endpoint.Contract, channelFactory.DefaultOpenTimeout, channelFactory.DefaultCloseTimeout, null, channelFactory.OpenedChannelFactory, endpoint.Binding.MessageVersion, remoteAddress, via)
+               {
+               }
+
+               public ClientRuntimeChannel (ClientRuntime runtime, ContractDescription contract, TimeSpan openTimeout, TimeSpan closeTimeout, IChannel contextChannel, IChannelFactory factory, MessageVersion messageVersion, EndpointAddress remoteAddress, Uri via)
                {
                        this.runtime = runtime;
-                       this.factory = factory;
+                       this.remote_address = remoteAddress;
+                       runtime.Via = via;
+                       this.contract = contract;
+                       this.message_version = messageVersion;
+                       default_open_timeout = openTimeout;
+                       default_close_timeout = closeTimeout;
                        _processDelegate = new ProcessDelegate (Process);
+                       requestDelegate = new RequestDelegate (Request);
+                       sendDelegate = new SendDelegate (Send);
 
                        // default values
                        AllowInitializationUI = true;
+                       OperationTimeout = TimeSpan.FromMinutes (1);
+
+                       if (contextChannel != null)
+                               channel = contextChannel;
+                       else {
+                               var method = factory.GetType ().GetMethod ("CreateChannel", new Type [] {typeof (EndpointAddress), typeof (Uri)});
+                               channel = (IChannel) method.Invoke (factory, new object [] {remote_address, Via});
+                               this.factory = factory;
+                       }
+               }
+
+               public ContractDescription Contract {
+                       get { return contract; }
                }
 
                public ClientRuntime Runtime {
                        get { return runtime; }
                }
 
+               IRequestChannel RequestChannel {
+                       get { return channel as IRequestChannel; }
+               }
+
+               IOutputChannel OutputChannel {
+                       get { return channel as IOutputChannel; }
+               }
+
+               internal IDuplexChannel DuplexChannel {
+                       get { return channel as IDuplexChannel; }
+               }
+
                #region IClientChannel
 
                bool did_interactive_initialization;
@@ -106,7 +155,7 @@ namespace System.ServiceModel
 
                        public override bool WaitOne (int millisecondsTimeout)
                        {
-                               return WaitOne (millisecondsTimeout, false);
+                               return WaitHandle.WaitAll (ResultWaitHandles, millisecondsTimeout);
                        }
 
                        WaitHandle [] ResultWaitHandles {
@@ -118,6 +167,7 @@ namespace System.ServiceModel
                                }
                        }
 
+#if !MOONLIGHT
                        public override bool WaitOne (int millisecondsTimeout, bool exitContext)
                        {
                                return WaitHandle.WaitAll (ResultWaitHandles, millisecondsTimeout, exitContext);
@@ -127,6 +177,7 @@ namespace System.ServiceModel
                        {
                                return WaitHandle.WaitAll (ResultWaitHandles, timeout, exitContext);
                        }
+#endif
                }
 
                class DisplayUIAsyncResult : IAsyncResult
@@ -223,119 +274,154 @@ namespace System.ServiceModel
                #region IContextChannel
 
                [MonoTODO]
-               public bool AllowOutputBatching {
-                       get { throw new NotImplementedException (); }
-                       set { throw new NotImplementedException (); }
-               }
+               public bool AllowOutputBatching { get; set; }
 
-               [MonoTODO]
                public IInputSession InputSession {
                        get {
-                               ISessionChannel<IInputSession> ch = request_channel as ISessionChannel<IInputSession>;
-                               ch = ch ?? output_channel as ISessionChannel<IInputSession>;
-                               return ch != null ? ch.Session : null;
+                               ISessionChannel<IInputSession> ch = RequestChannel as ISessionChannel<IInputSession>;
+                               ch = ch ?? OutputChannel as ISessionChannel<IInputSession>;
+                               if (ch != null)
+                                       return ch.Session;
+                               var dch = OutputChannel as ISessionChannel<IDuplexSession>;
+                               return dch != null ? dch.Session : null;
                        }
                }
 
-               [MonoTODO]
                public EndpointAddress LocalAddress {
-                       get { throw new NotImplementedException (); }
+                       get {
+                               var dc = OperationChannel as IDuplexChannel;
+                               return dc != null ? dc.LocalAddress : null;
+                       }
                }
 
                [MonoTODO]
-               public TimeSpan OperationTimeout {
-                       get { throw new NotImplementedException (); }
-                       set { throw new NotImplementedException (); }
-               }
+               public TimeSpan OperationTimeout { get; set; }
 
-               [MonoTODO]
                public IOutputSession OutputSession {
                        get {
-                               ISessionChannel<IOutputSession> ch = request_channel as ISessionChannel<IOutputSession>;
-                               ch = ch ?? output_channel as ISessionChannel<IOutputSession>;
-                               return ch != null ? ch.Session : null;
+                               ISessionChannel<IOutputSession> ch = RequestChannel as ISessionChannel<IOutputSession>;
+                               ch = ch ?? OutputChannel as ISessionChannel<IOutputSession>;
+                               if (ch != null)
+                                       return ch.Session;
+                               var dch = OutputChannel as ISessionChannel<IDuplexSession>;
+                               return dch != null ? dch.Session : null;
                        }
                }
 
-               [MonoTODO]
                public EndpointAddress RemoteAddress {
-                       get { throw new NotImplementedException (); }
+                       get { return RequestChannel != null ? RequestChannel.RemoteAddress : OutputChannel.RemoteAddress; }
                }
 
-               [MonoTODO]
                public string SessionId {
-                       get { throw new NotImplementedException (); }
+                       get { return OutputSession != null ? OutputSession.Id : InputSession != null ? InputSession.Id : null; }
                }
 
                #endregion
 
                // CommunicationObject
                protected internal override TimeSpan DefaultOpenTimeout {
-                       get { return factory.DefaultOpenTimeout; }
+                       get { return default_open_timeout; }
                }
 
                protected internal override TimeSpan DefaultCloseTimeout {
-                       get { return factory.DefaultCloseTimeout; }
+                       get { return default_close_timeout; }
                }
 
                protected override void OnAbort ()
                {
-                       factory.Abort ();
+                       channel.Abort ();
+                       if (factory != null) // ... is it valid?
+                               factory.Abort ();
                }
 
+               Action<TimeSpan> close_delegate;
+
                protected override IAsyncResult OnBeginClose (
                        TimeSpan timeout, AsyncCallback callback, object state)
                {
-                       return factory.BeginClose (timeout, callback, state);
+                       if (close_delegate == null)
+                               close_delegate = new Action<TimeSpan> (OnClose);
+                       return close_delegate.BeginInvoke (timeout, callback, state);
                }
 
                protected override void OnEndClose (IAsyncResult result)
                {
-                       factory.EndClose (result);
+                       close_delegate.EndInvoke (result);
                }
 
                protected override void OnClose (TimeSpan timeout)
                {
-                       factory.Close (timeout);
+                       DateTime start = DateTime.Now;
+                       channel.Close (timeout);
                }
 
+               Action<TimeSpan> open_callback;
+
                protected override IAsyncResult OnBeginOpen (
                        TimeSpan timeout, AsyncCallback callback, object state)
                {
-                       throw new SystemException ("INTERNAL ERROR: this should not be called (or not supported yet)");
+                       if (open_callback == null)
+                               open_callback = new Action<TimeSpan> (OnOpen);
+                       return open_callback.BeginInvoke (timeout, callback, state);
                }
 
                protected override void OnEndOpen (IAsyncResult result)
                {
+                       if (open_callback == null)
+                               throw new InvalidOperationException ("Async open operation has not started");
+                       open_callback.EndInvoke (result);
                }
 
                protected override void OnOpen (TimeSpan timeout)
                {
                        if (runtime.InteractiveChannelInitializers.Count > 0 && !DidInteractiveInitialization)
                                throw new InvalidOperationException ("The client runtime is assigned interactive channel initializers, and in such case DisplayInitializationUI must be called before the channel is opened.");
+                       if (channel.State == CommunicationState.Created)
+                               channel.Open (timeout);
                }
 
                // IChannel
+
+               IChannel OperationChannel {
+                       get { return channel; }
+               }
+
                public T GetProperty<T> () where T : class
                {
-                       return factory.GetProperty<T> ();
+                       return OperationChannel.GetProperty<T> ();
                }
 
                // IExtensibleObject<IContextChannel>
-               [MonoTODO]
+
+               IExtensionCollection<IContextChannel> extensions;
+
                public IExtensionCollection<IContextChannel> Extensions {
-                       get { throw new NotImplementedException (); }
+                       get {
+                               if (extensions == null)
+                                       extensions = new ExtensionCollection<IContextChannel> (this);
+                               return extensions;
+                       }
                }
 
                #region Request/Output processing
 
                public IAsyncResult BeginProcess (MethodBase method, string operationName, object [] parameters, AsyncCallback callback, object asyncState)
                {
+                       if (context != null)
+                               throw new InvalidOperationException ("another operation is in progress");
+                       context = OperationContext.Current;
                        return _processDelegate.BeginInvoke (method, operationName, parameters, callback, asyncState);
                }
 
-               public object EndProcess (IAsyncResult result)
+               public object EndProcess (MethodBase method, string operationName, object [] parameters, IAsyncResult result)
                {
+                       context = null;
+                       if (result == null)
+                               throw new ArgumentNullException ("result");
+                       if (parameters == null)
+                               throw new ArgumentNullException ("parameters");
+                       // FIXME: the method arguments should be verified to be 
+                       // identical to the arguments in the corresponding begin method.
                        return _processDelegate.EndInvoke (result);
                }
 
@@ -344,8 +430,10 @@ namespace System.ServiceModel
                        try {
                                return DoProcess (method, operationName, parameters);
                        } catch (Exception ex) {
+#if MOONLIGHT // just for debugging
                                Console.Write ("Exception in async operation: ");
                                Console.WriteLine (ex);
+#endif
                                throw;
                        }
                }
@@ -355,6 +443,10 @@ namespace System.ServiceModel
                        if (AllowInitializationUI)
                                DisplayInitializationUI ();
                        OperationDescription od = SelectOperation (method, operationName, parameters);
+
+                       if (State != CommunicationState.Opened)
+                               Open ();
+
                        if (!od.IsOneWay)
                                return Request (od, parameters);
                        else {
@@ -370,136 +462,56 @@ namespace System.ServiceModel
                                operation = Runtime.OperationSelector.SelectOperation (method, parameters);
                        else
                                operation = operationName;
-                       OperationDescription od = factory.Endpoint.Contract.Operations.Find (operation);
+                       OperationDescription od = contract.Operations.Find (operation);
                        if (od == null)
                                throw new Exception (String.Format ("OperationDescription for operation '{0}' was not found in its internally-generated contract.", operation));
                        return od;
                }
 
-               BindingParameterCollection CreateBindingParameters ()
-               {
-                       BindingParameterCollection pl =
-                               new BindingParameterCollection ();
-
-                       ContractDescription cd = factory.Endpoint.Contract;
-#if !NET_2_1
-                       pl.Add (ChannelProtectionRequirements.CreateFromContract (cd));
-
-                       foreach (IEndpointBehavior behavior in factory.Endpoint.Behaviors)
-                               behavior.AddBindingParameters (factory.Endpoint, pl);
-#endif
-
-                       return pl;
-               }
-
-               void SetupOutputChannel ()
-               {
-                       if (output_channel != null)
-                               return;
-                       BindingParameterCollection pl =
-                               CreateBindingParameters ();
-                       bool session = false;
-                       switch (factory.Endpoint.Contract.SessionMode) {
-                       case SessionMode.Required:
-                               session = factory.Endpoint.Binding.CanBuildChannelFactory<IOutputSessionChannel> (pl);
-                               if (!session)
-                                       throw new InvalidOperationException ("The contract requires session support, but the binding does not support it.");
-                               break;
-                       case SessionMode.Allowed:
-                               session = !factory.Endpoint.Binding.CanBuildChannelFactory<IOutputChannel> (pl);
-                               break;
-                       }
-
-                       EndpointAddress address = factory.Endpoint.Address;
-                       Uri via = Runtime.Via;
-
-                       if (session) {
-                               IChannelFactory<IOutputSessionChannel> f =
-                                       factory.Endpoint.Binding.BuildChannelFactory<IOutputSessionChannel> (pl);
-                               f.Open ();
-                               output_channel = f.CreateChannel (address, via);
-                       } else {
-                               IChannelFactory<IOutputChannel> f =
-                                       factory.Endpoint.Binding.BuildChannelFactory<IOutputChannel> (pl);
-                               f.Open ();
-                               output_channel = f.CreateChannel (address, via);
-                       }
-
-                       output_channel.Open ();
-               }
-
-               void SetupRequestChannel ()
-               {
-                       if (request_channel != null)
-                               return;
-
-                       BindingParameterCollection pl =
-                               CreateBindingParameters ();
-                       bool session = false;
-                       switch (factory.Endpoint.Contract.SessionMode) {
-                       case SessionMode.Required:
-                               session = factory.Endpoint.Binding.CanBuildChannelFactory<IRequestSessionChannel> (pl);
-                               if (!session)
-                                       throw new InvalidOperationException ("The contract requires session support, but the binding does not support it.");
-                               break;
-                       case SessionMode.Allowed:
-                               session = !factory.Endpoint.Binding.CanBuildChannelFactory<IRequestChannel> (pl);
-                               break;
-                       }
-
-                       EndpointAddress address = factory.Endpoint.Address;
-                       Uri via = Runtime.Via;
-
-                       if (session) {
-                               IChannelFactory<IRequestSessionChannel> f =
-                                       factory.Endpoint.Binding.BuildChannelFactory<IRequestSessionChannel> (pl);
-                               f.Open ();
-                               request_channel = f.CreateChannel (address, via);
-                       } else {
-                               IChannelFactory<IRequestChannel> f =
-                                       factory.Endpoint.Binding.BuildChannelFactory<IRequestChannel> (pl);
-                               f.Open ();
-                               request_channel = f.CreateChannel (address, via);
-                       }
-
-                       request_channel.Open ();
-               }
-
                void Output (OperationDescription od, object [] parameters)
                {
-                       SetupOutputChannel ();
-
                        ClientOperation op = runtime.Operations [od.Name];
-                       Output (CreateRequest (op, parameters));
+                       Send (CreateRequest (op, parameters, false), OperationTimeout);
                }
 
                object Request (OperationDescription od, object [] parameters)
                {
-                       SetupRequestChannel ();
-
                        ClientOperation op = runtime.Operations [od.Name];
                        object [] inspections = new object [runtime.MessageInspectors.Count];
-                       Message req = CreateRequest (op, parameters);
+                       Message req = CreateRequest (op, parameters, true);
 
                        for (int i = 0; i < inspections.Length; i++)
                                inspections [i] = runtime.MessageInspectors [i].BeforeSendRequest (ref req, this);
 
-                       Message res = Request (req);
+                       Message res = Request (req, OperationTimeout);
                        if (res.IsFault) {
-                               MessageFault fault = MessageFault.CreateFault (res, runtime.MaxFaultSize);
-                               if (fault.HasDetail && fault is MessageFault.SimpleMessageFault) {
-                                       MessageFault.SimpleMessageFault simpleFault = fault as MessageFault.SimpleMessageFault;
-                                       object detail = simpleFault.Detail;
-                                       Type t = detail.GetType ();
-                                       Type faultType = typeof (FaultException<>).MakeGenericType (t);
-                                       object [] constructorParams = new object [] { detail, fault.Reason, fault.Code, fault.Actor };
-                                       FaultException fe = (FaultException) Activator.CreateInstance (faultType, constructorParams);
-                                       throw fe;
-                               }
-                               else {
-                                       // given a MessageFault, it is hard to figure out the type of the embedded detail
-                                       throw new FaultException(fault);
+                               var resb = res.CreateBufferedCopy (runtime.MaxFaultSize);
+                               MessageFault fault = MessageFault.CreateFault (resb.CreateMessage (), runtime.MaxFaultSize);
+                               var conv = OperationChannel.GetProperty<FaultConverter> () ?? FaultConverter.GetDefaultFaultConverter (res.Version);
+                               Exception ex;
+                               if (!conv.TryCreateException (resb.CreateMessage (), fault, out ex)) {
+                                       if (fault.HasDetail) {
+                                               Type detailType = typeof (ExceptionDetail);
+                                               var freader = fault.GetReaderAtDetailContents ();
+                                               DataContractSerializer ds = null;
+#if !NET_2_1
+                                               foreach (var fci in op.FaultContractInfos)
+                                                       if (res.Headers.Action == fci.Action || fci.Serializer.IsStartObject (freader)) {
+                                                               detailType = fci.Detail;
+                                                               ds = fci.Serializer;
+                                                               break;
+                                                       }
+#endif
+                                               if (ds == null)
+                                                       ds = new DataContractSerializer (detailType);
+                                               var detail = ds.ReadObject (freader);
+                                               ex = (Exception) Activator.CreateInstance (typeof (FaultException<>).MakeGenericType (detailType), new object [] {detail, fault.Reason, fault.Code, res.Headers.Action});
+                                       }
+
+                                       if (ex == null)
+                                               ex = new FaultException (fault);
                                }
+                               throw ex;
                        }
 
                        for (int i = 0; i < inspections.Length; i++)
@@ -511,27 +523,89 @@ namespace System.ServiceModel
                                return res;
                }
 
-               Message Request (Message msg)
+               #region Message-based Request() and Send()
+               // They are internal for ClientBase<T>.ChannelBase use.
+               internal Message Request (Message msg, TimeSpan timeout)
                {
-                       return request_channel.Request (msg, factory.Endpoint.Binding.SendTimeout);
+                       if (RequestChannel != null)
+                               return RequestChannel.Request (msg, timeout);
+                       else {
+                               DateTime startTime = DateTime.Now;
+                               OutputChannel.Send (msg, timeout);
+                               return ((IDuplexChannel) OutputChannel).Receive (timeout - (DateTime.Now - startTime));
+                       }
+               }
+
+               internal IAsyncResult BeginRequest (Message msg, TimeSpan timeout, AsyncCallback callback, object state)
+               {
+                       return requestDelegate.BeginInvoke (msg, timeout, callback, state);
+               }
+
+               internal Message EndRequest (IAsyncResult result)
+               {
+                       return requestDelegate.EndInvoke (result);
+               }
+
+               internal void Send (Message msg, TimeSpan timeout)
+               {
+                       OutputChannel.Send (msg, timeout);
                }
 
-               void Output (Message msg)
+               internal IAsyncResult BeginSend (Message msg, TimeSpan timeout, AsyncCallback callback, object state)
                {
-                       output_channel.Send (msg, factory.Endpoint.Binding.SendTimeout);
+                       return sendDelegate.BeginInvoke (msg, timeout, callback, state);
                }
 
-               Message CreateRequest (ClientOperation op, object [] parameters)
+               internal void EndSend (IAsyncResult result)
                {
-                       MessageVersion version = factory.Endpoint.Binding.MessageVersion;
+                       sendDelegate.EndInvoke (result);
+               }
+               #endregion
+
+               Message CreateRequest (ClientOperation op, object [] parameters, bool isOutputChannel)
+               {
+                       MessageVersion version = message_version;
                        if (version == null)
                                version = MessageVersion.Default;
 
+                       Message msg;
                        if (op.SerializeRequest)
-                               return op.GetFormatter ().SerializeRequest (
+                               msg = op.GetFormatter ().SerializeRequest (
                                        version, parameters);
-                       else
-                               return (Message) parameters [0];
+                       else {
+                               if (parameters.Length != 1)
+                                       throw new ArgumentException (String.Format ("Argument parameters does not match the expected input. It should contain only a Message, but has {0} parameters", parameters.Length));
+                               if (!(parameters [0] is Message))
+                                       throw new ArgumentException (String.Format ("Argument should be only a Message, but has {0}", parameters [0] != null ? parameters [0].GetType ().FullName : "null"));
+                               msg = (Message) parameters [0];
+                       }
+
+                       context = context ?? OperationContext.Current;
+                       if (context != null) {
+                               // CopyHeadersFrom does not work here (brings duplicates -> error)
+                               foreach (var mh in context.OutgoingMessageHeaders) {
+                                       int x = msg.Headers.FindHeader (mh.Name, mh.Namespace, mh.Actor);
+                                       if (x >= 0)
+                                               msg.Headers.RemoveAt (x);
+                                       msg.Headers.Add ((MessageHeader) mh);
+                               }
+                               msg.Properties.CopyProperties (context.OutgoingMessageProperties);
+                       }
+
+                       if (OutputSession != null)
+                               msg.Headers.MessageId = new UniqueId (OutputSession.Id);
+                       msg.Properties.AllowOutputBatching = AllowOutputBatching;
+
+                       if (msg.Version.Addressing.Equals (AddressingVersion.WSAddressing10)) {
+                               if (msg.Headers.MessageId == null)
+                                       msg.Headers.MessageId = new UniqueId ();
+                               if (msg.Headers.ReplyTo == null && !isOutputChannel)
+                                       msg.Headers.ReplyTo = new EndpointAddress (Constants.WsaAnonymousUri);
+                               if (msg.Headers.To == null)
+                                       msg.Headers.To = RemoteAddress.Uri;
+                       }
+
+                       return msg;
                }
 
                #endregion