330b89fba3ae889af45cbbed2bb9b9dbdcca08e5
[mono.git] / mcs / class / referencesource / System.Data / System / Data / ProviderBase / DbConnectionPoolIdentity.cs
1 //------------------------------------------------------------------------------
2 // <copyright file="DbConnectionPoolIdentity.cs" company="Microsoft">
3 //      Copyright (c) Microsoft Corporation.  All rights reserved.
4 // </copyright>
5 // <owner current="true" primary="true">[....]</owner>
6 //------------------------------------------------------------------------------
7
8 namespace System.Data.ProviderBase {
9
10     using System;
11     using System.Collections;
12     using System.Data.Common;
13     using System.Diagnostics;
14     using System.Globalization;
15     using System.Runtime.CompilerServices;
16     using System.Runtime.InteropServices;
17     using System.Security;
18     using System.Security.Permissions;
19     using System.Security.Principal;
20     using System.Threading;
21     using System.Runtime.Versioning;
22
23     [Serializable] // Serializable so SqlDependencyProcessDispatcher can marshall cross domain to SqlDependency.
24     sealed internal class DbConnectionPoolIdentity {
25         private const int E_NotImpersonationToken      = unchecked((int)0x8007051D);
26         private const int Win32_CheckTokenMembership   = 1;
27         private const int Win32_GetTokenInformation_1  = 2;
28         private const int Win32_GetTokenInformation_2  = 3;
29         private const int Win32_ConvertSidToStringSidW = 4;
30         private const int Win32_CreateWellKnownSid     = 5;
31
32         static public  readonly DbConnectionPoolIdentity NoIdentity = new DbConnectionPoolIdentity(String.Empty, false, true);
33         static private readonly byte[]                   NetworkSid = (ADP.IsWindowsNT ? CreateWellKnownSid(WellKnownSidType.NetworkSid) : null);
34         static private DbConnectionPoolIdentity _lastIdentity = null;
35
36         private readonly string _sidString;
37         private readonly bool   _isRestricted;
38         private readonly bool   _isNetwork;
39         private readonly int    _hashCode;
40
41         private DbConnectionPoolIdentity (string sidString, bool isRestricted, bool isNetwork) {
42             _sidString = sidString;
43             _isRestricted = isRestricted;
44             _isNetwork = isNetwork;
45             _hashCode = sidString == null ? 0 : sidString.GetHashCode();
46         }
47
48         internal bool IsRestricted {
49             get { return _isRestricted; }
50         }
51
52         internal bool IsNetwork {
53             get { return _isNetwork; }
54         }
55
56         static private byte[] CreateWellKnownSid(WellKnownSidType sidType) {
57             // Passing an array as big as it can ever be is a small price to pay for
58             // not having to P/Invoke twice (once to get the buffer, once to get the data)
59
60             uint length = ( uint )SecurityIdentifier.MaxBinaryLength;
61             byte[] resultSid = new byte[ length ];
62
63             // NOTE - We copied this code from System.Security.Principal.Win32.CreateWellKnownSid...
64
65             if ( 0 == UnsafeNativeMethods.CreateWellKnownSid(( int )sidType, null, resultSid, ref length )) {
66                 IntegratedSecurityError(Win32_CreateWellKnownSid);
67             }
68             return resultSid;
69         }
70
71         override public bool Equals(object value) {
72             bool result = ((this == NoIdentity) || (this == value));
73             if (!result && (null != value)) {
74                 DbConnectionPoolIdentity that = ((DbConnectionPoolIdentity) value);
75                 result = ((this._sidString == that._sidString) && (this._isRestricted == that._isRestricted) && (this._isNetwork == that._isNetwork));
76             }
77             return result;
78         }
79
80         [SecurityPermission(SecurityAction.Assert, Flags=SecurityPermissionFlag.ControlPrincipal)]
81         static internal WindowsIdentity GetCurrentWindowsIdentity() {
82             return WindowsIdentity.GetCurrent();
83         }
84
85         [SecurityPermission(SecurityAction.Assert, Flags=SecurityPermissionFlag.UnmanagedCode)]
86         static private IntPtr GetWindowsIdentityToken(WindowsIdentity identity) {
87             return identity.Token;
88         }
89
90         [ResourceExposure(ResourceScope.None)] // SxS: this method does not create named objects
91         [ResourceConsumption(ResourceScope.Process, ResourceScope.Process)]
92         static internal DbConnectionPoolIdentity GetCurrent() {
93
94             // DEVNOTE: GetTokenInfo and EqualSID do not work on 9x.  WindowsIdentity does not
95             //          work either on 9x.  In fact, after checking with native there is no way
96             //          to validate the user on 9x, so simply don't.  It is a known issue in
97             //          native, and we will handle this the same way.
98
99             if (!ADP.IsWindowsNT) {
100                 return NoIdentity;
101             }
102
103             WindowsIdentity identity     = GetCurrentWindowsIdentity();
104             IntPtr          token        = GetWindowsIdentityToken(identity); // Free'd by WindowsIdentity.
105             uint            bufferLength = 2048;           // Suggested default given by Greg Fee.
106             uint            lengthNeeded = 0;
107
108             IntPtr          tokenStruct = IntPtr.Zero;
109             IntPtr          SID;
110             IntPtr          sidStringBuffer = IntPtr.Zero;
111             bool            isNetwork;
112
113             // Win32NativeMethods.IsTokenRestricted will raise exception if the native call fails
114             bool            isRestricted = Win32NativeMethods.IsTokenRestrictedWrapper(token);
115             
116             DbConnectionPoolIdentity current = null;
117
118             RuntimeHelpers.PrepareConstrainedRegions();
119             try {
120                 if (!UnsafeNativeMethods.CheckTokenMembership(token, NetworkSid, out isNetwork)) {
121                     // will always fail with 0x8007051D if token is not an impersonation token
122                     IntegratedSecurityError(Win32_CheckTokenMembership);
123                 }
124                 
125                 RuntimeHelpers.PrepareConstrainedRegions();
126                 try { } finally {
127                     // allocating memory and assigning to tokenStruct must happen
128                     tokenStruct = SafeNativeMethods.LocalAlloc(DbBuffer.LMEM_FIXED, (IntPtr)bufferLength);
129                 }
130                 if (IntPtr.Zero == tokenStruct) {
131                     throw new OutOfMemoryException();
132                 }
133                 if (!UnsafeNativeMethods.GetTokenInformation(token, 1, tokenStruct, bufferLength, ref lengthNeeded)) {
134                     if (lengthNeeded > bufferLength) {
135                         bufferLength = lengthNeeded;
136
137                         RuntimeHelpers.PrepareConstrainedRegions();
138                         try { } finally {
139                             // freeing token struct and setting tokenstruct to null must happen together
140                             // allocating memory and assigning to tokenStruct must happen
141                             SafeNativeMethods.LocalFree(tokenStruct);
142                             tokenStruct = IntPtr.Zero; // protect against LocalAlloc throwing an exception
143                             tokenStruct = SafeNativeMethods.LocalAlloc(DbBuffer.LMEM_FIXED, (IntPtr)bufferLength);
144                         }
145                         if (IntPtr.Zero == tokenStruct) {
146                             throw new OutOfMemoryException();
147                         }
148
149                         if (!UnsafeNativeMethods.GetTokenInformation(token, 1, tokenStruct, bufferLength, ref lengthNeeded)) {
150                             IntegratedSecurityError(Win32_GetTokenInformation_1);
151                         }
152                     }
153                     else {
154                         IntegratedSecurityError(Win32_GetTokenInformation_2);
155                     }
156                 }
157
158                 identity.Dispose(); // Keep identity variable alive until after GetTokenInformation calls.
159
160
161                 SID = Marshal.ReadIntPtr(tokenStruct, 0);
162
163                 if (!UnsafeNativeMethods.ConvertSidToStringSidW(SID, out sidStringBuffer)) {
164                     IntegratedSecurityError(Win32_ConvertSidToStringSidW);
165                 }
166
167                 if (IntPtr.Zero == sidStringBuffer) {
168                     throw ADP.InternalError(ADP.InternalErrorCode.ConvertSidToStringSidWReturnedNull);
169                 }
170
171                 string sidString = Marshal.PtrToStringUni(sidStringBuffer);
172
173                 var lastIdentity = _lastIdentity;
174                 if ((lastIdentity != null) && (lastIdentity._sidString == sidString) && (lastIdentity._isRestricted == isRestricted) && (lastIdentity._isNetwork == isNetwork)) {
175                     current = lastIdentity;
176                 }
177                 else {
178                     current = new DbConnectionPoolIdentity(sidString, isRestricted, isNetwork);
179                 }
180             }
181             finally {
182                 // Marshal.FreeHGlobal does not have a ReliabilityContract
183                 if (IntPtr.Zero != tokenStruct) {
184                     SafeNativeMethods.LocalFree(tokenStruct);
185                     tokenStruct = IntPtr.Zero;
186                 }
187                 if (IntPtr.Zero != sidStringBuffer) {
188                     SafeNativeMethods.LocalFree(sidStringBuffer);
189                     sidStringBuffer = IntPtr.Zero;
190                 }
191             }
192             _lastIdentity = current;
193             return current;
194         }
195
196         override public int GetHashCode() {
197             return _hashCode;
198         }
199
200         static private void IntegratedSecurityError(int caller) {
201             // passing 1,2,3,4,5 instead of true/false so that with a debugger
202             // we could determine more easily which Win32 method call failed
203             int lastError = Marshal.GetHRForLastWin32Error();
204             if ((Win32_CheckTokenMembership != caller) || (E_NotImpersonationToken != lastError)) {
205                 Marshal.ThrowExceptionForHR(lastError); // will only throw if (hresult < 0)
206             }
207         }
208         
209     }
210 }
211