Merge pull request #409 from Alkarex/patch-1
[mono.git] / mcs / class / System.ServiceModel / System.ServiceModel.Dispatcher / OperationInvokerHandler.cs
1 using System;
2 using System.Collections.Generic;
3 using System.Text;
4 using System.ServiceModel.Channels;
5 using System.ServiceModel;
6 using System.Reflection;
7 using System.Threading;
8
9 namespace System.ServiceModel.Dispatcher
10 {
11         internal class OperationInvokerHandler : BaseRequestProcessorHandler
12         {
13                 IDuplexChannel duplex;
14
15                 public OperationInvokerHandler (IChannel channel)
16                 {
17                         duplex = channel as IDuplexChannel;
18                 }
19
20                 protected override bool ProcessRequest (MessageProcessingContext mrc)
21                 {                       
22                         RequestContext rc = mrc.RequestContext;
23                         DispatchRuntime dispatchRuntime = mrc.OperationContext.EndpointDispatcher.DispatchRuntime;
24                         DispatchOperation operation = GetOperation (mrc.IncomingMessage, dispatchRuntime);
25                         mrc.Operation = operation;
26                         try {                           
27                                 DoProcessRequest (mrc);
28                                 if (!mrc.Operation.IsOneWay)
29                                         Reply (mrc, true);
30                         } catch (TargetInvocationException ex) {
31                                 mrc.ReplyMessage = BuildExceptionMessage (mrc, ex.InnerException, 
32                                         dispatchRuntime.ChannelDispatcher.IncludeExceptionDetailInFaults);
33                                 if (!mrc.Operation.IsOneWay)
34                                         Reply (mrc, true);
35                                 ProcessCustomErrorHandlers (mrc, ex);
36                         }
37                         return false;
38                 }
39
40                 void DoProcessRequest (MessageProcessingContext mrc)
41                 {
42                         DispatchOperation operation = mrc.Operation;
43                         Message req = mrc.IncomingMessage;
44                         object instance = mrc.InstanceContext.GetServiceInstance(req);
45                         object [] parameters, outParams;
46                         BuildInvokeParams (mrc, out parameters);
47
48                         if (operation.Invoker.IsSynchronous) {
49                                 object result = operation.Invoker.Invoke (instance, parameters, out outParams);
50                                 HandleInvokeResult (mrc, outParams, result);
51                         } else {
52                                 AsyncCallback callback = delegate {};
53                                 // FIXME: the original code passed null callback
54                                 // and null state, which is very wrong :(
55                                 // It is still wrong to pass dummy callback, but
56                                 // wrong code without obvious issues is better
57                                 // than code with an obvious issue.
58                                 var ar = operation.Invoker.InvokeBegin (instance, parameters, callback, null);
59                                 object result = operation.Invoker.InvokeEnd (instance, out outParams, ar);
60                                 HandleInvokeResult (mrc, outParams, result);
61                         }
62                 }
63
64                 void Reply (MessageProcessingContext mrc, bool useTimeout)
65                 {
66                         if (duplex != null)
67                                 mrc.Reply (duplex, useTimeout);
68                         else
69                                 mrc.Reply (useTimeout);
70                 }
71
72                 DispatchOperation GetOperation (Message input, DispatchRuntime dispatchRuntime)
73                 {
74                         if (dispatchRuntime.OperationSelector != null) {
75                                 string name = dispatchRuntime.OperationSelector.SelectOperation (ref input);
76                                 foreach (DispatchOperation d in dispatchRuntime.Operations)
77                                         if (d.Name == name)
78                                                 return d;
79                         } else {
80                                 string action = input.Headers.Action;
81                                 foreach (DispatchOperation d in dispatchRuntime.Operations)
82                                         if (d.Action == action)
83                                                 return d;
84                         }
85                         return dispatchRuntime.UnhandledDispatchOperation;
86                 }
87
88                 void HandleInvokeResult (MessageProcessingContext mrc, object [] outputs, object result)
89                 {
90                         DispatchOperation operation = mrc.Operation;
91                         mrc.EventsHandler.AfterInvoke (operation);
92
93                         if (operation.IsOneWay)
94                                 return;
95
96                         Message res = null;
97                         if (operation.SerializeReply)
98                                 res = operation.Formatter.SerializeReply (
99                                         mrc.OperationContext.IncomingMessageVersion, outputs, result);
100                         else
101                                 res = (Message) result;
102                         res.Headers.CopyHeadersFrom (mrc.OperationContext.OutgoingMessageHeaders);
103                         res.Properties.CopyProperties (mrc.OperationContext.OutgoingMessageProperties);
104                         if (res.Headers.RelatesTo == null)
105                                  res.Headers.RelatesTo = mrc.OperationContext.IncomingMessageHeaders.MessageId;
106                         mrc.ReplyMessage = res;
107                 }
108
109                 void BuildInvokeParams (MessageProcessingContext mrc, out object [] parameters)
110                 {
111                         DispatchOperation operation = mrc.Operation;
112                         EnsureValid (operation);
113
114                         if (operation.DeserializeRequest) {
115                                 parameters = operation.Invoker.AllocateInputs ();
116                                 operation.Formatter.DeserializeRequest (mrc.IncomingMessage, parameters);
117                         } else
118                                 parameters = new object [] { mrc.IncomingMessage };
119
120                         mrc.EventsHandler.BeforeInvoke (operation);
121                 }
122
123                 void ProcessCustomErrorHandlers (MessageProcessingContext mrc, Exception ex)
124                 {
125                         var dr = mrc.OperationContext.EndpointDispatcher.DispatchRuntime;
126                         bool shutdown = false;
127                         if (dr.ChannelDispatcher != null) // non-callback channel
128                                 foreach (var eh in dr.ChannelDispatcher.ErrorHandlers)
129                                         shutdown |= eh.HandleError (ex);
130                         if (shutdown)
131                                 ProcessSessionErrorShutdown (mrc);
132                 }
133
134                 void ProcessSessionErrorShutdown (MessageProcessingContext mrc)
135                 {
136                         var dr = mrc.OperationContext.EndpointDispatcher.DispatchRuntime;
137                         var session = mrc.OperationContext.Channel.InputSession;
138                         var dcc = mrc.OperationContext.Channel as IDuplexContextChannel;
139                         if (session == null || dcc == null)
140                                 return;
141                         foreach (var h in dr.InputSessionShutdownHandlers)
142                                 h.ChannelFaulted (dcc);
143                 }
144
145                 bool IsGenericFaultException (Type type, out Type arg)
146                 {
147                         for (; type != null; type = type.BaseType) {
148                                 if (!type.IsGenericType)
149                                         continue;
150                                 var tdef = type.GetGenericTypeDefinition ();
151                                 if (!tdef.Equals (typeof (FaultException<>)))
152                                         continue;
153                                 arg = type.GetGenericArguments () [0];
154                                 return true;
155                         }
156
157                         arg = null;
158                         return false;
159                 }
160
161                 Message BuildExceptionMessage (MessageProcessingContext mrc, Exception ex, bool includeDetailsInFault)
162                 {
163                         var dr = mrc.OperationContext.EndpointDispatcher.DispatchRuntime;
164                         var cd = dr.ChannelDispatcher;
165                         Message msg = null;
166                         if (cd != null) // non-callback channel
167                                 foreach (var eh in cd.ErrorHandlers)
168                                         eh.ProvideFault (ex, cd.MessageVersion, ref msg);
169                         if (msg != null)
170                                 return msg;
171
172                         var req = mrc.IncomingMessage;
173
174                         Type gft;
175                         var fe = ex as FaultException;
176                         if (fe != null && IsGenericFaultException (fe.GetType (), out gft)) {
177                                 foreach (var fci in mrc.Operation.FaultContractInfos) {
178                                         if (fci.Detail == gft)
179                                                 return Message.CreateMessage (req.Version, fe.CreateMessageFault (), fci.Action);
180                                 }
181                         }
182
183                         // FIXME: set correct name
184                         FaultCode fc = new FaultCode (
185                                 "InternalServiceFault",
186                                 req.Version.Addressing.Namespace);
187
188
189                         if (includeDetailsInFault) {
190                                 return Message.CreateMessage (req.Version, fc, ex.Message, new ExceptionDetail (ex), req.Headers.Action);
191                         }
192
193                         string faultString =
194                                 @"The server was unable to process the request due to an internal error.  The server may be able to return exception details (it depends on the server settings).";
195                         return Message.CreateMessage (req.Version, fc, faultString, req.Headers.Action);
196                 }
197
198                 void EnsureValid (DispatchOperation operation)
199                 {
200                         if (operation.Invoker == null)
201                                 throw new InvalidOperationException (String.Format ("DispatchOperation '{0}' for contract '{1}' requires Invoker.", operation.Name, operation.Parent.EndpointDispatcher.ContractName));
202                         if ((operation.DeserializeRequest || operation.SerializeReply) && operation.Formatter == null)
203                                 throw new InvalidOperationException ("The DispatchOperation '" + operation.Name + "' requires Formatter, since DeserializeRequest and SerializeReply are not both false.");
204                 }               
205         }
206 }