Merge pull request #4621 from alexanderkyte/strdup_env
[mono.git] / mcs / class / System.Data.Linq / src / DbLinq / Data / Linq / DataContext.cs
index bb49b9e35d2d6a7c8fb9a5ae86289d5573c7f003..db68b96777b3473655a9a97785e4452b77f7157f 100644 (file)
@@ -192,12 +192,12 @@ namespace DbLinq.Data.Linq
         /// </summary>\r
         /// <param name="connectionString">specifies file or server connection</param>\r
         [DbLinqToDo]\r
-        public DataContext(string connectionString)\r
+        public DataContext(string fileOrServerOrConnection)\r
         {\r
             Profiler.At("START DataContext(string)");\r
-            IVendor ivendor = GetVendor(ref connectionString);\r
+            IVendor ivendor = GetVendor(ref fileOrServerOrConnection);\r
 \r
-            IDbConnection dbConnection = ivendor.CreateDbConnection(connectionString);\r
+            IDbConnection dbConnection = ivendor.CreateDbConnection(fileOrServerOrConnection);\r
             Init(new DatabaseContext(dbConnection), null, ivendor);\r
 \r
             Profiler.At("END DataContext(string)");\r
@@ -236,18 +236,18 @@ namespace DbLinq.Data.Linq
             System.Text.RegularExpressions.Regex reProvider\r
                 = new System.Text.RegularExpressions.Regex(@"DbLinqProvider=([\w\.]+);?");\r
 \r
-            string assemblyFile = null;\r
+            string assemblyName = null;\r
             string vendor;\r
             if (!reProvider.IsMatch(connectionString))\r
             {\r
                 vendor       = "SqlServer";\r
-                assemblyFile = "DbLinq.SqlServer.dll";\r
+                assemblyName = "DbLinq.SqlServer";\r
             }\r
             else\r
             {\r
                 var match    = reProvider.Match(connectionString);\r
                 vendor       = match.Groups[1].Value;\r
-                assemblyFile = "DbLinq." + vendor + ".dll";\r
+                assemblyName = "DbLinq." + vendor;\r
 \r
                 //plain DbLinq - non MONO: \r
                 //IVendor classes are in DLLs such as "DbLinq.MySql.dll"\r
@@ -268,16 +268,15 @@ namespace DbLinq.Data.Linq
 #if MONO_STRICT\r
                 assembly = typeof (DataContext).Assembly; // System.Data.Linq.dll\r
 #else\r
-                //TODO: check if DLL is already loaded?\r
-                assembly = Assembly.LoadFrom(assemblyFile);\r
+                assembly = Assembly.Load(assemblyName);\r
 #endif\r
             }\r
             catch (Exception e)\r
             {\r
                 throw new ArgumentException(\r
                         string.Format(\r
-                            "Unable to load the `{0}' DbLinq vendor within assembly `{1}'.",\r
-                            assemblyFile, vendor),\r
+                            "Unable to load the `{0}' DbLinq vendor within assembly '{1}.dll'.",\r
+                            assemblyName, vendor),\r
                         "connectionString", e);\r
             }\r
         }\r
@@ -295,7 +294,11 @@ namespace DbLinq.Data.Linq
             _VendorProvider = ObjectFactory.Get<IVendorProvider>();\r
             Vendor = vendor ?? \r
                 (connectionString != null ? GetVendor(ref connectionString) : null) ??\r
+#if MOBILE\r
+                _VendorProvider.FindVendorByProviderType(typeof(DbLinq.Sqlite.SqliteSqlProvider));\r
+#else\r
                 _VendorProvider.FindVendorByProviderType(typeof(SqlClient.Sql2005Provider));\r
+#endif\r
             \r
             DatabaseContext = databaseContext;\r
 \r
@@ -394,57 +397,77 @@ namespace DbLinq.Data.Linq
             if (this.objectTrackingEnabled == false)\r
                 throw new InvalidOperationException("Object tracking is not enabled for the current data context instance.");\r
             using (DatabaseContext.OpenConnection()) //ConnMgr will close connection for us\r
