Merge branch 'master' of github.com:mono/mono
[mono.git] / mcs / class / Mono.CodeContracts / Mono.CodeContracts.Rewrite / PerformRewrite.cs
1 //
2 // PerformRewrite.cs
3 //
4 // Authors:
5 //      Chris Bacon (chrisbacon76@gmail.com)
6 //
7 // Copyright (C) 2010 Chris Bacon
8 //
9 // Permission is hereby granted, free of charge, to any person obtaining
10 // a copy of this software and associated documentation files (the
11 // "Software"), to deal in the Software without restriction, including
12 // without limitation the rights to use, copy, modify, merge, publish,
13 // distribute, sublicense, and/or sell copies of the Software, and to
14 // permit persons to whom the Software is furnished to do so, subject to
15 // the following conditions:
16 // 
17 // The above copyright notice and this permission notice shall be
18 // included in all copies or substantial portions of the Software.
19 // 
20 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
21 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
22 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
23 // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
24 // LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
25 // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
26 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
27 //
28
29 using System;
30 using System.Collections.Generic;
31 using System.Linq;
32 using System.Text;
33 using Mono.Cecil;
34 using Mono.Cecil.Cil;
35 using Mono.CodeContracts.Rewrite.Ast;
36 using Mono.CodeContracts.Rewrite.AstVisitors;
37
38 namespace Mono.CodeContracts.Rewrite {
39         class PerformRewrite {
40
41                 public PerformRewrite (ISymbolWriter sym, RewriterOptions options)
42                 {
43                         this.sym = sym;
44                         this.options = options;
45                 }
46
47                 private ISymbolWriter sym;
48                 private RewriterOptions options;
49                 private Dictionary<MethodDefinition, TransformContractsVisitor> rewrittenMethods = new Dictionary<MethodDefinition, TransformContractsVisitor> ();
50
51                 public void Rewrite (AssemblyDefinition assembly)
52                 {
53                         foreach (ModuleDefinition module in assembly.Modules) {
54                                 ContractsRuntime contractsRuntime = new ContractsRuntime(module, this.options);
55
56                                 var allMethods =
57                                         from type in module.Types.Cast<TypeDefinition> ()
58                                         from method in type.Methods.Cast<MethodDefinition> ()
59                                         select method;
60
61                                 foreach (MethodDefinition method in allMethods.ToArray ()) {
62                                         this.RewriteMethod (module, method, contractsRuntime);
63                                 }
64                         }
65                 }
66
67                 private void RewriteMethod (ModuleDefinition module, MethodDefinition method, ContractsRuntime contractsRuntime)
68                 {
69                         if (this.rewrittenMethods.ContainsKey (method)) {
70                                 return;
71                         }
72                         var overridden = this.GetOverriddenMethod (method);
73                         if (overridden != null) {
74                                 this.RewriteMethod (module, overridden, contractsRuntime);
75                         }
76                         bool anyRewrites = false;
77                         var baseMethod = this.GetBaseOverriddenMethod (method);
78                         if (baseMethod != method) {
79                                 // Contract inheritance must be used
80                                 var vOverriddenTransform = this.rewrittenMethods [baseMethod];
81                                 // Can be null if overriding an abstract method
82                                 if (vOverriddenTransform != null) {
83                                         if (this.options.Level >= 2) {
84                                                 // Only insert re-written contracts if level >= 2
85                                                 foreach (var inheritedRequires in vOverriddenTransform.ContractRequiresInfo) {
86                                                         this.RewriteIL (method.Body, null, null, inheritedRequires.RewrittenExpr);
87                                                         anyRewrites = true;
88                                                 }
89                                         }
90                                 }
91                         }
92
93                         TransformContractsVisitor vTransform = null;
94                         if (method.HasBody) {
95                                 vTransform = this.TransformContracts (module, method, contractsRuntime);
96                                 if (this.sym != null) {
97                                         this.sym.Write (method.Body);
98                                 }
99                                 if (vTransform.ContractRequiresInfo.Any ()) {
100                                         anyRewrites = true;
101                                 }
102                         }
103                         this.rewrittenMethods.Add (method, vTransform);
104
105                         if (anyRewrites) {
106                                 Console.WriteLine (method);
107                         }
108                 }
109
110                 private TransformContractsVisitor TransformContracts (ModuleDefinition module, MethodDefinition method, ContractsRuntime contractsRuntime)
111                 {
112                         var body = method.Body;
113                         Decompile decompile = new Decompile (module, method);
114                         var decomp = decompile.Go ();
115
116                         TransformContractsVisitor vTransform = new TransformContractsVisitor (module, method, decompile.Instructions, contractsRuntime);
117                         vTransform.Visit (decomp);
118
119                         foreach (var replacement in vTransform.ContractRequiresInfo) {
120                                 // Only insert re-written contracts if level >= 2
121                                 Expr rewritten = this.options.Level >= 2 ? replacement.RewrittenExpr : null;
122                                 this.RewriteIL (body, decompile.Instructions, replacement.OriginalExpr, rewritten);
123                         }
124
125                         return vTransform;
126                 }
127
128                 private void RewriteIL (MethodBody body, Dictionary<Expr,Instruction> instructionLookup, Expr remove, Expr insert)
129                 {
130                         var il = body.CilWorker;
131                         Instruction instInsertBefore;
132                         if (remove != null) {
133                                 var vInstExtent = new InstructionExtentVisitor (instructionLookup);
134                                 vInstExtent.Visit (remove);
135                                 instInsertBefore = vInstExtent.Instructions.Last ().Next;
136                                 foreach (var instRemove in vInstExtent.Instructions) {
137                                         il.Remove (instRemove);
138                                 }
139                         } else {
140                                 instInsertBefore = body.Instructions [0];
141                         }
142                         if (insert != null) {
143                                 var compiler = new CompileVisitor (il, instructionLookup, inst => il.InsertBefore (instInsertBefore, inst));
144                                 compiler.Visit (insert);
145                         }
146                 }
147
148                 private MethodDefinition GetOverriddenMethod (MethodDefinition method)
149                 {
150                         if (method.IsNewSlot || !method.IsVirtual) {
151                                 return null;
152                         }
153                         var baseType = method.DeclaringType.BaseType;
154                         if (baseType == null) {
155                                 return null;
156                         }
157                         var overridden = baseType.Resolve ().Methods.Cast<MethodDefinition> ().FirstOrDefault (x => x.Name == method.Name);
158                         return overridden;
159                 }
160
161                 private MethodDefinition GetBaseOverriddenMethod (MethodDefinition method)
162                 {
163                         var overridden = method;
164                         while (true) {
165                                 var overriddenTemp = this.GetOverriddenMethod (overridden);
166                                 if (overriddenTemp == null) {
167                                         return overridden;
168                                 }
169                                 overridden = overriddenTemp;
170                         }
171                 }
172
173         }
174 }