Merge pull request #495 from nicolas-raoul/fix-for-issue2907-with-no-formatting-changes
[mono.git] / mcs / class / System.Runtime.Serialization / System.Runtime.Serialization / KnownTypeCollection.cs
index 946e5aa5af5e7c1c9ebb9b553d69bfd79152ff83..43b846aaad72480e334680b124f71208eb7a3f83 100755 (executable)
@@ -74,14 +74,16 @@ namespace System.Runtime.Serialization
          exists (and raises InvalidOperationException if required).
 
 */
+
        internal static class TypeExtensions
        {
+#if !NET_4_5
                public static T GetCustomAttribute<T> (this MemberInfo type, bool inherit)
                {
                        var arr = type.GetCustomAttributes (typeof (T), inherit);
                        return arr != null && arr.Length == 1 ? (T) arr [0] : default (T);
                }
-
+#endif
                public static IEnumerable<Type> GetInterfacesOrSelfInterface (this Type type)
                {
                        if (type.IsInterface)
@@ -89,6 +91,20 @@ namespace System.Runtime.Serialization
                        foreach (var t in type.GetInterfaces ())
                                yield return t;
                }
+
+               public static bool ImplementsInterface (this Type type, Type iface)
+               {
+                       foreach (var t in type.GetInterfacesOrSelfInterface ()) {
+                               if (t == iface)
+                                       return true;
+                       }
+
+                       var baseType = type.BaseType;
+                       if (baseType != null)
+                               return baseType.ImplementsInterface (iface);
+                       
+                       return false;
+               }
        }
 
        internal sealed class KnownTypeCollection : Collection<Type>
@@ -571,8 +587,10 @@ namespace System.Runtime.Serialization
                                return qname;
 
                        Type element = GetCollectionElementType (type);
-                       if (element != null)
-                               return GetCollectionQName (element);
+                       if (element != null) {
+                               if (type.IsInterface || IsCustomCollectionType (type, element))
+                                       return GetCollectionQName (element);
+                       }
 
                        if (GetAttribute<SerializableAttribute> (type) != null)
                                return GetSerializableQName (type);
@@ -775,7 +793,7 @@ namespace System.Runtime.Serialization
                                if (i.IsGenericType && i.GetGenericTypeDefinition ().Equals (typeof (IEnumerable<>)))
                                        return i.GetGenericArguments () [0];
                        foreach (Type i in ifaces)
-                               if (i == typeof (IList))
+                               if (i == typeof (IEnumerable))
                                        return typeof (object);
                        return null;
                }
@@ -794,7 +812,9 @@ namespace System.Runtime.Serialization
 
                        Type element = GetCollectionElementType (type);
                        if (element == null)
-                               throw new InvalidOperationException (String.Format ("Type '{0}' is marked as collection contract, but it is not a collection", type));
+                               throw new InvalidDataContractException (String.Format ("Type '{0}' is marked as collection contract, but it is not a collection", type));
+                       if (type.GetMethod ("Add", new Type[] { element }) == null)
+                               throw new InvalidDataContractException (String.Format ("Type '{0}' is marked as collection contract, but missing a public \"Add\" method", type));
 
                        TryRegister (element); // must be registered before the name conflict check.
 
@@ -820,6 +840,17 @@ namespace System.Runtime.Serialization
 
                        TryRegister (element);
 
+                       /*
+                        * To qualify as a custom collection type, a type must have
+                        * a public parameterless constructor and an "Add" method
+                        * with the correct parameter type in addition to implementing
+                        * one of the collection interfaces.
+                        * 
+                        */
+
+                       if (!type.IsArray && type.IsClass && !IsCustomCollectionType (type, element))
+                               return null;
+
                        QName qname = GetCollectionQName (element);
 
                        var map = FindUserMap (qname, element);
@@ -836,6 +867,86 @@ namespace System.Runtime.Serialization
                        return ret;
                }
 
+               static bool IsCustomCollectionType (Type type, Type elementType)
+               {
+                       if (!type.IsClass)
+                               return false;
+                       if (type.GetConstructor (new Type [0]) == null)
+                               return false;
+                       if (type.GetMethod ("Add", new Type[] { elementType }) == null)
+                               return false;
+
+                       return true;
+               }
+
+               internal static bool IsInterchangeableCollectionType (Type contractType, Type graphType,
+                                                                     out QName collectionQName)
+               {
+                       collectionQName = null;
+                       if (GetAttribute<CollectionDataContractAttribute> (contractType) != null)
+                               return false;
+
+                       var type = contractType;
+                       if (type.IsGenericType)
+                               type = type.GetGenericTypeDefinition ();
+
+                       var elementType = GetCollectionElementType (contractType);
+                       if (elementType == null)
+                               return false;
+                       
+                       if (contractType.IsArray) {
+                               if (!graphType.IsArray || !elementType.Equals (graphType.GetElementType ()))
+                                       throw new InvalidCastException (String.Format ("Type '{0}' cannot be converted into '{1}'.", graphType.GetElementType (), elementType));
+                       } else if (!contractType.IsInterface) {
+                               if (GetAttribute<SerializableAttribute> (contractType) == null)
+                                       return false;
+
+                               var graphElementType = GetCollectionElementType (graphType);
+                               if (elementType != graphElementType)
+                                       return false;
+
+                               if (!IsCustomCollectionType (contractType, elementType))
+                                       return false;
+                       } else if (type.Equals (typeof (IEnumerable)) || type.Equals (typeof (IList)) ||
+                                  type.Equals (typeof (ICollection))) {
+                               if (!graphType.ImplementsInterface (contractType))
+                                       return false;
+                       } else if (type.Equals (typeof (IEnumerable<>)) || type.Equals (typeof (IList<>)) ||
+                                  type.Equals (typeof (ICollection<>))) {
+                               var graphElementType = GetCollectionElementType (graphType);
+                               if (graphElementType != elementType)
+                                       throw new InvalidCastException (String.Format (
+                                               "Cannot convert type '{0}' into '{1}'.", graphType, contractType));
+
+                               if (!graphType.ImplementsInterface (contractType))
+                                       return false;
+                       } else {
+                               return false;
+                       }
+
+                       collectionQName = GetCollectionQName (elementType);
+                       return true;
+               }
+
+               static bool ImplementsInterface (Type type, Type iface)
+               {
+                       foreach (var i in type.GetInterfacesOrSelfInterface ())
+                               if (iface == i)
+                                       return true;
+                                       
+                       return false;
+               }
+
+
+               static bool TypeImplementsIEnumerable (Type type)
+               {
+                       foreach (var iface in type.GetInterfacesOrSelfInterface ())
+                               if (iface == typeof (IEnumerable) || (iface.IsGenericType && iface.GetGenericTypeDefinition () == typeof (IEnumerable<>)))
+                                       return true;
+                       
+                       return false;
+               }
+
                static bool TypeImplementsIDictionary (Type type)
                {
                        foreach (var iface in type.GetInterfacesOrSelfInterface ())