-            using (IDatabaseTransaction transaction = DatabaseContext.Transaction())\r
             {\r
-                var queryContext = new QueryContext(this);\r
-\r
-                // There's no sense in updating an entity when it's going to \r
-                // be deleted in the current transaction, so do deletes first.\r
-                foreach (var entityTrack in CurrentTransactionEntities.EnumerateAll().ToList())\r
+                if (Transaction != null)\r
+                    SubmitChangesImpl(failureMode);\r
+                else\r
                 {\r
-                    switch (entityTrack.EntityState)\r
+                    using (IDbTransaction transaction = DatabaseContext.CreateTransaction())\r
                     {\r
-                        case EntityState.ToDelete:\r
-                            var deleteQuery = QueryBuilder.GetDeleteQuery(entityTrack.Entity, queryContext);\r
-                            QueryRunner.Delete(entityTrack.Entity, deleteQuery);\r
-\r
-                            UnregisterDelete(entityTrack.Entity);\r
-                            AllTrackedEntities.RegisterToDelete(entityTrack.Entity);\r
-                            AllTrackedEntities.RegisterDeleted(entityTrack.Entity);\r
-                            break;\r
-                        default:\r
-                            // ignore.\r
-                            break;\r
+                        try\r
+                        {\r
+                            Transaction = (DbTransaction) transaction;\r
+                            SubmitChangesImpl(failureMode);\r
+                            // TODO: handle conflicts (which can only occur when concurrency mode is implemented)\r
+                            transaction.Commit();\r
+                        }\r
+                        finally\r
+                        {\r
+                            Transaction = null;\r
+                        }\r
                     }\r
                 }\r
-                foreach (var entityTrack in CurrentTransactionEntities.EnumerateAll()\r
-                        .Concat(AllTrackedEntities.EnumerateAll())\r
-                        .ToList())\r
+            }\r
+        }\r
+\r
+        void SubmitChangesImpl(ConflictMode failureMode)\r
+        {\r
+            var queryContext = new QueryContext(this);\r
+\r
+            // There's no sense in updating an entity when it's going to \r
+            // be deleted in the current transaction, so do deletes first.\r
+            foreach (var entityTrack in CurrentTransactionEntities.EnumerateAll().ToList())\r
+            {\r
+                switch (entityTrack.EntityState)\r
                 {\r
-                    switch (entityTrack.EntityState)\r
-                    {\r
-                        case EntityState.ToInsert:\r
-                            foreach (var toInsert in GetReferencedObjects(entityTrack.Entity))\r
-                            {\r
-                                InsertEntity(toInsert, queryContext);\r
-                            }\r
-                            break;\r
-                        case EntityState.ToWatch:\r
-                            foreach (var toUpdate in GetReferencedObjects(entityTrack.Entity))\r
-                            {\r
-                                UpdateEntity(toUpdate, queryContext);\r
-                            }\r
-                            break;\r
-                        default:\r
-                            throw new ArgumentOutOfRangeException();\r
-                    }\r
+                    case EntityState.ToDelete:\r
+                        var deleteQuery = QueryBuilder.GetDeleteQuery(entityTrack.Entity, queryContext);\r
+                        QueryRunner.Delete(entityTrack.Entity, deleteQuery);\r
+\r
+                        UnregisterDelete(entityTrack.Entity);\r
+                        AllTrackedEntities.RegisterToDelete(entityTrack.Entity);\r
+                        AllTrackedEntities.RegisterDeleted(entityTrack.Entity);\r
+                        break;\r
+                    default:\r
+                        // ignore.\r
+                        break;\r
+                }\r
+            }\r
+            foreach (var entityTrack in CurrentTransactionEntities.EnumerateAll()\r
+                    .Concat(AllTrackedEntities.EnumerateAll())\r
+                    .ToList())\r
+            {\r
+                switch (entityTrack.EntityState)\r
+                {\r
+                    case EntityState.ToInsert:\r
+                        foreach (var toInsert in GetReferencedObjects(entityTrack.Entity))\r
+                        {\r
+                            InsertEntity(toInsert, queryContext);\r
+                        }\r
+                        break;\r
+                    case EntityState.ToWatch:\r
+                        foreach (var toUpdate in GetReferencedObjects(entityTrack.Entity))\r
+                        {\r
+                            UpdateEntity(toUpdate, queryContext);\r
+                        }\r
+                        break;\r
+                    default:\r
+                        throw new ArgumentOutOfRangeException();\r
                 }\r
-                // TODO: handle conflicts (which can only occur when concurrency mode is implemented)\r
-                transaction.Commit();\r
             }\r
         }\r
 \r
