[socket] Fix race on SocketAsyncResult (#4389)
[mono.git] / mcs / class / System / System.Net.Sockets / SocketAsyncResult.cs
index 467cafa139a057ac5efad1bb01863e894a837aa6..2533129e7e5682876c582442c9f229ca7386838a 100644 (file)
@@ -35,19 +35,12 @@ using System.Threading;
 namespace System.Net.Sockets
 {
        [StructLayout (LayoutKind.Sequential)]
-       internal sealed class SocketAsyncResult: IAsyncResult, IThreadPoolWorkItem
+       internal sealed class SocketAsyncResult: IOAsyncResult
        {
-               /* Same structure in the runtime. Keep this in sync with
-                * MonoSocketAsyncResult in metadata/socket-io.h and
-                * ProcessAsyncReader in System.Diagnostics/Process.cs. */
-
                public Socket socket;
-               IntPtr handle;
-               object state;
-               AsyncCallback callback; // used from the runtime
-               WaitHandle wait_handle;
+               public SocketOperation operation;
 
-               Exception delayed_exception;
+               Exception DelayedException;
 
                public EndPoint EndPoint;                 // Connect,ReceiveFrom,SendTo
                public byte [] Buffer;                    // Receive,ReceiveFrom,Send,SendTo
@@ -59,108 +52,33 @@ namespace System.Net.Sockets
                public int Port;                          // Connect
                public IList<ArraySegment<byte>> Buffers; // Receive, Send
                public bool ReuseSocket;                  // Disconnect
+               public int CurrentAddress;                // Connect
 
-               // Return values
-               Socket accept_socket;
-               int total;
+               public Socket AcceptedSocket;
+               public int Total;
 
-               bool completed_synchronously;
-               bool completed;
-               bool is_blocking;
                internal int error;
-               public SocketOperation operation;
-               AsyncResult async_result;
+
                public int EndCalled;
 
-               /* These fields are not in MonoSocketAsyncResult */
-               public SocketAsyncWorker Worker;
-               public int CurrentAddress;                // Connect
+               public IntPtr Handle {
+                       get { return socket != null ? socket.Handle : IntPtr.Zero; }
+               }
 
+               /* Used by SocketAsyncEventArgs */
                public SocketAsyncResult ()
+                       : base ()
                {
                }
 
-               public SocketAsyncResult (Socket socket, object state, AsyncCallback callback, SocketOperation operation)
+               public void Init (Socket socket, AsyncCallback callback, object state, SocketOperation operation)
                {
-                       Init (socket, state, callback, operation, new SocketAsyncWorker (this));
-               }
-
-               public object AsyncState {
-                       get {
-                               return state;
-                       }
-               }
-
-               public WaitHandle AsyncWaitHandle {
-                       get {
-                               lock (this) {
-                                       if (wait_handle == null)
-                                               wait_handle = new ManualResetEvent (completed);
-                               }
-
-                               return wait_handle;
-                       }
-                       set {
-                               wait_handle = value;
-                       }
-               }
-
-               public bool CompletedSynchronously {
-                       get {
-                               return completed_synchronously;
-                       }
-               }
-
-               public bool IsCompleted {
-                       get {
-                               return completed;
-                       }
-                       set {
-                               completed = value;
-                               lock (this) {
-                                       if (wait_handle != null && value)
-                                               ((ManualResetEvent) wait_handle).Set ();
-                               }
-                       }
-               }
-
-               public Socket Socket {
-                       get {
-                               return accept_socket;
-                       }
-               }
-
-               public int Total {
-                       get { return total; }
-                       set { total = value; }
-               }
-
-               public SocketError ErrorCode {
-                       get {
-                               SocketException ex = delayed_exception as SocketException;
-                               if (ex != null)
-                                       return ex.SocketErrorCode;
-
-                               if (error != 0)
-                                       return (SocketError) error;
-
-                               return SocketError.Success;
-                       }
-               }
+                       base.Init (callback, state);
 
-               public void Init (Socket socket, object state, AsyncCallback callback, SocketOperation operation, SocketAsyncWorker worker)
-               {
                        this.socket = socket;
-                       this.is_blocking = socket != null ? socket.is_blocking : true;
-                       this.handle = socket != null ? socket.Handle : IntPtr.Zero;
-                       this.state = state;
-                       this.callback = callback;
                        this.operation = operation;
 
-                       if (wait_handle != null)
-                               ((ManualResetEvent) wait_handle).Reset ();
-
-                       delayed_exception = null;
+                       DelayedException = null;
 
                        EndPoint = null;
                        Buffer = null;
@@ -172,39 +90,41 @@ namespace System.Net.Sockets
                        Port = 0;
                        Buffers = null;
                        ReuseSocket = false;
-                       accept_socket = null;
-                       total = 0;
+                       CurrentAddress = 0;
+
+                       AcceptedSocket = null;
+                       Total = 0;
 
-                       completed_synchronously = false;
-                       completed = false;
-                       is_blocking = false;
                        error = 0;
-                       async_result = null;
+
                        EndCalled = 0;
-                       Worker = worker;
                }
 
-               public void DoMConnectCallback ()
+               public SocketAsyncResult (Socket socket, AsyncCallback callback, object state, SocketOperation operation)
+                       : base (callback, state)
                {
-                       if (callback == null)
-                               return;
-                       ThreadPool.UnsafeQueueUserWorkItem (_ => callback (this), null);
+                       this.socket = socket;
+                       this.operation = operation;
                }
 
-               public void Dispose ()
-               {
-                       Init (null, null, null, 0, Worker);
-                       if (wait_handle != null) {
-                               wait_handle.Close ();
-                               wait_handle = null;
+               public SocketError ErrorCode {
+                       get {
+                               SocketException ex = DelayedException as SocketException;
+                               if (ex != null)
+                                       return ex.SocketErrorCode;
+
+                               if (error != 0)
+                                       return (SocketError) error;
+
+                               return SocketError.Success;
                        }
                }
 
                public void CheckIfThrowDelayedException ()
                {
-                       if (delayed_exception != null) {
+                       if (DelayedException != null) {
                                socket.is_connected = false;
-                               throw delayed_exception;
+                               throw DelayedException;
                        }
 
                        if (error != 0) {
@@ -213,106 +133,87 @@ namespace System.Net.Sockets
                        }
                }
 
-               void CompleteDisposed (object unused)
+               internal override void CompleteDisposed ()
                {
                        Complete ();
                }
 
                public void Complete ()
                {
-                       if (operation != SocketOperation.Receive && socket.is_disposed)
-                               delayed_exception = new ObjectDisposedException (socket.GetType ().ToString ());
+                       if (operation != SocketOperation.Receive && socket.CleanedUp)
+                               DelayedException = new ObjectDisposedException (socket.GetType ().ToString ());
 
                        IsCompleted = true;
 
-                       Queue<SocketAsyncWorker> queue = null;
-                       switch (operation) {
+                       /* It is possible that this.socket is modified by this.Init which has been called by the callback. This
+                        * would lead to inconsistency, as we would for example not release the correct socket.ReadSem or
+                        * socket.WriteSem.
+                        * For example, this can happen with AcceptAsync followed by a ReceiveAsync on the same
+                        * SocketAsyncEventArgs */
+                       Socket completedSocket = socket;
+                       SocketOperation completedOperation = operation;
+
+                       AsyncCallback callback = AsyncCallback;
+                       if (callback != null) {
+                               ThreadPool.UnsafeQueueUserWorkItem (_ => callback (this), null);
+                       }
+
+                       /* Warning: any field on the current SocketAsyncResult might have changed, as the callback might have
+                        * called this.Init */
+
+                       switch (completedOperation) {
                        case SocketOperation.Receive:
                        case SocketOperation.ReceiveFrom:
                        case SocketOperation.ReceiveGeneric:
                        case SocketOperation.Accept:
-                               queue = socket.readQ;
+                               completedSocket.ReadSem.Release ();
                                break;
                        case SocketOperation.Send:
                        case SocketOperation.SendTo:
                        case SocketOperation.SendGeneric:
-                               queue = socket.writeQ;
+                               completedSocket.WriteSem.Release ();
                                break;
                        }
 
-                       if (queue != null) {
-                               lock (queue) {
-                                       /* queue.Count will only be 0 if the socket is closed while receive/send/accept
-                                        * operation(s) are pending and at least one call to this method is waiting
-                                        * on the lock while another one calls CompleteAllOnDispose() */
-                                       if (queue.Count > 0)
-                                               queue.Dequeue (); /* remove ourselves */
-                                       if (queue.Count > 0) {
-                                               if (!socket.is_disposed) {
-                                                       Socket.socket_pool_queue (SocketAsyncWorker.Dispatcher, (queue.Peek ()).result);
-                                               } else {
-                                                       /* CompleteAllOnDispose */
-                                                       SocketAsyncWorker [] workers = queue.ToArray ();
-                                                       for (int i = 0; i < workers.Length; i++)
-                                                               ThreadPool.UnsafeQueueUserWorkItem (workers [i].result.CompleteDisposed, null);
-                                                       queue.Clear ();
-                                               }
-                                       }
-                               }
-                       }
-
                        // IMPORTANT: 'callback', if any is scheduled from unmanaged code
                }
 
                public void Complete (bool synch)
                {
-                       this.completed_synchronously = synch;
+                       CompletedSynchronously = synch;
                        Complete ();
                }
 
                public void Complete (int total)
                {
-                       this.total = total;
+                       Total = total;
                        Complete ();
                }
 
                public void Complete (Exception e, bool synch)
                {
-                       this.completed_synchronously = synch;
-                       this.delayed_exception = e;
+                       DelayedException = e;
+                       CompletedSynchronously = synch;
                        Complete ();
                }
 
                public void Complete (Exception e)
                {
-                       this.delayed_exception = e;
+                       DelayedException = e;
                        Complete ();
                }
 
                public void Complete (Socket s)
                {
-                       this.accept_socket = s;
+                       AcceptedSocket = s;
                        Complete ();
                }
 
                public void Complete (Socket s, int total)
                {
-                       this.accept_socket = s;
-                       this.total = total;
+                       AcceptedSocket = s;
+                       Total = total;
                        Complete ();
                }
-
-               void IThreadPoolWorkItem.ExecuteWorkItem()
-               {
-                       ((IThreadPoolWorkItem) async_result).ExecuteWorkItem ();
-
-                       if (completed && callback != null) {
-                               ThreadPool.UnsafeQueueCustomWorkItem (new AsyncResult (state => callback ((IAsyncResult) state), this, false), false);
-                       }
-               }
-
-               void IThreadPoolWorkItem.MarkAborted(ThreadAbortException tae)
-               {
-               }
        }
 }