New test.
[mono.git] / mcs / class / corlib / System.Threading / WaitHandle.cs
index 7f135bfc14bbb50279d6658dd661686140e78d71..9081ece31d9c559e21cc1c3fc0461bb3e15a566a 100644 (file)
@@ -41,7 +41,7 @@ namespace System.Threading
                [MethodImplAttribute(MethodImplOptions.InternalCall)]
                private static extern bool WaitAll_internal(WaitHandle[] handles, int ms, bool exitContext);
                
-               static void CheckArray (WaitHandle [] handles)
+               static void CheckArray (WaitHandle [] handles, bool waitAll)
                {
                        if (handles == null)
                                throw new ArgumentNullException ("waitHandles");
@@ -50,10 +50,7 @@ namespace System.Threading
                        if (length > 64)
                                throw new NotSupportedException ("Too many handles");
 
-                       MethodInfo entryPoint = Assembly.GetEntryAssembly ().EntryPoint;
-                       if (length > 1 &&
-                           (Thread.CurrentThread.ApartmentState == ApartmentState.STA ||
-                            entryPoint.GetCustomAttributes (typeof (STAThreadAttribute), false).Length == 1))
+                       if (waitAll && length > 1 && IsSTAThread)
                                throw new NotSupportedException ("WaitAll for multiple handles is not allowed on an STA thread.");
                        
                        foreach (WaitHandle w in handles) {
@@ -64,16 +61,33 @@ namespace System.Threading
                                        throw new ArgumentException ("null element found", "waitHandle");
                        }
                }
+
+               static bool IsSTAThread {
+                       get {
+                               bool isSTA = Thread.CurrentThread.ApartmentState ==
+                                       ApartmentState.STA;
+
+                               // FIXME: remove this check after Thread.ApartmentState
+                               // has been properly implemented.
+                               if (!isSTA) {
+                                       Assembly asm = Assembly.GetEntryAssembly ();
+                                       if (asm != null)
+                                               isSTA = asm.EntryPoint.GetCustomAttributes (typeof (STAThreadAttribute), false).Length > 0;
+                               }
+
+                               return isSTA;
+                       }
+               }
                
                public static bool WaitAll(WaitHandle[] waitHandles)
                {
-                       CheckArray (waitHandles);
+                       CheckArray (waitHandles, true);
                        return(WaitAll_internal(waitHandles, Timeout.Infinite, false));
                }
 
                public static bool WaitAll(WaitHandle[] waitHandles, int millisecondsTimeout, bool exitContext)
                {
-                       CheckArray (waitHandles);
+                       CheckArray (waitHandles, true);
                        try {
                                if (exitContext) SynchronizationAttribute.ExitContext ();
                                return(WaitAll_internal(waitHandles, millisecondsTimeout, false));
@@ -87,7 +101,7 @@ namespace System.Threading
                                           TimeSpan timeout,
                                           bool exitContext)
                {
-                       CheckArray (waitHandles);
+                       CheckArray (waitHandles, true);
                        long ms = (long) timeout.TotalMilliseconds;
                        
                        if (ms < -1 || ms > Int32.MaxValue)
@@ -108,7 +122,7 @@ namespace System.Threading
                // LAMESPEC: Doesn't specify how to signal failures
                public static int WaitAny(WaitHandle[] waitHandles)
                {
-                       CheckArray (waitHandles);
+                       CheckArray (waitHandles, false);
                        return(WaitAny_internal(waitHandles, Timeout.Infinite, false));
                }
 
@@ -116,7 +130,7 @@ namespace System.Threading
                                          int millisecondsTimeout,
                                          bool exitContext)
                {
-                       CheckArray (waitHandles);
+                       CheckArray (waitHandles, false);
                        try {
                                if (exitContext) SynchronizationAttribute.ExitContext ();
                                return(WaitAny_internal(waitHandles, millisecondsTimeout, exitContext));
@@ -129,7 +143,7 @@ namespace System.Threading
                public static int WaitAny(WaitHandle[] waitHandles,
                                          TimeSpan timeout, bool exitContext)
                {
-                       CheckArray (waitHandles);
+                       CheckArray (waitHandles, false);
                        long ms = (long) timeout.TotalMilliseconds;
                        
                        if (ms < -1 || ms > Int32.MaxValue)