2007-07-22 Nagappan A <anagappan@novell.com>
[mono.git] / mcs / class / Mono.Data.Tds / Mono.Data.Tds.Protocol / TdsComm.cs
1 //
2 // Mono.Data.Tds.Protocol.TdsComm.cs
3 //
4 // Author:
5 //   Tim Coleman (tim@timcoleman.com)
6 //
7 // Copyright (C) 2002 Tim Coleman
8 //
9
10 //
11 // Permission is hereby granted, free of charge, to any person obtaining
12 // a copy of this software and associated documentation files (the
13 // "Software"), to deal in the Software without restriction, including
14 // without limitation the rights to use, copy, modify, merge, publish,
15 // distribute, sublicense, and/or sell copies of the Software, and to
16 // permit persons to whom the Software is furnished to do so, subject to
17 // the following conditions:
18 // 
19 // The above copyright notice and this permission notice shall be
20 // included in all copies or substantial portions of the Software.
21 // 
22 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
23 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
24 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
25 // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
26 // LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
27 // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
28 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
29 //
30
31 using System;
32 using System.Net;
33 using System.Net.Sockets;
34 using System.Text;
35 using System.Threading;
36
37 namespace Mono.Data.Tds.Protocol {
38         internal sealed class TdsComm
39         {
40                 #region Fields
41
42                 NetworkStream stream;
43                 int packetSize;
44                 TdsPacketType packetType = TdsPacketType.None;
45                 Encoding encoder;
46
47                 string dataSource;
48                 int commandTimeout;
49                 int connectionTimeout;
50
51                 byte[] outBuffer;
52                 int outBufferLength;
53                 int nextOutBufferIndex = 0;
54
55                 byte[] inBuffer;
56                 int inBufferLength;
57                 int inBufferIndex = 0;
58
59                 static int headerLength = 8;
60
61                 byte[] tmpBuf = new byte[8];
62                 byte[] resBuffer = new byte[256];
63
64                 int packetsSent = 0;
65                 int packetsReceived = 0;
66
67                 Socket socket;
68                 TdsVersion tdsVersion;
69
70                 ManualResetEvent connected = new ManualResetEvent (false);
71                 
72                 #endregion // Fields
73                 
74                 #region Constructors
75
76                 [MonoTODO ("Fix when asynchronous socket connect works on Linux.")]             
77                 public TdsComm (string dataSource, int port, int packetSize, int timeout, TdsVersion tdsVersion)
78                 {
79                         this.packetSize = packetSize;
80                         this.tdsVersion = tdsVersion;
81                         this.dataSource = dataSource;
82                         this.connectionTimeout = timeout;
83
84                         outBuffer = new byte[packetSize];
85                         inBuffer = new byte[packetSize];
86
87                         outBufferLength = packetSize;
88                         inBufferLength = packetSize;
89
90                         IPEndPoint endPoint;
91                         
92                         try {
93                                 socket = new Socket (AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
94 #if NET_2_0
95                                 IPHostEntry hostEntry = Dns.GetHostEntry (this.dataSource);
96 #else
97                                 IPHostEntry hostEntry = Dns.Resolve (this.dataSource);
98 #endif
99                                 endPoint = new IPEndPoint (hostEntry.AddressList [0], port);
100
101                                 // This replaces the code below for now
102                                 socket.Connect (endPoint);
103
104                                 /*
105                                   FIXME: Asynchronous socket connection doesn't work right on linux, so comment 
106                                   this out for now.  This *does* do the right thing on windows
107
108                                   connected.Reset ();
109                                   IAsyncResult asyncResult = socket.BeginConnect (endPoint, new AsyncCallback (ConnectCallback), socket);
110
111                                   if (timeout > 0 && !connected.WaitOne (new TimeSpan (0, 0, timeout), true))
112                                   throw Tds.CreateTimeoutException (dataSource, "Open()");
113                                   else if (timeout > 0 && !connected.WaitOne ())
114                                   throw Tds.CreateTimeoutException (dataSource, "Open()");
115                                 */
116
117                                 stream = new NetworkStream (socket);
118                         } catch (SocketException e) {
119                                 throw new TdsInternalException ("Server does not exist or connection refused.", e);
120                         }
121                 }
122                 
123                 #endregion // Constructors
124                 
125                 #region Properties
126
127                 public int CommandTimeout {
128                         get { return commandTimeout; }
129                         set { commandTimeout = value; }
130                 }
131
132                 internal Encoding Encoder {
133                         get { return encoder; }
134                         set { encoder = value; }
135                 }
136                 
137                 public int PacketSize {
138                         get { return packetSize; }
139                         set { packetSize = value; }
140                 }
141                 
142                 #endregion // Properties
143                 
144                 #region Methods
145
146                 public byte[] Swap(byte[] toswap) {
147                         byte[] ret = new byte[toswap.Length];
148                         for(int i = 0; i < toswap.Length; i++)
149                                 ret [toswap.Length - i - 1] = toswap[i];
150
151                         return ret;
152                 }
153                 public void Append (object o)
154                 {
155                         if (o == null || o == DBNull.Value) {
156                                 Append ((byte)0);
157                                 return ;
158                         }
159                         switch (Type.GetTypeCode (o.GetType ())) {
160                         case TypeCode.Byte :
161                                 Append ((byte) o);
162                                 return;
163                         case TypeCode.Boolean:
164                                 if ((bool)o == true)
165                                         Append ((byte)1);
166                                 else
167                                         Append ((byte)0);
168                                 return;
169                         case TypeCode.Object :
170                                 if (o is byte[])
171                                         Append ((byte[]) o);
172                                 return;
173                         case TypeCode.Int16 :
174                                 Append ((short) o);
175                                 return;
176                         case TypeCode.Int32 :
177                                 Append ((int) o);
178                                 return;
179                         case TypeCode.String :
180                                 Append ((string) o);
181                                 return;
182                         case TypeCode.Double :
183                                 Append ((double) o);
184                                 return;
185                         case TypeCode.Single :
186                                 Append ((float) o);
187                                 return;
188                         case TypeCode.Int64 :
189                                 Append ((long) o);
190                                 return;
191                         case TypeCode.Decimal:
192                                 Append ((decimal) o, 17);
193                                 return;
194                         case TypeCode.DateTime:
195                                 Append ((DateTime) o, 8);
196                                 return;
197                         }
198                         throw new InvalidOperationException (String.Format ("Object Type :{0} , not being appended", o.GetType ()));
199                 }
200
201                 public void Append (byte b)
202                 {
203                         if (nextOutBufferIndex == outBufferLength) {
204                                 SendPhysicalPacket (false);
205                                 nextOutBufferIndex = headerLength;
206                         }
207                         Store (nextOutBufferIndex, b);
208                         nextOutBufferIndex++;
209                 }       
210
211                 public void Append (DateTime t, int bytes)
212                 {
213                         DateTime epoch = new DateTime (1900,1,1);
214                         
215                         TimeSpan span = t - epoch;
216                         int days = span.Days ;
217                         int val = 0;    
218
219                         if (bytes == 8) {
220                                 long ms = (span.Hours * 3600 + span.Minutes * 60 + span.Seconds)*1000L + (long)span.Milliseconds;
221                                 val = (int) ((ms*300)/1000);
222                                 Append ((int) days);
223                                 Append ((int) val);
224                         } else if (bytes ==4) {
225                                 val = span.Hours * 60 + span.Minutes;
226                                 Append ((ushort) days);
227                                 Append ((short) val);
228                         } else {
229                                 throw new Exception ("Invalid No of bytes");
230                         }
231                 }
232
233                 public void Append (byte[] b)
234                 {
235                         Append (b, b.Length, (byte) 0);
236                 }               
237
238                 public void Append (byte[] b, int len, byte pad)
239                 {
240                         int i = 0;
241                         for ( ; i < b.Length && i < len; i++)
242                             Append (b[i]);
243
244                         for ( ; i < len; i++)
245                             Append (pad);
246                 }       
247
248                 public void Append (short s)
249                 {
250                         if(!BitConverter.IsLittleEndian)
251                                 Append (Swap (BitConverter.GetBytes(s)));
252                         else 
253                                 Append (BitConverter.GetBytes (s));
254                 }
255
256                 public void Append (ushort s)
257                 {
258                         if(!BitConverter.IsLittleEndian)
259                                 Append (Swap (BitConverter.GetBytes(s)));
260                         else 
261                                 Append (BitConverter.GetBytes (s));
262                 }
263
264                 public void Append (int i)
265                 {
266                         if(!BitConverter.IsLittleEndian)
267                                 Append (Swap (BitConverter.GetBytes(i)));
268                         else
269                                 Append (BitConverter.GetBytes (i));
270                 }
271
272                 public void Append (string s)
273                 {
274                         if (tdsVersion < TdsVersion.tds70) 
275                                 Append (encoder.GetBytes (s));
276                         else 
277                                 foreach (char c in s)
278                                         if(!BitConverter.IsLittleEndian)
279                                                 Append (Swap (BitConverter.GetBytes (c)));
280                                         else
281                                                 Append (BitConverter.GetBytes (c));
282                 }       
283
284                 // Appends with padding
285                 public byte[] Append (string s, int len, byte pad)
286                 {
287                         if (s == null)
288                                 return new byte[0];
289
290                         byte[] result = encoder.GetBytes (s);
291                         Append (result, len, pad);
292                         return result;
293                 }
294
295                 public void Append (double value)
296                 {
297                         if (!BitConverter.IsLittleEndian)
298                                 Append (Swap (BitConverter.GetBytes (value)));
299                         else
300                                 Append (BitConverter.GetBytes (value));
301                 }
302
303                 public void Append (float value)
304                 {
305                         if (!BitConverter.IsLittleEndian)
306                                 Append (Swap (BitConverter.GetBytes (value)));
307                         else
308                                 Append (BitConverter.GetBytes (value));
309                 }
310
311                 public void Append (long l)
312                 {
313                         if (!BitConverter.IsLittleEndian)
314                                 Append (Swap (BitConverter.GetBytes (l)));
315                         else
316                                 Append (BitConverter.GetBytes (l));
317                 }
318
319                 public void Append (decimal d, int bytes)
320                 {
321                         int[] arr = Decimal.GetBits (d);
322                         byte sign =  (d > 0 ? (byte)1 : (byte)0);
323                         Append (sign) ;
324                         Append (arr[0]);
325                         Append (arr[1]);
326                         Append (arr[2]);
327                         Append ((int)0);
328                 }
329
330                 public void Close ()
331                 {
332                         stream.Close ();
333                 }
334
335                 private void ConnectCallback (IAsyncResult ar)
336                 {
337                         Socket s = (Socket) ar.AsyncState;
338                         if (Poll (s, connectionTimeout, SelectMode.SelectWrite)) {
339                                 socket.EndConnect (ar);
340                                 connected.Set ();
341                         }
342                 }
343
344                 public byte GetByte ()
345                 {
346                         byte result;
347
348                         if (inBufferIndex >= inBufferLength) {
349                                 // out of data, read another physical packet.
350                                 GetPhysicalPacket ();
351                         }
352                         result = inBuffer[inBufferIndex++];
353                         return result;
354                 }
355
356                 public byte[] GetBytes (int len, bool exclusiveBuffer)
357                 {
358                         byte[] result = null;
359                         int i;
360
361                         // Do not keep an internal result buffer larger than 16k.
362                         // This would unnecessarily use up memory.
363                         if (exclusiveBuffer || len > 16384)
364                                 result = new byte[len];
365                         else
366                         {
367                                 if (resBuffer.Length < len)
368                                         resBuffer = new byte[len];
369                                 result = resBuffer;
370                         }
371
372                         for (i = 0; i<len; )
373                         {
374                                 if (inBufferIndex >= inBufferLength)
375                                         GetPhysicalPacket ();
376
377                                 int avail = inBufferLength - inBufferIndex;
378                                 avail = avail>len-i ? len-i : avail;
379
380                                 System.Array.Copy (inBuffer, inBufferIndex, result, i, avail);
381                                 i += avail;
382                                 inBufferIndex += avail;
383                         }
384
385                         return result;
386                 }
387
388                 public string GetString (int len)
389                 {
390                         if (tdsVersion == TdsVersion.tds70) 
391                                 return GetString (len, true);
392                         else
393                                 return GetString (len, false);
394                 }
395
396                 public string GetString (int len, bool wide)
397                 {
398                         if (wide) {
399                                 char[] chars = new char[len];
400                                 for (int i = 0; i < len; ++i) {
401                                         int lo = ((byte) GetByte ()) & 0xFF;
402                                         int hi = ((byte) GetByte ()) & 0xFF;
403                                         chars[i] = (char) (lo | ( hi << 8));
404                                 }
405                                 return new String (chars);
406                         }
407                         else {
408                                 byte[] result = new byte[len];
409                                 Array.Copy (GetBytes (len, false), result, len);
410                                 return (encoder.GetString (result));
411                         }
412                 }
413
414                 public int GetNetShort ()
415                 {
416                         byte[] tmp = new byte[2];
417                         tmp[0] = GetByte ();
418                         tmp[1] = GetByte ();
419                         return Ntohs (tmp, 0);
420                 }
421
422                 public short GetTdsShort ()
423                 {
424                         byte[] input = new byte[2];
425
426                         for (int i = 0; i < 2; i += 1)
427                                 input[i] = GetByte ();
428                         if(!BitConverter.IsLittleEndian)
429                                 return (BitConverter.ToInt16 (Swap (input), 0));
430                         else
431                                 return (BitConverter.ToInt16 (input, 0));
432                 }
433
434
435                 public int GetTdsInt ()
436                 {
437                         byte[] input = new byte[4];
438                         for (int i = 0; i < 4; i += 1) {
439                                 input[i] = GetByte ();
440                         }
441                         if(!BitConverter.IsLittleEndian)
442                                 return (BitConverter.ToInt32 (Swap (input), 0));
443                         else
444                                 return (BitConverter.ToInt32 (input, 0));
445                 }
446
447                 public long GetTdsInt64 ()
448                 {
449                         byte[] input = new byte[8];
450                         for (int i = 0; i < 8; i += 1)
451                                 input[i] = GetByte ();
452                         if(!BitConverter.IsLittleEndian)
453                                 return (BitConverter.ToInt64 (Swap (input), 0));
454                         else
455                                 return (BitConverter.ToInt64 (input, 0));
456                 }
457
458                 private void GetPhysicalPacket ()
459                 {
460                         int dataLength = GetPhysicalPacketHeader ();
461                         GetPhysicalPacketData (dataLength);
462                 }
463
464                 private int GetPhysicalPacketHeader ()
465                 {
466                         int nread = 0;
467                                                 
468                         // read the header
469                         while (nread < 8)
470                                 nread += stream.Read (tmpBuf, nread, 8 - nread);
471
472                         TdsPacketType packetType = (TdsPacketType) tmpBuf[0];
473                         if (packetType != TdsPacketType.Logon && packetType != TdsPacketType.Query && packetType != TdsPacketType.Reply) 
474                         {
475                                 throw new Exception (String.Format ("Unknown packet type {0}", tmpBuf[0]));
476                         }
477
478                         // figure out how many bytes are remaining in this packet.
479                         int len = Ntohs (tmpBuf, 2) - 8;
480                         if (len >= inBuffer.Length) 
481                                 inBuffer = new byte[len];
482
483                         if (len < 0) {
484                                 throw new Exception (String.Format ("Confused by a length of {0}", len));
485                         }
486                         
487                         return len;
488
489                 }
490                 
491                 private void GetPhysicalPacketData (int length)
492                 {
493                         // now get the data
494                         int nread = 0;
495                         while (nread < length) {
496                                 nread += stream.Read (inBuffer, nread, length - nread);
497                         }
498                         packetsReceived++;
499
500                         // adjust the bookkeeping info about the incoming buffer
501                         inBufferLength = length;
502                         inBufferIndex = 0;
503                 }
504                 
505
506                 private static int Ntohs (byte[] buf, int offset)
507                 {
508                         int lo = ((int) buf[offset + 1] & 0xff);
509                         int hi = (((int) buf[offset] & 0xff ) << 8);
510
511                         return hi | lo;
512                         // return an int since we really want an _unsigned_
513                 }               
514
515                 public byte Peek ()
516                 {
517                         // If out of data, read another physical packet.
518                         if (inBufferIndex >= inBufferLength)
519                                 GetPhysicalPacket ();
520
521                         return inBuffer[inBufferIndex];
522                 }
523
524                 public bool Poll (int seconds, SelectMode selectMode)
525                 {
526                         return Poll (socket, seconds, selectMode);
527                 }
528
529                 private bool Poll (Socket s, int seconds, SelectMode selectMode)
530                 {
531                         long uSeconds = seconds * 1000000;
532                         bool bState = false;
533
534                         while (uSeconds > (long) Int32.MaxValue) {
535                                 bState = s.Poll (Int32.MaxValue, selectMode);
536                                 if (bState) 
537                                         return true;
538                                 uSeconds -= Int32.MaxValue;
539                         }
540                         return s.Poll ((int) uSeconds, selectMode);
541                 }
542
543                 internal void ResizeOutBuf (int newSize)
544                 {
545                         if (newSize != outBufferLength) {
546                                 byte[] newBuf = new byte [newSize];
547                                 Array.Copy (outBuffer, 0, newBuf, 0, newSize);
548                                 outBufferLength = newSize;
549                                 outBuffer = newBuf;
550                         }
551                 }
552
553                 public void SendPacket ()
554                 {
555                         SendPhysicalPacket (true);
556                         nextOutBufferIndex = 0;
557                         packetType = TdsPacketType.None;
558                 }
559                 
560                 private void SendPhysicalPacket (bool isLastSegment)
561                 {
562                         if (nextOutBufferIndex > headerLength || packetType == TdsPacketType.Cancel) {
563                                 // packet type
564                                 Store (0, (byte) packetType);
565                                 Store (1, (byte) (isLastSegment ? 1 : 0));
566                                 Store (2, (short) nextOutBufferIndex );
567                                 Store (4, (byte) 0);
568                                 Store (5, (byte) 0);
569                                 Store (6, (byte) (tdsVersion == TdsVersion.tds70 ? 0x1 : 0x0));
570                                 Store (7, (byte) 0);
571
572                                 stream.Write (outBuffer, 0, nextOutBufferIndex);
573                                 stream.Flush ();
574                                 packetsSent++;
575                         }
576                 }
577                 
578                 public void Skip (long i)
579                 {
580                         for ( ; i > 0; i--)
581                                 GetByte ();
582                 }
583
584                 public void StartPacket (TdsPacketType type)
585                 {
586                         if (type != TdsPacketType.Cancel && inBufferIndex != inBufferLength)
587                                 inBufferIndex = inBufferLength;
588
589                         packetType = type;
590                         nextOutBufferIndex = headerLength;
591                 }
592
593                 private void Store (int index, byte value)
594                 {
595                         outBuffer[index] = value;
596                 }               
597
598                 private void Store (int index, short value)
599                 {
600                         outBuffer[index] = (byte) (((byte) (value >> 8)) & 0xff);
601                         outBuffer[index + 1] = (byte) (((byte) (value >> 0)) & 0xff);
602                 }
603
604                 #endregion // Methods
605 #if NET_2_0
606                 #region Async Methods
607
608                 public IAsyncResult BeginReadPacket (AsyncCallback callback, object stateObject)
609                 {
610                         TdsAsyncResult ar = new TdsAsyncResult (callback, stateObject);
611
612                         stream.BeginRead (tmpBuf, 0, 8, new AsyncCallback(OnReadPacketCallback), ar);
613                         return ar;
614                 }
615                 
616                 /// <returns>Packet size in bytes</returns>
617                 public int EndReadPacket (IAsyncResult ar)
618                 {
619                         if (!ar.IsCompleted)
620                                 ar.AsyncWaitHandle.WaitOne ();
621                         return (int) ((TdsAsyncResult) ar).ReturnValue;
622                 }
623                 
624
625                 public void OnReadPacketCallback (IAsyncResult socketAsyncResult)
626                 {
627                         TdsAsyncResult ar = (TdsAsyncResult) socketAsyncResult.AsyncState;
628                         int nread = stream.EndRead (socketAsyncResult);
629                         
630                         while (nread < 8)
631                                 nread += stream.Read (tmpBuf, nread, 8 - nread);
632
633                         TdsPacketType packetType = (TdsPacketType) tmpBuf[0];
634                         if (packetType != TdsPacketType.Logon && packetType != TdsPacketType.Query && packetType != TdsPacketType.Reply) 
635                         {
636                                 throw new Exception (String.Format ("Unknown packet type {0}", tmpBuf[0]));
637                         }
638
639                         // figure out how many bytes are remaining in this packet.
640                         int len = Ntohs (tmpBuf, 2) - 8;
641
642                         if (len >= inBuffer.Length) 
643                                 inBuffer = new byte[len];
644
645                         if (len < 0) {
646                                 throw new Exception (String.Format ("Confused by a length of {0}", len));
647                         }
648
649                         GetPhysicalPacketData (len);
650                         int value = len + 8;
651                         ar.ReturnValue = ((object)value); // packet size
652                         ar.MarkComplete ();
653                 }
654                 
655                 #endregion // Async Methods
656 #endif // NET_2_0
657
658         }
659
660 }