Merge pull request #3389 from lambdageek/bug-43099
[mono.git] / mcs / class / referencesource / System.ServiceModel / System / ServiceModel / Description / WsdlHelper.cs
1 // <copyright>
2 //     Copyright (c) Microsoft Corporation.  All rights reserved.
3 // </copyright>
4
5 namespace System.ServiceModel.Description
6 {
7     using System.Collections.Generic;
8     using System.Diagnostics.CodeAnalysis;
9     using System.Globalization;
10     using System.IO;
11     using System.Linq;
12     using System.Runtime;
13     using System.Text;
14     using System.Xml;
15     using System.Xml.Schema;
16     using System.Xml.Serialization;
17     using WsdlNS = System.Web.Services.Description;
18
19     internal static class WsdlHelper
20     {
21         public static WsdlNS.ServiceDescription GetSingleWsdl(MetadataSet metadataSet)
22         {
23             if (metadataSet.MetadataSections.Count < 1)
24             {
25                 return null;
26             }
27
28             List<WsdlNS.ServiceDescription> wsdls = new List<WsdlNS.ServiceDescription>();
29             List<XmlSchema> xsds = new List<XmlSchema>();
30
31             foreach (MetadataSection section in metadataSet.MetadataSections)
32             {
33                 if (section.Metadata is WsdlNS.ServiceDescription)
34                 {
35                     wsdls.Add((WsdlNS.ServiceDescription)section.Metadata);
36                 }
37
38                 if (section.Metadata is XmlSchema)
39                 {
40                     xsds.Add((XmlSchema)section.Metadata);
41                 }
42             }
43
44             VerifyContractNamespace(wsdls);
45             WsdlNS.ServiceDescription singleWsdl = GetSingleWsdl(CopyServiceDescriptionCollection(wsdls));
46
47             // Inline XML schemas
48             foreach (XmlSchema schema in xsds)
49             {
50                 XmlSchema newSchema = CloneXsd(schema);
51                 RemoveSchemaLocations(newSchema);
52                 singleWsdl.Types.Schemas.Add(newSchema);
53             }
54
55             return singleWsdl;
56         }
57
58         private static void RemoveSchemaLocations(XmlSchema schema)
59         {
60             foreach (XmlSchemaObject schemaObject in schema.Includes)
61             {
62                 XmlSchemaExternal external = schemaObject as XmlSchemaExternal;
63                 if (external != null)
64                 {
65                     external.SchemaLocation = null;
66                 }
67             }
68         }
69
70         private static WsdlNS.ServiceDescription GetSingleWsdl(List<WsdlNS.ServiceDescription> wsdls)
71         {
72             // Use WSDL that has the contracts as the base for single WSDL
73             WsdlNS.ServiceDescription singleWsdl = wsdls.First(wsdl => wsdl.PortTypes.Count > 0);
74             if (singleWsdl == null)
75             {
76                 singleWsdl = new WsdlNS.ServiceDescription();
77             }
78             else
79             {
80                 singleWsdl.Types.Schemas.Clear();
81                 singleWsdl.Imports.Clear();
82             }
83
84             Dictionary<XmlQualifiedName, XmlQualifiedName> bindingReferenceChanges = new Dictionary<XmlQualifiedName, XmlQualifiedName>();
85             foreach (WsdlNS.ServiceDescription wsdl in wsdls)
86             {
87                 if (wsdl != singleWsdl)
88                 {
89                     MergeWsdl(singleWsdl, wsdl, bindingReferenceChanges);
90                 }              
91             }
92
93             EnsureSingleNamespace(singleWsdl, bindingReferenceChanges);
94             return singleWsdl;
95         }
96
97         private static List<WsdlNS.ServiceDescription> CopyServiceDescriptionCollection(List<WsdlNS.ServiceDescription> wsdls)
98         {
99             List<WsdlNS.ServiceDescription> newWsdls = new List<WsdlNS.ServiceDescription>();
100             foreach (WsdlNS.ServiceDescription wsdl in wsdls)
101             {
102                 newWsdls.Add(CloneWsdl(wsdl));
103             }
104
105             return newWsdls;
106         }
107
108         private static void MergeWsdl(WsdlNS.ServiceDescription singleWsdl, WsdlNS.ServiceDescription wsdl, Dictionary<XmlQualifiedName, XmlQualifiedName> bindingReferenceChanges)
109         {
110             if (wsdl.Services.Count > 0)
111             {
112                 singleWsdl.Name = wsdl.Name;
113             }
114
115             foreach (WsdlNS.Binding binding in wsdl.Bindings)
116             {
117                 string uniqueBindingName = NamingHelper.GetUniqueName(binding.Name, WsdlHelper.IsBindingNameUsed, singleWsdl.Bindings);
118                 if (binding.Name != uniqueBindingName)
119                 {
120                     bindingReferenceChanges.Add(
121                         new XmlQualifiedName(binding.Name, binding.ServiceDescription.TargetNamespace),
122                         new XmlQualifiedName(uniqueBindingName, singleWsdl.TargetNamespace));
123                     UpdatePolicyKeys(binding, uniqueBindingName, wsdl);
124                     binding.Name = uniqueBindingName;
125                 }
126
127                 singleWsdl.Bindings.Add(binding);
128             }
129
130             foreach (object extension in wsdl.Extensions)
131             {
132                 singleWsdl.Extensions.Add(extension);
133             }
134
135             foreach (WsdlNS.Message message in wsdl.Messages)
136             {
137                 singleWsdl.Messages.Add(message);
138             }
139
140             foreach (WsdlNS.Service service in wsdl.Services)
141             {
142                 singleWsdl.Services.Add(service);
143             }
144
145             foreach (string warning in wsdl.ValidationWarnings)
146             {
147                 singleWsdl.ValidationWarnings.Add(warning);
148             }
149         }
150
151         private static void UpdatePolicyKeys(WsdlNS.Binding binding, string newBindingName, WsdlNS.ServiceDescription wsdl)
152         {
153             string oldBindingName = binding.Name;
154
155             // policy
156             IEnumerable<XmlElement> bindingPolicies = FindAllElements(wsdl.Extensions, MetadataStrings.WSPolicy.Elements.Policy);
157             string policyIdStringPrefixFormat = "{0}_";
158             foreach (XmlElement policyElement in bindingPolicies)
159             {
160                 XmlNode policyId = policyElement.Attributes.GetNamedItem(MetadataStrings.Wsu.Attributes.Id, MetadataStrings.Wsu.NamespaceUri);
161                 string policyIdString = policyId.Value;
162                 string policyIdStringWithOldBindingName = string.Format(CultureInfo.InvariantCulture, policyIdStringPrefixFormat, oldBindingName);
163                 string policyIdStringWithNewBindingName = string.Format(CultureInfo.InvariantCulture, policyIdStringPrefixFormat, newBindingName);
164                 if (policyId != null && policyIdString != null && policyIdString.StartsWith(policyIdStringWithOldBindingName, StringComparison.Ordinal))
165                 {
166                     policyId.Value = policyIdStringWithNewBindingName + policyIdString.Substring(policyIdStringWithOldBindingName.Length);
167                 }
168             }
169
170             // policy reference
171             UpdatePolicyReference(binding.Extensions, oldBindingName, newBindingName);
172             foreach (WsdlNS.OperationBinding operationBinding in binding.Operations)
173             {
174                 UpdatePolicyReference(operationBinding.Extensions, oldBindingName, newBindingName);
175                 if (operationBinding.Input != null)
176                 {
177                     UpdatePolicyReference(operationBinding.Input.Extensions, oldBindingName, newBindingName);
178                 }
179
180                 if (operationBinding.Output != null)
181                 {
182                     UpdatePolicyReference(operationBinding.Output.Extensions, oldBindingName, newBindingName);
183                 }
184
185                 foreach (WsdlNS.FaultBinding fault in operationBinding.Faults)
186                 {
187                     UpdatePolicyReference(fault.Extensions, oldBindingName, newBindingName);
188                 }
189             }
190         }
191
192         private static void UpdatePolicyReference(WsdlNS.ServiceDescriptionFormatExtensionCollection extensions, string oldBindingName, string newBindingName)
193         {
194             IEnumerable<XmlElement> bindingPolicyReferences = FindAllElements(extensions, MetadataStrings.WSPolicy.Elements.PolicyReference);
195             string policyReferencePrefixFormat = "#{0}_";
196             foreach (XmlElement policyReferenceElement in bindingPolicyReferences)
197             {
198                 XmlNode policyReference = policyReferenceElement.Attributes.GetNamedItem(MetadataStrings.WSPolicy.Attributes.URI);
199                 string policyReferenceValue = policyReference.Value;
200                 string policyReferenceValueWithOldBindingName = string.Format(CultureInfo.InvariantCulture, policyReferencePrefixFormat, oldBindingName);
201                 string policyReferenceValueWithNewBindingName = string.Format(CultureInfo.InvariantCulture, policyReferencePrefixFormat, newBindingName);
202                 if (policyReference != null && policyReferenceValue != null && policyReferenceValue.StartsWith(policyReferenceValueWithOldBindingName, StringComparison.Ordinal))
203                 {
204                     policyReference.Value = policyReferenceValueWithNewBindingName + policyReference.Value.Substring(policyReferenceValueWithOldBindingName.Length);
205                 }
206             }
207         }
208
209         private static IEnumerable<XmlElement> FindAllElements(WsdlNS.ServiceDescriptionFormatExtensionCollection extensions, string elementName)
210         {
211             List<XmlElement> policyReferences = new List<XmlElement>();
212             for (int i = 0; i < extensions.Count; i++)
213             {
214                 XmlElement element = extensions[i] as XmlElement;
215                 if (element != null && element.LocalName == elementName)
216                 {
217                     policyReferences.Add(element);
218                 }
219             }
220
221             return policyReferences;
222         }
223
224         private static void VerifyContractNamespace(List<WsdlNS.ServiceDescription> wsdls)
225         {
226             IEnumerable<WsdlNS.ServiceDescription> contractWsdls = wsdls.Where(serviceDescription => serviceDescription.PortTypes.Count > 0);
227             if (contractWsdls.Count() > 1)
228             {
229                 IEnumerable<string> namespaces = contractWsdls.Select<WsdlNS.ServiceDescription, string>(wsdl => wsdl.TargetNamespace);
230                 string contractNamespaces = string.Join(", ", namespaces);
231                 throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new NotSupportedException(SR.GetString(SR.SingleWsdlNotGenerated, contractNamespaces)));
232             }
233         }
234
235         private static void EnsureSingleNamespace(WsdlNS.ServiceDescription wsdl, Dictionary<XmlQualifiedName, XmlQualifiedName> bindingReferenceChanges)
236         {
237             string targetNamespace = wsdl.TargetNamespace;
238             foreach (WsdlNS.Binding binding in wsdl.Bindings)
239             {
240                 if (binding.Type.Namespace != targetNamespace)
241                 {
242                     binding.Type = new XmlQualifiedName(binding.Type.Name, targetNamespace);
243                 }
244             }
245
246             foreach (WsdlNS.PortType portType in wsdl.PortTypes)
247             {
248                 foreach (WsdlNS.Operation operation in portType.Operations)
249                 {
250                     WsdlNS.OperationInput messageInput = operation.Messages.Input;
251                     if (messageInput != null && messageInput.Message.Namespace != targetNamespace)
252                     {
253                         messageInput.Message = new XmlQualifiedName(messageInput.Message.Name, targetNamespace);
254                     }
255
256                     WsdlNS.OperationOutput messageOutput = operation.Messages.Output;
257                     if (messageOutput != null && messageOutput.Message.Namespace != targetNamespace)
258                     {
259                         messageOutput.Message = new XmlQualifiedName(messageOutput.Message.Name, targetNamespace);
260                     }
261
262                     foreach (WsdlNS.OperationFault fault in operation.Faults)
263                     {
264                         if (fault.Message.Namespace != targetNamespace)
265                         {
266                             fault.Message = new XmlQualifiedName(fault.Message.Name, targetNamespace);
267                         }
268                     }
269                 }
270             }
271
272             foreach (WsdlNS.Service service in wsdl.Services)
273             {
274                 foreach (WsdlNS.Port port in service.Ports)
275                 {
276                     XmlQualifiedName newPortBinding;
277                     if (bindingReferenceChanges.TryGetValue(port.Binding, out newPortBinding))
278                     {
279                         port.Binding = newPortBinding;
280                     }
281                     else if (port.Binding.Namespace != targetNamespace)
282                     {
283                         port.Binding = new XmlQualifiedName(port.Binding.Name, targetNamespace);
284                     }
285                 }
286             }
287         }
288
289         private static bool IsBindingNameUsed(string name, object collection)
290         {
291             WsdlNS.BindingCollection bindings = (WsdlNS.BindingCollection)collection;
292             foreach (WsdlNS.Binding binding in bindings)
293             {
294                 if (binding.Name == name)
295                 {
296                     return true;
297                 }
298             }
299
300             return false;
301         }
302
303         private static WsdlNS.ServiceDescription CloneWsdl(WsdlNS.ServiceDescription originalWsdl)
304         {
305             Fx.Assert(originalWsdl != null, "originalWsdl must not be null");
306             WsdlNS.ServiceDescription newWsdl;
307             using (MemoryStream memoryStream = new MemoryStream())
308             {
309                 originalWsdl.Write(memoryStream);
310                 memoryStream.Seek(0, SeekOrigin.Begin);
311                 newWsdl = WsdlNS.ServiceDescription.Read(memoryStream);
312             }
313
314             return newWsdl;
315         }
316
317         [SuppressMessage("Microsoft.Security.Xml", "CA3054:DoNotAllowDtdOnXmlTextReader")]
318         [SuppressMessage("Microsoft.Security.Xml", "CA3069:ReviewDtdProcessingAssignment", Justification = "This is trusted server code from the application only. We should allow the customer add dtd.")]
319         private static XmlSchema CloneXsd(XmlSchema originalXsd)
320         {
321             Fx.Assert(originalXsd != null, "originalXsd must not be null");
322             XmlSchema newXsd;
323             using (MemoryStream memoryStream = new MemoryStream())
324             {
325                 originalXsd.Write(memoryStream);
326                 memoryStream.Seek(0, SeekOrigin.Begin);
327                 newXsd = XmlSchema.Read(new XmlTextReader(memoryStream) { DtdProcessing = DtdProcessing.Parse }, null);
328             }
329
330             return newXsd;
331         }
332     }
333 }