75ddb26261836b352166235f279e60873aa1b37b
[mono.git] / mcs / class / referencesource / System.Data.Entity / System / Data / Common / CommandTrees / Internal / Validator.cs
1 //---------------------------------------------------------------------
2 // <copyright file="Validator.cs" company="Microsoft">
3 //      Copyright (c) Microsoft Corporation.  All rights reserved.
4 // </copyright>
5 //
6 // @owner  [....]
7 // @backupOwner [....]
8 //---------------------------------------------------------------------
9
10 namespace System.Data.Common.CommandTrees.Internal
11 {
12     using System;
13     using System.Collections.Generic;
14     using System.Data.Entity;
15     using System.Data.Metadata.Edm;
16     using System.Diagnostics;
17     using System.Linq;
18
19     internal sealed class DbExpressionValidator : DbExpressionRebinder
20     {
21         private readonly DataSpace requiredSpace;
22         private readonly DataSpace[] allowedMetadataSpaces;
23         private readonly DataSpace[] allowedFunctionSpaces;
24         private readonly Dictionary<string, DbParameterReferenceExpression> paramMappings = new Dictionary<string, DbParameterReferenceExpression>();
25         private readonly Stack<Dictionary<string, TypeUsage>> variableScopes = new Stack<Dictionary<string, TypeUsage>>();
26
27         private string expressionArgumentName;
28
29         internal DbExpressionValidator(MetadataWorkspace metadata, DataSpace expectedDataSpace)
30             : base(metadata)
31         {
32             this.requiredSpace = expectedDataSpace;
33             this.allowedFunctionSpaces = new[] { DataSpace.CSpace, DataSpace.SSpace };
34             if (expectedDataSpace == DataSpace.SSpace)
35             {
36                 this.allowedMetadataSpaces = new[] { DataSpace.SSpace, DataSpace.CSpace };
37             }
38             else
39             {
40                 this.allowedMetadataSpaces = new[] { DataSpace.CSpace };   
41             }
42         }
43
44         internal Dictionary<string, DbParameterReferenceExpression> Parameters { get { return this.paramMappings; } }
45
46         internal void ValidateExpression(DbExpression expression, string argumentName)
47         {
48             Debug.Assert(expression != null, "Ensure expression is non-null before calling ValidateExpression");
49             this.expressionArgumentName = argumentName;
50             this.VisitExpression(expression);
51             this.expressionArgumentName = null;
52             Debug.Assert(this.variableScopes.Count == 0, "Variable scope stack left in inconsistent state");
53         }
54
55         protected override EntitySetBase VisitEntitySet(EntitySetBase entitySet)
56         {
57             return ValidateMetadata(entitySet, base.VisitEntitySet, es => es.EntityContainer.DataSpace, this.allowedMetadataSpaces);
58         }
59
60         protected override EdmFunction VisitFunction(EdmFunction function)
61         {
62             // Functions from the current space and S-Space are allowed
63             return ValidateMetadata(function, base.VisitFunction, func => func.DataSpace, this.allowedFunctionSpaces);
64         }
65
66         protected override EdmType VisitType(EdmType type)
67         {
68             return ValidateMetadata(type, base.VisitType, et => et.DataSpace, this.allowedMetadataSpaces);
69         }
70
71         protected override TypeUsage VisitTypeUsage(TypeUsage type)
72         {
73             return ValidateMetadata(type, base.VisitTypeUsage, tu => tu.EdmType.DataSpace, this.allowedMetadataSpaces);
74         }
75
76         protected override void OnEnterScope(IEnumerable<DbVariableReferenceExpression> scopeVariables)
77         {
78             var newScope = scopeVariables.ToDictionary(var => var.VariableName, var => var.ResultType, StringComparer.Ordinal);
79             this.variableScopes.Push(newScope);
80         }
81
82         protected override void OnExitScope()
83         {
84             this.variableScopes.Pop();
85         }
86
87         public override DbExpression Visit(DbVariableReferenceExpression expression)
88         {
89             DbExpression result = base.Visit(expression);
90             if(result.ExpressionKind == DbExpressionKind.VariableReference)
91             {
92                 DbVariableReferenceExpression varRef = (DbVariableReferenceExpression)result;
93                 TypeUsage foundType = null;
94                 foreach(Dictionary<string, TypeUsage> scope in this.variableScopes)
95                 {
96                     if(scope.TryGetValue(varRef.VariableName, out foundType))
97                     {
98                         break;
99                     }
100                 }
101                 
102                 if(foundType == null)
103                 {
104                     ThrowInvalid(System.Data.Entity.Strings.Cqt_Validator_VarRefInvalid(varRef.VariableName));
105                 }
106                                 
107                 // SQLBUDT#545720: Equivalence is not a sufficient check (consider row types) - equality is required.
108                 if (!TypeSemantics.IsEqual(varRef.ResultType, foundType))
109                 {
110                     ThrowInvalid(System.Data.Entity.Strings.Cqt_Validator_VarRefTypeMismatch(varRef.VariableName));
111                 }
112             }
113
114             return result;
115         }
116
117         public override DbExpression Visit(DbParameterReferenceExpression expression)
118         {
119             DbExpression result = base.Visit(expression);
120             if (result.ExpressionKind == DbExpressionKind.ParameterReference)
121             {
122                 DbParameterReferenceExpression paramRef = result as DbParameterReferenceExpression;
123
124                 DbParameterReferenceExpression foundParam;
125                 if (this.paramMappings.TryGetValue(paramRef.ParameterName, out foundParam))
126                 {
127                     // SQLBUDT#545720: Equivalence is not a sufficient check (consider row types for TVPs) - equality is required.
128                     if (!TypeSemantics.IsEqual(paramRef.ResultType, foundParam.ResultType))
129                     {
130                         ThrowInvalid(Strings.Cqt_Validator_InvalidIncompatibleParameterReferences(paramRef.ParameterName));
131                     }
132                 }
133                 else
134                 {
135                     this.paramMappings.Add(paramRef.ParameterName, paramRef);
136                 }
137             }
138             return result;
139         }
140
141         private TMetadata ValidateMetadata<TMetadata>(TMetadata metadata, Func<TMetadata, TMetadata> map, Func<TMetadata, DataSpace> getDataSpace, DataSpace[] allowedSpaces)
142         {
143             TMetadata result = map(metadata);
144             if (!object.ReferenceEquals(metadata, result))
145             {
146                 ThrowInvalidMetadata(metadata);
147             }
148
149             DataSpace resultSpace = getDataSpace(result);
150             if (!allowedSpaces.Any(ds => ds == resultSpace))
151             {
152                 ThrowInvalidSpace(metadata);
153             }
154             return result;
155         }
156                 
157         private void ThrowInvalidMetadata<TMetadata>(TMetadata invalid)
158         {
159             ThrowInvalid(Strings.Cqt_Validator_InvalidOtherWorkspaceMetadata(typeof(TMetadata).Name));
160         }
161
162         private void ThrowInvalidSpace<TMetadata>(TMetadata invalid)
163         {
164             ThrowInvalid(Strings.Cqt_Validator_InvalidIncorrectDataSpaceMetadata(typeof(TMetadata).Name, Enum.GetName(typeof(DataSpace), this.requiredSpace)));
165         }
166
167         private void ThrowInvalid(string message)
168         {
169             throw EntityUtil.Argument(message, this.expressionArgumentName);
170         }
171     }
172 }