Merge pull request #463 from strawd/concurrent-requests
[mono.git] / mcs / class / System.ServiceModel / System.ServiceModel.Description / ServiceContractGenerator.cs
1 //
2 // ServiceContractGenerator.cs
3 //
4 // Author:
5 //      Atsushi Enomoto <atsushi@ximian.com>
6 //
7 // Copyright (C) 2005 Novell, Inc.  http://www.novell.com
8 //
9 // Permission is hereby granted, free of charge, to any person obtaining
10 // a copy of this software and associated documentation files (the
11 // "Software"), to deal in the Software without restriction, including
12 // without limitation the rights to use, copy, modify, merge, publish,
13 // distribute, sublicense, and/or sell copies of the Software, and to
14 // permit persons to whom the Software is furnished to do so, subject to
15 // the following conditions:
16 // 
17 // The above copyright notice and this permission notice shall be
18 // included in all copies or substantial portions of the Software.
19 // 
20 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
21 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
22 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
23 // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
24 // LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
25 // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
26 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
27 //
28 using System;
29 using System.CodeDom;
30 using System.Collections.Generic;
31 using System.Collections.ObjectModel;
32 using System.ComponentModel;
33 using System.Configuration;
34 using System.Linq;
35 using System.Reflection;
36 using System.Runtime.Serialization;
37 using System.ServiceModel;
38 using System.ServiceModel.Channels;
39 using System.ServiceModel.Configuration;
40 using System.Threading;
41 using System.Xml.Schema;
42 using System.Xml.Serialization;
43
44 using ConfigurationType = System.Configuration.Configuration;
45 using QName = System.Xml.XmlQualifiedName;
46 using OPair = System.Collections.Generic.KeyValuePair<System.ServiceModel.Description.IOperationContractGenerationExtension,System.ServiceModel.Description.OperationContractGenerationContext>;
47
48 namespace System.ServiceModel.Description
49 {
50         public class ServiceContractGenerator
51         {
52                 CodeCompileUnit ccu;
53                 ConfigurationType config;
54                 CodeIdentifiers identifiers = new CodeIdentifiers ();
55                 Collection<MetadataConversionError> errors
56                         = new Collection<MetadataConversionError> ();
57                 Dictionary<string,string> nsmappings
58                         = new Dictionary<string,string> ();
59                 Dictionary<ContractDescription,Type> referenced_types
60                         = new Dictionary<ContractDescription,Type> ();
61                 Dictionary<ContractDescription,ContractCacheEntry> generated_contracts
62                         = new Dictionary<ContractDescription,ContractCacheEntry> ();
63                 ServiceContractGenerationOptions options;
64                 Dictionary<QName, QName> imported_names
65                         = new Dictionary<QName, QName> ();
66                 ServiceContractGenerationContext contract_context;
67                 List<OPair> operation_contexts = new List<OPair> ();
68
69                 XsdDataContractImporter data_contract_importer;
70                 XmlSerializerMessageContractImporterInternal xml_serialization_importer;
71
72                 class ContractCacheEntry {
73                         public ContractDescription Contract {
74                                 get;
75                                 private set;
76                         }
77
78                         public string ConfigurationName {
79                                 get;
80                                 private set;
81                         }
82
83                         public CodeTypeDeclaration TypeDeclaration {
84                                 get;
85                                 private set;
86                         }
87
88                         public bool GeneratedContractType {
89                                 get; set;
90                         }
91
92                         public CodeTypeReference GetReference ()
93                         {
94                                 return reference;
95                         }
96
97                         public ContractCacheEntry (ContractDescription cd, string config,
98                                                    CodeTypeDeclaration tdecl)
99                         {
100                                 Contract = cd;
101                                 ConfigurationName = config;
102                                 TypeDeclaration = tdecl;
103                                 reference = new CodeTypeReference (tdecl.Name);
104                         }
105
106                         readonly CodeTypeReference reference;
107                 }
108
109                 public ServiceContractGenerator ()
110                         : this (null, null)
111                 {
112                 }
113
114                 public ServiceContractGenerator (CodeCompileUnit ccu)
115                         : this (ccu, null)
116                 {
117                 }
118
119                 public ServiceContractGenerator (ConfigurationType config)
120                         : this (null, config)
121                 {
122                 }
123
124                 public ServiceContractGenerator (CodeCompileUnit ccu, ConfigurationType config)
125                 {
126                         if (ccu == null)
127                                 this.ccu = new CodeCompileUnit ();
128                         else
129                                 this.ccu = ccu;
130                         this.config = config;
131                         Options |= ServiceContractGenerationOptions.ChannelInterface | 
132                                 ServiceContractGenerationOptions.ClientClass;
133                 }
134
135                 public ConfigurationType Configuration {
136                         get { return config; }
137                 }
138
139                 public Collection<MetadataConversionError> Errors {
140                         get { return errors; }
141                 }
142
143                 public Dictionary<string,string> NamespaceMappings {
144                         get { return nsmappings; }
145                 }
146
147                 public ServiceContractGenerationOptions Options {
148                         get { return options; }
149                         set { options = value; }
150                 }
151
152                 bool GenerateAsync {
153                         get { return GenerateEventBasedAsync || (options & ServiceContractGenerationOptions.AsynchronousMethods) != 0; }
154                 }
155
156                 bool GenerateEventBasedAsync {
157                         get { return (options & ServiceContractGenerationOptions.EventBasedAsynchronousMethods) != 0; }
158                 }
159
160                 public Dictionary<ContractDescription,Type> ReferencedTypes {
161                         get { return referenced_types; }
162                 }
163
164                 public CodeCompileUnit TargetCompileUnit {
165                         get { return ccu; }
166                 }
167
168                 public void GenerateBinding (Binding binding,
169                         out string bindingSectionName,
170                         out string configurationName)
171                 {
172                         if (config == null)
173                                 throw new InvalidOperationException ();
174
175                         var element = ConfigUtil.FindCollectionElement (binding, config);
176                         if (element == null)
177                                 throw new InvalidOperationException ();
178
179                         bindingSectionName = element.BindingName;
180
181                         int idx = 0;
182                         configurationName = binding.Name;
183                         while (element.ContainsKey (configurationName))
184                                 configurationName = binding.Name + (++idx);
185
186                         if (!element.TryAdd (configurationName, binding, config))
187                                 throw new InvalidOperationException ();
188                 }
189
190                 #region Service Contract Type
191
192                 // Those implementation classes are very likely to be split
193                 // into different classes.
194
195                 [MonoTODO]
196                 public CodeTypeReference GenerateServiceContractType (
197                         ContractDescription contractDescription)
198                 {
199                         CodeNamespace cns = GetNamespace (contractDescription.Namespace);
200                         var cache = ExportInterface_internal (contractDescription, cns);
201                         if (cache.GeneratedContractType)
202                                 return cache.GetReference ();
203
204                         // FIXME: handle duplex callback
205
206                         if ((Options & ServiceContractGenerationOptions.ChannelInterface) != 0)
207                                 GenerateChannelInterface (contractDescription, cns);
208
209                         if ((Options & ServiceContractGenerationOptions.ClientClass) != 0)
210                                 GenerateProxyClass (contractDescription, cns);
211
212                         if (data_contract_importer != null)
213                                 MergeCompileUnit (data_contract_importer.CodeCompileUnit, ccu);
214                         if (xml_serialization_importer != null)
215                                 MergeCompileUnit (xml_serialization_importer.CodeCompileUnit, ccu);
216
217                         // Process extensions. Class first, then methods.
218                         // (built-in ones must present before processing class extensions).
219                         foreach (var cb in contractDescription.Behaviors) {
220                                 var gex = cb as IServiceContractGenerationExtension;
221                                 if (gex != null)
222                                         gex.GenerateContract (contract_context);
223                         }
224                         foreach (var opair in operation_contexts)
225                                 opair.Key.GenerateOperation (opair.Value);
226
227                         cache.GeneratedContractType = true;
228                         return cache.GetReference ();
229                 }
230
231                 CodeNamespace GetNamespace (string contractNs)
232                 {
233                         if (contractNs == null)
234                                 contractNs = String.Empty;
235                         string csharpNs;
236                         if (nsmappings.ContainsKey (contractNs))
237                                 csharpNs = nsmappings [contractNs];
238                         else if (nsmappings.ContainsKey ("*"))
239                                 csharpNs = nsmappings ["*"];
240                         else
241                                 csharpNs = string.Empty;
242                         foreach (CodeNamespace cns in ccu.Namespaces)
243                                 if (cns.Name == csharpNs)
244                                         return cns;
245                         CodeNamespace ncns = new CodeNamespace ();
246                         ncns.Name = csharpNs;
247                         ccu.Namespaces.Add (ncns);
248                         return ncns;
249                 }
250
251                 CodeTypeDeclaration GetTypeDeclaration (CodeNamespace cns, string name)
252                 {
253                         foreach (CodeTypeDeclaration type in cns.Types)
254                                 if (type.Name == name)
255                                         return type;
256                         return null;
257                 }
258
259                 void GenerateProxyClass (ContractDescription cd, CodeNamespace cns)
260                 {
261                         string name = cd.Name + "Client";
262                         if (name [0] == 'I')
263                                 name = name.Substring (1);
264                         name = identifiers.AddUnique (name, null);
265                         CodeTypeDeclaration type = GetTypeDeclaration (cns, name);
266                         if (type != null)
267                                 return; // already imported
268                         CodeTypeReference clientBase = new CodeTypeReference (typeof (ClientBase<>));
269                         clientBase.TypeArguments.Add (new CodeTypeReference (cd.Name));
270                         type = new CodeTypeDeclaration (name);
271                         cns.Types.Add (type);
272                         type.TypeAttributes = TypeAttributes.Public;
273                         type.BaseTypes.Add (clientBase);
274                         type.BaseTypes.Add (new CodeTypeReference (cd.Name));
275
276                         // .ctor()
277                         CodeConstructor ctor = new CodeConstructor ();
278                         ctor.Attributes = MemberAttributes.Public;
279                         type.Members.Add (ctor);
280
281                         // .ctor(string endpointConfigurationName)
282                         ctor = new CodeConstructor ();
283                         ctor.Attributes = MemberAttributes.Public;
284                         ctor.Parameters.Add (
285                                 new CodeParameterDeclarationExpression (
286                                         new CodeTypeReference (typeof (string)), "endpointConfigurationName"));
287                         ctor.BaseConstructorArgs.Add (
288                                 new CodeArgumentReferenceExpression ("endpointConfigurationName"));
289                         type.Members.Add (ctor);
290
291                         // .ctor(string endpointConfigurationName, string remoteAddress)
292                         ctor = new CodeConstructor ();
293                         ctor.Attributes = MemberAttributes.Public;
294                         ctor.Parameters.Add (
295                                 new CodeParameterDeclarationExpression (
296                                         new CodeTypeReference (typeof (string)), "endpointConfigurationName"));
297                         ctor.Parameters.Add (
298                                 new CodeParameterDeclarationExpression (
299                                         new CodeTypeReference (typeof (string)), "remoteAddress"));
300                         ctor.BaseConstructorArgs.Add (
301                                 new CodeArgumentReferenceExpression ("endpointConfigurationName"));
302                         ctor.BaseConstructorArgs.Add (
303                                 new CodeArgumentReferenceExpression ("remoteAddress"));
304                         type.Members.Add (ctor);
305
306                         // .ctor(string endpointConfigurationName, EndpointAddress remoteAddress)
307                         ctor = new CodeConstructor ();
308                         ctor.Attributes = MemberAttributes.Public;
309                         ctor.Parameters.Add (
310                                 new CodeParameterDeclarationExpression (
311                                         new CodeTypeReference (typeof (string)), "endpointConfigurationName"));
312                         ctor.Parameters.Add (
313                                 new CodeParameterDeclarationExpression (
314                                         new CodeTypeReference (typeof (EndpointAddress)), "remoteAddress"));
315                         ctor.BaseConstructorArgs.Add (
316                                 new CodeArgumentReferenceExpression ("endpointConfigurationName"));
317                         ctor.BaseConstructorArgs.Add (
318                                 new CodeArgumentReferenceExpression ("remoteAddress"));
319                         type.Members.Add (ctor);
320
321                         // .ctor(Binding,EndpointAddress)
322                         ctor = new CodeConstructor ();
323                         ctor.Attributes = MemberAttributes.Public;
324                         ctor.Parameters.Add (
325                                 new CodeParameterDeclarationExpression (
326                                         new CodeTypeReference (typeof (Binding)), "binding"));
327                         ctor.Parameters.Add (
328                                 new CodeParameterDeclarationExpression (
329                                         new CodeTypeReference (typeof (EndpointAddress)), "endpoint"));
330                         ctor.BaseConstructorArgs.Add (
331                                 new CodeArgumentReferenceExpression ("binding"));
332                         ctor.BaseConstructorArgs.Add (
333                                 new CodeArgumentReferenceExpression ("endpoint"));
334                         type.Members.Add (ctor);
335
336                         // service contract methods
337                         AddImplementationClientMethods (type, cd);
338
339                         if (GenerateEventBasedAsync)
340                                 foreach (var od in cd.Operations)
341                                         GenerateEventBasedAsyncSupport (type, od, cns);
342                 }
343
344                 void GenerateChannelInterface (ContractDescription cd, CodeNamespace cns)
345                 {
346                         string name = cd.Name + "Channel";
347                         name = identifiers.AddUnique (name, null);
348                         CodeTypeDeclaration type = GetTypeDeclaration (cns, name);
349                         if (type != null)
350                                 return;
351
352                         type = new CodeTypeDeclaration ();
353                         type.Name = name;
354                         type.TypeAttributes = TypeAttributes.Interface | TypeAttributes.Public;
355                         cns.Types.Add (type);
356                         
357                         type.BaseTypes.Add (ExportInterface (cd, cns));
358                         type.BaseTypes.Add (new CodeTypeReference (typeof (System.ServiceModel.IClientChannel)));
359                 }
360
361                 CodeTypeReference ExportInterface (ContractDescription cd, CodeNamespace cns)
362                 {
363                         var cache = ExportInterface_internal (cd, cns);
364                         return cache.GetReference ();
365                 }
366
367                 ContractCacheEntry ExportInterface_internal (ContractDescription cd, CodeNamespace cns)
368                 {
369                         if (generated_contracts.ContainsKey (cd))
370                                 return generated_contracts [cd];
371
372                         var type = new CodeTypeDeclaration ();
373                         type.TypeAttributes = TypeAttributes.Interface;
374                         type.TypeAttributes |= TypeAttributes.Public;
375                         cns.Types.Add (type);
376                         type.Name = identifiers.AddUnique (cd.Name, null);
377
378                         var configName = type.Name;
379                         CodeAttributeDeclaration ad = 
380                                 new CodeAttributeDeclaration (
381                                         new CodeTypeReference (
382                                         typeof (ServiceContractAttribute)));
383                         ad.Arguments.Add (new CodeAttributeArgument ("Namespace", new CodePrimitiveExpression (cd.Namespace)));
384                         ad.Arguments.Add (new CodeAttributeArgument ("ConfigurationName", new CodePrimitiveExpression (configName)));
385                         type.CustomAttributes.Add (ad);
386                         contract_context = new ServiceContractGenerationContext (this, cd, type);
387                         
388                         AddOperationMethods (type, cd);
389
390                         var cache = new ContractCacheEntry (cd, configName, type);
391                         generated_contracts.Add (cd, cache);
392                         return cache;
393                 }
394
395                 void AddBeginAsyncArgs (CodeMemberMethod cm)
396                 {
397                         var acb = new CodeParameterDeclarationExpression (new CodeTypeReference (typeof (AsyncCallback)), "asyncCallback");
398                         cm.Parameters.Add (acb);
399                         var us = new CodeParameterDeclarationExpression (new CodeTypeReference (typeof (object)), "userState");
400                         cm.Parameters.Add (us);
401                 }
402
403                 void AddOperationMethods (CodeTypeDeclaration type, ContractDescription cd)
404                 {
405                         foreach (OperationDescription od in cd.Operations) {
406                                 CodeMemberMethod syncMethod = null, beginMethod = null, endMethod = null;
407
408                                 CodeTypeReference returnTypeFromMessageContract = null;
409                                 syncMethod = GenerateOperationMethod (type, cd, od, false, out returnTypeFromMessageContract);
410                                 type.Members.Add (syncMethod);
411
412                                 if (GenerateAsync) {
413                                         beginMethod = GenerateOperationMethod (type, cd, od, true, out returnTypeFromMessageContract);
414                                         type.Members.Add (beginMethod);
415
416                                         var cm = new CodeMemberMethod ();
417                                         type.Members.Add (cm);
418                                         cm.Name = "End" + od.Name;
419                                         endMethod = cm;
420
421                                         var res = new CodeParameterDeclarationExpression (new CodeTypeReference (typeof (IAsyncResult)), "result");
422                                         cm.Parameters.Add (res);
423
424                                         if (od.SyncMethod != null) // FIXME: it depends on sync method!
425                                                 cm.ReturnType = new CodeTypeReference (od.SyncMethod.ReturnType);
426                                         else
427                                                 cm.ReturnType = returnTypeFromMessageContract;
428                                 }
429
430                                 foreach (var ob in od.Behaviors) {
431                                         var gex = ob as IOperationContractGenerationExtension;
432                                         if (gex != null)
433                                                 operation_contexts.Add (new OPair (gex, new OperationContractGenerationContext (this, contract_context, od, type, syncMethod, beginMethod, endMethod)));
434                                 }
435                         }
436                 }
437
438                 CodeMemberMethod GenerateOperationMethod (CodeTypeDeclaration type, ContractDescription cd, OperationDescription od, bool async, out CodeTypeReference returnType)
439                 {
440                         CodeMemberMethod cm = new CodeMemberMethod ();
441
442                         if (od.Behaviors.Find<XmlSerializerMappingBehavior> () != null)
443                                 cm.CustomAttributes.Add (new CodeAttributeDeclaration (new CodeTypeReference (typeof (XmlSerializerFormatAttribute))));
444
445                         if (async)
446                                 cm.Name = "Begin" + od.Name;
447                         else
448                                 cm.Name = od.Name;
449
450                         if (od.SyncMethod != null) {
451                                 ExportParameters (cm, od.SyncMethod.GetParameters ());
452                                 if (async) {
453                                         AddBeginAsyncArgs (cm);
454                                         cm.ReturnType = new CodeTypeReference (typeof (IAsyncResult));
455                                 }
456                                 else
457                                         cm.ReturnType = new CodeTypeReference (od.SyncMethod.ReturnType);
458                                 returnType = new CodeTypeReference (od.SyncMethod.ReturnType);
459                         } else {
460                                 ExportMessages (od.Messages, cm, false);
461                                 returnType = cm.ReturnType;
462                                 if (async) {
463                                         AddBeginAsyncArgs (cm);
464                                         cm.ReturnType = new CodeTypeReference (typeof (IAsyncResult));
465                                 }
466                         }
467
468                         // [OperationContract (Action = "...", ReplyAction = "..")]
469                         var ad = new CodeAttributeDeclaration (new CodeTypeReference (typeof (OperationContractAttribute)));
470                         foreach (MessageDescription md in od.Messages) {
471                                 if (md.Direction == MessageDirection.Input)
472                                         ad.Arguments.Add (new CodeAttributeArgument ("Action", new CodePrimitiveExpression (md.Action)));
473                                 else
474                                         ad.Arguments.Add (new CodeAttributeArgument ("ReplyAction", new CodePrimitiveExpression (md.Action)));
475                         }
476                         if (async)
477                                 ad.Arguments.Add (new CodeAttributeArgument ("AsyncPattern", new CodePrimitiveExpression (true)));
478                         cm.CustomAttributes.Add (ad);
479
480                         return cm;
481                 }
482
483                 void ExportParameters (CodeMemberMethod method, ParameterInfo [] parameters)
484                 {
485                         foreach (ParameterInfo pi in parameters)
486                                 method.Parameters.Add (
487                                         new CodeParameterDeclarationExpression (
488                                                 new CodeTypeReference (pi.ParameterType),
489                                                 pi.Name));
490                 }
491
492                 void AddImplementationClientMethods (CodeTypeDeclaration type, ContractDescription cd)
493                 {
494                         foreach (OperationDescription od in cd.Operations) {
495                                 CodeMemberMethod cm;
496                                 CodeTypeReference returnTypeFromMessageContract = null;
497                                 cm = GenerateImplementationClientMethod (type, cd, od, false, out returnTypeFromMessageContract);
498                                 type.Members.Add (cm);
499
500                                 if (!GenerateAsync)
501                                         continue;
502
503                                 cm = GenerateImplementationClientMethod (type, cd, od, true, out returnTypeFromMessageContract);
504                                 type.Members.Add (cm);
505
506                                 // EndXxx() implementation
507
508                                 cm = new CodeMemberMethod ();
509                                 cm.Attributes = MemberAttributes.Public 
510                                                 | MemberAttributes.Final;
511                                 type.Members.Add (cm);
512                                 cm.Name = "End" + od.Name;
513
514                                 var res = new CodeParameterDeclarationExpression (new CodeTypeReference (typeof (IAsyncResult)), "result");
515                                 cm.Parameters.Add (res);
516
517                                 if (od.SyncMethod != null) // FIXME: it depends on sync method!
518                                         cm.ReturnType = new CodeTypeReference (od.SyncMethod.ReturnType);
519                                 else
520                                         cm.ReturnType = returnTypeFromMessageContract;
521
522                                 string resultArgName = "result";
523                                 if (od.EndMethod != null)
524                                         resultArgName = od.EndMethod.GetParameters () [0].Name;
525
526                                 var call = new CodeMethodInvokeExpression (
527                                         new CodePropertyReferenceExpression (
528                                                 new CodeBaseReferenceExpression (),
529                                                 "Channel"),
530                                         cm.Name,
531                                         new CodeArgumentReferenceExpression (resultArgName));
532
533                                 if (cm.ReturnType.BaseType == "System.Void")
534                                         cm.Statements.Add (new CodeExpressionStatement (call));
535                                 else
536                                         cm.Statements.Add (new CodeMethodReturnStatement (call));
537                         }
538                 }
539
540                 CodeMemberMethod GenerateImplementationClientMethod (CodeTypeDeclaration type, ContractDescription cd, OperationDescription od, bool async, out CodeTypeReference returnTypeFromMessageContract)
541                 {
542                         CodeMemberMethod cm = new CodeMemberMethod ();
543                         if (async)
544                                 cm.Name = "Begin" + od.Name;
545                         else
546                                 cm.Name = od.Name;
547                         cm.Attributes = MemberAttributes.Public | MemberAttributes.Final;
548                         returnTypeFromMessageContract = null;
549
550                         List<CodeExpression> args = new List<CodeExpression> ();
551                         if (od.SyncMethod != null) {
552                                 ParameterInfo [] pars = od.SyncMethod.GetParameters ();
553                                 ExportParameters (cm, pars);
554                                 cm.ReturnType = new CodeTypeReference (od.SyncMethod.ReturnType);
555                                 int i = 0;
556                                 foreach (ParameterInfo pi in pars)
557                                         args.Add (new CodeArgumentReferenceExpression (pi.Name));
558                         } else {
559                                 args.AddRange (ExportMessages (od.Messages, cm, true));
560                                 returnTypeFromMessageContract = cm.ReturnType;
561                                 if (async) {
562                                         AddBeginAsyncArgs (cm);
563                                         cm.ReturnType = new CodeTypeReference (typeof (IAsyncResult));
564                                 }
565                         }
566                         if (async) {
567                                 args.Add (new CodeArgumentReferenceExpression ("asyncCallback"));
568                                 args.Add (new CodeArgumentReferenceExpression ("userState"));
569                         }
570
571                         CodeExpression call = new CodeMethodInvokeExpression (
572                                 new CodePropertyReferenceExpression (
573                                         new CodeBaseReferenceExpression (),
574                                         "Channel"),
575                                 cm.Name,
576                                 args.ToArray ());
577
578                         if (cm.ReturnType.BaseType == "System.Void")
579                                 cm.Statements.Add (new CodeExpressionStatement (call));
580                         else
581                                 cm.Statements.Add (new CodeMethodReturnStatement (call));
582                         return cm;
583                 }
584
585                 CodeMemberMethod FindByName (CodeTypeDeclaration type, string name)
586                 {
587                         foreach (var m in type.Members) {
588                                 var method = m as CodeMemberMethod;
589                                 if (method != null && method.Name == name)
590                                         return method;
591                         }
592                         return null;
593                 }
594
595                 void GenerateEventBasedAsyncSupport (CodeTypeDeclaration type, OperationDescription od, CodeNamespace cns)
596                 {
597                         var method = FindByName (type, od.Name) ?? FindByName (type, "Begin" + od.Name);
598                         var endMethod = method.Name == od.Name ? null : FindByName (type, "End" + od.Name);
599                         bool methodAsync = method.Name.StartsWith ("Begin", StringComparison.Ordinal);
600                         var resultType = endMethod != null ? endMethod.ReturnType : method.ReturnType;
601
602                         var thisExpr = new CodeThisReferenceExpression ();
603                         var baseExpr = new CodeBaseReferenceExpression ();
604                         var nullExpr = new CodePrimitiveExpression (null);
605                         var asyncResultType = new CodeTypeReference (typeof (IAsyncResult));
606
607                         // OnBeginXxx() implementation
608                         var cm = new CodeMemberMethod () {
609                                 Name = "OnBegin" + od.Name,
610                                 Attributes = MemberAttributes.Private | MemberAttributes.Final,
611                                 ReturnType = asyncResultType
612                                 };
613                         type.Members.Add (cm);
614
615                         AddMethodParam (cm, typeof (object []), "args");
616                         AddMethodParam (cm, typeof (AsyncCallback), "asyncCallback");
617                         AddMethodParam (cm, typeof (object), "userState");
618
619                         var call = new CodeMethodInvokeExpression (
620                                 thisExpr,
621                                 "Begin" + od.Name);
622                         for (int idx = 0; idx < method.Parameters.Count - (methodAsync ? 2 : 0); idx++) {
623                                 var p = method.Parameters [idx];
624                                 cm.Statements.Add (new CodeVariableDeclarationStatement (p.Type, p.Name, new CodeCastExpression (p.Type, new CodeArrayIndexerExpression (new CodeArgumentReferenceExpression ("args"), new CodePrimitiveExpression (idx)))));
625                                 call.Parameters.Add (new CodeVariableReferenceExpression (p.Name));
626                         }
627                         call.Parameters.Add (new CodeArgumentReferenceExpression ("asyncCallback"));
628                         call.Parameters.Add (new CodeArgumentReferenceExpression ("userState"));
629                         cm.Statements.Add (new CodeMethodReturnStatement (call));
630
631                         // OnEndXxx() implementation
632                         cm = new CodeMemberMethod () {
633                                 Name = "OnEnd" + od.Name,
634                                 Attributes = MemberAttributes.Private | MemberAttributes.Final,
635                                 ReturnType = new CodeTypeReference (typeof (object [])) };
636                         type.Members.Add (cm);
637
638                         AddMethodParam (cm, typeof (IAsyncResult), "result");
639
640                         var outArgRefs = new List<CodeVariableReferenceExpression> ();
641
642                         for (int idx = 0; idx < method.Parameters.Count; idx++) {
643                                 var p = method.Parameters [idx];
644                                 if (p.Direction != FieldDirection.In) {
645                                         cm.Statements.Add (new CodeVariableDeclarationStatement (p.Type, p.Name));
646                                         outArgRefs.Add (new CodeVariableReferenceExpression (p.Name)); // FIXME: should this work? They need "out" or "ref" modifiers.
647                                 }
648                         }
649
650                         call = new CodeMethodInvokeExpression (
651                                 thisExpr,
652                                 "End" + od.Name,
653                                 new CodeArgumentReferenceExpression ("result"));
654                         call.Parameters.AddRange (outArgRefs.Cast<CodeExpression> ().ToArray ()); // questionable
655
656                         var retCreate = new CodeArrayCreateExpression (typeof (object));
657                         if (resultType.BaseType == "System.Void")
658                                 cm.Statements.Add (call);
659                         else {
660                                 cm.Statements.Add (new CodeVariableDeclarationStatement (typeof (object), "__ret", call));
661                                 retCreate.Initializers.Add (new CodeVariableReferenceExpression ("__ret"));
662                         }
663                         foreach (var outArgRef in outArgRefs)
664                                 retCreate.Initializers.Add (new CodeVariableReferenceExpression (outArgRef.VariableName));
665
666                         cm.Statements.Add (new CodeMethodReturnStatement (retCreate));
667
668                         // OnXxxCompleted() implementation
669                         cm = new CodeMemberMethod () {
670                                 Name = "On" + od.Name + "Completed",
671                                 Attributes = MemberAttributes.Private | MemberAttributes.Final };
672                         type.Members.Add (cm);
673
674                         AddMethodParam (cm, typeof (object), "state");
675
676                         string argsname = identifiers.AddUnique (od.Name + "CompletedEventArgs", null);
677                         var iaargs = new CodeTypeReference ("InvokeAsyncCompletedEventArgs"); // avoid messy System.Type instance for generic nested type :|
678                         var iaref = new CodeVariableReferenceExpression ("args");
679                         var methodEventArgs = new CodeObjectCreateExpression (new CodeTypeReference (argsname),
680                                 new CodePropertyReferenceExpression (iaref, "Results"),
681                                 new CodePropertyReferenceExpression (iaref, "Error"),
682                                 new CodePropertyReferenceExpression (iaref, "Cancelled"),
683                                 new CodePropertyReferenceExpression (iaref, "UserState"));
684                         cm.Statements.Add (new CodeConditionStatement (
685                                 new CodeBinaryOperatorExpression (
686                                         new CodeEventReferenceExpression (thisExpr, od.Name + "Completed"), CodeBinaryOperatorType.IdentityInequality, nullExpr),
687                                 new CodeVariableDeclarationStatement (iaargs, "args", new CodeCastExpression (iaargs, new CodeArgumentReferenceExpression ("state"))),
688                                 new CodeExpressionStatement (new CodeMethodInvokeExpression (thisExpr, od.Name + "Completed", thisExpr, methodEventArgs))));
689
690                         // delegate fields
691                         type.Members.Add (new CodeMemberField (new CodeTypeReference ("BeginOperationDelegate"), "onBegin" + od.Name + "Delegate"));
692                         type.Members.Add (new CodeMemberField (new CodeTypeReference ("EndOperationDelegate"), "onEnd" + od.Name + "Delegate"));
693                         type.Members.Add (new CodeMemberField (new CodeTypeReference (typeof (SendOrPostCallback)), "on" + od.Name + "CompletedDelegate"));
694
695                         // XxxCompletedEventArgs class
696                         var argsType = new CodeTypeDeclaration (argsname);
697                         argsType.BaseTypes.Add (new CodeTypeReference (typeof (AsyncCompletedEventArgs)));
698                         cns.Types.Add (argsType);
699
700                         var argsCtor = new CodeConstructor () {
701                                 Attributes = MemberAttributes.Public | MemberAttributes.Final };
702                         argsCtor.Parameters.Add (new CodeParameterDeclarationExpression (typeof (object []), "results"));
703                         argsCtor.Parameters.Add (new CodeParameterDeclarationExpression (typeof (Exception), "error"));
704                         argsCtor.Parameters.Add (new CodeParameterDeclarationExpression (typeof (bool), "cancelled"));
705                         argsCtor.Parameters.Add (new CodeParameterDeclarationExpression (typeof (object), "userState"));
706                         argsCtor.BaseConstructorArgs.Add (new CodeArgumentReferenceExpression ("error"));
707                         argsCtor.BaseConstructorArgs.Add (new CodeArgumentReferenceExpression ("cancelled"));
708                         argsCtor.BaseConstructorArgs.Add (new CodeArgumentReferenceExpression ("userState"));
709                         var resultsField = new CodeFieldReferenceExpression (thisExpr, "results");
710                         argsCtor.Statements.Add (new CodeAssignStatement (resultsField, new CodeArgumentReferenceExpression ("results")));
711                         argsType.Members.Add (argsCtor);
712
713                         argsType.Members.Add (new CodeMemberField (typeof (object []), "results"));
714
715                         if (resultType.BaseType != "System.Void") {
716                                 var resultProp = new CodeMemberProperty {
717                                         Name = "Result",
718                                         Type = resultType,
719                                         Attributes = MemberAttributes.Public | MemberAttributes.Final };
720                                 resultProp.GetStatements.Add (new CodeMethodReturnStatement (new CodeCastExpression (resultProp.Type, new CodeArrayIndexerExpression (resultsField, new CodePrimitiveExpression (0)))));
721                                 argsType.Members.Add (resultProp);
722                         }
723
724                         // event field
725                         var handlerType = new CodeTypeReference (typeof (EventHandler<>));
726                         handlerType.TypeArguments.Add (new CodeTypeReference (argsType.Name));
727                         type.Members.Add (new CodeMemberEvent () {
728                                 Name = od.Name + "Completed",
729                                 Type = handlerType,
730                                 Attributes = MemberAttributes.Public | MemberAttributes.Final });
731
732                         // XxxAsync() implementations
733                         bool hasAsync = false;
734                         foreach (int __x in Enumerable.Range (0, 2)) {
735                                 cm = new CodeMemberMethod ();
736                                 type.Members.Add (cm);
737                                 cm.Name = od.Name + "Async";
738                                 cm.Attributes = MemberAttributes.Public 
739                                                 | MemberAttributes.Final;
740
741                                 var inArgs = new List<CodeParameterDeclarationExpression > ();
742
743                                 for (int idx = 0; idx < method.Parameters.Count - (methodAsync ? 2 : 0); idx++) {
744                                         var pd = method.Parameters [idx];
745                                         inArgs.Add (pd);
746                                         cm.Parameters.Add (pd);
747                                 }
748
749                                 // First one is overload without asyncState arg.
750                                 if (!hasAsync) {
751                                         call = new CodeMethodInvokeExpression (thisExpr, cm.Name, inArgs.ConvertAll<CodeExpression> (decl => new CodeArgumentReferenceExpression (decl.Name)).ToArray ());
752                                         call.Parameters.Add (nullExpr);
753                                         cm.Statements.Add (new CodeExpressionStatement (call));
754                                         hasAsync = true;
755                                         continue;
756                                 }
757
758                                 // Second one is the primary one.
759
760                                 cm.Parameters.Add (new CodeParameterDeclarationExpression (typeof (object), "userState"));
761
762                                 // if (onBeginBarOperDelegate == null) onBeginBarOperDelegate = new BeginOperationDelegate (OnBeginBarOper);
763                                 // if (onEndBarOperDelegate == null) onEndBarOperDelegate = new EndOperationDelegate (OnEndBarOper);
764                                 // if (onBarOperCompletedDelegate == null) onBarOperCompletedDelegate = new BeginOperationDelegate (OnBarOperCompleted);
765                                 var beginOperDelegateRef = new CodeFieldReferenceExpression (thisExpr, "onBegin" + od.Name + "Delegate");
766                                 var endOperDelegateRef = new CodeFieldReferenceExpression (thisExpr, "onEnd" + od.Name + "Delegate");
767                                 var operCompletedDelegateRef = new CodeFieldReferenceExpression (thisExpr, "on" + od.Name + "CompletedDelegate");
768
769                                 var ifstmt = new CodeConditionStatement (
770                                         new CodeBinaryOperatorExpression (beginOperDelegateRef, CodeBinaryOperatorType.IdentityEquality, nullExpr),
771                                         new CodeAssignStatement (beginOperDelegateRef, new CodeDelegateCreateExpression (new CodeTypeReference ("BeginOperationDelegate"), thisExpr, "OnBegin" + od.Name)));
772                                 cm.Statements.Add (ifstmt);
773                                 ifstmt = new CodeConditionStatement (
774                                         new CodeBinaryOperatorExpression (endOperDelegateRef, CodeBinaryOperatorType.IdentityEquality, nullExpr),
775                                         new CodeAssignStatement (endOperDelegateRef, new CodeDelegateCreateExpression (new CodeTypeReference ("EndOperationDelegate"), thisExpr, "OnEnd" + od.Name)));
776                                 cm.Statements.Add (ifstmt);
777                                 ifstmt = new CodeConditionStatement (
778                                         new CodeBinaryOperatorExpression (operCompletedDelegateRef, CodeBinaryOperatorType.IdentityEquality, nullExpr),
779                                         new CodeAssignStatement (operCompletedDelegateRef, new CodeDelegateCreateExpression (new CodeTypeReference (typeof (SendOrPostCallback)), thisExpr, "On" + od.Name + "Completed")));
780                                 cm.Statements.Add (ifstmt);
781
782                                 // InvokeAsync (onBeginBarOperDelegate, inValues, onEndBarOperDelegate, onBarOperCompletedDelegate, userState);
783
784                                 inArgs.Add (new CodeParameterDeclarationExpression (typeof (object), "userState"));
785
786                                 var args = new List<CodeExpression> ();
787                                 args.Add (beginOperDelegateRef);
788                                 args.Add (new CodeArrayCreateExpression (typeof (object), inArgs.ConvertAll<CodeExpression> (decl => new CodeArgumentReferenceExpression (decl.Name)).ToArray ()));
789                                 args.Add (endOperDelegateRef);
790                                 args.Add (new CodeFieldReferenceExpression (thisExpr, "on" + od.Name + "CompletedDelegate"));
791                                 args.Add (new CodeArgumentReferenceExpression ("userState"));
792                                 call = new CodeMethodInvokeExpression (baseExpr, "InvokeAsync", args.ToArray ());
793                                 cm.Statements.Add (new CodeExpressionStatement (call));
794                         }
795                 }
796
797                 void AddMethodParam (CodeMemberMethod cm, Type type, string name)
798                 {
799                         cm.Parameters.Add (new CodeParameterDeclarationExpression (new CodeTypeReference (type), name));
800                 }
801
802                 const string ms_arrays_ns = "http://schemas.microsoft.com/2003/10/Serialization/Arrays";
803
804                 private CodeExpression[] ExportMessages (MessageDescriptionCollection messages, CodeMemberMethod method, bool return_args)
805                 {
806                         CodeExpression [] args = null;
807                         foreach (MessageDescription md in messages) {
808                                 if (md.Direction == MessageDirection.Output) {
809                                         if (md.Body.ReturnValue != null) {
810                                                 ExportDataContract (md.Body.ReturnValue);
811                                                 method.ReturnType = md.Body.ReturnValue.CodeTypeReference;
812                                         }
813                                         continue;
814                                 }
815
816                                 if (return_args)
817                                         args = new CodeExpression [md.Body.Parts.Count];
818
819                                 MessagePartDescriptionCollection parts = md.Body.Parts;
820                                 for (int i = 0; i < parts.Count; i++) {
821                                         ExportDataContract (parts [i]);
822
823                                         method.Parameters.Add (
824                                                 new CodeParameterDeclarationExpression (
825                                                         parts [i].CodeTypeReference,
826                                                         parts [i].Name));
827
828                                         if (return_args)
829                                                 args [i] = new CodeArgumentReferenceExpression (parts [i].Name);
830                                 }
831                         }
832
833                         return args;
834                 }
835
836                 #endregion
837
838                 public CodeTypeReference GenerateServiceEndpoint (
839                         ServiceEndpoint endpoint,
840                         out ChannelEndpointElement channelElement)
841                 {
842                         if (config == null)
843                                 throw new InvalidOperationException ();
844
845                         var cd = endpoint.Contract;
846                         var cns = GetNamespace (cd.Namespace);
847                         var cache = ExportInterface_internal (cd, cns);
848
849                         string bindingSectionName, configurationName;
850                         GenerateBinding (endpoint.Binding, out bindingSectionName, out configurationName);
851
852                         channelElement = new ChannelEndpointElement ();
853                         channelElement.Binding = bindingSectionName;
854                         channelElement.BindingConfiguration = configurationName;
855                         channelElement.Name = configurationName;
856                         channelElement.Contract = cache.ConfigurationName;
857                         channelElement.Address = endpoint.Address.Uri;
858
859                         var section = (ClientSection)config.GetSection ("system.serviceModel/client");
860                         section.Endpoints.Add (channelElement);
861
862                         return cache.GetReference ();
863                 }
864
865                 void MergeCompileUnit (CodeCompileUnit from, CodeCompileUnit to)
866                 {
867                         if (from == to)
868                                 return;
869                         foreach (CodeNamespace fns in from.Namespaces) {
870                                 bool merged = false;
871                                 foreach (CodeNamespace tns in to.Namespaces)
872                                         if (fns.Name == tns.Name) {
873                                                 // namespaces are merged.
874                                                 MergeNamespace (fns, tns);
875                                                 merged = true;
876                                                 break;
877                                         }
878                                 if (!merged)
879                                         to.Namespaces.Add (fns);
880                         }
881                 }
882
883                 // existing type is skipped.
884                 void MergeNamespace (CodeNamespace from, CodeNamespace to)
885                 {
886                         foreach (CodeTypeDeclaration ftd in from.Types) {
887                                 bool skip = false;
888                                 foreach (CodeTypeDeclaration ttd in to.Types)
889                                         if (ftd.Name == ttd.Name) {
890                                                 skip = true;
891                                                 break;
892                                         }
893                                 if (!skip)
894                                         to.Types.Add (ftd);
895                         }
896                 }
897
898                 private void ExportDataContract (MessagePartDescription md)
899                 {
900                         if (data_contract_importer == null)
901                                 data_contract_importer = md.DataContractImporter;
902                         else if (md.DataContractImporter != null && data_contract_importer != md.DataContractImporter)
903                                 throw new Exception ("INTERNAL ERROR: should not happen");
904                         if (xml_serialization_importer == null)
905                                 xml_serialization_importer = md.XmlSerializationImporter;
906                         else if (md.XmlSerializationImporter != null && xml_serialization_importer != md.XmlSerializationImporter)
907                                 throw new Exception ("INTERNAL ERROR: should not happen");
908                 }
909                 
910                 private string GetXmlNamespace (CodeTypeDeclaration type)
911                 {
912                         foreach (CodeAttributeDeclaration attr in type.CustomAttributes) {
913                                 if (attr.Name == "System.Xml.Serialization.XmlTypeAttribute" ||
914                                         attr.Name == "System.Xml.Serialization.XmlRootAttribute") {
915
916                                         foreach (CodeAttributeArgument arg in attr.Arguments)
917                                                 if (arg.Name == "Namespace")
918                                                         return ((CodePrimitiveExpression)arg.Value).Value as string;
919
920                                         //Could not find Namespace arg!
921                                         return null;    
922                                 }
923                         }
924                         
925                         return null;
926                 }
927
928
929         }
930 }