-        private static IEnumerable<object> GetReferencedObjects(object value)\r
+        private IEnumerable<object> GetReferencedObjects(object value)\r
         {\r
             var values = new EntitySet<object>();\r
             FillReferencedObjects(value, values);\r
@@ -452,34 +475,35 @@ namespace DbLinq.Data.Linq
         }\r
 \r
         // Breadth-first traversal of an object graph\r
-        private static void FillReferencedObjects(object value, EntitySet<object> values)\r
+        private void FillReferencedObjects(object parent, EntitySet<object> values)\r
         {\r
-            if (value == null)\r
+            if (parent == null)\r
                 return;\r
-            values.Add(value);\r
-            var children = new List<object>();\r
-            foreach (var p in value.GetType().GetProperties())\r
-            {\r
-                var type = p.PropertyType.IsGenericType\r
-                    ? p.PropertyType.GetGenericTypeDefinition()\r
-                    : null;\r
-                if (type != null && p.CanRead && type == typeof(EntitySet<>) &&\r
-                        p.GetGetMethod().GetParameters().Length == 0)\r
-                {\r
-                    var set = p.GetValue(value, null);\r
-                    if (set == null)\r
-                        continue;\r
-                    var hasLoadedOrAssignedValues = p.PropertyType.GetProperty("HasLoadedOrAssignedValues");\r
-                    if (!((bool)hasLoadedOrAssignedValues.GetValue(set, null)))\r
-                        continue;   // execution deferred; ignore.\r
-                    foreach (var o in ((IEnumerable)set))\r
-                        children.Add(o);\r
+            var children = new Queue<object>();\r
+                       children.Enqueue(parent);\r
+                       while (children.Count > 0)\r
+                       {\r
+                object value = children.Dequeue();\r
+                values.Add(value);\r
+                IEnumerable<MetaAssociation> associationList = Mapping.GetMetaType(value.GetType()).Associations.Where(a => !a.IsForeignKey);\r
+                if (associationList.Any())\r
+                           {\r
+                                   foreach (MetaAssociation association in associationList)\r
+                    {\r
+                        var memberData = association.ThisMember;\r
+                        var entitySetValue = memberData.Member.GetMemberValue(value);\r
+\r
+                        if (entitySetValue != null)\r
+                        {\r
+                                                   var hasLoadedOrAssignedValues = entitySetValue.GetType().GetProperty("HasLoadedOrAssignedValues");\r
+                                                   if (!((bool)hasLoadedOrAssignedValues.GetValue(entitySetValue, null)))\r
+                                                           continue;   // execution deferred; ignore.\r
+                                                   foreach (var o in ((IEnumerable)entitySetValue))\r
+                                                           children.Enqueue(o);\r
+                                           }\r
+                    }\r
                 }\r
-            }\r
-            foreach (var c in children)\r
-            {\r
-                FillReferencedObjects(c, values);\r
-            }\r
+                       }\r
         }\r
 \r
         private void InsertEntity(object entity, QueryContext queryContext)\r
@@ -487,7 +511,7 @@ namespace DbLinq.Data.Linq
             var insertQuery = QueryBuilder.GetInsertQuery(entity, queryContext);\r
             QueryRunner.Insert(entity, insertQuery);\r
             Register(entity);\r
-            UpdateReferencedObjects(entity, AutoSync.OnInsert);\r
+            UpdateReferencedObjects(entity);\r
             MoveToAllTrackedEntities(entity, true);\r
         }\r
 \r
@@ -502,26 +526,28 @@ namespace DbLinq.Data.Linq
                 QueryRunner.Update(entity, updateQuery, modifiedMembers);\r
 \r
                 RegisterUpdateAgain(entity);\r
