0891d546459f947007ccbf12bb8d3faf49a9f0fa
[mate.git] / Mate / X86CodeGen.hs
1 {-# LANGUAGE OverloadedStrings #-}
2 {-# LANGUAGE ForeignFunctionInterface #-}
3 module Mate.X86CodeGen where
4
5 import Data.Binary
6 import Data.Int
7 import Data.Maybe
8 import qualified Data.Map as M
9 import qualified Data.ByteString.Lazy as B
10 import Control.Monad
11
12 import Foreign
13 import Foreign.C.Types
14
15 import Text.Printf
16
17 import qualified JVM.Assembler as J
18 import JVM.Assembler hiding (Instruction)
19 import JVM.ClassFile
20
21 import Harpy
22 import Harpy.X86Disassembler
23
24 import Mate.BasicBlocks
25 import Mate.Types
26 import Mate.Utilities
27 import Mate.ClassPool
28
29 foreign import ccall "dynamic"
30    code_int :: FunPtr (CInt -> CInt -> IO CInt) -> (CInt -> CInt -> IO CInt)
31
32 foreign import ccall "getaddr"
33   getaddr :: CUInt
34
35 foreign import ccall "callertrap"
36   callertrap :: IO ()
37
38 foreign import ccall "register_signal"
39   register_signal :: IO ()
40
41 test_01, test_02, test_03 :: IO ()
42 test_01 = do
43   register_signal
44   (entry, end) <- testCase "./tests/Fib" "fib"
45   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
46
47   mapM_ (\x -> do
48     result <- code_int entryFuncPtr x 0
49     let iresult :: Int; iresult = fromIntegral result
50     let kk :: String; kk = if iresult == (fib x) then "OK" else "FAIL (" ++ (show (fib x)) ++ ")"
51     printf "result of fib(%2d): %3d\t\t%s\n" (fromIntegral x :: Int) iresult kk
52     ) $ ([0..10] :: [CInt])
53   printf "patched disasm:\n"
54   Right newdisasm <- disassembleBlock entry end
55   mapM_ (putStrLn . showAtt) newdisasm
56   where
57     fib :: CInt -> Int
58     fib n
59       | n <= 1 = 1
60       | otherwise = (fib (n - 1)) + (fib (n - 2))
61
62
63 test_02 = do
64   (entry,_) <- testCase "./tests/While" "f"
65   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
66   result <- code_int entryFuncPtr 5 4
67   let iresult :: Int; iresult = fromIntegral result
68   let kk :: String; kk = if iresult == 15 then "OK" else "FAIL"
69   printf "result of f(5,4): %3d\t\t%s\n" iresult kk
70
71   result2 <- code_int entryFuncPtr 4 3
72   let iresult2 :: Int; iresult2 = fromIntegral result2
73   let kk2 :: String; kk2 = if iresult2 == 10 then "OK" else "FAIL"
74   printf "result of f(4,3): %3d\t\t%s\n" iresult2 kk2
75
76
77 test_03 = do
78   (entry,_) <- testCase "./tests/While" "g"
79   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
80   result <- code_int entryFuncPtr 5 4
81   let iresult :: Int; iresult = fromIntegral result
82   let kk :: String; kk = if iresult == 15 then "OK" else "FAIL"
83   printf "result of g(5,4): %3d\t\t%s\n" iresult kk
84
85   result2 <- code_int entryFuncPtr 4 3
86   let iresult2 :: Int; iresult2 = fromIntegral result2
87   let kk2 :: String; kk2 = if iresult2 == 10 then "OK" else "FAIL"
88   printf "result of g(4,3): %3d\t\t%s\n" iresult2 kk2
89
90
91 testCase :: B.ByteString -> B.ByteString -> IO (Ptr Word8, Int)
92 testCase cf method = do
93       cls <- getClassFile cf
94       hmap <- parseMethod cls method
95       printMapBB hmap
96       case hmap of
97         Nothing -> error "sorry, no code generation"
98         Just hmap' -> do
99               let ebb = emitFromBB cls hmap'
100               (_, Right ((entry, bbstarts, end, _), disasm)) <- runCodeGen ebb () ()
101               let int_entry = ((fromIntegral $ ptrToIntPtr entry) :: Int)
102               printf "disasm:\n"
103               mapM_ (putStrLn . showAtt) disasm
104               printf "basicblocks addresses:\n"
105               let b = map (\(x,y) -> (x,y + int_entry)) $ M.toList bbstarts
106               mapM_ (\(x,y) -> printf "\tBasicBlock %2d starts at 0x%08x\n" x y) b
107               return (entry, end)
108
109 type EntryPoint = Ptr Word8
110 type EntryPointOffset = Int
111 type PatchInfo = (BlockID, EntryPointOffset)
112
113 type BBStarts = M.Map BlockID Int
114
115 type CompileInfo = (EntryPoint, BBStarts, Int, TMap)
116
117
118 emitFromBB :: Class Resolved -> MapBB -> CodeGen e s (CompileInfo, [Instruction])
119 emitFromBB cls hmap =  do
120         llmap <- sequence [newNamedLabel ("bb_" ++ show x) | (x,_) <- M.toList hmap]
121         let lmap = zip (Prelude.fst $ unzip $ M.toList hmap) llmap
122         ep <- getEntryPoint
123         push ebp
124         mov ebp esp
125
126         (calls, bbstarts) <- efBB (0,(hmap M.! 0)) M.empty M.empty lmap
127         d <- disassemble
128         end <- getCodeOffset
129         return ((ep, bbstarts, end, calls), d)
130   where
131   getLabel :: BlockID -> [(BlockID, Label)] -> Label
132   getLabel _ [] = error "label not found!"
133   getLabel i ((x,l):xs) = if i==x then l else getLabel i xs
134
135   efBB :: (BlockID, BasicBlock) -> TMap -> BBStarts -> [(BlockID, Label)] -> CodeGen e s (TMap, BBStarts)
136   efBB (bid, bb) calls bbstarts lmap =
137         if M.member bid bbstarts then
138           return (calls, bbstarts)
139         else do
140           bb_offset <- getCodeOffset
141           let bbstarts' = M.insert bid bb_offset bbstarts
142           defineLabel $ getLabel bid lmap
143           cs <- mapM emit' $ code bb
144           let calls' = calls `M.union` (M.fromList $ catMaybes cs)
145           case successor bb of
146             Return -> return (calls', bbstarts')
147             FallThrough t -> do
148               efBB (t, hmap M.! t) calls' bbstarts' lmap
149             OneTarget t -> do
150               efBB (t, hmap M.! t) calls' bbstarts' lmap
151             TwoTarget t1 t2 -> do
152               (calls'', bbstarts'') <- efBB (t1, hmap M.! t1) calls' bbstarts' lmap
153               efBB (t2, hmap M.! t2) calls'' bbstarts'' lmap
154     -- TODO(bernhard): also use metainformation
155     -- TODO(bernhard): implement `emit' as function which accepts a list of
156     --                 instructions, so we can use patterns for optimizations
157     where
158     emit' :: J.Instruction -> CodeGen e s (Maybe (Word32, TrapInfo))
159     emit' (INVOKESTATIC cpidx) = do
160         ep <- getEntryPoint
161         let w32_ep = (fromIntegral $ ptrToIntPtr ep) :: Word32
162         let l = buildMethodID cls cpidx
163         calladdr <- getCodeOffset
164         let w32_calladdr = w32_ep + (fromIntegral calladdr) :: Word32
165         newNamedLabel (show l) >>= defineLabel
166         -- causes SIGILL. in the signal handler we patch it to the acutal call.
167         -- place a nop at the end, therefore the disasm doesn't screw up
168         emit32 (0xffff9090 :: Word32) >> emit8 (0x90 :: Word8)
169         -- discard arguments on stack
170         let argcnt = (methodGetArgsCount cls cpidx) * 4
171         when (argcnt > 0) (add esp argcnt)
172         -- push result on stack if method has a return value
173         when (methodHaveReturnValue cls cpidx) (push eax)
174         return $ Just $ (w32_calladdr, MI l)
175     emit' (PUTSTATIC cpidx) = do
176         pop eax
177         ep <- getEntryPoint
178         let w32_ep = (fromIntegral $ ptrToIntPtr ep) :: Word32
179         trapaddr <- getCodeOffset
180         let w32_trapaddr = w32_ep + (fromIntegral trapaddr)
181         mov (Addr 0x00000000) eax -- it's a trap
182         return $ Just $ (w32_trapaddr, SFI $ buildFieldID cls cpidx)
183     emit' (GETSTATIC cpidx) = do
184         ep <- getEntryPoint
185         let w32_ep = (fromIntegral $ ptrToIntPtr ep) :: Word32
186         trapaddr <- getCodeOffset
187         let w32_trapaddr = w32_ep + (fromIntegral trapaddr)
188         mov eax (Addr 0x00000000) -- it's a trap
189         push eax
190         return $ Just $ (w32_trapaddr, SFI $ buildFieldID cls cpidx)
191     emit' insn = emit insn >> return Nothing
192
193     emit :: J.Instruction -> CodeGen e s ()
194     emit POP = do -- print dropped value
195         ep <- getEntryPoint
196         let w32_ep = (fromIntegral $ ptrToIntPtr ep) :: Word32
197         -- '5' is the size of the `call' instruction ( + immediate)
198         calladdr <- getCodeOffset
199         let w32_calladdr = 5 + w32_ep + (fromIntegral calladdr) :: Word32
200         let trapaddr = (fromIntegral getaddr :: Word32)
201         call (trapaddr - w32_calladdr)
202         add esp (4 :: Word32)
203     emit (BIPUSH val) = push ((fromIntegral val) :: Word32)
204     emit (SIPUSH val) = push ((fromIntegral $ ((fromIntegral val) :: Int16)) :: Word32)
205     emit (ICONST_0) = push (0 :: Word32)
206     emit (ICONST_1) = push (1 :: Word32)
207     emit (ICONST_2) = push (2 :: Word32)
208     emit (ICONST_4) = push (4 :: Word32)
209     emit (ICONST_5) = push (5 :: Word32)
210     emit (ILOAD_ x) = do
211         push (Disp (cArgs_ x), ebp)
212     emit (ISTORE_ x) = do
213         pop eax
214         mov (Disp (cArgs_ x), ebp) eax
215     emit IADD = do pop ebx; pop eax; add eax ebx; push eax
216     emit ISUB = do pop ebx; pop eax; sub eax ebx; push eax
217     emit IMUL = do pop ebx; pop eax; mul ebx; push eax
218     emit (IINC x imm) = do
219         add (Disp (cArgs x), ebp) (s8_w32 imm)
220
221     emit (IF_ICMP cond _) = do
222         pop eax -- value2
223         pop ebx -- value1
224         cmp ebx eax -- intel syntax is swapped (TODO(bernhard): test that plz)
225         let sid = case successor bb of TwoTarget _ t -> t; _ -> error "bad"
226         let l = getLabel sid lmap
227         case cond of
228           C_EQ -> je  l; C_NE -> jne l
229           C_LT -> jl  l; C_GT -> jg  l
230           C_GE -> jge l; C_LE -> jle l
231
232     emit (IF cond _) = do
233         pop eax -- value1
234         cmp eax (0 :: Word32) -- TODO(bernhard): test that plz
235         let sid = case successor bb of TwoTarget _ t -> t; _ -> error "bad"
236         let l = getLabel sid lmap
237         case cond of
238           C_EQ -> je  l; C_NE -> jne l
239           C_LT -> jl  l; C_GT -> jg  l
240           C_GE -> jge l; C_LE -> jle l
241
242     emit (GOTO _ ) = do
243         let sid = case successor bb of OneTarget t -> t; _ -> error "bad"
244         jmp $ getLabel sid lmap
245
246     emit RETURN = do mov esp ebp; pop ebp; ret
247     emit IRETURN = do
248         pop eax
249         mov esp ebp
250         pop ebp
251         ret
252     emit invalid = error $ "insn not implemented yet: " ++ (show invalid)
253
254   cArgs x = (8 + 4 * (fromIntegral x))
255   cArgs_ x = (8 + 4 * case x of I0 -> 0; I1 -> 1; I2 -> 2; I3 -> 3)
256
257   -- sign extension from w8 to w32 (over s8)
258   --   unfortunately, hs-java is using Word8 everywhere (while
259   --   it should be Int8 actually)
260   s8_w32 :: Word8 -> Word32
261   s8_w32 w8 = fromIntegral s8
262     where s8 = (fromIntegral w8) :: Int8