[WebConnection] Fix race condition between Close and BeginWrite (#4693)
authorLudovic Henry <ludovic@xamarin.com>
Fri, 14 Apr 2017 18:51:42 +0000 (14:51 -0400)
committerGitHub <noreply@github.com>
Fri, 14 Apr 2017 18:51:42 +0000 (14:51 -0400)
* [WebConnection] Make ReadDone and InitRead instance methods to avoid passing cnc around

* [WebConnection] Inline only call to InitConnection

* [WebConnection] Fix race condition between Close and BeginWrite

mcs/class/System/System.Net/WebConnection.cs
mcs/class/System/System.Net/WebConnectionStream.cs

index 0a36e21004943427b937884ef9993d703c7a2c9c..563ce725fd9a5a4b1bacf9e0a30ccf4b6fabd50d 100644 (file)
@@ -59,7 +59,6 @@ namespace System.Net
                WaitCallback initConn;
                bool keepAlive;
                byte [] buffer;
-               static AsyncCallback readDoneDelegate = new AsyncCallback (ReadDone);
                EventHandler abortHandler;
                AbortHelper abortHelper;
                internal WebConnectionData Data;
@@ -100,11 +99,6 @@ namespace System.Net
                        this.sPoint = sPoint;
                        buffer = new byte [4096];
                        Data = new WebConnectionData ();
-                       initConn = new WaitCallback (state => {
-                               try {
-                                       InitConnection (state);
-                               } catch {}
-                               });
                        queue = wcs.Group.Queue;
                        abortHelper = new AbortHelper ();
                        abortHelper.Connection = this;
@@ -460,13 +454,12 @@ namespace System.Net
                        }
                }
                
-               static void ReadDone (IAsyncResult result)
+               void ReadDone (IAsyncResult result)
                {
-                       WebConnection cnc = (WebConnection)result.AsyncState;
-                       WebConnectionData data = cnc.Data;
-                       Stream ns = cnc.nstream;
+                       WebConnectionData data = Data;
+                       Stream ns = nstream;
                        if (ns == null) {
-                               cnc.Close (true);
+                               Close (true);
                                return;
                        }
 
@@ -479,84 +472,84 @@ namespace System.Net
                                if (e.InnerException is ObjectDisposedException)
                                        return;
 
-                               cnc.HandleError (WebExceptionStatus.ReceiveFailure, e, "ReadDone1");
+                               HandleError (WebExceptionStatus.ReceiveFailure, e, "ReadDone1");
                                return;
                        }
 
                        if (nread == 0) {
-                               cnc.HandleError (WebExceptionStatus.ReceiveFailure, null, "ReadDone2");
+                               HandleError (WebExceptionStatus.ReceiveFailure, null, "ReadDone2");
                                return;
                        }
 
                        if (nread < 0) {
-                               cnc.HandleError (WebExceptionStatus.ServerProtocolViolation, null, "ReadDone3");
+                               HandleError (WebExceptionStatus.ServerProtocolViolation, null, "ReadDone3");
                                return;
                        }
 
                        int pos = -1;
-                       nread += cnc.position;
+                       nread += position;
                        if (data.ReadState == ReadState.None) { 
                                Exception exc = null;
                                try {
-                                       pos = GetResponse (data, cnc.sPoint, cnc.buffer, nread);
+                                       pos = GetResponse (data, sPoint, buffer, nread);
                                } catch (Exception e) {
                                        exc = e;
                                }
 
                                if (exc != null || pos == -1) {
-                                       cnc.HandleError (WebExceptionStatus.ServerProtocolViolation, exc, "ReadDone4");
+                                       HandleError (WebExceptionStatus.ServerProtocolViolation, exc, "ReadDone4");
                                        return;
                                }
                        }
 
                        if (data.ReadState == ReadState.Aborted) {
-                               cnc.HandleError (WebExceptionStatus.RequestCanceled, null, "ReadDone");
+                               HandleError (WebExceptionStatus.RequestCanceled, null, "ReadDone");
                                return;
                        }
 
                        if (data.ReadState != ReadState.Content) {
                                int est = nread * 2;
-                               int max = (est < cnc.buffer.Length) ? cnc.buffer.Length : est;
+                               int max = (est < buffer.Length) ? buffer.Length : est;
                                byte [] newBuffer = new byte [max];
-                               Buffer.BlockCopy (cnc.buffer, 0, newBuffer, 0, nread);
-                               cnc.buffer = newBuffer;
-                               cnc.position = nread;
+                               Buffer.BlockCopy (buffer, 0, newBuffer, 0, nread);
+                               buffer = newBuffer;
+                               position = nread;
                                data.ReadState = ReadState.None;
-                               InitRead (cnc);
+                               InitRead ();
                                return;
                        }
 
-                       cnc.position = 0;
+                       position = 0;
 
-                       WebConnectionStream stream = new WebConnectionStream (cnc, data);
+                       WebConnectionStream stream = new WebConnectionStream (this, data);
                        bool expect_content = ExpectContent (data.StatusCode, data.request.Method);
                        string tencoding = null;
                        if (expect_content)
                                tencoding = data.Headers ["Transfer-Encoding"];
 