-                UpdateReferencedObjects(entity, AutoSync.OnUpdate);\r
+                UpdateReferencedObjects(entity);\r
                 MoveToAllTrackedEntities(entity, false);\r
             }\r
         }\r
 \r
-        private void UpdateReferencedObjects(object root, AutoSync sync)\r
+        private void UpdateReferencedObjects(object root)\r
         {\r
             var metaType = Mapping.GetMetaType(root.GetType());\r
             foreach (var assoc in metaType.Associations)\r
             {\r
                 var memberData = assoc.ThisMember;\r
-                if (memberData.Association.ThisKey.Any(m => m.AutoSync != sync))\r
-                    continue;\r
+                               //This is not correct - AutoSyncing applies to auto-updating columns, such as a TimeStamp, not to foreign key associations, which is always automatically synched\r
+                               //Confirmed against default .NET l2sql - association columns are always set, even if AutoSync==AutoSync.Never\r
+                               //if (memberData.Association.ThisKey.Any(m => (m.AutoSync != AutoSync.Always) && (m.AutoSync != sync)))\r
+                //    continue;\r
                 var oks = memberData.Association.OtherKey.Select(m => m.StorageMember).ToList();\r
                 if (oks.Count == 0)\r
                     continue;\r
                 var pks = memberData.Association.ThisKey\r
                     .Select(m => m.StorageMember.GetMemberValue(root))\r
                     .ToList();\r
-                if (pks.Count != pks.Count)\r
+                if (pks.Count != oks.Count)\r
                     throw new InvalidOperationException(\r
                         string.Format("Count of primary keys ({0}) doesn't match count of other keys ({1}).",\r
                             pks.Count, oks.Count));\r
@@ -701,14 +727,27 @@ namespace DbLinq.Data.Linq
                                        //it would be interesting surround the above query with a .Take(1) expression for performance.\r
                                }\r
 \r
+                               // If no separate Storage is specified, use the member directly\r
+                               MemberInfo storage = memberData.StorageMember;\r
+                               if (storage == null)\r
+                                       storage = memberData.Member;\r
+\r
+                                // Check that the storage is a field or a writable property\r
+                               if (!(storage is FieldInfo) && !(storage is PropertyInfo && ((PropertyInfo)storage).CanWrite)) {\r
+                                       throw new InvalidOperationException(String.Format(\r
+                                               "Member {0}.{1} is not a field nor a writable property",\r
+                                               storage.DeclaringType, storage.Name));\r
+                               }\r
+\r
+                               Type storageType = storage.GetMemberType();\r
 \r
-                               FieldInfo entityRefField = (FieldInfo)memberData.StorageMember;\r
                                object entityRefValue = null;\r
                                if (query != null)\r
-                                       entityRefValue = Activator.CreateInstance(entityRefField.FieldType, query);\r
+                                       entityRefValue = Activator.CreateInstance(storageType, query);\r
                                else\r
-                                       entityRefValue = Activator.CreateInstance(entityRefField.FieldType);\r
-                               entityRefField.SetValue(entity, entityRefValue);\r
+                                       entityRefValue = Activator.CreateInstance(storageType);\r
+\r
+                               storage.SetMemberValue(entity, entityRefValue);\r
                        }\r
                }\r
 \r
@@ -791,18 +830,17 @@ namespace DbLinq.Data.Linq
             }\r
         }\r
 \r
-        private object GetOtherTableQuery(Expression predicate, ParameterExpression parameter, Type otherTableType, IQueryable otherTable)\r
+               private static MethodInfo _WhereMethod;\r
+        internal object GetOtherTableQuery(Expression predicate, ParameterExpression parameter, Type otherTableType, IQueryable otherTable)\r
         {\r
+            if (_WhereMethod == null)\r
+                System.Threading.Interlocked.CompareExchange (ref _WhereMethod, typeof(Queryable).GetMethods().First(m => m.Name == "Where"), null);\r
+\r
             //predicate: other.EmployeeID== "WARTH"\r
             Expression lambdaPredicate = Expression.Lambda(predicate, parameter);\r
             //lambdaPredicate: other=>other.EmployeeID== "WARTH"\r
 \r
-            var whereMethod = typeof(Queryable)\r
-                              .GetMethods().First(m => m.Name == "Where")\r
-                              .MakeGenericMethod(otherTableType);\r
-\r
-\r
-            Expression call = Expression.Call(whereMethod, otherTable.Expression, lambdaPredicate);\r
+                       Expression call = Expression.Call(_WhereMethod.MakeGenericMethod(otherTableType), otherTable.Expression, lambdaPredicate);\r
             //Table[EmployeesTerritories].Where(other=>other.employeeID="WARTH")\r
 \r
             return otherTable.Provider.CreateQuery(call);\r
@@ -821,12 +859,7 @@ namespace DbLinq.Data.Linq
             CurrentTransactionEntities.RegisterToInsert(entity);\r
         }\r
 \r
