37c3276d24f6359817b1b7dbc9e685156459e103
[mono.git] / mcs / class / referencesource / System.ServiceModel / System / ServiceModel / Dispatcher / StreamFormatter.cs
1 //-----------------------------------------------------------------------------
2 // Copyright (c) Microsoft Corporation.  All rights reserved.
3 //-----------------------------------------------------------------------------
4
5 namespace System.ServiceModel.Dispatcher
6 {
7     using System.IO;
8     using System.Runtime;
9     using System.ServiceModel;
10     using System.ServiceModel.Channels;
11     using System.ServiceModel.Description;
12     using System.ServiceModel.Diagnostics;
13     using System.Xml;
14
15     class StreamFormatter
16     {
17         string wrapperName;
18         string wrapperNS;
19         string partName;
20         string partNS;
21         int streamIndex;
22         bool isRequest;
23         string operationName;
24         const int returnValueIndex = -1;
25
26         internal static StreamFormatter Create(MessageDescription messageDescription, string operationName, bool isRequest)
27         {
28             MessagePartDescription streamPart = ValidateAndGetStreamPart(messageDescription, isRequest, operationName);
29             if (streamPart == null)
30                 return null;
31             return new StreamFormatter(messageDescription, streamPart, operationName, isRequest);
32         }
33
34         StreamFormatter(MessageDescription messageDescription, MessagePartDescription streamPart, string operationName, bool isRequest)
35         {
36             if ((object)streamPart == (object)messageDescription.Body.ReturnValue)
37                 this.streamIndex = returnValueIndex;
38             else
39                 this.streamIndex = streamPart.Index;
40             wrapperName = messageDescription.Body.WrapperName;
41             wrapperNS = messageDescription.Body.WrapperNamespace;
42             partName = streamPart.Name;
43             partNS = streamPart.Namespace;
44             this.isRequest = isRequest;
45             this.operationName = operationName;
46         }
47
48         internal void Serialize(XmlDictionaryWriter writer, object[] parameters, object returnValue)
49         {
50             Stream streamValue = GetStreamAndWriteStartWrapperIfNecessary(writer, parameters, returnValue);
51             writer.WriteValue(new OperationStreamProvider(streamValue));
52             WriteEndWrapperIfNecessary(writer);
53         }
54
55         Stream GetStreamAndWriteStartWrapperIfNecessary(XmlDictionaryWriter writer, object[] parameters, object returnValue)
56         {
57             Stream streamValue = GetStreamValue(parameters, returnValue);
58             if (streamValue == null)
59                 throw DiagnosticUtility.ExceptionUtility.ThrowHelperArgumentNull(partName);
60             if (WrapperName != null)
61                 writer.WriteStartElement(WrapperName, WrapperNamespace);
62             writer.WriteStartElement(PartName, PartNamespace);
63             return streamValue;
64         }
65
66         void WriteEndWrapperIfNecessary(XmlDictionaryWriter writer)
67         {
68             writer.WriteEndElement();
69             if (wrapperName != null)
70                 writer.WriteEndElement();
71         }
72
73         internal IAsyncResult BeginSerialize(XmlDictionaryWriter writer, object[] parameters, object returnValue, AsyncCallback callback, object state)
74         {
75             return new SerializeAsyncResult(this, writer, parameters, returnValue, callback, state);
76         }
77
78         public void EndSerialize(IAsyncResult result)
79         {
80             SerializeAsyncResult.End(result);
81         }
82
83         class SerializeAsyncResult : AsyncResult
84         {
85             static AsyncCompletion handleEndSerialize = new AsyncCompletion(HandleEndSerialize);
86
87             StreamFormatter streamFormatter;
88             XmlDictionaryWriter writer;
89
90             internal SerializeAsyncResult(StreamFormatter streamFormatter, XmlDictionaryWriter writer, object[] parameters, object returnValue,
91                 AsyncCallback callback, object state)
92                 : base(callback, state)
93             {
94                 this.streamFormatter = streamFormatter;
95                 this.writer = writer;
96                 bool completeSelf = true;
97
98                 Stream streamValue = streamFormatter.GetStreamAndWriteStartWrapperIfNecessary(writer, parameters, returnValue);
99                 IAsyncResult result = writer.WriteValueAsync(new OperationStreamProvider(streamValue)).AsAsyncResult(PrepareAsyncCompletion(handleEndSerialize), this);
100                 completeSelf = SyncContinue(result);
101
102                 // Note:  The current task implementation hard codes the "IAsyncResult.CompletedSynchronously" property to false, so this fast path will never
103                 // be hit, and we will always hop threads.  CSDMain #210220
104                 if (completeSelf)
105                 {
106                     Complete(true);
107                 }
108             }
109
110             static bool HandleEndSerialize(IAsyncResult result)
111             {
112                 SerializeAsyncResult thisPtr = (SerializeAsyncResult)result.AsyncState;
113                 thisPtr.streamFormatter.WriteEndWrapperIfNecessary(thisPtr.writer);
114                 return true;
115             }
116
117             public static void End(IAsyncResult result)
118             {
119                 AsyncResult.End<SerializeAsyncResult>(result);
120             }
121         }
122
123         internal void Deserialize(object[] parameters, ref object retVal, Message message)
124         {
125             SetStreamValue(parameters, ref retVal, new MessageBodyStream(message, WrapperName, WrapperNamespace, PartName, PartNamespace, isRequest));
126         }
127
128         internal string WrapperName
129         {
130             get { return wrapperName; }
131             set { wrapperName = value; }
132         }
133
134         internal string WrapperNamespace
135         {
136             get { return wrapperNS; }
137             set { wrapperNS = value; }
138         }
139
140         internal string PartName
141         {
142             get { return partName; }
143         }
144
145         internal string PartNamespace
146         {
147             get { return partNS; }
148         }
149
150
151         Stream GetStreamValue(object[] parameters, object returnValue)
152         {
153             if (streamIndex == returnValueIndex)
154                 return (Stream)returnValue;
155             return (Stream)parameters[streamIndex];
156         }
157
158         void SetStreamValue(object[] parameters, ref object returnValue, Stream streamValue)
159         {
160             if (streamIndex == returnValueIndex)
161                 returnValue = streamValue;
162             else
163                 parameters[streamIndex] = streamValue;
164         }
165
166         static MessagePartDescription ValidateAndGetStreamPart(MessageDescription messageDescription, bool isRequest, string operationName)
167         {
168             MessagePartDescription part = GetStreamPart(messageDescription);
169             if (part != null)
170                 return part;
171             if (HasStream(messageDescription))
172             {
173                 if (messageDescription.IsTypedMessage)
174                     throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SR.GetString(SR.SFxInvalidStreamInTypedMessage, messageDescription.MessageName)));
175                 else if (isRequest)
176                     throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SR.GetString(SR.SFxInvalidStreamInRequest, operationName)));
177                 else
178                     throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new InvalidOperationException(SR.GetString(SR.SFxInvalidStreamInResponse, operationName)));
179             }
180             return null;
181         }
182
183         private static bool HasStream(MessageDescription messageDescription)
184         {
185             if (messageDescription.Body.ReturnValue != null && messageDescription.Body.ReturnValue.Type == typeof(Stream))
186                 return true;
187             foreach (MessagePartDescription part in messageDescription.Body.Parts)
188             {
189                 if (part.Type == typeof(Stream))
190                     return true;
191             }
192             return false;
193         }
194
195         static MessagePartDescription GetStreamPart(MessageDescription messageDescription)
196         {
197             if (OperationFormatter.IsValidReturnValue(messageDescription.Body.ReturnValue))
198             {
199                 if (messageDescription.Body.Parts.Count == 0)
200                     if (messageDescription.Body.ReturnValue.Type == typeof(Stream))
201                         return messageDescription.Body.ReturnValue;
202             }
203             else
204             {
205                 if (messageDescription.Body.Parts.Count == 1)
206                     if (messageDescription.Body.Parts[0].Type == typeof(Stream))
207                         return messageDescription.Body.Parts[0];
208             }
209             return null;
210         }
211
212         internal static bool IsStream(MessageDescription messageDescription)
213         {
214             return GetStreamPart(messageDescription) != null;
215         }
216
217         internal class MessageBodyStream : Stream
218         {
219             Message message;
220             XmlDictionaryReader reader;
221             long position;
222             string wrapperName, wrapperNs;
223             string elementName, elementNs;
224             bool isRequest;
225             internal MessageBodyStream(Message message, string wrapperName, string wrapperNs, string elementName, string elementNs, bool isRequest)
226             {
227                 this.message = message;
228                 this.position = 0;
229                 this.wrapperName = wrapperName;
230                 this.wrapperNs = wrapperNs;
231                 this.elementName = elementName;
232                 this.elementNs = elementNs;
233                 this.isRequest = isRequest;
234             }
235
236             public override int Read(byte[] buffer, int offset, int count)
237             {
238                 EnsureStreamIsOpen();
239                 if (buffer == null)
240                     throw TraceUtility.ThrowHelperError(new ArgumentNullException("buffer"), this.message);
241                 if (offset < 0)
242                     throw TraceUtility.ThrowHelperError(new ArgumentOutOfRangeException("offset", offset,
243                                                     SR.GetString(SR.ValueMustBeNonNegative)), this.message);
244                 if (count < 0)
245                     throw TraceUtility.ThrowHelperError(new ArgumentOutOfRangeException("count", count,
246                                                     SR.GetString(SR.ValueMustBeNonNegative)), this.message);
247                 if (buffer.Length - offset < count)
248                     throw TraceUtility.ThrowHelperError(new ArgumentException(SR.GetString(SR.SFxInvalidStreamOffsetLength, offset + count)), this.message);
249
250                 try
251                 {
252
253                     if (reader == null)
254                     {
255                         reader = message.GetReaderAtBodyContents();
256                         if (wrapperName != null)
257                         {
258                             reader.MoveToContent();
259                             reader.ReadStartElement(wrapperName, wrapperNs);
260                         }
261                         reader.MoveToContent();
262                         if (reader.NodeType == XmlNodeType.EndElement)
263                         {
264                             return 0;
265                         }
266
267                         reader.ReadStartElement(elementName, elementNs);
268                     }
269                     if (reader.MoveToContent() != XmlNodeType.Text)
270                     {
271                         Exhaust(reader);
272                         return 0;
273                     }
274                     int bytesRead = reader.ReadContentAsBase64(buffer, offset, count);
275                     position += bytesRead;
276                     if (bytesRead == 0)
277                     {
278                         Exhaust(reader);
279                     }
280                     return bytesRead;
281                 }
282                 catch (Exception ex)
283                 {
284                     if (Fx.IsFatal(ex))
285                         throw;
286                     throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new IOException(SR.GetString(SR.SFxStreamIOException), ex));
287                 }
288             }
289
290             private void EnsureStreamIsOpen()
291             {
292                 if (message.State == MessageState.Closed)
293                     throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new ObjectDisposedException(SR.GetString(
294                         isRequest ? SR.SFxStreamRequestMessageClosed : SR.SFxStreamResponseMessageClosed)));
295             }
296
297             static void Exhaust(XmlDictionaryReader reader)
298             {
299                 if (reader != null)
300                 {
301                     while (reader.Read())
302                     {
303                         // drain
304                     }
305                 }
306             }
307
308             public override long Position
309             {
310                 get
311                 {
312                     EnsureStreamIsOpen();
313                     return position;
314                 }
315                 set { throw TraceUtility.ThrowHelperError(new NotSupportedException(), message); }
316             }
317
318             public override void Close()
319             {
320                 message.Close();
321                 if (reader != null)
322                 {
323                     reader.Close();
324                     reader = null;
325                 }
326                 base.Close();
327             }
328             public override bool CanRead { get { return message.State != MessageState.Closed; } }
329             public override bool CanSeek { get { return false; } }
330             public override bool CanWrite { get { return false; } }
331             public override long Length
332             {
333                 get
334                 {
335 #pragma warning suppress 56503 // Microsoft, not a seekable stream, it is ok to throw NotSupported in this case
336                     throw TraceUtility.ThrowHelperError(new NotSupportedException(), this.message);
337                 }
338             }
339             public override void Flush() { throw TraceUtility.ThrowHelperError(new NotSupportedException(), this.message); }
340             public override long Seek(long offset, SeekOrigin origin) { throw TraceUtility.ThrowHelperError(new NotSupportedException(), this.message); }
341             public override void SetLength(long value) { throw TraceUtility.ThrowHelperError(new NotSupportedException(), this.message); }
342             public override void Write(byte[] buffer, int offset, int count) { throw TraceUtility.ThrowHelperError(new NotSupportedException(), this.message); }
343         }
344
345         class OperationStreamProvider : IStreamProvider
346         {
347             Stream stream;
348
349             internal OperationStreamProvider(Stream stream)
350             {
351                 this.stream = stream;
352             }
353
354             public Stream GetStream()
355             {
356                 return stream;
357             }
358             public void ReleaseStream(Stream stream)
359             {
360                 //Noop
361             }
362         }
363     }
364
365
366
367 }