[corlib] Fixes security tests failures
[mono.git] / mcs / class / System.Security / System.Security.Cryptography.Xml / EncryptedXml.cs
index c2a9218be2e2528d0b9050f76d53e48141227178..6dde9199a2ad36d5c5f9d8620f7444ff22760df2 100644 (file)
 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 //
 
-#if NET_2_0
 
 using System.Collections;
 using System.IO;
 using System.Security.Cryptography;
+using System.Security.Cryptography.X509Certificates;
 using System.Security.Policy;
 using System.Text;
 using System.Xml;
@@ -48,11 +48,11 @@ namespace System.Security.Cryptography.Xml {
                public const string XmlEncAES256KeyWrapUrl      = XmlEncNamespaceUrl + "kw-aes256";
                public const string XmlEncAES256Url             = XmlEncNamespaceUrl + "aes256-cbc";
                public const string XmlEncDESUrl                = XmlEncNamespaceUrl + "des-cbc";
-               public const string XmlEncElementContentUrl     = XmlEncNamespaceUrl + "ElementContent";
-               public const string XmlEncElementUrl            = XmlEncNamespaceUrl + "element";
+               public const string XmlEncElementContentUrl     = XmlEncNamespaceUrl + "Content";
+               public const string XmlEncElementUrl            = XmlEncNamespaceUrl + "Element";
                public const string XmlEncEncryptedKeyUrl       = XmlEncNamespaceUrl + "EncryptedKey";
                public const string XmlEncNamespaceUrl          = "http://www.w3.org/2001/04/xmlenc#";
-               public const string XmlEncRSA1_5Url             = XmlEncNamespaceUrl + "rsa-1_5";
+               public const string XmlEncRSA15Url              = XmlEncNamespaceUrl + "rsa-1_5";
                public const string XmlEncRSAOAEPUrl            = XmlEncNamespaceUrl + "rsa-oaep-mgf1p";
                public const string XmlEncSHA256Url             = XmlEncNamespaceUrl + "sha256";
                public const string XmlEncSHA512Url             = XmlEncNamespaceUrl + "sha512";
@@ -61,11 +61,12 @@ namespace System.Security.Cryptography.Xml {
 
                Evidence documentEvidence;
                Encoding encoding = Encoding.UTF8;
-               Hashtable keyNameMapping = new Hashtable ();
+               internal Hashtable keyNameMapping = new Hashtable ();
                CipherMode mode = CipherMode.CBC;
                PaddingMode padding = PaddingMode.ISO10126;
                string recipient;
                XmlResolver resolver;
+               XmlDocument document;
 
                #endregion // Fields
        
@@ -79,11 +80,13 @@ namespace System.Security.Cryptography.Xml {
                [MonoTODO]
                public EncryptedXml (XmlDocument document)
                {
+                       this.document = document;
                }
 
                [MonoTODO]
                public EncryptedXml (XmlDocument document, Evidence evidence)
                {
+                       this.document = document;
                        DocumentEvidence = evidence;
                }
        
@@ -137,49 +140,173 @@ namespace System.Security.Cryptography.Xml {
 
                public byte[] DecryptData (EncryptedData encryptedData, SymmetricAlgorithm symAlg)
                {
-                       return Transform (encryptedData.CipherData.CipherValue, symAlg.CreateDecryptor ());
+                       if (encryptedData == null)
+                               throw new ArgumentNullException ("encryptedData");
+                       if (symAlg == null)
+                               throw new ArgumentNullException ("symAlg");
+
+                       PaddingMode bak = symAlg.Padding;
+                       try {
+                               symAlg.Padding = Padding;
+                               return Transform (encryptedData.CipherData.CipherValue, symAlg.CreateDecryptor (), symAlg.BlockSize / 8, true);
+                       } finally {
+                               symAlg.Padding = bak;
+                       }
                }
 
-               [MonoTODO]
                public void DecryptDocument ()
                {
-                       throw new NotImplementedException ();
+                       XmlNodeList nodes = document.GetElementsByTagName ("EncryptedData", XmlEncNamespaceUrl);
+                       foreach (XmlNode node in nodes) {
+                               EncryptedData encryptedData = new EncryptedData ();
+                               encryptedData.LoadXml ((XmlElement) node);
+                               SymmetricAlgorithm symAlg = GetDecryptionKey (encryptedData, encryptedData.EncryptionMethod.KeyAlgorithm);
+                               ReplaceData ((XmlElement) node, DecryptData (encryptedData, symAlg));
+                       }
                }
 
-               [MonoTODO]
                public virtual byte[] DecryptEncryptedKey (EncryptedKey encryptedKey)
                {
-                       throw new NotImplementedException ();
+                       if (encryptedKey == null)
+                               throw new ArgumentNullException ("encryptedKey");
+
+                       object keyAlg = null;
+                       foreach (KeyInfoClause innerClause in encryptedKey.KeyInfo) {
+                               if (innerClause is KeyInfoName) {
+                                       keyAlg = keyNameMapping [((KeyInfoName) innerClause).Value];
+                                       break;
+                               }
+                       }
+                       switch (encryptedKey.EncryptionMethod.KeyAlgorithm) {
+                       case XmlEncRSA15Url:
+                               return DecryptKey (encryptedKey.CipherData.CipherValue, (RSA) keyAlg, false);
+                       case XmlEncRSAOAEPUrl:
+                               return DecryptKey (encryptedKey.CipherData.CipherValue, (RSA) keyAlg, true);
+                       }
+                       return DecryptKey (encryptedKey.CipherData.CipherValue, (SymmetricAlgorithm) keyAlg);
                }
 
-               [MonoTODO]
                public static byte[] DecryptKey (byte[] keyData, SymmetricAlgorithm symAlg)
                {
+                       if (keyData == null)
+                               throw new ArgumentNullException ("keyData");
+                       if (symAlg == null)
+                               throw new ArgumentNullException ("symAlg");
+
                        if (symAlg is TripleDES)
                                return SymmetricKeyWrap.TripleDESKeyWrapDecrypt (symAlg.Key, keyData);
                        if (symAlg is Rijndael)
-                               return SymmetricKeyWrap.TripleDESKeyWrapDecrypt (symAlg.Key, keyData);
-
+                               return SymmetricKeyWrap.AESKeyWrapDecrypt (symAlg.Key, keyData);
                        throw new CryptographicException ("The specified cryptographic transform is not supported.");
                }
 
-               [MonoTODO]
+               [MonoTODO ("Test this.")]
                public static byte[] DecryptKey (byte[] keyData, RSA rsa, bool fOAEP)
+               {
+                       AsymmetricKeyExchangeDeformatter deformatter = null;
+                       if (fOAEP) 
+                               deformatter = new RSAOAEPKeyExchangeDeformatter (rsa);
+                       else
+                               deformatter = new RSAPKCS1KeyExchangeDeformatter (rsa);
+                       return deformatter.DecryptKeyExchange (keyData);
+               }
+
+               public EncryptedData Encrypt (XmlElement inputElement, string keyName)
+               {
+                       // There are two keys of note here.
+                       // 1) KeyAlg: the key-encryption-key is used to wrap a key.  The keyName
+                       //    parameter will give us the KEK.
+                       // 2) SymAlg: A 256-bit AES key will be generated to encrypt the contents.
+                       //    This key will be wrapped using the KEK.
+
+                       SymmetricAlgorithm symAlg = SymmetricAlgorithm.Create ("Rijndael");
+                       symAlg.KeySize = 256;
+                       symAlg.GenerateKey ();
+                       symAlg.GenerateIV ();
+
+                       EncryptedData encryptedData = new EncryptedData ();
+                       EncryptedKey encryptedKey = new EncryptedKey();
+
+                       object keyAlg = keyNameMapping [keyName];
+
+                       encryptedKey.EncryptionMethod = new EncryptionMethod (GetKeyWrapAlgorithmUri (keyAlg));
+
+                       if (keyAlg is RSA)
+                               encryptedKey.CipherData = new CipherData (EncryptKey (symAlg.Key, (RSA) keyAlg, false));
+                       else
+                               encryptedKey.CipherData = new CipherData (EncryptKey (symAlg.Key, (SymmetricAlgorithm) keyAlg));
+
+                       encryptedKey.KeyInfo = new KeyInfo();
+                       encryptedKey.KeyInfo.AddClause (new KeyInfoName (keyName));
+                       
+                       encryptedData.Type = XmlEncElementUrl;
+                       encryptedData.EncryptionMethod = new EncryptionMethod (GetAlgorithmUri (symAlg));
+                       encryptedData.KeyInfo = new KeyInfo ();
+                       encryptedData.KeyInfo.AddClause (new KeyInfoEncryptedKey (encryptedKey));
+                       encryptedData.CipherData = new CipherData (EncryptData (inputElement, symAlg, false));
+
+                       return encryptedData;
+               }
+               
+               [MonoTODO]
+               public EncryptedData Encrypt (XmlElement inputElement, X509Certificate2 certificate)
                {
                        throw new NotImplementedException ();
                }
 
+               public byte[] EncryptData (byte[] plainText, SymmetricAlgorithm symAlg)
+               {
+                       if (plainText == null)
+                               throw new ArgumentNullException ("plainText");
+                       if (symAlg == null)
+                               throw new ArgumentNullException ("symAlg");
+
+                       PaddingMode bak = symAlg.Padding;
+                       try {
+                               symAlg.Padding = Padding;
+                               return EncryptDataCore (plainText, symAlg);
+                       } finally {
+                               symAlg.Padding = bak;
+                       }
+               }
+
+               byte[] EncryptDataCore (byte[] plainText, SymmetricAlgorithm symAlg)
+               {
+                       // Write the symmetric algorithm IV and ciphertext together.
+                       // We use a memory stream to accomplish this.
+                       MemoryStream stream = new MemoryStream ();
+                       BinaryWriter writer = new BinaryWriter (stream);
+
+                       writer.Write (symAlg.IV);
+                       writer.Write (Transform (plainText, symAlg.CreateEncryptor ()));
+                       writer.Flush ();
+
+                       byte [] output = stream.ToArray ();
+
+                       writer.Close ();
+                       stream.Close ();
+
+                       return output;
+               }
+
                public byte[] EncryptData (XmlElement inputElement, SymmetricAlgorithm symAlg, bool content)
                {
+                       if (inputElement == null)
+                               throw new ArgumentNullException ("inputElement");
+
                        if (content)
-                               return Transform (Encoding.GetBytes (inputElement.InnerXml), symAlg.CreateEncryptor ());
+                               return EncryptData (Encoding.GetBytes (inputElement.InnerXml), symAlg);
                        else
-                               return Transform (Encoding.GetBytes (inputElement.OuterXml), symAlg.CreateEncryptor ());
+                               return EncryptData (Encoding.GetBytes (inputElement.OuterXml), symAlg);
                }
 
-               [MonoTODO ("Do we need to support more algorithms?")]
                public static byte[] EncryptKey (byte[] keyData, SymmetricAlgorithm symAlg)
                {
+                       if (keyData == null)
+                               throw new ArgumentNullException ("keyData");
+                       if (symAlg == null)
+                               throw new ArgumentNullException ("symAlg");
+
                        if (symAlg is TripleDES)
                                return SymmetricKeyWrap.TripleDESKeyWrapEncrypt (symAlg.Key, keyData);
                        if (symAlg is Rijndael)
@@ -188,10 +315,15 @@ namespace System.Security.Cryptography.Xml {
                        throw new CryptographicException ("The specified cryptographic transform is not supported.");
                }
 
-               [MonoTODO ("Not sure what this is for.")]
+               [MonoTODO ("Test this.")]
                public static byte[] EncryptKey (byte[] keyData, RSA rsa, bool fOAEP)
                {
-                       throw new NotImplementedException ();
+                       AsymmetricKeyExchangeFormatter formatter = null;
+                       if (fOAEP) 
+                               formatter = new RSAOAEPKeyExchangeFormatter (rsa);
+                       else
+                               formatter = new RSAPKCS1KeyExchangeFormatter (rsa);
+                       return formatter.CreateKeyExchange (keyData);
                }
 
                private static SymmetricAlgorithm GetAlgorithm (string symAlgUri)
@@ -200,14 +332,17 @@ namespace System.Security.Cryptography.Xml {
 
                        switch (symAlgUri) {
                        case XmlEncAES128Url:
+                       case XmlEncAES128KeyWrapUrl:
                                symAlg = SymmetricAlgorithm.Create ("Rijndael");
                                symAlg.KeySize = 128;
                                break;
                        case XmlEncAES192Url:
+                       case XmlEncAES192KeyWrapUrl:
                                symAlg = SymmetricAlgorithm.Create ("Rijndael");
                                symAlg.KeySize = 192;
                                break;
                        case XmlEncAES256Url:
+                       case XmlEncAES256KeyWrapUrl:
                                symAlg = SymmetricAlgorithm.Create ("Rijndael");
                                symAlg.KeySize = 256;
                                break;
@@ -215,33 +350,93 @@ namespace System.Security.Cryptography.Xml {
                                symAlg = SymmetricAlgorithm.Create ("DES");
                                break;
                        case XmlEncTripleDESUrl:
+                       case XmlEncTripleDESKeyWrapUrl:
                                symAlg = SymmetricAlgorithm.Create ("TripleDES");
                                break;
                        default:
-                               throw new ArgumentException ("symAlgUri");
+                               throw new CryptographicException ("symAlgUri");
                        }
 
                        return symAlg;
                }
 
-               [MonoTODO]
+               private static string GetAlgorithmUri (SymmetricAlgorithm symAlg)
+               {
+                       if (symAlg is Rijndael)
+                       {
+                               switch (symAlg.KeySize) {
+                               case 128:
+                                       return XmlEncAES128Url;
+                               case 192:
+                                       return XmlEncAES192Url;
+                               case 256:
+                                       return XmlEncAES256Url;
+                               }
+                       }
+                       else if (symAlg is DES)
+                               return XmlEncDESUrl;
+                       else if (symAlg is TripleDES)
+                               return XmlEncTripleDESUrl;
+
+                       throw new ArgumentException ("symAlg");
+               }
+
+               private static string GetKeyWrapAlgorithmUri (object keyAlg)
+               {
+                       if (keyAlg is Rijndael)
+                       {
+                               switch (((Rijndael) keyAlg).KeySize) {
+                               case 128:
+                                       return XmlEncAES128KeyWrapUrl;
+                               case 192:
+                                       return XmlEncAES192KeyWrapUrl;
+                               case 256:
+                                       return XmlEncAES256KeyWrapUrl;
+                               }
+                       }
+                       else if (keyAlg is RSA) 
+                               return XmlEncRSA15Url;
+                       else if (keyAlg is TripleDES)
+                               return XmlEncTripleDESKeyWrapUrl;
+
+                       throw new ArgumentException ("keyAlg");
+               }
+
                public virtual byte[] GetDecryptionIV (EncryptedData encryptedData, string symAlgUri)
                {
-                       SymmetricAlgorithm symAlg = GetAlgorithm (symAlgUri);
+                       if (encryptedData == null)
+                               throw new ArgumentNullException ("encryptedData");
 
-                       throw new NotImplementedException ();
+                       SymmetricAlgorithm symAlg = GetAlgorithm (symAlgUri);
+                       byte[] iv = new Byte [symAlg.BlockSize / 8];
+                       Buffer.BlockCopy (encryptedData.CipherData.CipherValue, 0, iv, 0, iv.Length);
+                       return iv;
                }
 
-               [MonoTODO]
                public virtual SymmetricAlgorithm GetDecryptionKey (EncryptedData encryptedData, string symAlgUri)
                {
-                       SymmetricAlgorithm symAlg = GetAlgorithm (symAlgUri);
+                       if (encryptedData == null)
+                               throw new ArgumentNullException ("encryptedData");
+                       if (symAlgUri == null)
+                               return null;
 
-                       throw new NotImplementedException ();
+                       SymmetricAlgorithm symAlg = GetAlgorithm (symAlgUri);
+                       symAlg.IV = GetDecryptionIV (encryptedData, encryptedData.EncryptionMethod.KeyAlgorithm);
+                       KeyInfo keyInfo = encryptedData.KeyInfo;
+                       foreach (KeyInfoClause clause in keyInfo) {
+                               if (clause is KeyInfoEncryptedKey) {
+                                       symAlg.Key = DecryptEncryptedKey (((KeyInfoEncryptedKey) clause).EncryptedKey);
+                                       break;
+                               }
+                       }
+                       return symAlg;
                }
 
                public virtual XmlElement GetIdElement (XmlDocument document, string idValue)
                {
+                       if ((document == null) || (idValue == null))
+                               return null;
+
                         // this works only if there's a DTD or XSD available to define the ID
                        XmlElement xel = document.GetElementById (idValue);
                        if (xel == null) {
@@ -251,25 +446,62 @@ namespace System.Security.Cryptography.Xml {
                        return xel;
                }
 
-               [MonoTODO]
+               public void ReplaceData (XmlElement inputElement, byte[] decryptedData)
+               {
+                       if (inputElement == null)
+                               throw new ArgumentNullException ("inputElement");
+                       if (decryptedData == null)
+                               throw new ArgumentNullException ("decryptedData");
+
+                       XmlDocument ownerDocument = inputElement.OwnerDocument;
+                       XmlTextReader reader = new XmlTextReader (new StringReader (Encoding.GetString (decryptedData, 0, decryptedData.Length)));
+                       reader.MoveToContent ();
+                       XmlNode node = ownerDocument.ReadNode (reader);
+                       inputElement.ParentNode.ReplaceChild (node, inputElement);
+               }
+
                public static void ReplaceElement (XmlElement inputElement, EncryptedData encryptedData, bool content)
                {
-                       throw new NotImplementedException ();
+                       if (inputElement == null)
+                               throw new ArgumentNullException ("inputElement");
+                       if (encryptedData == null)
+                               throw new ArgumentNullException ("encryptedData");
+
+                       XmlDocument ownerDocument = inputElement.OwnerDocument;
+                       inputElement.ParentNode.ReplaceChild (encryptedData.GetXml (ownerDocument), inputElement);
                }
 
                private byte[] Transform (byte[] data, ICryptoTransform transform)
+               {
+                       return Transform (data, transform, 0, false);
+               }
+
+               private byte[] Transform (byte[] data, ICryptoTransform transform, int blockOctetCount, bool trimPadding)
                {
                        MemoryStream output = new MemoryStream ();
                        CryptoStream crypto = new CryptoStream (output, transform, CryptoStreamMode.Write);
                        crypto.Write (data, 0, data.Length);
+
+                       crypto.FlushFinalBlock ();
+
+                       // strip padding (see xmlenc spec 5.2)
+                       int trimSize = 0;
+                       if (trimPadding)
+                               trimSize = output.GetBuffer () [output.Length - 1];
+                       // It should not happen, but somehow .NET allows such cipher 
+                       // data as if there were no padding.
+                       if (trimSize > blockOctetCount)
+                               trimSize = 0;
+                       byte[] result = new byte [output.Length - blockOctetCount - trimSize];
+                       Array.Copy (output.GetBuffer (), blockOctetCount, result, 0, result.Length);
+
                        crypto.Close ();
                        output.Close ();
 
-                       return output.ToArray ();
+                       return result;
                }
 
                #endregion // Methods
        }
 }
 
-#endif