2009-05-13 Atsushi Enomoto <atsushi@ximian.com>
[mono.git] / mcs / class / FirebirdSql.Data.Firebird / FirebirdSql.Data.Embedded / XsqldaMarshaler.cs
1 /*
2  *      Firebird ADO.NET Data provider for .NET and Mono 
3  * 
4  *         The contents of this file are subject to the Initial 
5  *         Developer's Public License Version 1.0 (the "License"); 
6  *         you may not use this file except in compliance with the 
7  *         License. You may obtain a copy of the License at 
8  *         http://www.firebirdsql.org/index.php?op=doc&id=idpl
9  *
10  *         Software distributed under the License is distributed on 
11  *         an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, either 
12  *         express or implied. See the License for the specific 
13  *         language governing rights and limitations under the License.
14  * 
15  *      Copyright (c) 2002, 2005 Carlos Guzman Alvarez
16  *      All Rights Reserved.
17  */
18
19 using System;
20 using System.Runtime.InteropServices;
21 using System.Text;
22
23 using FirebirdSql.Data.Common;
24
25 namespace FirebirdSql.Data.Embedded
26 {
27         internal sealed class XsqldaMarshaler
28         {
29                 #region Static Fields
30
31                 private static XsqldaMarshaler instance;
32
33                 #endregion
34
35                 #region Constructors
36
37                 private XsqldaMarshaler()
38                 {
39                 }
40
41                 #endregion
42
43                 #region Methods
44
45                 public static XsqldaMarshaler GetInstance()
46                 {
47                         if (XsqldaMarshaler.instance == null)
48                         {
49                                 XsqldaMarshaler.instance = new XsqldaMarshaler();
50                         }
51
52                         return XsqldaMarshaler.instance;
53                 }
54
55                 public void     CleanUpNativeData(ref IntPtr pNativeData)
56                 {
57                         if (pNativeData != IntPtr.Zero)
58                         {
59                                 // Obtain XSQLDA information
60                                 XSQLDA xsqlda = new     XSQLDA();
61                         
62                                 xsqlda = (XSQLDA)Marshal.PtrToStructure(pNativeData, typeof(XSQLDA));
63
64                                 // Destroy XSQLDA structure
65                                 Marshal.DestroyStructure(pNativeData, typeof(XSQLDA));
66
67                                 // Destroy XSQLVAR structures
68                                 for     (int i = 0;     i <     xsqlda.sqln; i++)
69                                 {
70                                         // Free sqldata and     sqlind pointers if needed
71                                         XSQLVAR sqlvar = (XSQLVAR)Marshal.PtrToStructure(
72                                                 this.GetIntPtr(pNativeData,     this.ComputeLength(i)), typeof(XSQLVAR));
73
74                                         if (sqlvar.sqldata != IntPtr.Zero)
75                                         {
76                                                 Marshal.FreeHGlobal(sqlvar.sqldata);
77                                                 sqlvar.sqldata = IntPtr.Zero;
78                                         }
79                                         if (sqlvar.sqlind != IntPtr.Zero)
80                                         {
81                                                 Marshal.FreeHGlobal(sqlvar.sqlind);
82                                                 sqlvar.sqlind = IntPtr.Zero;
83                                         }
84
85                                         Marshal.DestroyStructure(
86                                                 this.GetIntPtr(pNativeData,     this.ComputeLength(i)), typeof(XSQLVAR));
87                                 }
88
89                                 // Free pointer memory
90                                 Marshal.FreeHGlobal(pNativeData);
91
92                                 pNativeData = IntPtr.Zero;
93                         }
94                 }
95
96                 public IntPtr MarshalManagedToNative(Charset charset, Descriptor descriptor)
97                 {
98                         // Set up XSQLDA structure
99                         XSQLDA xsqlda = new XSQLDA();
100
101                         xsqlda.version = descriptor.Version;
102                         xsqlda.sqln      = descriptor.Count;
103                         xsqlda.sqld      = descriptor.ActualCount;
104                         
105                         XSQLVAR[] xsqlvar = new XSQLVAR[descriptor.Count];
106
107                         for     (int i = 0;     i <     xsqlvar.Length; i++)
108                         {
109                                 // Create a     new     XSQLVAR structure and fill it
110                                 xsqlvar[i] = new XSQLVAR();
111
112                                 xsqlvar[i].sqltype       = descriptor[i].DataType;
113                                 xsqlvar[i].sqlscale      = descriptor[i].NumericScale;
114                                 xsqlvar[i].sqlsubtype = descriptor[i].SubType;
115                                 xsqlvar[i].sqllen        = descriptor[i].Length;
116
117                                 // Create a     new     pointer for     the     xsqlvar data
118                                 byte[] buffer = this.GetBytes(descriptor[i]);
119                                 if (buffer.Length > 0)
120                                 {
121                                         xsqlvar[i].sqldata = Marshal.AllocHGlobal(buffer.Length);
122                                         Marshal.Copy(buffer, 0, xsqlvar[i].sqldata,     buffer.Length);
123                                 }
124
125                                 // Create a     new     pointer for     the     sqlind value
126                                 xsqlvar[i].sqlind = Marshal.AllocHGlobal(Marshal.SizeOf(typeof(Int16)));
127                                 Marshal.WriteInt16(xsqlvar[i].sqlind, descriptor[i].NullFlag);                            
128
129                                 // Name
130                                 xsqlvar[i].sqlname               = this.GetStringBuffer(charset,        descriptor[i].Name);
131                                 xsqlvar[i].sqlname_length = (short)xsqlvar[i].sqlname.Length;
132
133                                 // Relation     Name
134                                 xsqlvar[i].relname               = this.GetStringBuffer(charset,        descriptor[i].Relation);
135                                 xsqlvar[i].relname_length = (short)xsqlvar[i].relname.Length;
136
137                                 // Owner name
138                                 xsqlvar[i].ownername     = this.GetStringBuffer(charset,        descriptor[i].Owner);
139                                 xsqlvar[i].ownername_length = (short)xsqlvar[i].ownername.Length;
140
141                                 // Alias name
142                                 xsqlvar[i].aliasname     = this.GetStringBuffer(charset,        descriptor[i].Alias);
143                                 xsqlvar[i].aliasname_length = (short)xsqlvar[i].aliasname.Length;
144                         }
145
146                         return this.MarshalManagedToNative(xsqlda, xsqlvar);
147                 }
148
149                 public IntPtr MarshalManagedToNative(XSQLDA     xsqlda, XSQLVAR[] xsqlvar)
150                 {
151                         int             size = this.ComputeLength(xsqlda.sqln);
152                         IntPtr  ptr      = Marshal.AllocHGlobal(size);
153
154                         Marshal.StructureToPtr(xsqlda, ptr,     true);
155
156                         for     (int i = 0;     i <     xsqlvar.Length; i++)
157                         {
158                                 int     offset = this.ComputeLength(i);
159                                 Marshal.StructureToPtr(xsqlvar[i], this.GetIntPtr(ptr, offset), true);
160                         }
161
162                         return ptr;
163                 }
164
165                 public Descriptor MarshalNativeToManaged(Charset charset, IntPtr pNativeData)
166                 {
167                         // Obtain XSQLDA information
168                         XSQLDA xsqlda = new     XSQLDA();
169                         
170                         xsqlda = (XSQLDA)Marshal.PtrToStructure(pNativeData, typeof(XSQLDA));
171
172                         // Create a     new     Descriptor
173                         Descriptor descriptor = new Descriptor(xsqlda.sqln);
174                         descriptor.ActualCount = xsqlda.sqld;
175                         
176                         // Obtain XSQLVAR members information
177                         XSQLVAR[] xsqlvar = new XSQLVAR[xsqlda.sqln];
178                         
179                         for     (int i = 0;     i <     xsqlvar.Length; i++)
180                         {
181                                 xsqlvar[i] = (XSQLVAR)Marshal.PtrToStructure(
182                                         this.GetIntPtr(pNativeData,     this.ComputeLength(i)), typeof(XSQLVAR));
183
184                                 // Map XSQLVAR information to Descriptor
185                                 descriptor[i].DataType   = xsqlvar[i].sqltype;
186                                 descriptor[i].NumericScale = xsqlvar[i].sqlscale;
187                                 descriptor[i].SubType    = xsqlvar[i].sqlsubtype;
188                                 descriptor[i].Length     = xsqlvar[i].sqllen;
189
190                                 // Decode sqlind value
191                                 if (xsqlvar[i].sqlind == IntPtr.Zero)
192                                 {
193                                         descriptor[i].NullFlag = 0;
194                                 }
195                                 else
196                                 {
197                                         descriptor[i].NullFlag = Marshal.ReadInt16(xsqlvar[i].sqlind);
198                                 }
199                                 
200                                 // Set value
201                                 if (descriptor[i].NullFlag != -1)
202                                 {
203                                         descriptor[i].SetValue(this.GetBytes(xsqlvar[i]));
204                                 }
205                                 
206                                 descriptor[i].Name       = this.GetString(charset, xsqlvar[i].sqlname);
207                                 descriptor[i].Relation = this.GetString(charset, xsqlvar[i].relname);
208                                 descriptor[i].Owner      = this.GetString(charset, xsqlvar[i].ownername);
209                                 descriptor[i].Alias      = this.GetString(charset, xsqlvar[i].aliasname);
210                         }
211
212                         return descriptor;
213                 }
214
215                 #endregion
216
217                 #region Private methods
218
219                 private IntPtr GetIntPtr(IntPtr ptr, int offset)
220                 {
221                         return (IntPtr)(ptr.ToInt32() + offset);
222                 }
223
224                 private int     ComputeLength(int n)
225                 {
226                         return (Marshal.SizeOf(typeof(XSQLDA)) + n * Marshal.SizeOf(typeof(XSQLVAR)));
227                 }
228
229                 private byte[] GetBytes(XSQLVAR xsqlvar)
230                 {
231                         if (xsqlvar.sqllen == 0 || xsqlvar.sqldata == IntPtr.Zero)
232                         {
233                                 return null;
234                         }
235
236                         byte[] buffer = new     byte[xsqlvar.sqllen];
237
238                         switch (xsqlvar.sqltype & ~1)
239                         {
240                                 case IscCodes.SQL_VARYING:
241                                         short length = Marshal.ReadInt16(xsqlvar.sqldata);
242
243                                         buffer = new byte[length];
244
245                                         IntPtr tmp = this.GetIntPtr(xsqlvar.sqldata, 2);
246
247                                         Marshal.Copy(tmp, buffer, 0, buffer.Length);
248
249                                         return buffer;
250
251                                 case IscCodes.SQL_TEXT: 
252                                 case IscCodes.SQL_SHORT:
253                                 case IscCodes.SQL_LONG:
254                                 case IscCodes.SQL_FLOAT:
255                                 case IscCodes.SQL_DOUBLE:
256                                 case IscCodes.SQL_D_FLOAT:
257                                 case IscCodes.SQL_QUAD:
258                                 case IscCodes.SQL_INT64:
259                                 case IscCodes.SQL_BLOB:
260                                 case IscCodes.SQL_ARRAY:        
261                                 case IscCodes.SQL_TIMESTAMP:
262                                 case IscCodes.SQL_TYPE_TIME:
263                                 case IscCodes.SQL_TYPE_DATE:
264                                         Marshal.Copy(xsqlvar.sqldata, buffer, 0, buffer.Length);
265
266                                         return buffer;
267
268                                 default:
269                                         throw new NotSupportedException("Unknown data type");
270                         }
271                 }
272
273                 private byte[] GetBytes(DbField field)
274                 {
275                         if (field.DbValue.IsDBNull())
276                         {
277                                 int     length = field.Length;
278                                 
279                                 if (field.SqlType == IscCodes.SQL_VARYING)
280                                 {
281                                         // Add two bytes more for store value length
282                                         length += 2;
283                                 }
284
285                                 return new byte[length];
286                         }
287
288                         switch (field.DbDataType)
289                         {
290                                 case DbDataType.Char:
291                                 {
292                                         string svalue = field.DbValue.GetString();
293
294                                         if ((field.Length %     field.Charset.BytesPerCharacter) == 0 &&
295                                                 svalue.Length > field.CharCount)
296                                         {        
297                                                 throw new IscException(335544321);       
298                                         }
299
300                                         byte[] buffer = new     byte[field.Length];
301                                         for     (int i = 0;     i <     buffer.Length; i++)
302                                         {
303                                                 buffer[i] = 32;
304                                         }
305
306                                         byte[] bytes = field.Charset.GetBytes(svalue);
307
308                                         Buffer.BlockCopy(bytes, 0, buffer, 0, bytes.Length);
309
310                                         return buffer;
311                                 }
312                                 
313                                 case DbDataType.VarChar:
314                                 {
315                                         string svalue = field.Value.ToString();
316
317                                         if ((field.Length %     field.Charset.BytesPerCharacter) == 0 &&
318                                                 svalue.Length > field.CharCount)
319                                         {        
320                                                 throw new IscException(335544321);       
321                                         }
322
323                                         byte[] sbuffer = field.Charset.GetBytes(svalue);
324
325                                         byte[] buffer = new     byte[field.Length +     2];
326
327                                         // Copy length
328                                         Buffer.BlockCopy(
329                                                 BitConverter.GetBytes((short)sbuffer.Length), 
330                                          0, buffer, 0, 2);
331                                         
332                                         // Copy string value
333                                         Buffer.BlockCopy(sbuffer, 0, buffer, 2, sbuffer.Length);
334
335                                         return buffer;
336                                 }
337
338                                 case DbDataType.Numeric:
339                                 case DbDataType.Decimal:
340                                         return this.GetNumericBytes(field);
341
342                                 case DbDataType.SmallInt:
343                                         return BitConverter.GetBytes(field.DbValue.GetInt16());
344
345                                 case DbDataType.Integer:
346                                         return BitConverter.GetBytes(field.DbValue.GetInt32());
347
348                                 case DbDataType.Array:
349                                 case DbDataType.Binary:
350                                 case DbDataType.Text:
351                                 case DbDataType.BigInt:
352                                         return BitConverter.GetBytes(field.DbValue.GetInt64());
353
354                                 case DbDataType.Float:
355                                         return BitConverter.GetBytes(field.DbValue.GetFloat());
356                                                                         
357                                 case DbDataType.Double:
358                                         return BitConverter.GetBytes(field.DbValue.GetDouble());
359
360                                 case DbDataType.Date:
361                                         return BitConverter.GetBytes(
362                                                 TypeEncoder.EncodeDate(field.DbValue.GetDateTime()));
363                                 
364                                 case DbDataType.Time:
365                                         return BitConverter.GetBytes(
366                                                 TypeEncoder.EncodeTime(field.DbValue.GetDateTime()));
367                                 
368                                 case DbDataType.TimeStamp:
369                                         byte[] date = BitConverter.GetBytes(
370                                                 TypeEncoder.EncodeDate(field.DbValue.GetDateTime()));
371                                         
372                                         byte[] time = BitConverter.GetBytes(
373                                                 TypeEncoder.EncodeTime(field.DbValue.GetDateTime()));
374                                         
375                                         byte[] result = new     byte[8];
376
377                                         Buffer.BlockCopy(date, 0, result, 0, date.Length);
378                                         Buffer.BlockCopy(time, 0, result, 4, time.Length);
379
380                                         return result;
381
382                                 case DbDataType.Guid:
383                                         return field.DbValue.GetGuid().ToByteArray();
384
385                                 default:
386                                         throw new NotSupportedException("Unknown data type");
387                         }
388                 }
389
390                 private byte[] GetNumericBytes(DbField field)
391                 {
392                         decimal value = field.DbValue.GetDecimal();
393                         object  numeric = TypeEncoder.EncodeDecimal(value, field.NumericScale, field.DataType);
394
395                         switch (field.SqlType)
396                         {
397                                 case IscCodes.SQL_SHORT:
398                                         return BitConverter.GetBytes((short)numeric);
399
400                                 case IscCodes.SQL_LONG:
401                                         return BitConverter.GetBytes((int)numeric);
402
403                                 case IscCodes.SQL_INT64:
404                                 case IscCodes.SQL_QUAD:
405                                         return BitConverter.GetBytes((long)numeric);
406
407                                 case IscCodes.SQL_DOUBLE:
408                                         return BitConverter.GetBytes(field.DbValue.GetDouble());
409
410                                 default:
411                                         return null;
412                         }
413                 }
414
415                 private byte[] GetStringBuffer(Charset charset, string value)
416                 {
417                         byte[] buffer = new     byte[32];
418                         
419                         charset.GetBytes(value, 0, value.Length, buffer, 0);
420
421                         return buffer;
422                 }
423
424                 private string GetString(Charset charset, byte[] buffer)
425                 {
426                         string value = charset.GetString(buffer);
427
428                         return value.Replace('\0', ' ').Trim();
429                 }
430
431                 #endregion
432         }
433 }