Contract-based type must ensure that the base type is also contract-based.
[mono.git] / mcs / class / System.Runtime.Serialization / System.Runtime.Serialization / KnownTypeCollection.cs
old mode 100644 (file)
new mode 100755 (executable)
index 513564c..66fb3c5
@@ -30,6 +30,7 @@ using System;
 using System.Collections;
 using System.Collections.Generic;
 using System.Collections.ObjectModel;
+using System.Linq;
 using System.Reflection;
 using System.Xml;
 using System.Xml.Schema;
@@ -73,12 +74,23 @@ namespace System.Runtime.Serialization
          exists (and raises InvalidOperationException if required).
 
 */
+       internal static class TypeExtensions
+       {
+               public static T GetCustomAttribute<T> (this Type type, bool inherit)
+               {
+                       var arr = type.GetCustomAttributes (typeof (T), inherit);
+                       return arr != null && arr.Length == 1 ? (T) arr [0] : default (T);
+               }
+       }
+
        internal sealed class KnownTypeCollection : Collection<Type>
        {
                internal const string MSSimpleNamespace =
                        "http://schemas.microsoft.com/2003/10/Serialization/";
                internal const string MSArraysNamespace =
                        "http://schemas.microsoft.com/2003/10/Serialization/Arrays";
+               internal const string DefaultClrNamespaceBase =
+                       "http://schemas.datacontract.org/2004/07/";
 
                static QName any_type, bool_type,
                        byte_type, date_type, decimal_type, double_type,
@@ -224,7 +236,7 @@ namespace System.Runtime.Serialization
                        case TypeCode.Boolean:
                                return XmlConvert.ToString ((bool) obj);
                        case TypeCode.Byte:
-                               return XmlConvert.ToString ((byte) obj);
+                               return XmlConvert.ToString ((int)((byte) obj));
                        case TypeCode.Char:
                                return XmlConvert.ToString ((uint) (char) obj);
                        case TypeCode.DateTime:
@@ -246,7 +258,7 @@ namespace System.Runtime.Serialization
                        case TypeCode.String:
                                return (string) obj;
                        case TypeCode.UInt16:
-                               return XmlConvert.ToString ((ushort) obj);
+                               return XmlConvert.ToString ((int) (ushort) obj);
                        case TypeCode.UInt32:
                                return XmlConvert.ToString ((uint) obj);
                        case TypeCode.UInt64:
@@ -254,34 +266,54 @@ namespace System.Runtime.Serialization
                        }
                }
 
