Updates referencesource to .NET 4.7
[mono.git] / mcs / class / referencesource / System.ServiceModel / System / ServiceModel / Channels / RequestChannel.cs
1 //------------------------------------------------------------
2 // Copyright (c) Microsoft Corporation.  All rights reserved.
3 //------------------------------------------------------------
4
5 namespace System.ServiceModel.Channels
6 {
7     using System.Collections.Generic;
8     using System.Diagnostics;
9     using System.Runtime;
10     using System.ServiceModel;
11     using System.ServiceModel.Diagnostics;
12     using System.Threading;
13
14     abstract class RequestChannel : ChannelBase, IRequestChannel
15     {
16         bool manualAddressing;
17         List<IRequestBase> outstandingRequests = new List<IRequestBase>();
18         EndpointAddress to;
19         Uri via;
20         ManualResetEvent closedEvent;
21         bool closed;
22
23         protected RequestChannel(ChannelManagerBase channelFactory, EndpointAddress to, Uri via, bool manualAddressing)
24             : base(channelFactory)
25         {
26             if (!manualAddressing)
27             {
28                 if (to == null)
29                 {
30                     throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("to");
31                 }
32             }
33
34             this.manualAddressing = manualAddressing;
35             this.to = to;
36             this.via = via;
37         }
38
39         protected bool ManualAddressing
40         {
41             get
42             {
43                 return this.manualAddressing;
44             }
45         }
46
47         public EndpointAddress RemoteAddress
48         {
49             get
50             {
51                 return this.to;
52             }
53         }
54
55         public Uri Via
56         {
57             get
58             {
59                 return this.via;
60             }
61         }
62
63         protected void AbortPendingRequests()
64         {
65             IRequestBase[] requestsToAbort = CopyPendingRequests(false);
66
67             if (requestsToAbort != null)
68             {
69                 foreach (IRequestBase request in requestsToAbort)
70                 {
71                     request.Abort(this);
72                 }
73             }
74         }
75
76         protected IAsyncResult BeginWaitForPendingRequests(TimeSpan timeout, AsyncCallback callback, object state)
77         {
78             IRequestBase[] pendingRequests = SetupWaitForPendingRequests();
79             return new WaitForPendingRequestsAsyncResult(timeout, this, pendingRequests, callback, state);
80         }
81
82         protected void EndWaitForPendingRequests(IAsyncResult result)
83         {
84             WaitForPendingRequestsAsyncResult.End(result);
85         }
86
87         void FinishClose()
88         {
89             lock (outstandingRequests)
90             {
91                 if (!closed)
92                 {
93                     closed = true;
94                     if (closedEvent != null)
95                     {
96                         this.closedEvent.Close();
97                     }
98                 }
99             }
100         }
101
102         IRequestBase[] SetupWaitForPendingRequests()
103         {
104             return this.CopyPendingRequests(true);
105         }
106
107         protected void WaitForPendingRequests(TimeSpan timeout)
108         {
109             IRequestBase[] pendingRequests = SetupWaitForPendingRequests();
110             if (pendingRequests != null)
111             {
112                 if (!closedEvent.WaitOne(timeout, false))
113                 {
114                     foreach (IRequestBase request in pendingRequests)
115                     {
116                         request.Abort(this);
117                     }
118                 }
119             }
120             FinishClose();
121         }
122
123         IRequestBase[] CopyPendingRequests(bool createEventIfNecessary)
124         {
125             IRequestBase[] requests = null;
126
127             lock (outstandingRequests)
128             {
129                 if (outstandingRequests.Count > 0)
130                 {
131                     requests = new IRequestBase[outstandingRequests.Count];
132                     outstandingRequests.CopyTo(requests);
133                     outstandingRequests.Clear();
134
135                     if (createEventIfNecessary && closedEvent == null)
136                     {
137                         closedEvent = new ManualResetEvent(false);
138                     }
139                 }
140             }
141
142             return requests;
143         }
144
145         protected void FaultPendingRequests()
146         {
147             IRequestBase[] requestsToFault = CopyPendingRequests(false);
148
149             if (requestsToFault != null)
150             {
151                 foreach (IRequestBase request in requestsToFault)
152                 {
153                     request.Fault(this);
154                 }
155             }
156         }
157
158         public override T GetProperty<T>()
159         {
160             if (typeof(T) == typeof(IRequestChannel))
161             {
162                 return (T)(object)this;
163             }
164
165             T baseProperty = base.GetProperty<T>();
166             if (baseProperty != null)
167             {
168                 return baseProperty;
169             }
170
171             return default(T);
172         }
173
174         protected override void OnAbort()
175         {
176             AbortPendingRequests();
177         }
178
179         void ReleaseRequest(IRequestBase request)
180         {
181             if (request != null)
182             {
183                 // Synchronization of OnReleaseRequest is the 
184                 // responsibility of the concrete implementation of request.
185                 request.OnReleaseRequest();
186             }
187
188             lock (outstandingRequests)
189             {
190                 // Remove supports the connection having been removed, so don't need extra Contains() check,
191                 // even though this may have been removed by Abort()
192                 outstandingRequests.Remove(request);
193                 if (outstandingRequests.Count == 0)
194                 {
195                     if (!closed && closedEvent != null)
196                     {
197                         closedEvent.Set();
198                     }
199                 }
200             }
201         }
202
203         void TrackRequest(IRequestBase request)
204         {
205             lock (outstandingRequests)
206             {
207                 ThrowIfDisposedOrNotOpen(); // make sure that we haven't already snapshot our collection
208                 outstandingRequests.Add(request);
209             }
210         }
211
212         public IAsyncResult BeginRequest(Message message, AsyncCallback callback, object state)
213         {
214             return this.BeginRequest(message, this.DefaultSendTimeout, callback, state);
215         }
216
217         public IAsyncResult BeginRequest(Message message, TimeSpan timeout, AsyncCallback callback, object state)
218         {
219             if (message == null)
220                 throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("message");
221
222             if (timeout < TimeSpan.Zero)
223                 throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
224                     new ArgumentOutOfRangeException("timeout", timeout, SR.GetString(SR.SFxTimeoutOutOfRange0)));
225
226             ThrowIfDisposedOrNotOpen();
227
228             AddHeadersTo(message);
229             IAsyncRequest asyncRequest = CreateAsyncRequest(message, callback, state);
230             TrackRequest(asyncRequest);
231
232             bool throwing = true;
233             try
234             {
235                 asyncRequest.BeginSendRequest(message, timeout);
236                 throwing = false;
237             }
238             finally
239             {
240                 if (throwing)
241                 {
242                     ReleaseRequest(asyncRequest);
243                 }
244             }
245
246             return asyncRequest;
247         }
248
249         protected abstract IRequest CreateRequest(Message message);
250         protected abstract IAsyncRequest CreateAsyncRequest(Message message, AsyncCallback callback, object state);
251
252         public Message EndRequest(IAsyncResult result)
253         {
254             if (result == null)
255             {
256                 throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("result");
257             }
258
259             IAsyncRequest asyncRequest = result as IAsyncRequest;
260
261             if (asyncRequest == null)
262             {
263                 throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgument("result", SR.GetString(SR.InvalidAsyncResult));
264             }
265
266             try
267             {
268                 Message reply = asyncRequest.End();
269
270                 if (DiagnosticUtility.ShouldTraceInformation)
271                 {
272                     TraceUtility.TraceEvent(TraceEventType.Information, TraceCode.RequestChannelReplyReceived,
273                         SR.GetString(SR.TraceCodeRequestChannelReplyReceived), reply);
274                 }
275
276                 return reply;
277             }
278             finally
279             {
280                 ReleaseRequest(asyncRequest);
281             }
282         }
283
284         public Message Request(Message message)
285         {
286             return this.Request(message, this.DefaultSendTimeout);
287         }
288
289         public Message Request(Message message, TimeSpan timeout)
290         {
291             if (message == null)
292             {
293                 throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull("message");
294             }
295
296             if (timeout < TimeSpan.Zero)
297                 throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(
298                     new ArgumentOutOfRangeException("timeout", timeout, SR.GetString(SR.SFxTimeoutOutOfRange0)));
299
300             ThrowIfDisposedOrNotOpen();
301
302             AddHeadersTo(message);
303             IRequest request = CreateRequest(message);
304             TrackRequest(request);
305             try
306             {
307                 Message reply;
308                 TimeoutHelper timeoutHelper = new TimeoutHelper(timeout);
309
310                 TimeSpan savedTimeout = timeoutHelper.RemainingTime();
311                 try
312                 {
313                     request.SendRequest(message, savedTimeout);
314                 }
315                 catch (TimeoutException timeoutException)
316                 {
317                     throw TraceUtility.ThrowHelperError(new TimeoutException(SR.GetString(SR.RequestChannelSendTimedOut, savedTimeout),
318                         timeoutException), message);
319                 }
320
321                 savedTimeout = timeoutHelper.RemainingTime();
322
323                 try
324                 {
325                     reply = request.WaitForReply(savedTimeout);
326                 }
327                 catch (TimeoutException timeoutException)
328                 {
329                     throw TraceUtility.ThrowHelperError(new TimeoutException(SR.GetString(SR.RequestChannelWaitForReplyTimedOut, savedTimeout),
330                         timeoutException), message);
331                 }
332
333
334                 if (DiagnosticUtility.ShouldTraceInformation)
335                 {
336                     TraceUtility.TraceEvent(TraceEventType.Information, TraceCode.RequestChannelReplyReceived,
337                         SR.GetString(SR.TraceCodeRequestChannelReplyReceived), reply);
338                 }
339
340                 return reply;
341             }
342             finally
343             {
344                 ReleaseRequest(request);
345             }
346         }
347
348         protected virtual void AddHeadersTo(Message message)
349         {
350             if (!manualAddressing && to != null)
351             {
352                 to.ApplyTo(message);
353             }
354         }
355
356         class WaitForPendingRequestsAsyncResult : AsyncResult
357         {
358             static WaitOrTimerCallback completeWaitCallBack = new WaitOrTimerCallback(OnCompleteWaitCallBack);
359             IRequestBase[] pendingRequests;
360             RequestChannel requestChannel;
361             TimeSpan timeout;
362             RegisteredWaitHandle waitHandle;
363
364             public WaitForPendingRequestsAsyncResult(TimeSpan timeout, RequestChannel requestChannel, IRequestBase[] pendingRequests, AsyncCallback callback, object state)
365                 : base(callback, state)
366             {
367                 this.requestChannel = requestChannel;
368                 this.pendingRequests = pendingRequests;
369                 this.timeout = timeout;
370
371                 if (this.timeout == TimeSpan.Zero || this.pendingRequests == null)
372                 {
373                     AbortRequests();
374                     CleanupEvents();
375                     Complete(true);
376                 }
377                 else
378                 {
379                     this.waitHandle = ThreadPool.RegisterWaitForSingleObject(this.requestChannel.closedEvent, completeWaitCallBack, this, TimeoutHelper.ToMilliseconds(timeout), true);
380                 }
381             }
382
383             void AbortRequests()
384             {
385                 if (pendingRequests != null)
386                 {
387                     foreach (IRequestBase request in pendingRequests)
388                     {
389                         request.Abort(this.requestChannel);
390                     }
391                 }
392             }
393
394             void CleanupEvents()
395             {
396                 if (requestChannel.closedEvent != null)
397                 {
398                     if (waitHandle != null)
399                     {
400                         waitHandle.Unregister(requestChannel.closedEvent);
401                     }
402                     requestChannel.FinishClose();
403                 }
404             }
405
406             static void OnCompleteWaitCallBack(object state, bool timedOut)
407             {
408                 WaitForPendingRequestsAsyncResult thisPtr = (WaitForPendingRequestsAsyncResult)state;
409                 Exception completionException = null;
410                 try
411                 {
412                     if (timedOut)
413                     {
414                         thisPtr.AbortRequests();
415                     }
416                     thisPtr.CleanupEvents();
417                 }
418 #pragma warning suppress 56500 // Microsoft, transferring exception to another thread
419                 catch (Exception e)
420                 {
421                     if (Fx.IsFatal(e))
422                     {
423                         throw;
424                     }
425                     completionException = e;
426                 }
427
428                 thisPtr.Complete(false, completionException);
429             }
430
431             public static void End(IAsyncResult result)
432             {
433                 AsyncResult.End<WaitForPendingRequestsAsyncResult>(result);
434             }
435         }
436     }
437
438     interface IRequestBase
439     {
440         void Abort(RequestChannel requestChannel);
441         void Fault(RequestChannel requestChannel);
442         void OnReleaseRequest();
443     }
444
445     interface IRequest : IRequestBase
446     {
447         void SendRequest(Message message, TimeSpan timeout);
448         Message WaitForReply(TimeSpan timeout);
449     }
450
451     interface IAsyncRequest : IAsyncResult, IRequestBase
452     {
453         void BeginSendRequest(Message message, TimeSpan timeout);
454         Message End();
455     }
456 }