Merge pull request #4453 from lambdageek/bug-49721
[mono.git] / mcs / class / System.ServiceModel / System.ServiceModel.Description / ServiceContractGenerator.cs
index ef24544793ce179c23df88c79195868797200f23..8590c4538c532deab96263dff10a91815823c294 100644 (file)
@@ -58,35 +58,76 @@ namespace System.ServiceModel.Description
                        = new Dictionary<string,string> ();
                Dictionary<ContractDescription,Type> referenced_types
                        = new Dictionary<ContractDescription,Type> ();
+               Dictionary<ContractDescription,ContractCacheEntry> generated_contracts
+                       = new Dictionary<ContractDescription,ContractCacheEntry> ();
                ServiceContractGenerationOptions options;
-               Dictionary<QName, QName> imported_names = null;
+               Dictionary<QName, QName> imported_names
+                       = new Dictionary<QName, QName> ();
                ServiceContractGenerationContext contract_context;
                List<OPair> operation_contexts = new List<OPair> ();
 
-               XsdDataContractImporter xsd_data_importer;
+               XsdDataContractImporter data_contract_importer;
+               XmlSerializerMessageContractImporterInternal xml_serialization_importer;
+
+               class ContractCacheEntry {
+                       public ContractDescription Contract {
+                               get;
+                               private set;
+                       }
+
+                       public string ConfigurationName {
+                               get;
+                               private set;
+                       }
+
+                       public CodeTypeDeclaration TypeDeclaration {
+                               get;
+                               private set;
+                       }
+
+                       public bool GeneratedContractType {
+                               get; set;
+                       }
+
+                       public CodeTypeReference GetReference ()
+                       {
+                               return reference;
+                       }
+
+                       public ContractCacheEntry (ContractDescription cd, string config,
+                                                  CodeTypeDeclaration tdecl)
+                       {
+                               Contract = cd;
+                               ConfigurationName = config;
+                               TypeDeclaration = tdecl;
+                               reference = new CodeTypeReference (tdecl.Name);
+                       }
+
+                       readonly CodeTypeReference reference;
+               }
 
                public ServiceContractGenerator ()
                        : this (null, null)
                {
                }
 
-               public ServiceContractGenerator (CodeCompileUnit ccu)
-                       : this (ccu, null)
+               public ServiceContractGenerator (CodeCompileUnit targetCompileUnit)
+                       : this (targetCompileUnit, null)
                {
                }
 
-               public ServiceContractGenerator (ConfigurationType config)
-                       : this (null, config)
+               public ServiceContractGenerator (ConfigurationType targetConfig)
+                       : this (null, targetConfig)
                {
                }
 
-               public ServiceContractGenerator (CodeCompileUnit ccu, ConfigurationType config)
+               public ServiceContractGenerator (CodeCompileUnit targetCompileUnit, ConfigurationType targetConfig)
                {
-                       if (ccu == null)
+                       if (targetCompileUnit == null)
                                this.ccu = new CodeCompileUnit ();
                        else
-                               this.ccu = ccu;
-                       this.config = config;
+                               this.ccu = targetCompileUnit;
+                       this.config = targetConfig;
                        Options |= ServiceContractGenerationOptions.ChannelInterface | 
                                ServiceContractGenerationOptions.ClientClass;
                }
@@ -124,12 +165,26 @@ namespace System.ServiceModel.Description
                        get { return ccu; }
                }
 
-               [MonoTODO]
                public void GenerateBinding (Binding binding,
                        out string bindingSectionName,
                        out string configurationName)
                {
-                       throw new NotImplementedException ();
+                       if (config == null)
+                               throw new InvalidOperationException ();
+
+                       var element = ConfigUtil.FindCollectionElement (binding, config);
+                       if (element == null)
+                               throw new InvalidOperationException ();
+
+                       bindingSectionName = element.BindingName;
+
+                       int idx = 0;
+                       configurationName = binding.Name;
+                       while (element.ContainsKey (configurationName))
+                               configurationName = binding.Name + (++idx);
+
+                       if (!element.TryAdd (configurationName, binding, config))
+                               throw new InvalidOperationException ();
                }
 
                #region Service Contract Type
@@ -142,8 +197,9 @@ namespace System.ServiceModel.Description
                        ContractDescription contractDescription)
                {
                        CodeNamespace cns = GetNamespace (contractDescription.Namespace);
-                       imported_names = new Dictionary<QName, QName> ();
-                       var ret = ExportInterface (contractDescription, cns);
+                       var cache = ExportInterface_internal (contractDescription, cns);
+                       if (cache.GeneratedContractType)
+                               return cache.GetReference ();
 
                        // FIXME: handle duplex callback
 
@@ -153,8 +209,10 @@ namespace System.ServiceModel.Description
                        if ((Options & ServiceContractGenerationOptions.ClientClass) != 0)
                                GenerateProxyClass (contractDescription, cns);
 
-                       if (xsd_data_importer != null)
-                               MergeCompileUnit (xsd_data_importer.CodeCompileUnit, ccu);
+                       if (data_contract_importer != null)
+                               MergeCompileUnit (data_contract_importer.CodeCompileUnit, ccu);
+                       if (xml_serialization_importer != null)
+                               MergeCompileUnit (xml_serialization_importer.CodeCompileUnit, ccu);
 
                        // Process extensions. Class first, then methods.
                        // (built-in ones must present before processing class extensions).
@@ -166,18 +224,26 @@ namespace System.ServiceModel.Description
                        foreach (var opair in operation_contexts)
                                opair.Key.GenerateOperation (opair.Value);
 
-                       return ret;
+                       cache.GeneratedContractType = true;
+                       return cache.GetReference ();
                }
 
