Merge branch 'cecil-light'
[mono.git] / mcs / tools / tuner / Mono.Tuner / InjectSecurityAttributes.cs
1 //
2 // InjectSecurityAttributes.cs
3 //
4 // Author:
5 //   Jb Evain (jbevain@novell.com)
6 //
7 // (C) 2009 Novell, Inc.
8 //
9 // Permission is hereby granted, free of charge, to any person obtaining
10 // a copy of this software and associated documentation files (the
11 // "Software"), to deal in the Software without restriction, including
12 // without limitation the rights to use, copy, modify, merge, publish,
13 // distribute, sublicense, and/or sell copies of the Software, and to
14 // permit persons to whom the Software is furnished to do so, subject to
15 // the following conditions:
16 //
17 // The above copyright notice and this permission notice shall be
18 // included in all copies or substantial portions of the Software.
19 //
20 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
21 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
22 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
23 // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
24 // LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
25 // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
26 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
27 //
28
29 using System;
30 using System.Collections;
31 using System.IO;
32 using System.Linq;
33 using System.Text;
34
35 using Mono.Linker;
36 using Mono.Linker.Steps;
37
38 using Mono.Cecil;
39
40 namespace Mono.Tuner {
41
42         public class InjectSecurityAttributes : BaseStep {
43
44                 enum TargetKind {
45                         Type,
46                         Method,
47                 }
48
49                 protected enum AttributeType {
50                         Critical,
51                         SafeCritical,
52                 }
53
54                 const string _safe_critical = "System.Security.SecuritySafeCriticalAttribute";
55                 const string _critical = "System.Security.SecurityCriticalAttribute";
56
57                 const string sec_attr_folder = "secattrs";
58
59                 protected AssemblyDefinition _assembly;
60
61                 MethodDefinition _safe_critical_ctor;
62                 MethodDefinition _critical_ctor;
63
64                 string data_folder;
65
66                 protected override bool ConditionToProcess ()
67                 {
68                         if (!Context.HasParameter (sec_attr_folder)) {
69                                 Console.Error.WriteLine ("Warning: no secattrs folder specified.");
70                                 return false;
71                         }
72
73                         data_folder = Context.GetParameter (sec_attr_folder);
74                         return true;
75                 }
76
77                 protected override void ProcessAssembly (AssemblyDefinition assembly)
78                 {
79                         if (Annotations.GetAction (assembly) != AssemblyAction.Link)
80                                 return;
81
82                         string secattr_file = Path.Combine (
83                                 data_folder,
84                                 assembly.Name.Name + ".secattr");
85
86                         if (!File.Exists (secattr_file)) {
87                                 Console.Error.WriteLine ("Warning: file '{0}' not found, skipping.", secattr_file);
88                                 return;
89                         }
90
91                         _assembly = assembly;
92
93                         // remove existing [SecurityCritical] and [SecuritySafeCritical]
94                         RemoveSecurityAttributes ();
95
96                         // add [SecurityCritical] and [SecuritySafeCritical] from the data file
97                         ProcessSecurityAttributeFile (secattr_file);
98                 }
99
100                 protected void RemoveSecurityAttributes ()
101                 {
102                         foreach (TypeDefinition type in _assembly.MainModule.Types) {
103                                 RemoveSecurityAttributes (type);
104
105                                 if (type.HasMethods)
106                                         foreach (MethodDefinition method in type.Methods)
107                                                 RemoveSecurityAttributes (method);
108                         }
109                 }
110
111                 static void RemoveSecurityDeclarations (ISecurityDeclarationProvider provider)
112                 {
113                         // also remove already existing CAS security declarations
114
115                         if (provider == null)
116                                 return;
117
118                         if (!provider.HasSecurityDeclarations)
119                                 return;
120
121                         provider.SecurityDeclarations.Clear ();
122                 }
123
124                 static void RemoveSecurityAttributes (ICustomAttributeProvider provider)
125                 {
126                         RemoveSecurityDeclarations (provider as ISecurityDeclarationProvider);
127
128                         if (!provider.HasCustomAttributes)
129                                 return;
130
131                         var attributes = provider.CustomAttributes;
132                         for (int i = 0; i < attributes.Count; i++) {
133                                 CustomAttribute attribute = attributes [i];
134                                 switch (attribute.Constructor.DeclaringType.FullName) {
135                                 case _safe_critical:
136                                 case _critical:
137                                         attributes.RemoveAt (i--);
138                                         break;
139                                 }
140                         }
141                 }
142
143                 void ProcessSecurityAttributeFile (string file)
144                 {
145                         using (StreamReader reader = File.OpenText (file)) {
146                                 string line;
147                                 while ((line = reader.ReadLine ()) != null)
148                                         ProcessLine (line);
149                         }
150                 }
151
152                 void ProcessLine (string line)
153                 {
154                         if (line == null || line.Length < 6)
155                                 return;
156
157                         int sep = line.IndexOf (": ");
158                         if (sep == -1)
159                                 return;
160
161                         string marker = line.Substring (0, sep);
162                         string target = line.Substring (sep + 2);
163
164                         ProcessSecurityAttributeEntry (
165                                 DecomposeAttributeType (marker),
166                                 DecomposeTargetKind (marker),
167                                 target);
168                 }
169
170                 static AttributeType DecomposeAttributeType (string marker)
171                 {
172                         if (marker.StartsWith ("SC"))
173                                 return AttributeType.Critical;
174                         else if (marker.StartsWith ("SSC"))
175                                 return AttributeType.SafeCritical;
176                         else
177                                 throw new ArgumentException ();
178                 }
179
180                 static TargetKind DecomposeTargetKind (string marker)
181                 {
182                         switch (marker [marker.Length - 1]) {
183                         case 'T':
184                                 return TargetKind.Type;
185                         case 'M':
186                                 return TargetKind.Method;
187                         default:
188                                 throw new ArgumentException ();
189                         }
190                 }
191
192                 void ProcessSecurityAttributeEntry (AttributeType type, TargetKind kind, string target)
193                 {
194                         ICustomAttributeProvider provider = GetTarget (kind, target);
195                         if (provider == null)
196                                 return;
197
198                         switch (type) {
199                         case AttributeType.Critical:
200                                 AddCriticalAttribute (provider);
201                                 break;
202                         case AttributeType.SafeCritical:
203                                 AddSafeCriticalAttribute (provider);
204                                 break;
205                         }
206                 }
207
208                 protected void AddCriticalAttribute (ICustomAttributeProvider provider)
209                 {
210                         // a [SecurityCritical] replaces a [SecuritySafeCritical]
211                         if (HasSecurityAttribute (provider, AttributeType.SafeCritical))
212                                 RemoveSecurityAttributes (provider);
213
214                         AddSecurityAttribute (provider, AttributeType.Critical);
215                 }
216
217                 void AddSafeCriticalAttribute (ICustomAttributeProvider provider)
218                 {
219                         // a [SecuritySafeCritical] is ignored if a [SecurityCritical] is present
220                         if (HasSecurityAttribute (provider, AttributeType.Critical))
221                                 return;
222
223                         AddSecurityAttribute (provider, AttributeType.SafeCritical);
224                 }
225
226                 void AddSecurityAttribute (ICustomAttributeProvider provider, AttributeType type)
227                 {
228                         if (HasSecurityAttribute (provider, type))
229                                 return;
230
231                         var attributes = provider.CustomAttributes;
232                         switch (type) {
233                         case AttributeType.Critical:
234                                 attributes.Add (CreateCriticalAttribute ());
235                                 break;
236                         case AttributeType.SafeCritical:
237                                 attributes.Add (CreateSafeCriticalAttribute ());
238                                 break;
239                         }
240                 }
241
242                 protected static bool HasSecurityAttribute (ICustomAttributeProvider provider, AttributeType type)
243                 {
244                         if (!provider.HasCustomAttributes)
245                                 return false;
246
247                         foreach (CustomAttribute attribute in provider.CustomAttributes) {
248                                 switch (attribute.Constructor.DeclaringType.Name) {
249                                 case _critical:
250                                         if (type == AttributeType.Critical)
251                                                 return true;
252
253                                         break;
254                                 case _safe_critical:
255                                         if (type == AttributeType.SafeCritical)
256                                                 return true;
257
258                                         break;
259                                 }
260                         }
261
262                         return false;
263                 }
264
265                 ICustomAttributeProvider GetTarget (TargetKind kind, string target)
266                 {
267                         switch (kind) {
268                         case TargetKind.Type:
269                                 return GetType (target);
270                         case TargetKind.Method:
271                                 return GetMethod (target);
272                         default:
273                                 throw new ArgumentException ();
274                         }
275                 }
276
277                 TypeDefinition GetType (string fullname)
278                 {
279                         return _assembly.MainModule.GetType (fullname);
280                 }
281
282                 MethodDefinition GetMethod (string signature)
283                 {
284                         int pos = signature.IndexOf (" ");
285                         if (pos == -1)
286                                 throw new ArgumentException ();
287
288                         string tmp = signature.Substring (pos + 1);
289
290                         pos = tmp.IndexOf ("::");
291                         if (pos == -1)
292                                 throw new ArgumentException ();
293
294                         string type_name = tmp.Substring (0, pos);
295
296                         int parpos = tmp.IndexOf ("(");
297                         if (parpos == -1)
298                                 throw new ArgumentException ();
299
300                         string method_name = tmp.Substring (pos + 2, parpos - pos - 2);
301
302                         TypeDefinition type = GetType (type_name);
303                         if (type == null)
304                                 return null;
305
306                         return GetMethod (type.Methods, signature);
307                 }
308
309                 static MethodDefinition GetMethod (IEnumerable methods, string signature)
310                 {
311                         foreach (MethodDefinition method in methods)
312                                 if (GetFullName (method) == signature)
313                                         return method;
314
315                         return null;
316                 }
317
318                 static string GetFullName (MethodReference method)
319                 {
320                         var sentinel = method.Parameters.FirstOrDefault (p => p.ParameterType.IsSentinel);
321                         var sentinel_pos = -1;
322                         if (sentinel != null)
323                                 sentinel_pos = method.Parameters.IndexOf (sentinel);
324
325                         StringBuilder sb = new StringBuilder ();
326                         sb.Append (method.ReturnType.FullName);
327                         sb.Append (" ");
328                         sb.Append (method.DeclaringType.FullName);
329                         sb.Append ("::");
330                         sb.Append (method.Name);
331                         if (method.HasGenericParameters) {
332                                 sb.Append ("<");
333                                 for (int i = 0; i < method.GenericParameters.Count; i++ ) {
334                                         if (i > 0)
335                                                 sb.Append (",");
336                                         sb.Append (method.GenericParameters [i].Name);
337                                 }
338                                 sb.Append (">");
339                         }
340                         sb.Append ("(");
341                         if (method.HasParameters) {
342                                 for (int i = 0; i < method.Parameters.Count; i++) {
343                                         if (i > 0)
344                                                 sb.Append (",");
345
346                                         if (i == sentinel_pos)
347                                                 sb.Append ("...,");
348
349                                         sb.Append (method.Parameters [i].ParameterType.FullName);
350                                 }
351                         }
352                         sb.Append (")");
353                         return sb.ToString ();
354                 }
355
356                 static MethodDefinition GetDefaultConstructor (TypeDefinition type)
357                 {
358                         foreach (MethodDefinition ctor in type.Methods.Where (m => m.IsConstructor))
359                                 if (ctor.Parameters.Count == 0)
360                                         return ctor;
361
362                         return null;
363                 }
364
365                 MethodDefinition GetSafeCriticalCtor ()
366                 {
367                         if (_safe_critical_ctor != null)
368                                 return _safe_critical_ctor;
369
370                         TypeDefinition safe_critical_type = Context.GetType (_safe_critical);
371                         if (safe_critical_type == null)
372                                 throw new InvalidOperationException (String.Format ("{0} type not found", _safe_critical));
373
374                         _safe_critical_ctor = GetDefaultConstructor (safe_critical_type);
375                         return _safe_critical_ctor;
376                 }
377
378                 MethodDefinition GetCriticalCtor ()
379                 {
380                         if (_critical_ctor != null)
381                                 return _critical_ctor;
382
383                         TypeDefinition critical_type = Context.GetType (_critical);
384                         if (critical_type == null)
385                                 throw new InvalidOperationException (String.Format ("{0} type not found", _critical));
386
387                         _critical_ctor = GetDefaultConstructor (critical_type);
388                         return _critical_ctor;
389                 }
390
391                 MethodReference Import (MethodDefinition method)
392                 {
393                         return _assembly.MainModule.Import (method);
394                 }
395
396                 CustomAttribute CreateSafeCriticalAttribute ()
397                 {
398                         return new CustomAttribute (Import (GetSafeCriticalCtor ()));
399                 }
400
401                 CustomAttribute CreateCriticalAttribute ()
402                 {
403                         return new CustomAttribute (Import (GetCriticalCtor ()));
404                 }
405         }
406 }