4c772c2aa74480e7b792809c9e9366717197a184
[mono.git] / mcs / class / System.Data / System.Data.SqlClient / SqlBulkCopy.cs
1 //
2 // System.Data.SqlClient.SqlBulkCopy.cs
3 //
4 // Author:
5 //   Nagappan A (anagappan@novell.com)
6 //
7 // (C) Novell, Inc 2007
8
9 //
10 // Copyright (C) 2007 Novell, Inc (http://www.novell.com)
11 //
12 // Permission is hereby granted, free of charge, to any person obtaining
13 // a copy of this software and associated documentation files (the
14 // "Software"), to deal in the Software without restriction, including
15 // without limitation the rights to use, copy, modify, merge, publish,
16 // distribute, sublicense, and/or sell copies of the Software, and to
17 // permit persons to whom the Software is furnished to do so, subject to
18 // the following conditions:
19 // 
20 // The above copyright notice and this permission notice shall be
21 // included in all copies or substantial portions of the Software.
22 // 
23 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
26 // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
27 // LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
28 // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
29 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
30 //
31 #if NET_2_0
32
33 using System;
34 using System.Data;
35 using System.Data.Common;
36 using Mono.Data.Tds;
37 using Mono.Data.Tds.Protocol;
38
39 namespace System.Data.SqlClient {
40         /// <summary>Efficient way to bulk load SQL Server table with several data rows at once</summary>
41         public sealed class SqlBulkCopy : IDisposable 
42         {
43                 #region Constants
44                 private const string transConflictMessage = "Must not specify SqlBulkCopyOptions.UseInternalTransaction " +
45                         "and pass an external Transaction at the same time.";
46                 #endregion
47                 
48                 #region Fields
49
50                 private int _batchSize = 0;
51                 private int _notifyAfter = 0;
52                 private int _bulkCopyTimeout = 0;
53                 private SqlBulkCopyColumnMappingCollection _columnMappingCollection = new SqlBulkCopyColumnMappingCollection ();
54                 private string _destinationTableName = null;
55                 private bool ordinalMapping = false;
56                 private bool sqlRowsCopied = false;
57                 private bool identityInsert = false;
58                 private bool isLocalConnection = false;
59                 private SqlConnection connection;
60                 private SqlTransaction externalTransaction;
61                 private SqlBulkCopyOptions copyOptions = SqlBulkCopyOptions.Default;
62
63                 #endregion
64
65                 #region Constructors
66                 public SqlBulkCopy (SqlConnection connection)
67                 {
68                         if (connection == null) {
69                                 throw new ArgumentNullException("connection");
70                         }
71                         
72                         this.connection = connection;
73                 }
74
75                 public SqlBulkCopy (string connectionString)
76                 {
77                         if (connectionString == null) {
78                                 throw new ArgumentNullException("connectionString");
79                         }
80                         
81                         this.connection = new SqlConnection (connectionString);
82                         isLocalConnection = true;
83                 }
84
85                 [MonoTODO]
86                 public SqlBulkCopy (string connectionString, SqlBulkCopyOptions copyOptions)
87                 {
88                         if (connectionString == null) {
89                                 throw new ArgumentNullException ("connectionString");
90                         }
91                         
92                         this.connection = new SqlConnection (connectionString);
93                         isLocalConnection = true;
94                         
95                         switch (copyOptions) {
96                         case SqlBulkCopyOptions.Default:
97                                 this.copyOptions = copyOptions;
98                                 break;
99                                 
100                         default:
101                                 throw new NotImplementedException ("We don't know how to process non-default copyOptions.");
102                         }
103                 }
104
105                 [MonoTODO]
106                 public SqlBulkCopy (SqlConnection connection, SqlBulkCopyOptions copyOptions, SqlTransaction externalTransaction)
107                 {
108                         if (connection == null) {
109                                 throw new ArgumentNullException ("connection");
110                         }
111                         
112                         this.connection = connection;
113                         this.copyOptions = copyOptions;
114                         
115                         if ((copyOptions & SqlBulkCopyOptions.UseInternalTransaction) == SqlBulkCopyOptions.UseInternalTransaction) {
116                                 if (externalTransaction != null)
117                                         throw new ArgumentException (transConflictMessage);
118                         }
119                         else
120                                 this.externalTransaction = externalTransaction;
121                         
122                         switch (copyOptions) {
123                         case SqlBulkCopyOptions.Default:
124                                 this.copyOptions = copyOptions;
125                                 break;
126                                 
127                         default:
128                                 throw new NotImplementedException ("We don't know how to process non-default copyOptions.");
129                         }
130                 }
131
132                 #endregion
133
134                 #region Properties
135
136                 public int BatchSize {
137                         get { return _batchSize; }
138                         set { _batchSize = value; }
139                 }
140
141                 public int BulkCopyTimeout {
142                         get { return _bulkCopyTimeout; }
143                         set { _bulkCopyTimeout = value; }
144                 }
145
146                 public SqlBulkCopyColumnMappingCollection ColumnMappings  {
147                         get { return _columnMappingCollection; }
148                 }
149
150                 public string DestinationTableName {
151                         get { return _destinationTableName; }
152                         set { _destinationTableName = value; }
153                 }
154
155                 public int NotifyAfter {
156                         get { return _notifyAfter; }
157                         set {
158                                 if (value < 0)
159                                         throw new ArgumentOutOfRangeException ("NotifyAfter should be greater than or equal to 0");
160                                 _notifyAfter = value;
161                         }
162                 }
163
164                 #endregion
165
166                 #region Methods
167
168                 public void Close ()
169                 {
170                         if (sqlRowsCopied == true) {
171                                 throw new InvalidOperationException ("Close should not be called from SqlRowsCopied event");
172                         }
173                         if (connection == null || connection.State == ConnectionState.Closed) {
174                                 return;
175                         }
176                         connection.Close ();
177                 }
178
179                 private DataTable [] GetColumnMetaData ()
180                 {
181                         DataTable [] columnMetaDataTables = new DataTable [2];
182                         SqlCommand cmd = new SqlCommand ("select @@trancount; " +
183                                                          "set fmtonly on select * from " +
184                                                          DestinationTableName + " set fmtonly off;" +
185                                                          "exec sp_tablecollations_90 '" +
186                                                          DestinationTableName + "'",
187                                                          connection);
188                         SqlDataReader reader = cmd.ExecuteReader ();
189                         int i = 0; // Skipping 1st result
190                         do {
191                                   if (i == 1) {
192                                         columnMetaDataTables [i - 1] = reader.GetSchemaTable ();
193                                   } else if (i == 2) {
194                                         SqlDataAdapter adapter = new SqlDataAdapter ();
195                                         adapter.MissingSchemaAction = MissingSchemaAction.AddWithKey;
196                                         columnMetaDataTables [i - 1] = new DataTable ();
197                                         adapter.FillInternal (columnMetaDataTables [i - 1], reader);
198                                 }
199                                 i++;
200                         } while (reader.IsClosed == false && reader.NextResult());
201                         reader.Close ();
202                         return columnMetaDataTables;
203                 }
204
205                 private string GenerateColumnMetaData (SqlCommand tmpCmd, DataTable colMetaData, DataTable tableCollations)
206                 {
207                         bool flag = false;
208                         string statement = "";
209                         int i = 0;
210                         foreach (DataRow row in colMetaData.Rows) {
211                                 flag = false;
212                                 foreach (DataColumn col in colMetaData.Columns) { // FIXME: This line not required, remove later
213                                         object value = null;
214                                         if (_columnMappingCollection.Count > 0) {
215                                                 if (ordinalMapping) {
216                                                         foreach (SqlBulkCopyColumnMapping mapping
217                                                                  in _columnMappingCollection) {
218                                                                 if (mapping.DestinationOrdinal == i) {
219                                                                         flag = true;
220                                                                         break;
221                                                                 }
222                                                         }
223                                                 } else {
224                                                         foreach (SqlBulkCopyColumnMapping mapping
225                                                                  in _columnMappingCollection) {
226                                                                 if (mapping.DestinationColumn == (string) row ["ColumnName"]) {
227                                                                         flag = true;
228                                                                         break;
229                                                                 }
230                                                         }
231                                                 }
232                                                 if (flag == false)
233                                                         break;
234                                         }
235                                         if ((bool)row ["IsReadOnly"]) {
236                                                 if (ordinalMapping)
237                                                         value = false;
238                                                 else
239                                                         break;
240                                         }
241                                         SqlParameter param = new SqlParameter ((string) row ["ColumnName"],
242                                                                                ((SqlDbType) row ["ProviderType"]));
243                                         param.Value = value;
244                                         if ((int)row ["ColumnSize"] != -1) {
245                                                 param.Size = (int) row ["ColumnSize"];
246                                         }
247                                         tmpCmd.Parameters.Add (param);
248                                         break;
249                                 }
250                                 i++;
251                         }
252                         flag = false;
253                         bool insertSt = false;
254                         foreach (DataRow row in colMetaData.Rows) {
255                                 if (_columnMappingCollection.Count > 0) {
256                                         i = 0;
257                                         insertSt = false;
258                                         foreach (SqlParameter param in tmpCmd.Parameters) {
259                                                 if (ordinalMapping) {
260                                                         foreach (SqlBulkCopyColumnMapping mapping
261                                                                  in _columnMappingCollection) {
262                                                                 if (mapping.DestinationOrdinal == i && param.Value == null) {
263                                                                         insertSt = true;
264                                                                 }
265                                                         }
266                                                 } else {
267                                                         foreach (SqlBulkCopyColumnMapping mapping
268                                                                  in _columnMappingCollection) {
269                                                                 if (mapping.DestinationColumn == param.ParameterName &&
270                                                                     (string)row ["ColumnName"] == param.ParameterName) {
271                                                                         insertSt = true;
272                                                                         param.Value = null;
273                                                                 }
274                                                         }
275                                                 }
276                                                 i++;
277                                                 if (insertSt == true)
278                                                         break;
279                                         }
280                                         if (insertSt == false)
281                                                 continue;
282                                 }
283                                 if ((bool)row ["IsReadOnly"]) {
284                                         continue;
285                                 }
286                                 string columnInfo = "";
287                                 if ((int)row ["ColumnSize"] != -1) {
288                                         columnInfo = string.Format ("{0}({1})",
289                                                                     (SqlDbType) row ["ProviderType"],
290                                                                     row ["ColumnSize"]);
291                                 } else {
292                                         columnInfo = string.Format ("{0}", (SqlDbType) row ["ProviderType"]);
293                                 }
294                                 if (flag)
295                                         statement += ", ";
296                                 string columnName = (string) row ["ColumnName"];
297                                 statement += string.Format ("[{0}] {1}", columnName, columnInfo);
298                                 if (flag == false)
299                                         flag = true;
300                                 if (tableCollations != null) {
301                                         foreach (DataRow collationRow in tableCollations.Rows) {
302                                                 if ((string)collationRow ["name"] == columnName) {
303                                                         statement += string.Format (" COLLATE {0}", collationRow ["collation"]);
304                                                         break;
305                                                 }
306                                         }
307                                 }
308                         }
309                         return statement;
310                 }
311
312                 private void ValidateColumnMapping (DataTable table, DataTable tableCollations)
313                 {
314                         foreach (SqlBulkCopyColumnMapping _columnMapping in _columnMappingCollection) {
315                                 if (ordinalMapping == false &&
316                                     (_columnMapping.DestinationColumn == String.Empty ||
317                                      _columnMapping.SourceColumn == String.Empty))
318                                         throw new InvalidOperationException ("Mappings must be either all null or ordinal");
319                                 if (ordinalMapping &&
320                                     (_columnMapping.DestinationOrdinal == -1 ||
321                                      _columnMapping.SourceOrdinal == -1))
322                                         throw new InvalidOperationException ("Mappings must be either all null or ordinal");
323                                 bool flag = false;
324                                 if (ordinalMapping == false) {
325                                         foreach (DataRow row in tableCollations.Rows) {
326                                                 if ((string)row ["name"] == _columnMapping.DestinationColumn) {
327                                                         flag = true;
328                                                         break;
329                                                 }
330                                         }
331                                         if (flag == false)
332                                                 throw new InvalidOperationException ("ColumnMapping does not match");
333                                         flag = false;
334                                         foreach (DataColumn col in table.Columns) {
335                                                 if (col.ColumnName == _columnMapping.SourceColumn) {
336                                                         flag = true;
337                                                         break;
338                                                 }
339                                         }
340                                         if (flag == false)
341                                                 throw new InvalidOperationException ("ColumnName " +
342                                                                                      _columnMapping.SourceColumn +
343                                                                                      " does not match");
344                                 } else {
345                                         if (_columnMapping.DestinationOrdinal >= tableCollations.Rows.Count)
346                                                 throw new InvalidOperationException ("ColumnMapping does not match");
347                                 }
348                         }
349                 }
350
351                 private void BulkCopyToServer (DataTable table, DataRowState state)
352                 {
353                         if (connection == null || connection.State == ConnectionState.Closed)
354                                 throw new InvalidOperationException ("This method should not be called on a closed connection");
355                         if (_destinationTableName == null)
356                                 throw new ArgumentNullException ("DestinationTableName");
357                         if (identityInsert) {
358                                 SqlCommand cmd = new SqlCommand ("set identity_insert " +
359                                                                  table.TableName + " on",
360                                                                  connection);
361                                 cmd.ExecuteScalar ();
362                         }
363                         DataTable [] columnMetaDataTables = GetColumnMetaData ();
364                         DataTable colMetaData = columnMetaDataTables [0];
365                         DataTable tableCollations = columnMetaDataTables [1];
366
367                         if (_columnMappingCollection.Count > 0) {
368                                 if (_columnMappingCollection [0].SourceOrdinal != -1)
369                                         ordinalMapping = true;
370                                 ValidateColumnMapping (table, tableCollations);
371                         }
372
373                         SqlCommand tmpCmd = new SqlCommand ();
374                         TdsBulkCopy blkCopy = new TdsBulkCopy ((Tds)connection.Tds);
375                         if (((Tds)connection.Tds).TdsVersion >= TdsVersion.tds70) {
376                                 string statement = "insert bulk " + DestinationTableName + " (";
377                                 statement += GenerateColumnMetaData (tmpCmd, colMetaData, tableCollations);
378                                 statement += ")";
379                                 blkCopy.SendColumnMetaData (statement);
380                         }
381                         blkCopy.BulkCopyStart (tmpCmd.Parameters.MetaParameters);
382                         long noRowsCopied = 0;
383                         foreach (DataRow row in table.Rows) {
384                                 if (row.RowState == DataRowState.Deleted)
385                                         continue; // Don't copy the row that's in deleted state
386                                 if (state != 0 && row.RowState != state)
387                                         continue;
388                                 bool isNewRow = true;
389                                 int i = 0;
390                                 foreach (SqlParameter param in tmpCmd.Parameters) {
391                                         int size = 0;
392                                         object rowToCopy = null;
393                                         if (_columnMappingCollection.Count > 0) {
394                                                 if (ordinalMapping) {
395                                                         foreach (SqlBulkCopyColumnMapping mapping
396                                                                  in _columnMappingCollection) {
397                                                                 if (mapping.DestinationOrdinal == i && param.Value == null) {
398                                                                         rowToCopy = row [mapping.SourceOrdinal];
399                                                                         SqlParameter parameter = new SqlParameter (mapping.SourceOrdinal.ToString (),
400                                                                                                                    rowToCopy);
401                                                                         if (param.MetaParameter.TypeName != parameter.MetaParameter.TypeName) {
402                                                                                 parameter.SqlDbType = param.SqlDbType;
403                                                                                 rowToCopy = parameter.Value = parameter.ConvertToFrameworkType (rowToCopy);
404                                                                         }
405                                                                         string colType = string.Format ("{0}", parameter.MetaParameter.TypeName);
406                                                                         if (colType == "nvarchar") {
407                                                                                 if (row [i] != null) {
408                                                                                         size = ((string) parameter.Value).Length;
409                                                                                         size <<= 1;
410                                                                                 }
411                                                                         } else {
412                                                                                 size = parameter.Size;
413                                                                         }
414                                                                         break;
415                                                                 }
416                                                         }
417                                                 } else {
418                                                         foreach (SqlBulkCopyColumnMapping mapping
419                                                                  in _columnMappingCollection) {
420                                                                 if (mapping.DestinationColumn == param.ParameterName) {
421                                                                         rowToCopy = row [mapping.SourceColumn];
422                                                                         SqlParameter parameter = new SqlParameter (mapping.SourceColumn, rowToCopy);
423                                                                         if (param.MetaParameter.TypeName != parameter.MetaParameter.TypeName) {
424                                                                                 parameter.SqlDbType = param.SqlDbType;
425                                                                                 rowToCopy = parameter.Value = parameter.ConvertToFrameworkType (rowToCopy);
426                                                                         }
427                                                                         string colType = string.Format ("{0}", parameter.MetaParameter.TypeName);
428                                                                         if (colType == "nvarchar") {
429                                                                                 if (row [mapping.SourceColumn] != null) {
430                                                                                         size = ((string) rowToCopy).Length;
431                                                                                         size <<= 1;
432                                                                                 }
433                                                                         } else {
434                                                                                 size = parameter.Size;
435                                                                         }
436                                                                         break;
437                                                                 }
438                                                         }
439                                                 }
440                                                 i++;
441                                         } else {
442                                                 rowToCopy = row [param.ParameterName];
443                                                 string colType = param.MetaParameter.TypeName;
444                                                 /*
445                                                   If column type is SqlDbType.NVarChar the size of parameter is multiplied by 2
446                                                   FIXME: Need to check for other types
447                                                 */
448                                                 if (colType == "nvarchar") {
449                                                         size = ((string) row [param.ParameterName]).Length;
450                                                         size <<= 1;
451                                                 } else {
452                                                         size = param.Size;
453                                                 }
454                                         }
455                                         if (rowToCopy == null)
456                                                 continue;
457                                         blkCopy.BulkCopyData (rowToCopy, size, isNewRow);
458                                         if (isNewRow)
459                                                 isNewRow = false;
460                                 } // foreach (SqlParameter)
461                                 if (_notifyAfter > 0) {
462                                         noRowsCopied ++;
463                                         if (noRowsCopied >= _notifyAfter) {
464                                                 RowsCopied (noRowsCopied);
465                                                 noRowsCopied = 0;
466                                         }
467                                 }
468                         } // foreach (DataRow)
469                         blkCopy.BulkCopyEnd ();
470                 }
471
472                 public void WriteToServer (DataRow [] rows)
473                 {
474                         if (rows == null)
475                                 throw new ArgumentNullException ("rows");
476                         DataTable table = new DataTable (rows [0].Table.TableName);
477                         foreach (DataColumn col in rows [0].Table.Columns) {
478                                 DataColumn tmpCol = new DataColumn (col.ColumnName, col.DataType);
479                                 table.Columns.Add (tmpCol);
480                         }
481                         foreach (DataRow row in rows) {
482                                 DataRow tmpRow = table.NewRow ();
483                                 for (int i = 0; i < table.Columns.Count; i++) {
484                                         tmpRow [i] = row [i];
485                                 }
486                                 table.Rows.Add (tmpRow);
487                         }
488                         BulkCopyToServer (table, 0);
489                 }
490
491                 public void WriteToServer (DataTable table)
492                 {
493                         BulkCopyToServer (table, 0);
494                 }
495
496                 public void WriteToServer (IDataReader reader)
497                 {
498                         DataTable table = new DataTable ();
499                         SqlDataAdapter adapter = new SqlDataAdapter ();
500                         adapter.FillInternal (table, reader);
501                         BulkCopyToServer (table, 0);
502                 }
503
504                 public void WriteToServer (DataTable table, DataRowState rowState)
505                 {
506                         BulkCopyToServer (table, rowState);
507                 }
508
509                 private void RowsCopied (long rowsCopied)
510                 {
511                         SqlRowsCopiedEventArgs e = new SqlRowsCopiedEventArgs (rowsCopied);
512                         if (null != SqlRowsCopied) {
513                                 SqlRowsCopied (this, e);
514                         }
515                 }
516
517                 #endregion
518
519                 #region Events
520
521                 public event SqlRowsCopiedEventHandler SqlRowsCopied;
522
523                 #endregion
524
525                 void IDisposable.Dispose ()
526                 {
527                         //throw new NotImplementedException ();
528                         if (isLocalConnection) {
529                                 Close ();
530                                 connection = null;
531                         }
532                 }
533
534         }
535 }
536
537 #endif