Merge pull request #2573 from BrzVlad/fix-conc-memusage
[mono.git] / mcs / class / corlib / System / __ComObject.cs
index 7e439be739c59226131633c688bce0b69b709070..18f0c43ec49761d5097b6a32b97c2910c4031816 100644 (file)
 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 //
 
+#if !FULL_AOT_RUNTIME
 using Mono.Interop;
 using System.Collections;
 using System.Runtime.InteropServices;
 using System.Runtime.CompilerServices;
+using System.Threading;
 
 namespace System
 {
@@ -50,65 +52,84 @@ namespace System
        // many times that obj.GetType().FullName == "System.__ComObject" and
        // Type.GetType("System.__ComObject") may be used.
 
+       [StructLayout (LayoutKind.Sequential)]
        internal class __ComObject : MarshalByRefObject
        {
+#pragma warning disable 169    
                #region Sync with object-internals.h
                IntPtr iunknown;
                IntPtr hash_table;
+               SynchronizationContext synchronization_context;
                #endregion
-
-               [ThreadStatic]
-               static bool coinitialized;
+#pragma warning restore 169
 
                [MethodImplAttribute (MethodImplOptions.InternalCall)]
                internal static extern __ComObject CreateRCW (Type t);
 
-               [MethodImplAttribute (MethodImplOptions.InternalCall)]
-               private static extern void AddInterface (__ComObject co, Type t, IntPtr pItf);
-
-               [MethodImplAttribute (MethodImplOptions.InternalCall)]
-               private static extern IntPtr FindInterface (__ComObject co, Type t);
-
                [MethodImplAttribute (MethodImplOptions.InternalCall)]
                private extern void ReleaseInterfaces ();
 
                ~__ComObject ()
-               {
-                       ReleaseInterfaces ();
+               {       
+                       if (synchronization_context != null)
+                               synchronization_context.Post ((state) => ReleaseInterfaces (), this);
+                       else
+                               ReleaseInterfaces ();                           
                }
 
                public __ComObject ()
                {
-                       // call CoInitialize once per thread
-                       if (!coinitialized) {
-                               CoInitialize (IntPtr.Zero);
-                               coinitialized = true;
-                       }
-
-                       Type t = GetType ();
-                       int hr = CoCreateInstance (GetCLSID (t), IntPtr.Zero, 0x1 | 0x4 | 0x10, IID_IUnknown, out iunknown);
-                       Marshal.ThrowExceptionForHR (hr);
+                       Initialize (GetType ());
                }
 
-               internal __ComObject (Type t)
-               {
-                       // call CoInitialize once per thread
-                       if (!coinitialized) {
-                               CoInitialize (IntPtr.Zero);
-                               coinitialized = true;
-                       }
-
-                       int hr = CoCreateInstance (GetCLSID (t), IntPtr.Zero, 0x1 | 0x4 | 0x10, IID_IUnknown, out iunknown);
-                       Marshal.ThrowExceptionForHR (hr);
+               internal __ComObject (Type t) {
+                       Initialize (t);
                }
 
                internal __ComObject (IntPtr pItf)
                {
+                       InitializeApartmentDetails ();
                        Guid iid = IID_IUnknown;
                        int hr = Marshal.QueryInterface (pItf, ref iid, out iunknown);
                        Marshal.ThrowExceptionForHR (hr);
                }
 
+               internal void Initialize (Type t)
+               {
+                       InitializeApartmentDetails ();
+                       // Guard multiple invocation.
+                       if (iunknown != IntPtr.Zero)
+                               return;
+
+                       System.Runtime.CompilerServices.RuntimeHelpers.RunClassConstructor (t.TypeHandle);
+                       
+                       ObjectCreationDelegate ocd = ExtensibleClassFactory.GetObjectCreationCallback (t);
+                       if (ocd != null) {
+                               iunknown = ocd (IntPtr.Zero);
+                               if (iunknown == IntPtr.Zero)
+                                       throw new COMException (string.Format("ObjectCreationDelegate for type {0} failed to return a valid COM object", t));
+                       }
+                       else {
+                               int hr = CoCreateInstance (GetCLSID (t), IntPtr.Zero, 0x1 | 0x4 | 0x10, IID_IUnknown, out iunknown);
+                               Marshal.ThrowExceptionForHR (hr);
+                       }
+               }
+
+               private void InitializeApartmentDetails ()
+               {
+                       // Only synchronization_context if thread is STA.
+                       if (Thread.CurrentThread.GetApartmentState() != ApartmentState.STA)
+                               return;
+                       
+                       synchronization_context = SynchronizationContext.Current;
+
+                       // Check whether the current context is a plain SynchronizationContext object
+                       // and handle this as if no context was set at all.
+                       if (synchronization_context != null &&
+                               synchronization_context.GetType () == typeof(SynchronizationContext))
+                               synchronization_context = null;                 
+               }
+
                private static Guid GetCLSID (Type t)
                {
                        if (t.IsImport)
@@ -124,20 +145,17 @@ namespace System
                        throw new COMException ("Could not find base COM type for type " + t.ToString());
                }
 
-               internal IntPtr GetInterface(Type t)
-               {
+               [MethodImplAttribute (MethodImplOptions.InternalCall)]
+               internal extern IntPtr GetInterfaceInternal (Type t, bool throwException);
+
+               internal IntPtr GetInterface (Type t, bool throwException) {
                        CheckIUnknown ();
-                       IntPtr pItf = FindInterface (this, t);
-                       if (pItf != IntPtr.Zero) {
-                               return pItf;
-                       }
+                       return GetInterfaceInternal (t, throwException);
+               }
 
-                       Guid iid = t.GUID;
-                       IntPtr ppv;
-                       int hr = Marshal.QueryInterface (iunknown, ref iid, out ppv);
-                       Marshal.ThrowExceptionForHR (hr);
-                       AddInterface (this, t, ppv);
-                       return ppv;
+               internal IntPtr GetInterface(Type t)
+               {
+                       return GetInterface (t, true);
                }
 
                private void CheckIUnknown ()
@@ -203,9 +221,6 @@ namespace System
                        return iunknown.ToInt32 ();
                }
 
-               [DllImport ("ole32.dll", CallingConvention = CallingConvention.StdCall)]
-               static extern int CoInitialize (IntPtr pvReserved);
-
                [DllImport ("ole32.dll", CallingConvention = CallingConvention.StdCall, ExactSpelling = true, PreserveSig = true)]
                static extern int CoCreateInstance (
                   [In, MarshalAs (UnmanagedType.LPStruct)] Guid rclsid,
@@ -215,3 +230,4 @@ namespace System
                        out IntPtr pUnk);
        }
 }
+#endif