invokevirtual: get the actual class at run-time
authorBernhard Urban <lewurm@gmail.com>
Wed, 25 Apr 2012 19:17:33 +0000 (21:17 +0200)
committerBernhard Urban <lewurm@gmail.com>
Wed, 25 Apr 2012 21:53:09 +0000 (23:53 +0200)
we don't know the actual class in the CodeGen Monad, so we have to
deduce it at run-time. we use the method-table-ptr for that, since
it has an unique address which we can use to map the actual class.

Makefile
Mate/ClassPool.hs
Mate/MethodPool.hs
Mate/Types.hs
Mate/X86CodeGen.hs
ffi/trap.c
tests/Instance4.java [new file with mode: 0644]

index ceeb809633e338a15e7892e179263a1a1a00d1fd..335313767b402dc43dd49d18818183e062534b3c 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -56,6 +56,8 @@ test: mate $(CLASS_FILES)
        @printf "should be:  0x%08x 0x%08x\n" 0x198 0x22
        ./$< tests/Instance3 | grep mainresult
        @printf "should be:  0x%08x 0x%08x\n" 0x33 0x44
+       ./$< tests/Instance4 | grep mainresult
+       @printf "should be:  0x%08x 0x%08x\n" 0x1337 0x1337
 
 %.class: %.java
        $(JAVAC) $<
index ac8ca411b3ac8d5ab9e720dfdcd7ea8832fc1f02..fd5fc8ebb328bb897e9cf3f15202ef0000701969 100644 (file)
@@ -101,6 +101,10 @@ loadClass path = do
   printf "methodmap: %s @ %s\n" (show methodmap) (toString path)
   printf "mbase: 0x%08x\n" mbase
 
+  virtual_map <- get_virtualmap >>= ptr2virtualmap
+  let virtual_map' = M.insert mbase path virtual_map
+  virtualmap2ptr virtual_map' >>= set_virtualmap
+
   class_map <- get_classmap >>= ptr2classmap
   let new_ci = ClassInfo path cfile staticmap fieldmap methodmap mbase False
   let class_map' = M.insert path new_ci class_map
index 23d8434b62670206a81f492c09caf959e2f0dd7b..e1eb1214323e989ec09f944a9d6bf0252391d119 100644 (file)
@@ -31,45 +31,49 @@ foreign import ccall "dynamic"
    code_void :: FunPtr (IO ()) -> (IO ())
 
 
-foreign export ccall getMethodEntry :: CUInt -> Ptr () -> Ptr () -> IO CUInt
-getMethodEntry :: CUInt -> Ptr () -> Ptr () -> IO CUInt
-getMethodEntry signal_from ptr_mmap ptr_tmap = do
-  mmap <- ptr2mmap ptr_mmap
-  tmap <- ptr2tmap ptr_tmap
+foreign export ccall getMethodEntry :: CUInt -> CUInt -> IO CUInt
+getMethodEntry :: CUInt -> CUInt -> IO CUInt
+getMethodEntry signal_from methodtable = do
+  mmap <- get_methodmap >>= ptr2mmap
+  tmap <- get_trapmap >>= ptr2tmap
+  vmap <- get_virtualmap >>= ptr2virtualmap
 
   let w32_from = fromIntegral signal_from
   let mi = tmap M.! w32_from
