2002-10-24 Miguel de Icaza <miguel@ximian.com>
[mono.git] / mcs / class / Mono.Data.Tds / Mono.Data.Tds.Protocol / TdsComm.cs
1 //
2 // Mono.Data.TdsClient.Internal.TdsComm.cs
3 //
4 // Author:
5 //   Tim Coleman (tim@timcoleman.com)
6 //
7 // Copyright (C) 2002 Tim Coleman
8 //
9
10 using System;
11 using System.Net.Sockets;
12 using System.Text;
13 using System.Threading;
14
15 namespace Mono.Data.TdsClient.Internal {
16         internal sealed class TdsComm
17         {
18                 #region Fields
19
20                 NetworkStream stream;
21                 int packetSize;
22                 TdsPacketType packetType = TdsPacketType.None;
23                 Encoding encoder;
24
25                 byte[] outBuffer;
26                 int outBufferLength;
27                 int nextOutBufferIndex = 0;
28
29                 byte[] inBuffer;
30                 int inBufferLength;
31                 int inBufferIndex = 0;
32
33                 static int headerLength = 8;
34
35                 byte[] tmpBuf = new byte[8];
36                 byte[] resBuffer = new byte[256];
37
38                 int packetsSent = 0;
39                 int packetsReceived = 0;
40
41                 TdsVersion tdsVersion;
42                 
43                 #endregion // Fields
44                 
45                 #region Constructors
46                 
47                 public TdsComm (Encoding encoder, Socket socket, int packetSize, TdsVersion tdsVersion)
48                 {
49                         this.encoder = encoder;
50                         this.packetSize = packetSize;
51                         this.tdsVersion = tdsVersion;
52
53                         outBuffer = new byte[packetSize];
54                         inBuffer = new byte[packetSize];
55
56                         outBufferLength = packetSize;
57                         inBufferLength = packetSize;
58                         stream = new NetworkStream (socket);
59                 }
60                 
61                 #endregion // Constructors
62                 
63                 #region Properties
64                 
65                 public int PacketSize {
66                         get { return packetSize; }
67                         set { packetSize = value; }
68                 }
69                 
70                 #endregion // Properties
71                 
72                 #region Methods
73
74                 internal void ResizeOutBuf (int newSize)
75                 {
76                         if (newSize > outBufferLength) {
77                                 byte[] newBuf = new byte[newSize];
78                                 Array.Copy (outBuffer, 0, newBuf, 0, outBufferLength);
79                                 outBufferLength = newSize;
80                                 outBuffer = newBuf;
81                         }
82                 }
83                 
84                 public void StartPacket (TdsPacketType type)
85                 {
86                         if (type != TdsPacketType.Cancel && inBufferIndex != inBufferLength)
87                         {
88                                 // SAfe It's ok to throw this exception so that we will know there
89                                 //      is a design flaw somewhere, but we should empty the buffer
90                                 //      however. Otherwise the connection will never close (e.g. if
91                                 //      SHOWPLAN_ALL is ON, a resultset will be returned by commit
92                                 //      or rollback and we will never get rid of it). It's true
93                                 //      that we should find a way to actually process these packets
94                                 //      but for now, just dump them (we have thrown an exception).
95                                 inBufferIndex = inBufferLength;
96                         }
97
98                         packetType = type;
99                         nextOutBufferIndex = headerLength;
100                 }
101
102                 public bool SomeThreadIsBuildingPacket ()
103                 {
104                         return packetType != TdsPacketType.None;
105                 }
106
107                 public void Append (byte b)
108                 {
109                         if (nextOutBufferIndex == outBufferLength) {
110                                 SendPhysicalPacket (false);
111                                 nextOutBufferIndex = headerLength;
112                         }
113                         StoreByte (nextOutBufferIndex, b);
114                         nextOutBufferIndex++;
115                 }       
116                 
117                 public void Append (byte[] b)
118                 {
119                         Append (b, b.Length, (byte) 0);
120                 }               
121
122                 public void Append (byte[] b, int len, byte pad)
123                 {
124                         int i = 0;
125                         for ( ; i < b.Length && i < len; i++)
126                             Append (b[i]);
127
128                         for ( ; i < len; i++)
129                             Append (pad);
130                 }       
131
132                 public void Append (short s)
133                 {
134                         if (tdsVersion < TdsVersion.tds70) {
135                                 Append ((byte) (((byte) (s >> 8)) & 0xff));
136                                 Append ((byte) (((byte) (s >> 0)) & 0xff));
137                         }
138                         else
139                                 Append (BitConverter.GetBytes (s));
140                 }
141
142                 public void Append (int i)
143                 {
144                         if (tdsVersion < TdsVersion.tds70) {
145                                 Append ((byte) (((byte) (i >> 24)) & 0xff));
146                                 Append ((byte) (((byte) (i >> 16)) & 0xff));
147                                 Append ((byte) (((byte) (i >> 8)) & 0xff));
148                                 Append ((byte) (((byte) (i >> 0)) & 0xff));
149                         } 
150                         else
151                                 Append (BitConverter.GetBytes (i));
152                 }
153
154                 public void Append (string s)
155                 {
156                         if (tdsVersion < TdsVersion.tds70) 
157                                 Append (encoder.GetBytes (s));
158                         else 
159                                 foreach (char c in s)
160                                         Append (BitConverter.GetBytes (c));
161                 }       
162
163                 // Appends with padding
164                 public byte[] Append (string s, int len, byte pad)
165                 {
166                         byte[] result = encoder.GetBytes (s);
167                         Append (result, len, pad);
168                         return result;
169                 }
170
171                 public void Append (double value)
172                 {
173                         Append (BitConverter.DoubleToInt64Bits (value));
174                 }
175
176                 public void Append (long l)
177                 {
178                         if (tdsVersion < TdsVersion.tds70) {
179                                 Append ((byte) (((byte) (l >> 56)) & 0xff));
180                                 Append ((byte) (((byte) (l >> 48)) & 0xff));
181                                 Append ((byte) (((byte) (l >> 40)) & 0xff));
182                                 Append ((byte) (((byte) (l >> 32)) & 0xff));
183                                 Append ((byte) (((byte) (l >> 24)) & 0xff));
184                                 Append ((byte) (((byte) (l >> 16)) & 0xff));
185                                 Append ((byte) (((byte) (l >> 8)) & 0xff));
186                                 Append ((byte) (((byte) (l >> 0)) & 0xff));
187                         }
188                         else {
189                                 Append (BitConverter.GetBytes (l));
190                         }
191                 }
192
193                 public void SendPacket ()
194                 {
195                         SendPhysicalPacket (true);
196                         nextOutBufferIndex = 0;
197                         packetType = TdsPacketType.None;
198                 }
199                 
200                 private void StoreByte (int index, byte value)
201                 {
202                         outBuffer[index] = value;
203                 }               
204
205                 private void StoreShort (int index, short s)
206                 {
207                         outBuffer[index] = (byte) (((byte) (s >> 8)) & 0xff);
208                         outBuffer[index + 1] = (byte) (((byte) (s >> 0)) & 0xff);
209                 }
210
211                 private void SendPhysicalPacket (bool isLastSegment)
212                 {
213                         if (nextOutBufferIndex > headerLength || packetType == TdsPacketType.Cancel) {
214                                 // packet type
215                                 StoreByte (0, (byte) packetType);
216                                 StoreByte (1, (byte) (isLastSegment ? 1 : 0));
217                                 StoreShort (2, (short) nextOutBufferIndex );
218                                 StoreByte (4, (byte) 0);
219                                 StoreByte (5, (byte) 0);
220                                 StoreByte (6, (byte) (tdsVersion == TdsVersion.tds70 ? 0x1 : 0x0));
221                                 StoreByte (7, (byte) 0);
222
223                                 stream.Write (outBuffer, 0, nextOutBufferIndex);
224                                 stream.Flush ();
225                                 packetsSent++;
226                         }
227                 }
228                 
229                 public byte Peek ()
230                 {
231                         // If out of data, read another physical packet.
232                         if (inBufferIndex >= inBufferLength)
233                                 GetPhysicalPacket ();
234
235                         return inBuffer[inBufferIndex];
236                 }
237
238
239                 public byte GetByte ()
240                 {
241                         byte result;
242
243                         if (inBufferIndex >= inBufferLength) {
244                                 // out of data, read another physical packet.
245                                 GetPhysicalPacket ();
246                         }
247
248                         result = inBuffer[inBufferIndex++];
249                         return result;
250                 }
251
252                 public byte[] GetBytes (int len, bool exclusiveBuffer)
253                 {
254                         byte[] result = null;
255                         int i;
256
257                         // Do not keep an internal result buffer larger than 16k.
258                         // This would unnecessarily use up memory.
259                         if (exclusiveBuffer || len > 16384)
260                                 result = new byte[len];
261                         else
262                         {
263                                 if (resBuffer.Length < len)
264                                         resBuffer = new byte[len];
265                                 result = resBuffer;
266                         }
267
268                         for (i = 0; i<len; )
269                         {
270                                 if (inBufferIndex >= inBufferLength)
271                                         GetPhysicalPacket ();
272
273                                 int avail = inBufferLength - inBufferIndex;
274                                 avail = avail>len-i ? len-i : avail;
275
276                                 System.Array.Copy (inBuffer, inBufferIndex, result, i, avail);
277                                 i += avail;
278                                 inBufferIndex += avail;
279                         }
280
281                         return result;
282                 }
283
284                 public string GetString (int len)
285                 {
286                         if (tdsVersion == TdsVersion.tds70) {
287                                 char[] chars = new char[len];
288                                 for (int i = 0; i < len; ++i) {
289                                         int lo = ((byte) GetByte ()) & 0xFF;
290                                         int hi = ((byte) GetByte ()) & 0xFF;
291                                         chars[i] = (char) (lo | ( hi << 8));
292                                 }
293                                 return new String (chars);
294                         }
295                         else {
296                                 byte[] result = new byte[len + 1];
297                                 Array.Copy (GetBytes (len, false), result, len);
298                                 result[len] = (byte) 0;
299                                 return (encoder.GetString (result));
300                         }
301                 }
302
303                 public void Skip (int i)
304                 {
305                         for ( ; i > 0; i--)
306                                 GetByte ();
307                 }
308                 // skip()
309
310
311                 public int GetNetShort ()
312                 {
313                         byte[] tmp = new byte[2];
314                         tmp[0] = GetByte ();
315                         tmp[1] = GetByte ();
316                         return Ntohs (tmp, 0);
317                 }
318
319                 public short GetTdsShort ()
320                 {
321                         byte[] input = new byte[2];
322
323                         for (int i = 0; i < 2; i += 1)
324                                 input[i] = GetByte ();
325
326                         return (BitConverter.ToInt16 (input, 0));
327                 }
328
329
330                 public int GetTdsInt ()
331                 {
332                         byte[] input = new byte[4];
333                         for (int i = 0; i < 4; i += 1)
334                                 input[i] = GetByte ();
335                         return (BitConverter.ToInt32 (input, 0));
336                 }
337
338                 public long GetTdsInt64 ()
339                 {
340                         byte[] input = new byte[8];
341                         for (int i = 0; i < 8; i += 1)
342                                 input[i] = GetByte ();
343                         return (BitConverter.ToInt64 (input, 0));
344                 }
345
346                 private void GetPhysicalPacket ()
347                 {
348                         int nread = 0;
349
350                         // read the header
351                         while (nread < 8)
352                                 nread += stream.Read (tmpBuf, nread, 8 - nread);
353
354                         TdsPacketType packetType = (TdsPacketType) tmpBuf[0];
355                         if (packetType != TdsPacketType.Logon && packetType != TdsPacketType.Query && packetType != TdsPacketType.Reply) {
356                                 throw new TdsException (String.Format ("Unknown packet type {0}", tmpBuf[0]));
357                         }
358
359                         // figure out how many bytes are remaining in this packet.
360                         int len = Ntohs (tmpBuf, 2) - 8;
361
362                         if (len >= inBuffer.Length) 
363                                 inBuffer = new byte[len];
364
365                         if (len < 0) {
366                                 throw new TdsException (String.Format ("Confused by a length of {0}", len));
367                         }
368
369                         // now get the data
370                         nread = 0;
371                         while (nread < len) {
372                                 nread += stream.Read (inBuffer, nread, len - nread);
373                         }
374
375                         packetsReceived++;
376
377                         // adjust the bookkeeping info about the incoming buffer
378                         inBufferLength = len;
379                         inBufferIndex = 0;
380                 }
381
382                 private static int Ntohs (byte[] buf, int offset)
383                 {
384                         int lo = ((int) buf[offset + 1] & 0xff);
385                         int hi = (((int) buf[offset] & 0xff ) << 8);
386
387                         return hi | lo;
388                         // return an int since we really want an _unsigned_
389                 }               
390                 #endregion // Methods
391         }
392
393 }