[corlib] Fixes security tests failures
[mono.git] / mcs / class / System.Security / System.Security.Cryptography.Xml / EncryptedXml.cs
index d02ebef4a739ac5de84d9ef57c9a7d63b686e052..6dde9199a2ad36d5c5f9d8620f7444ff22760df2 100644 (file)
@@ -27,7 +27,6 @@
 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 //
 
-#if NET_2_0
 
 using System.Collections;
 using System.IO;
@@ -53,7 +52,7 @@ namespace System.Security.Cryptography.Xml {
                public const string XmlEncElementUrl            = XmlEncNamespaceUrl + "Element";
                public const string XmlEncEncryptedKeyUrl       = XmlEncNamespaceUrl + "EncryptedKey";
                public const string XmlEncNamespaceUrl          = "http://www.w3.org/2001/04/xmlenc#";
-               public const string XmlEncRSA1_5Url             = XmlEncNamespaceUrl + "rsa-1_5";
+               public const string XmlEncRSA15Url              = XmlEncNamespaceUrl + "rsa-1_5";
                public const string XmlEncRSAOAEPUrl            = XmlEncNamespaceUrl + "rsa-oaep-mgf1p";
                public const string XmlEncSHA256Url             = XmlEncNamespaceUrl + "sha256";
                public const string XmlEncSHA512Url             = XmlEncNamespaceUrl + "sha512";
@@ -141,7 +140,18 @@ namespace System.Security.Cryptography.Xml {
 
                public byte[] DecryptData (EncryptedData encryptedData, SymmetricAlgorithm symAlg)
                {
-                       return Transform (encryptedData.CipherData.CipherValue, symAlg.CreateDecryptor (), symAlg.BlockSize / 8);
+                       if (encryptedData == null)
+                               throw new ArgumentNullException ("encryptedData");
+                       if (symAlg == null)
+                               throw new ArgumentNullException ("symAlg");
+
+                       PaddingMode bak = symAlg.Padding;
+                       try {
+                               symAlg.Padding = Padding;
+                               return Transform (encryptedData.CipherData.CipherValue, symAlg.CreateDecryptor (), symAlg.BlockSize / 8, true);
+                       } finally {
+                               symAlg.Padding = bak;
+                       }
                }
 
                public void DecryptDocument ()
@@ -157,6 +167,9 @@ namespace System.Security.Cryptography.Xml {
 
                public virtual byte[] DecryptEncryptedKey (EncryptedKey encryptedKey)
                {
+                       if (encryptedKey == null)
+                               throw new ArgumentNullException ("encryptedKey");
+
                        object keyAlg = null;
                        foreach (KeyInfoClause innerClause in encryptedKey.KeyInfo) {
                                if (innerClause is KeyInfoName) {
@@ -165,7 +178,7 @@ namespace System.Security.Cryptography.Xml {
                                }
                        }
                        switch (encryptedKey.EncryptionMethod.KeyAlgorithm) {
-                       case XmlEncRSA1_5Url:
+                       case XmlEncRSA15Url:
                                return DecryptKey (encryptedKey.CipherData.CipherValue, (RSA) keyAlg, false);
                        case XmlEncRSAOAEPUrl:
                                return DecryptKey (encryptedKey.CipherData.CipherValue, (RSA) keyAlg, true);
@@ -175,6 +188,11 @@ namespace System.Security.Cryptography.Xml {
 
                public static byte[] DecryptKey (byte[] keyData, SymmetricAlgorithm symAlg)
                {
+                       if (keyData == null)
+                               throw new ArgumentNullException ("keyData");
+                       if (symAlg == null)
+                               throw new ArgumentNullException ("symAlg");
+
                        if (symAlg is TripleDES)
                                return SymmetricKeyWrap.TripleDESKeyWrapDecrypt (symAlg.Key, keyData);
                        if (symAlg is Rijndael)
@@ -206,12 +224,18 @@ namespace System.Security.Cryptography.Xml {
                        symAlg.GenerateKey ();
                        symAlg.GenerateIV ();
 
-                       SymmetricAlgorithm keyAlg = (SymmetricAlgorithm) keyNameMapping [keyName];
                        EncryptedData encryptedData = new EncryptedData ();
-
                        EncryptedKey encryptedKey = new EncryptedKey();
+
+                       object keyAlg = keyNameMapping [keyName];
+
                        encryptedKey.EncryptionMethod = new EncryptionMethod (GetKeyWrapAlgorithmUri (keyAlg));
-                       encryptedKey.CipherData = new CipherData (EncryptKey (symAlg.Key, keyAlg));
+
+                       if (keyAlg is RSA)
+                               encryptedKey.CipherData = new CipherData (EncryptKey (symAlg.Key, (RSA) keyAlg, false));
+                       else
+                               encryptedKey.CipherData = new CipherData (EncryptKey (symAlg.Key, (SymmetricAlgorithm) keyAlg));
+
                        encryptedKey.KeyInfo = new KeyInfo();
                        encryptedKey.KeyInfo.AddClause (new KeyInfoName (keyName));
                        
@@ -225,12 +249,28 @@ namespace System.Security.Cryptography.Xml {
                }
                
                [MonoTODO]
-               public EncryptedData Encrypt (XmlElement inputElement, X509CertificateEx certificate)
+               public EncryptedData Encrypt (XmlElement inputElement, X509Certificate2 certificate)
                {
                        throw new NotImplementedException ();
                }
 
                public byte[] EncryptData (byte[] plainText, SymmetricAlgorithm symAlg)
+               {
+                       if (plainText == null)
+                               throw new ArgumentNullException ("plainText");
+                       if (symAlg == null)
+                               throw new ArgumentNullException ("symAlg");
+
+                       PaddingMode bak = symAlg.Padding;
+                       try {
+                               symAlg.Padding = Padding;
+                               return EncryptDataCore (plainText, symAlg);
+                       } finally {
+                               symAlg.Padding = bak;
+                       }
+               }
+
+               byte[] EncryptDataCore (byte[] plainText, SymmetricAlgorithm symAlg)
                {
                        // Write the symmetric algorithm IV and ciphertext together.
                        // We use a memory stream to accomplish this.
@@ -251,6 +291,9 @@ namespace System.Security.Cryptography.Xml {
 
                public byte[] EncryptData (XmlElement inputElement, SymmetricAlgorithm symAlg, bool content)
                {
+                       if (inputElement == null)
+                               throw new ArgumentNullException ("inputElement");
+
                        if (content)
                                return EncryptData (Encoding.GetBytes (inputElement.InnerXml), symAlg);
                        else
@@ -259,6 +302,11 @@ namespace System.Security.Cryptography.Xml {
 
                public static byte[] EncryptKey (byte[] keyData, SymmetricAlgorithm symAlg)
                {
+                       if (keyData == null)
+                               throw new ArgumentNullException ("keyData");
+                       if (symAlg == null)
+                               throw new ArgumentNullException ("symAlg");
+
                        if (symAlg is TripleDES)
                                return SymmetricKeyWrap.TripleDESKeyWrapEncrypt (symAlg.Key, keyData);
                        if (symAlg is Rijndael)
@@ -306,7 +354,7 @@ namespace System.Security.Cryptography.Xml {
                                symAlg = SymmetricAlgorithm.Create ("TripleDES");
                                break;
                        default:
-                               throw new ArgumentException ("symAlgUri");
+                               throw new CryptographicException ("symAlgUri");
                        }
 
                        return symAlg;
@@ -333,11 +381,11 @@ namespace System.Security.Cryptography.Xml {
                        throw new ArgumentException ("symAlg");
                }
 
-               private static string GetKeyWrapAlgorithmUri (SymmetricAlgorithm symAlg)
+               private static string GetKeyWrapAlgorithmUri (object keyAlg)
                {
-                       if (symAlg is Rijndael)
+                       if (keyAlg is Rijndael)
                        {
-                               switch (symAlg.KeySize) {
+                               switch (((Rijndael) keyAlg).KeySize) {
                                case 128:
                                        return XmlEncAES128KeyWrapUrl;
                                case 192:
@@ -346,14 +394,19 @@ namespace System.Security.Cryptography.Xml {
                                        return XmlEncAES256KeyWrapUrl;
                                }
                        }
-                       else if (symAlg is TripleDES)
+                       else if (keyAlg is RSA) 
+                               return XmlEncRSA15Url;
+                       else if (keyAlg is TripleDES)
                                return XmlEncTripleDESKeyWrapUrl;
 
-                       throw new ArgumentException ("symAlg");
+                       throw new ArgumentException ("keyAlg");
                }
 
                public virtual byte[] GetDecryptionIV (EncryptedData encryptedData, string symAlgUri)
                {
+                       if (encryptedData == null)
+                               throw new ArgumentNullException ("encryptedData");
+
                        SymmetricAlgorithm symAlg = GetAlgorithm (symAlgUri);
                        byte[] iv = new Byte [symAlg.BlockSize / 8];
                        Buffer.BlockCopy (encryptedData.CipherData.CipherValue, 0, iv, 0, iv.Length);
@@ -362,6 +415,11 @@ namespace System.Security.Cryptography.Xml {
 
                public virtual SymmetricAlgorithm GetDecryptionKey (EncryptedData encryptedData, string symAlgUri)
                {
+                       if (encryptedData == null)
+                               throw new ArgumentNullException ("encryptedData");
+                       if (symAlgUri == null)
+                               return null;
+
                        SymmetricAlgorithm symAlg = GetAlgorithm (symAlgUri);
                        symAlg.IV = GetDecryptionIV (encryptedData, encryptedData.EncryptionMethod.KeyAlgorithm);
                        KeyInfo keyInfo = encryptedData.KeyInfo;
@@ -376,6 +434,9 @@ namespace System.Security.Cryptography.Xml {
 
                public virtual XmlElement GetIdElement (XmlDocument document, string idValue)
                {
+                       if ((document == null) || (idValue == null))
+                               return null;
+
                         // this works only if there's a DTD or XSD available to define the ID
                        XmlElement xel = document.GetElementById (idValue);
                        if (xel == null) {
@@ -387,6 +448,11 @@ namespace System.Security.Cryptography.Xml {
 
                public void ReplaceData (XmlElement inputElement, byte[] decryptedData)
                {
+                       if (inputElement == null)
+                               throw new ArgumentNullException ("inputElement");
+                       if (decryptedData == null)
+                               throw new ArgumentNullException ("decryptedData");
+
                        XmlDocument ownerDocument = inputElement.OwnerDocument;
                        XmlTextReader reader = new XmlTextReader (new StringReader (Encoding.GetString (decryptedData, 0, decryptedData.Length)));
                        reader.MoveToContent ();
@@ -396,24 +462,38 @@ namespace System.Security.Cryptography.Xml {
 
                public static void ReplaceElement (XmlElement inputElement, EncryptedData encryptedData, bool content)
                {
+                       if (inputElement == null)
+                               throw new ArgumentNullException ("inputElement");
+                       if (encryptedData == null)
+                               throw new ArgumentNullException ("encryptedData");
+
                        XmlDocument ownerDocument = inputElement.OwnerDocument;
                        inputElement.ParentNode.ReplaceChild (encryptedData.GetXml (ownerDocument), inputElement);
                }
 
                private byte[] Transform (byte[] data, ICryptoTransform transform)
                {
-                       return Transform (data, transform, 0);
+                       return Transform (data, transform, 0, false);
                }
 
-               private byte[] Transform (byte[] data, ICryptoTransform transform, int startIndex)
+               private byte[] Transform (byte[] data, ICryptoTransform transform, int blockOctetCount, bool trimPadding)
                {
                        MemoryStream output = new MemoryStream ();
                        CryptoStream crypto = new CryptoStream (output, transform, CryptoStreamMode.Write);
-                       crypto.Write (data, startIndex, data.Length - startIndex);
+                       crypto.Write (data, 0, data.Length);
 
                        crypto.FlushFinalBlock ();
 
-                       byte[] result = output.ToArray ();
+                       // strip padding (see xmlenc spec 5.2)
+                       int trimSize = 0;
+                       if (trimPadding)
+                               trimSize = output.GetBuffer () [output.Length - 1];
+                       // It should not happen, but somehow .NET allows such cipher 
+                       // data as if there were no padding.
+                       if (trimSize > blockOctetCount)
+                               trimSize = 0;
+                       byte[] result = new byte [output.Length - blockOctetCount - trimSize];
+                       Array.Copy (output.GetBuffer (), blockOctetCount, result, 0, result.Length);
 
                        crypto.Close ();
                        output.Close ();
@@ -425,4 +505,3 @@ namespace System.Security.Cryptography.Xml {
        }
 }
 
-#endif