2002-11-09 Tim Coleman <tim@timcoleman.com>
[mono.git] / mcs / class / Mono.Data.TdsClient / Mono.Data.TdsClient.Internal / TdsComm.cs
index 818753db0086491d02c4f263dbd4fd157f28746e..a199b058ede5ff8a2e40edcf756d8cd20e9c38ea 100644 (file)
@@ -8,6 +8,7 @@
 //
 
 using System;
+using System.Net;
 using System.Net.Sockets;
 using System.Text;
 using System.Threading;
@@ -20,6 +21,11 @@ namespace Mono.Data.TdsClient.Internal {
                NetworkStream stream;
                int packetSize;
                TdsPacketType packetType = TdsPacketType.None;
+               Encoding encoder;
+
+               string dataSource;
+               int commandTimeout;
+               int connectionTimeout;
 
                byte[] outBuffer;
                int outBufferLength;
@@ -37,21 +43,56 @@ namespace Mono.Data.TdsClient.Internal {
                int packetsSent = 0;
                int packetsReceived = 0;
 
+               Socket socket;
                TdsVersion tdsVersion;
+
+               ManualResetEvent connected = new ManualResetEvent (false);
                
                #endregion // Fields
                
                #region Constructors
                
-               public TdsComm (Socket socket, int packetSize, TdsVersion tdsVersion)
+               public TdsComm (string dataSource, int port, int packetSize, int timeout, TdsVersion tdsVersion)
                {
                        this.packetSize = packetSize;
+                       this.tdsVersion = tdsVersion;
+                       this.dataSource = dataSource;
+                       this.connectionTimeout = timeout;
+
+                       outBuffer = new byte[packetSize];
+                       inBuffer = new byte[packetSize];
+
+                       outBufferLength = packetSize;
+                       inBufferLength = packetSize;
+
+                       socket = new Socket (AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
+                       IPHostEntry hostEntry = Dns.Resolve (dataSource);
+                       IPEndPoint endPoint;
+                       endPoint = new IPEndPoint (hostEntry.AddressList [0], port);
+
+                       connected.Reset ();
+                       IAsyncResult asyncResult = socket.BeginConnect (endPoint, new AsyncCallback (ConnectCallback), socket);
+
+                       if (timeout > 0 && !connected.WaitOne (new TimeSpan (0, 0, timeout), true))
+                               throw Tds.CreateTimeoutException (dataSource, "Open()");
+                       else if (timeout > 0 && !connected.WaitOne ())
+                               throw Tds.CreateTimeoutException (dataSource, "Open()");
+
                        stream = new NetworkStream (socket);
                }
                
                #endregion // Constructors
                
                #region Properties
+
+               public int CommandTimeout {
+                       get { return commandTimeout; }
+                       set { commandTimeout = value; }
+               }
+
+               internal Encoding Encoder {
+                       set { encoder = value; }
+               }
                
                public int PacketSize {
                        get { return packetSize; }
@@ -61,177 +102,104 @@ namespace Mono.Data.TdsClient.Internal {
                #endregion // Properties
                
                #region Methods
-               
-               public void StartPacket (TdsPacketType type)
-               {
-                       if (type != TdsPacketType.Cancel && inBufferIndex != inBufferLength)
-                       {
-                               // SAfe It's ok to throw this exception so that we will know there
-                               //      is a design flaw somewhere, but we should empty the buffer
-                               //      however. Otherwise the connection will never close (e.g. if
-                               //      SHOWPLAN_ALL is ON, a resultset will be returned by commit
-                               //      or rollback and we will never get rid of it). It's true
-                               //      that we should find a way to actually process these packets
-                               //      but for now, just dump them (we have thrown an exception).
-                               inBufferIndex = inBufferLength;
-                       }
-
-                       // Only one thread at a time can be building an outboudn packet.
-                       // This is primarily a concern with building cancel packets.
-                       //  XXX: as why should more than one thread work with the same tds-stream ??? would be fatal anyway
-
-                       Monitor.Enter (packetType);
-
-                       packetType = type;
-                       nextOutBufferIndex = headerLength;
-               }
 
-               public bool SomeThreadIsBuildingPacket ()
-               {
-                       return packetType != TdsPacketType.None;
-               }
-
-               public void AppendByte (byte b)
+               public void Append (byte b)
                {
                        if (nextOutBufferIndex == outBufferLength) {
-                               // If we have a full physical packet then ship it out to the
-                               // network.
                                SendPhysicalPacket (false);
                                nextOutBufferIndex = headerLength;
                        }
-       
-                       StoreByte( nextOutBufferIndex, b );
+                       Store (nextOutBufferIndex, b);
                        nextOutBufferIndex++;
                }       
                
-               public void AppendBytes (byte[] b)
+               public void Append (byte[] b)
                {
-                       AppendBytes (b, b.Length, (byte) 0);
+                       Append (b, b.Length, (byte) 0);
                }               
 
-               public void AppendBytes (byte[] b, int len, byte pad)
+               public void Append (byte[] b, int len, byte pad)
                {
                        int i = 0;
                        for ( ; i < b.Length && i < len; i++)
-                           AppendByte( b[i] );
+                           Append (b[i]);
 
                        for ( ; i < len; i++)
-                           AppendByte( pad );
+                           Append (pad);
                }       
 
-
-               public void AppendShort (short s)
+               public void Append (short s)
                {
-                       AppendByte ((byte) ((s >> 8) & 0xff));
-                       AppendByte ((byte) ((s >> 0) & 0xff));
+                       if (tdsVersion < TdsVersion.tds70) {
+                               Append ((byte) (((byte) (s >> 8)) & 0xff));
+                               Append ((byte) (((byte) (s >> 0)) & 0xff));
+                       }
+                       else
+                               Append (BitConverter.GetBytes (s));
                }
 
-               public void AppendTdsShort (short s)
+               public void Append (int i)
                {
-                       AppendByte ((byte ) ((s >> 0) & 0xff));
-                       AppendByte ((byte ) ((s >> 8) & 0xff));
+                       if (tdsVersion < TdsVersion.tds70) {
+                               Append ((byte) (((byte) (i >> 24)) & 0xff));
+                               Append ((byte) (((byte) (i >> 16)) & 0xff));
+                               Append ((byte) (((byte) (i >> 8)) & 0xff));
+                               Append ((byte) (((byte) (i >> 0)) & 0xff));
+                       } 
+                       else
+                               Append (BitConverter.GetBytes (i));
                }
 
-               public void AppendFlt8 (double value)
+               public void Append (string s)
                {
-                       long l = BitConverter.DoubleToInt64Bits (value);
-
-                       AppendByte ((byte) ((l >> 0) & 0xff));
-                       AppendByte ((byte) ((l >> 8) & 0xff));
-                       AppendByte ((byte) ((l >> 16) & 0xff));
-                       AppendByte ((byte) ((l >> 24) & 0xff));
-                       AppendByte ((byte) ((l >> 32) & 0xff));
-                       AppendByte ((byte) ((l >> 40) & 0xff));
-                       AppendByte ((byte) ((l >> 48) & 0xff));
-                       AppendByte ((byte) ((l >> 56) & 0xff));
-               }
+                       if (tdsVersion < TdsVersion.tds70) 
+                               Append (encoder.GetBytes (s));
+                       else 
+                               foreach (char c in s)
+                                       Append (BitConverter.GetBytes (c));
+               }       
 
-               public void AppendInt (int i)
+               // Appends with padding
+               public byte[] Append (string s, int len, byte pad)
                {
-                       AppendByte ((byte) ((i >> 24) & 0xff));
-                       AppendByte ((byte) ((i >> 16) & 0xff));
-                       AppendByte ((byte) ((i >> 8) & 0xff));
-                       AppendByte ((byte) ((i >> 0) & 0xff));
-               }
+                       if (s == null)
+                               return new byte[0];
 
-               public void AppendTdsInt (int i)
-               {
-                       AppendByte ((byte) ((i >> 0) & 0xff));
-                       AppendByte ((byte) ((i >> 8) & 0xff));
-                       AppendByte ((byte) ((i >> 16) & 0xff));
-                       AppendByte ((byte) ((i >> 24) & 0xff));
+                       byte[] result = encoder.GetBytes (s);
+                       Append (result, len, pad);
+                       return result;
                }
 
-
-               public void AppendInt64 (long i)
+               public void Append (double value)
                {
-                       AppendByte ((byte) ((i >> 56) & 0xff));
-                       AppendByte ((byte) ((i >> 48) & 0xff));
-                       AppendByte ((byte) ((i >> 40) & 0xff));
-                       AppendByte ((byte) ((i >> 32) & 0xff));
-                       AppendByte ((byte) ((i >> 24) & 0xff));
-                       AppendByte ((byte) ((i >> 16) & 0xff));
-                       AppendByte ((byte) ((i >> 8) & 0xff));
-                       AppendByte ((byte) ((i >> 0) & 0xff));
+                       Append (BitConverter.DoubleToInt64Bits (value));
                }
 
-               public void AppendChars (string s)
+               public void Append (long l)
                {
-                       foreach (char c in s)
-                       {
-                               byte b1 = (byte) (c & 0xFF);
-                               byte b2 = (byte) ((c >> 8) & 0xFF);
-                               AppendByte (b1);
-                               AppendByte (b2);
+                       if (tdsVersion < TdsVersion.tds70) {
+                               Append ((byte) (((byte) (l >> 56)) & 0xff));
+                               Append ((byte) (((byte) (l >> 48)) & 0xff));
+                               Append ((byte) (((byte) (l >> 40)) & 0xff));
+                               Append ((byte) (((byte) (l >> 32)) & 0xff));
+                               Append ((byte) (((byte) (l >> 24)) & 0xff));
+                               Append ((byte) (((byte) (l >> 16)) & 0xff));
+                               Append ((byte) (((byte) (l >> 8)) & 0xff));
+                               Append ((byte) (((byte) (l >> 0)) & 0xff));
+                       }
+                       else {
+                               Append (BitConverter.GetBytes (l));
                        }
-               }       
-
-               public void SendPacket ()
-               {
-                       Monitor.Pulse (packetType);
-                       SendPhysicalPacket (true);
-                       nextOutBufferIndex = 0;
-                       packetType = TdsPacketType.None;
-                       Monitor.Exit (packetType);
-               }
-               
-               private void StoreByte (int index, byte value)
-               {
-                       outBuffer[index] = value;
-               }               
-
-               private void StoreShort (int index, short s)
-               {
-                       outBuffer[index] = (byte) ((s >> 8) & 0xff);
-                       outBuffer[index + 1] = (byte) ((s >> 0) & 0xff);
                }
 
-               private void SendPhysicalPacket (bool isLastSegment)
+               private void ConnectCallback (IAsyncResult ar)
                {
-                       if (nextOutBufferIndex > headerLength || packetType == TdsPacketType.Cancel) {
-                               // packet type
-                               StoreByte (0, (byte) ((byte) packetType & 0xff));
-                               StoreByte (1, isLastSegment ? (byte) 1 : (byte) 0);
-                               StoreShort (2, (short) nextOutBufferIndex );
-                               StoreByte (4, (byte) 0);
-                               StoreByte (5, (byte) 0);
-                               StoreByte (6, (byte) (tdsVersion == TdsVersion.tds70 ? 1 : 0));
-                               StoreByte (7, (byte) 0);
-
-                               stream.Write (outBuffer, 0, nextOutBufferIndex);
-                               packetsSent++;
+                       Socket s = (Socket) ar.AsyncState;
+                       if (Poll (s, connectionTimeout, SelectMode.SelectWrite)) {
+                               socket.EndConnect (ar);
+                               connected.Set ();
                        }
-               }
-               
-               public byte Peek ()
-               {
-                       // If out of data, read another physical packet.
-                       if (inBufferIndex >= inBufferLength)
-                               GetPhysicalPacket ();
-
-                       return inBuffer[inBufferIndex];
-               }
-
+                }
 
                public byte GetByte ()
                {
@@ -280,32 +248,31 @@ namespace Mono.Data.TdsClient.Internal {
 
                public string GetString (int len)
                {
-                       if (tdsVersion == TdsVersion.tds70) {
+                       if (tdsVersion == TdsVersion.tds70) 
+                               return GetString (len, true);
+                       else
+                               return GetString (len, false);
+               }
+
+               public string GetString (int len, bool wide)
+               {
+                       if (wide) {
                                char[] chars = new char[len];
                                for (int i = 0; i < len; ++i) {
-                                       int lo = GetByte () & 0xFF;
-                                       int hi = GetByte () & 0xFF;
+                                       int lo = ((byte) GetByte ()) & 0xFF;
+                                       int hi = ((byte) GetByte ()) & 0xFF;
                                        chars[i] = (char) (lo | ( hi << 8));
                                }
                                return new String (chars);
                        }
                        else {
-                               byte[] result = GetBytes (len, false);
-                               StringBuilder sb = new StringBuilder ();
-                               foreach (byte b in result)
-                                       sb.Append (b);
-                               return sb.ToString ();
+                               byte[] result = new byte[len + 1];
+                               Array.Copy (GetBytes (len, false), result, len);
+                               result[len] = (byte) 0;
+                               return (encoder.GetString (result));
                        }
                }
 
-               public void Skip (int i)
-               {
-                       for ( ; i > 0; i--)
-                               GetByte ();
-               }
-               // skip()
-
-
                public int GetNetShort ()
                {
                        byte[] tmp = new byte[2];
@@ -314,50 +281,45 @@ namespace Mono.Data.TdsClient.Internal {
                        return Ntohs (tmp, 0);
                }
 
-               public int GetTdsShort ()
+               public short GetTdsShort ()
                {
-                       int lo = ((int) GetByte () & 0xff);
-                       int hi = ((int) GetByte () & 0xff) << 8;
-                       return lo | hi;
+                       byte[] input = new byte[2];
+
+                       for (int i = 0; i < 2; i += 1)
+                               input[i] = GetByte ();
+
+                       return (BitConverter.ToInt16 (input, 0));
                }
 
 
                public int GetTdsInt ()
                {
-                       int result;
-
-                       int b1 = ((int) GetByte () & 0xff);
-                       int b2 = ((int) GetByte () & 0xff) << 8;
-                       int b3 = ((int) GetByte () & 0xff) << 16;
-                       int b4 = ((int) GetByte () & 0xff) << 24;
-
-                       result = b4 | b3 | b2 | b1;
-
-                       return result;
+                       byte[] input = new byte[4];
+                       for (int i = 0; i < 4; i += 1)
+                               input[i] = GetByte ();
+                       return (BitConverter.ToInt32 (input, 0));
                }
 
                public long GetTdsInt64 ()
                {
-                       long b1 = ((long) GetByte () & 0xff);
-                       long b2 = ((long) GetByte () & 0xff) << 8;
-                       long b3 = ((long) GetByte () & 0xff) << 16;
-                       long b4 = ((long) GetByte () & 0xff) << 24;
-                       long b5 = ((long) GetByte () & 0xff) << 32;
-                       long b6 = ((long) GetByte () & 0xff) << 40;
-                       long b7 = ((long) GetByte () & 0xff) << 48;
-                       long b8 = ((long) GetByte () & 0xff) << 56;
-                       return b1 | b2 | b3 | b4 | b5 | b6 | b7 | b8;
+                       byte[] input = new byte[8];
+                       for (int i = 0; i < 8; i += 1)
+                               input[i] = GetByte ();
+                       return (BitConverter.ToInt64 (input, 0));
                }
 
                private void GetPhysicalPacket ()
                {
+                       int nread = 0;
+
                        // read the header
-                       for (int nread = 0; nread < 8; ) 
+                       while (nread < 8)
                                nread += stream.Read (tmpBuf, nread, 8 - nread);
 
                        TdsPacketType packetType = (TdsPacketType) tmpBuf[0];
-                       if (packetType != TdsPacketType.Logon && packetType != TdsPacketType.Query && packetType != TdsPacketType.Reply) {
-                               //throw new TdsUnknownPacketType (packetType, tmpBuf);
+                       if (packetType != TdsPacketType.Logon && packetType != TdsPacketType.Query && packetType != TdsPacketType.Reply) 
+                       {
+                               throw new Exception (String.Format ("Unknown packet type {0}", tmpBuf[0]));
                        }
 
                        // figure out how many bytes are remaining in this packet.
@@ -367,12 +329,14 @@ namespace Mono.Data.TdsClient.Internal {
                                inBuffer = new byte[len];
 
                        if (len < 0) {
-                               //throw new TdsException ("Confused by a length of " + len);
+                               throw new Exception (String.Format ("Confused by a length of {0}", len));
                        }
 
                        // now get the data
-                       for (int nread = 0; nread < len; )
+                       nread = 0;
+                       while (nread < len) {
                                nread += stream.Read (inBuffer, nread, len - nread);
+                       }
 
                        packetsReceived++;
 
@@ -389,6 +353,105 @@ namespace Mono.Data.TdsClient.Internal {
                        return hi | lo;
                        // return an int since we really want an _unsigned_
                }               
+
+               public byte Peek ()
+               {
+                       // If out of data, read another physical packet.
+                       if (inBufferIndex >= inBufferLength)
+                               GetPhysicalPacket ();
+
+                       return inBuffer[inBufferIndex];
+               }
+
+               public bool Poll (int seconds, SelectMode selectMode)
+               {
+                       return Poll (socket, seconds, selectMode);
+               }
+
+               private bool Poll (Socket s, int seconds, SelectMode selectMode)
+               {
+                       long uSeconds = seconds * 1000000;
+                       bool bState = false;
+
+                       while (uSeconds > (long) Int32.MaxValue) {
+                               bState = s.Poll (Int32.MaxValue, selectMode);
+                               if (bState) 
+                                       return true;
+                               uSeconds -= Int32.MaxValue;
+                       }
+                       return s.Poll ((int) uSeconds, selectMode);
+               }
+
+               internal void ResizeOutBuf (int newSize)
+               {
+                       if (newSize > outBufferLength) {
+                               byte[] newBuf = new byte [newSize];
+                               Array.Copy (outBuffer, 0, newBuf, 0, outBufferLength);
+                               outBufferLength = newSize;
+                               outBuffer = newBuf;
+                       }
+               }
+
+               public void SendPacket ()
+               {
+                       SendPhysicalPacket (true);
+                       nextOutBufferIndex = 0;
+                       packetType = TdsPacketType.None;
+               }
+               
+               private void SendPhysicalPacket (bool isLastSegment)
+               {
+                       if (nextOutBufferIndex > headerLength || packetType == TdsPacketType.Cancel) {
+                               // packet type
+                               Store (0, (byte) packetType);
+                               Store (1, (byte) (isLastSegment ? 1 : 0));
+                               Store (2, (short) nextOutBufferIndex );
+                               Store (4, (byte) 0);
+                               Store (5, (byte) 0);
+                               Store (6, (byte) (tdsVersion == TdsVersion.tds70 ? 0x1 : 0x0));
+                               Store (7, (byte) 0);
+
+                               stream.Write (outBuffer, 0, nextOutBufferIndex);
+                               stream.Flush ();
+                               packetsSent++;
+                       }
+               }
+               
+               public void Skip (int i)
+               {
+                       for ( ; i > 0; i--)
+                               GetByte ();
+               }
+
+               public void StartPacket (TdsPacketType type)
+               {
+                       if (type != TdsPacketType.Cancel && inBufferIndex != inBufferLength)
+                       {
+                               // SAfe It's ok to throw this exception so that we will know there
+                               //      is a design flaw somewhere, but we should empty the buffer
+                               //      however. Otherwise the connection will never close (e.g. if
+                               //      SHOWPLAN_ALL is ON, a resultset will be returned by commit
+                               //      or rollback and we will never get rid of it). It's true
+                               //      that we should find a way to actually process these packets
+                               //      but for now, just dump them (we have thrown an exception).
+                               inBufferIndex = inBufferLength;
+                       }
+
+                       packetType = type;
+                       nextOutBufferIndex = headerLength;
+               }
+
+               private void Store (int index, byte value)
+               {
+                       outBuffer[index] = value;
+               }               
+
+               private void Store (int index, short value)
+               {
+                       outBuffer[index] = (byte) (((byte) (value >> 8)) & 0xff);
+                       outBuffer[index + 1] = (byte) (((byte) (value >> 0)) & 0xff);
+               }
+
                #endregion // Methods
        }