-  case mi of
-    (MI mi'@(MethodInfo method cm sig)) -> do
-      case M.lookup mi' mmap of
-        Nothing -> do
-          cls <- getClassFile cm
-          printf "getMethodEntry(from 0x%08x): no method \"%s\" found. compile it\n" w32_from (show mi')
-          mm <- lookupMethodRecursive method [] cls
-          case mm of
-            Just (mm', clsnames, cls') -> do
-                let flags = methodAccessFlags mm'
-                case S.member ACC_NATIVE flags of
-                  False -> do
-                    hmap <- parseMethod cls' method
-                    printMapBB hmap
-                    case hmap of
-                      Just hmap' -> do
-                        entry <- compileBB hmap' (MethodInfo method (thisClass cls') sig)
-                        addMethodRef entry mi' clsnames
-                        return $ fromIntegral entry
-                      Nothing -> error $ (show method) ++ " not found. abort"
-                  True -> do
-                    let symbol = (replace "/" "_" $ toString cm) ++ "__" ++ (toString method) ++ "__" ++ (replace "(" "_" (replace ")" "_" $ toString $ encode sig))
-                    printf "native-call: symbol: %s\n" symbol
-                    nf <- loadNativeFunction symbol
-                    let w32_nf = fromIntegral nf
-                    let mmap' = M.insert mi' w32_nf mmap
-                    mmap2ptr mmap' >>= set_methodmap
-                    return nf
-            Nothing -> error $ (show method) ++ " not found. abort"
-        Just w32 -> return (fromIntegral w32)
-    _ -> error $ "getMethodEntry: no trapInfo. abort"
+  let mi'@(MethodInfo method cm sig) =
+        case mi of
+          (MI x) -> x
+          (VI (MethodInfo methname _ msig)) ->
+              (MethodInfo methname (vmap M.! (fromIntegral methodtable)) msig)
+          _ -> error $ "getMethodEntry: no trapInfo. abort."
+  case M.lookup mi' mmap of
+    Nothing -> do
+      cls <- getClassFile cm
+      printf "getMethodEntry(from 0x%08x): no method \"%s\" found. compile it\n" w32_from (show mi')
+      mm <- lookupMethodRecursive method [] cls
+      case mm of
+        Just (mm', clsnames, cls') -> do
+            let flags = methodAccessFlags mm'
+            case S.member ACC_NATIVE flags of
+              False -> do
+                hmap <- parseMethod cls' method
+                printMapBB hmap
+                case hmap of
+                  Just hmap' -> do
+                    entry <- compileBB hmap' (MethodInfo method (thisClass cls') sig)
+                    addMethodRef entry mi' clsnames
+                    return $ fromIntegral entry
+                  Nothing -> error $ (show method) ++ " not found. abort"
+              True -> do
+                let symbol = (replace "/" "_" $ toString cm) ++ "__" ++ (toString method) ++ "__" ++ (replace "(" "_" (replace ")" "_" $ toString $ encode sig))
+                printf "native-call: symbol: %s\n" symbol
+                nf <- loadNativeFunction symbol
+                let w32_nf = fromIntegral nf
+                let mmap' = M.insert mi' w32_nf mmap
+                mmap2ptr mmap' >>= set_methodmap
+                return nf
+        Nothing -> error $ (show method) ++ " not found. abort"
+    Just w32 -> return (fromIntegral w32)
 
 lookupMethodRecursive :: B.ByteString -> [B.ByteString] -> Class Resolved
                          -> IO (Maybe ((Method Resolved, [B.ByteString], Class Resolved)))
@@ -115,6 +119,7 @@ initMethodPool = do
   mmap2ptr M.empty >>= set_methodmap
   tmap2ptr M.empty >>= set_trapmap
   classmap2ptr M.empty >>= set_classmap
+  virtualmap2ptr M.empty >>= set_virtualmap
 
 
 addMethodRef :: Word32 -> MethodInfo -> [B.ByteString] -> IO ()
index 521a269133c1383a754ec0a547233458c1dd372d..8977143ca1e2b540860e780fe14d17dc86a0518c 100644 (file)
@@ -32,7 +32,10 @@ type MapBB = M.Map BlockID BasicBlock
 -- MethodInfo = relevant information about callee
 type TMap = M.Map Word32 TrapInfo
 
