Updates referencesource to .NET 4.7
[mono.git] / mcs / class / referencesource / System.Data.Entity / System / Data / Common / CommandTrees / Internal / ExpressionCopier.cs
1 //---------------------------------------------------------------------
2 // <copyright file="DbExpressionRebinder.cs" company="Microsoft">
3 //      Copyright (c) Microsoft Corporation.  All rights reserved.
4 // </copyright>
5 //
6 // @owner  Microsoft
7 // @backupOwner Microsoft
8 //---------------------------------------------------------------------
9
10 namespace System.Data.Common.CommandTrees.Internal
11 {
12     using System.Collections.Generic;
13     using System.Data.Common;
14     using System.Data.Common.CommandTrees;
15     using System.Data.Common.CommandTrees.ExpressionBuilder;
16     using System.Data.Common.EntitySql;
17     using System.Data.Metadata.Edm;
18     using System.Diagnostics;
19     using System.Linq;
20
21     /// <summary>
22     /// Ensures that all metadata in a given expression tree is from the specified metadata workspace,
23     /// potentially rebinding and rebuilding the expressions to appropriate replacement metadata where necessary.
24     /// </summary>
25     internal class DbExpressionRebinder : DefaultExpressionVisitor
26     {
27         private readonly MetadataWorkspace _metadata;
28         private readonly Perspective _perspective;
29
30         protected DbExpressionRebinder(MetadataWorkspace targetWorkspace)
31         {
32             Debug.Assert(targetWorkspace != null, "Metadata workspace is null");
33             _metadata = targetWorkspace;
34             _perspective = new ModelPerspective(targetWorkspace);
35         }
36
37         // 
38         [System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")]
39         internal static DbExpression BindToWorkspace(DbExpression expression, MetadataWorkspace targetWorkspace)
40         {
41             Debug.Assert(expression != null, "expression is null");
42
43             DbExpressionRebinder copier = new DbExpressionRebinder(targetWorkspace);
44             return copier.VisitExpression(expression);
45         }
46
47         protected override EntitySetBase VisitEntitySet(EntitySetBase entitySet)
48         {
49             EntityContainer container;
50             if (_metadata.TryGetEntityContainer(entitySet.EntityContainer.Name, entitySet.EntityContainer.DataSpace, out container))
51             {
52                 EntitySetBase extent = null;
53                 if (container.BaseEntitySets.TryGetValue(entitySet.Name, false, out extent) &&
54                     extent != null &&
55                     entitySet.BuiltInTypeKind == extent.BuiltInTypeKind) // EntitySet -> EntitySet, AssociationSet -> AssociationSet, etc
56                 {
57                     return extent;
58                 }
59
60                 throw EntityUtil.Argument(System.Data.Entity.Strings.Cqt_Copier_EntitySetNotFound(entitySet.EntityContainer.Name, entitySet.Name));
61             }
62
63             throw EntityUtil.Argument(System.Data.Entity.Strings.Cqt_Copier_EntityContainerNotFound(entitySet.EntityContainer.Name));
64         }
65
66         protected override EdmFunction VisitFunction(EdmFunction function)
67         {
68             List<TypeUsage> paramTypes = new List<TypeUsage>(function.Parameters.Count);
69             foreach (FunctionParameter funcParam in function.Parameters)
70             {
71                 TypeUsage mappedParamType = this.VisitTypeUsage(funcParam.TypeUsage);
72                 paramTypes.Add(mappedParamType);
73             }
74
75             if (DataSpace.SSpace == function.DataSpace)
76             {
77                 EdmFunction foundFunc = null;
78                 if (_metadata.TryGetFunction(function.Name,
79                                              function.NamespaceName,
80                                              paramTypes.ToArray(),
81                                              false /* ignoreCase */,
82                                              function.DataSpace,
83                                              out foundFunc) &&
84                     foundFunc != null)
85                 {
86                     return foundFunc;
87                 }
88             }
89             else
90             {
91                 // Find the function or function import.
92                 IList<EdmFunction> candidateFunctions;
93                 if (_perspective.TryGetFunctionByName(function.NamespaceName, function.Name, /*ignoreCase:*/ false, out candidateFunctions))
94                 {
95                     Debug.Assert(null != candidateFunctions && candidateFunctions.Count > 0, "Perspective.TryGetFunctionByName returned true with null/empty function result list");
96
97                     bool isAmbiguous;
98                     EdmFunction retFunc = FunctionOverloadResolver.ResolveFunctionOverloads(candidateFunctions, paramTypes, /*isGroupAggregateFunction:*/ false, out isAmbiguous);
99                     if (!isAmbiguous &&
100                         retFunc != null)
101                     {
102                         return retFunc;
103                     }
104                 }
105             }
106
107             throw EntityUtil.Argument(System.Data.Entity.Strings.Cqt_Copier_FunctionNotFound(TypeHelpers.GetFullName(function)));
108         }
109
110         protected override EdmType VisitType(EdmType type)
111         {
112             EdmType retType = type;
113
114             if (BuiltInTypeKind.RefType == type.BuiltInTypeKind)
115             {
116                 RefType refType = (RefType)type;
117                 EntityType mappedEntityType = (EntityType)this.VisitType(refType.ElementType);
118                 if (!object.ReferenceEquals(refType.ElementType, mappedEntityType))
119                 {
120                     retType = new RefType(mappedEntityType);
121                 }
122             }
123             else if (BuiltInTypeKind.CollectionType == type.BuiltInTypeKind)
124             {
125                 CollectionType collectionType = (CollectionType)type;
126                 TypeUsage mappedElementType = this.VisitTypeUsage(collectionType.TypeUsage);
127                 if (!object.ReferenceEquals(collectionType.TypeUsage, mappedElementType))
128                 {
129                     retType = new CollectionType(mappedElementType);
130                 }
131             }
132             else if (BuiltInTypeKind.RowType == type.BuiltInTypeKind)
133             {
134                 RowType rowType = (RowType)type;
135                 List<KeyValuePair<string, TypeUsage>> mappedPropInfo = null;
136                 for (int idx = 0; idx < rowType.Properties.Count; idx++)
137                 {
138                     EdmProperty originalProp = rowType.Properties[idx];
139                     TypeUsage mappedPropType = this.VisitTypeUsage(originalProp.TypeUsage);
140                     if (!object.ReferenceEquals(originalProp.TypeUsage, mappedPropType))
141                     {
142                         if (mappedPropInfo == null)
143                         {
144                             mappedPropInfo = new List<KeyValuePair<string, TypeUsage>>(
145                                                 rowType.Properties.Select(
146                                                     prop => new KeyValuePair<string, TypeUsage>(prop.Name, prop.TypeUsage)
147                                                 ));
148                         }
149                         mappedPropInfo[idx] = new KeyValuePair<string,TypeUsage>(originalProp.Name, mappedPropType);
150                     }
151                 }
152                 if (mappedPropInfo != null)
153                 {
154                     IEnumerable<EdmProperty> mappedProps = mappedPropInfo.Select(propInfo => new EdmProperty(propInfo.Key, propInfo.Value));
155                     retType = new RowType(mappedProps, rowType.InitializerMetadata);
156                 }
157             }
158             else
159             {
160                 if (!_metadata.TryGetType(type.Name, type.NamespaceName, type.DataSpace, out retType) ||
161                     null == retType)
162                 {
163                     throw EntityUtil.Argument(System.Data.Entity.Strings.Cqt_Copier_TypeNotFound(TypeHelpers.GetFullName(type)));
164                 }
165             }
166
167             return retType;
168         }
169                         
170         protected override TypeUsage VisitTypeUsage(TypeUsage type)
171         {
172             //
173             // If the target metatadata workspace contains the same type instances, then the type does not
174             // need to be 'mapped' and the same TypeUsage instance may be returned. This can happen if the
175             // target workspace and the workspace of the source Command Tree are using the same ItemCollection.
176             //
177             EdmType retEdmType = this.VisitType(type.EdmType);
178             if (object.ReferenceEquals(retEdmType, type.EdmType))
179             {
180                 return type;
181             }
182
183             //
184             // Retrieve the Facets from this type usage so that
185             // 1) They can be used to map the type if it is a primitive type
186             // 2) They can be applied to the new type usage that references the mapped type
187             //
188             Facet[] facets = new Facet[type.Facets.Count];
189             int idx = 0;
190             foreach (Facet f in type.Facets)
191             {
192                 facets[idx] = f;
193                 idx++;
194             }
195
196             return TypeUsage.Create(retEdmType, facets);
197         }
198
199         private bool TryGetMember<TMember>(DbExpression instance, string memberName, out TMember member) where TMember : EdmMember
200         {
201             member = null;
202             StructuralType declType = instance.ResultType.EdmType as StructuralType;
203             if (declType != null)
204             {
205                 EdmMember foundMember = null;
206                 if (declType.Members.TryGetValue(memberName, false, out foundMember))
207                 {
208                     member = foundMember as TMember;
209                 }
210             }
211
212             return (member != null);
213         }
214
215         public override DbExpression Visit(DbPropertyExpression expression)
216         {
217             EntityUtil.CheckArgumentNull(expression, "expression");
218
219             DbExpression result = expression;
220             DbExpression newInstance = this.VisitExpression(expression.Instance);
221             if (!object.ReferenceEquals(expression.Instance, newInstance))
222             {
223                 if (Helper.IsRelationshipEndMember(expression.Property))
224                 {
225                     RelationshipEndMember endMember;
226                     if(!TryGetMember(newInstance, expression.Property.Name, out endMember))
227                     {
228                         throw EntityUtil.Argument(System.Data.Entity.Strings.Cqt_Copier_EndNotFound(expression.Property.Name, TypeHelpers.GetFullName(newInstance.ResultType.EdmType)));
229                     }
230                     result = DbExpressionBuilder.Property(newInstance, endMember);
231                 }
232                 else if (Helper.IsNavigationProperty(expression.Property))
233                 {
234                     NavigationProperty navProp;
235                     if (!TryGetMember(newInstance, expression.Property.Name, out navProp))
236                     {
237                         throw EntityUtil.Argument(System.Data.Entity.Strings.Cqt_Copier_NavPropertyNotFound(expression.Property.Name, TypeHelpers.GetFullName(newInstance.ResultType.EdmType)));
238                     }
239                     result = DbExpressionBuilder.Property(newInstance, navProp);
240                 }
241                 else
242                 {
243                     EdmProperty prop;
244                     if (!TryGetMember(newInstance, expression.Property.Name, out prop))
245                     {
246                         throw EntityUtil.Argument(System.Data.Entity.Strings.Cqt_Copier_PropertyNotFound(expression.Property.Name, TypeHelpers.GetFullName(newInstance.ResultType.EdmType)));
247                     }
248                     result = DbExpressionBuilder.Property(newInstance, prop);
249                 }
250             }
251             return result;
252         }
253     }
254 }