From f82dbecc763818452667ac568da96b7c5dd7cc97 Mon Sep 17 00:00:00 2001 From: Bernhard Urban Date: Wed, 18 Jul 2012 23:12:20 +0200 Subject: [PATCH] refactor: style, fun, hlint, ... warmup for serious stuff (hopefully) --- Mate/BasicBlocks.hs | 10 ++++---- Mate/ClassPool.hs | 22 +++++++---------- Mate/MethodPool.hs | 12 +++++----- Mate/Utilities.hs | 9 +++---- Mate/X86CodeGen.hs | 13 ++++++----- Mate/X86TrapHandling.hs | 52 +++++++++++++++++++++-------------------- 6 files changed, 56 insertions(+), 62 deletions(-) diff --git a/Mate/BasicBlocks.hs b/Mate/BasicBlocks.hs index ce676c2..fbb61f7 100644 --- a/Mate/BasicBlocks.hs +++ b/Mate/BasicBlocks.hs @@ -41,7 +41,7 @@ printMapBB :: Maybe MapBB -> IO () printMapBB Nothing = putStrLn "No BasicBlock" printMapBB (Just hmap) = do putStr "BlockIDs: " - let keys = fst $ unzip $ M.toList hmap -- M.keys + let keys = M.keys hmap mapM_ (putStr . (flip (++)) ", " . show) keys putStrLn "\n\nBasicBlocks:" printMapBB' keys hmap @@ -106,10 +106,10 @@ parseMethod cls method sig = do testCFG :: Maybe (Method Direct) -> Maybe MapBB -testCFG (Just m) = case attrByName m "Code" of - Nothing -> Nothing - Just bytecode -> Just $ buildCFG $ codeInstructions $ decodeMethod bytecode -testCFG _ = Nothing +testCFG m = do + m' <- m + bytecode <- attrByName m' "Code" + return $ buildCFG $ codeInstructions $ decodeMethod bytecode buildCFG :: [Instruction] -> MapBB diff --git a/Mate/ClassPool.hs b/Mate/ClassPool.hs index 8d88ad3..8788e75 100644 --- a/Mate/ClassPool.hs +++ b/Mate/ClassPool.hs @@ -19,6 +19,7 @@ import Data.Word import Data.Binary import qualified Data.Map as M import qualified Data.Set as S +import Data.List import qualified Data.ByteString.Lazy as B import Data.String.Utils import Control.Monad @@ -197,32 +198,25 @@ calculateFields :: Class Direct -> Maybe ClassInfo -> IO (FieldMap, FieldMap) calculateFields cf superclass = do -- TODO(bernhard): correct sizes. int only atm - -- TODO(bernhard): nicer replacement for `myspan' - let (sfields, ifields) = myspan (S.member ACC_STATIC . fieldAccessFlags) (classFields cf) - myspan :: (a -> Bool) -> [a] -> ([a], [a]) - myspan _ [] = ([],[]) - myspan p (x:xs) - | p x = (x:ns, ni) - | otherwise = (ns, x:ni) - where (ns,ni) = myspan p xs + let (sfields, ifields) = partition (S.member ACC_STATIC . fieldAccessFlags) (classFields cf) - staticbase <- mallocClassData $ fromIntegral (length sfields) * 4 - let i_sb = fromIntegral $ ptrToIntPtr staticbase - let sm = zipbase i_sb sfields let sc_sm = getsupermap superclass ciStaticMap + staticbase <- mallocClassData $ fromIntegral (length sfields) * 4 + let sm = zipbase (fromIntegral $ ptrToIntPtr staticbase) sfields -- new fields "overwrite" old ones, if they have the same name - let staticmap = M.fromList sm `M.union` sc_sm + let staticmap = sm `M.union` sc_sm let sc_im = getsupermap superclass ciFieldMap -- "+ 4" for the method table pointer let max_off = (4+) $ fromIntegral $ M.size sc_im * 4 let im = zipbase max_off ifields -- new fields "overwrite" old ones, if they have the same name - let fieldmap = M.fromList im `M.union` sc_im + let fieldmap = im `M.union` sc_im return (staticmap, fieldmap) where - zipbase base = zipWith (\x y -> (fieldName y, x + base)) [0,4..] + zipbase :: Int32 -> [Field Direct] -> FieldMap + zipbase base = foldr (\(x,y) -> M.insert (fieldName y) (x + base)) M.empty . zip [0,4..] -- helper getsupermap :: Maybe ClassInfo -> (ClassInfo -> FieldMap) -> FieldMap diff --git a/Mate/MethodPool.hs b/Mate/MethodPool.hs index 1403eae..7f0779c 100644 --- a/Mate/MethodPool.hs +++ b/Mate/MethodPool.hs @@ -54,7 +54,7 @@ getMethodEntry signal_from methodtable = do -- figured out the problem yet :/ therefore, I have no -- testcase for replaying the situation. -- setTrapMap $ M.delete w32_from tmap - case M.lookup mi' mmap of + entryaddr <- case M.lookup mi' mmap of Nothing -> do cls <- getClassFile cm printfMp "getMethodEntry(from 0x%08x): no method \"%s\" found. compile it\n" w32_from (show mi') @@ -71,8 +71,7 @@ getMethodEntry signal_from methodtable = do symbol = sym1 ++ "__" ++ toString method ++ "__" ++ sym2 printfMp "native-call: symbol: %s\n" symbol nf <- loadNativeFunction symbol - let w32_nf = fromIntegral nf - setMethodMap $ M.insert mi' w32_nf mmap + setMethodMap $ M.insert mi' nf mmap return nf else do hmap <- parseMethod cls' method sig @@ -83,7 +82,8 @@ getMethodEntry signal_from methodtable = do return $ fromIntegral entry Nothing -> error $ show method ++ " not found. abort" Nothing -> error $ show method ++ " not found. abort" - Just w32 -> return (fromIntegral w32) + Just w32 -> return w32 + return $ fromIntegral entryaddr lookupMethodRecursive :: B.ByteString -> MethodSignature -> [B.ByteString] -> Class Direct -> IO (Maybe (Method Direct, [B.ByteString], Class Direct)) @@ -105,7 +105,7 @@ lookupMethodRecursive name sig clsnames cls = foreign import ccall safe "lookupSymbol" c_lookupSymbol :: CString -> IO (Ptr a) -loadNativeFunction :: String -> IO CUInt +loadNativeFunction :: String -> IO Word32 loadNativeFunction sym = do _ <- loadRawObject "ffi/native.o" -- TODO(bernhard): WTF @@ -127,7 +127,7 @@ loadNativeFunction sym = do addMethodRef :: Word32 -> MethodInfo -> [B.ByteString] -> IO () addMethodRef entry (MethodInfo mmname _ msig) clsnames = do mmap <- getMethodMap - let newmap = M.fromList $ map (\x -> (MethodInfo mmname x msig, entry)) clsnames + let newmap = foldr (\i -> M.insert (MethodInfo mmname i msig) entry) M.empty clsnames setMethodMap $ mmap `M.union` newmap diff --git a/Mate/Utilities.hs b/Mate/Utilities.hs index 89020a5..565d4b1 100644 --- a/Mate/Utilities.hs +++ b/Mate/Utilities.hs @@ -5,6 +5,7 @@ module Mate.Utilities where import Data.Word import qualified Data.Map as M import qualified Data.ByteString.Lazy as B +import Data.List import JVM.ClassFile @@ -57,9 +58,5 @@ methodHaveReturnValue cls idx = case ret of (MethodSignature _ ret) = ntSignature nt lookupMethodSig :: B.ByteString -> MethodSignature -> Class Direct -> Maybe (Method Direct) -lookupMethodSig name sig cls = look (classMethods cls) - where - look [] = Nothing - look (f:fs) - | methodName f == name && methodSignature f == sig = Just f - | otherwise = look fs +lookupMethodSig name sig cls = + find (\x -> methodName x == name && methodSignature x == sig) $ classMethods cls diff --git a/Mate/X86CodeGen.hs b/Mate/X86CodeGen.hs index 80def61..e3fa8d6 100644 --- a/Mate/X86CodeGen.hs +++ b/Mate/X86CodeGen.hs @@ -48,8 +48,9 @@ type CompileInfo = (EntryPoint, BBStarts, Int, TrapMap) emitFromBB :: B.ByteString -> MethodSignature -> Class Direct -> MapBB -> CodeGen e s (CompileInfo, [Instruction]) emitFromBB method sig cls hmap = do - llmap <- sequence [newNamedLabel ("bb_" ++ show x) | (x,_) <- M.toList hmap] - let lmap = zip (Prelude.fst $ unzip $ M.toList hmap) llmap + let keys = M.keys hmap + llmap <- mapM (newNamedLabel . (++) "bb_" . show) keys + let lmap = zip keys llmap ep <- getEntryPoint push ebp mov ebp esp @@ -147,7 +148,7 @@ emitFromBB method sig cls hmap = do -- note, that "mi" has the wrong class reference here. -- we figure that out at run-time, in the methodpool, -- depending on the method-table-ptr - invokeEpilog cpidx offset (\x -> InterfaceMethod x mi) + invokeEpilog cpidx offset (`InterfaceMethod` mi) emit' (INVOKEVIRTUAL cpidx) = do -- get methodInfo entry let mi@(MethodInfo methodname objname msig@(MethodSignature args _)) = buildMethodID cls cpidx @@ -162,7 +163,7 @@ emitFromBB method sig cls hmap = do -- note, that "mi" has the wrong class reference here. -- we figure that out at run-time, in the methodpool, -- depending on the method-table-ptr - invokeEpilog cpidx offset (\x -> VirtualMethod x mi) + invokeEpilog cpidx offset (`VirtualMethod` mi) emit' (PUTSTATIC cpidx) = do pop eax trapaddr <- getCurrentOffset @@ -240,7 +241,7 @@ emitFromBB method sig cls hmap = do emit (INSTANCEOF _) = do pop eax push (1 :: Word32) - emit ATHROW = do -- TODO(bernhard): ... + emit ATHROW = -- TODO(bernhard): ... emit32 (0xffffffff :: Word32) emit I2C = do pop eax @@ -360,7 +361,7 @@ emitFromBB method sig cls hmap = do thisMethodArgCnt :: Word32 thisMethodArgCnt = isNonStatic + fromIntegral (length args) where - (Just m) = lookupMethodSig method sig cls + m = fromJust $ lookupMethodSig method sig cls (MethodSignature args _) = sig isNonStatic = if S.member ACC_STATIC (methodAccessFlags m) then 0 else 1 -- one argument for the this pointer diff --git a/Mate/X86TrapHandling.hs b/Mate/X86TrapHandling.hs index 882a541..1761d6b 100644 --- a/Mate/X86TrapHandling.hs +++ b/Mate/X86TrapHandling.hs @@ -2,7 +2,10 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ForeignFunctionInterface #-} #include "debug.h" -module Mate.X86TrapHandling where +module Mate.X86TrapHandling ( + mateHandler, + register_signal + ) where import qualified Data.Map as M @@ -16,20 +19,22 @@ import Mate.ClassPool foreign import ccall "register_signal" register_signal :: IO () +data TrapType = + StaticMethodCall + | StaticFieldAccess + | VirtualMethodCall Bool + | InterfaceMethodCall Bool -getTrapType :: CUInt -> CUInt -> IO CUInt -getTrapType signal_from from2 = do - tmap <- getTrapMap +getTrapType :: TrapMap -> CUInt -> CUInt -> TrapType +getTrapType tmap signal_from from2 = case M.lookup (fromIntegral signal_from) tmap of - (Just (StaticMethod _)) -> return 0 - (Just (StaticField _)) -> return 2 + (Just (StaticMethod _)) -> StaticMethodCall + (Just (StaticField _)) -> StaticFieldAccess (Just _) -> error "getTrapMap: doesn't happen" -- maybe we've a hit on the second `from' value Nothing -> case M.lookup (fromIntegral from2) tmap of - (Just (VirtualMethod True _)) -> return 1 - (Just (VirtualMethod False _)) -> return 5 - (Just (InterfaceMethod True _)) -> return 4 - (Just (InterfaceMethod False _)) -> return 8 + (Just (VirtualMethod imm8 _)) -> VirtualMethodCall imm8 + (Just (InterfaceMethod imm8 _)) -> InterfaceMethodCall imm8 (Just _) -> error "getTrapType: abort #1 :-(" Nothing -> error $ "getTrapType: abort #2 :-(" ++ show signal_from ++ ", " ++ show from2 ++ ", " ++ show tmap @@ -37,15 +42,12 @@ foreign export ccall mateHandler :: CUInt -> CUInt -> CUInt -> CUInt -> IO CUInt mateHandler :: CUInt -> CUInt -> CUInt -> CUInt -> IO CUInt mateHandler eip eax ebx esp = do callerAddr <- callerAddrFromStack esp - blah <- getTrapType eip callerAddr - case blah of - 0 -> staticCallHandler eip - 1 -> invokeHandler eax eax esp True - 5 -> invokeHandler eax eax esp False - 4 -> invokeHandler eax ebx esp True - 8 -> invokeHandler eax ebx esp False - 2 -> staticFieldHandler eip - x -> error $ "wtf: " ++ show x + tmap <- getTrapMap + case getTrapType tmap eip callerAddr of + StaticMethodCall -> staticCallHandler eip + StaticFieldAccess -> staticFieldHandler eip + VirtualMethodCall imm8 -> invokeHandler eax eax esp imm8 + InterfaceMethodCall imm8 -> invokeHandler eax ebx esp imm8 staticCallHandler :: CUInt -> IO CUInt staticCallHandler eip = do @@ -57,8 +59,8 @@ staticCallHandler eip = do -- in order to produce a SIGILL signal. we also do a safety -- check here, if we're really the "owner" of this signal. checkMe <- peek imm_ptr - case checkMe == 0x90ffff90 of - True -> do + if checkMe == 0x90ffff90 then + do entryAddr <- getMethodEntry eip 0 poke insn_ptr 0xe8 -- call opcode -- it's a relative call, so we have to calculate the offset. why "+ 3"? @@ -67,18 +69,18 @@ staticCallHandler eip = do -- (3) offset is calculated wrt to the beginning of the next insn poke imm_ptr (entryAddr - (eip + 3)) return (eip - 2) - False -> error "staticCallHandler: something is wrong here. abort\n" + else error "staticCallHandler: something is wrong here. abort\n" staticFieldHandler :: CUInt -> IO CUInt staticFieldHandler eip = do -- patch the offset here, first two bytes are part of the insn (opcode + reg) let imm_ptr = intPtrToPtr (fromIntegral (eip + 2)) :: Ptr CUInt checkMe <- peek imm_ptr - case checkMe == 0x00000000 of - True -> do + if checkMe == 0x00000000 then + do getStaticFieldAddr eip >>= poke imm_ptr return eip - False -> error "staticFieldHandler: something is wrong here. abort.\n" + else error "staticFieldHandler: something is wrong here. abort.\n" invokeHandler :: CUInt -> CUInt -> CUInt -> Bool -> IO CUInt invokeHandler method_table table2patch esp imm8 = do -- 2.25.1