refactor: style, fun, hlint, ...
[mate.git] / Mate / X86TrapHandling.hs
index 882a541f1eefb7b5fe89eed22285d7a2d4cf524e..1761d6b1047671eecad02d487d8de3b3c9456401 100644 (file)
@@ -2,7 +2,10 @@
 {-# LANGUAGE OverloadedStrings #-}
 {-# LANGUAGE ForeignFunctionInterface #-}
 #include "debug.h"
-module Mate.X86TrapHandling where
+module Mate.X86TrapHandling (
+  mateHandler,
+  register_signal
+  ) where
 
 import qualified Data.Map as M
 
@@ -16,20 +19,22 @@ import Mate.ClassPool
 foreign import ccall "register_signal"
   register_signal :: IO ()
 
+data TrapType =
+    StaticMethodCall
+  | StaticFieldAccess
+  | VirtualMethodCall Bool
+  | InterfaceMethodCall Bool
 
-getTrapType :: CUInt -> CUInt -> IO CUInt
-getTrapType signal_from from2 = do
-  tmap <- getTrapMap
+getTrapType :: TrapMap -> CUInt -> CUInt -> TrapType
+getTrapType tmap signal_from from2 =
   case M.lookup (fromIntegral signal_from) tmap of
-    (Just (StaticMethod _)) -> return 0
-    (Just (StaticField _)) -> return 2
+    (Just (StaticMethod _)) -> StaticMethodCall
+    (Just (StaticField _)) -> StaticFieldAccess
     (Just _) -> error "getTrapMap: doesn't happen"
     -- maybe we've a hit on the second `from' value
     Nothing -> case M.lookup (fromIntegral from2) tmap of
-      (Just (VirtualMethod True _)) -> return 1
-      (Just (VirtualMethod False _)) -> return 5
-      (Just (InterfaceMethod True _)) -> return 4
-      (Just (InterfaceMethod False _)) -> return 8
+      (Just (VirtualMethod imm8 _)) -> VirtualMethodCall imm8
+      (Just (InterfaceMethod imm8 _)) -> InterfaceMethodCall imm8
       (Just _) -> error "getTrapType: abort #1 :-("
       Nothing -> error $ "getTrapType: abort #2 :-(" ++ show signal_from ++ ", " ++ show from2 ++ ", " ++ show tmap
 
@@ -37,15 +42,12 @@ foreign export ccall mateHandler :: CUInt -> CUInt -> CUInt -> CUInt -> IO CUInt
 mateHandler :: CUInt -> CUInt -> CUInt -> CUInt -> IO CUInt
 mateHandler eip eax ebx esp = do
   callerAddr <- callerAddrFromStack esp
-  blah <- getTrapType eip callerAddr
-  case blah of
-    0 -> staticCallHandler eip
-    1 -> invokeHandler eax eax esp True
-    5 -> invokeHandler eax eax esp False
-    4 -> invokeHandler eax ebx esp True
-    8 -> invokeHandler eax ebx esp False
-    2 -> staticFieldHandler eip
-    x -> error $ "wtf: " ++ show x
+  tmap <- getTrapMap
+  case getTrapType tmap eip callerAddr of
+    StaticMethodCall  -> staticCallHandler eip
+    StaticFieldAccess -> staticFieldHandler eip
+    VirtualMethodCall imm8   -> invokeHandler eax eax esp imm8
+    InterfaceMethodCall imm8 -> invokeHandler eax ebx esp imm8
 
 staticCallHandler :: CUInt -> IO CUInt
 staticCallHandler eip = do
@@ -57,8 +59,8 @@ staticCallHandler eip = do
   -- in order to produce a SIGILL signal. we also do a safety
   -- check here, if we're really the "owner" of this signal.
   checkMe <- peek imm_ptr
-  case checkMe == 0x90ffff90 of
-    True -> do
+  if checkMe == 0x90ffff90 then
+    do
       entryAddr <- getMethodEntry eip 0
       poke insn_ptr 0xe8 -- call opcode
       -- it's a relative call, so we have to calculate the offset. why "+ 3"?
@@ -67,18 +69,18 @@ staticCallHandler eip = do
       -- (3) offset is calculated wrt to the beginning of the next insn
       poke imm_ptr (entryAddr - (eip + 3))
       return (eip - 2)
-    False -> error "staticCallHandler: something is wrong here. abort\n"
+    else error "staticCallHandler: something is wrong here. abort\n"
 
 staticFieldHandler :: CUInt -> IO CUInt
 staticFieldHandler eip = do
   -- patch the offset here, first two bytes are part of the insn (opcode + reg)
   let imm_ptr = intPtrToPtr (fromIntegral (eip + 2)) :: Ptr CUInt
   checkMe <- peek imm_ptr
-  case checkMe == 0x00000000 of
-    True -> do
+  if checkMe == 0x00000000 then
+    do
       getStaticFieldAddr eip >>= poke imm_ptr
       return eip
-    False -> error "staticFieldHandler: something is wrong here. abort.\n"
+    else error "staticFieldHandler: something is wrong here. abort.\n"
 
 invokeHandler :: CUInt -> CUInt -> CUInt -> Bool -> IO CUInt
 invokeHandler method_table table2patch esp imm8 = do