2008-11-05 Francisco Figueiredo Jr. <francisco@npgsql.org>
[mono.git] / mcs / class / Npgsql / Npgsql / SSPIHandler.cs
1 #if WINDOWS && UNMANAGED
2
3 using System;
4 using System.IO;
5 using System.Runtime.InteropServices;
6 using System.ComponentModel;
7
8 namespace Npgsql
9 {
10     /// <summary>
11     /// A class to handle everything associated with SSPI authentication
12     /// </summary>
13     internal class SSPIHandler : IDisposable
14     {
15         #region constants and structs
16
17         private const int SECBUFFER_VERSION = 0;
18         private const int SECBUFFER_TOKEN = 2;
19         private const int SEC_E_OK = 0x00000000;
20         private const int SEC_I_CONTINUE_NEEDED = 0x00090312;
21         private const int ISC_REQ_ALLOCATE_MEMORY=0x00000100;
22         private const int SECURITY_NETWORK_DREP=0x00000000;
23         private const int SECPKG_CRED_OUTBOUND=0x00000002;
24
25         [StructLayout(LayoutKind.Sequential)]
26         private struct SecHandle
27         {
28             public int dwLower;
29             public int dwUpper;
30         }
31
32         [StructLayout(LayoutKind.Sequential)]
33         private struct SecBuffer
34         {
35             public int cbBuffer;
36             public int BufferType;
37             public IntPtr pvBuffer;
38         }
39
40         /// <summary>
41         /// Simplified SecBufferDesc struct with only one SecBuffer
42         /// </summary>
43         [StructLayout(LayoutKind.Sequential)]
44         private struct SecBufferDesc
45         {
46             public int ulVersion;
47             public int cBuffers;
48             public IntPtr pBuffer;
49         }
50
51         #endregion
52
53         #region p/invoke methods
54
55         [DllImport("Secur32.dll")]
56         private extern static int AcquireCredentialsHandle(
57             string pszPrincipal,
58             string pszPackage,
59             int fCredentialUse,
60             IntPtr pvLogonID,
61             IntPtr pAuthData,
62             IntPtr pGetKeyFn,
63             IntPtr pvGetKeyArgument,
64             ref SecHandle phCredential,
65             out SecHandle ptsExpiry
66         );
67
68         [DllImport("secur32", CharSet=CharSet.Auto, SetLastError=true)]
69         static extern int InitializeSecurityContext(
70             ref SecHandle phCredential,
71             ref SecHandle phContext,
72             string pszTargetName,
73             int fContextReq,
74             int Reserved1,
75             int TargetDataRep,
76             ref SecBufferDesc pInput,
77             int Reserved2,
78             out SecHandle phNewContext,
79             out SecBufferDesc pOutput,
80             out int pfContextAttr,
81             out SecHandle ptsExpiry);
82
83         [DllImport("secur32", CharSet=CharSet.Auto, SetLastError=true)]
84         static extern int InitializeSecurityContext(
85             ref SecHandle phCredential,
86             IntPtr phContext,
87             string pszTargetName,
88             int fContextReq,
89             int Reserved1,
90             int TargetDataRep,
91             IntPtr pInput,
92             int Reserved2,
93             out SecHandle phNewContext,
94             out SecBufferDesc pOutput,
95             out int pfContextAttr,
96             out SecHandle ptsExpiry);
97
98         [DllImport("Secur32.dll")]
99         private extern static int FreeContextBuffer(
100             IntPtr pvContextBuffer
101         );
102
103         [DllImport("Secur32.dll")]
104         private extern static int FreeCredentialsHandle(
105             ref SecHandle phCredential
106         );
107
108         [DllImport("Secur32.dll")]
109         private extern static int DeleteSecurityContext(
110             ref SecHandle phContext
111         );
112
113         #endregion
114
115         private bool disposed;
116         private string sspitarget;
117         private SecHandle sspicred;
118         private SecHandle sspictx;
119         private bool sspictx_set;
120
121         public SSPIHandler(string pghost, string krbsrvname)
122         {
123             if (pghost == null)
124                 throw new ArgumentNullException("pghost");
125             if (krbsrvname == null)
126                 krbsrvname = String.Empty;
127             sspitarget = String.Format("{0}/{1}", krbsrvname, pghost);
128
129             SecHandle expire;
130             int status = AcquireCredentialsHandle(
131                 "",
132                 "negotiate",
133                 SECPKG_CRED_OUTBOUND,
134                 IntPtr.Zero,
135                 IntPtr.Zero,
136                 IntPtr.Zero,
137                 IntPtr.Zero,
138                 ref sspicred,
139                 out expire
140             );
141             if (status != SEC_E_OK)
142             {
143                 // This will automaticcaly fill in the message of the last Win32 error
144                 throw new Win32Exception();
145             }
146         }
147
148         public string Continue(byte[] authData)
149         {
150             if (authData == null && sspictx_set)
151                 throw new InvalidOperationException("The authData parameter con only be null at the first call to continue!");
152
153
154             int status;
155
156             SecBuffer OutBuffer;
157             SecBuffer InBuffer;
158                 SecBufferDesc inbuf;
159                 SecBufferDesc outbuf;
160             SecHandle newContext;
161             SecHandle expire;
162             int contextAttr;
163
164             OutBuffer.pvBuffer = IntPtr.Zero;
165             OutBuffer.BufferType = SECBUFFER_TOKEN;
166             OutBuffer.cbBuffer = 0;
167             outbuf.cBuffers = 1;
168             outbuf.ulVersion = SECBUFFER_VERSION;
169             outbuf.pBuffer = Marshal.AllocHGlobal(Marshal.SizeOf(OutBuffer));
170
171             try
172             {
173                 Marshal.StructureToPtr(OutBuffer, outbuf.pBuffer, false);
174                 if (sspictx_set)
175                 {
176                     inbuf.pBuffer = IntPtr.Zero;
177                     InBuffer.pvBuffer = Marshal.AllocHGlobal(authData.Length);
178                     try
179                     {
180                     Marshal.Copy(authData, 0, InBuffer.pvBuffer, authData.Length);
181                     InBuffer.cbBuffer = authData.Length;
182                     InBuffer.BufferType = SECBUFFER_TOKEN;
183                     inbuf.ulVersion = SECBUFFER_VERSION;
184                     inbuf.cBuffers = 1;
185                     inbuf.pBuffer = Marshal.AllocHGlobal(Marshal.SizeOf(InBuffer));
186                     Marshal.StructureToPtr(InBuffer, inbuf.pBuffer, false);
187                         status = InitializeSecurityContext(
188                             ref sspicred,
189                             ref sspictx,
190                             sspitarget,
191                             ISC_REQ_ALLOCATE_MEMORY,
192                             0,
193                             SECURITY_NETWORK_DREP,
194                             ref inbuf,
195                             0,
196                             out newContext,
197                             out outbuf,
198                             out contextAttr,
199                             out expire
200                         );
201                     }
202                     finally
203                     {
204                         if (InBuffer.pvBuffer != IntPtr.Zero)
205                             Marshal.FreeHGlobal(InBuffer.pvBuffer);
206                         if (inbuf.pBuffer != IntPtr.Zero)
207                             Marshal.FreeHGlobal(inbuf.pBuffer);
208                     }
209                 }
210                 else
211                 {
212                     status = InitializeSecurityContext(
213                         ref sspicred,
214                         IntPtr.Zero,
215                         sspitarget,
216                         ISC_REQ_ALLOCATE_MEMORY,
217                         0,
218                         SECURITY_NETWORK_DREP,
219                         IntPtr.Zero,
220                         0,
221                         out newContext,
222                         out outbuf,
223                         out contextAttr,
224                         out expire
225                     );
226                 }
227
228                 if (status != SEC_E_OK && status != SEC_I_CONTINUE_NEEDED)
229                 {
230                     // This will automaticcaly fill in the message of the last Win32 error
231                     throw new Win32Exception();
232                 }
233                 if (!sspictx_set)
234                 {
235                     sspictx.dwUpper = newContext.dwUpper;
236                     sspictx.dwLower = newContext.dwLower;
237                     sspictx_set = true;
238                 }
239
240
241                 if (outbuf.cBuffers > 0)
242                 {
243                     if (outbuf.cBuffers != 1)
244                     {
245                         throw new InvalidOperationException("SSPI returned invalid number of output buffers");
246                     }
247                     // attention: OutBuffer is still our initially created struct but outbuf.pBuffer doesn't point to
248                     // it but to the copy of it we created on the unmanaged heap and passed to InitializeSecurityContext() 
249                     // we have to marshal it back to see the content change
250                     OutBuffer = (SecBuffer)Marshal.PtrToStructure(outbuf.pBuffer, typeof(SecBuffer));
251                     if (OutBuffer.cbBuffer > 0)
252                     {
253                         // we need the buffer with a terminating 0 so we
254                         // make it one byte bigger
255                         byte[] buffer = new byte[OutBuffer.cbBuffer];
256                         Marshal.Copy(OutBuffer.pvBuffer, buffer, 0, buffer.Length);
257                         // The SSPI authentication data must be sent as password message
258
259                         return System.Text.Encoding.ASCII.GetString(buffer);
260                         //stream.WriteByte((byte)'p');
261                         //PGUtil.WriteInt32(stream, buffer.Length + 5);
262                         //stream.Write(buffer, 0, buffer.Length);
263                         //stream.Flush();
264                     }
265                 }
266                 return String.Empty;
267             }
268             finally
269             {
270                 if (OutBuffer.pvBuffer != IntPtr.Zero)
271                     FreeContextBuffer(OutBuffer.pvBuffer);
272                 if (outbuf.pBuffer != IntPtr.Zero)
273                     Marshal.FreeHGlobal(outbuf.pBuffer);
274             }
275         }
276
277
278         #region resource cleanup
279
280         private void FreeHandles()
281         {
282             if (sspictx_set)
283             {
284                 FreeCredentialsHandle(ref sspicred);
285                 DeleteSecurityContext(ref sspictx);
286             }
287         }
288
289         ~SSPIHandler()
290         {
291             FreeHandles();
292         }
293
294         public void Dispose()
295         {
296             Dispose(true);
297             GC.SuppressFinalize(this);
298         }
299
300         protected virtual void Dispose(bool disposing)
301         {
302             if (!disposed)
303             {
304                 if (disposing)
305                 {
306                     FreeHandles();
307                 }
308                 disposed = true;
309             }
310         }
311
312         #endregion
313     }
314 }
315
316 #endif