-               CodeNamespace GetNamespace (string ns)
+               CodeNamespace GetNamespace (string contractNs)
                {
-                       if (ns == null)
-                               ns = String.Empty;
+                       if (contractNs == null)
+                               contractNs = String.Empty;
+                       string csharpNs;
+                       if (nsmappings.ContainsKey (contractNs))
+                               csharpNs = nsmappings [contractNs];
+                       else if (nsmappings.ContainsKey ("*"))
+                               csharpNs = nsmappings ["*"];
+                       else
+                               csharpNs = string.Empty;
                        foreach (CodeNamespace cns in ccu.Namespaces)
-                               if (cns.Name == ns)
+                               if (cns.Name == csharpNs)
                                        return cns;
                        CodeNamespace ncns = new CodeNamespace ();
-                       //ncns.Name = ns;
+                       ncns.Name = csharpNs;
                        ccu.Namespaces.Add (ncns);
                        return ncns;
                }
@@ -294,25 +360,36 @@ namespace System.ServiceModel.Description
 
                CodeTypeReference ExportInterface (ContractDescription cd, CodeNamespace cns)
                {
-                       CodeTypeDeclaration type = GetTypeDeclaration (cns, cd.Name);
-                       if (type != null)
-                               return new CodeTypeReference (type.Name);
-                       type = new CodeTypeDeclaration ();
+                       var cache = ExportInterface_internal (cd, cns);
+                       return cache.GetReference ();
+               }
+
+               ContractCacheEntry ExportInterface_internal (ContractDescription cd, CodeNamespace cns)
+               {
+                       if (generated_contracts.ContainsKey (cd))
+                               return generated_contracts [cd];
+
+                       var type = new CodeTypeDeclaration ();
                        type.TypeAttributes = TypeAttributes.Interface;
                        type.TypeAttributes |= TypeAttributes.Public;
                        cns.Types.Add (type);
                        type.Name = identifiers.AddUnique (cd.Name, null);
+
+                       var configName = type.Name;
                        CodeAttributeDeclaration ad = 
                                new CodeAttributeDeclaration (
                                        new CodeTypeReference (
-                                               typeof (ServiceContractAttribute)));
+                                       typeof (ServiceContractAttribute)));
                        ad.Arguments.Add (new CodeAttributeArgument ("Namespace", new CodePrimitiveExpression (cd.Namespace)));
+                       ad.Arguments.Add (new CodeAttributeArgument ("ConfigurationName", new CodePrimitiveExpression (configName)));
                        type.CustomAttributes.Add (ad);
                        contract_context = new ServiceContractGenerationContext (this, cd, type);
-
+                       
                        AddOperationMethods (type, cd);
 
-                       return new CodeTypeReference (type.Name);
+                       var cache = new ContractCacheEntry (cd, configName, type);
+                       generated_contracts.Add (cd, cache);
+                       return cache;
                }
 
                void AddBeginAsyncArgs (CodeMemberMethod cm)
@@ -362,6 +439,9 @@ namespace System.ServiceModel.Description
                {
                        CodeMemberMethod cm = new CodeMemberMethod ();
 
+                       if (od.Behaviors.Find<XmlSerializerMappingBehavior> () != null)
+                               cm.CustomAttributes.Add (new CodeAttributeDeclaration (new CodeTypeReference (typeof (XmlSerializerFormatAttribute))));
+
                        if (async)
                                cm.Name = "Begin" + od.Name;
                        else
@@ -517,6 +597,7 @@ namespace System.ServiceModel.Description
                        var method = FindByName (type, od.Name) ?? FindByName (type, "Begin" + od.Name);
                        var endMethod = method.Name == od.Name ? null : FindByName (type, "End" + od.Name);
                        bool methodAsync = method.Name.StartsWith ("Begin", StringComparison.Ordinal);
+                       var resultType = endMethod != null ? endMethod.ReturnType : method.ReturnType;
 
                        var thisExpr = new CodeThisReferenceExpression ();
                        var baseExpr = new CodeBaseReferenceExpression ();
@@ -572,9 +653,13 @@ namespace System.ServiceModel.Description
                                new CodeArgumentReferenceExpression ("result"));
                        call.Parameters.AddRange (outArgRefs.Cast<CodeExpression> ().ToArray ()); // questionable
 
