New tests.
[mono.git] / mcs / class / System.Web.Mvc2 / System.Web.Mvc / ControllerActionInvoker.cs
1 /* ****************************************************************************\r
2  *\r
3  * Copyright (c) Microsoft Corporation. All rights reserved.\r
4  *\r
5  * This software is subject to the Microsoft Public License (Ms-PL). \r
6  * A copy of the license can be found in the license.htm file included \r
7  * in this distribution.\r
8  *\r
9  * You must not remove this notice, or any other, from this software.\r
10  *\r
11  * ***************************************************************************/\r
12 \r
13 namespace System.Web.Mvc {\r
14     using System;\r
15     using System.Collections.Generic;\r
16     using System.Diagnostics.CodeAnalysis;\r
17     using System.Globalization;\r
18     using System.Linq;\r
19     using System.Threading;\r
20     using System.Web;\r
21     using System.Web.Mvc.Resources;\r
22 \r
23     public class ControllerActionInvoker : IActionInvoker {\r
24 \r
25         private readonly static ControllerDescriptorCache _staticDescriptorCache = new ControllerDescriptorCache();\r
26 \r
27         private ModelBinderDictionary _binders;\r
28         private ControllerDescriptorCache _instanceDescriptorCache;\r
29 \r
30         [SuppressMessage("Microsoft.Usage", "CA2227:CollectionPropertiesShouldBeReadOnly",\r
31             Justification = "Property is settable so that the dictionary can be provided for unit testing purposes.")]\r
32         protected internal ModelBinderDictionary Binders {\r
33             get {\r
34                 if (_binders == null) {\r
35                     _binders = ModelBinders.Binders;\r
36                 }\r
37                 return _binders;\r
38             }\r
39             set {\r
40                 _binders = value;\r
41             }\r
42         }\r
43 \r
44         internal ControllerDescriptorCache DescriptorCache {\r
45             get {\r
46                 if (_instanceDescriptorCache == null) {\r
47                     _instanceDescriptorCache = _staticDescriptorCache;\r
48                 }\r
49                 return _instanceDescriptorCache;\r
50             }\r
51             set {\r
52                 _instanceDescriptorCache = value;\r
53             }\r
54         }\r
55 \r
56         private static void AddControllerToFilterList<TFilter>(ControllerBase controller, IList<TFilter> filterList) where TFilter : class {\r
57             TFilter controllerAsFilter = controller as TFilter;\r
58             if (controllerAsFilter != null) {\r
59                 filterList.Insert(0, controllerAsFilter);\r
60             }\r
61         }\r
62 \r
63         protected virtual ActionResult CreateActionResult(ControllerContext controllerContext, ActionDescriptor actionDescriptor, object actionReturnValue) {\r
64             if (actionReturnValue == null) {\r
65                 return new EmptyResult();\r
66             }\r
67 \r
68             ActionResult actionResult = (actionReturnValue as ActionResult) ??\r
69                 new ContentResult { Content = Convert.ToString(actionReturnValue, CultureInfo.InvariantCulture) };\r
70             return actionResult;\r
71         }\r
72 \r
73         protected virtual ControllerDescriptor GetControllerDescriptor(ControllerContext controllerContext) {\r
74             Type controllerType = controllerContext.Controller.GetType();\r
75             ControllerDescriptor controllerDescriptor = DescriptorCache.GetDescriptor(controllerType, () => new ReflectedControllerDescriptor(controllerType));\r
76             return controllerDescriptor;\r
77         }\r
78 \r
79         protected virtual ActionDescriptor FindAction(ControllerContext controllerContext, ControllerDescriptor controllerDescriptor, string actionName) {\r
80             ActionDescriptor actionDescriptor = controllerDescriptor.FindAction(controllerContext, actionName);\r
81             return actionDescriptor;\r
82         }\r
83 \r
84         protected virtual FilterInfo GetFilters(ControllerContext controllerContext, ActionDescriptor actionDescriptor) {\r
85             FilterInfo filters = actionDescriptor.GetFilters();\r
86 \r
87             // if the current controller implements one of the filter interfaces, it should be added to the list at position 0\r
88             ControllerBase controller = controllerContext.Controller;\r
89             AddControllerToFilterList(controller, filters.ActionFilters);\r
90             AddControllerToFilterList(controller, filters.ResultFilters);\r
91             AddControllerToFilterList(controller, filters.AuthorizationFilters);\r
92             AddControllerToFilterList(controller, filters.ExceptionFilters);\r
93 \r
94             return filters;\r
95         }\r
96 \r
97         private IModelBinder GetModelBinder(ParameterDescriptor parameterDescriptor) {\r
98             // look on the parameter itself, then look in the global table\r
99             return parameterDescriptor.BindingInfo.Binder ?? Binders.GetBinder(parameterDescriptor.ParameterType);\r
100         }\r
101 \r
102         protected virtual object GetParameterValue(ControllerContext controllerContext, ParameterDescriptor parameterDescriptor) {\r
103             // collect all of the necessary binding properties\r
104             Type parameterType = parameterDescriptor.ParameterType;\r
105             IModelBinder binder = GetModelBinder(parameterDescriptor);\r
106             IValueProvider valueProvider = controllerContext.Controller.ValueProvider;\r
107             string parameterName = parameterDescriptor.BindingInfo.Prefix ?? parameterDescriptor.ParameterName;\r
108             Predicate<string> propertyFilter = GetPropertyFilter(parameterDescriptor);\r
109 \r
110             // finally, call into the binder\r
111             ModelBindingContext bindingContext = new ModelBindingContext() {\r
112                 FallbackToEmptyPrefix = (parameterDescriptor.BindingInfo.Prefix == null), // only fall back if prefix not specified\r
113                 ModelMetadata = ModelMetadataProviders.Current.GetMetadataForType(null, parameterType),\r
114                 ModelName = parameterName,\r
115                 ModelState = controllerContext.Controller.ViewData.ModelState,\r
116                 PropertyFilter = propertyFilter,\r
117                 ValueProvider = valueProvider\r
118             };\r
119 \r
120             object result = binder.BindModel(controllerContext, bindingContext);\r
121             return result ?? parameterDescriptor.DefaultValue;\r
122         }\r
123 \r
124         protected virtual IDictionary<string, object> GetParameterValues(ControllerContext controllerContext, ActionDescriptor actionDescriptor) {\r
125             Dictionary<string, object> parametersDict = new Dictionary<string, object>(StringComparer.OrdinalIgnoreCase);\r
126             ParameterDescriptor[] parameterDescriptors = actionDescriptor.GetParameters();\r
127 \r
128             foreach (ParameterDescriptor parameterDescriptor in parameterDescriptors) {\r
129                 parametersDict[parameterDescriptor.ParameterName] = GetParameterValue(controllerContext, parameterDescriptor);\r
130             }\r
131             return parametersDict;\r
132         }\r
133 \r
134         private static Predicate<string> GetPropertyFilter(ParameterDescriptor parameterDescriptor) {\r
135             ParameterBindingInfo bindingInfo = parameterDescriptor.BindingInfo;\r
136             return propertyName => BindAttribute.IsPropertyAllowed(propertyName, bindingInfo.Include.ToArray(), bindingInfo.Exclude.ToArray());\r
137         }\r
138 \r
139         public virtual bool InvokeAction(ControllerContext controllerContext, string actionName) {\r
140             if (controllerContext == null) {\r
141                 throw new ArgumentNullException("controllerContext");\r
142             }\r
143             if (String.IsNullOrEmpty(actionName)) {\r
144                 throw new ArgumentException(MvcResources.Common_NullOrEmpty, "actionName");\r
145             }\r
146 \r
147             ControllerDescriptor controllerDescriptor = GetControllerDescriptor(controllerContext);\r
148             ActionDescriptor actionDescriptor = FindAction(controllerContext, controllerDescriptor, actionName);\r
149             if (actionDescriptor != null) {\r
150                 FilterInfo filterInfo = GetFilters(controllerContext, actionDescriptor);\r
151 \r
152                 try {\r
153                     AuthorizationContext authContext = InvokeAuthorizationFilters(controllerContext, filterInfo.AuthorizationFilters, actionDescriptor);\r
154                     if (authContext.Result != null) {\r
155                         // the auth filter signaled that we should let it short-circuit the request\r
156                         InvokeActionResult(controllerContext, authContext.Result);\r
157                     }\r
158                     else {\r
159                         if (controllerContext.Controller.ValidateRequest) {\r
160                             ValidateRequest(controllerContext);\r
161                         }\r
162 \r
163                         IDictionary<string, object> parameters = GetParameterValues(controllerContext, actionDescriptor);\r
164                         ActionExecutedContext postActionContext = InvokeActionMethodWithFilters(controllerContext, filterInfo.ActionFilters, actionDescriptor, parameters);\r
165                         InvokeActionResultWithFilters(controllerContext, filterInfo.ResultFilters, postActionContext.Result);\r
166                     }\r
167                 }\r
168                 catch (ThreadAbortException) {\r
169                     // This type of exception occurs as a result of Response.Redirect(), but we special-case so that\r
170                     // the filters don't see this as an error.\r
171                     throw;\r
172                 }\r
173                 catch (Exception ex) {\r
174                     // something blew up, so execute the exception filters\r
175                     ExceptionContext exceptionContext = InvokeExceptionFilters(controllerContext, filterInfo.ExceptionFilters, ex);\r
176                     if (!exceptionContext.ExceptionHandled) {\r
177                         throw;\r
178                     }\r
179                     InvokeActionResult(controllerContext, exceptionContext.Result);\r
180                 }\r
181 \r
182                 return true;\r
183             }\r
184 \r
185             // notify controller that no method matched\r
186             return false;\r
187         }\r
188 \r
189         protected virtual ActionResult InvokeActionMethod(ControllerContext controllerContext, ActionDescriptor actionDescriptor, IDictionary<string, object> parameters) {\r
190             object returnValue = actionDescriptor.Execute(controllerContext, parameters);\r
191             ActionResult result = CreateActionResult(controllerContext, actionDescriptor, returnValue);\r
192             return result;\r
193         }\r
194 \r
195         internal static ActionExecutedContext InvokeActionMethodFilter(IActionFilter filter, ActionExecutingContext preContext, Func<ActionExecutedContext> continuation) {\r
196             filter.OnActionExecuting(preContext);\r
197             if (preContext.Result != null) {\r
198                 return new ActionExecutedContext(preContext, preContext.ActionDescriptor, true /* canceled */, null /* exception */) {\r
199                     Result = preContext.Result\r
200                 };\r
201             }\r
202 \r
203             bool wasError = false;\r
204             ActionExecutedContext postContext = null;\r
205             try {\r
206                 postContext = continuation();\r
207             }\r
208             catch (ThreadAbortException) {\r
209                 // This type of exception occurs as a result of Response.Redirect(), but we special-case so that\r
210                 // the filters don't see this as an error.\r
211                 postContext = new ActionExecutedContext(preContext, preContext.ActionDescriptor, false /* canceled */, null /* exception */);\r
212                 filter.OnActionExecuted(postContext);\r
213                 throw;\r
214             }\r
215             catch (Exception ex) {\r
216                 wasError = true;\r
217                 postContext = new ActionExecutedContext(preContext, preContext.ActionDescriptor, false /* canceled */, ex);\r
218                 filter.OnActionExecuted(postContext);\r
219                 if (!postContext.ExceptionHandled) {\r
220                     throw;\r
221                 }\r
222             }\r
223             if (!wasError) {\r
224                 filter.OnActionExecuted(postContext);\r
225             }\r
226             return postContext;\r
227         }\r
228 \r
229         protected virtual ActionExecutedContext InvokeActionMethodWithFilters(ControllerContext controllerContext, IList<IActionFilter> filters, ActionDescriptor actionDescriptor, IDictionary<string, object> parameters) {\r
230             ActionExecutingContext preContext = new ActionExecutingContext(controllerContext, actionDescriptor, parameters);\r
231             Func<ActionExecutedContext> continuation = () =>\r
232                 new ActionExecutedContext(controllerContext, actionDescriptor, false /* canceled */, null /* exception */) {\r
233                     Result = InvokeActionMethod(controllerContext, actionDescriptor, parameters)\r
234                 };\r
235 \r
236             // need to reverse the filter list because the continuations are built up backward\r
237             Func<ActionExecutedContext> thunk = filters.Reverse().Aggregate(continuation,\r
238                 (next, filter) => () => InvokeActionMethodFilter(filter, preContext, next));\r
239             return thunk();\r
240         }\r
241 \r
242         protected virtual void InvokeActionResult(ControllerContext controllerContext, ActionResult actionResult) {\r
243             actionResult.ExecuteResult(controllerContext);\r
244         }\r
245 \r
246         internal static ResultExecutedContext InvokeActionResultFilter(IResultFilter filter, ResultExecutingContext preContext, Func<ResultExecutedContext> continuation) {\r
247             filter.OnResultExecuting(preContext);\r
248             if (preContext.Cancel) {\r
249                 return new ResultExecutedContext(preContext, preContext.Result, true /* canceled */, null /* exception */);\r
250             }\r
251 \r
252             bool wasError = false;\r
253             ResultExecutedContext postContext = null;\r
254             try {\r
255                 postContext = continuation();\r
256             }\r
257             catch (ThreadAbortException) {\r
258                 // This type of exception occurs as a result of Response.Redirect(), but we special-case so that\r
259                 // the filters don't see this as an error.\r
260                 postContext = new ResultExecutedContext(preContext, preContext.Result, false /* canceled */, null /* exception */);\r
261                 filter.OnResultExecuted(postContext);\r
262                 throw;\r
263             }\r
264             catch (Exception ex) {\r
265                 wasError = true;\r
266                 postContext = new ResultExecutedContext(preContext, preContext.Result, false /* canceled */, ex);\r
267                 filter.OnResultExecuted(postContext);\r
268                 if (!postContext.ExceptionHandled) {\r
269                     throw;\r
270                 }\r
271             }\r
272             if (!wasError) {\r
273                 filter.OnResultExecuted(postContext);\r
274             }\r
275             return postContext;\r
276         }\r
277 \r
278         protected virtual ResultExecutedContext InvokeActionResultWithFilters(ControllerContext controllerContext, IList<IResultFilter> filters, ActionResult actionResult) {\r
279             ResultExecutingContext preContext = new ResultExecutingContext(controllerContext, actionResult);\r
280             Func<ResultExecutedContext> continuation = delegate {\r
281                 InvokeActionResult(controllerContext, actionResult);\r
282                 return new ResultExecutedContext(controllerContext, actionResult, false /* canceled */, null /* exception */);\r
283             };\r
284 \r
285             // need to reverse the filter list because the continuations are built up backward\r
286             Func<ResultExecutedContext> thunk = filters.Reverse().Aggregate(continuation,\r
287                 (next, filter) => () => InvokeActionResultFilter(filter, preContext, next));\r
288             return thunk();\r
289         }\r
290 \r
291         protected virtual AuthorizationContext InvokeAuthorizationFilters(ControllerContext controllerContext, IList<IAuthorizationFilter> filters, ActionDescriptor actionDescriptor) {\r
292             AuthorizationContext context = new AuthorizationContext(controllerContext, actionDescriptor);\r
293             foreach (IAuthorizationFilter filter in filters) {\r
294                 filter.OnAuthorization(context);\r
295                 // short-circuit evaluation\r
296                 if (context.Result != null) {\r
297                     break;\r
298                 }\r
299             }\r
300 \r
301             return context;\r
302         }\r
303 \r
304         protected virtual ExceptionContext InvokeExceptionFilters(ControllerContext controllerContext, IList<IExceptionFilter> filters, Exception exception) {\r
305             ExceptionContext context = new ExceptionContext(controllerContext, exception);\r
306             foreach (IExceptionFilter filter in filters) {\r
307                 filter.OnException(context);\r
308             }\r
309 \r
310             return context;\r
311         }\r
312 \r
313         [SuppressMessage("Microsoft.Performance", "CA1804:RemoveUnusedLocals", MessageId = "rawUrl",\r
314             Justification = "We only care about the property getter's side effects, not the returned value.")]\r
315         internal static void ValidateRequest(ControllerContext controllerContext) {\r
316             if (controllerContext.IsChildAction) {\r
317                 return;\r
318             }\r
319 \r
320             // DevDiv 214040: Enable Request Validation by default for all controller requests\r
321             // \r
322             // Note that we grab the Request's RawUrl to force it to be validated. Calling ValidateInput()\r
323             // doesn't actually validate anything. It just sets flags indicating that on the next usage of\r
324             // certain inputs that they should be validated. We special case RawUrl because the URL has already\r
325             // been consumed by routing and thus might contain dangerous data. By forcing the RawUrl to be\r
326             // re-read we're making sure that it gets validated by ASP.NET.\r
327 \r
328             controllerContext.HttpContext.Request.ValidateInput();\r
329             string rawUrl = controllerContext.HttpContext.Request.RawUrl;\r
330         }\r
331 \r
332     }\r
333 }\r