2009-05-15 Jb Evain <jbevain@novell.com>
[mono.git] / mcs / tools / tuner / Mono.Tuner / InjectSecurityAttributes.cs
1 //
2 // InjectSecurityAttributes.cs
3 //
4 // Author:
5 //   Jb Evain (jbevain@novell.com)
6 //
7 // (C) 2009 Novell, Inc.
8 //
9 // Permission is hereby granted, free of charge, to any person obtaining
10 // a copy of this software and associated documentation files (the
11 // "Software"), to deal in the Software without restriction, including
12 // without limitation the rights to use, copy, modify, merge, publish,
13 // distribute, sublicense, and/or sell copies of the Software, and to
14 // permit persons to whom the Software is furnished to do so, subject to
15 // the following conditions:
16 //
17 // The above copyright notice and this permission notice shall be
18 // included in all copies or substantial portions of the Software.
19 //
20 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
21 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
22 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
23 // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
24 // LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
25 // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
26 // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
27 //
28
29 using System;
30 using System.Collections;
31 using System.IO;
32 using System.Text;
33
34 using Mono.Linker;
35 using Mono.Linker.Steps;
36
37 using Mono.Cecil;
38
39 namespace Mono.Tuner {
40
41         public class InjectSecurityAttributes : BaseStep {
42
43                 enum TargetKind {
44                         Type,
45                         Method,
46                 }
47
48                 enum AttributeType {
49                         Critical,
50                         SafeCritical,
51                 }
52
53                 const string _safe_critical = "System.Security.SecuritySafeCriticalAttribute";
54                 const string _critical = "System.Security.SecurityCriticalAttribute";
55
56                 const string sec_attr_folder = "secattrs";
57
58                 AssemblyDefinition _assembly;
59
60                 MethodDefinition _safe_critical_ctor;
61                 MethodDefinition _critical_ctor;
62
63                 string data_folder;
64
65                 protected override bool ConditionToProcess ()
66                 {
67                         if (!Context.HasParameter (sec_attr_folder))
68                                 return false;
69
70                         data_folder = Context.GetParameter (sec_attr_folder);
71                         return true;
72                 }
73
74                 protected override void ProcessAssembly (AssemblyDefinition assembly)
75                 {
76                         if (Annotations.GetAction (assembly) != AssemblyAction.Link)
77                                 return;
78
79                         string secattr_file = Path.Combine (
80                                 data_folder,
81                                 assembly.Name.Name + ".secattr");
82
83                         if (!File.Exists (secattr_file)) {
84                                 Console.Error.WriteLine ("Warning: file '{0}' not found, skipping.", secattr_file);
85                                 return;
86                         }
87
88                         _assembly = assembly;
89
90                         // remove existing [SecurityCritical] and [SecuritySafeCritical]
91                         RemoveSecurityAttributes ();
92
93                         // add [SecurityCritical] and [SecuritySafeCritical] from the data file
94                         ProcessSecurityAttributeFile (secattr_file);
95                 }
96
97                 void RemoveSecurityAttributes ()
98                 {
99                         foreach (TypeDefinition type in _assembly.MainModule.Types) {
100                                 RemoveSecurityAttributes (type);
101
102                                 if (type.HasConstructors)
103                                         foreach (MethodDefinition ctor in type.Constructors)
104                                                 RemoveSecurityAttributes (ctor);
105
106                                 if (type.HasMethods)
107                                         foreach (MethodDefinition method in type.Methods)
108                                                 RemoveSecurityAttributes (method);
109                         }
110                 }
111
112                 static void RemoveSecurityDeclarations (IHasSecurity provider)
113                 {
114                         // also remove already existing CAS security declarations
115
116                         if (provider == null)
117                                 return;
118
119                         if (!provider.HasSecurityDeclarations)
120                                 return;
121
122                         provider.SecurityDeclarations.Clear ();
123                 }
124
125                 static void RemoveSecurityAttributes (ICustomAttributeProvider provider)
126                 {
127                         RemoveSecurityDeclarations (provider as IHasSecurity);
128
129                         if (!provider.HasCustomAttributes)
130                                 return;
131
132                         CustomAttributeCollection attributes = provider.CustomAttributes;
133                         for (int i = 0; i < attributes.Count; i++) {
134                                 CustomAttribute attribute = attributes [i];
135                                 switch (attribute.Constructor.DeclaringType.FullName) {
136                                 case _safe_critical:
137                                 case _critical:
138                                         attributes.RemoveAt (i--);
139                                         break;
140                                 }
141                         }
142                 }
143
144                 void ProcessSecurityAttributeFile (string file)
145                 {
146                         using (StreamReader reader = File.OpenText (file)) {
147                                 string line;
148                                 while ((line = reader.ReadLine ()) != null)
149                                         ProcessLine (line);
150                         }
151                 }
152
153                 void ProcessLine (string line)
154                 {
155                         if (line == null || line.Length < 6)
156                                 return;
157
158                         int sep = line.IndexOf (": ");
159                         if (sep == -1)
160                                 return;
161
162                         string marker = line.Substring (0, sep);
163                         string target = line.Substring (sep + 2);
164
165                         ProcessSecurityAttributeEntry (
166                                 DecomposeAttributeType (marker),
167                                 DecomposeTargetKind (marker),
168                                 target);
169                 }
170
171                 static AttributeType DecomposeAttributeType (string marker)
172                 {
173                         if (marker.StartsWith ("SC"))
174                                 return AttributeType.Critical;
175                         else if (marker.StartsWith ("SSC"))
176                                 return AttributeType.SafeCritical;
177                         else
178                                 throw new ArgumentException ();
179                 }
180
181                 static TargetKind DecomposeTargetKind (string marker)
182                 {
183                         switch (marker [marker.Length - 1]) {
184                         case 'T':
185                                 return TargetKind.Type;
186                         case 'M':
187                                 return TargetKind.Method;
188                         default:
189                                 throw new ArgumentException ();
190                         }
191                 }
192
193                 void ProcessSecurityAttributeEntry (AttributeType type, TargetKind kind, string target)
194                 {
195                         ICustomAttributeProvider provider = GetTarget (kind, target);
196                         if (provider == null)
197                                 return;
198
199                         switch (type) {
200                         case AttributeType.Critical:
201                                 AddCriticalAttribute (provider);
202                                 break;
203                         case AttributeType.SafeCritical:
204                                 AddSafeCriticalAttribute (provider);
205                                 break;
206                         }
207                 }
208
209                 void AddCriticalAttribute (ICustomAttributeProvider provider)
210                 {
211                         // a [SecurityCritical] replaces a [SecuritySafeCritical]
212                         if (HasSecurityAttribute (provider, AttributeType.SafeCritical))
213                                 RemoveSecurityAttributes (provider);
214
215                         AddSecurityAttribute (provider, AttributeType.Critical);
216                 }
217
218                 void AddSafeCriticalAttribute (ICustomAttributeProvider provider)
219                 {
220                         // a [SecuritySafeCritical] is ignored if a [SecurityCritical] is present
221                         if (HasSecurityAttribute (provider, AttributeType.Critical))
222                                 return;
223
224                         AddSecurityAttribute (provider, AttributeType.SafeCritical);
225                 }
226
227                 void AddSecurityAttribute (ICustomAttributeProvider provider, AttributeType type)
228                 {
229                         if (HasSecurityAttribute (provider, type))
230                                 return;
231
232                         CustomAttributeCollection attributes = provider.CustomAttributes;
233                         switch (type) {
234                         case AttributeType.Critical:
235                                 attributes.Add (CreateCriticalAttribute ());
236                                 break;
237                         case AttributeType.SafeCritical:
238                                 attributes.Add (CreateSafeCriticalAttribute ());
239                                 break;
240                         }
241                 }
242
243                 static bool HasSecurityAttribute (ICustomAttributeProvider provider, AttributeType type)
244                 {
245                         if (!provider.HasCustomAttributes)
246                                 return false;
247
248                         foreach (CustomAttribute attribute in provider.CustomAttributes) {
249                                 switch (attribute.Constructor.DeclaringType.Name) {
250                                 case _critical:
251                                         if (type == AttributeType.Critical)
252                                                 return true;
253
254                                         break;
255                                 case _safe_critical:
256                                         if (type == AttributeType.SafeCritical)
257                                                 return true;
258
259                                         break;
260                                 }
261                         }
262
263                         return false;
264                 }
265
266                 ICustomAttributeProvider GetTarget (TargetKind kind, string target)
267                 {
268                         switch (kind) {
269                         case TargetKind.Type:
270                                 return GetType (target);
271                         case TargetKind.Method:
272                                 return GetMethod (target);
273                         default:
274                                 throw new ArgumentException ();
275                         }
276                 }
277
278                 TypeDefinition GetType (string fullname)
279                 {
280                         return _assembly.MainModule.Types [fullname];
281                 }
282
283                 MethodDefinition GetMethod (string signature)
284                 {
285                         int pos = signature.IndexOf (" ");
286                         if (pos == -1)
287                                 throw new ArgumentException ();
288
289                         string tmp = signature.Substring (pos + 1);
290
291                         pos = tmp.IndexOf ("::");
292                         if (pos == -1)
293                                 throw new ArgumentException ();
294
295                         string type_name = tmp.Substring (0, pos);
296
297                         int parpos = tmp.IndexOf ("(");
298                         if (parpos == -1)
299                                 throw new ArgumentException ();
300
301                         string method_name = tmp.Substring (pos + 2, parpos - pos - 2);
302
303                         TypeDefinition type = GetType (type_name);
304                         if (type == null)
305                                 return null;
306
307                         return method_name.StartsWith (".c") ?
308                                 GetMethod (type.Constructors, signature) :
309                                 GetMethod (type.Methods.GetMethod (method_name), signature);
310                 }
311
312                 static MethodDefinition GetMethod (IEnumerable methods, string signature)
313                 {
314                         foreach (MethodDefinition method in methods)
315                                 if (GetFullName (method) == signature)
316                                         return method;
317
318                         return null;
319                 }
320
321                 static string GetFullName (MethodReference method)
322                 {
323                         int sentinel = method.GetSentinel ();
324
325                         StringBuilder sb = new StringBuilder ();
326                         sb.Append (method.ReturnType.ReturnType.FullName);
327                         sb.Append (" ");
328                         sb.Append (method.DeclaringType.FullName);
329                         sb.Append ("::");
330                         sb.Append (method.Name);
331                         if (method.HasGenericParameters) {
332                                 sb.Append ("<");
333                                 for (int i = 0; i < method.GenericParameters.Count; i++ ) {
334                                         if (i > 0)
335                                                 sb.Append (",");
336                                         sb.Append (method.GenericParameters [i].Name);
337                                 }
338                                 sb.Append (">");
339                         }
340                         sb.Append ("(");
341                         if (method.HasParameters) {
342                                 for (int i = 0; i < method.Parameters.Count; i++) {
343                                         if (i > 0)
344                                                 sb.Append (",");
345
346                                         if (i == sentinel)
347                                                 sb.Append ("...,");
348
349                                         sb.Append (method.Parameters [i].ParameterType.FullName);
350                                 }
351                         }
352                         sb.Append (")");
353                         return sb.ToString ();
354                 }
355
356                 static MethodDefinition GetDefaultConstructor (TypeDefinition type)
357                 {
358                         foreach (MethodDefinition ctor in type.Constructors)
359                                 if (ctor.Parameters.Count == 0)
360                                         return ctor;
361
362                         return null;
363                 }
364
365                 MethodDefinition GetSafeCriticalCtor ()
366                 {
367                         if (_safe_critical_ctor != null)
368                                 return _safe_critical_ctor;
369
370                         _safe_critical_ctor = GetDefaultConstructor (Context.GetType (_safe_critical));
371                         return _safe_critical_ctor;
372                 }
373
374                 MethodDefinition GetCriticalCtor ()
375                 {
376                         if (_critical_ctor != null)
377                                 return _critical_ctor;
378
379                         _critical_ctor = GetDefaultConstructor (Context.GetType (_critical));
380                         return _critical_ctor;
381                 }
382
383                 MethodReference Import (MethodDefinition method)
384                 {
385                         return _assembly.MainModule.Import (method);
386                 }
387
388                 CustomAttribute CreateSafeCriticalAttribute ()
389                 {
390                         return new CustomAttribute (Import (GetSafeCriticalCtor ()));
391                 }
392
393                 CustomAttribute CreateCriticalAttribute ()
394                 {
395                         return new CustomAttribute (Import (GetCriticalCtor ()));
396                 }
397         }
398 }