Merge pull request #498 from Unroll-Me/master
[mono.git] / mcs / class / System.Core / System.Linq.Expressions / EmitContext.cs
index 221571dff419032f071ec8951a19c3fd0fd46f8b..a43b87f728bf0f2ad672b30c2e2b2edea1a2b808 100644 (file)
@@ -27,6 +27,7 @@
 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 //
 
+#if !FULL_AOT_RUNTIME
 using System;
 using System.Collections.ObjectModel;
 using System.Collections.Generic;
@@ -40,8 +41,106 @@ namespace System.Linq.Expressions {
 
        class CompilationContext {
 
+               class ParameterReplacer : ExpressionTransformer {
+
+                       CompilationContext context;
+                       ExecutionScope scope;
+                       object [] locals;
+
+                       public ParameterReplacer (CompilationContext context, ExecutionScope scope, object [] locals)
+                       {
+                               this.context = context;
+                               this.scope = scope;
+                               this.locals = locals;
+                       }
+
+                       protected override Expression VisitParameter (ParameterExpression parameter)
+                       {
+                               var scope = this.scope;
+                               var locals = this.locals;
+
+                               while (scope != null) {
+                                       int position = IndexOfHoistedLocal (scope, parameter);
+                                       if (position != -1)
+                                               return ReadHoistedLocalFromArray (locals, position);
+
+                                       locals = scope.Locals;
+                                       scope = scope.Parent;
+                               }
+
+                               return parameter;
+                       }
+
+                       Expression ReadHoistedLocalFromArray (object [] locals, int position)
+                       {
+                               return Expression.Field (
+                                       Expression.Convert (
+                                               Expression.ArrayIndex (
+                                                       Expression.Constant (locals),
+                                                       Expression.Constant (position)),
+                                               locals [position].GetType ()),
+                                       "Value");
+                       }
+
+                       int IndexOfHoistedLocal (ExecutionScope scope, ParameterExpression parameter)
+                       {
+                               return context.units [scope.compilation_unit].IndexOfHoistedLocal (parameter);
+                       }
+               }
+
+               class HoistedVariableDetector : ExpressionVisitor {
+
+                       Dictionary<ParameterExpression, LambdaExpression> parameter_to_lambda =
+                               new Dictionary<ParameterExpression, LambdaExpression> ();
+
+                       Dictionary<LambdaExpression, List<ParameterExpression>> hoisted_map;
+
+                       LambdaExpression lambda;
+
+                       public Dictionary<LambdaExpression, List<ParameterExpression>> Process (LambdaExpression lambda)
+                       {
+                               Visit (lambda);
+                               return hoisted_map;
+                       }
+
+                       protected override void VisitLambda (LambdaExpression lambda)
+                       {
+                               this.lambda = lambda;
+                               foreach (var parameter in lambda.Parameters)
+                                       parameter_to_lambda [parameter] = lambda;
+                               base.VisitLambda (lambda);
+                       }
+
+                       protected override void VisitParameter (ParameterExpression parameter)
+                       {
+                               if (lambda.Parameters.Contains (parameter))
+                                       return;
+
+                               Hoist (parameter);
+                       }
+
+                       void Hoist (ParameterExpression parameter)
+                       {
+                               LambdaExpression lambda;
+                               if (!parameter_to_lambda.TryGetValue (parameter, out lambda))
+                                       return;
+
+                               if (hoisted_map == null)
+                                       hoisted_map = new Dictionary<LambdaExpression, List<ParameterExpression>> ();
+
+                               List<ParameterExpression> hoisted;
+                               if (!hoisted_map.TryGetValue (lambda, out hoisted)) {
+                                       hoisted = new List<ParameterExpression> ();
+                                       hoisted_map [lambda] = hoisted;
+                               }
+
+                               hoisted.Add (parameter);
+                       }
+               }
+
                List<object> globals = new List<object> ();
                List<EmitContext> units = new List<EmitContext> ();
+               Dictionary<LambdaExpression, List<ParameterExpression>> hoisted_map;
 
                public int AddGlobal (object global)
                {
@@ -61,12 +160,44 @@ namespace System.Linq.Expressions {
 
                public int AddCompilationUnit (LambdaExpression lambda)
                {
-                       var context = new EmitContext (this, lambda);
+                       DetectHoistedVariables (lambda);
+                       return AddCompilationUnit (null, lambda);
+               }
+
+               public int AddCompilationUnit (EmitContext parent, LambdaExpression lambda)
+               {
+                       var context = new EmitContext (this, parent, lambda);
                        var unit = AddItemToList (context, units);
                        context.Emit ();
                        return unit;
                }
 
+               void DetectHoistedVariables (LambdaExpression lambda)
+               {
+                       hoisted_map = new HoistedVariableDetector ().Process (lambda);
+               }
+
+               public List<ParameterExpression> GetHoistedLocals (LambdaExpression lambda)
+               {
+                       if (hoisted_map == null)
+                               return null;
+
+                       List<ParameterExpression> hoisted;
+                       hoisted_map.TryGetValue (lambda, out hoisted);
+                       return hoisted;
+               }
+
+               public object [] CreateHoistedLocals (int unit)
+               {
+                       var hoisted = GetHoistedLocals (units [unit].Lambda);
+                       return new object [hoisted == null ? 0 : hoisted.Count];
+               }
+
+               public Expression IsolateExpression (ExecutionScope scope, object [] locals, Expression expression)
+               {
+                       return new ParameterReplacer (this, scope, locals).Transform (expression);
+               }
+
                public Delegate CreateDelegate ()
                {
                        return CreateDelegate (0, new ExecutionScope (this));
@@ -80,29 +211,49 @@ namespace System.Linq.Expressions {
 
        class EmitContext {
 
-               LambdaExpression owner;
                CompilationContext context;
+               EmitContext parent;
+               LambdaExpression lambda;
                DynamicMethod method;
+               LocalBuilder hoisted_store;
+               List<ParameterExpression> hoisted;
+
+               public readonly ILGenerator ig;
 
-               public ILGenerator ig;
+               public bool HasHoistedLocals {
+                       get { return hoisted != null && hoisted.Count > 0; }
+               }
+
+               public LambdaExpression Lambda {
+                       get { return lambda; }
+               }
 
-               public EmitContext (CompilationContext context, LambdaExpression lambda)
+               public EmitContext (CompilationContext context, EmitContext parent, LambdaExpression lambda)
                {
                        this.context = context;
-                       this.owner = lambda;
+                       this.parent = parent;
+                       this.lambda = lambda;
+                       this.hoisted = context.GetHoistedLocals (lambda);
 
-                       method = new DynamicMethod ("lambda_method", owner.GetReturnType (),
-                               CreateParameterTypes (owner.Parameters), typeof (ExecutionScope), true);
+                       method = new DynamicMethod (
+                               "lambda_method",
+                               lambda.GetReturnType (),
+                               CreateParameterTypes (lambda.Parameters),
+                               typeof (ExecutionScope),
+                               true);
 
                        ig = method.GetILGenerator ();
                }
 
                public void Emit ()
                {
-                       owner.EmitBody (this);
+                       if (HasHoistedLocals)
+                               EmitStoreHoistedLocals ();
+
+                       lambda.EmitBody (this);
                }
 
-               static Type [] CreateParameterTypes (ReadOnlyCollection<ParameterExpression> parameters)
+               static Type [] CreateParameterTypes (IList<ParameterExpression> parameters)
                {
                        var types = new Type [parameters.Count + 1];
                        types [0] = typeof (ExecutionScope);
@@ -113,18 +264,20 @@ namespace System.Linq.Expressions {
                        return types;
                }
 
-               public int GetParameterPosition (ParameterExpression p)
+               public bool IsLocalParameter (ParameterExpression parameter, ref int position)
                {
-                       int position = owner.Parameters.IndexOf (p);
-                       if (position == -1)
-                               throw new InvalidOperationException ("Parameter not in scope");
+                       position = lambda.Parameters.IndexOf (parameter);
+                       if (position > -1) {
+                               position++;
+                               return true;
+                       }
 
-                       return position + 1; // + 1 because 0 is the ExecutionScope
+                       return false;
                }
 
                public Delegate CreateDelegate (ExecutionScope scope)
                {
-                       return method.CreateDelegate (owner.Type, scope);
+                       return method.CreateDelegate (lambda.Type, scope);
                }
 
                public void Emit (Expression expression)
@@ -146,8 +299,25 @@ namespace System.Linq.Expressions {
                        ig.Emit (OpCodes.Ldloca, EmitStored (expression));
                }
 
+               public void EmitLoadEnum (Expression expression)
+               {
+                       expression.Emit (this);
+                       ig.Emit (OpCodes.Box, expression.Type);
+               }
+
+               public void EmitLoadEnum (LocalBuilder local)
+               {
+                       ig.Emit (OpCodes.Ldloc, local);
+                       ig.Emit (OpCodes.Box, local.LocalType);
+               }
+
                public void EmitLoadSubject (Expression expression)
                {
+                       if (expression.Type.IsEnum) {
+                               EmitLoadEnum (expression);
+                               return;
+                       }
+
                        if (expression.Type.IsValueType) {
                                EmitLoadAddress (expression);
                                return;
@@ -158,6 +328,11 @@ namespace System.Linq.Expressions {
 
                public void EmitLoadSubject (LocalBuilder local)
                {
+                       if (local.LocalType.IsEnum) {
+                               EmitLoadEnum (local);
+                               return;
+                       }
+
                        if (local.LocalType.IsValueType) {
                                EmitLoadAddress (local);
                                return;
@@ -176,7 +351,7 @@ namespace System.Linq.Expressions {
                        ig.Emit (OpCodes.Ldloc, local);
                }
 
-               public void EmitCall (LocalBuilder local, ReadOnlyCollection<Expression> arguments, MethodInfo method)
+               public void EmitCall (LocalBuilder local, IList<Expression> arguments, MethodInfo method)
                {
                        EmitLoadSubject (local);
                        EmitArguments (method, arguments);
@@ -197,7 +372,7 @@ namespace System.Linq.Expressions {
                        EmitCall (method);
                }
 
-               public void EmitCall (Expression expression, ReadOnlyCollection<Expression> arguments, MethodInfo method)
+               public void EmitCall (Expression expression, IList<Expression> arguments, MethodInfo method)
                {
                        if (!method.IsStatic)
                                EmitLoadSubject (expression);
@@ -206,7 +381,7 @@ namespace System.Linq.Expressions {
                        EmitCall (method);
                }
 
-               void EmitArguments (MethodInfo method, ReadOnlyCollection<Expression> arguments)
+               void EmitArguments (MethodInfo method, IList<Expression> arguments)
                {
                        var parameters = method.GetParameters ();
 
@@ -302,15 +477,25 @@ namespace System.Linq.Expressions {
                        EmitReadGlobal (global, global.GetType ());
                }
 
-               public void EmitReadGlobal (object global, Type type)
+               public void EmitLoadGlobals ()
                {
                        EmitScope ();
 
                        ig.Emit (OpCodes.Ldfld, typeof (ExecutionScope).GetField ("Globals"));
+               }
+
+               public void EmitReadGlobal (object global, Type type)
+               {
+                       EmitLoadGlobals ();
 
                        ig.Emit (OpCodes.Ldc_I4, AddGlobal (global, type));
                        ig.Emit (OpCodes.Ldelem, typeof (object));
 
+                       EmitLoadStrongBoxValue (type);
+               }
+
+               public void EmitLoadStrongBoxValue (Type type)
+               {
                        var strongbox = type.MakeStrongBoxType ();
 
                        ig.Emit (OpCodes.Isinst, strongbox);
@@ -327,16 +512,92 @@ namespace System.Linq.Expressions {
                        EmitScope ();
 
                        ig.Emit (OpCodes.Ldc_I4, AddChildContext (lambda));
-                       ig.Emit (OpCodes.Ldnull);
+                       if (hoisted_store != null)
+                               ig.Emit (OpCodes.Ldloc, hoisted_store);
+                       else
+                               ig.Emit (OpCodes.Ldnull);
 
                        ig.Emit (OpCodes.Callvirt, typeof (ExecutionScope).GetMethod ("CreateDelegate"));
 
                        ig.Emit (OpCodes.Castclass, lambda.Type);
                }
 
+               void EmitStoreHoistedLocals ()
+               {
+                       EmitHoistedLocalsStore ();
+                       for (int i = 0; i < hoisted.Count; i++)
+                               EmitStoreHoistedLocal (i, hoisted [i]);
+               }
+
+               void EmitStoreHoistedLocal (int position, ParameterExpression parameter)
+               {
+                       ig.Emit (OpCodes.Ldloc, hoisted_store);
+                       ig.Emit (OpCodes.Ldc_I4, position);
+                       parameter.Emit (this);
+                       EmitCreateStrongBox (parameter.Type);
+                       ig.Emit (OpCodes.Stelem, typeof (object));
+               }
+
+               public void EmitLoadHoistedLocalsStore ()
+               {
+                       ig.Emit (OpCodes.Ldloc, hoisted_store);
+               }
+
+               void EmitCreateStrongBox (Type type)
+               {
+                       ig.Emit (OpCodes.Newobj, type.MakeStrongBoxType ().GetConstructor (new [] { type }));
+               }
+
+               void EmitHoistedLocalsStore ()
+               {
+                       EmitScope ();
+                       hoisted_store = ig.DeclareLocal (typeof (object []));
+                       ig.Emit (OpCodes.Callvirt, typeof (ExecutionScope).GetMethod ("CreateHoistedLocals"));
+                       ig.Emit (OpCodes.Stloc, hoisted_store);
+               }
+
+               public void EmitLoadLocals ()
+               {
+                       ig.Emit (OpCodes.Ldfld, typeof (ExecutionScope).GetField ("Locals"));
+               }
+
+               public void EmitParentScope ()
+               {
+                       ig.Emit (OpCodes.Ldfld, typeof (ExecutionScope).GetField ("Parent"));
+               }
+
+               public void EmitIsolateExpression ()
+               {
+                       ig.Emit (OpCodes.Callvirt, typeof (ExecutionScope).GetMethod ("IsolateExpression"));
+               }
+
+               public int IndexOfHoistedLocal (ParameterExpression parameter)
+               {
+                       if (!HasHoistedLocals)
+                               return -1;
+
+                       return hoisted.IndexOf (parameter);
+               }
+
+               public bool IsHoistedLocal (ParameterExpression parameter, ref int level, ref int position)
+               {
+                       if (parent == null)
+                               return false;
+
+                       if (parent.hoisted != null) {
+                               position = parent.hoisted.IndexOf (parameter);
+                               if (position > -1)
+                                       return true;
+                       }
+
+                       level++;
+
+                       return parent.IsHoistedLocal (parameter, ref level, ref position);
+               }
+
                int AddChildContext (LambdaExpression lambda)
                {
-                       return context.AddCompilationUnit (lambda);
+                       return context.AddCompilationUnit (this, lambda);
                }
 
                static object CreateStrongBox (object value, Type type)
@@ -346,3 +607,4 @@ namespace System.Linq.Expressions {
                }
        }
 }
+#endif