[coop] Temporarily restore MonoThreadInfo when TLS destructor runs. Fixes #43099
[mono.git] / mcs / class / referencesource / System.ServiceModel / System / ServiceModel / Channels / WebSocketHelper.cs
1 // <copyright>
2 // Copyright (c) Microsoft Corporation.  All rights reserved.
3 // </copyright>
4
5 namespace System.ServiceModel.Channels
6 {
7     using System;
8     using System.Collections.Generic;
9     using System.Globalization;
10     using System.Linq;
11     using System.Net;
12     using System.Net.WebSockets;
13     using System.Runtime;
14     using System.Runtime.InteropServices;
15     using System.Security.Cryptography;
16     using System.Text;
17     using System.Threading;
18     using System.Threading.Tasks;
19
20     static class WebSocketHelper
21     {
22         internal const int OperationNotStarted = 0;
23         internal const int OperationFinished = 1;
24
25         internal const string SecWebSocketKey = "Sec-WebSocket-Key";
26         internal const string SecWebSocketVersion = "Sec-WebSocket-Version";
27         internal const string SecWebSocketProtocol = "Sec-WebSocket-Protocol";
28         internal const string SecWebSocketAccept = "Sec-WebSocket-Accept";
29         internal const string MaxPendingConnectionsString = "MaxPendingConnections";
30         internal const string WebSocketTransportSettingsString = "WebSocketTransportSettings";
31
32         internal const string CloseOperation = "CloseOperation";
33         internal const string SendOperation = "SendOperation";
34         internal const string ReceiveOperation = "ReceiveOperation";
35
36         internal static readonly char[] ProtocolSeparators = new char[] { ',' };
37
38         const string WebSocketKeyPostString = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
39
40         const string SchemeWs = "ws";
41         const string SchemeWss = "wss";
42
43         static readonly int PropertyBufferSize = ((2 * Marshal.SizeOf(typeof(uint))) + Marshal.SizeOf(typeof(bool))) + IntPtr.Size;
44         static readonly HashSet<char> InvalidSeparatorSet = new HashSet<char>(new char[] { '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=', '{', '}', ' ' });
45         static string currentWebSocketVersion;
46
47         internal static string ComputeAcceptHeader(string webSocketKey)
48         {
49             Fx.Assert(webSocketKey != null, "webSocketKey should not be null.");
50             using (SHA1 sha = SHA1.Create())
51             {
52                 string fullString = webSocketKey + WebSocketHelper.WebSocketKeyPostString;
53                 byte[] bytes = Encoding.UTF8.GetBytes(fullString);
54                 return Convert.ToBase64String(sha.ComputeHash(bytes));
55             }
56         }
57
58         internal static int ComputeClientBufferSize(long maxReceivedMessageSize)
59         {
60             return ComputeInternalBufferSize(maxReceivedMessageSize, false);
61         }
62
63         internal static int ComputeServerBufferSize(long maxReceivedMessageSize)
64         {
65             return ComputeInternalBufferSize(maxReceivedMessageSize, true);
66         }
67
68         internal static int GetReceiveBufferSize(long maxReceivedMessageSize)
69         {
70             int effectiveMaxReceiveBufferSize = maxReceivedMessageSize <= WebSocketDefaults.BufferSize ? (int)maxReceivedMessageSize : WebSocketDefaults.BufferSize;
71             return Math.Max(WebSocketDefaults.MinReceiveBufferSize, effectiveMaxReceiveBufferSize);
72         }
73
74         internal static bool UseWebSocketTransport(WebSocketTransportUsage transportUsage, bool isContractDuplex)
75         {
76             return transportUsage == WebSocketTransportUsage.Always
77                 || (transportUsage == WebSocketTransportUsage.WhenDuplex && isContractDuplex);
78         }
79
80         internal static Uri GetWebSocketUri(Uri httpUri)
81         {
82             Fx.Assert(httpUri != null, "RemoteAddress.Uri should not be null.");
83             UriBuilder builder = new UriBuilder(httpUri);
84
85             if (Uri.UriSchemeHttp.Equals(httpUri.Scheme, StringComparison.OrdinalIgnoreCase))
86             {
87                 builder.Scheme = SchemeWs;
88             }
89             else
90             {
91                 Fx.Assert(
92                     Uri.UriSchemeHttps.Equals(httpUri.Scheme, StringComparison.OrdinalIgnoreCase),
93                     "httpUri.Scheme should be http or https.");
94                 builder.Scheme = SchemeWss;
95             }
96
97             return builder.Uri;
98         }
99
100         internal static bool IsWebSocketUri(Uri uri)
101         {
102             return uri != null && 
103                 (WebSocketHelper.SchemeWs.Equals(uri.Scheme, StringComparison.OrdinalIgnoreCase) ||
104                  WebSocketHelper.SchemeWss.Equals(uri.Scheme, StringComparison.OrdinalIgnoreCase));
105         }
106
107         internal static Uri NormalizeWsSchemeWithHttpScheme(Uri uri)
108         {
109             Fx.Assert(uri != null, "RemoteAddress.Uri should not be null.");
110             if (!IsWebSocketUri(uri))
111             {
112                 return uri;
113             }
114
115             UriBuilder builder = new UriBuilder(uri);
116
117             switch (uri.Scheme.ToLowerInvariant())
118             {
119                 case SchemeWs:
120                     builder.Scheme = Uri.UriSchemeHttp;
121                     break;
122                 case SchemeWss:
123                     builder.Scheme = Uri.UriSchemeHttps;
124                     break;
125                 default:
126                     break;
127             }
128
129             return builder.Uri;
130         }
131
132         internal static bool TryParseSubProtocol(string subProtocolValue, out List<string> subProtocolList)
133         {
134             subProtocolList = new List<string>();
135             if (subProtocolValue != null)
136             {
137                 string[] parsedTokens = subProtocolValue.Split(ProtocolSeparators, StringSplitOptions.RemoveEmptyEntries);
138
139                 string invalidChar;
140                 for (int i = 0; i < parsedTokens.Length; i++)
141                 {
142                     string token = parsedTokens[i];
143                     if (!string.IsNullOrWhiteSpace(token))
144                     {
145                         token = token.Trim();
146                         if (!IsSubProtocolInvalid(token, out invalidChar))
147                         {
148                             // Note that we could be adding a duplicate to this list. According to the specification the header should not include
149                             // duplicates but we aim to be "robust in what we receive" so we will allow it. The matching code that consumes this list
150                             // will take the first match so duplicates will not affect the outcome of the negotiation process.
151                             subProtocolList.Add(token);
152                         }
153                         else
154                         {
155                             FxTrace.Exception.AsWarning(new WebException(
156                                 SR.GetString(SR.WebSocketInvalidProtocolInvalidCharInProtocolString, token, invalidChar)));
157                             return false;
158                         }
159                     }
160                 }
161             }
162
163             return true;
164         }
165
166         internal static bool IsSubProtocolInvalid(string protocol, out string invalidChar)
167         {
168             Fx.Assert(protocol != null, "protocol should not be null");
169             char[] chars = protocol.ToCharArray();
170             for (int i = 0; i < chars.Length; i++)
171             {
172                 char ch = chars[i];
173                 if (ch < 0x21 || ch > 0x7e)
174                 {
175                     invalidChar = string.Format(CultureInfo.InvariantCulture, "[{0}]", (int)ch);
176                     return true;
177                 }
178
179                 if (InvalidSeparatorSet.Contains(ch))
180                 {
181                     invalidChar = ch.ToString();
182                     return true;
183                 }
184             }
185
186             invalidChar = null;
187             return false;
188         }
189
190         internal static string GetCurrentVersion()
191         {
192             if (currentWebSocketVersion == null)
193             {
194                 WebSocket.RegisterPrefixes();
195                 HttpWebRequest request = (HttpWebRequest)HttpWebRequest.Create("ws://localhost");
196                 string version = request.Headers[WebSocketHelper.SecWebSocketVersion];
197                 Fx.Assert(version != null, "version should not be null.");
198                 currentWebSocketVersion = version.Trim();
199             }
200
201             return currentWebSocketVersion;
202         }
203
204         internal static WebSocketTransportSettings GetRuntimeWebSocketSettings(WebSocketTransportSettings settings)
205         {
206             WebSocketTransportSettings runtimeSettings = settings.Clone();
207             if (runtimeSettings.MaxPendingConnections == WebSocketDefaults.DefaultMaxPendingConnections)
208             {
209                 runtimeSettings.MaxPendingConnections = WebSocketDefaults.MaxPendingConnectionsCpuCount;
210             }
211
212             return runtimeSettings;
213         }
214         
215         internal static bool OSSupportsWebSockets()
216         {
217             return OSEnvironmentHelper.IsAtLeast(OSVersion.Win8);
218         }
219
220         [System.Diagnostics.CodeAnalysis.SuppressMessage(FxCop.Category.ReliabilityBasic, FxCop.Rule.WrapExceptionsRule,
221                     Justification = "The exceptions thrown here are already wrapped.")]
222         internal static void ThrowCorrectException(Exception ex)
223         {
224             throw ConvertAndTraceException(ex);
225         }
226
227         [System.Diagnostics.CodeAnalysis.SuppressMessage(FxCop.Category.ReliabilityBasic, FxCop.Rule.WrapExceptionsRule,
228                     Justification = "The exceptions thrown here are already wrapped.")]
229         internal static void ThrowCorrectException(Exception ex, TimeSpan timeout, string operation)
230         {
231             throw ConvertAndTraceException(ex, timeout, operation);
232         }
233
234         internal static Exception ConvertAndTraceException(Exception ex)
235         {
236             return ConvertAndTraceException(
237                     ex,
238                     TimeSpan.MinValue, // this is a dummy since operation type is null, so the timespan value won't be used
239                     null);
240         }
241
242         [System.Diagnostics.CodeAnalysis.SuppressMessage(FxCop.Category.ReliabilityBasic, "Reliability103:ThrowWrappedExceptionsRule",
243                     Justification = "The exceptions wrapped here will be thrown out later.")]
244         internal static Exception ConvertAndTraceException(Exception ex, TimeSpan timeout, string operation)
245         {
246             ObjectDisposedException objectDisposedException = ex as ObjectDisposedException;
247             if (objectDisposedException != null)
248             {
249                 CommunicationObjectAbortedException communicationObjectAbortedException = new CommunicationObjectAbortedException(ex.Message, ex);
250                 FxTrace.Exception.AsWarning(communicationObjectAbortedException);
251                 return communicationObjectAbortedException;
252             }
253
254             AggregateException aggregationException = ex as AggregateException;
255             if (aggregationException != null)
256             {
257                 Exception exception = FxTrace.Exception.AsError<OperationCanceledException>(aggregationException);
258                 OperationCanceledException operationCanceledException = exception as OperationCanceledException;
259                 if (operationCanceledException != null)
260                 {
261                     TimeoutException timeoutException = GetTimeoutException(exception, timeout, operation);
262                     FxTrace.Exception.AsWarning(timeoutException);
263                     return timeoutException;
264                 }
265                 else
266                 {
267                     Exception communicationException = ConvertAggregateExceptionToCommunicationException(aggregationException);
268                     if (communicationException is CommunicationObjectAbortedException)
269                     {
270                         FxTrace.Exception.AsWarning(communicationException);
271                         return communicationException;
272                     }
273                     else
274                     {
275                         return FxTrace.Exception.AsError(communicationException);
276                     }
277                 }
278             }
279
280             WebSocketException webSocketException = ex as WebSocketException;
281             if (webSocketException != null)
282             {
283                 switch (webSocketException.WebSocketErrorCode)
284                 {
285                     case WebSocketError.InvalidMessageType:
286                     case WebSocketError.UnsupportedProtocol:
287                     case WebSocketError.UnsupportedVersion:
288                         ex = new ProtocolException(ex.Message, ex);
289                         break;
290                     default:
291                         ex = new CommunicationException(ex.Message, ex);
292                         break;
293                 }
294             }
295
296             return FxTrace.Exception.AsError(ex);
297         }
298
299         [System.Diagnostics.CodeAnalysis.SuppressMessage(FxCop.Category.ReliabilityBasic, "Reliability103",
300                             Justification = "The exceptions will be wrapped by the callers.")]
301         internal static Exception ConvertAggregateExceptionToCommunicationException(AggregateException ex)
302         {
303             Exception exception = FxTrace.Exception.AsError<WebSocketException>(ex);
304             WebSocketException webSocketException = exception as WebSocketException;
305             if (webSocketException != null && webSocketException.InnerException != null)
306             {
307                 HttpListenerException httpListenerException = webSocketException.InnerException as HttpListenerException;
308                 if (httpListenerException != null)
309                 {
310                     return HttpChannelUtilities.CreateCommunicationException(httpListenerException);
311                 }
312             }
313
314             ObjectDisposedException objectDisposedException = exception as ObjectDisposedException;
315             if (objectDisposedException != null)
316             {
317                 return new CommunicationObjectAbortedException(exception.Message, exception);
318             }
319
320             return new CommunicationException(exception.Message, exception);
321         }
322
323         internal static void ThrowExceptionOnTaskFailure(Task task, TimeSpan timeout, string operation)
324         {
325             if (task.IsFaulted)
326             {
327                 throw FxTrace.Exception.AsError<CommunicationException>(task.Exception);
328             }
329             else if (task.IsCanceled)
330             {
331                 throw FxTrace.Exception.AsError(GetTimeoutException(null, timeout, operation));
332             }
333         }
334
335         internal static TimeoutException GetTimeoutException(Exception innerException, TimeSpan timeout, string operation)
336         {
337             string errorMsg = string.Empty;
338             if (operation != null)
339             {
340                 switch (operation)
341                 {
342                     case WebSocketHelper.CloseOperation:
343                         errorMsg = SR.GetString(SR.CloseTimedOut, timeout);
344                         break;
345                     case WebSocketHelper.SendOperation:
346                         errorMsg = SR.GetString(SR.WebSocketSendTimedOut, timeout);
347                         break;
348                     case WebSocketHelper.ReceiveOperation:
349                         errorMsg = SR.GetString(SR.WebSocketReceiveTimedOut, timeout);
350                         break;
351                     default:
352                         errorMsg = SR.GetString(SR.WebSocketOperationTimedOut, operation, timeout);
353                         break;
354                 }
355             }
356
357             return innerException == null ? new TimeoutException(errorMsg) : new TimeoutException(errorMsg, innerException);
358         }
359
360         private static int ComputeInternalBufferSize(long maxReceivedMessageSize, bool isServerBuffer)
361         {
362             const int NativeOverheadBufferSize = 144;
363             /* LAYOUT:
364             | Native buffer              | PayloadReceiveBuffer | PropertyBuffer |
365             | RBS + SBS + 144            | RBS                  | PBS            |
366             | Only WSPC may modify       | Only WebSocketBase may modify         | 
367
368              *RBS = ReceiveBufferSize, *SBS = SendBufferSize
369              *PBS = PropertyBufferSize (32-bit: 16, 64 bit: 20 bytes) */
370
371             int nativeSendBufferSize = isServerBuffer ? WebSocketDefaults.MinSendBufferSize : WebSocketDefaults.BufferSize;
372             return (2 * GetReceiveBufferSize(maxReceivedMessageSize)) + nativeSendBufferSize + NativeOverheadBufferSize + PropertyBufferSize;
373         }
374     }
375 }