-        /// <summary>\r
-        /// Registers an entity for update\r
-        /// The entity will be updated only if some of its members have changed after the registration\r
-        /// </summary>\r
-        /// <param name="entity"></param>\r
-        internal void RegisterUpdate(object entity)\r
+        private void DoRegisterUpdate(object entity)\r
         {\r
             if (entity == null)\r
                 throw new ArgumentNullException("entity");\r
@@ -843,6 +876,17 @@ namespace DbLinq.Data.Linq
             AllTrackedEntities.RegisterToWatch(entity, identityKey);\r
         }\r
 \r
+        /// <summary>\r
+        /// Registers an entity for update\r
+        /// The entity will be updated only if some of its members have changed after the registration\r
+        /// </summary>\r
+        /// <param name="entity"></param>\r
+        internal void RegisterUpdate(object entity)\r
+        {\r
+            DoRegisterUpdate(entity);\r
+                       MemberModificationHandler.Register(entity, Mapping);\r
+        }\r
+\r
         /// <summary>\r
         /// Registers or re-registers an entity and clears its state\r
         /// </summary>\r
@@ -868,7 +912,7 @@ namespace DbLinq.Data.Linq
         {\r
             if (!this.objectTrackingEnabled)\r
                 return;\r
-            RegisterUpdate(entity);\r
+            DoRegisterUpdate(entity);\r
             MemberModificationHandler.Register(entity, entityOriginalState, Mapping);\r
         }\r
 \r
@@ -950,7 +994,7 @@ namespace DbLinq.Data.Linq
         /// <summary>\r
         /// Execute raw SQL query and return object\r
         /// </summary>\r
-        public IEnumerable<TResult> ExecuteQuery<TResult>(string query, params object[] parameters) where TResult : class, new()\r
+        public IEnumerable<TResult> ExecuteQuery<TResult>(string query, params object[] parameters) where TResult : new()\r
         {\r
             if (query == null)\r
                 throw new ArgumentNullException("query");\r
@@ -959,7 +1003,7 @@ namespace DbLinq.Data.Linq
         }\r
 \r
         private IEnumerable<TResult> CreateExecuteQueryEnumerable<TResult>(string query, object[] parameters)\r
-            where TResult : class, new()\r
+            where TResult : new()\r
         {\r
             foreach (TResult result in ExecuteQuery(typeof(TResult), query, parameters))\r
                 yield return result;\r
@@ -987,7 +1031,10 @@ namespace DbLinq.Data.Linq
                        set { throw new NotImplementedException(); }\r
                }\r
 \r
-        public DbTransaction Transaction { get; set; }\r
+        public DbTransaction Transaction {\r
+            get { return (DbTransaction) DatabaseContext.CurrentTransaction; }\r
+            set { DatabaseContext.CurrentTransaction = value; }\r
+        }\r
 \r
         /// <summary>\r
         /// Runs the given reader and returns columns.\r
@@ -1027,6 +1074,9 @@ namespace DbLinq.Data.Linq
         {\r
             //connection closing should not be done here.\r
             //read: http://msdn2.microsoft.com/en-us/library/bb292288.aspx\r
+\r
+                       //We own the instance of MemberModificationHandler - we must unregister listeners of entities we attached to\r
+                       MemberModificationHandler.UnregisterAll();\r
         }\r
 \r
         [DbLinqToDo]\r