codegen: patch method calls on-demand via traps
authorBernhard Urban <lewurm@gmail.com>
Sun, 8 Apr 2012 18:21:56 +0000 (20:21 +0200)
committerBernhard Urban <lewurm@gmail.com>
Sun, 8 Apr 2012 18:26:10 +0000 (20:26 +0200)
we can determine the source of an invalid memory access via unix signal
handling. to do so, we write
 > mov (Addr 0) eax    ; 0x8905 0000 0000
which tries to access memory at address 0. upon first execution of this
instruction the signalhandler is called. there, we replace it with
 > nop                 ; 0x90
 > call <target>       ; 0xe8 YYYY YYYY  ; Y = target

at the moment, this just works with Fib.fib() (or other recursive methods)
as the <target> address is more or less hardcoded.

several TODOs:
- determine address of target method in a different way
- after a call, we have to throw away arguments of the call.
  this is hardcoded now.

Mate/X86CodeGen.hs
tests/Fib.java
trap.c

index b9f314fe08b6017c8fe00fd3eac34b95460b024b..7f79fe233c98f18d7c8cbf1742418aae261746e0 100644 (file)
@@ -26,14 +26,38 @@ import Mate.BasicBlocks
 foreign import ccall "dynamic"
    code_int :: FunPtr (CInt -> CInt -> IO CInt) -> (CInt -> CInt -> IO CInt)
 
+foreign import ccall "getaddr"
+  getaddr :: CUInt
+
+foreign import ccall "callertrap"
+  callertrap :: IO ()
+
+foreign import ccall "register_signal"
+  register_signal :: IO ()
+
 test_01, test_02, test_03 :: IO ()
 test_01 = do
-  _ <- testCase "./tests/Fib.class" "fib"
-  return ()
+  register_signal
+  (entry, end) <- testCase "./tests/Fib.class" "fib"
+  let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
+
+  mapM_ (\(x,entryFuncPtr) -> do
+    result <- code_int entryFuncPtr (fromIntegral x) (fromIntegral 0)
+    let iresult :: Int; iresult = fromIntegral result
+    let kk :: String; kk = if iresult == (fib x) then "OK" else "FAIL (" ++ (show (fib x)) ++ ")"
+    printf "result of fib(%2d): %3d\t\t%s\n" x iresult kk
+    ) $ zip ([0..10] :: [Int]) (repeat entryFuncPtr)
+  printf "patched disasm:\n"
+  Right newdisasm <- disassembleBlock entry end
+  mapM_ (putStrLn . showAtt) newdisasm
+  where
+    fib n
+      | n <= 1 = 1
+      | otherwise = (fib (n - 1)) + (fib (n - 2))
 
 
 test_02 = do
-  entry <- testCase "./tests/While.class" "f"
+  (entry,_) <- testCase "./tests/While.class" "f"
   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
   result <- code_int entryFuncPtr (fromIntegral 5) (fromIntegral 4)
   let iresult :: Int; iresult = fromIntegral result
@@ -47,7 +71,7 @@ test_02 = do
 
 
 test_03 = do
-  entry <- testCase "./tests/While.class" "g"
+  (entry,_) <- testCase "./tests/While.class" "g"
   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
   result <- code_int entryFuncPtr (fromIntegral 5) (fromIntegral 4)
   let iresult :: Int; iresult = fromIntegral result
@@ -60,7 +84,7 @@ test_03 = do
   printf "result of g(4,3): %3d\t\t%s\n" iresult kk
 
 
-testCase :: String -> B.ByteString -> IO (Ptr Word8)
+testCase :: String -> B.ByteString -> IO (Ptr Word8, Int)
 testCase cf method = do
       hmap <- parseMethod cf method
       printMapBB hmap
@@ -68,14 +92,14 @@ testCase cf method = do
         Nothing -> error "sorry, no code generation"
         Just hmap -> do
               let ebb = emitFromBB hmap
-              (_, Right ((entry, bbstarts), disasm)) <- runCodeGen ebb () ()
+              (_, Right ((entry, bbstarts, end), disasm)) <- runCodeGen ebb () ()
               let int_entry = ((fromIntegral $ ptrToIntPtr entry) :: Int)
               printf "disasm:\n"
               mapM_ (putStrLn . showAtt) disasm
               printf "basicblocks addresses:\n"
               let b = map (\(x,y) -> (x,y + int_entry)) $ M.toList bbstarts
               mapM_ (\(x,y) -> printf "\tBasicBlock %2d starts at 0x%08x\n" x y) b
-              return entry
+              return (entry, end)
 
 type EntryPoint = Ptr Word8
 type EntryPointOffset = Int
@@ -83,7 +107,7 @@ type PatchInfo = (BlockID, EntryPointOffset)
 
 type BBStarts = M.Map BlockID Int
 
-type CompileInfo = (EntryPoint, BBStarts)
+type CompileInfo = (EntryPoint, BBStarts, Int)
 
 emitFromBB :: MapBB -> CodeGen e s (CompileInfo, [Instruction])
 emitFromBB hmap =  do
@@ -92,12 +116,21 @@ emitFromBB hmap =  do
         ep <- getEntryPoint
         push ebp
         mov ebp esp
