dea2a4b0eb8e2c6c5575e80da79b5be40b28014f
[mate.git] / Mate / X86CodeGen.hs
1 {-# LANGUAGE OverloadedStrings #-}
2 {-# LANGUAGE ForeignFunctionInterface #-}
3 module Mate.X86CodeGen where
4
5 import Data.Binary
6 import Data.Int
7 import Data.Maybe
8 import qualified Data.Map as M
9 import qualified Data.ByteString.Lazy as B
10
11 import Foreign
12 import Foreign.C.Types
13
14 import Text.Printf
15
16 import qualified JVM.Assembler as J
17 import JVM.Assembler hiding (Instruction)
18 import JVM.ClassFile
19 import JVM.Converter
20
21 import Harpy
22 import Harpy.X86Disassembler
23
24 import Mate.BasicBlocks
25 import Mate.Utilities
26
27 foreign import ccall "dynamic"
28    code_int :: FunPtr (CInt -> CInt -> IO CInt) -> (CInt -> CInt -> IO CInt)
29
30 foreign import ccall "getaddr"
31   getaddr :: CUInt
32
33 foreign import ccall "callertrap"
34   callertrap :: IO ()
35
36 foreign import ccall "register_signal"
37   register_signal :: IO ()
38
39 foreign import ccall "get_cmap"
40   get_cmap :: IO (Ptr ())
41
42 foreign import ccall "set_cmap"
43   set_cmap :: Ptr () -> IO ()
44
45 test_01, test_02, test_03 :: IO ()
46 test_01 = do
47   register_signal
48   (entry, end) <- testCase "./tests/Fib.class" "fib"
49   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
50
51   mapM_ (\x -> do
52     result <- code_int entryFuncPtr x 0
53     let iresult :: Int; iresult = fromIntegral result
54     let kk :: String; kk = if iresult == (fib x) then "OK" else "FAIL (" ++ (show (fib x)) ++ ")"
55     printf "result of fib(%2d): %3d\t\t%s\n" (fromIntegral x :: Int) iresult kk
56     ) $ ([0..10] :: [CInt])
57   printf "patched disasm:\n"
58   Right newdisasm <- disassembleBlock entry end
59   mapM_ (putStrLn . showAtt) newdisasm
60   where
61     fib :: CInt -> Int
62     fib n
63       | n <= 1 = 1
64       | otherwise = (fib (n - 1)) + (fib (n - 2))
65
66
67 test_02 = do
68   (entry,_) <- testCase "./tests/While.class" "f"
69   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
70   result <- code_int entryFuncPtr 5 4
71   let iresult :: Int; iresult = fromIntegral result
72   let kk :: String; kk = if iresult == 15 then "OK" else "FAIL"
73   printf "result of f(5,4): %3d\t\t%s\n" iresult kk
74
75   result2 <- code_int entryFuncPtr 4 3
76   let iresult2 :: Int; iresult2 = fromIntegral result2
77   let kk2 :: String; kk2 = if iresult2 == 10 then "OK" else "FAIL"
78   printf "result of f(4,3): %3d\t\t%s\n" iresult2 kk2
79
80
81 test_03 = do
82   (entry,_) <- testCase "./tests/While.class" "g"
83   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
84   result <- code_int entryFuncPtr 5 4
85   let iresult :: Int; iresult = fromIntegral result
86   let kk :: String; kk = if iresult == 15 then "OK" else "FAIL"
87   printf "result of g(5,4): %3d\t\t%s\n" iresult kk
88
89   result2 <- code_int entryFuncPtr 4 3
90   let iresult2 :: Int; iresult2 = fromIntegral result2
91   let kk2 :: String; kk2 = if iresult2 == 10 then "OK" else "FAIL"
92   printf "result of g(4,3): %3d\t\t%s\n" iresult2 kk2
93
94
95 testCase :: String -> B.ByteString -> IO (Ptr Word8, Int)
96 testCase cf method = do
97       cls <- parseClassFile cf
98       hmap <- parseMethod cls method
99       printMapBB hmap
100       case hmap of
101         Nothing -> error "sorry, no code generation"
102         Just hmap' -> do
103               let ebb = emitFromBB cls hmap'
104               (_, Right ((entry, bbstarts, end, _), disasm)) <- runCodeGen ebb () ()
105               let int_entry = ((fromIntegral $ ptrToIntPtr entry) :: Int)
106               printf "disasm:\n"
107               mapM_ (putStrLn . showAtt) disasm
108               printf "basicblocks addresses:\n"
109               let b = map (\(x,y) -> (x,y + int_entry)) $ M.toList bbstarts
110               mapM_ (\(x,y) -> printf "\tBasicBlock %2d starts at 0x%08x\n" x y) b
111               return (entry, end)
112
113 type EntryPoint = Ptr Word8
114 type EntryPointOffset = Int
115 type PatchInfo = (BlockID, EntryPointOffset)
116
117 type BBStarts = M.Map BlockID Int
118
119 type CompileInfo = (EntryPoint, BBStarts, Int, CMap)
120
121 -- B.ByteString: encoded name: <Class>.<methodname><signature>
122 -- Class Resolved: classfile
123 -- Word16: index of invoke-instruction
124 type MethodInfo = (B.ByteString, Class Resolved, Word16)
125
126 -- Word32 = point of method call in generated code
127 -- MethodInfo = relevant information about callee
128 type CMap = M.Map Word32 MethodInfo
129
130
131 emitFromBB :: Class Resolved -> MapBB -> CodeGen e s (CompileInfo, [Instruction])
132 emitFromBB cls hmap =  do
133         llmap <- sequence [newNamedLabel ("bb_" ++ show x) | (x,_) <- M.toList hmap]
134         let lmap = zip (Prelude.fst $ unzip $ M.toList hmap) llmap
135         ep <- getEntryPoint
136         push ebp
137         mov ebp esp
138
139         (calls, bbstarts) <- efBB (0,(hmap M.! 0)) M.empty M.empty lmap
140         d <- disassemble
141         end <- getCodeOffset
142         return ((ep, bbstarts, end, calls), d)
143   where
144   getLabel :: BlockID -> [(BlockID, Label)] -> Label
145   getLabel _ [] = error "label not found!"
146   getLabel i ((x,l):xs) = if i==x then l else getLabel i xs
147
148   efBB :: (BlockID, BasicBlock) -> CMap -> BBStarts -> [(BlockID, Label)] -> CodeGen e s (CMap, BBStarts)
149   efBB (bid, bb) calls bbstarts lmap =
150         if M.member bid bbstarts then
151           return (calls, bbstarts)
152         else do
153           bb_offset <- getCodeOffset
154           let bbstarts' = M.insert bid bb_offset bbstarts
155           defineLabel $ getLabel bid lmap
156           cs <- mapM emit' $ code bb
157           let calls' = calls `M.union` (M.fromList $ catMaybes cs)
158           case successor bb of
159             Return -> return (calls', bbstarts')
160             FallThrough t -> do
161               efBB (t, hmap M.! t) calls' bbstarts' lmap
162             OneTarget t -> do
163               efBB (t, hmap M.! t) calls' bbstarts' lmap
164             TwoTarget t1 t2 -> do
165               (calls'', bbstarts'') <- efBB (t1, hmap M.! t1) calls' bbstarts' lmap
166               efBB (t2, hmap M.! t2) calls'' bbstarts'' lmap
167     -- TODO(bernhard): also use metainformation
168     -- TODO(bernhard): implement `emit' as function which accepts a list of
169     --                 instructions, so we can use patterns for optimizations
170     where
171     emit' :: J.Instruction -> CodeGen e s (Maybe (Word32, MethodInfo))
172     emit' (INVOKESTATIC cpidx) = do
173         ep <- getEntryPoint
174         let w32_ep = (fromIntegral $ ptrToIntPtr ep) :: Word32
175         let l = buildMethodID cls cpidx
176         calladdr <- getCodeOffset
177         let w32_calladdr = w32_ep + (fromIntegral calladdr) :: Word32
178         newNamedLabel (toString l) >>= defineLabel
179         -- causes SIGILL. in the signal handler we patch it to the acutal call.
180         -- place a nop at the end, therefore the disasm doesn't screw up
181         emit32 (0xffffffff :: Word32) >> emit8 (0x90 :: Word8)
182         -- discard arguments (TODO(bernhard): don't hardcode it)
183         add esp (4 :: Word32)
184         -- push result on stack (TODO(bernhard): if any)
185         push eax
186         return $ Just $ (w32_calladdr, (l, cls, cpidx))
187     emit' insn = emit insn >> return Nothing
188
189     emit :: J.Instruction -> CodeGen e s ()
190     emit POP = do -- print dropped value
191         ep <- getEntryPoint
192         let w32_ep = (fromIntegral $ ptrToIntPtr ep) :: Word32
193         -- '5' is the size of the `call' instruction ( + immediate)
194         calladdr <- getCodeOffset
195         let w32_calladdr = 5 + w32_ep + (fromIntegral calladdr) :: Word32
196         let trapaddr = (fromIntegral getaddr :: Word32)
197         call (trapaddr - w32_calladdr)
198         add esp (4 :: Word32)
199     emit (BIPUSH val) = push ((fromIntegral val) :: Word32)
200     emit (ICONST_0) = push (0 :: Word32)
201     emit (ICONST_1) = push (1 :: Word32)
202     emit (ICONST_2) = push (2 :: Word32)
203     emit (ICONST_5) = push (5 :: Word32)
204     emit (ILOAD_ x) = do
205         push (Disp (cArgs_ x), ebp)
206     emit (ISTORE_ x) = do
207         pop eax
208         mov (Disp (cArgs_ x), ebp) eax
209     emit IADD = do pop ebx; pop eax; add eax ebx; push eax
210     emit ISUB = do pop ebx; pop eax; sub eax ebx; push eax
211     emit IMUL = do pop ebx; pop eax; mul ebx; push eax
212     emit (IINC x imm) = do
213         add (Disp (cArgs x), ebp) (s8_w32 imm)
214
215     emit (IF_ICMP cond _) = do
216         pop eax -- value2
217         pop ebx -- value1
218         cmp ebx eax -- intel syntax is swapped (TODO(bernhard): test that plz)
219         let sid = case successor bb of TwoTarget _ t -> t; _ -> error "bad"
220         let l = getLabel sid lmap
221         case cond of
222           C_EQ -> je  l; C_NE -> jne l
223           C_LT -> jl  l; C_GT -> jg  l
224           C_GE -> jge l; C_LE -> jle l
225
226     emit (IF cond _) = do
227         pop eax -- value1
228         cmp eax (0 :: Word32) -- TODO(bernhard): test that plz
229         let sid = case successor bb of TwoTarget _ t -> t; _ -> error "bad"
230         let l = getLabel sid lmap
231         case cond of
232           C_EQ -> je  l; C_NE -> jne l
233           C_LT -> jl  l; C_GT -> jg  l
234           C_GE -> jge l; C_LE -> jle l
235
236     emit (GOTO _ ) = do
237         let sid = case successor bb of OneTarget t -> t; _ -> error "bad"
238         jmp $ getLabel sid lmap
239
240     emit RETURN = do mov esp ebp; pop ebp; ret
241     emit IRETURN = do
242         pop eax
243         mov esp ebp
244         pop ebp
245         ret
246     emit _ = do cmovbe eax eax -- dummy
247
248   cArgs x = (8 + 4 * (fromIntegral x))
249   cArgs_ x = (8 + 4 * case x of I0 -> 0; I1 -> 1; I2 -> 2; I3 -> 3)
250
251   -- sign extension from w8 to w32 (over s8)
252   --   unfortunately, hs-java is using Word8 everywhere (while
253   --   it should be Int8 actually)
254   s8_w32 :: Word8 -> Word32
255   s8_w32 w8 = fromIntegral s8
256     where s8 = (fromIntegral w8) :: Int8