Merge pull request #1163 from AerisG222/more_mvc_fixes
[mono.git] / mcs / class / Mono.Data.Tds / Mono.Data.Tds.Protocol / Tds.cs
index 76a810ab80a5f151312757dfbf400143a57710c1..f06932c802620d1fed8f3fcb74fb9d5e0fb1ee26 100644 (file)
@@ -41,7 +41,9 @@ using System.ComponentModel;
 using System.Diagnostics;
 using System.Net.Sockets;
 using System.Globalization;
+using System.Security;
 using System.Text;
+using System.Runtime.InteropServices;
 
 namespace Mono.Data.Tds.Protocol
 {
@@ -64,6 +66,9 @@ namespace Mono.Data.Tds.Protocol
                int databaseMajorVersion;
                CultureInfo locale = CultureInfo.InvariantCulture;
 
+               readonly int lifeTime;
+               readonly DateTime created = DateTime.Now;
+
                string charset;
                string language;
 
@@ -147,7 +152,7 @@ namespace Mono.Data.Tds.Protocol
                        get { return dataSource; }
                }
 
-               public bool IsConnected {
+               public virtual bool IsConnected {
                        get { return connected && comm != null && comm.IsConnected (); }
                        set { connected = value; }
                }
@@ -416,12 +421,23 @@ namespace Mono.Data.Tds.Protocol
                #region Constructors
 
                public Tds (string dataSource, int port, int packetSize, int timeout, TdsVersion tdsVersion)
+                       : this  (dataSource, port, packetSize, timeout, 0, tdsVersion)
+               {
+               }
+
+               public Tds (string dataSource, int port, int packetSize, int timeout, int lifeTime, TdsVersion tdsVersion)
                {
                        this.tdsVersion = tdsVersion;
                        this.packetSize = packetSize;
                        this.dataSource = dataSource;
                        this.columns = new TdsDataColumnCollection ();
+                       this.lifeTime = lifeTime;
 
+                       InitComm (port, timeout);
+               }
+
+               protected virtual void InitComm (int port, int timeout)
+               {
                        comm = new TdsComm (dataSource, port, packetSize, timeout, tdsVersion);
                }
 
@@ -429,6 +445,14 @@ namespace Mono.Data.Tds.Protocol
 
                #region Public Methods
 
+               internal bool Expired {
+                       get {
+                               if (lifeTime == 0)
+                                       return false;
+                               return DateTime.Now > (created + TimeSpan.FromSeconds (lifeTime));
+                       }
+               }
+
                internal protected void InitExec () 
                {
                        // clean up
@@ -477,7 +501,7 @@ namespace Mono.Data.Tds.Protocol
                        return new TdsTimeoutException (0, 0, message, -2, method, dataSource, "Mono TdsClient Data Provider", 0);
                }
 
-               public void Disconnect ()
+               public virtual void Disconnect ()
                {
                        try {
                                comm.StartPacket (TdsPacketType.Logoff);
@@ -1441,13 +1465,12 @@ namespace Mono.Data.Tds.Protocol
                        // 0x0200       Negotiate NTLM
                        // 0x8000       Negotiate Always Sign
 
-                       Type3Message t3 = new Type3Message ();
-                       t3.Challenge = t2.Nonce;
+                       Type3Message t3 = new Type3Message (t2);
                        
                        t3.Domain = this.connectionParms.DefaultDomain;
                        t3.Host = this.connectionParms.Hostname;
                        t3.Username = this.connectionParms.User;
-                       t3.Password = this.connectionParms.Password;
+                       t3.Password = GetPlainPassword(this.connectionParms.Password);
 
                        Comm.StartPacket (TdsPacketType.SspAuth); // 0x11
                        Comm.Append (t3.GetBytes ());
@@ -1898,6 +1921,20 @@ namespace Mono.Data.Tds.Protocol
                        comm.Skip(4);
                }
 
+               public static string GetPlainPassword(SecureString secPass)
+               {
+                       IntPtr plainString = IntPtr.Zero;
+                       try
+                       {
+                               plainString = Marshal.SecureStringToGlobalAllocUnicode(secPass);
+                               return Marshal.PtrToStringUni(plainString);
+                       }
+                       finally
+                       {
+                               Marshal.ZeroFreeGlobalAllocUnicode(plainString);
+                       }
+               }
+
                #endregion // Private Methods
 
 #if NET_2_0