[runtime] Don't insta-fail when a faulty COM type is encountered. (#5616)
[mono.git] / mono / metadata / cominterop.c
index 86309ef20234d4dd559726f0c9fcfe1a32f08e32..c4bc5b037810292912e73923e02bd50288feb305 100644 (file)
@@ -378,30 +378,39 @@ cominterop_get_method_interface (MonoMethod* method)
                }
        }
 
-       if (!ic) 
-               g_assert (ic);
-       g_assert (MONO_CLASS_IS_INTERFACE (ic));
-
        return ic;
 }
 
+static void
+mono_cominterop_get_interface_missing_error (MonoError* error, MonoMethod* method)
+{
+       mono_error_set_invalid_operation (error, "Method '%s' in ComImport class '%s' must implement an interface method.", method->name, method->klass->name);
+}
+
 /**
  * cominterop_get_com_slot_for_method:
  * @method: a method
+ * @error: set on error
  *
  * Returns: the method's slot in the COM interface vtable
  */
 static int
-cominterop_get_com_slot_for_method (MonoMethod* method)
+cominterop_get_com_slot_for_method (MonoMethod* method, MonoError* error)
 {
        guint32 slot = method->slot;
        MonoClass *ic = method->klass;
 
+       error_init (error);
+
        /* if method is on a class, we need to look up interface method exists on */
        if (!MONO_CLASS_IS_INTERFACE(ic)) {
                int offset = 0;
                int i = 0;
                ic = cominterop_get_method_interface (method);
+               if (!ic || !MONO_CLASS_IS_INTERFACE (ic)) {
+                       mono_cominterop_get_interface_missing_error (error, method);
+                       return -1;
+               }
                offset = mono_class_interface_offset (method->klass, ic);
                g_assert(offset >= 0);
                int mcount = mono_class_get_method_count (ic);
@@ -640,14 +649,30 @@ mono_cominterop_cleanup (void)
 }
 
 void
-mono_mb_emit_cominterop_call (MonoMethodBuilder *mb, MonoMethodSignature *sig, MonoMethod* method)
+mono_mb_emit_cominterop_get_function_pointer (MonoMethodBuilder *mb, MonoMethod *method)
 {
 #ifndef DISABLE_JIT
+       int slot;
+       MonoError error;
        // get function pointer from 1st arg, the COM interface pointer
        mono_mb_emit_ldarg (mb, 0);
-       mono_mb_emit_icon (mb, cominterop_get_com_slot_for_method (method));
-       mono_mb_emit_icall (mb, cominterop_get_function_pointer);
+       slot = cominterop_get_com_slot_for_method (method, &error);
+       if (is_ok (&error)) {
+               mono_mb_emit_icon (mb, slot);
+               mono_mb_emit_icall (mb, cominterop_get_function_pointer);
+               /* Leaves the function pointer on top of the stack */
+       }
+       else {
+               mono_mb_emit_exception_for_error (mb, &error);
+       }
+       mono_error_cleanup (&error);
+#endif
+}
 
+void
+mono_mb_emit_cominterop_call_function_pointer (MonoMethodBuilder *mb, MonoMethodSignature *sig)
+{
+#ifndef DISABLE_JIT
        mono_mb_emit_byte (mb, MONO_CUSTOM_PREFIX);
        mono_mb_emit_byte (mb, CEE_MONO_SAVE_LMF);
        mono_mb_emit_calli (mb, sig);
@@ -656,6 +681,16 @@ mono_mb_emit_cominterop_call (MonoMethodBuilder *mb, MonoMethodSignature *sig, M
 #endif /* DISABLE_JIT */
 }
 
+void
+mono_mb_emit_cominterop_call (MonoMethodBuilder *mb, MonoMethodSignature *sig, MonoMethod* method)
+{
+#ifndef DISABLE_JIT
+       mono_mb_emit_cominterop_get_function_pointer (mb, method);
+
+       mono_mb_emit_cominterop_call_function_pointer (mb, sig);
+#endif /* DISABLE_JIT */
+}
+
 void
 mono_cominterop_emit_ptr_to_object_conv (MonoMethodBuilder *mb, MonoType *type, MonoMarshalConv conv, MonoMarshalSpec *mspec)
 {
@@ -978,6 +1013,18 @@ mono_cominterop_get_native_wrapper (MonoMethod *method)
                        mono_mb_emit_managed_call (mb, ctor, NULL);
                        mono_mb_emit_byte (mb, CEE_RET);
                }
+               else if (method->flags & METHOD_ATTRIBUTE_STATIC) {
+                       /*
+                        * The method's class must implement an interface.
+                        * However, no interfaces are allowed to have static methods.
+                        * Thus, calling it should invariably lead to an exception.
+                        */
+                       MonoError error;
+                       error_init (&error);
+                       mono_cominterop_get_interface_missing_error (&error, method);
+                       mono_mb_emit_exception_for_error (mb, &error);
+                       mono_error_cleanup (&error);
+               }
                else {
                        static MonoMethod * ThrowExceptionForHR = NULL;
                        MonoMethod *adjusted_method;
@@ -1090,7 +1137,7 @@ mono_cominterop_get_invoke (MonoMethod *method)
        for (i = 1; i <= sig->param_count; i++)
                mono_mb_emit_ldarg (mb, i);
 
-       if (method->iflags & METHOD_IMPL_ATTRIBUTE_INTERNAL_CALL) {
+       if ((method->iflags & METHOD_IMPL_ATTRIBUTE_INTERNAL_CALL) || mono_class_is_interface (method->klass)) {
                MonoMethod * native_wrapper = mono_cominterop_get_native_wrapper(method);
                mono_mb_emit_managed_call (mb, native_wrapper, NULL);
        }
@@ -1689,7 +1736,10 @@ guint32
 ves_icall_System_Runtime_InteropServices_Marshal_GetComSlotForMethodInfoInternal (MonoReflectionMethod *m)
 {
 #ifndef DISABLE_COM
-       return cominterop_get_com_slot_for_method (m->method);
+       MonoError error;
+       int slot = cominterop_get_com_slot_for_method (m->method, &error);
+       mono_error_assert_ok (&error);
+       return slot;
 #else
        g_assert_not_reached ();
 #endif
@@ -3664,3 +3714,52 @@ ves_icall_System_Runtime_InteropServices_Marshal_FreeBSTR (gpointer ptr)
 {
        mono_free_bstr (ptr);
 }
+
+void*
+mono_cominterop_get_com_interface (MonoObject *object, MonoClass *ic, MonoError *error)
+{
+       error_init (error);
+
+#ifndef DISABLE_COM
+       if (!object)
+               return NULL;
+
+       if (cominterop_object_is_rcw (object)) {
+               MonoClass *klass = NULL;
+               MonoRealProxy* real_proxy = NULL;
+               if (!object)
+                       return NULL;
+               klass = mono_object_class (object);
+               if (!mono_class_is_transparent_proxy (klass)) {
+                       mono_error_set_invalid_operation (error, "Class is not transparent");
+                       return NULL;
+               }
+
+               real_proxy = ((MonoTransparentProxy*)object)->rp;
+               if (!real_proxy) {
+                       mono_error_set_invalid_operation (error, "RealProxy is null");
+                       return NULL;
+               }
+
+               klass = mono_object_class (real_proxy);
+               if (klass != mono_class_get_interop_proxy_class ()) {
+                       mono_error_set_invalid_operation (error, "Object is not a proxy");
+                       return NULL;
+               }
+
+               if (!((MonoComInteropProxy*)real_proxy)->com_object) {
+                       mono_error_set_invalid_operation (error, "Proxy points to null COM object");
+                       return NULL;
+               }
+
+               void* com_itf = cominterop_get_interface_checked (((MonoComInteropProxy*)real_proxy)->com_object, ic, error);
+               return com_itf;
+       }
+       else {
+               void* ccw_entry = cominterop_get_ccw_checked (object, ic, error);
+               return ccw_entry;
+       }
+#else
+       g_assert_not_reached ();
+#endif
+}