3e46e43dd30a1ce787d5269bb09f2c54d8240abf
[mono.git] / mcs / class / referencesource / System.Data.Entity / System / Data / Mapping / Update / Internal / DynamicUpdateCommand.cs
1 //---------------------------------------------------------------------
2 // <copyright file="DynamicUpdateCommand.cs" company="Microsoft">
3 //      Copyright (c) Microsoft Corporation.  All rights reserved.
4 // </copyright>
5 //
6 // @owner Microsoft
7 // @backupOwner Microsoft
8 //---------------------------------------------------------------------
9
10
11 using System.Collections.Generic;
12 using System.Data.Common.CommandTrees;
13 using System.Data.Metadata.Edm;
14 using System.Data.Common;
15 using System.Data.EntityClient;
16 using System.Diagnostics;
17 using System.Data.Common.Utils;
18 using System.Linq;
19 using System.Data.Common.CommandTrees.ExpressionBuilder;
20 using System.Data.Spatial;
21
22 namespace System.Data.Mapping.Update.Internal
23 {
24     internal sealed class DynamicUpdateCommand : UpdateCommand
25     {
26         private readonly ModificationOperator m_operator;
27         private readonly TableChangeProcessor m_processor;
28         private readonly List<KeyValuePair<int, DbSetClause>> m_inputIdentifiers;
29         private readonly Dictionary<int, string> m_outputIdentifiers;
30         private readonly DbModificationCommandTree m_modificationCommandTree;
31
32
33         internal DynamicUpdateCommand(TableChangeProcessor processor, UpdateTranslator translator, ModificationOperator op,
34             PropagatorResult originalValues, PropagatorResult currentValues, DbModificationCommandTree tree,
35             Dictionary<int, string> outputIdentifiers)
36             : base(originalValues, currentValues)
37         {
38             m_processor = EntityUtil.CheckArgumentNull(processor, "processor");
39             m_operator = op;
40             m_modificationCommandTree = EntityUtil.CheckArgumentNull(tree, "commandTree");
41             m_outputIdentifiers = outputIdentifiers; // may be null (not all commands have output identifiers)
42
43             // initialize identifier information (supports lateral propagation of server gen values)
44             if (ModificationOperator.Insert == op || ModificationOperator.Update == op)
45             {
46                 const int capacity = 2; // "average" number of identifiers per row
47                 m_inputIdentifiers = new List<KeyValuePair<int ,DbSetClause>>(capacity);
48
49                 foreach (KeyValuePair<EdmMember, PropagatorResult> member in
50                     Helper.PairEnumerations(TypeHelpers.GetAllStructuralMembers(this.CurrentValues.StructuralType),
51                                              this.CurrentValues.GetMemberValues()))
52                 {
53                     DbSetClause setter;
54                     int identifier = member.Value.Identifier;
55
56                     if (PropagatorResult.NullIdentifier != identifier &&
57                         TryGetSetterExpression(tree, member.Key, op, out setter)) // can find corresponding setter
58                     {
59                         foreach (int principal in translator.KeyManager.GetPrincipals(identifier))
60                         {
61                             m_inputIdentifiers.Add(new KeyValuePair<int, DbSetClause>(principal, setter));
62                         }
63                     }
64                 }
65             }
66         }
67
68         // effects: try to find setter expression for the given member
69         // requires: command tree must be an insert or update tree (since other DML trees hnabve 
70         private static bool TryGetSetterExpression(DbModificationCommandTree tree, EdmMember member, ModificationOperator op, out DbSetClause setter)
71         {
72             Debug.Assert(op == ModificationOperator.Insert || op == ModificationOperator.Update, "only inserts and updates have setters");
73             IEnumerable<DbModificationClause> clauses;
74             if (ModificationOperator.Insert == op)
75             {
76                 clauses = ((DbInsertCommandTree)tree).SetClauses;
77             }
78             else
79             {
80                 clauses = ((DbUpdateCommandTree)tree).SetClauses;
81             }
82             foreach (DbSetClause setClause in clauses)
83             {
84                 // check if this is the correct setter
85                 if (((DbPropertyExpression)setClause.Property).Property.EdmEquals(member))
86                 {
87                     setter = setClause;
88                     return true;
89                 }
90             }
91
92             // no match found
93             setter = null;
94             return false;
95         }
96
97         internal override long Execute(UpdateTranslator translator, EntityConnection connection, Dictionary<int, object> identifierValues, List<KeyValuePair<PropagatorResult, object>> generatedValues)
98         {
99             // Compile command
100             using (DbCommand command = this.CreateCommand(translator, identifierValues))
101             {
102                 // configure command to use the connection and transaction for this session
103                 command.Transaction = ((null != connection.CurrentTransaction) ? connection.CurrentTransaction.StoreTransaction : null);
104                 command.Connection = connection.StoreConnection;
105                 if (translator.CommandTimeout.HasValue)
106                 {
107                     command.CommandTimeout = translator.CommandTimeout.Value;
108                 }
109
110                 // Execute the query
111                 int rowsAffected;
112                 if (m_modificationCommandTree.HasReader)
113                 {
114                     // retrieve server gen results
115                     rowsAffected = 0;
116                     using (DbDataReader reader = command.ExecuteReader(CommandBehavior.SequentialAccess))
117                     {
118                         if (reader.Read())
119                         {
120                             rowsAffected++;
121
122                             IBaseList<EdmMember> members = TypeHelpers.GetAllStructuralMembers(this.CurrentValues.StructuralType);
123
124                             for (int ordinal = 0; ordinal < reader.FieldCount; ordinal++)
125                             {
126                                 // column name of result corresponds to column name of table
127                                 string columnName = reader.GetName(ordinal);
128                                 EdmMember member = members[columnName];
129                                 object value;
130                                 if (Helper.IsSpatialType(member.TypeUsage) && !reader.IsDBNull(ordinal))
131                                 {
132                                     value = SpatialHelpers.GetSpatialValue(translator.MetadataWorkspace, reader, member.TypeUsage, ordinal);
133                                 }
134                                 else
135                                 {
136                                     value = reader.GetValue(ordinal);
137                                 }
138
139                                 // retrieve result which includes the context for back-propagation
140                                 int columnOrdinal = members.IndexOf(member);
141                                 PropagatorResult result = this.CurrentValues.GetMemberValue(columnOrdinal);
142
143                                 // register for back-propagation
144                                 generatedValues.Add(new KeyValuePair<PropagatorResult, object>(result, value));
145
146                                 // register identifier if it exists
147                                 int identifier = result.Identifier;
148                                 if (PropagatorResult.NullIdentifier != identifier)
149                                 {
150                                     identifierValues.Add(identifier, value);
151                                 }
152                             }
153                         }
154
155                         // Consume the current reader (and subsequent result sets) so that any errors
156                         // executing the command can be intercepted
157                         CommandHelper.ConsumeReader(reader);
158                     }
159                 }
160                 else
161                 {
162                     rowsAffected = command.ExecuteNonQuery();
163                 }
164
165                 return rowsAffected;
166             }
167         }
168
169         /// <summary>
170         /// Gets DB command definition encapsulating store logic for this command.
171         /// </summary>
172         private DbCommand CreateCommand(UpdateTranslator translator, Dictionary<int, object> identifierValues)
173         {
174             DbModificationCommandTree commandTree = m_modificationCommandTree;
175
176             // check if any server gen identifiers need to be set
177             if (null != m_inputIdentifiers)
178             {
179                 Dictionary<DbSetClause, DbSetClause> modifiedClauses = new Dictionary<DbSetClause, DbSetClause>();
180                 for (int idx = 0; idx < m_inputIdentifiers.Count; idx++)
181                 {
182                     KeyValuePair<int, DbSetClause> inputIdentifier = m_inputIdentifiers[idx];
183
184                     object value;
185                     if (identifierValues.TryGetValue(inputIdentifier.Key, out value))
186                     {
187                         // reset the value of the identifier
188                         DbSetClause newClause = new DbSetClause(inputIdentifier.Value.Property, DbExpressionBuilder.Constant(value));
189                         modifiedClauses[inputIdentifier.Value] = newClause;
190                         m_inputIdentifiers[idx] = new KeyValuePair<int, DbSetClause>(inputIdentifier.Key, newClause);
191                     }
192                 }
193                 commandTree = RebuildCommandTree(commandTree, modifiedClauses);
194             }
195
196             return translator.CreateCommand(commandTree);
197         }
198
199         private DbModificationCommandTree RebuildCommandTree(DbModificationCommandTree originalTree, Dictionary<DbSetClause, DbSetClause> clauseMappings)
200         {
201             if (clauseMappings.Count == 0)
202             {
203                 return originalTree;
204             }
205
206             DbModificationCommandTree result;
207             Debug.Assert(originalTree.CommandTreeKind == DbCommandTreeKind.Insert || originalTree.CommandTreeKind == DbCommandTreeKind.Update, "Set clauses specified for a modification tree that is not an update or insert tree?");
208             if (originalTree.CommandTreeKind == DbCommandTreeKind.Insert)
209             {
210                 DbInsertCommandTree insertTree = (DbInsertCommandTree)originalTree;
211                 result = new DbInsertCommandTree(insertTree.MetadataWorkspace, insertTree.DataSpace, 
212                     insertTree.Target, ReplaceClauses(insertTree.SetClauses, clauseMappings).AsReadOnly(), insertTree.Returning);
213             }
214             else
215             {
216                 DbUpdateCommandTree updateTree = (DbUpdateCommandTree)originalTree;
217                 result = new DbUpdateCommandTree(updateTree.MetadataWorkspace, updateTree.DataSpace,
218                     updateTree.Target, updateTree.Predicate, ReplaceClauses(updateTree.SetClauses, clauseMappings).AsReadOnly(), updateTree.Returning);
219             }
220
221             return result;
222         }
223
224         /// <summary>
225         /// Creates a new list of modification clauses with the specified remapped clauses replaced.
226         /// </summary>
227         private List<DbModificationClause> ReplaceClauses(IList<DbModificationClause> originalClauses, Dictionary<DbSetClause, DbSetClause> mappings)
228         {
229             List<DbModificationClause> result = new List<DbModificationClause>(originalClauses.Count);
230             for (int idx = 0; idx < originalClauses.Count; idx++)
231             {
232                 DbSetClause replacementClause;
233                 if (mappings.TryGetValue((DbSetClause)originalClauses[idx], out replacementClause))
234                 {
235                     result.Add(replacementClause);
236                 }
237                 else
238                 {
239                     result.Add(originalClauses[idx]);
240                 }
241             }
242             return result;
243         }
244
245         internal ModificationOperator Operator { get { return m_operator; } }
246
247         internal override EntitySet Table { get { return this.m_processor.Table; } }
248
249         internal override IEnumerable<int> InputIdentifiers 
250         { 
251             get 
252             {
253                 if (null == m_inputIdentifiers)
254                 {
255                     yield break;
256                 }
257                 else
258                 {
259                     foreach (KeyValuePair<int, DbSetClause> inputIdentifier in m_inputIdentifiers)
260                     {
261                         yield return inputIdentifier.Key;
262                     }
263                 }
264             } 
265         }
266
267         internal override IEnumerable<int> OutputIdentifiers 
268         { 
269             get 
270             { 
271                 if (null == m_outputIdentifiers)
272                 {
273                     return Enumerable.Empty<int>();
274                 }
275                 return m_outputIdentifiers.Keys; 
276             } 
277         }
278
279         internal override UpdateCommandKind Kind
280         {
281             get { return UpdateCommandKind.Dynamic; }
282         }
283
284         internal override IList<IEntityStateEntry> GetStateEntries(UpdateTranslator translator)
285         {
286             List<IEntityStateEntry> stateEntries = new List<IEntityStateEntry>(2);
287             if (null != this.OriginalValues)
288             {
289                 foreach (IEntityStateEntry stateEntry in SourceInterpreter.GetAllStateEntries(
290                     this.OriginalValues, translator, this.Table))
291                 {
292                     stateEntries.Add(stateEntry);
293                 }
294             }
295
296             if (null != this.CurrentValues)
297             {
298                 foreach (IEntityStateEntry stateEntry in SourceInterpreter.GetAllStateEntries(
299                     this.CurrentValues, translator, this.Table))
300                 {
301                     stateEntries.Add(stateEntry);
302                 }
303             }
304             return stateEntries;
305         }
306
307         internal override int CompareToType(UpdateCommand otherCommand)
308         {
309             Debug.Assert(!object.ReferenceEquals(this, otherCommand), "caller is supposed to ensure otherCommand is different reference");
310
311             DynamicUpdateCommand other = (DynamicUpdateCommand)otherCommand;
312
313             // order by operation type
314             int result = (int)this.Operator - (int)other.Operator;
315             if (0 != result) { return result; }
316
317             // order by Container.Table
318             result = StringComparer.Ordinal.Compare(this.m_processor.Table.Name, other.m_processor.Table.Name);
319             if (0 != result) { return result; }
320             result = StringComparer.Ordinal.Compare(this.m_processor.Table.EntityContainer.Name, other.m_processor.Table.EntityContainer.Name);
321             if (0 != result) { return result; }
322             
323             // order by table key
324             PropagatorResult thisResult = (this.Operator == ModificationOperator.Delete ? this.OriginalValues : this.CurrentValues);
325             PropagatorResult otherResult = (other.Operator == ModificationOperator.Delete ? other.OriginalValues : other.CurrentValues);
326             for (int i = 0; i < m_processor.KeyOrdinals.Length; i++)
327             {
328                 int keyOrdinal = m_processor.KeyOrdinals[i];
329                 object thisValue = thisResult.GetMemberValue(keyOrdinal).GetSimpleValue();
330                 object otherValue = otherResult.GetMemberValue(keyOrdinal).GetSimpleValue();
331                 result = ByValueComparer.Default.Compare(thisValue, otherValue);
332                 if (0 != result) { return result; }
333             }
334
335             // If the result is still zero, it means key values are all the same. Switch to synthetic identifiers
336             // to differentiate.
337             for (int i = 0; i < m_processor.KeyOrdinals.Length; i++)
338             {
339                 int keyOrdinal = m_processor.KeyOrdinals[i];
340                 int thisValue = thisResult.GetMemberValue(keyOrdinal).Identifier;
341                 int otherValue = otherResult.GetMemberValue(keyOrdinal).Identifier;
342                 result = thisValue - otherValue;
343                 if (0 != result) { return result; }
344             }
345
346             return result;
347         }
348     }
349 }