trap: give disasm some nop's, so it shows the label
[mate.git] / Mate / X86CodeGen.hs
1 {-# LANGUAGE OverloadedStrings #-}
2 {-# LANGUAGE ForeignFunctionInterface #-}
3 module Mate.X86CodeGen where
4
5 import Data.Binary
6 import Data.Int
7 import Data.Maybe
8 import qualified Data.Map as M
9 import qualified Data.ByteString.Lazy as B
10 import Control.Monad
11
12 import Foreign
13 import Foreign.C.Types
14
15 import Text.Printf
16
17 import qualified JVM.Assembler as J
18 import JVM.Assembler hiding (Instruction)
19 import JVM.ClassFile
20 import JVM.Converter
21
22 import Harpy
23 import Harpy.X86Disassembler
24
25 import Mate.BasicBlocks
26 import Mate.Utilities
27
28 foreign import ccall "dynamic"
29    code_int :: FunPtr (CInt -> CInt -> IO CInt) -> (CInt -> CInt -> IO CInt)
30
31 foreign import ccall "getaddr"
32   getaddr :: CUInt
33
34 foreign import ccall "callertrap"
35   callertrap :: IO ()
36
37 foreign import ccall "register_signal"
38   register_signal :: IO ()
39
40 foreign import ccall "get_cmap"
41   get_cmap :: IO (Ptr ())
42
43 foreign import ccall "set_cmap"
44   set_cmap :: Ptr () -> IO ()
45
46 test_01, test_02, test_03 :: IO ()
47 test_01 = do
48   register_signal
49   (entry, end) <- testCase "./tests/Fib.class" "fib"
50   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
51
52   mapM_ (\x -> do
53     result <- code_int entryFuncPtr x 0
54     let iresult :: Int; iresult = fromIntegral result
55     let kk :: String; kk = if iresult == (fib x) then "OK" else "FAIL (" ++ (show (fib x)) ++ ")"
56     printf "result of fib(%2d): %3d\t\t%s\n" (fromIntegral x :: Int) iresult kk
57     ) $ ([0..10] :: [CInt])
58   printf "patched disasm:\n"
59   Right newdisasm <- disassembleBlock entry end
60   mapM_ (putStrLn . showAtt) newdisasm
61   where
62     fib :: CInt -> Int
63     fib n
64       | n <= 1 = 1
65       | otherwise = (fib (n - 1)) + (fib (n - 2))
66
67
68 test_02 = do
69   (entry,_) <- testCase "./tests/While.class" "f"
70   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
71   result <- code_int entryFuncPtr 5 4
72   let iresult :: Int; iresult = fromIntegral result
73   let kk :: String; kk = if iresult == 15 then "OK" else "FAIL"
74   printf "result of f(5,4): %3d\t\t%s\n" iresult kk
75
76   result2 <- code_int entryFuncPtr 4 3
77   let iresult2 :: Int; iresult2 = fromIntegral result2
78   let kk2 :: String; kk2 = if iresult2 == 10 then "OK" else "FAIL"
79   printf "result of f(4,3): %3d\t\t%s\n" iresult2 kk2
80
81
82 test_03 = do
83   (entry,_) <- testCase "./tests/While.class" "g"
84   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
85   result <- code_int entryFuncPtr 5 4
86   let iresult :: Int; iresult = fromIntegral result
87   let kk :: String; kk = if iresult == 15 then "OK" else "FAIL"
88   printf "result of g(5,4): %3d\t\t%s\n" iresult kk
89
90   result2 <- code_int entryFuncPtr 4 3
91   let iresult2 :: Int; iresult2 = fromIntegral result2
92   let kk2 :: String; kk2 = if iresult2 == 10 then "OK" else "FAIL"
93   printf "result of g(4,3): %3d\t\t%s\n" iresult2 kk2
94
95
96 testCase :: String -> B.ByteString -> IO (Ptr Word8, Int)
97 testCase cf method = do
98       cls <- parseClassFile cf
99       hmap <- parseMethod cls method
100       printMapBB hmap
101       case hmap of
102         Nothing -> error "sorry, no code generation"
103         Just hmap' -> do
104               let ebb = emitFromBB cls hmap'
105               (_, Right ((entry, bbstarts, end, _), disasm)) <- runCodeGen ebb () ()
106               let int_entry = ((fromIntegral $ ptrToIntPtr entry) :: Int)
107               printf "disasm:\n"
108               mapM_ (putStrLn . showAtt) disasm
109               printf "basicblocks addresses:\n"
110               let b = map (\(x,y) -> (x,y + int_entry)) $ M.toList bbstarts
111               mapM_ (\(x,y) -> printf "\tBasicBlock %2d starts at 0x%08x\n" x y) b
112               return (entry, end)
113
114 type EntryPoint = Ptr Word8
115 type EntryPointOffset = Int
116 type PatchInfo = (BlockID, EntryPointOffset)
117
118 type BBStarts = M.Map BlockID Int
119
120 type CompileInfo = (EntryPoint, BBStarts, Int, CMap)
121
122 -- B.ByteString: encoded name: <Class>.<methodname><signature>
123 -- Class Resolved: classfile
124 -- Word16: index of invoke-instruction
125 type MethodInfo = (B.ByteString, Class Resolved, Word16)
126
127 -- Word32 = point of method call in generated code
128 -- MethodInfo = relevant information about callee
129 type CMap = M.Map Word32 MethodInfo
130
131
132 emitFromBB :: Class Resolved -> MapBB -> CodeGen e s (CompileInfo, [Instruction])
133 emitFromBB cls hmap =  do
134         llmap <- sequence [newNamedLabel ("bb_" ++ show x) | (x,_) <- M.toList hmap]
135         let lmap = zip (Prelude.fst $ unzip $ M.toList hmap) llmap
136         ep <- getEntryPoint
137         push ebp
138         mov ebp esp
139
140         (calls, bbstarts) <- efBB (0,(hmap M.! 0)) M.empty M.empty lmap
141         d <- disassemble
142         end <- getCodeOffset
143         return ((ep, bbstarts, end, calls), d)
144   where
145   getLabel :: BlockID -> [(BlockID, Label)] -> Label
146   getLabel _ [] = error "label not found!"
147   getLabel i ((x,l):xs) = if i==x then l else getLabel i xs
148
149   efBB :: (BlockID, BasicBlock) -> CMap -> BBStarts -> [(BlockID, Label)] -> CodeGen e s (CMap, BBStarts)
150   efBB (bid, bb) calls bbstarts lmap =
151         if M.member bid bbstarts then
152           return (calls, bbstarts)
153         else do
154           bb_offset <- getCodeOffset
155           let bbstarts' = M.insert bid bb_offset bbstarts
156           defineLabel $ getLabel bid lmap
157           cs <- mapM emit' $ code bb
158           let calls' = calls `M.union` (M.fromList $ catMaybes cs)
159           case successor bb of
160             Return -> return (calls', bbstarts')
161             FallThrough t -> do
162               efBB (t, hmap M.! t) calls' bbstarts' lmap
163             OneTarget t -> do
164               efBB (t, hmap M.! t) calls' bbstarts' lmap
165             TwoTarget t1 t2 -> do
166               (calls'', bbstarts'') <- efBB (t1, hmap M.! t1) calls' bbstarts' lmap
167               efBB (t2, hmap M.! t2) calls'' bbstarts'' lmap
168     -- TODO(bernhard): also use metainformation
169     -- TODO(bernhard): implement `emit' as function which accepts a list of
170     --                 instructions, so we can use patterns for optimizations
171     where
172     emit' :: J.Instruction -> CodeGen e s (Maybe (Word32, MethodInfo))
173     emit' (INVOKESTATIC cpidx) = do
174         ep <- getEntryPoint
175         let w32_ep = (fromIntegral $ ptrToIntPtr ep) :: Word32
176         let l = buildMethodID cls cpidx
177         calladdr <- getCodeOffset
178         let w32_calladdr = w32_ep + (fromIntegral calladdr) :: Word32
179         newNamedLabel (toString l) >>= defineLabel
180         -- causes SIGILL. in the signal handler we patch it to the acutal call.
181         -- place a nop at the end, therefore the disasm doesn't screw up
182         emit32 (0xffff9090 :: Word32) >> emit8 (0x90 :: Word8)
183         -- discard arguments on stack
184         let argcnt = (methodGetArgsCount cls cpidx) * 4
185         when (argcnt > 0) (add esp argcnt)
186         -- push result on stack if method has a return value
187         when (methodHaveReturnValue cls cpidx) (push eax)
188         return $ Just $ (w32_calladdr, (l, cls, cpidx))
189     emit' insn = emit insn >> return Nothing
190
191     emit :: J.Instruction -> CodeGen e s ()
192     emit POP = do -- print dropped value
193         ep <- getEntryPoint
194         let w32_ep = (fromIntegral $ ptrToIntPtr ep) :: Word32
195         -- '5' is the size of the `call' instruction ( + immediate)
196         calladdr <- getCodeOffset
197         let w32_calladdr = 5 + w32_ep + (fromIntegral calladdr) :: Word32
198         let trapaddr = (fromIntegral getaddr :: Word32)
199         call (trapaddr - w32_calladdr)
200         add esp (4 :: Word32)
201     emit (BIPUSH val) = push ((fromIntegral val) :: Word32)
202     emit (SIPUSH val) = push ((fromIntegral $ ((fromIntegral val) :: Int16)) :: Word32)
203     emit (ICONST_0) = push (0 :: Word32)
204     emit (ICONST_1) = push (1 :: Word32)
205     emit (ICONST_2) = push (2 :: Word32)
206     emit (ICONST_4) = push (4 :: Word32)
207     emit (ICONST_5) = push (5 :: Word32)
208     emit (ILOAD_ x) = do
209         push (Disp (cArgs_ x), ebp)
210     emit (ISTORE_ x) = do
211         pop eax
212         mov (Disp (cArgs_ x), ebp) eax
213     emit IADD = do pop ebx; pop eax; add eax ebx; push eax
214     emit ISUB = do pop ebx; pop eax; sub eax ebx; push eax
215     emit IMUL = do pop ebx; pop eax; mul ebx; push eax
216     emit (IINC x imm) = do
217         add (Disp (cArgs x), ebp) (s8_w32 imm)
218
219     emit (IF_ICMP cond _) = do
220         pop eax -- value2
221         pop ebx -- value1
222         cmp ebx eax -- intel syntax is swapped (TODO(bernhard): test that plz)
223         let sid = case successor bb of TwoTarget _ t -> t; _ -> error "bad"
224         let l = getLabel sid lmap
225         case cond of
226           C_EQ -> je  l; C_NE -> jne l
227           C_LT -> jl  l; C_GT -> jg  l
228           C_GE -> jge l; C_LE -> jle l
229
230     emit (IF cond _) = do
231         pop eax -- value1
232         cmp eax (0 :: Word32) -- TODO(bernhard): test that plz
233         let sid = case successor bb of TwoTarget _ t -> t; _ -> error "bad"
234         let l = getLabel sid lmap
235         case cond of
236           C_EQ -> je  l; C_NE -> jne l
237           C_LT -> jl  l; C_GT -> jg  l
238           C_GE -> jge l; C_LE -> jle l
239
240     emit (GOTO _ ) = do
241         let sid = case successor bb of OneTarget t -> t; _ -> error "bad"
242         jmp $ getLabel sid lmap
243
244     emit RETURN = do mov esp ebp; pop ebp; ret
245     emit IRETURN = do
246         pop eax
247         mov esp ebp
248         pop ebp
249         ret
250     emit _ = do cmovbe eax eax -- dummy
251
252   cArgs x = (8 + 4 * (fromIntegral x))
253   cArgs_ x = (8 + 4 * case x of I0 -> 0; I1 -> 1; I2 -> 2; I3 -> 3)
254
255   -- sign extension from w8 to w32 (over s8)
256   --   unfortunately, hs-java is using Word8 everywhere (while
257   --   it should be Int8 actually)
258   s8_w32 :: Word8 -> Word32
259   s8_w32 w8 = fromIntegral s8
260     where s8 = (fromIntegral w8) :: Int8