* SqlParameter.cs: Modified ConvertToFrameworkType to only perform
[mono.git] / mcs / class / System.Data / System.Data.SqlClient / SqlParameter.cs
index 8282b5c609399b8b42b2f4b980d0bd220247eccd..6774c0ddbebf84410b7ba1ea4ddf6cc9b257b2bb 100644 (file)
@@ -8,13 +8,14 @@
 //   Diego Caravana (diego@toth.it)
 //   Umadevi S (sumadevi@novell.com)
 //   Amit Biswas (amit@amitbiswas.com)
+//   Veerapuram Varadhan (vvaradhan@novell.com)
 //
 // (C) Ximian, Inc. 2002
 // Copyright (C) Tim Coleman, 2002
 //
 
 //
-// Copyright (C) 2004 Novell, Inc (http://www.novell.com)
+// Copyright (C) 2004, 2008, 2009 Novell, Inc (http://www.novell.com)
 //
 // Permission is hereby granted, free of charge, to any person obtaining
 // a copy of this software and associated documentation files (the
@@ -44,8 +45,10 @@ using System.ComponentModel;
 using System.Data;
 using System.Data.Common;
 using System.Data.SqlTypes;
+using System.Globalization;
 using System.Runtime.InteropServices;
 using System.Text;
+using System.Xml;
 
 namespace System.Data.SqlClient {
 #if NET_2_0
@@ -70,13 +73,15 @@ namespace System.Data.SqlClient {
                DataRowVersion sourceVersion;
                SqlCompareOptions compareInfo;
                int localeId;
-               Object sqlValue;
+               Type sqlType;
                bool typeChanged;
 #if NET_2_0
                bool sourceColumnNullMapping;
                string xmlSchemaCollectionDatabase = String.Empty;
                string xmlSchemaCollectionOwningSchema = String.Empty;
                string xmlSchemaCollectionName = String.Empty;
+#else
+               static Hashtable DbTypeMapping;
 #endif
 
                static Hashtable type_mapping;
@@ -85,15 +90,52 @@ namespace System.Data.SqlClient {
 
                #region Constructors
 
+               
                static SqlParameter ()
                {
+                       if (DbTypeMapping == null)
+                               DbTypeMapping = new Hashtable ();
+                       
+                       DbTypeMapping.Add (SqlDbType.BigInt, typeof (long));
+                       DbTypeMapping.Add (SqlDbType.Bit, typeof (bool));
+                       DbTypeMapping.Add (SqlDbType.Char, typeof (string));
+                       DbTypeMapping.Add (SqlDbType.NChar, typeof (string));
+                       DbTypeMapping.Add (SqlDbType.Text, typeof (string));
+                       DbTypeMapping.Add (SqlDbType.NText, typeof (string));
+                       DbTypeMapping.Add (SqlDbType.VarChar, typeof (string));
+                       DbTypeMapping.Add (SqlDbType.NVarChar, typeof (string));
+                       DbTypeMapping.Add (SqlDbType.SmallDateTime, typeof (DateTime));
+                       DbTypeMapping.Add (SqlDbType.DateTime, typeof (DateTime));
+                       DbTypeMapping.Add (SqlDbType.Decimal, typeof (decimal));
+                       DbTypeMapping.Add (SqlDbType.Float, typeof (double));
+                       DbTypeMapping.Add (SqlDbType.Binary, typeof (byte []));
+                       DbTypeMapping.Add (SqlDbType.Image, typeof (byte []));
+                       DbTypeMapping.Add (SqlDbType.Money, typeof (decimal));
+                       DbTypeMapping.Add (SqlDbType.SmallMoney, typeof (decimal));
+                       DbTypeMapping.Add (SqlDbType.VarBinary, typeof (byte []));
+                       DbTypeMapping.Add (SqlDbType.TinyInt, typeof (byte));
+                       DbTypeMapping.Add (SqlDbType.Int, typeof (int));
+                       DbTypeMapping.Add (SqlDbType.Real, typeof (float));
+                       DbTypeMapping.Add (SqlDbType.SmallInt, typeof (short));
+                       DbTypeMapping.Add (SqlDbType.UniqueIdentifier, typeof (Guid));
+                       DbTypeMapping.Add (SqlDbType.Variant, typeof (object));
+#if NET_2_0
+                       DbTypeMapping.Add (SqlDbType.Xml, typeof (string));
+#endif
+
                        type_mapping = new Hashtable ();
+
                        type_mapping.Add (typeof (long), SqlDbType.BigInt);
                        type_mapping.Add (typeof (SqlTypes.SqlInt64), SqlDbType.BigInt);
 
                        type_mapping.Add (typeof (bool), SqlDbType.Bit);
                        type_mapping.Add (typeof (SqlTypes.SqlBoolean), SqlDbType.Bit);
 
+#if NET_2_0
+                       type_mapping.Add (typeof (char), SqlDbType.NVarChar);
+                       type_mapping.Add (typeof (char []), SqlDbType.NVarChar);
+                       type_mapping.Add (typeof (SqlTypes.SqlChars), SqlDbType.NVarChar);
+#endif
                        type_mapping.Add (typeof (string), SqlDbType.NVarChar);
                        type_mapping.Add (typeof (SqlTypes.SqlString), SqlDbType.NVarChar);
 
@@ -108,6 +150,9 @@ namespace System.Data.SqlClient {
 
                        type_mapping.Add (typeof (byte []), SqlDbType.VarBinary);
                        type_mapping.Add (typeof (SqlTypes.SqlBinary), SqlDbType.VarBinary);
+#if NET_2_0
+                       type_mapping.Add (typeof (SqlTypes.SqlBytes), SqlDbType.VarBinary);
+#endif
 
                        type_mapping.Add (typeof (byte), SqlDbType.TinyInt);
                        type_mapping.Add (typeof (SqlTypes.SqlByte), SqlDbType.TinyInt);
@@ -126,9 +171,14 @@ namespace System.Data.SqlClient {
 
                        type_mapping.Add (typeof (SqlTypes.SqlMoney), SqlDbType.Money);
 
+#if NET_2_0
+                       type_mapping.Add (typeof (XmlReader), SqlDbType.Xml);
+                       type_mapping.Add (typeof (SqlTypes.SqlXml), SqlDbType.Xml);
+#endif
+
                        type_mapping.Add (typeof (object), SqlDbType.Variant);
                }
-
+               
                public SqlParameter () 
                        : this (String.Empty, SqlDbType.NVarChar, 0, ParameterDirection.Input, false, 0, 0, String.Empty, DataRowVersion.Current, null)
                {
@@ -170,7 +220,7 @@ namespace System.Data.SqlClient {
                                                              isNullable, precision, 
                                                              scale,
                                                              GetFrameworkValue);
-                       metaParameter.RawValue = value;
+                       metaParameter.RawValue =  value;
                        if (dbType != SqlDbType.Variant) 
                                SqlDbType = dbType;
                        Direction = direction;
@@ -217,7 +267,7 @@ namespace System.Data.SqlClient {
                                break;
                        }
 
-                       SetDbTypeName ((string) dbValues [16]);
+                       SqlDbType = (SqlDbType) FrameworkDbTypeFromName ((string) dbValues [16]);
 
                        if (MetaParameter.IsVariableSizeType) {
                                if (dbValues [10] != DBNull.Value)
@@ -454,7 +504,11 @@ namespace System.Data.SqlClient {
                override
 #endif // NET_2_0
                object Value {
-                       get { return metaParameter.RawValue; }
+                       get {
+                               if (sqlType != null)
+                                       return GetSqlValue (metaParameter.RawValue);
+                               return metaParameter.RawValue;
+                       }
                        set {
                                if (!isTypeSet) {
 #if NET_2_0
@@ -464,6 +518,11 @@ namespace System.Data.SqlClient {
                                                InferSqlType (value);
 #endif
                                }
+
+                               if (value is INullable) {
+                                       sqlType = value.GetType ();
+                                       value = SqlTypeToFrameworkType (value);
+                               }
                                metaParameter.RawValue = value;
                        }
                }
@@ -484,11 +543,10 @@ namespace System.Data.SqlClient {
                [Browsable (false)]
                [DesignerSerializationVisibility (DesignerSerializationVisibility.Hidden)]
                public Object SqlValue {
-                       get { return sqlValue; }
+                       get {
+                               return GetSqlValue (metaParameter.RawValue);
+                       }
                        set {
-                               sqlValue = value;
-                               if (value is INullable)
-                                       value = SqlTypeToFrameworkType (value);
                                Value = value;
                        }
                }
@@ -544,6 +602,84 @@ namespace System.Data.SqlClient {
                        SetSqlDbType ((SqlDbType) t);
                }
 
+               // Returns System.Type corresponding to the underlying SqlDbType
+#if NET_2_0
+               internal override
+#endif
+               Type SystemType {
+                       get {
+                               return (Type) DbTypeMapping [sqlDbType];
+                       }
+               }
+
+#if NET_2_0
+               internal override object FrameworkDbType {
+                       get {
+                               return sqlDbType;
+                       }
+                       
+                       set {
+                               object t;
+                               try {
+                                       t = (DbType) DbTypeFromName ((string)value);
+                                       SetDbType ((DbType)t);
+                               } catch (ArgumentException) {
+                                       t = (SqlDbType)FrameworkDbTypeFromName ((string)value);
+                                       SetSqlDbType ((SqlDbType) t);
+                               }
+                       }
+               }
+
+               DbType DbTypeFromName (string name)
+               {
+                       switch (name.ToLower ()) {
+                               case "ansistring":
+                                       return DbType.AnsiString;
+                               case "ansistringfixedlength":
+                                       return DbType.AnsiStringFixedLength;
+                               case "binary": 
+                                       return DbType.Binary;
+                               case "boolean":
+                                       return DbType.Boolean;
+                               case "byte":
+                                       return DbType.Byte;
+                               case "currency": 
+                                       return DbType.Currency;
+                               case "date":
+                                       return DbType.Date;
+                               case "datetime": 
+                                       return DbType.DateTime;
+                               case "decimal":
+                                       return DbType.Decimal;
+                               case "double": 
+                                       return DbType.Double;
+                               case "guid": 
+                                       return DbType.Guid;
+                               case "int16": 
+                                       return DbType.Int16;
+                               case "int32": 
+                                       return DbType.Int32;
+                               case "int64": 
+                                       return DbType.Int64;
+                               case "object": 
+                                       return DbType.Object;
+                               case "single": 
+                                       return DbType.Single;
+                               case "string": 
+                                       return DbType.String;
+                               case "stringfixedlength": 
+                                       return DbType.StringFixedLength;
+                               case "time": 
+                                       return DbType.Time;
+                               case "xml": 
+                                       return DbType.Xml;
+                               default:
+                                       string exception = String.Format ("No mapping exists from {0} to a known DbType.", name);
+                                       throw new ArgumentException (exception);
+                       }
+               }
+#endif
+
                // When the DbType is set, we also set the SqlDbType, as well as the SQL Server
                // string representation of the type name.  If the DbType is not convertible
                // to an SqlDbType, throw an exception.
@@ -644,89 +780,63 @@ namespace System.Data.SqlClient {
                }
 
                // Used by internal constructor which has a SQL Server typename
-               private void SetDbTypeName (string dbTypeName)
+               private SqlDbType FrameworkDbTypeFromName (string dbTypeName)
                {
                        switch (dbTypeName.ToLower ()) {        
                        case "bigint":
-                               SqlDbType = SqlDbType.BigInt;
-                               break;
+                               return SqlDbType.BigInt;
                        case "binary":
-                               SqlDbType = SqlDbType.Binary;
-                               break;
+                               return SqlDbType.Binary;
                        case "bit":
-                               SqlDbType = SqlDbType.Bit;
-                               break;
+                               return SqlDbType.Bit;
                        case "char":
-                               SqlDbType = SqlDbType.Char;
-                               break;
+                               return SqlDbType.Char;
                        case "datetime":
-                               SqlDbType = SqlDbType.DateTime;
-                               break;
+                               return SqlDbType.DateTime;
                        case "decimal":
-                               SqlDbType = SqlDbType.Decimal;
-                               break;
+                               return SqlDbType.Decimal;
                        case "float":
-                               SqlDbType = SqlDbType.Float;
-                               break;
+                               return SqlDbType.Float;
                        case "image":
-                               SqlDbType = SqlDbType.Image;
-                               break;
+                               return SqlDbType.Image;
                        case "int":
-                               SqlDbType = SqlDbType.Int;
-                               break;
+                               return SqlDbType.Int;
                        case "money":
-                               SqlDbType = SqlDbType.Money;
-                               break;
+                               return SqlDbType.Money;
                        case "nchar":
-                               SqlDbType = SqlDbType.NChar;
-                               break;
+                               return SqlDbType.NChar;
                        case "ntext":
-                               SqlDbType = SqlDbType.NText;
-                               break;
+                               return SqlDbType.NText;
                        case "nvarchar":
-                               SqlDbType = SqlDbType.NVarChar;
-                               break;
+                               return SqlDbType.NVarChar;
                        case "real":
-                               SqlDbType = SqlDbType.Real;
-                               break;
+                               return SqlDbType.Real;
                        case "smalldatetime":
-                               SqlDbType = SqlDbType.SmallDateTime;
-                               break;
+                               return SqlDbType.SmallDateTime;
                        case "smallint":
-                               SqlDbType = SqlDbType.SmallInt;
-                               break;
+                               return SqlDbType.SmallInt;
                        case "smallmoney":
-                               SqlDbType = SqlDbType.SmallMoney;
-                               break;
+                               return SqlDbType.SmallMoney;
                        case "text":
-                               SqlDbType = SqlDbType.Text;
-                               break;
+                               return SqlDbType.Text;
                        case "timestamp":
-                               SqlDbType = SqlDbType.Timestamp;
-                               break;
+                               return SqlDbType.Timestamp;
                        case "tinyint":
-                               SqlDbType = SqlDbType.TinyInt;
-                               break;
+                               return SqlDbType.TinyInt;
                        case "uniqueidentifier":
-                               SqlDbType = SqlDbType.UniqueIdentifier;
-                               break;
+                               return SqlDbType.UniqueIdentifier;
                        case "varbinary":
-                               SqlDbType = SqlDbType.VarBinary;
-                               break;
+                               return SqlDbType.VarBinary;
                        case "varchar":
-                               SqlDbType = SqlDbType.VarChar;
-                               break;
+                               return SqlDbType.VarChar;
                        case "sql_variant":
-                               SqlDbType = SqlDbType.Variant;
-                               break;
-#if NET_2_0                            
+                               return SqlDbType.Variant;
+#if NET_2_0
                        case "xml":
-                               SqlDbType = SqlDbType.Xml;
-                               break;
+                               return SqlDbType.Xml;
 #endif
                        default:
-                               SqlDbType = SqlDbType.Variant;
-                               break;
+                               return SqlDbType.Variant;
                        }
                }
 
@@ -872,12 +982,106 @@ namespace System.Data.SqlClient {
                                tdsValue = null;
                        return tdsValue;
                }
-
+               
+               // TODO: Code copied from SqlDataReader, need a better approach
+               object GetSqlValue (object value)
+               {               
+                       if (value == null)
+                               return value;
+                       switch (sqlDbType) {
+                       case SqlDbType.BigInt:
+                               if (value == DBNull.Value)
+                                       return SqlInt64.Null;
+                               return (SqlInt64) ((long) value);
+                       case SqlDbType.Binary:
+                       case SqlDbType.Image:
+                       case SqlDbType.VarBinary:
+                       case SqlDbType.Timestamp:
+                               if (value == DBNull.Value)
+                                       return SqlBinary.Null;
+                               return (SqlBinary) (byte[]) value;
+                       case SqlDbType.Bit:
+                               if (value == DBNull.Value)
+                                       return SqlBoolean.Null;
+                               return (SqlBoolean) ((bool) value);
+                       case SqlDbType.Char:
+                       case SqlDbType.NChar:
+                       case SqlDbType.NText:
+                       case SqlDbType.NVarChar:
+                       case SqlDbType.Text:
+                       case SqlDbType.VarChar:
+                               if (value == DBNull.Value)
+                                       return SqlString.Null;
+
+                               string str;
+                               Type type = value.GetType ();
+                               if (type == typeof (char))
+                                       str = value.ToString ();
+                               else if (type == typeof (char[]))
+                                       str = new String ((char[])value);
+                               else
+                                       str = ((string)value);
+                                       return (SqlString) str;
+                       case SqlDbType.DateTime:
+                       case SqlDbType.SmallDateTime:
+                               if (value == DBNull.Value)
+                                       return SqlDateTime.Null;
+                               return (SqlDateTime) ((DateTime) value);
+                       case SqlDbType.Decimal:
+                               if (value == DBNull.Value)
+                                       return SqlDecimal.Null;
+                               if (value is TdsBigDecimal)
+                                       return SqlDecimal.FromTdsBigDecimal ((TdsBigDecimal) value);
+                               return (SqlDecimal) ((decimal) value);
+                       case SqlDbType.Float:
+                               if (value == DBNull.Value)
+                                       return SqlDouble.Null;
+                               return (SqlDouble) ((double) value);
+                       case SqlDbType.Int:
+                               if (value == DBNull.Value)
+                                       return SqlInt32.Null;
+                               return (SqlInt32) ((int) value);
+                       case SqlDbType.Money:
+                       case SqlDbType.SmallMoney:
+                               if (value == DBNull.Value)
+                                       return SqlMoney.Null;
+                               return (SqlMoney) ((decimal) value);
+                       case SqlDbType.Real:
+                               if (value == DBNull.Value)
+                                       return SqlSingle.Null;
+                               return (SqlSingle) ((float) value);
+                       case SqlDbType.UniqueIdentifier:
+                               if (value == DBNull.Value)
+                                       return SqlGuid.Null;
+                               return (SqlGuid) ((Guid) value);
+                       case SqlDbType.SmallInt:
+                               if (value == DBNull.Value)
+                                       return SqlInt16.Null;
+                               return (SqlInt16) ((short) value);
+                       case SqlDbType.TinyInt:
+                               if (value == DBNull.Value)
+                                       return SqlByte.Null;
+                               return (SqlByte) ((byte) value);
+#if NET_2_0
+                       case SqlDbType.Xml:
+                               if (value == DBNull.Value)
+                                       return SqlXml.Null;
+                               return (SqlXml) value;
+#endif
+                       default:
+                               throw new NotImplementedException ("Type '" + sqlDbType + "' not implemented.");
+                       }
+               }
+               
                object SqlTypeToFrameworkType (object value)
                {
-                       if (!(value is INullable)) // if the value is not SqlType
+                       INullable nullable = value as INullable;
+                       if (nullable == null)
                                return ConvertToFrameworkType (value);
 
+                       if (nullable.IsNull)
+                               return DBNull.Value;
+
                        Type type = value.GetType ();
                        // Map to .net type, as Mono TDS respects only types from .net
 
@@ -904,7 +1108,16 @@ namespace System.Data.SqlClient {
                        if (typeof (SqlBinary) == type) {
                                return ((SqlBinary) value).Value;
                        }
+                       
+#if NET_2_0
+                       if (typeof (SqlBytes) == type) {
+                               return ((SqlBytes) value).Value;
+                       }
 
+                       if (typeof (SqlChars) == type) {
+                               return ((SqlChars) value).Value;
+                       }
+#endif
                        if (typeof (SqlBoolean) == type) {
                                return ((SqlBoolean) value).Value;
                        }
@@ -929,10 +1142,6 @@ namespace System.Data.SqlClient {
                                return ((SqlMoney) value).Value;
                        }
 
-                       if (typeof (SqlMoney) == type) {
-                               return ((SqlMoney) value).Value;
-                       }
-
                        if (typeof (SqlSingle) == type) {
                                return ((SqlSingle) value).Value;
                        }
@@ -944,58 +1153,34 @@ namespace System.Data.SqlClient {
                {
                        if (value == null || value == DBNull.Value)
                                return value;
-                       
-                       if (value is string && ((string)value).Length == 0)
-                               return DBNull.Value;
-                       
-                       switch (sqlDbType)  {
-                       case SqlDbType.BigInt :
-                               return Convert.ChangeType (value, typeof (Int64));
-                       case SqlDbType.Binary:
-                       case SqlDbType.Image:
-                       case SqlDbType.VarBinary:
-                               if (value is byte[])
-                                       return value;
-                               break;
-                       case SqlDbType.Bit:
-                               return Convert.ChangeType (value, typeof (bool));
-                       case SqlDbType.Int:
-                               return Convert.ChangeType (value, typeof (Int32));
-                       case SqlDbType.SmallInt :
-                               return Convert.ChangeType (value, typeof (Int16));
-                       case SqlDbType.TinyInt :
-                               return Convert.ChangeType (value, typeof (byte));
-                       case SqlDbType.Float:
-                               return Convert.ChangeType (value, typeof (Double));
-                       case SqlDbType.Real:
-                               return Convert.ChangeType (value, typeof (Single));
-                       case SqlDbType.Decimal:
-                               return Convert.ChangeType (value, typeof (Decimal));
-                       case SqlDbType.Money:
-                       case SqlDbType.SmallMoney:
-                               {
-                                       Decimal val = (Decimal)Convert.ChangeType (value, typeof (Decimal));
-                                       return Decimal.Round(val, 4);
-                               }
-                       case SqlDbType.DateTime:
-                       case SqlDbType.SmallDateTime:
-                               return Convert.ChangeType (value, typeof (DateTime));
-                       case SqlDbType.VarChar:
-                       case SqlDbType.NVarChar:
-                       case SqlDbType.Char:
-                       case SqlDbType.NChar:
-                       case SqlDbType.Text:
-                       case SqlDbType.NText:
-#if NET_2_0
-                       case SqlDbType.Xml:
-#endif
-                               return Convert.ChangeType (value,  typeof (string));
-                       case SqlDbType.UniqueIdentifier:
-                               return Convert.ChangeType (value,  typeof (Guid));
-                       case SqlDbType.Variant:
+                       if (sqlDbType == SqlDbType.Variant)
                                return metaParameter.Value;
+
+                       Type frameworkType = SystemType;
+                       if (frameworkType == null)
+                               throw new NotImplementedException ("Type Not Supported : " + sqlDbType.ToString());
+
+                       Type valueType = value.GetType ();
+                       if (valueType == frameworkType)
+                               return value;
+
+                       object sqlvalue = null;
+
+                       try {
+                               sqlvalue = Convert.ChangeType (value, frameworkType);
+                               switch (sqlDbType) {
+                               case SqlDbType.Money:
+                               case SqlDbType.SmallMoney:
+                                       sqlvalue = Decimal.Round ((decimal) sqlvalue, 4);
+                                       break;
+                               }
+                       } catch (FormatException ex) {
+                               throw new FormatException (string.Format (CultureInfo.InvariantCulture,
+                                       "Parameter value could not be converted from {0} to {1}.",
+                                       valueType.Name, frameworkType.Name), ex);
                        }
-                       throw new  NotImplementedException ("Type Not Supported : " + sqlDbType.ToString());
+
+                       return sqlvalue;
                }
 
 #if NET_2_0