-data TrapInfo = MI MethodInfo | SFI StaticFieldInfo
+data TrapInfo =
+  MI MethodInfo |
+  VI MethodInfo | -- for virtual calls
+  SFI StaticFieldInfo
 
 data StaticFieldInfo = StaticFieldInfo {
   sfiClassName :: B.ByteString,
@@ -46,6 +49,11 @@ type ClassMap = M.Map B.ByteString ClassInfo
 
 type FieldMap = M.Map B.ByteString Int32
 
+-- map "methodtable addr" to "classname"
+-- we need that to identify the actual type
+-- on the invokevirtual insn
+type VirtualMap = M.Map Word32 B.ByteString
+
 data ClassInfo = ClassInfo {
   clName :: B.ByteString,
   clFile :: Class Resolved,
@@ -109,6 +117,12 @@ foreign import ccall "get_classmap"
 foreign import ccall "set_classmap"
   set_classmap :: Ptr () -> IO ()
 
+foreign import ccall "get_virtualmap"
+  get_virtualmap :: IO (Ptr ())
+
+foreign import ccall "set_virtualmap"
+  set_virtualmap :: Ptr () -> IO ()
+
 -- TODO(bernhard): make some typeclass magic 'n stuff
 mmap2ptr :: MMap -> IO (Ptr ())
 mmap2ptr mmap = do
@@ -133,3 +147,11 @@ classmap2ptr cmap = do
 
 ptr2classmap :: Ptr () -> IO ClassMap
 ptr2classmap vmap = deRefStablePtr $ ((castPtrToStablePtr vmap) :: StablePtr cmap)
+
+virtualmap2ptr :: VirtualMap -> IO (Ptr ())
+virtualmap2ptr cmap = do
+  ptr_cmap <- newStablePtr cmap
+  return $ castStablePtrToPtr ptr_cmap
+
+ptr2virtualmap :: Ptr () -> IO VirtualMap
+ptr2virtualmap vmap = deRefStablePtr $ ((castPtrToStablePtr vmap) :: StablePtr cmap)
index 7ddcb6a8b0cd257068f9477092db51d9f731f803..2c4e70e907289e4c87fd8326e6a66b6802697f64 100644 (file)
@@ -193,7 +193,7 @@ emitFromBB method cls hmap =  do
         newNamedLabel (show mi) >>= defineLabel
         -- objref lives somewhere on the argument stack
         mov eax (Disp ((*4) $ fromIntegral $ length args), esp)
-        -- get methodtable ref
+        -- get method-table-ptr
         mov eax (Disp 0, eax)
         -- get method offset
         let nameAndSig = methodname `B.append` (encode msig)
@@ -206,7 +206,10 @@ emitFromBB method cls hmap =  do
         when (argcnt > 0) (add esp argcnt)
         -- push result on stack if method has a return value
         when (methodHaveReturnValue cls cpidx) (push eax)
-        return $ Just $ (calladdr, MI mi)
+        -- note, the "mi" has the wrong class reference here.
+        -- we figure that out at run-time, in the methodpool,
+        -- depending on the method-table-ptr
+        return $ Just $ (calladdr, VI mi)
     emit' (PUTSTATIC cpidx) = do
         pop eax
         trapaddr <- getCurrentOffset
@@ -242,6 +245,7 @@ emitFromBB method cls hmap =  do
         -- set method table pointer
         let mtable = unsafePerformIO $ getMethodTable objname
         mov (Disp 0, eax) mtable
+    emit (CHECKCAST _) = nop -- TODO(bernhard): ...
     emit (BIPUSH val) = push ((fromIntegral val) :: Word32)
     emit (SIPUSH val) = push ((fromIntegral $ ((fromIntegral val) :: Int16)) :: Word32)
     emit (ICONST_0) = push (0 :: Word32)
index 0a766094351c10a19affd234a859497882caff05..228b2c36f3e96d1122f9b7ec74ed631ca2d8b5c3 100644 (file)
@@ -19,7 +19,7 @@
 
 #include <sys/ucontext.h>
 
-unsigned int getMethodEntry(unsigned int, void *, void *);
+unsigned int getMethodEntry(unsigned int, unsigned int);
 unsigned int getStaticFieldAddr(unsigned int, void*);
 
 #define NEW_MAP(prefix) \
@@ -38,6 +38,7 @@ unsigned int getStaticFieldAddr(unsigned int, void*);
 NEW_MAP(method)
 NEW_MAP(trap)
 NEW_MAP(class)
+NEW_MAP(virtual)
 
 
 void mainresult(unsigned int a)
@@ -55,7 +56,7 @@ void callertrap(int nSignal, siginfo_t *info, void *ctx)
                printf("callertrap: something is wrong here. abort\n");
                exit(0);
        }
-       unsigned int patchme = getMethodEntry(from, method_map, trap_map);
+       unsigned int patchme = getMethodEntry(from, 0);
 
        unsigned char *insn = (unsigned char *) from;
        *insn = 0xe8; // call opcode
@@ -71,18 +72,21 @@ void staticfieldtrap(int nSignal, siginfo_t *info, void *ctx)
        /* TODO(bernhard): more generic and cleaner please... */
        mcontext_t *mctx = &((ucontext_t *) ctx)->uc_mcontext;
        unsigned int from = (unsigned int) mctx->gregs[REG_EIP];
-       if (from == 0) { // invokevirtual
-               unsigned int eax = (unsigned int) mctx->gregs[REG_EAX];
+       if (from < 0x10000) { // invokevirtual
+               if (from > 0) {
+                       printf("from: 0x%08x but should be 0 :-(\n", from);
+               }
+               unsigned int method_table_ptr = (unsigned int) mctx->gregs[REG_EAX];
                unsigned int *esp = (unsigned int *) mctx->gregs[REG_ESP];
                /* get actual eip from stack storage */
                unsigned int from = (*esp) - 3;
                unsigned char offset = *((unsigned char *) (*esp) - 1);
                /* method entry to patch */
-               unsigned int *to_patch = (unsigned int*) (eax + offset);
+               unsigned int *to_patch = (unsigned int*) (method_table_ptr + offset);
                printf("invokevirtual by 0x%08x with offset 0x%08x\n", from, offset);
                printf(" to_patch: 0x%08x\n", (unsigned int) to_patch);
                printf("*to_patch: 0x%08x\n", *to_patch);
-               *to_patch = getMethodEntry(from, method_map, trap_map);
+               *to_patch = getMethodEntry(from, method_table_ptr);
                mctx->gregs[REG_EIP] = *to_patch;
                printf("*to_patch: 0x%08x\n", *to_patch);
        } else {
diff --git a/tests/Instance4.java b/tests/Instance4.java
new file mode 100644 (file)
index 0000000..0ecd257
--- /dev/null
@@ -0,0 +1,19 @@
+package tests;
+
+public class Instance4 extends Instance2 {
+       public Instance4() {
+               x = 0x11;
+               y = 0x22;
+       }
+
+       public static void main(String []args) {
+               Instance2 a = new Instance4();
+               a.getX(); // 0x1337
+               Instance4 b = (Instance4) a;
+               b.getX(); // 0x1337;
+       }
+
+       public int getX() {
+               return 0x1337;
+       }
+}