From: Ludovic Henry Date: Fri, 14 Apr 2017 18:51:42 +0000 (-0400) Subject: [WebConnection] Fix race condition between Close and BeginWrite (#4693) X-Git-Url: http://wien.tomnetworks.com/gitweb/?p=mono.git;a=commitdiff_plain;h=52b00f29220b431fa0270d812668014bc92b2f7e [WebConnection] Fix race condition between Close and BeginWrite (#4693) * [WebConnection] Make ReadDone and InitRead instance methods to avoid passing cnc around * [WebConnection] Inline only call to InitConnection * [WebConnection] Fix race condition between Close and BeginWrite --- diff --git a/mcs/class/System/System.Net/WebConnection.cs b/mcs/class/System/System.Net/WebConnection.cs index 0a36e210049..563ce725fd9 100644 --- a/mcs/class/System/System.Net/WebConnection.cs +++ b/mcs/class/System/System.Net/WebConnection.cs @@ -59,7 +59,6 @@ namespace System.Net WaitCallback initConn; bool keepAlive; byte [] buffer; - static AsyncCallback readDoneDelegate = new AsyncCallback (ReadDone); EventHandler abortHandler; AbortHelper abortHelper; internal WebConnectionData Data; @@ -100,11 +99,6 @@ namespace System.Net this.sPoint = sPoint; buffer = new byte [4096]; Data = new WebConnectionData (); - initConn = new WaitCallback (state => { - try { - InitConnection (state); - } catch {} - }); queue = wcs.Group.Queue; abortHelper = new AbortHelper (); abortHelper.Connection = this; @@ -460,13 +454,12 @@ namespace System.Net } } - static void ReadDone (IAsyncResult result) + void ReadDone (IAsyncResult result) { - WebConnection cnc = (WebConnection)result.AsyncState; - WebConnectionData data = cnc.Data; - Stream ns = cnc.nstream; + WebConnectionData data = Data; + Stream ns = nstream; if (ns == null) { - cnc.Close (true); + Close (true); return; } @@ -479,84 +472,84 @@ namespace System.Net if (e.InnerException is ObjectDisposedException) return; - cnc.HandleError (WebExceptionStatus.ReceiveFailure, e, "ReadDone1"); + HandleError (WebExceptionStatus.ReceiveFailure, e, "ReadDone1"); return; } if (nread == 0) { - cnc.HandleError (WebExceptionStatus.ReceiveFailure, null, "ReadDone2"); + HandleError (WebExceptionStatus.ReceiveFailure, null, "ReadDone2"); return; } if (nread < 0) { - cnc.HandleError (WebExceptionStatus.ServerProtocolViolation, null, "ReadDone3"); + HandleError (WebExceptionStatus.ServerProtocolViolation, null, "ReadDone3"); return; } int pos = -1; - nread += cnc.position; + nread += position; if (data.ReadState == ReadState.None) { Exception exc = null; try { - pos = GetResponse (data, cnc.sPoint, cnc.buffer, nread); + pos = GetResponse (data, sPoint, buffer, nread); } catch (Exception e) { exc = e; } if (exc != null || pos == -1) { - cnc.HandleError (WebExceptionStatus.ServerProtocolViolation, exc, "ReadDone4"); + HandleError (WebExceptionStatus.ServerProtocolViolation, exc, "ReadDone4"); return; } } if (data.ReadState == ReadState.Aborted) { - cnc.HandleError (WebExceptionStatus.RequestCanceled, null, "ReadDone"); + HandleError (WebExceptionStatus.RequestCanceled, null, "ReadDone"); return; } if (data.ReadState != ReadState.Content) { int est = nread * 2; - int max = (est < cnc.buffer.Length) ? cnc.buffer.Length : est; + int max = (est < buffer.Length) ? buffer.Length : est; byte [] newBuffer = new byte [max]; - Buffer.BlockCopy (cnc.buffer, 0, newBuffer, 0, nread); - cnc.buffer = newBuffer; - cnc.position = nread; + Buffer.BlockCopy (buffer, 0, newBuffer, 0, nread); + buffer = newBuffer; + position = nread; data.ReadState = ReadState.None; - InitRead (cnc); + InitRead (); return; } - cnc.position = 0; + position = 0; - WebConnectionStream stream = new WebConnectionStream (cnc, data); + WebConnectionStream stream = new WebConnectionStream (this, data); bool expect_content = ExpectContent (data.StatusCode, data.request.Method); string tencoding = null; if (expect_content) tencoding = data.Headers ["Transfer-Encoding"]; - cnc.chunkedRead = (tencoding != null && tencoding.IndexOf ("chunked", StringComparison.OrdinalIgnoreCase) != -1); - if (!cnc.chunkedRead) { - stream.ReadBuffer = cnc.buffer; + chunkedRead = (tencoding != null && tencoding.IndexOf ("chunked", StringComparison.OrdinalIgnoreCase) != -1); + if (!chunkedRead) { + stream.ReadBuffer = buffer; stream.ReadBufferOffset = pos; stream.ReadBufferSize = nread; try { stream.CheckResponseInBuffer (); } catch (Exception e) { - cnc.HandleError (WebExceptionStatus.ReceiveFailure, e, "ReadDone7"); + HandleError (WebExceptionStatus.ReceiveFailure, e, "ReadDone7"); } - } else if (cnc.chunkStream == null) { + } else if (chunkStream == null) { try { - cnc.chunkStream = new ChunkStream (cnc.buffer, pos, nread, data.Headers); + chunkStream = new ChunkStream (buffer, pos, nread, data.Headers); } catch (Exception e) { - cnc.HandleError (WebExceptionStatus.ServerProtocolViolation, e, "ReadDone5"); + HandleError (WebExceptionStatus.ServerProtocolViolation, e, "ReadDone5"); return; } } else { - cnc.chunkStream.ResetBuffer (); + chunkStream.ResetBuffer (); try { - cnc.chunkStream.Write (cnc.buffer, pos, nread); + chunkStream.Write (buffer, pos, nread); } catch (Exception e) { - cnc.HandleError (WebExceptionStatus.ServerProtocolViolation, e, "ReadDone6"); + HandleError (WebExceptionStatus.ServerProtocolViolation, e, "ReadDone6"); return; } } @@ -576,16 +569,15 @@ namespace System.Net return (statusCode >= 200 && statusCode != 204 && statusCode != 304); } - internal static void InitRead (object state) + internal void InitRead () { - WebConnection cnc = (WebConnection) state; - Stream ns = cnc.nstream; + Stream ns = nstream; try { - int size = cnc.buffer.Length - cnc.position; - ns.BeginRead (cnc.buffer, cnc.position, size, readDoneDelegate, cnc); + int size = buffer.Length - position; + ns.BeginRead (buffer, position, size, ReadDone, null); } catch (Exception e) { - cnc.HandleError (WebExceptionStatus.ReceiveFailure, e, "InitRead"); + HandleError (WebExceptionStatus.ReceiveFailure, e, "InitRead"); } } @@ -709,9 +701,8 @@ namespace System.Net return -1; } - void InitConnection (object state) + void InitConnection (HttpWebRequest request) { - HttpWebRequest request = (HttpWebRequest) state; request.WebConnection = this; if (request.ReuseConnection) request.StoredConnection = this; @@ -773,7 +764,7 @@ namespace System.Net lock (this) { if (state.TrySetBusy ()) { status = WebExceptionStatus.Success; - ThreadPool.QueueUserWorkItem (initConn, request); + ThreadPool.QueueUserWorkItem (o => { try { InitConnection ((HttpWebRequest) o); } catch {} }, request); } else { lock (queue) { #if MONOTOUCH @@ -1016,6 +1007,18 @@ namespace System.Net IAsyncResult result = null; try { result = s.BeginWrite (buffer, offset, size, cb, state); + } catch (ObjectDisposedException) { + lock (this) { + if (Data.request != request) + return null; + } + throw; + } catch (IOException e) { + SocketException se = e.InnerException as SocketException; + if (se != null && se.SocketErrorCode == SocketError.NotConnected) { + return null; + } + throw; } catch (Exception) { status = WebExceptionStatus.SendFailure; throw; diff --git a/mcs/class/System/System.Net/WebConnectionStream.cs b/mcs/class/System/System.Net/WebConnectionStream.cs index e01d5306adb..96a2b3bf42f 100644 --- a/mcs/class/System/System.Net/WebConnectionStream.cs +++ b/mcs/class/System/System.Net/WebConnectionStream.cs @@ -459,7 +459,7 @@ namespace System.Net result.SetCompleted (false, 0); if (!initRead) { initRead = true; - WebConnection.InitRead (cnc); + cnc.InitRead (); } } catch (Exception e) { KillBuffer (); @@ -666,7 +666,7 @@ namespace System.Net cnc.EndWrite (request, true, r); if (!initRead) { initRead = true; - WebConnection.InitRead (cnc); + cnc.InitRead (); } var cl = request.ContentLength; if (!sendChunked && cl == 0) @@ -730,7 +730,7 @@ namespace System.Net if (!initRead) { initRead = true; - WebConnection.InitRead (cnc); + cnc.InitRead (); } if (length == 0) { @@ -800,7 +800,7 @@ namespace System.Net complete_request_written = true; if (!initRead) { initRead = true; - WebConnection.InitRead (cnc); + cnc.InitRead (); } return; }