* ObjectReader.cs: Changed signature of ReadObjectGraph, so now it returns the
[mono.git] / mcs / class / corlib / System.Runtime.Serialization.Formatters.Binary / MessageFormatter.cs
1 //
2 // System.Runtime.Remoting.MessageFormatter.cs
3 //
4 // Author: Lluis Sanchez Gual (lluis@ideary.com)
5 //
6 // (C) 2003, Lluis Sanchez Gual
7 //\r
8 \r
9 using System;\r
10 using System.IO;\r
11 using System.Reflection;\r
12 using System.Collections;\r
13 using System.Runtime.Remoting;\r
14 using System.Runtime.Serialization;\r
15 using System.Runtime.Remoting.Messaging;\r
16 \r
17 namespace System.Runtime.Serialization.Formatters.Binary\r
18 {\r
19         internal class MessageFormatter\r
20         {\r
21                 public static void WriteMethodCall (BinaryWriter writer, object obj, Header[] headers, ISurrogateSelector surrogateSelector, StreamingContext context)\r
22                 {\r
23                         IMethodCallMessage call = (IMethodCallMessage)obj;\r
24                         writer.Write ((byte) BinaryElement.MethodCall);\r
25 \r
26                         MethodFlags methodFlags;\r
27                         int infoArraySize = 0;\r
28                         object info = null;\r
29                         object[] extraProperties = null;\r
30 \r
31                         if (call.LogicalCallContext != null && call.LogicalCallContext.HasInfo)\r
32                         {\r
33                                 methodFlags = MethodFlags.IncludesLogicalCallContext;\r
34                                 infoArraySize++;\r
35                         }\r
36                         else\r
37                                 methodFlags = MethodFlags.ExcludeLogicalCallContext;\r
38 \r
39                         if (RemotingServices.IsMethodOverloaded (call))\r
40                         {\r
41                                 infoArraySize++;\r
42                                 methodFlags |= MethodFlags.IncludesSignature;\r
43                         }\r
44 \r
45                         if (call.Properties.Count > MethodCallDictionary.InternalKeys.Length)\r
46                         {\r
47                                 extraProperties = GetExtraProperties (call.Properties, MethodCallDictionary.InternalKeys);\r
48                                 infoArraySize++;\r
49                         }\r
50 \r
51                         if (call.ArgCount == 0)\r
52                                 methodFlags |= MethodFlags.NoArguments;\r
53                         else {\r
54                                 if (AllTypesArePrimitive (call.Args)) \r
55                                         methodFlags |= MethodFlags.PrimitiveArguments;\r
56                                 else {\r
57                                         if (infoArraySize == 0)\r
58                                                 methodFlags |= MethodFlags.ArgumentsInSimpleArray;\r
59                                         else {\r
60                                                 methodFlags |= MethodFlags.ArgumentsInMultiArray;\r
61                                                 infoArraySize++;\r
62                                         }\r
63                                 }\r
64                         }\r
65 \r
66                         writer.Write ((byte) (methodFlags));\r
67 \r
68                         // FIXME: what are the following 3 bytes for?\r
69                         writer.Write ((byte) 0);\r
70                         writer.Write ((byte) 0);\r
71                         writer.Write ((byte) 0);\r
72 \r
73                         // Method name\r
74                         writer.Write ((byte) BinaryTypeCode.String);\r
75                         writer.Write (call.MethodName);\r
76 \r
77                         // Class name\r
78                         writer.Write ((byte) BinaryTypeCode.String);\r
79                         writer.Write (call.TypeName);\r
80 \r
81                         // Arguments\r
82 \r
83                         if ((methodFlags & MethodFlags.PrimitiveArguments) > 0)\r
84                         {\r
85                                 writer.Write ((uint)call.Args.Length);\r
86                                 for (int n=0; n<call.InArgCount; n++)\r
87                                 {\r
88                                         object arg = call.GetArg(n);\r
89                                         if (arg != null) {\r
90                                                 writer.Write (BinaryCommon.GetTypeCode (arg.GetType()));\r
91                                                 ObjectWriter.WritePrimitiveValue (writer, arg);\r
92                                         }\r
93                                         else\r
94                                                 writer.Write ((byte)BinaryTypeCode.Null);\r
95                                 }\r
96                         }\r
97 \r
98                         if ( infoArraySize > 0)\r
99                         {\r
100                                 object[] ainfo = new object[infoArraySize];\r
101                                 int n=0;\r
102                                 if ((methodFlags & MethodFlags.ArgumentsInMultiArray) > 0) ainfo[n++] = call.Args;\r
103                                 if ((methodFlags & MethodFlags.IncludesSignature) > 0) ainfo[n++] = call.MethodSignature;\r
104                                 if ((methodFlags & MethodFlags.IncludesLogicalCallContext) > 0) ainfo[n++] = call.LogicalCallContext;\r
105                                 if (extraProperties != null) ainfo[n++] = extraProperties;\r
106                                 info = ainfo;\r
107                         }\r
108                         else if ((methodFlags & MethodFlags.ArgumentsInSimpleArray) > 0)\r
109                                 info = call.Args;\r
110 \r
111                         if (info != null)\r
112                         {\r
113                                 ObjectWriter objectWriter = new ObjectWriter(surrogateSelector, context);\r
114                                 objectWriter.WriteObjectGraph (writer, info, headers);\r
115                         }\r
116                         else\r
117                                 writer.Write ((byte) BinaryElement.End);\r
118                 }\r
119 \r
120                 public static void WriteMethodResponse (BinaryWriter writer, object obj, Header[] headers, ISurrogateSelector surrogateSelector, StreamingContext context)\r
121                 {\r
122                         IMethodReturnMessage resp = (IMethodReturnMessage)obj;\r
123                         writer.Write ((byte) BinaryElement.MethodResponse);\r
124 \r
125                         string[] internalProperties = MethodReturnDictionary.InternalReturnKeys;\r
126 \r
127                         int infoArrayLength = 0;\r
128                         object info = null;\r
129                         object[] extraProperties = null;\r
130 \r
131                         // Type of return value\r
132 \r
133                         ReturnTypeTag returnTypeTag;\r
134 \r
135                         if (resp.Exception != null) {\r
136                                 returnTypeTag = ReturnTypeTag.Exception | ReturnTypeTag.Null;\r
137                                 info = new object[] {resp.Exception};\r
138                                 internalProperties = MethodReturnDictionary.InternalExceptionKeys;\r
139                         }\r
140                         else if (resp.ReturnValue == null) {\r
141                                 returnTypeTag = ReturnTypeTag.Null;\r
142                         }\r
143                         else if (IsMethodPrimitive(resp.ReturnValue.GetType())) {\r
144                                 returnTypeTag = ReturnTypeTag.PrimitiveType;\r
145                         }\r
146                         else {\r
147                                 returnTypeTag = ReturnTypeTag.ObjectType;\r
148                                 infoArrayLength++;\r
149                         }\r
150 \r
151                         // Message flags\r
152 \r
153                         MethodFlags contextFlag;\r
154                         MethodFlags formatFlag;\r
155 \r
156                         if ((resp.LogicalCallContext != null) && resp.LogicalCallContext.HasInfo && ((returnTypeTag & ReturnTypeTag.Exception) == 0)) \r
157                         {\r
158                                 contextFlag = MethodFlags.IncludesLogicalCallContext;\r
159                                 infoArrayLength++;\r
160                         }\r
161                         else\r
162                                 contextFlag = MethodFlags.ExcludeLogicalCallContext;\r
163 \r
164                         if (resp.Properties.Count > internalProperties.Length && ((returnTypeTag & ReturnTypeTag.Exception) == 0))\r
165                         {\r
166                                 extraProperties = GetExtraProperties (resp.Properties, internalProperties);\r
167                                 infoArrayLength++;\r
168                         }\r
169 \r
170                         if (resp.OutArgCount == 0)\r
171                                 formatFlag = MethodFlags.NoArguments;\r
172                         else \r
173                         {\r
174                                 if (AllTypesArePrimitive (resp.OutArgs)) \r
175                                         formatFlag = MethodFlags.PrimitiveArguments;\r
176                                 else \r
177                                 {\r
178                                         if (infoArrayLength == 0)\r
179                                                 formatFlag = MethodFlags.ArgumentsInSimpleArray; \r
180                                         else {\r
181                                                 formatFlag = MethodFlags.ArgumentsInMultiArray;\r
182                                                 infoArrayLength++;\r
183                                         }\r
184                                 }\r
185                         }\r
186 \r
187                         writer.Write ((byte) (contextFlag | formatFlag));\r
188                         writer.Write ((byte) returnTypeTag);\r
189 \r
190                         // FIXME: what are the following 2 bytes for?\r
191                         writer.Write ((byte) 0);\r
192                         writer.Write ((byte) 0);\r
193 \r
194                         // Arguments\r
195 \r
196                         if (returnTypeTag == ReturnTypeTag.PrimitiveType)\r
197                         {\r
198                                 writer.Write (BinaryCommon.GetTypeCode (resp.ReturnValue.GetType()));\r
199                                 ObjectWriter.WritePrimitiveValue (writer, resp.ReturnValue);\r
200                         }\r
201 \r
202                         if (formatFlag == MethodFlags.PrimitiveArguments)\r
203                         {\r
204                                 writer.Write ((uint)resp.OutArgCount);\r
205                                 for (int n=0; n<resp.OutArgCount; n++)\r
206                                 {\r
207                                         object val = resp.GetOutArg(n);\r
208                                         if (val != null) {\r
209                                                 writer.Write (BinaryCommon.GetTypeCode (val.GetType()));\r
210                                                 ObjectWriter.WritePrimitiveValue (writer, val);\r
211                                         }\r
212                                         else\r
213                                                 writer.Write ((byte)BinaryTypeCode.Null);\r
214                                 }\r
215                         }\r
216 \r
217                         if (infoArrayLength > 0)\r
218                         {\r
219                                 object[] infoArray = new object[infoArrayLength];\r
220                                 int n = 0;\r
221 \r
222                                 if (formatFlag == MethodFlags.ArgumentsInMultiArray)\r
223                                         infoArray[n++] = resp.OutArgs;\r
224 \r
225                                 if (returnTypeTag == ReturnTypeTag.ObjectType)\r
226                                         infoArray[n++] = resp.ReturnValue;\r
227 \r
228                                 if (contextFlag == MethodFlags.IncludesLogicalCallContext)\r
229                                         infoArray[n++] = resp.LogicalCallContext;\r
230 \r
231                                 if (extraProperties != null)\r
232                                         infoArray[n++] = extraProperties;\r
233 \r
234                                 info = infoArray;\r
235                         }\r
236                         else if ((formatFlag & MethodFlags.ArgumentsInSimpleArray) > 0)\r
237                                 info = resp.OutArgs;\r
238 \r
239                         if (info != null)\r
240                         {\r
241                                 ObjectWriter objectWriter = new ObjectWriter(surrogateSelector, context);\r
242                                 objectWriter.WriteObjectGraph (writer, info, headers);\r
243                         }\r
244                         else\r
245                                 writer.Write ((byte) BinaryElement.End);\r
246                 }\r
247 \r
248                 public static object ReadMethodCall (BinaryReader reader, bool hasHeaders, HeaderHandler headerHandler, ISurrogateSelector surrogateSelector, StreamingContext context, SerializationBinder binder)\r
249                 {\r
250                         BinaryElement elem = (BinaryElement)reader.ReadByte();  // The element code\r
251                         if (elem != BinaryElement.MethodCall) throw new SerializationException("Invalid format. Expected BinaryElement.MethodCall, found " +  elem);\r
252 \r
253                         MethodFlags flags = (MethodFlags) reader.ReadByte();\r
254 \r
255                         // FIXME: find a meaning for those 3 bytes\r
256                         reader.ReadByte();\r
257                         reader.ReadByte();\r
258                         reader.ReadByte();\r
259 \r
260                         if (((BinaryTypeCode)reader.ReadByte()) != BinaryTypeCode.String) throw new SerializationException ("Invalid format");\r
261                         string methodName = reader.ReadString();\r
262 \r
263                         if (((BinaryTypeCode)reader.ReadByte()) != BinaryTypeCode.String) throw new SerializationException ("Invalid format");\r
264                         string className = reader.ReadString();\r
265 \r
266                         bool hasContextInfo = (flags & MethodFlags.IncludesLogicalCallContext) > 0;
267
268                         object[] arguments = null;
269                         object methodSignature = null;
270                         object callContext = null;
271                         object[] extraProperties = null;
272                         Header[] headers = null;
273
274                         if ((flags & MethodFlags.PrimitiveArguments) > 0)
275                         {
276                                 uint count = reader.ReadUInt32();
277                                 arguments = new object[count];
278                                 for (int n=0; n<count; n++)
279                                 {
280                                         Type type = BinaryCommon.GetTypeFromCode (reader.ReadByte());
281                                         arguments[n] = ObjectReader.ReadPrimitiveTypeValue (reader, type);
282                                 }
283                         }
284
285                         if ((flags & MethodFlags.NeedsInfoArrayMask) > 0)\r
286                         {\r
287                                 ObjectReader objectReader = new ObjectReader(surrogateSelector, context, binder);\r
288 \r
289                                 object result;\r
290                                 objectReader.ReadObjectGraph (reader, hasHeaders, out result, out headers);\r
291                                 object[] msgInfo = (object[]) result;\r
292 \r
293                                 if ((flags & MethodFlags.ArgumentsInSimpleArray) > 0) {\r
294                                         arguments = msgInfo;\r
295                                 }\r
296                                 else\r
297                                 {\r
298                                         int n = 0;\r
299                                         if ((flags & MethodFlags.ArgumentsInMultiArray) > 0) {\r
300                                                 if (msgInfo.Length > 1) arguments = (object[]) msgInfo[n++];\r
301                                                 else arguments = new object[0];\r
302                                         }\r
303 \r
304                                         if ((flags & MethodFlags.IncludesSignature) > 0)\r
305                                                 methodSignature = msgInfo[n++];\r
306 \r
307                                         if ((flags & MethodFlags.IncludesLogicalCallContext) > 0) \r
308                                                 callContext = msgInfo[n++];\r
309 \r
310                                         if (n < msgInfo.Length)\r
311                                                 extraProperties = (object[]) msgInfo[n];\r
312                                 }\r
313                         }\r
314                         else {\r
315                                 reader.ReadByte ();     // Reads the stream ender\r
316                         }\r
317 \r
318                         if (arguments == null) arguments = new object[0];\r
319 \r
320                         string uri = null;\r
321                         if (headerHandler != null)\r
322                                 uri = headerHandler(headers) as string;\r
323 \r
324                         Header[] methodInfo = new Header[6];\r
325                         methodInfo[0] = new Header("__MethodName", methodName);
326                         methodInfo[1] = new Header("__MethodSignature", methodSignature);
327                         methodInfo[2] = new Header("__TypeName", className);
328                         methodInfo[3] = new Header("__Args", arguments);
329                         methodInfo[4] = new Header("__CallContext", callContext);
330                         methodInfo[5] = new Header("__Uri", uri);
331
332                         MethodCall call = new MethodCall (methodInfo);
333
334                         if (extraProperties != null) {
335                                 foreach (DictionaryEntry entry in extraProperties)
336                                         call.Properties [(string)entry.Key] = entry.Value;
337                         }
338
339                         return call;
340                 }\r
341 \r
342                 public static object ReadMethodResponse (BinaryReader reader, bool hasHeaders, HeaderHandler headerHandler, IMethodCallMessage methodCallMessage, ISurrogateSelector surrogateSelector, StreamingContext context, SerializationBinder binder)\r
343                 {\r
344                         BinaryElement elem = (BinaryElement)reader.ReadByte();  // The element code\r
345                         if (elem != BinaryElement.MethodResponse) throw new SerializationException("Invalid format. Expected BinaryElement.MethodResponse, found " +  elem);\r
346 \r
347                         MethodFlags flags = (MethodFlags) reader.ReadByte ();\r
348                         ReturnTypeTag typeTag = (ReturnTypeTag) reader.ReadByte ();\r
349                         bool hasContextInfo = (flags & MethodFlags.IncludesLogicalCallContext) > 0;
350 \r
351                         // FIXME: find a meaning for those 2 bytes\r
352                         reader.ReadByte();\r
353                         reader.ReadByte();\r
354 \r
355                         object returnValue = null;\r
356                         object[] outArgs = null;\r
357                         LogicalCallContext callContext = null;\r
358                         Exception exception = null;\r
359                         object[] extraProperties = null;
360                         Header[] headers = null;
361 \r
362                         if ((typeTag & ReturnTypeTag.PrimitiveType) > 0)
363                         {\r
364                                 Type type = BinaryCommon.GetTypeFromCode (reader.ReadByte());
365                                 returnValue = ObjectReader.ReadPrimitiveTypeValue (reader, type);\r
366                         }\r
367 \r
368                         if ((flags & MethodFlags.PrimitiveArguments) > 0)
369                         {
370                                 uint count = reader.ReadUInt32();
371                                 outArgs = new object[count];
372                                 for (int n=0; n<count; n++) {
373                                         Type type = BinaryCommon.GetTypeFromCode (reader.ReadByte());
374                                         outArgs[n] = ObjectReader.ReadPrimitiveTypeValue (reader, type);
375                                 }
376                         }
377
378                         if (hasContextInfo || (typeTag & ReturnTypeTag.ObjectType) > 0 || \r
379                                 (typeTag & ReturnTypeTag.Exception) > 0 ||\r
380                                 (flags & MethodFlags.ArgumentsInSimpleArray) > 0 || \r
381                                 (flags & MethodFlags.ArgumentsInMultiArray) > 0)\r
382                         {\r
383                                 // There objects that need to be deserialized using an ObjectReader\r
384 \r
385                                 ObjectReader objectReader = new ObjectReader(surrogateSelector, context, binder);\r
386                                 object result;\r
387                                 objectReader.ReadObjectGraph (reader, hasHeaders, out result, out headers);\r
388                                 object[] msgInfo = (object[]) result;\r
389 \r
390                                 if ((typeTag & ReturnTypeTag.Exception) > 0) {\r
391                                         exception = (Exception) msgInfo[0];\r
392                                 }\r
393                                 else if ((flags & MethodFlags.NoArguments) > 0 || (flags & MethodFlags.PrimitiveArguments) > 0) {\r
394                                         int n = 0;\r
395                                         if ((typeTag & ReturnTypeTag.ObjectType) > 0) returnValue = msgInfo [n++];\r
396                                         if (hasContextInfo) callContext = (LogicalCallContext)msgInfo[n++];\r
397                                         if (n < msgInfo.Length) extraProperties = (object[]) msgInfo[n];\r
398                                 }\r
399                                 else if ((flags & MethodFlags.ArgumentsInSimpleArray) > 0) {\r
400                                         outArgs = msgInfo;\r
401                                 }\r
402                                 else {\r
403                                         int n = 0;\r
404                                         outArgs = (object[]) msgInfo[n++];\r
405                                         if ((typeTag & ReturnTypeTag.ObjectType) > 0) returnValue = msgInfo[n++];\r
406                                         if (hasContextInfo) callContext = (LogicalCallContext)msgInfo[n++];\r
407                                         if (n < msgInfo.Length) extraProperties = (object[]) msgInfo[n];\r
408                                 }\r
409                         }\r
410                         else {\r
411                                 reader.ReadByte ();     // Reads the stream ender\r
412                         }\r
413 \r
414                         if (headerHandler != null) \r
415                                 headerHandler(headers);\r
416 \r
417                         if (exception != null)\r
418                                 return new ReturnMessage (exception, methodCallMessage);\r
419                         else\r
420                         {\r
421                                 int argCount = (outArgs!=null) ? outArgs.Length : 0;\r
422                                 ReturnMessage result = new ReturnMessage (returnValue, outArgs, argCount, callContext, methodCallMessage);\r
423 \r
424                                 if (extraProperties != null) {
425                                         foreach (DictionaryEntry entry in extraProperties)
426                                                 result.Properties [(string)entry.Key] = entry.Value;
427                                 }
428
429                                 return result;
430                         }\r
431                 }\r
432 \r
433                 private static bool AllTypesArePrimitive(object[] objects)\r
434                 {\r
435                         foreach (object ob in objects) \r
436                         {\r
437                                 if (ob != null && !IsMethodPrimitive(ob.GetType())) \r
438                                         return false;\r
439                         }\r
440                         return true;\r
441                 }\r
442 \r
443                 // When serializing methods, string are considered primitive types\r
444                 public static bool IsMethodPrimitive (Type type)\r
445                 {\r
446                         return type.IsPrimitive || type == typeof(string) || type == typeof (DateTime) || type == typeof (Decimal);\r
447                 }\r
448 \r
449                 static object[] GetExtraProperties (IDictionary properties, string[] internalKeys)\r
450                 {\r
451                         object[] extraProperties = new object [properties.Count - internalKeys.Length];\r
452                         \r
453                         int n = 0;\r
454                         IDictionaryEnumerator e = properties.GetEnumerator();\r
455                         while (e.MoveNext())\r
456                                 if (!IsInternalKey ((string) e.Entry.Key, internalKeys)) extraProperties [n++] = e.Entry;\r
457 \r
458                         return extraProperties;\r
459                 }\r
460 \r
461                 static bool IsInternalKey (string key, string[] internalKeys)\r
462                 {\r
463                         foreach (string ikey in internalKeys)\r
464                                 if (key == ikey) return true;\r
465                         return false;\r
466                 }\r
467 \r
468         }\r
469 }\r