In System.Data/System.Data.SqlClient:
[mono.git] / mcs / class / Mono.Data.Tds / Mono.Data.Tds.Protocol / TdsComm.cs
1 //
2 // Mono.Data.Tds.Protocol.TdsComm.cs
3 //
4 // Author:
5 //   Tim Coleman (tim@timcoleman.com)
6 //
7 // Copyright (C) 2002 Tim Coleman
8 //
9
10 //
11 // Permission is hereby granted, free of charge, to any person obtaining
12 // a copy of this software and associated documentation files (the
13 // "Software"), to deal in the Software without restriction, including
14 // without limitation the rights to use, copy, modify, merge, publish,
15 // distribute, sublicense, and/or sell copies of the Software, and to
16 // permit persons to whom the Software is furnished to do so, subject to
17 // the following conditions:
18 // 
19 // The above copyright notice and this permission notice shall be
20 // included in all copies or substantial portions of the Software.
21 // 
22 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
23 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
24 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
25 // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
26 // LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
27 // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
28 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
29 //
30
31 using System;
32 using System.Net;
33 using System.Net.Sockets;
34 using System.Text;
35 using System.Threading;
36
37 namespace Mono.Data.Tds.Protocol {
38         internal sealed class TdsComm
39         {
40                 #region Fields
41
42                 NetworkStream stream;
43                 int packetSize;
44                 TdsPacketType packetType = TdsPacketType.None;
45                 Encoding encoder;
46
47                 string dataSource;
48                 int commandTimeout;
49                 int connectionTimeout;
50
51                 byte[] outBuffer;
52                 int outBufferLength;
53                 int nextOutBufferIndex = 0;
54
55                 byte[] inBuffer;
56                 int inBufferLength;
57                 int inBufferIndex = 0;
58
59                 static int headerLength = 8;
60
61                 byte[] tmpBuf = new byte[8];
62                 byte[] resBuffer = new byte[256];
63
64                 int packetsSent = 0;
65                 int packetsReceived = 0;
66
67                 Socket socket;
68                 TdsVersion tdsVersion;
69
70                 ManualResetEvent connected = new ManualResetEvent (false);
71                 
72                 #endregion // Fields
73                 
74                 #region Constructors
75
76                 [MonoTODO ("Fix when asynchronous socket connect works on Linux.")]             
77                 public TdsComm (string dataSource, int port, int packetSize, int timeout, TdsVersion tdsVersion)
78                 {
79                         this.packetSize = packetSize;
80                         this.tdsVersion = tdsVersion;
81                         this.dataSource = dataSource;
82                         this.connectionTimeout = timeout;
83
84                         outBuffer = new byte[packetSize];
85                         inBuffer = new byte[packetSize];
86
87                         outBufferLength = packetSize;
88                         inBufferLength = packetSize;
89
90                         IPEndPoint endPoint;
91                         
92                         try {
93                                 socket = new Socket (AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
94                                 IPHostEntry hostEntry = Dns.Resolve (dataSource);
95                                 endPoint = new IPEndPoint (hostEntry.AddressList [0], port);
96
97                                 // This replaces the code below for now
98                                 socket.Connect (endPoint);
99
100                                 /*
101                                   FIXME: Asynchronous socket connection doesn't work right on linux, so comment 
102                                   this out for now.  This *does* do the right thing on windows
103
104                                   connected.Reset ();
105                                   IAsyncResult asyncResult = socket.BeginConnect (endPoint, new AsyncCallback (ConnectCallback), socket);
106
107                                   if (timeout > 0 && !connected.WaitOne (new TimeSpan (0, 0, timeout), true))
108                                   throw Tds.CreateTimeoutException (dataSource, "Open()");
109                                   else if (timeout > 0 && !connected.WaitOne ())
110                                   throw Tds.CreateTimeoutException (dataSource, "Open()");
111                                 */
112
113                                 stream = new NetworkStream (socket);
114                         } catch (SocketException e) {
115                                 throw new TdsInternalException ("Server does not exist or connection refused.", e);
116                         }
117                 }
118                 
119                 #endregion // Constructors
120                 
121                 #region Properties
122
123                 public int CommandTimeout {
124                         get { return commandTimeout; }
125                         set { commandTimeout = value; }
126                 }
127
128                 internal Encoding Encoder {
129                         get { return encoder; }
130                         set { encoder = value; }
131                 }
132                 
133                 public int PacketSize {
134                         get { return packetSize; }
135                         set { packetSize = value; }
136                 }
137                 
138                 #endregion // Properties
139                 
140                 #region Methods
141
142                 public byte[] Swap(byte[] toswap) {
143                         byte[] ret = new byte[toswap.Length];
144                         for(int i = 0; i < toswap.Length; i++)
145                                 ret [toswap.Length - i - 1] = toswap[i];
146
147                         return ret;
148                 }
149                 public void Append (object o)
150                 {
151                         switch (o.GetType ().ToString ()) {
152                         case "System.Byte":
153                                 Append ((byte) o);
154                                 return;
155                         case "System.Byte[]":
156                                 Append ((byte[]) o);
157                                 return;
158                         case "System.Int16":
159                                 Append ((short) o);
160                                 return;
161                         case "System.Int32":
162                                 Append ((int) o);
163                                 return;
164                         case "System.String":
165                                 Append ((string) o);
166                                 return;
167                         case "System.Double":
168                                 Append ((double) o);
169                                 return;
170                         case "System.Int64":
171                                 Append ((long) o);
172                                 return;
173                         }
174                 }
175
176                 public void Append (byte b)
177                 {
178                         if (nextOutBufferIndex == outBufferLength) {
179                                 SendPhysicalPacket (false);
180                                 nextOutBufferIndex = headerLength;
181                         }
182                         Store (nextOutBufferIndex, b);
183                         nextOutBufferIndex++;
184                 }       
185                 
186                 public void Append (byte[] b)
187                 {
188                         Append (b, b.Length, (byte) 0);
189                 }               
190
191                 public void Append (byte[] b, int len, byte pad)
192                 {
193                         int i = 0;
194                         for ( ; i < b.Length && i < len; i++)
195                             Append (b[i]);
196
197                         for ( ; i < len; i++)
198                             Append (pad);
199                 }       
200
201                 public void Append (short s)
202                 {
203                         if(!BitConverter.IsLittleEndian)
204                                 Append (Swap (BitConverter.GetBytes(s)));
205                         else 
206                                 Append (BitConverter.GetBytes (s));
207                 }
208
209                 public void Append (int i)
210                 {
211                         if(!BitConverter.IsLittleEndian)
212                                 Append (Swap (BitConverter.GetBytes(i)));
213                         else
214                                 Append (BitConverter.GetBytes (i));
215                 }
216
217                 public void Append (string s)
218                 {
219                         if (tdsVersion < TdsVersion.tds70) 
220                                 Append (encoder.GetBytes (s));
221                         else 
222                                 foreach (char c in s)
223                                         if(!BitConverter.IsLittleEndian)
224                                                 Append (Swap (BitConverter.GetBytes (c)));
225                                         else
226                                                 Append (BitConverter.GetBytes (c));
227                 }       
228
229                 // Appends with padding
230                 public byte[] Append (string s, int len, byte pad)
231                 {
232                         if (s == null)
233                                 return new byte[0];
234
235                         byte[] result = encoder.GetBytes (s);
236                         Append (result, len, pad);
237                         return result;
238                 }
239
240                 public void Append (double value)
241                 {
242                         Append (BitConverter.DoubleToInt64Bits (value));
243                 }
244
245                 public void Append (long l)
246                 {
247                         if (tdsVersion < TdsVersion.tds70) {
248                                 Append ((byte) (((byte) (l >> 56)) & 0xff));
249                                 Append ((byte) (((byte) (l >> 48)) & 0xff));
250                                 Append ((byte) (((byte) (l >> 40)) & 0xff));
251                                 Append ((byte) (((byte) (l >> 32)) & 0xff));
252                                 Append ((byte) (((byte) (l >> 24)) & 0xff));
253                                 Append ((byte) (((byte) (l >> 16)) & 0xff));
254                                 Append ((byte) (((byte) (l >> 8)) & 0xff));
255                                 Append ((byte) (((byte) (l >> 0)) & 0xff));
256                         }
257                         else 
258                                 if (!BitConverter.IsLittleEndian)
259                                         Append (Swap (BitConverter.GetBytes (l)));
260                                 else
261                                         Append (BitConverter.GetBytes (l));
262                 }
263
264                 public void Close ()
265                 {
266                         stream.Close ();
267                 }
268
269                 private void ConnectCallback (IAsyncResult ar)
270                 {
271                         Socket s = (Socket) ar.AsyncState;
272                         if (Poll (s, connectionTimeout, SelectMode.SelectWrite)) {
273                                 socket.EndConnect (ar);
274                                 connected.Set ();
275                         }
276                 }
277
278                 public byte GetByte ()
279                 {
280                         byte result;
281
282                         if (inBufferIndex >= inBufferLength) {
283                                 // out of data, read another physical packet.
284                                 GetPhysicalPacket ();
285                         }
286
287                         result = inBuffer[inBufferIndex++];
288                         return result;
289                 }
290
291                 public byte[] GetBytes (int len, bool exclusiveBuffer)
292                 {
293                         byte[] result = null;
294                         int i;
295
296                         // Do not keep an internal result buffer larger than 16k.
297                         // This would unnecessarily use up memory.
298                         if (exclusiveBuffer || len > 16384)
299                                 result = new byte[len];
300                         else
301                         {
302                                 if (resBuffer.Length < len)
303                                         resBuffer = new byte[len];
304                                 result = resBuffer;
305                         }
306
307                         for (i = 0; i<len; )
308                         {
309                                 if (inBufferIndex >= inBufferLength)
310                                         GetPhysicalPacket ();
311
312                                 int avail = inBufferLength - inBufferIndex;
313                                 avail = avail>len-i ? len-i : avail;
314
315                                 System.Array.Copy (inBuffer, inBufferIndex, result, i, avail);
316                                 i += avail;
317                                 inBufferIndex += avail;
318                         }
319
320                         return result;
321                 }
322
323                 public string GetString (int len)
324                 {
325                         if (tdsVersion == TdsVersion.tds70) 
326                                 return GetString (len, true);
327                         else
328                                 return GetString (len, false);
329                 }
330
331                 public string GetString (int len, bool wide)
332                 {
333                         if (wide) {
334                                 char[] chars = new char[len];
335                                 for (int i = 0; i < len; ++i) {
336                                         int lo = ((byte) GetByte ()) & 0xFF;
337                                         int hi = ((byte) GetByte ()) & 0xFF;
338                                         chars[i] = (char) (lo | ( hi << 8));
339                                 }
340                                 return new String (chars);
341                         }
342                         else {
343                                 byte[] result = new byte[len];
344                                 Array.Copy (GetBytes (len, false), result, len);
345                                 return (encoder.GetString (result));
346                         }
347                 }
348
349                 public int GetNetShort ()
350                 {
351                         byte[] tmp = new byte[2];
352                         tmp[0] = GetByte ();
353                         tmp[1] = GetByte ();
354                         return Ntohs (tmp, 0);
355                 }
356
357                 public short GetTdsShort ()
358                 {
359                         byte[] input = new byte[2];
360
361                         for (int i = 0; i < 2; i += 1)
362                                 input[i] = GetByte ();
363                         if(!BitConverter.IsLittleEndian)
364                                 return (BitConverter.ToInt16 (Swap (input), 0));
365                         else
366                                 return (BitConverter.ToInt16 (input, 0));
367                 }
368
369
370                 public int GetTdsInt ()
371                 {
372                         byte[] input = new byte[4];
373                         for (int i = 0; i < 4; i += 1)
374                                 input[i] = GetByte ();
375                         if(!BitConverter.IsLittleEndian)
376                                 return (BitConverter.ToInt32 (Swap (input), 0));
377                         else
378                                 return (BitConverter.ToInt32 (input, 0));
379                 }
380
381                 public long GetTdsInt64 ()
382                 {
383                         byte[] input = new byte[8];
384                         for (int i = 0; i < 8; i += 1)
385                                 input[i] = GetByte ();
386                         if(!BitConverter.IsLittleEndian)
387                                 return (BitConverter.ToInt64 (Swap (input), 0));
388                         else
389                                 return (BitConverter.ToInt64 (input, 0));
390                 }
391
392                 private void GetPhysicalPacket ()
393                 {
394                         int dataLength = GetPhysicalPacketHeader ();
395                         GetPhysicalPacketData (dataLength);
396                 }
397
398                 private int GetPhysicalPacketHeader ()
399                 {
400                         int nread = 0;
401                                                 
402                         // read the header
403                         while (nread < 8)
404                                 nread += stream.Read (tmpBuf, nread, 8 - nread);
405
406                         TdsPacketType packetType = (TdsPacketType) tmpBuf[0];
407                         if (packetType != TdsPacketType.Logon && packetType != TdsPacketType.Query && packetType != TdsPacketType.Reply) 
408                         {
409                                 throw new Exception (String.Format ("Unknown packet type {0}", tmpBuf[0]));
410                         }
411
412                         // figure out how many bytes are remaining in this packet.
413                         int len = Ntohs (tmpBuf, 2) - 8;
414
415                         if (len >= inBuffer.Length) 
416                                 inBuffer = new byte[len];
417
418                         if (len < 0) {
419                                 throw new Exception (String.Format ("Confused by a length of {0}", len));
420                         }
421                         
422                         return len;
423
424                 }
425                 
426                 private void GetPhysicalPacketData (int length)
427                 {
428                         // now get the data
429                         int nread = 0;
430                         while (nread < length) {
431                                 nread += stream.Read (inBuffer, nread, length - nread);
432                         }
433
434                         packetsReceived++;
435
436                         // adjust the bookkeeping info about the incoming buffer
437                         inBufferLength = length;
438                         inBufferIndex = 0;
439                 }
440                 
441
442                 private static int Ntohs (byte[] buf, int offset)
443                 {
444                         int lo = ((int) buf[offset + 1] & 0xff);
445                         int hi = (((int) buf[offset] & 0xff ) << 8);
446
447                         return hi | lo;
448                         // return an int since we really want an _unsigned_
449                 }               
450
451                 public byte Peek ()
452                 {
453                         // If out of data, read another physical packet.
454                         if (inBufferIndex >= inBufferLength)
455                                 GetPhysicalPacket ();
456
457                         return inBuffer[inBufferIndex];
458                 }
459
460                 public bool Poll (int seconds, SelectMode selectMode)
461                 {
462                         return Poll (socket, seconds, selectMode);
463                 }
464
465                 private bool Poll (Socket s, int seconds, SelectMode selectMode)
466                 {
467                         long uSeconds = seconds * 1000000;
468                         bool bState = false;
469
470                         while (uSeconds > (long) Int32.MaxValue) {
471                                 bState = s.Poll (Int32.MaxValue, selectMode);
472                                 if (bState) 
473                                         return true;
474                                 uSeconds -= Int32.MaxValue;
475                         }
476                         return s.Poll ((int) uSeconds, selectMode);
477                 }
478
479                 internal void ResizeOutBuf (int newSize)
480                 {
481                         if (newSize > outBufferLength) {
482                                 byte[] newBuf = new byte [newSize];
483                                 Array.Copy (outBuffer, 0, newBuf, 0, outBufferLength);
484                                 outBufferLength = newSize;
485                                 outBuffer = newBuf;
486                         }
487                 }
488
489                 public void SendPacket ()
490                 {
491                         SendPhysicalPacket (true);
492                         nextOutBufferIndex = 0;
493                         packetType = TdsPacketType.None;
494                 }
495                 
496                 private void SendPhysicalPacket (bool isLastSegment)
497                 {
498                         if (nextOutBufferIndex > headerLength || packetType == TdsPacketType.Cancel) {
499                                 // packet type
500                                 Store (0, (byte) packetType);
501                                 Store (1, (byte) (isLastSegment ? 1 : 0));
502                                 Store (2, (short) nextOutBufferIndex );
503                                 Store (4, (byte) 0);
504                                 Store (5, (byte) 0);
505                                 Store (6, (byte) (tdsVersion == TdsVersion.tds70 ? 0x1 : 0x0));
506                                 Store (7, (byte) 0);
507
508                                 stream.Write (outBuffer, 0, nextOutBufferIndex);
509                                 stream.Flush ();
510                                 packetsSent++;
511                         }
512                 }
513                 
514                 public void Skip (int i)
515                 {
516                         for ( ; i > 0; i--)
517                                 GetByte ();
518                 }
519
520                 public void StartPacket (TdsPacketType type)
521                 {
522                         if (type != TdsPacketType.Cancel && inBufferIndex != inBufferLength)
523                                 inBufferIndex = inBufferLength;
524
525                         packetType = type;
526                         nextOutBufferIndex = headerLength;
527                 }
528
529                 private void Store (int index, byte value)
530                 {
531                         outBuffer[index] = value;
532                 }               
533
534                 private void Store (int index, short value)
535                 {
536                         outBuffer[index] = (byte) (((byte) (value >> 8)) & 0xff);
537                         outBuffer[index + 1] = (byte) (((byte) (value >> 0)) & 0xff);
538                 }
539
540                 #endregion // Methods
541 #if NET_2_0
542                 #region Async Methods
543
544                 public IAsyncResult BeginReadPacket (AsyncCallback callback, object stateObject)
545                 {
546                         TdsAsyncResult ar = new TdsAsyncResult (callback, stateObject);
547
548                         stream.BeginRead (tmpBuf, 0, 8, new AsyncCallback(OnReadPacketCallback), ar);
549                         return ar;
550                 }
551                 
552                 /// <returns>Packet size in bytes</returns>
553                 public int EndReadPacket (IAsyncResult ar)
554                 {
555                         if (!ar.IsCompleted)
556                                 ar.AsyncWaitHandle.WaitOne ();
557                         return (int) ((TdsAsyncResult) ar).ReturnValue;
558                 }
559                 
560
561                 public void OnReadPacketCallback (IAsyncResult socketAsyncResult)
562                 {
563                         TdsAsyncResult ar = (TdsAsyncResult) socketAsyncResult.AsyncState;
564                         int nread = stream.EndRead (socketAsyncResult);
565                         
566                         while (nread < 8)
567                                 nread += stream.Read (tmpBuf, nread, 8 - nread);
568
569                         TdsPacketType packetType = (TdsPacketType) tmpBuf[0];
570                         if (packetType != TdsPacketType.Logon && packetType != TdsPacketType.Query && packetType != TdsPacketType.Reply) 
571                         {
572                                 throw new Exception (String.Format ("Unknown packet type {0}", tmpBuf[0]));
573                         }
574
575                         // figure out how many bytes are remaining in this packet.
576                         int len = Ntohs (tmpBuf, 2) - 8;
577
578                         if (len >= inBuffer.Length) 
579                                 inBuffer = new byte[len];
580
581                         if (len < 0) {
582                                 throw new Exception (String.Format ("Confused by a length of {0}", len));
583                         }
584
585                         GetPhysicalPacketData (len);
586                         int value = len + 8;
587                         ar.ReturnValue = ((object)value); // packet size
588                         ar.MarkComplete ();
589                 }
590                 
591                 #endregion // Async Methods
592 #endif // NET_2_0
593
594         }
595
596 }