From 54a2170d22bb853afa42d87eeeffd8b633efcd36 Mon Sep 17 00:00:00 2001 From: Bernhard Urban Date: Wed, 25 Apr 2012 21:17:33 +0200 Subject: [PATCH] invokevirtual: get the actual class at run-time we don't know the actual class in the CodeGen Monad, so we have to deduce it at run-time. we use the method-table-ptr for that, since it has an unique address which we can use to map the actual class. --- Makefile | 2 ++ Mate/ClassPool.hs | 4 +++ Mate/MethodPool.hs | 77 +++++++++++++++++++++++--------------------- Mate/Types.hs | 24 +++++++++++++- Mate/X86CodeGen.hs | 8 +++-- ffi/trap.c | 16 +++++---- tests/Instance4.java | 19 +++++++++++ 7 files changed, 105 insertions(+), 45 deletions(-) create mode 100644 tests/Instance4.java diff --git a/Makefile b/Makefile index ceeb809..3353137 100644 --- a/Makefile +++ b/Makefile @@ -56,6 +56,8 @@ test: mate $(CLASS_FILES) @printf "should be: 0x%08x 0x%08x\n" 0x198 0x22 ./$< tests/Instance3 | grep mainresult @printf "should be: 0x%08x 0x%08x\n" 0x33 0x44 + ./$< tests/Instance4 | grep mainresult + @printf "should be: 0x%08x 0x%08x\n" 0x1337 0x1337 %.class: %.java $(JAVAC) $< diff --git a/Mate/ClassPool.hs b/Mate/ClassPool.hs index ac8ca41..fd5fc8e 100644 --- a/Mate/ClassPool.hs +++ b/Mate/ClassPool.hs @@ -101,6 +101,10 @@ loadClass path = do printf "methodmap: %s @ %s\n" (show methodmap) (toString path) printf "mbase: 0x%08x\n" mbase + virtual_map <- get_virtualmap >>= ptr2virtualmap + let virtual_map' = M.insert mbase path virtual_map + virtualmap2ptr virtual_map' >>= set_virtualmap + class_map <- get_classmap >>= ptr2classmap let new_ci = ClassInfo path cfile staticmap fieldmap methodmap mbase False let class_map' = M.insert path new_ci class_map diff --git a/Mate/MethodPool.hs b/Mate/MethodPool.hs index 23d8434..e1eb121 100644 --- a/Mate/MethodPool.hs +++ b/Mate/MethodPool.hs @@ -31,45 +31,49 @@ foreign import ccall "dynamic" code_void :: FunPtr (IO ()) -> (IO ()) -foreign export ccall getMethodEntry :: CUInt -> Ptr () -> Ptr () -> IO CUInt -getMethodEntry :: CUInt -> Ptr () -> Ptr () -> IO CUInt -getMethodEntry signal_from ptr_mmap ptr_tmap = do - mmap <- ptr2mmap ptr_mmap - tmap <- ptr2tmap ptr_tmap +foreign export ccall getMethodEntry :: CUInt -> CUInt -> IO CUInt +getMethodEntry :: CUInt -> CUInt -> IO CUInt +getMethodEntry signal_from methodtable = do + mmap <- get_methodmap >>= ptr2mmap + tmap <- get_trapmap >>= ptr2tmap + vmap <- get_virtualmap >>= ptr2virtualmap let w32_from = fromIntegral signal_from let mi = tmap M.! w32_from - case mi of - (MI mi'@(MethodInfo method cm sig)) -> do - case M.lookup mi' mmap of - Nothing -> do - cls <- getClassFile cm - printf "getMethodEntry(from 0x%08x): no method \"%s\" found. compile it\n" w32_from (show mi') - mm <- lookupMethodRecursive method [] cls - case mm of - Just (mm', clsnames, cls') -> do - let flags = methodAccessFlags mm' - case S.member ACC_NATIVE flags of - False -> do - hmap <- parseMethod cls' method - printMapBB hmap - case hmap of - Just hmap' -> do - entry <- compileBB hmap' (MethodInfo method (thisClass cls') sig) - addMethodRef entry mi' clsnames - return $ fromIntegral entry - Nothing -> error $ (show method) ++ " not found. abort" - True -> do - let symbol = (replace "/" "_" $ toString cm) ++ "__" ++ (toString method) ++ "__" ++ (replace "(" "_" (replace ")" "_" $ toString $ encode sig)) - printf "native-call: symbol: %s\n" symbol - nf <- loadNativeFunction symbol - let w32_nf = fromIntegral nf - let mmap' = M.insert mi' w32_nf mmap - mmap2ptr mmap' >>= set_methodmap - return nf - Nothing -> error $ (show method) ++ " not found. abort" - Just w32 -> return (fromIntegral w32) - _ -> error $ "getMethodEntry: no trapInfo. abort" + let mi'@(MethodInfo method cm sig) = + case mi of + (MI x) -> x + (VI (MethodInfo methname _ msig)) -> + (MethodInfo methname (vmap M.! (fromIntegral methodtable)) msig) + _ -> error $ "getMethodEntry: no trapInfo. abort." + case M.lookup mi' mmap of + Nothing -> do + cls <- getClassFile cm + printf "getMethodEntry(from 0x%08x): no method \"%s\" found. compile it\n" w32_from (show mi') + mm <- lookupMethodRecursive method [] cls + case mm of + Just (mm', clsnames, cls') -> do + let flags = methodAccessFlags mm' + case S.member ACC_NATIVE flags of + False -> do + hmap <- parseMethod cls' method + printMapBB hmap + case hmap of + Just hmap' -> do + entry <- compileBB hmap' (MethodInfo method (thisClass cls') sig) + addMethodRef entry mi' clsnames + return $ fromIntegral entry + Nothing -> error $ (show method) ++ " not found. abort" + True -> do + let symbol = (replace "/" "_" $ toString cm) ++ "__" ++ (toString method) ++ "__" ++ (replace "(" "_" (replace ")" "_" $ toString $ encode sig)) + printf "native-call: symbol: %s\n" symbol + nf <- loadNativeFunction symbol + let w32_nf = fromIntegral nf + let mmap' = M.insert mi' w32_nf mmap + mmap2ptr mmap' >>= set_methodmap + return nf + Nothing -> error $ (show method) ++ " not found. abort" + Just w32 -> return (fromIntegral w32) lookupMethodRecursive :: B.ByteString -> [B.ByteString] -> Class Resolved -> IO (Maybe ((Method Resolved, [B.ByteString], Class Resolved))) @@ -115,6 +119,7 @@ initMethodPool = do mmap2ptr M.empty >>= set_methodmap tmap2ptr M.empty >>= set_trapmap classmap2ptr M.empty >>= set_classmap + virtualmap2ptr M.empty >>= set_virtualmap addMethodRef :: Word32 -> MethodInfo -> [B.ByteString] -> IO () diff --git a/Mate/Types.hs b/Mate/Types.hs index 521a269..8977143 100644 --- a/Mate/Types.hs +++ b/Mate/Types.hs @@ -32,7 +32,10 @@ type MapBB = M.Map BlockID BasicBlock -- MethodInfo = relevant information about callee type TMap = M.Map Word32 TrapInfo -data TrapInfo = MI MethodInfo | SFI StaticFieldInfo +data TrapInfo = + MI MethodInfo | + VI MethodInfo | -- for virtual calls + SFI StaticFieldInfo data StaticFieldInfo = StaticFieldInfo { sfiClassName :: B.ByteString, @@ -46,6 +49,11 @@ type ClassMap = M.Map B.ByteString ClassInfo type FieldMap = M.Map B.ByteString Int32 +-- map "methodtable addr" to "classname" +-- we need that to identify the actual type +-- on the invokevirtual insn +type VirtualMap = M.Map Word32 B.ByteString + data ClassInfo = ClassInfo { clName :: B.ByteString, clFile :: Class Resolved, @@ -109,6 +117,12 @@ foreign import ccall "get_classmap" foreign import ccall "set_classmap" set_classmap :: Ptr () -> IO () +foreign import ccall "get_virtualmap" + get_virtualmap :: IO (Ptr ()) + +foreign import ccall "set_virtualmap" + set_virtualmap :: Ptr () -> IO () + -- TODO(bernhard): make some typeclass magic 'n stuff mmap2ptr :: MMap -> IO (Ptr ()) mmap2ptr mmap = do @@ -133,3 +147,11 @@ classmap2ptr cmap = do ptr2classmap :: Ptr () -> IO ClassMap ptr2classmap vmap = deRefStablePtr $ ((castPtrToStablePtr vmap) :: StablePtr cmap) + +virtualmap2ptr :: VirtualMap -> IO (Ptr ()) +virtualmap2ptr cmap = do + ptr_cmap <- newStablePtr cmap + return $ castStablePtrToPtr ptr_cmap + +ptr2virtualmap :: Ptr () -> IO VirtualMap +ptr2virtualmap vmap = deRefStablePtr $ ((castPtrToStablePtr vmap) :: StablePtr cmap) diff --git a/Mate/X86CodeGen.hs b/Mate/X86CodeGen.hs index 7ddcb6a..2c4e70e 100644 --- a/Mate/X86CodeGen.hs +++ b/Mate/X86CodeGen.hs @@ -193,7 +193,7 @@ emitFromBB method cls hmap = do newNamedLabel (show mi) >>= defineLabel -- objref lives somewhere on the argument stack mov eax (Disp ((*4) $ fromIntegral $ length args), esp) - -- get methodtable ref + -- get method-table-ptr mov eax (Disp 0, eax) -- get method offset let nameAndSig = methodname `B.append` (encode msig) @@ -206,7 +206,10 @@ emitFromBB method cls hmap = do when (argcnt > 0) (add esp argcnt) -- push result on stack if method has a return value when (methodHaveReturnValue cls cpidx) (push eax) - return $ Just $ (calladdr, MI mi) + -- note, the "mi" has the wrong class reference here. + -- we figure that out at run-time, in the methodpool, + -- depending on the method-table-ptr + return $ Just $ (calladdr, VI mi) emit' (PUTSTATIC cpidx) = do pop eax trapaddr <- getCurrentOffset @@ -242,6 +245,7 @@ emitFromBB method cls hmap = do -- set method table pointer let mtable = unsafePerformIO $ getMethodTable objname mov (Disp 0, eax) mtable + emit (CHECKCAST _) = nop -- TODO(bernhard): ... emit (BIPUSH val) = push ((fromIntegral val) :: Word32) emit (SIPUSH val) = push ((fromIntegral $ ((fromIntegral val) :: Int16)) :: Word32) emit (ICONST_0) = push (0 :: Word32) diff --git a/ffi/trap.c b/ffi/trap.c index 0a76609..228b2c3 100644 --- a/ffi/trap.c +++ b/ffi/trap.c @@ -19,7 +19,7 @@ #include -unsigned int getMethodEntry(unsigned int, void *, void *); +unsigned int getMethodEntry(unsigned int, unsigned int); unsigned int getStaticFieldAddr(unsigned int, void*); #define NEW_MAP(prefix) \ @@ -38,6 +38,7 @@ unsigned int getStaticFieldAddr(unsigned int, void*); NEW_MAP(method) NEW_MAP(trap) NEW_MAP(class) +NEW_MAP(virtual) void mainresult(unsigned int a) @@ -55,7 +56,7 @@ void callertrap(int nSignal, siginfo_t *info, void *ctx) printf("callertrap: something is wrong here. abort\n"); exit(0); } - unsigned int patchme = getMethodEntry(from, method_map, trap_map); + unsigned int patchme = getMethodEntry(from, 0); unsigned char *insn = (unsigned char *) from; *insn = 0xe8; // call opcode @@ -71,18 +72,21 @@ void staticfieldtrap(int nSignal, siginfo_t *info, void *ctx) /* TODO(bernhard): more generic and cleaner please... */ mcontext_t *mctx = &((ucontext_t *) ctx)->uc_mcontext; unsigned int from = (unsigned int) mctx->gregs[REG_EIP]; - if (from == 0) { // invokevirtual - unsigned int eax = (unsigned int) mctx->gregs[REG_EAX]; + if (from < 0x10000) { // invokevirtual + if (from > 0) { + printf("from: 0x%08x but should be 0 :-(\n", from); + } + unsigned int method_table_ptr = (unsigned int) mctx->gregs[REG_EAX]; unsigned int *esp = (unsigned int *) mctx->gregs[REG_ESP]; /* get actual eip from stack storage */ unsigned int from = (*esp) - 3; unsigned char offset = *((unsigned char *) (*esp) - 1); /* method entry to patch */ - unsigned int *to_patch = (unsigned int*) (eax + offset); + unsigned int *to_patch = (unsigned int*) (method_table_ptr + offset); printf("invokevirtual by 0x%08x with offset 0x%08x\n", from, offset); printf(" to_patch: 0x%08x\n", (unsigned int) to_patch); printf("*to_patch: 0x%08x\n", *to_patch); - *to_patch = getMethodEntry(from, method_map, trap_map); + *to_patch = getMethodEntry(from, method_table_ptr); mctx->gregs[REG_EIP] = *to_patch; printf("*to_patch: 0x%08x\n", *to_patch); } else { diff --git a/tests/Instance4.java b/tests/Instance4.java new file mode 100644 index 0000000..0ecd257 --- /dev/null +++ b/tests/Instance4.java @@ -0,0 +1,19 @@ +package tests; + +public class Instance4 extends Instance2 { + public Instance4() { + x = 0x11; + y = 0x22; + } + + public static void main(String []args) { + Instance2 a = new Instance4(); + a.getX(); // 0x1337 + Instance4 b = (Instance4) a; + b.getX(); // 0x1337; + } + + public int getX() { + return 0x1337; + } +} -- 2.25.1