[Mono.Unix] Fix crasher in StringToHeap (#5639)
[mono.git] / mcs / class / Mono.Posix / Mono.Unix / UnixMarshal.cs
index 334b6f2a49b42c089ea2b1492109021cbc75628a..7d39fd4c43ffad92bcf2f284a542482b63f0c4eb 100644 (file)
@@ -309,6 +309,9 @@ namespace Mono.Unix {
 
                public static IntPtr StringToHeap (string s, Encoding encoding)
                {
+                       if (s == null)
+                               return IntPtr.Zero;
+
                        return StringToHeap (s, 0, s.Length, encoding);
                }
 
@@ -325,27 +328,42 @@ namespace Mono.Unix {
                        if (encoding == null)
                                throw new ArgumentNullException ("encoding");
 
-                       int min_byte_count = encoding.GetMaxByteCount(1);
-                       char[] copy = s.ToCharArray (index, count);
-                       byte[] marshal = new byte [encoding.GetByteCount (copy) + min_byte_count];
+                       if (index < 0 || count < 0)
+                               throw new ArgumentOutOfRangeException ((index < 0 ? "index" : "count"),
+                                        "Non - negative number required.");
 
-                       int bytes_copied = encoding.GetBytes (copy, 0, copy.Length, marshal, 0);
+                       if (s.Length - index < count)
+                               throw new ArgumentOutOfRangeException ("s", "Index and count must refer to a location within the string.");
 
-                       if (bytes_copied != (marshal.Length-min_byte_count))
-                               throw new NotSupportedException ("encoding.GetBytes() doesn't equal encoding.GetByteCount()!");
+                       int null_terminator_count = encoding.GetMaxByteCount (1);
+                       int length_without_null = encoding.GetByteCount (s);
+                       int marshalLength = checked (length_without_null + null_terminator_count);
 
-                       IntPtr mem = AllocHeap (marshal.Length);
+                       IntPtr mem = AllocHeap (marshalLength);
                        if (mem == IntPtr.Zero)
                                throw new UnixIOException (Native.Errno.ENOMEM);
 
-                       bool copied = false;
-                       try {
-                               Marshal.Copy (marshal, 0, mem, marshal.Length);
-                               copied = true;
-                       }
-                       finally {
-                               if (!copied)
-                                       FreeHeap (mem);
+                       unsafe {
+                               fixed (char* p = s) {
+                                       byte* marshal = (byte*)mem;
+                                       int bytes_copied;
+
+                                       try {
+                                               bytes_copied = encoding.GetBytes (p + index, count, marshal, marshalLength);
+                                       } catch {
+                                               FreeHeap (mem);
+                                               throw;
+                                       }
+
+                                       if (bytes_copied != length_without_null) {
+                                               FreeHeap (mem);
+                                               throw new NotSupportedException ("encoding.GetBytes() doesn't equal encoding.GetByteCount()!");
+                                       }
+
+                                       marshal += length_without_null;
+                                       for (int i = 0; i < null_terminator_count; ++i)
+                                               marshal[i] = 0;
+                               }
                        }
 
                        return mem;