+
+        -- TODO(bernhard): remove me. just for PoC here
+        ep <- getEntryPoint
+        let w32_ep = (fromIntegral $ ptrToIntPtr ep) :: Word32
+        push w32_ep
+        -- '5' is the size of the `call' instruction ( + immediate)
+        calladdr <- getCodeOffset
+        let w32_calladdr = 5 + w32_ep + (fromIntegral calladdr) :: Word32
+        let trapaddr = (fromIntegral getaddr :: Word32)
+        call (trapaddr - w32_calladdr)
+
         bbstarts <- efBB (0,(hmap M.! 0)) M.empty lmap
-        mov esp ebp
-        pop ebp
-        ret
         d <- disassemble
-        return ((ep, bbstarts), d)
+        end <- getCodeOffset
+        return ((ep, bbstarts, end), d)
   where
   getLabel :: BlockID -> [(BlockID, Label)] -> Label
   getLabel _ [] = error "label not found!"
@@ -139,7 +172,7 @@ emitFromBB hmap =  do
     emit (IF_ICMP cond _) = do
         pop eax -- value2
         pop ebx -- value1
-        cmp eax ebx -- intel syntax is swapped (TODO(bernhard): test that plz)
+        cmp ebx eax -- intel syntax is swapped (TODO(bernhard): test that plz)
         let sid = case successor bb of TwoTarget _ t -> t
         let l = getLabel sid lmap
         case cond of
@@ -160,8 +193,20 @@ emitFromBB hmap =  do
     emit (GOTO _ ) = do
         let sid = case successor bb of OneTarget t -> t
         jmp $ getLabel sid lmap
-
-    emit IRETURN = do pop eax
+    emit (INVOKESTATIC x) = do
+        -- TODO(bernhard): get and save information about this call
+        -- TODO(bernhard): better try SIGILL instead of SIGSEGV?
+        mov (Addr 0) eax
+        -- discard arguments (TODO(bernhard): don't hardcode it)
+        add esp (4 :: Word32)
+        -- push result on stack (TODO(bernhard): if any)
+        push eax
+
+    emit IRETURN = do
+        pop eax
+        mov esp ebp
+        pop ebp
+        ret
     emit _ = do cmovbe eax eax -- dummy
 
   cArgs x = (8 + 4 * (fromIntegral x))
index b91671634269ef3e2f4a45ea60b9740123d10997..0ae8017c7c1e3eb729477172342c60e48fb38290 100644 (file)
@@ -8,6 +8,7 @@ public class Fib
 
        public static void main(String[] args)
        {
-               fib(10);
+               for (int i = 0; i < 10; i++)
+                       System.out.println(i + ": " + fib(i));
        }
 }
diff --git a/trap.c b/trap.c
index a1e5ddc8696df244d61b3628aee5631cb04d63f6..73b4256da5933fd094405d9bda8b853e03a05380 100644 (file)
--- a/trap.c
+++ b/trap.c
@@ -1,17 +1,50 @@
 #include <stdio.h>
+#include <stdlib.h>
+#include <signal.h>
+#include <asm/ucontext.h>
 
-void callertrap(void)
+unsigned int patchme = 0;
+void print_foo(unsigned int addr)
 {
-       char buf[5];
-       unsigned int *ptr = (unsigned int) (buf + 1);
+       // printf("\n\nprint foo: 0x%08x\n", addr);
+       patchme = addr;
+}
+
+void callertrap(int nSignal, siginfo_t *info, void *ctx)
+{
+       struct ucontext *uctx = (struct ucontext *) ctx;
+
+       printf("callertrap(mctx)  by 0x%08x\n", uctx->uc_mcontext.eip);
+       // printf("callertrap(addr)  by 0x%08x\n", info->si_addr);
+       // printf("callertrap(*esp)  by 0x%08x\n", * (unsigned int *) uctx->uc_mcontext.esp);
 
-       printf("callertrap by 0x%08x\n", *(ptr + 4));
-       /* TODO:
-        * call magic haskell function
-        * with environment information */
+       unsigned int *to_patch = (unsigned int *) (uctx->uc_mcontext.eip + 2);
+       unsigned char *insn = (unsigned int *) (uctx->uc_mcontext.eip);
+       *insn = 0x90; // nop
+       insn++;
+       *insn = 0xe8; // call
+       printf(" to_patch: 0x%08x\n", to_patch);
+       printf("*to_patch: 0x%08x\n", *to_patch);
+       if (*to_patch != 0x00000000) {
+               printf("something is wrong here. abort\n");
+               exit(0);
+       }
+       *to_patch = (unsigned int) patchme - ((unsigned int) insn + 5);
+       printf("*to_patch: 0x%08x\n", *to_patch);
+       uctx->uc_mcontext.eip = insn;
+       // while (1) ;
+}
+
+void register_signal(void)
+{
+       struct sigaction segvaction;
+       segvaction.sa_sigaction = callertrap;
+       sigemptyset(&segvaction.sa_mask);
+       segvaction.sa_flags = SA_SIGINFO | SA_RESTART;
+       sigaction(SIGSEGV, &segvaction, NULL);
 }
 
 unsigned int getaddr(void)
 {
-       return (unsigned int) callertrap;
+       return (unsigned int) print_foo;
 }