a735b2cba4677006d6a81871c4942c385590c509
[mono.git] / mcs / class / referencesource / System.Data.Entity / System / Data / Mapping / Update / Internal / FunctionUpdateCommand.cs
1 //---------------------------------------------------------------------
2 // <copyright file="FunctionUpdateCommand.cs" company="Microsoft">
3 //      Copyright (c) Microsoft Corporation.  All rights reserved.
4 // </copyright>
5 //
6 // @owner Microsoft
7 // @backupOwner Microsoft
8 //---------------------------------------------------------------------
9
10 namespace System.Data.Mapping.Update.Internal
11 {
12     using System.Collections.Generic;
13     using System.Data.Common;
14     using System.Data.Common.Utils;
15     using System.Data.EntityClient;
16     using System.Data.Metadata.Edm;
17     using System.Diagnostics;
18     using System.Globalization;
19     using System.Linq;
20     using System.Data.Spatial;
21
22     /// <summary>
23     /// Aggregates information about a modification command delegated to a store function.
24     /// </summary>
25     internal sealed class FunctionUpdateCommand : UpdateCommand
26     {
27         #region Constructors
28         /// <summary>
29         /// Initialize a new function command. Initializes the command object.
30         /// </summary>
31         /// <param name="functionMapping">Function mapping metadata</param>
32         /// <param name="translator">Translator</param>
33         /// <param name="stateEntries">State entries handled by this operation.</param>
34         /// <param name="stateEntry">'Root' state entry being handled by this function.</param>
35         internal FunctionUpdateCommand(StorageModificationFunctionMapping functionMapping, UpdateTranslator translator,
36             System.Collections.ObjectModel.ReadOnlyCollection<IEntityStateEntry> stateEntries,
37             ExtractedStateEntry stateEntry)
38             : base(stateEntry.Original, stateEntry.Current)
39         {
40             EntityUtil.CheckArgumentNull(functionMapping, "functionMapping");
41             EntityUtil.CheckArgumentNull(translator, "translator");
42             EntityUtil.CheckArgumentNull(stateEntries, "stateEntries");
43
44             // populate the main state entry for error reporting
45             m_stateEntries = stateEntries;
46
47             // create a command
48             DbCommandDefinition commandDefinition = translator.GenerateCommandDefinition(functionMapping);
49             m_dbCommand = commandDefinition.CreateCommand();
50         }
51         #endregion
52
53         #region Fields
54         private readonly System.Collections.ObjectModel.ReadOnlyCollection<IEntityStateEntry> m_stateEntries;
55
56         /// <summary>
57         /// Gets the store command wrapped by this command.
58         /// </summary>
59         private readonly DbCommand m_dbCommand;
60
61         /// <summary>
62         /// Gets pairs for column names and propagator results (so that we can associate reader results with
63         /// the source records for server generated values).
64         /// </summary>
65         private List<KeyValuePair<string, PropagatorResult>> m_resultColumns;
66
67         /// <summary>
68         /// Gets map from identifiers (key component proxies) to parameters holding the actual
69         /// key values. Supports propagation of identifier values (fixup for server-gen keys)
70         /// </summary>
71         private List<KeyValuePair<int, DbParameter>> m_inputIdentifiers;
72
73         /// <summary>
74         /// Gets map from identifiers (key component proxies) to column names producing the actual
75         /// key values. Supports propagation of identifier values (fixup for server-gen keys)
76         /// </summary>
77         private Dictionary<int, string> m_outputIdentifiers;
78
79         /// <summary>
80         /// Gets a reference to the rows affected output parameter for the stored procedure. May be null.
81         /// </summary>
82         private DbParameter m_rowsAffectedParameter;
83         #endregion
84
85         #region Properties
86         internal override IEnumerable<int> InputIdentifiers
87         {
88             get 
89             {
90                 if (null == m_inputIdentifiers)
91                 {
92                     yield break;
93                 }
94                 else
95                 {
96                     foreach (KeyValuePair<int, DbParameter> inputIdentifier in m_inputIdentifiers)
97                     {
98                         yield return inputIdentifier.Key;
99                     }
100                 }
101             }
102         }
103
104         internal override IEnumerable<int> OutputIdentifiers
105         {
106             get
107             {
108                 if (null == m_outputIdentifiers)
109                 {
110                     return Enumerable.Empty<int>();
111                 }
112                 return m_outputIdentifiers.Keys; 
113             }
114         }
115
116         internal override UpdateCommandKind Kind
117         {
118             get { return UpdateCommandKind.Function; }
119         }
120         #endregion
121
122         #region Methods
123         /// <summary>
124         /// Gets state entries contributing to this function. Supports error reporting.
125         /// </summary>
126         internal override IList<IEntityStateEntry> GetStateEntries(UpdateTranslator translator)
127         {
128             return m_stateEntries;
129         }
130
131         // Adds and register a DbParameter to the current command.
132         internal void SetParameterValue(PropagatorResult result, StorageModificationFunctionParameterBinding parameterBinding, UpdateTranslator translator)
133         {
134             // retrieve DbParameter
135             DbParameter parameter = this.m_dbCommand.Parameters[parameterBinding.Parameter.Name];
136             TypeUsage parameterType = parameterBinding.Parameter.TypeUsage;
137             object parameterValue = translator.KeyManager.GetPrincipalValue(result);
138             translator.SetParameterValue(parameter, parameterType, parameterValue); 
139
140             // if the parameter corresponds to an identifier (key component), remember this fact in case
141             // it's important for dependency ordering (e.g., output the identifier before creating it)
142             int identifier = result.Identifier;
143             if (PropagatorResult.NullIdentifier != identifier)
144             {
145                 const int initialSize = 2; // expect on average less than two input identifiers per command
146                 if (null == m_inputIdentifiers)
147                 {
148                     m_inputIdentifiers = new List<KeyValuePair<int, DbParameter>>(initialSize);
149                 }
150                 foreach (int principal in translator.KeyManager.GetPrincipals(identifier))
151                 {
152                     m_inputIdentifiers.Add(new KeyValuePair<int, DbParameter>(principal, parameter));
153                 }
154             }
155         }
156
157         // Adds and registers a DbParameter taking the number of rows affected
158         internal void RegisterRowsAffectedParameter(FunctionParameter rowsAffectedParameter)
159         {
160             if (null != rowsAffectedParameter)
161             {
162                 Debug.Assert(rowsAffectedParameter.Mode == ParameterMode.Out || rowsAffectedParameter.Mode == ParameterMode.InOut,
163                     "when loading mapping metadata, we check that the parameter is an out parameter");
164                 m_rowsAffectedParameter = m_dbCommand.Parameters[rowsAffectedParameter.Name];
165             }
166         }
167
168         // Adds a result column binding from a column name (from the result set for the function) to
169         // a propagator result (which contains the context necessary to back-propagate the result).
170         // If the result is an identifier, binds the 
171         internal void AddResultColumn(UpdateTranslator translator, String columnName, PropagatorResult result)
172         {
173             const int initializeSize = 2; // expect on average less than two result columns per command
174             if (null == m_resultColumns)
175             {
176                 m_resultColumns = new List<KeyValuePair<string, PropagatorResult>>(initializeSize);
177             }
178             m_resultColumns.Add(new KeyValuePair<string, PropagatorResult>(columnName, result));
179
180             int identifier = result.Identifier;
181             if (PropagatorResult.NullIdentifier != identifier)
182             {
183                 if (translator.KeyManager.HasPrincipals(identifier))
184                 {
185                     throw EntityUtil.InvalidOperation(System.Data.Entity.Strings.Update_GeneratedDependent(columnName));
186                 }
187
188                 // register output identifier to enable fix-up and dependency tracking
189                 AddOutputIdentifier(columnName, identifier);
190             }
191         }
192
193         // Indicate that a column in the command result set (specified by 'columnName') produces the
194         // value for a key component (specified by 'identifier')
195         private void AddOutputIdentifier(String columnName, int identifier)
196         {
197             const int initialSize = 2; // expect on average less than two identifier output per command
198             if (null == m_outputIdentifiers)
199             {
200                 m_outputIdentifiers = new Dictionary<int, string>(initialSize);
201             }
202             m_outputIdentifiers[identifier] = columnName;
203         }
204
205         // efects: Executes the current function command in the given transaction and connection context.
206         // All server-generated values are added to the generatedValues list. If those values are identifiers, they are
207         // also added to the identifierValues dictionary, which associates proxy identifiers for keys in the session
208         // with their actual values, permitting fix-up of identifiers across relationships.
209         internal override long Execute(UpdateTranslator translator, EntityConnection connection, Dictionary<int, object> identifierValues,
210             List<KeyValuePair<PropagatorResult, object>> generatedValues)
211         {
212             // configure command to use the connection and transaction for this session
213             m_dbCommand.Transaction = ((null != connection.CurrentTransaction) ? connection.CurrentTransaction.StoreTransaction : null);
214             m_dbCommand.Connection = connection.StoreConnection;
215             if (translator.CommandTimeout.HasValue)
216             {
217                 m_dbCommand.CommandTimeout = translator.CommandTimeout.Value;
218             }
219
220             // set all identifier inputs (to support propagation of identifier values across relationship
221             // boundaries)
222             if (null != m_inputIdentifiers)
223             {
224                 foreach (KeyValuePair<int, DbParameter> inputIdentifier in m_inputIdentifiers)
225                 {
226                     object value;
227                     if (identifierValues.TryGetValue(inputIdentifier.Key, out value))
228                     {
229                         // set the actual value for the identifier if it has been produced by some
230                         // other command
231                         inputIdentifier.Value.Value = value;
232                     }
233                 }
234             }
235
236             // Execute the query
237             long rowsAffected;
238             if (null != m_resultColumns)
239             {
240                 // If there are result columns, read the server gen results
241                 rowsAffected = 0;
242                 IBaseList<EdmMember> members = TypeHelpers.GetAllStructuralMembers(this.CurrentValues.StructuralType);                
243                 using (DbDataReader reader = m_dbCommand.ExecuteReader(CommandBehavior.SequentialAccess))
244                 {
245                     // Retrieve only the first row from the first result set
246                     if (reader.Read())
247                     {
248                         rowsAffected++;
249
250                         foreach (var resultColumn in m_resultColumns
251                             .Select(r => new KeyValuePair<int, PropagatorResult>(GetColumnOrdinal(translator, reader, r.Key), r.Value))
252                             .OrderBy(r => r.Key)) // order by column ordinal to avoid breaking SequentialAccess readers
253                         {
254                             int columnOrdinal = resultColumn.Key;
255                             TypeUsage columnType = members[resultColumn.Value.RecordOrdinal].TypeUsage;
256                             object value;
257
258                             if (Helper.IsSpatialType(columnType) && !reader.IsDBNull(columnOrdinal))
259                             {
260                                 value = SpatialHelpers.GetSpatialValue(translator.MetadataWorkspace, reader, columnType, columnOrdinal);
261                             }
262                             else
263                             {
264                                 value = reader.GetValue(columnOrdinal);
265                             }
266
267                             // register for back-propagation
268                             PropagatorResult result = resultColumn.Value;
269                             generatedValues.Add(new KeyValuePair<PropagatorResult, object>(result, value));
270
271                             // register identifier if it exists
272                             int identifier = result.Identifier;
273                             if (PropagatorResult.NullIdentifier != identifier)
274                             {
275                                 identifierValues.Add(identifier, value);
276                             }
277                         }
278                     }
279
280                     // Consume the current reader (and subsequent result sets) so that any errors
281                     // executing the function can be intercepted
282                     CommandHelper.ConsumeReader(reader);
283                 }
284             }
285             else
286             {
287                 rowsAffected = m_dbCommand.ExecuteNonQuery();
288             }
289
290             // if an explicit rows affected parameter exists, use this value instead
291             if (null != m_rowsAffectedParameter)
292             {
293                 // by design, negative row counts indicate failure iff. an explicit rows
294                 // affected parameter is used
295                 if (DBNull.Value.Equals(m_rowsAffectedParameter.Value))
296                 {
297                     rowsAffected = 0;
298                 }
299                 else
300                 {
301                     try
302                     {
303                         rowsAffected = Convert.ToInt64(m_rowsAffectedParameter.Value, CultureInfo.InvariantCulture);
304                     }
305                     catch (Exception e)
306                     {
307                         if (UpdateTranslator.RequiresContext(e))
308                         {
309                             // wrap the exception
310                             throw EntityUtil.Update(System.Data.Entity.Strings.Update_UnableToConvertRowsAffectedParameterToInt32(
311                                 m_rowsAffectedParameter.ParameterName, typeof(int).FullName), e, this.GetStateEntries(translator));
312                         }
313                         throw;
314                     }
315                 }
316             }
317
318             return rowsAffected;
319         }
320
321         private int GetColumnOrdinal(UpdateTranslator translator, DbDataReader reader, string columnName)
322         {
323             int columnOrdinal;
324             try
325             {
326                 columnOrdinal = reader.GetOrdinal(columnName);
327             }
328             catch (IndexOutOfRangeException)
329             {
330                 throw EntityUtil.Update(System.Data.Entity.Strings.Update_MissingResultColumn(columnName), null,
331                     this.GetStateEntries(translator));
332             }
333             return columnOrdinal;
334         }
335
336         /// <summary>
337         /// Gets modification operator corresponding to the given entity state.
338         /// </summary>
339         private static ModificationOperator GetModificationOperator(EntityState state)
340         {
341             switch (state)
342             {
343                 case EntityState.Modified:
344                 case EntityState.Unchanged:
345                     // unchanged entities correspond to updates (consider the case where
346                     // the entity is not being modified but a collocated relationship is)
347                     return ModificationOperator.Update;
348
349                 case EntityState.Added:
350                     return ModificationOperator.Insert;
351
352                 case EntityState.Deleted:
353                     return ModificationOperator.Delete;
354
355                 default:
356                     Debug.Fail("unexpected entity state " + state);
357                     return default(ModificationOperator);
358             }
359         }
360
361         internal override int CompareToType(UpdateCommand otherCommand)
362         {
363             Debug.Assert(!object.ReferenceEquals(this, otherCommand), "caller should ensure other command is different");
364
365             FunctionUpdateCommand other = (FunctionUpdateCommand)otherCommand;
366
367             // first state entry is the 'main' state entry for the command (see ctor)
368             IEntityStateEntry thisParent = this.m_stateEntries[0];
369             IEntityStateEntry otherParent = other.m_stateEntries[0];
370
371             // order by operator
372             int result = (int)GetModificationOperator(thisParent.State) -
373                 (int)GetModificationOperator(otherParent.State);
374             if (0 != result) { return result; }
375
376             // order by entity set
377             result = StringComparer.Ordinal.Compare(thisParent.EntitySet.Name, otherParent.EntitySet.Name);
378             if (0 != result) { return result; }
379             result = StringComparer.Ordinal.Compare(thisParent.EntitySet.EntityContainer.Name, otherParent.EntitySet.EntityContainer.Name);
380             if (0 != result) { return result; }
381             
382             // order by key values
383             int thisInputIdentifierCount = (null == this.m_inputIdentifiers ? 0 : this.m_inputIdentifiers.Count);
384             int otherInputIdentifierCount = (null == other.m_inputIdentifiers ? 0 : other.m_inputIdentifiers.Count);
385             result = thisInputIdentifierCount - otherInputIdentifierCount;
386             if (0 != result) { return result; }
387             for (int i = 0; i < thisInputIdentifierCount; i++)
388             {
389                 DbParameter thisParameter = this.m_inputIdentifiers[i].Value;
390                 DbParameter otherParameter = other.m_inputIdentifiers[i].Value;
391                 result = ByValueComparer.Default.Compare(thisParameter.Value, otherParameter.Value);
392                 if (0 != result) { return result; }
393             }
394
395             // If the result is still zero, it means key values are all the same. Switch to synthetic identifiers
396             // to differentiate.
397             for (int i = 0; i < thisInputIdentifierCount; i++)
398             {
399                 int thisIdentifier = this.m_inputIdentifiers[i].Key;
400                 int otherIdentifier = other.m_inputIdentifiers[i].Key;
401                 result = thisIdentifier - otherIdentifier;
402                 if (0 != result) { return result; }
403             }
404
405             return result;
406         }
407
408         #endregion
409     }
410 }