91311238e9ed669e4c4c0986c0cfb6e05fc62b3e
[mono.git] / mcs / class / referencesource / System / net / System / Net / SecureProtocols / _NegoStream.cs
1 /*++
2 Copyright (c) Microsoft Corporation
3
4 Module Name:
5
6     _NegoStream.cs
7
8 Abstract:
9     The class is used to encrypt/decrypt user data based on established
10     security context. Presumably the context belongs to SSPI NEGO or NTLM package.
11
12 Author:
13     Alexei Vopilov    12-Aug-2003
14
15 Revision History:
16     12-Aug-2003 New design that has obsoleted Authenticator class
17     15-Jan-2004 Converted to a partial class, only internal NegotiateStream implementaion goes into this file.
18
19 --*/
20
21 namespace System.Net.Security {
22     using System;
23     using System.IO;
24     using System.Security;
25     using System.Security.Principal;
26     using System.Security.Permissions;
27     using System.Threading;
28
29     //
30     // This is a wrapping stream that does data encryption/decryption based on a successfully authenticated SSPI context.
31     //
32     public partial class NegotiateStream: AuthenticatedStream
33     {
34         private static AsyncCallback _WriteCallback = new AsyncCallback(WriteCallback);
35         private static AsyncProtocolCallback _ReadCallback  = new AsyncProtocolCallback(ReadCallback);
36
37         private int         _NestedWrite;
38         private int         _NestedRead;
39         private byte[]      _ReadHeader;
40
41         // never updated directly, special properties are used
42         private byte[]      _InternalBuffer;
43         private int         _InternalOffset;
44         private int         _InternalBufferCount;
45
46         FixedSizeReader     _FrameReader;
47
48         //
49         // Private implemenation
50         //
51
52         private void InitializeStreamPart()
53         {
54             _ReadHeader = new byte[4];
55             _FrameReader = new FixedSizeReader(InnerStream);
56         }
57
58         //
59         //
60         private byte[] InternalBuffer {
61             get {
62                 return _InternalBuffer;
63             }
64         }
65         //
66         //
67         private int InternalOffset {
68             get {
69                 return _InternalOffset;
70             }
71         }
72         //
73         private int InternalBufferCount {
74             get {
75                 return _InternalBufferCount;
76             }
77         }
78         //
79         //
80         private void DecrementInternalBufferCount(int decrCount)
81         {
82             _InternalOffset += decrCount;
83             _InternalBufferCount -= decrCount;
84         }
85         //
86         //
87         private void EnsureInternalBufferSize(int bytes)
88         {
89             _InternalBufferCount = bytes;
90             _InternalOffset = 0;
91             if (InternalBuffer == null || InternalBuffer.Length < bytes)
92             {
93                 _InternalBuffer = new byte[bytes];
94             }
95         }
96         //
97         private void AdjustInternalBufferOffsetSize(int bytes, int offset)
98         {
99             _InternalBufferCount = bytes;
100             _InternalOffset = offset;
101         }
102         //
103         // Validates user parameteres for all Read/Write methods
104         //
105         private void ValidateParameters(byte[] buffer, int offset, int count)
106         {
107             if (buffer == null)
108                 throw new ArgumentNullException("buffer");
109
110             if (offset < 0)
111                 throw new ArgumentOutOfRangeException("offset");
112
113             if (count < 0)
114                 throw new ArgumentOutOfRangeException("count");
115
116             if (count > buffer.Length-offset)
117                 throw new ArgumentOutOfRangeException("count", SR.GetString(SR.net_offset_plus_count));
118         }
119         //
120         // Combined sync/async write method. For sync requet asyncRequest==null
121         //
122         private void ProcessWrite(byte[] buffer, int offset, int count, AsyncProtocolRequest asyncRequest)
123         {
124             ValidateParameters(buffer, offset, count);
125
126             if (Interlocked.Exchange(ref _NestedWrite, 1) == 1)
127             {
128                 throw new NotSupportedException(SR.GetString(SR.net_io_invalidnestedcall, (asyncRequest != null? "BeginWrite":"Write"), "write"));
129             }
130
131
132             bool failed = false;
133             try
134             {
135                 StartWriting(buffer, offset, count, asyncRequest);
136             }
137             catch (Exception e)
138             {
139                 failed = true;
140                 if (e is IOException) {
141                     throw;
142                 }
143                 throw new IOException(SR.GetString(SR.net_io_write), e);
144             }
145             finally
146             {
147                 if (asyncRequest == null || failed)
148                 {
149                     _NestedWrite = 0;
150                 }
151             }
152         }
153         //
154         //
155         //
156         private void StartWriting(byte[] buffer, int offset, int count, AsyncProtocolRequest asyncRequest)
157         {
158             // We loop to this method from the callback
159             // If the last chunk was just completed from async callback (count < 0), we complete user request
160             if (count >= 0 )
161             {
162                 byte[] outBuffer = null;
163                 do
164                 {
165                     int chunkBytes = Math.Min(count, NegoState.c_MaxWriteDataSize);
166                     int encryptedBytes;
167
168                     try {
169                         encryptedBytes = _NegoState.EncryptData(buffer, offset, chunkBytes, ref outBuffer);
170                     }
171                     catch (Exception e) {
172                         throw new IOException(SR.GetString(SR.net_io_encrypt), e);
173                     }
174
175                     if (asyncRequest != null)
176                     {
177                         // prepare for the next request
178                         asyncRequest.SetNextRequest(buffer, offset+chunkBytes, count-chunkBytes, null);
179                         IAsyncResult ar = InnerStream.BeginWrite(outBuffer, 0, encryptedBytes, _WriteCallback, asyncRequest);
180                         if (!ar.CompletedSynchronously)
181                         {
182                             return;
183                         }
184                         InnerStream.EndWrite(ar);
185
186                     }
187                     else
188                     {
189                         InnerStream.Write(outBuffer, 0, encryptedBytes);
190                     }
191                     offset += chunkBytes;
192                     count  -= chunkBytes;
193                 } while (count != 0);
194             }
195
196             if (asyncRequest != null) {
197                 asyncRequest.CompleteUser();
198             }
199         }
200         //
201         // Combined sync/async read method. For sync requet asyncRequest==null
202         // There is a little overheader because we need to pass buffer/offset/count used only in sync.
203         // Still the benefit is that we have a common sync/async code path.
204         //
205         private int ProcessRead(byte[] buffer, int offset, int count, AsyncProtocolRequest asyncRequest)
206         {
207             ValidateParameters(buffer, offset, count);
208
209             if (Interlocked.Exchange(ref _NestedRead, 1) == 1)
210             {
211                 throw new NotSupportedException(SR.GetString(SR.net_io_invalidnestedcall, (asyncRequest!=null? "BeginRead":"Read"), "read"));
212             }
213
214             bool failed = false;
215             try
216             {
217                 if (InternalBufferCount != 0)
218                 {
219                     int copyBytes = InternalBufferCount > count? count:InternalBufferCount;
220                     if (copyBytes != 0)
221                     {
222                         Buffer.BlockCopy(InternalBuffer, InternalOffset, buffer, offset, copyBytes);
223                         DecrementInternalBufferCount(copyBytes);
224                     }
225                     if (asyncRequest != null) {
226                         asyncRequest.CompleteUser((object) copyBytes);
227                     }
228                     return copyBytes;
229                 }
230                 // going into real IO
231                 return StartReading(buffer, offset, count, asyncRequest);
232             }
233             catch (Exception e)
234             {
235                 failed = true;
236                 if (e is IOException) {
237                     throw;
238                 }
239                 throw new IOException(SR.GetString(SR.net_io_read), e);
240             }
241             finally
242             {
243                 // if sync request or exception
244                 if (asyncRequest == null || failed)
245                 {
246                     _NestedRead = 0;
247                 }
248             }
249         }
250         //
251         // To avoid recursion when decrypted 0 bytes this method will loop until decryption resulted at least in 1 byte.
252         //
253         private int StartReading(byte[] buffer, int offset, int count, AsyncProtocolRequest asyncRequest)
254         {
255             int result;
256             // When we read -1 bytes means we have decrypted 0 bytes, need looping.
257             while ((result = StartFrameHeader(buffer, offset, count, asyncRequest)) == -1) {
258                 ;
259             }
260             return result;
261         }
262
263         //
264         // Need read frame size first
265         //
266         private int StartFrameHeader(byte[] buffer, int offset, int count, AsyncProtocolRequest asyncRequest)
267         {
268             int readBytes = 0;
269             if (asyncRequest != null)
270             {
271                 asyncRequest.SetNextRequest(_ReadHeader, 0, _ReadHeader.Length, _ReadCallback);
272                 _FrameReader.AsyncReadPacket(asyncRequest);
273                 if (!asyncRequest.MustCompleteSynchronously)
274                 {
275                     return 0;
276                 }
277                 readBytes = asyncRequest.Result;
278             }
279             else
280             {
281                 readBytes = _FrameReader.ReadPacket(_ReadHeader, 0, _ReadHeader.Length);
282             }
283             return StartFrameBody(readBytes, buffer, offset, count, asyncRequest);
284         }
285         //
286         //
287         //
288         private int StartFrameBody(int readBytes, byte[] buffer, int offset, int count, AsyncProtocolRequest asyncRequest)
289         {
290             if (readBytes == 0)
291             {
292                 //EOF
293                 if (asyncRequest != null)
294                 {
295                     asyncRequest.CompleteUser((object)0);
296                 }
297                 return 0;
298             }
299             GlobalLog.Assert(readBytes == _ReadHeader.Length, "NegoStream::ProcessHeader()|Frame size must be 4 but received {0} bytes.", readBytes);
300
301             //rpelace readBytes with the body size recovered from the header content
302             readBytes =  _ReadHeader[3];
303             readBytes = (readBytes<<8) | _ReadHeader[2];
304             readBytes = (readBytes<<8) | _ReadHeader[1];
305             readBytes = (readBytes<<8) | _ReadHeader[0];
306
307             //
308             // The body carries 4 bytes for trailer size slot plus trailer, hence <=4 frame size is always an error.
309             // Additionally we'd like to restrice the read frame size to modest 64k
310             //
311             if (readBytes <= 4 || readBytes > NegoState.c_MaxReadFrameSize)
312             {
313                 throw new IOException(SR.GetString(SR.net_frame_read_size));
314             }
315
316             //
317             // Always pass InternalBuffer for SSPI "in place" decryption.
318             // A user buffer can be shared by many threads in that case decryption/integrity check may fail cause of data corruption.
319             //
320             EnsureInternalBufferSize(readBytes);
321             if (asyncRequest != null) //Async
322             {
323                 asyncRequest.SetNextRequest(InternalBuffer, 0, readBytes, _ReadCallback);
324
325                 _FrameReader.AsyncReadPacket(asyncRequest);
326
327                 if (!asyncRequest.MustCompleteSynchronously)
328                 {
329                     return 0;
330                 }
331                 readBytes = asyncRequest.Result;
332             }
333             else //Sync
334             {
335                 readBytes = _FrameReader.ReadPacket(InternalBuffer, 0, readBytes);
336             }
337             return ProcessFrameBody(readBytes, buffer, offset, count, asyncRequest);
338         }
339         //
340         //
341         //
342         private int ProcessFrameBody(int readBytes, byte[] buffer, int offset, int count, AsyncProtocolRequest asyncRequest)
343         {
344             if (readBytes == 0)
345             {
346                 // We already checked that the frame body is bigger than 0 bytes
347                 // Hence, this is an EOF ... fire.
348                 throw new IOException(SR.GetString(SR.net_io_eof));
349             }
350
351             //Decrypt into internal buffer, change "readBytes" to count now _Decrypted Bytes_
352             int internalOffset;
353             readBytes = _NegoState.DecryptData(InternalBuffer, 0, readBytes, out internalOffset);
354
355             // Decrypted data start from zero offset, the size can be shrinked after decryption
356             AdjustInternalBufferOffsetSize(readBytes, internalOffset);
357
358             if (readBytes == 0 && count != 0)
359             {
360                 //Read again
361                 return -1;
362             }
363
364             if (readBytes > count)
365             {
366                 readBytes = count;
367             }
368             Buffer.BlockCopy(InternalBuffer, InternalOffset, buffer, offset, readBytes);
369
370             // This will adjust both the remaining internal buffer count and the offset
371             DecrementInternalBufferCount(readBytes);
372
373             if (asyncRequest != null)
374             {
375                 asyncRequest.CompleteUser((object)readBytes);
376             }
377
378             return readBytes;
379         }
380         //
381         //
382         //
383         private static void WriteCallback(IAsyncResult transportResult)
384         {
385             if (transportResult.CompletedSynchronously)
386             {
387                 return;
388             }
389             GlobalLog.Assert(transportResult.AsyncState is AsyncProtocolRequest , "NegotiateSteam::WriteCallback|State type is wrong, expected AsyncProtocolRequest.");
390
391             AsyncProtocolRequest asyncRequest = (AsyncProtocolRequest) transportResult.AsyncState;
392
393             try {
394                 NegotiateStream negoStream = (NegotiateStream)asyncRequest.AsyncObject;
395                 negoStream.InnerStream.EndWrite(transportResult);
396                 if (asyncRequest.Count == 0) {
397                     // this was the last chunk
398                     asyncRequest.Count = -1;
399                 }
400                 negoStream.StartWriting(asyncRequest.Buffer, asyncRequest.Offset, asyncRequest.Count, asyncRequest);
401
402             }
403             catch (Exception e) {
404                 if (asyncRequest.IsUserCompleted) {
405                     // This will throw on a worker thread.
406                     throw;
407                 }
408                 asyncRequest.CompleteWithError(e);
409             }
410         }
411         //
412         //
413         private static void ReadCallback(AsyncProtocolRequest asyncRequest)
414         {
415             // Async ONLY completion
416             try
417             {
418                 NegotiateStream negoStream = (NegotiateStream)asyncRequest.AsyncObject;
419                 BufferAsyncResult bufferResult = (BufferAsyncResult) asyncRequest.UserAsyncResult;
420
421                 // This is not a hack, just optimization to avoid an additional callback.
422                 //
423                 if ((object) asyncRequest.Buffer == (object)negoStream._ReadHeader)
424                 {
425                     negoStream.StartFrameBody(asyncRequest.Result, bufferResult.Buffer, bufferResult.Offset, bufferResult.Count, asyncRequest);
426                 }
427                 else
428                 {
429                     if (-1 == negoStream.ProcessFrameBody(asyncRequest.Result, bufferResult.Buffer, bufferResult.Offset, bufferResult.Count, asyncRequest))
430                     {
431                         // in case we decrypted 0 bytes start another reading.
432                         negoStream.StartReading(bufferResult.Buffer, bufferResult.Offset, bufferResult.Count, asyncRequest);
433
434                     }
435                 }
436             }
437             catch (Exception e)
438             {
439                 if (asyncRequest.IsUserCompleted) {
440                     // This will throw on a worker thread.
441                     throw;
442                 }
443                 asyncRequest.CompleteWithError(e);
444             }
445         }
446     }
447 }