Merge pull request #3769 from evincarofautumn/fix-verify-before-allocs
[mono.git] / mcs / class / System / Mono.Net.Security / ChainValidationHelper.cs
index 63a781dbdf5a7d37628af8efb85473e6692d486e..324192d6af182a898272432a394351099def6971 100644 (file)
@@ -34,9 +34,6 @@
 #if MONO_SECURITY_ALIAS
 extern alias MonoSecurity;
 #endif
-#if MONO_X509_ALIAS
-extern alias PrebuiltSystem;
-#endif
 
 #if MONO_SECURITY_ALIAS
 using MonoSecurity::Mono.Security.Interface;
@@ -47,13 +44,6 @@ using Mono.Security.Interface;
 using MSX = Mono.Security.X509;
 using Mono.Security.X509.Extensions;
 #endif
-#if MONO_X509_ALIAS
-using XX509CertificateCollection = PrebuiltSystem::System.Security.Cryptography.X509Certificates.X509CertificateCollection;
-using XX509Chain = PrebuiltSystem::System.Security.Cryptography.X509Certificates.X509Chain;
-#else
-using XX509CertificateCollection = System.Security.Cryptography.X509Certificates.X509CertificateCollection;
-using XX509Chain = System.Security.Cryptography.X509Certificates.X509Chain;
-#endif
 
 using System;
 using System.Net;