-                       cm.Statements.Add (new CodeVariableDeclarationStatement (typeof (object), "__ret", call));
                        var retCreate = new CodeArrayCreateExpression (typeof (object));
-                       retCreate.Initializers.Add (new CodeVariableReferenceExpression ("__ret"));
+                       if (resultType.BaseType == "System.Void")
+                               cm.Statements.Add (call);
+                       else {
+                               cm.Statements.Add (new CodeVariableDeclarationStatement (typeof (object), "__ret", call));
+                               retCreate.Initializers.Add (new CodeVariableReferenceExpression ("__ret"));
+                       }
                        foreach (var outArgRef in outArgRefs)
                                retCreate.Initializers.Add (new CodeVariableReferenceExpression (outArgRef.VariableName));
 
@@ -627,12 +712,14 @@ namespace System.ServiceModel.Description
 
                        argsType.Members.Add (new CodeMemberField (typeof (object []), "results"));
 
-                       var resultProp = new CodeMemberProperty {
-                               Name = "Result",
-                               Type = endMethod != null ? endMethod.ReturnType : method.ReturnType,
-                               Attributes = MemberAttributes.Public | MemberAttributes.Final };
-                       resultProp.GetStatements.Add (new CodeMethodReturnStatement (new CodeCastExpression (resultProp.Type, new CodeArrayIndexerExpression (resultsField, new CodePrimitiveExpression (0)))));
-                       argsType.Members.Add (resultProp);
+                       if (resultType.BaseType != "System.Void") {
+                               var resultProp = new CodeMemberProperty {
+                                       Name = "Result",
+                                       Type = resultType,
+                                       Attributes = MemberAttributes.Public | MemberAttributes.Final };
+                               resultProp.GetStatements.Add (new CodeMethodReturnStatement (new CodeCastExpression (resultProp.Type, new CodeArrayIndexerExpression (resultsField, new CodePrimitiveExpression (0)))));
+                               argsType.Members.Add (resultProp);
+                       }
 
                        // event field
                        var handlerType = new CodeTypeReference (typeof (EventHandler<>));
@@ -714,13 +801,6 @@ namespace System.ServiceModel.Description
 
                const string ms_arrays_ns = "http://schemas.microsoft.com/2003/10/Serialization/Arrays";
 
-               string GetCodeTypeName (QName mappedTypeName)
-               {
-                       if (mappedTypeName.Namespace == ms_arrays_ns)
-                               return DataContractSerializerMessageContractImporter.GetCLRTypeName (mappedTypeName.Name.Substring ("ArrayOf".Length)) + "[]";
-                       return mappedTypeName.Name;
-               }
-
                private CodeExpression[] ExportMessages (MessageDescriptionCollection messages, CodeMemberMethod method, bool return_args)
                {
                        CodeExpression [] args = null;
@@ -755,12 +835,31 @@ namespace System.ServiceModel.Description
 
                #endregion
 
-               [MonoTODO]
                public CodeTypeReference GenerateServiceEndpoint (
                        ServiceEndpoint endpoint,
                        out ChannelEndpointElement channelElement)
                {
-                       throw new NotImplementedException ();
+                       if (config == null)
+                               throw new InvalidOperationException ();
+
+                       var cd = endpoint.Contract;
+                       var cns = GetNamespace (cd.Namespace);
+                       var cache = ExportInterface_internal (cd, cns);
+
+                       string bindingSectionName, configurationName;
+                       GenerateBinding (endpoint.Binding, out bindingSectionName, out configurationName);
+
+                       channelElement = new ChannelEndpointElement ();
+                       channelElement.Binding = bindingSectionName;
+                       channelElement.BindingConfiguration = configurationName;
+                       channelElement.Name = configurationName;
+                       channelElement.Contract = cache.ConfigurationName;
+                       channelElement.Address = endpoint.Address.Uri;
+
+                       var section = (ClientSection)config.GetSection ("system.serviceModel/client");
+                       section.Endpoints.Add (channelElement);
+
+                       return cache.GetReference ();
                }
 
                void MergeCompileUnit (CodeCompileUnit from, CodeCompileUnit to)
@@ -798,8 +897,14 @@ namespace System.ServiceModel.Description
 
                private void ExportDataContract (MessagePartDescription md)
                {
-                       if (xsd_data_importer == null)
-                               xsd_data_importer = md.Importer;
+                       if (data_contract_importer == null)
+                               data_contract_importer = md.DataContractImporter;
+                       else if (md.DataContractImporter != null && data_contract_importer != md.DataContractImporter)
+                               throw new Exception ("INTERNAL ERROR: should not happen");
+                       if (xml_serialization_importer == null)
+                               xml_serialization_importer = md.XmlSerializationImporter;
+                       else if (md.XmlSerializationImporter != null && xml_serialization_importer != md.XmlSerializationImporter)
+                               throw new Exception ("INTERNAL ERROR: should not happen");
                }
                
                private string GetXmlNamespace (CodeTypeDeclaration type)