-                       cnc.chunkedRead = (tencoding != null && tencoding.IndexOf ("chunked", StringComparison.OrdinalIgnoreCase) != -1);
-                       if (!cnc.chunkedRead) {
-                               stream.ReadBuffer = cnc.buffer;
+                       chunkedRead = (tencoding != null && tencoding.IndexOf ("chunked", StringComparison.OrdinalIgnoreCase) != -1);
+                       if (!chunkedRead) {
+                               stream.ReadBuffer = buffer;
                                stream.ReadBufferOffset = pos;
                                stream.ReadBufferSize = nread;
                                try {
                                        stream.CheckResponseInBuffer ();
                                } catch (Exception e) {
-                                       cnc.HandleError (WebExceptionStatus.ReceiveFailure, e, "ReadDone7");
+                                       HandleError (WebExceptionStatus.ReceiveFailure, e, "ReadDone7");
                                }
-                       } else if (cnc.chunkStream == null) {
+                       } else if (chunkStream == null) {
                                try {
-                                       cnc.chunkStream = new ChunkStream (cnc.buffer, pos, nread, data.Headers);
+                                       chunkStream = new ChunkStream (buffer, pos, nread, data.Headers);
                                } catch (Exception e) {
-                                       cnc.HandleError (WebExceptionStatus.ServerProtocolViolation, e, "ReadDone5");
+                                       HandleError (WebExceptionStatus.ServerProtocolViolation, e, "ReadDone5");
                                        return;
                                }
                        } else {
-                               cnc.chunkStream.ResetBuffer ();
+                               chunkStream.ResetBuffer ();
                                try {
-                                       cnc.chunkStream.Write (cnc.buffer, pos, nread);
+                                       chunkStream.Write (buffer, pos, nread);
                                } catch (Exception e) {
-                                       cnc.HandleError (WebExceptionStatus.ServerProtocolViolation, e, "ReadDone6");
+                                       HandleError (WebExceptionStatus.ServerProtocolViolation, e, "ReadDone6");
                                        return;
                                }
                        }
@@ -576,16 +569,15 @@ namespace System.Net
                        return (statusCode >= 200 && statusCode != 204 && statusCode != 304);
                }
 
-               internal static void InitRead (object state)
+               internal void InitRead ()
                {
-                       WebConnection cnc = (WebConnection) state;
-                       Stream ns = cnc.nstream;
+                       Stream ns = nstream;
 
                        try {
-                               int size = cnc.buffer.Length - cnc.position;
-                               ns.BeginRead (cnc.buffer, cnc.position, size, readDoneDelegate, cnc);
+                               int size = buffer.Length - position;
+                               ns.BeginRead (buffer, position, size, ReadDone, null);
                        } catch (Exception e) {
-                               cnc.HandleError (WebExceptionStatus.ReceiveFailure, e, "InitRead");
+                               HandleError (WebExceptionStatus.ReceiveFailure, e, "InitRead");
                        }
                }
                
@@ -709,9 +701,8 @@ namespace System.Net
                        return -1;
                }
                
-               void InitConnection (object state)
+               void InitConnection (HttpWebRequest request)
                {
-                       HttpWebRequest request = (HttpWebRequest) state;
                        request.WebConnection = this;
                        if (request.ReuseConnection)
                                request.StoredConnection = this;
@@ -773,7 +764,7 @@ namespace System.Net
                        lock (this) {
                                if (state.TrySetBusy ()) {
                                        status = WebExceptionStatus.Success;
-                                       ThreadPool.QueueUserWorkItem (initConn, request);
+                                       ThreadPool.QueueUserWorkItem (o => { try { InitConnection ((HttpWebRequest) o); } catch {} }, request);
                                } else {
                                        lock (queue) {
 #if MONOTOUCH
@@ -1016,6 +1007,18 @@ namespace System.Net
                        IAsyncResult result = null;
                        try {
                                result = s.BeginWrite (buffer, offset, size, cb, state);
+                       } catch (ObjectDisposedException) {
+                               lock (this) {
+                                       if (Data.request != request)
+                                               return null;
+                               }
+                               throw;
+                       } catch (IOException e) {
+                               SocketException se = e.InnerException as SocketException;
+                               if (se != null && se.SocketErrorCode == SocketError.NotConnected) {
+                                       return null;
+                               }
+                               throw;
                        } catch (Exception) {
                                status = WebExceptionStatus.SendFailure;
                                throw;
index e01d5306adb07fd60c06664de78c6f186bab8ccd..96a2b3bf42fc4cc506ad4e6fd056e2c6168d36b6 100644 (file)
@@ -459,7 +459,7 @@ namespace System.Net
                                result.SetCompleted (false, 0);
                                if (!initRead) {
                                        initRead = true;
-                                       WebConnection.InitRead (cnc);
+                                       cnc.InitRead ();
                                }
                        } catch (Exception e) {
                                KillBuffer ();
@@ -666,7 +666,7 @@ namespace System.Net
                                        cnc.EndWrite (request, true, r);
                                        if (!initRead) {
                                                initRead = true;
-                                               WebConnection.InitRead (cnc);
+                                               cnc.InitRead ();
                                        }
                                        var cl = request.ContentLength;
                                        if (!sendChunked && cl == 0)
@@ -730,7 +730,7 @@ namespace System.Net
 
                                if (!initRead) {
                                        initRead = true;
-                                       WebConnection.InitRead (cnc);
+                                       cnc.InitRead ();
                                }
 
                                if (length == 0) {
@@ -800,7 +800,7 @@ namespace System.Net
                                complete_request_written = true;
                                if (!initRead) {
                                        initRead = true;
-                                       WebConnection.InitRead (cnc);
+                                       cnc.InitRead ();
                                }
                                return;
                        }