@@ -74,7 +64,7 @@ namespace Mono.Net.Security
 {
        internal delegate bool ServerCertValidationCallbackWrapper (ServerCertValidationCallback callback, X509Certificate certificate, X509Chain chain, MonoSslPolicyErrors sslPolicyErrors);
 
-       internal class ChainValidationHelper : ICertificateValidator
+       internal class ChainValidationHelper : ICertificateValidator2
        {
                readonly object sender;
                readonly MonoTlsSettings settings;
@@ -85,7 +75,9 @@ namespace Mono.Net.Security
                readonly MonoTlsStream tlsStream;
                readonly HttpWebRequest request;
 
-               internal static ICertificateValidator GetDefaultValidator (MonoTlsProvider provider, MonoTlsSettings settings)
+#pragma warning disable 618
+
+               internal static ICertificateValidator GetInternalValidator (MonoTlsProvider provider, MonoTlsSettings settings)
                {
                        if (settings == null)
                                return new ChainValidationHelper (provider, null, false, null, null);
@@ -94,6 +86,16 @@ namespace Mono.Net.Security
                        return new ChainValidationHelper (provider, settings, false, null, null);
                }
 
+               internal static ICertificateValidator GetDefaultValidator (MonoTlsSettings settings)
+               {
+                       var provider = MonoTlsProviderFactory.GetProvider ();
+                       if (settings == null)
+                               return new ChainValidationHelper (provider, null, false, null, null);
+                       if (settings.CertificateValidator != null)
+                               throw new NotSupportedException ();
+                       return new ChainValidationHelper (provider, settings, false, null, null);
+               }
+
 #region SslStream support
 
                /*
@@ -146,6 +148,8 @@ namespace Mono.Net.Security
                                settings = MonoTlsSettings.CopyDefaultSettings ();
                        if (cloneSettings)
                                settings = settings.CloneWithValidator (this);
+                       if (provider == null)
+                               provider = MonoTlsProviderFactory.GetProvider ();
 
                        this.provider = provider;
                        this.settings = settings;
@@ -180,7 +184,9 @@ namespace Mono.Net.Security
                                certValidationCallback = ServicePointManager.ServerCertValidationCallback;
                }
 
-               static X509Certificate DefaultSelectionCallback (string targetHost, XX509CertificateCollection localCertificates, X509Certificate remoteCertificate, string[] acceptableIssuers)
+#pragma warning restore 618
+
+               static X509Certificate DefaultSelectionCallback (string targetHost, X509CertificateCollection localCertificates, X509Certificate remoteCertificate, string[] acceptableIssuers)
                {
                        X509Certificate clientCertificate;
                        if (localCertificates == null || localCertificates.Count == 0)
@@ -203,7 +209,7 @@ namespace Mono.Net.Security
                }
 
                public bool SelectClientCertificate (
-                       string targetHost, XX509CertificateCollection localCertificates, X509Certificate remoteCertificate,
+                       string targetHost, X509CertificateCollection localCertificates, X509Certificate remoteCertificate,
                        string[] acceptableIssuers, out X509Certificate clientCertificate)
                {
                        if (certSelectionCallback == null) {
@@ -215,7 +221,7 @@ namespace Mono.Net.Security
                }
 
                internal X509Certificate SelectClientCertificate (
-                       string targetHost, XX509CertificateCollection localCertificates, X509Certificate remoteCertificate,
+                       string targetHost, X509CertificateCollection localCertificates, X509Certificate remoteCertificate,
                        string[] acceptableIssuers)
                {
                        if (certSelectionCallback == null)
@@ -225,20 +231,39 @@ namespace Mono.Net.Security
 
                internal bool ValidateClientCertificate (X509Certificate certificate, MonoSslPolicyErrors errors)
                {
-                       var certs = new XX509CertificateCollection ();
+                       var certs = new X509CertificateCollection ();
                        certs.Add (new X509Certificate2 (certificate.GetRawCertData ()));
 
-                       var result = ValidateChain (string.Empty, true, certs, (SslPolicyErrors)errors);
+                       var result = ValidateChain (string.Empty, true, certificate, null, certs, (SslPolicyErrors)errors);
                        if (result == null)
                                return false;
 
                        return result.Trusted && !result.UserDenied;
                }
 
-               public ValidationResult ValidateCertificate (string host, bool serverMode, XX509CertificateCollection certs)
+               public ValidationResult ValidateCertificate (string host, bool serverMode, X509CertificateCollection certs)
+               {
+                       try {
+                               X509Certificate leaf;
+                               if (certs != null && certs.Count != 0)
+                                       leaf = certs [0];
+                               else
+                                       leaf = null;
+                               var result = ValidateChain (host, serverMode, leaf, null, certs, 0);
+                               if (tlsStream != null)
+                                       tlsStream.CertificateValidationFailed = result == null || !result.Trusted || result.UserDenied;
+                               return result;
+                       } catch {
+                               if (tlsStream != null)
+                                       tlsStream.CertificateValidationFailed = true;
+                               throw;
+                       }
+               }
+
+               public ValidationResult ValidateCertificate (string host, bool serverMode, X509Certificate leaf, X509Chain chain)
                {
                        try {
-                               var result = ValidateChain (host, serverMode, certs, 0);
+                               var result = ValidateChain (host, serverMode, leaf, chain, null, 0);
                                if (tlsStream != null)
                                        tlsStream.CertificateValidationFailed = result == null || !result.Trusted || result.UserDenied;
                                return result;
@@ -249,7 +274,28 @@ namespace Mono.Net.Security
                        }
                }
 
-               ValidationResult ValidateChain (string host, bool server, XX509CertificateCollection certs, SslPolicyErrors errors)
+               ValidationResult ValidateChain (string host, bool server, X509Certificate leaf,
+                                               X509Chain chain, X509CertificateCollection certs,
+                                               SslPolicyErrors errors)
+               {
+                       var oldChain = chain;
+                       var ownsChain = chain == null;
+                       try {
+                               var result = ValidateChain (host, server, leaf, ref chain, certs, errors);
+                               if (chain != oldChain)
+                                       ownsChain = true;
+
+                               return result;
+                       } finally {
+                               // If ValidateChain() changed the chain, then we need to free it.
+                               if (ownsChain && chain != null)
+                                       chain.Dispose ();
+                       }
+               }
+
+               ValidationResult ValidateChain (string host, bool server, X509Certificate leaf,
+                                               ref X509Chain chain, X509CertificateCollection certs,
+                                               SslPolicyErrors errors)
                {
                        // user_denied is true if the user callback is called and returns false
                        bool user_denied = false;
@@ -257,12 +303,6 @@ namespace Mono.Net.Security
 
                        var hasCallback = certValidationCallback != null || callbackWrapper != null;
 
-                       X509Certificate leaf;
-                       if (certs == null || certs.Count == 0)
-                               leaf = null;
-                       else
-                               leaf = certs [0];
-
                        if (tlsStream != null)
                                request.ServicePoint.UpdateServerCertificate (leaf);
 
@@ -278,10 +318,16 @@ namespace Mono.Net.Security
                                return new ValidationResult (result, user_denied, 0, (MonoSslPolicyErrors)errors);
                        }
 
+                       // Ignore port number when validating certificates.
+                       if (!string.IsNullOrEmpty (host)) {
+                               var pos = host.IndexOf (':');
+                               if (pos > 0)
+                                       host = host.Substring (0, pos);
+                       }
+
                        ICertificatePolicy policy = ServicePointManager.GetLegacyCertificatePolicy ();
 
                        int status11 = 0; // Error code passed to the obsolete ICertificatePolicy callback
-                       X509Chain chain = null;
 
                        bool wantsChain = SystemCertificateValidator.NeedsChain (settings);
                        if (!wantsChain && hasCallback) {
@@ -289,28 +335,19 @@ namespace Mono.Net.Security
                                        wantsChain = true;
                        }
 
-                       if (wantsChain)
-                               chain = SystemCertificateValidator.CreateX509Chain (certs);
+                       var xerrors = (MonoSslPolicyErrors)errors;
+                       result = provider.ValidateCertificate (this, host, server, certs, wantsChain, ref chain, ref xerrors, ref status11);
+                       errors = (SslPolicyErrors)xerrors;
 
-                       bool providerValidated = false;
-                       if (provider != null && provider.HasCustomSystemCertificateValidator) {
-                               var xerrors = (MonoSslPolicyErrors)errors;
-                               var xchain = (XX509Chain)(object)chain;
-                               providerValidated = provider.InvokeSystemCertificateValidator (this, host, server, certs, xchain, out result, ref xerrors, ref status11);
-                               errors = (SslPolicyErrors)xerrors;
+                       if (status11 == 0 && errors != 0) {
+                               // TRUST_E_FAIL
+                               status11 = unchecked ((int)0x800B010B);
                        }
 
-                       if (!providerValidated)
-                               result = SystemCertificateValidator.Evaluate (settings, host, certs, chain, ref errors, ref status11);
-
                        if (policy != null && (!(policy is DefaultCertificatePolicy) || certValidationCallback == null)) {
                                ServicePoint sp = null;
                                if (request != null)
                                        sp = request.ServicePointNoLock;
-                               if (status11 == 0 && errors != 0) {
-                                       // TRUST_E_FAIL
-                                       status11 = unchecked ((int)0x800B010B);
-                               }
 
                                // pre 2.0 callback
                                result = policy.CheckValidationResult (sp, leaf, request, status11);
@@ -327,9 +364,8 @@ namespace Mono.Net.Security
                        return new ValidationResult (result, user_denied, status11, (MonoSslPolicyErrors)errors);
                }
 
-               public bool InvokeSystemValidator (string targetHost, bool serverMode, XX509CertificateCollection certificates, XX509Chain xchain, ref MonoSslPolicyErrors xerrors, ref int status11)
+               bool InvokeSystemValidator (string targetHost, bool serverMode, X509CertificateCollection certificates, X509Chain chain, ref MonoSslPolicyErrors xerrors, ref int status11)
                {
-                       X509Chain chain = (X509Chain)(object)xchain;
                        var errors = (SslPolicyErrors)xerrors;
                        var result = SystemCertificateValidator.Evaluate (settings, targetHost, certificates, chain, ref errors, ref status11);
                        xerrors = (MonoSslPolicyErrors)errors;