-               internal static bool IsPrimitiveType (QName qname)
+               // FIXME: xsd types and ms serialization types should be differentiated.
+               internal static Type GetPrimitiveTypeFromName (string name)
                {
-                       /* FIXME: qname.Namespace ? */
-                       switch (qname.Name) {
+                       switch (name) {
                        case "anyURI":
+                               return typeof (Uri);
                        case "boolean":
+                               return typeof (bool);
                        case "base64Binary":
+                               return typeof (byte []);
                        case "dateTime":
+                               return typeof (DateTime);
                        case "duration":
+                               return typeof (TimeSpan);
                        case "QName":
+                               return typeof (QName);
                        case "decimal":
+                               return typeof (decimal);
                        case "double":
+                               return typeof (double);
                        case "float":
+                               return typeof (float);
                        case "byte":
+                               return typeof (sbyte);
                        case "short":
+                               return typeof (short);
                        case "int":
+                               return typeof (int);
                        case "long":
+                               return typeof (long);
                        case "unsignedByte":
+                               return typeof (byte);
                        case "unsignedShort":
+                               return typeof (ushort);
                        case "unsignedInt":
+                               return typeof (uint);
                        case "unsignedLong":
+                               return typeof (ulong);
                        case "string":
+                               return typeof (string);
                        case "anyType":
+                               return typeof (object);
                        case "guid":
+                               return typeof (Guid);
                        case "char":
-                               return true;
+                               return typeof (char);
                        default:
-                               return false;
+                               return null;
                        }
                }
 
@@ -354,7 +386,7 @@ namespace System.Runtime.Serialization
 
                protected override void InsertItem (int index, Type type)
                {
-                       if (TryRegister (type))
+                       if (!Contains (type) && TryRegister (type))
                                base.InsertItem (index, type);
                }
 
@@ -375,21 +407,27 @@ namespace System.Runtime.Serialization
 
                protected override void SetItem (int index, Type type)
                {
-                       if (index == Count)
-                               InsertItem (index, type);
-                       else {
+                       if (index != Count)
                                RemoveItem (index);
-                               if (TryRegister (type))
-                                       base.InsertItem (index - 1, type);
-                       }
+                       if (TryRegister (type))
+                               base.InsertItem (index - 1, type);
                }
 
                internal SerializationMap FindUserMap (QName qname)
                {
-                       for (int i = 0; i < contracts.Count; i++)
-                               if (qname == contracts [i].XmlName)
-                                       return contracts [i];
-                       return null;
+                       return contracts.FirstOrDefault (c => c.XmlName == qname);
+               }
+
+               internal Type GetSerializedType (Type type)
+               {
+                       Type element = GetCollectionElementType (type);
+                       if (element == null)
+                               return type;
+                       QName name = GetQName (type);
+                       var map = FindUserMap (name);
+                       if (map != null)
+                               return map.RuntimeType;
+                       return type;
                }
 
                internal SerializationMap FindUserMap (Type type)
@@ -402,13 +440,17 @@ namespace System.Runtime.Serialization
 
                internal QName GetQName (Type type)
                {
-                       if (IsPrimitiveNotEnum (type))
-                               return GetPrimitiveTypeName (type);
-
                        SerializationMap map = FindUserMap (type);
                        if (map != null)
                                // already mapped.
                                return map.XmlName; 
+                       return GetStaticQName (type);
+               }
+
+               public static QName GetStaticQName (Type type)
+               {
+                       if (IsPrimitiveNotEnum (type))
+                               return GetPrimitiveTypeName (type);
 
                        if (type.IsEnum)
                                return GetEnumQName (type);
@@ -422,64 +464,89 @@ namespace System.Runtime.Serialization
                                //need name of the type..
                                return GetSerializableQName (type);
 
+                       qname = GetCollectionContractQName (type);
+                       if (qname != null)
+                               return qname;
+
                        Type element = GetCollectionElementType (type);
                        if (element != null)
                                return GetCollectionQName (element);
 
-                       if (type.GetCustomAttributes (typeof (SerializableAttribute), false).Length == 1)
+                       if (GetAttribute<SerializableAttribute> (type) != null)
                                return GetSerializableQName (type);
 
-                       // FIXME: it needs in-depth check.
-                       return QName.Empty;
+                       // default type map - still uses GetContractQName().
+                       return GetContractQName (type, null, null);
                }
-               
-               private QName GetContractQName (Type type)
+
+               internal static QName GetContractQName (Type type)
                {
-                       object [] atts = type.GetCustomAttributes (
-                               typeof (DataContractAttribute), false);
-                       if (atts.Length == 0)
-                               return null;
+                       var a = GetAttribute<DataContractAttribute> (type);
+                       return a == null ? null : GetContractQName (type, a.Name, a.Namespace);
+               }
 
-                       string name = ((DataContractAttribute) atts [0]).Name;
-                       if (name == null)
-                               // FIXME: there could be decent ways to get 
-                               // the same result...
-                               name = type.Namespace == null || type.Namespace.Length == 0 ? type.Name : type.FullName.Substring (type.Namespace.Length + 1).Replace ('+', '.');
+               static QName GetCollectionContractQName (Type type)
+               {
+                       var a = GetAttribute<CollectionDataContractAttribute> (type);
+                       return a == null ? null : GetContractQName (type, a.Name, a.Namespace);
+               }
 
-                       string ns = ((DataContractAttribute) atts [0]).Namespace;
+               static QName GetContractQName (Type type, string name, string ns)
+               {
+                       if (name == null)
+                               name = GetDefaultName (type);
                        if (ns == null)
-                               ns = XmlObjectSerializer.DefaultNamespaceBase + type.Namespace;
-
+                               ns = GetDefaultNamespace (type);
                        return new QName (name, ns);
                }
 
-               private QName GetEnumQName (Type type)
+               static QName GetEnumQName (Type type)
                {
                        string name = null, ns = null;
 
                        if (!type.IsEnum)
                                return null;
 
-                       object [] atts = type.GetCustomAttributes (
-                               typeof (DataContractAttribute), false);
+                       var dca = GetAttribute<DataContractAttribute> (type);
 
-                       if (atts.Length != 0) {
-                               ns = ((DataContractAttribute) atts [0]).Namespace;
-                               name = ((DataContractAttribute) atts [0]).Name;
+                       if (dca != null) {
+                               ns = dca.Namespace;
+                               name = dca.Name;
                        }
 
                        if (ns == null)
-                               ns = XmlObjectSerializer.DefaultNamespaceBase + type.Namespace;
+                               ns = GetDefaultNamespace (type);
 
                        if (name == null)
-                               name = type.Namespace == null || type.Namespace.Length == 0 ? type.Name : type.FullName.Substring (type.Namespace.Length + 1).Replace ('+', '.');
+                               name = type.Namespace == null ? type.Name : type.FullName.Substring (type.Namespace.Length + 1).Replace ('+', '.');
 
                        return new QName (name, ns);
                }
 
-               private QName GetCollectionQName (Type element)
+               internal static string GetDefaultName (Type type)
+               {
+                       // FIXME: there could be decent ways to get 
+                       // the same result...
+                       string name = type.Namespace == null || type.Namespace.Length == 0 ? type.Name : type.FullName.Substring (type.Namespace.Length + 1).Replace ('+', '.');
+                       if (type.IsGenericType) {
+                               name = name.Substring (0, name.IndexOf ('`')) + "Of";
+                               foreach (var t in type.GetGenericArguments ())
+                                       name += t.Name; // FIXME: check namespaces too
+                       }
+                       return name;
+               }
+
+               internal static string GetDefaultNamespace (Type type)
+               {
+                       foreach (ContractNamespaceAttribute a in type.Assembly.GetCustomAttributes (typeof (ContractNamespaceAttribute), true))
+                               if (a.ClrNamespace == type.Namespace)
+                                       return a.ContractNamespace;
+                       return DefaultClrNamespaceBase + type.Namespace;
+               }
+
+               static QName GetCollectionQName (Type element)
                {
-                       QName eqname = GetQName (element);
+                       QName eqname = GetStaticQName (element);
                        
                        string ns = eqname.Namespace;
                        if (eqname.Namespace == MSSimpleNamespace)
@@ -491,22 +558,40 @@ namespace System.Runtime.Serialization
                                ns);
                }
 
-               private QName GetSerializableQName (Type type)
+               static QName GetSerializableQName (Type type)
                {
+#if !NET_2_1
+                       // First, check XmlSchemaProviderAttribute and try GetSchema() to see if it returns a schema in the expected format.
+                       var xpa = type.GetCustomAttribute<XmlSchemaProviderAttribute> (true);
+                       if (xpa != null) {
+                               var mi = type.GetMethod (xpa.MethodName, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static);
+                               if (mi != null) {
+                                       try {
+                                               var xss = new XmlSchemaSet ();
+                                               return (XmlQualifiedName) mi.Invoke (null, new object [] {xss});
+                                       } catch {
+                                               // ignore.
+                                       }
+                               }
+                       }
+#endif
+
                        string xmlName = type.Name;
-                       string xmlNamespace = XmlObjectSerializer.DefaultNamespaceBase + type.Namespace;
-                       object [] xmlRootAttributes = type.GetCustomAttributes (typeof (XmlRootAttribute), false);
-                       if (xmlRootAttributes.Length > 1)
-                               throw new Exception ("Only one XmlRoot namespace allowed on type " + type.Name);
-                       if (xmlRootAttributes.Length == 1) {
-                               XmlRootAttribute rootAttribute = (XmlRootAttribute) xmlRootAttributes [0];
-                               xmlName = rootAttribute.ElementName;
-                               xmlNamespace = rootAttribute.Namespace;
+                       if (type.IsGenericType) {
+                               xmlName = xmlName.Substring (0, xmlName.IndexOf ('`')) + "Of";
+                               foreach (var t in type.GetGenericArguments ())
+                                       xmlName += GetStaticQName (t).Name; // FIXME: check namespaces too
+                       }
+                       string xmlNamespace = GetDefaultNamespace (type);
+                       var x = GetAttribute<XmlRootAttribute> (type);
+                       if (x != null) {
+                               xmlName = x.ElementName;
+                               xmlNamespace = x.Namespace;
                        }
-                       return new QName (XmlConvert.EncodeLocalName (xmlName), xmlNamespace);
+                       return new QName (XmlConvert.EncodeLocalName (xmlName), xmlNamespace);
                }
 
-               internal bool IsPrimitiveNotEnum (Type type)
+               static bool IsPrimitiveNotEnum (Type type)
                {
                        if (type.IsEnum)
                                return false;
@@ -532,62 +617,127 @@ namespace System.Runtime.Serialization
                        if (RegisterEnum (type) != null)
                                return true;
 
+                       if (RegisterDictionary (type) != null)
+                               return true;
+
+                       if (RegisterCollectionContract (type) != null)
+                               return true;
+
                        if (RegisterContract (type) != null)
                                return true;
 
                        if (RegisterIXmlSerializable (type) != null)
                                return true;
-                       
-                       Type element = GetCollectionElementType (type);
-                       if (element != null) {
-                               TryRegister (element);
-                               RegisterCollection (type, element);
+
+                       if (RegisterCollection (type) != null)
                                return true;
-                       }
 
-                       if (type.GetCustomAttributes (typeof (SerializableAttribute), false).Length == 1) {
+                       if (GetAttribute<SerializableAttribute> (type) != null) {
                                RegisterSerializable (type);
                                return true;
                        }
 
-                       throw new InvalidDataContractException (String.Format ("Type {0} has neither Serializable nor DataContract attributes.", type));
+                       RegisterDefaultTypeMap (type);
+                       return true;
                }
 
-               static readonly Type genericIEnumerable =
-                       typeof (IEnumerable<object>).GetGenericTypeDefinition ();
-
-               internal static Type GetCollectionElementType (Type type)
+               static Type GetCollectionElementType (Type type)
                {
                        if (type.IsArray)
                                return type.GetElementType ();
 
                        Type [] ifaces = type.GetInterfaces ();
-                       foreach (Type iface in ifaces) {
-                               Type t = iface;
-                               Type gt = t.IsGenericType ? 
-                                       t.GetGenericTypeDefinition () : null;
-                               if (gt == genericIEnumerable)
-                                       return t.GetGenericArguments () [0];
-                               foreach (Type i in ifaces)
-                                       if (i == typeof (IEnumerable))
-                                               return typeof (object);
-                       }
+                       foreach (Type i in ifaces)
+                               if (i.IsGenericType && i.GetGenericTypeDefinition ().Equals (typeof (ICollection<>)))
+                                       return i.GetGenericArguments () [0];
+                       foreach (Type i in ifaces)
+                               if (i == typeof (IList))
+                                       return typeof (object);
                        return null;
                }
 
-               private CollectionTypeMap RegisterCollection (Type type, Type element)
+               internal static T GetAttribute<T> (ICustomAttributeProvider ap) where T : Attribute
                {
-                       QName qname = GetCollectionQName (element);
+                       object [] atts = ap.GetCustomAttributes (typeof (T), false);
+                       return atts.Length == 0 ? null : (T) atts [0];
+               }
+
+               private CollectionContractTypeMap RegisterCollectionContract (Type type)
+               {
+                       var cdca = GetAttribute<CollectionDataContractAttribute> (type);
+                       if (cdca == null)
+                               return null;
+
+                       Type element = GetCollectionElementType (type);
+                       if (element == null)
+                               throw new InvalidOperationException (String.Format ("Type '{0}' is marked as collection contract, but it is not a collection", type));
 
+                       TryRegister (element); // must be registered before the name conflict check.
+
+                       QName qname = GetCollectionContractQName (type);
+                       CheckStandardQName (qname);
                        if (FindUserMap (qname) != null)
                                throw new InvalidOperationException (String.Format ("Failed to add type {0} to known type collection. There already is a registered type for XML name {1}", type, qname));
 
+                       var ret = new CollectionContractTypeMap (type, cdca, element, qname, this);
+                       contracts.Add (ret);
+                       return ret;
+               }
+
+               private CollectionTypeMap RegisterCollection (Type type)
+               {
+                       Type element = GetCollectionElementType (type);
+                       if (element == null)
+                               return null;
+
+                       TryRegister (element);
+
+                       QName qname = GetCollectionQName (element);
+
+                       var map = FindUserMap (qname);
+                       if (map != null) {
+                               var cmap = map as CollectionTypeMap;
+                               if (cmap == null || cmap.RuntimeType != type)
+                                       throw new InvalidOperationException (String.Format ("Failed to add type {0} to known type collection. There already is a registered type for XML name {1}", type, qname));
+                               return cmap;
+                       }
+
                        CollectionTypeMap ret =
                                new CollectionTypeMap (type, element, qname, this);
                        contracts.Add (ret);
                        return ret;
                }
 
+               static bool TypeImplementsIDictionary (Type type)
+               {
+                       foreach (var iface in type.GetInterfaces ())
+                               if (iface == typeof (IDictionary) || (iface.IsGenericType && iface.GetGenericTypeDefinition () == typeof (IDictionary<,>)))
+                                       return true;
+
+                       return false;
+               }
+
+               // it also supports contract-based dictionary.
+               private DictionaryTypeMap RegisterDictionary (Type type)
+               {
+                       if (!TypeImplementsIDictionary (type))
+                               return null;
+
+                       var cdca = GetAttribute<CollectionDataContractAttribute> (type);
+
+                       DictionaryTypeMap ret =
+                               new DictionaryTypeMap (type, cdca, this);
+
+                       if (FindUserMap (ret.XmlName) != null)
+                               throw new InvalidOperationException (String.Format ("Failed to add type {0} to known type collection. There already is a registered type for XML name {1}", type, ret.XmlName));
+                       contracts.Add (ret);
+
+                       TryRegister (ret.KeyType);
+                       TryRegister (ret.ValueType);
+
+                       return ret;
+               }
+
                private SerializationMap RegisterSerializable (Type type)
                {
                        QName qname = GetSerializableQName (type);
@@ -595,9 +745,9 @@ namespace System.Runtime.Serialization
                        if (FindUserMap (qname) != null)
                                throw new InvalidOperationException (String.Format ("There is already a registered type for XML name {0}", qname));
 
-                       SharedTypeMap ret =
-                               new SharedTypeMap (type, qname, this);
+                       SharedTypeMap ret = new SharedTypeMap (type, qname, this);
                        contracts.Add (ret);
+                       ret.Initialize ();
                        return ret;
                }
 
@@ -617,12 +767,8 @@ namespace System.Runtime.Serialization
                        return ret;
                }
 
-               private SharedContractMap RegisterContract (Type type)
+               void CheckStandardQName (QName qname)
                {
-                       QName qname = GetContractQName (type);
-                       if (qname == null)
-                               return null;
-
                        switch (qname.Namespace) {
                        case XmlSchema.Namespace:
                        case XmlSchema.InstanceNamespace:
@@ -631,12 +777,41 @@ namespace System.Runtime.Serialization
                                throw new InvalidOperationException (String.Format ("Namespace {0} is reserved and cannot be used for user serialization", qname.Namespace));
                        }
 
+               }
+
+               private SharedContractMap RegisterContract (Type type)
+               {
+                       QName qname = GetContractQName (type);
+                       if (qname == null)
+                               return null;
+                       CheckStandardQName (qname);
                        if (FindUserMap (qname) != null)
                                throw new InvalidOperationException (String.Format ("There is already a registered type for XML name {0}", qname));
 
-                       SharedContractMap ret =
-                               new SharedContractMap (type, qname, this);
+                       SharedContractMap ret = new SharedContractMap (type, qname, this);
+                       contracts.Add (ret);
+                       ret.Initialize ();
+
+                       if (type.BaseType != typeof (object)) {
+                               TryRegister (type.BaseType);
+                               if (!FindUserMap (type.BaseType).IsContractAllowedType)
+                                       throw new InvalidDataContractException (String.Format ("To be serializable by data contract, type '{0}' cannot inherit from non-contract and non-Serializable type '{1}'", type, type.BaseType));
+                       }
+
+                       object [] attrs = type.GetCustomAttributes (typeof (KnownTypeAttribute), true);
+                       for (int i = 0; i < attrs.Length; i++) {
+                               KnownTypeAttribute kt = (KnownTypeAttribute) attrs [i];
+                               TryRegister (kt.Type);
+                       }
+
+                       return ret;
+               }
+
+               DefaultTypeMap RegisterDefaultTypeMap (Type type)
+               {
+                       DefaultTypeMap ret = new DefaultTypeMap (type, this);
                        contracts.Add (ret);
+                       ret.Initialize ();
                        return ret;
                }