codegen: factor offset calculation
[mate.git] / Mate / X86CodeGen.hs
1 {-# LANGUAGE OverloadedStrings #-}
2 {-# LANGUAGE ForeignFunctionInterface #-}
3 module Mate.X86CodeGen where
4
5 import Data.Binary
6 import Data.Int
7 import Data.Maybe
8 import qualified Data.Map as M
9 import qualified Data.ByteString.Lazy as B
10 import Control.Monad
11
12 import Foreign
13 import Foreign.C.Types
14
15 import Text.Printf
16
17 import qualified JVM.Assembler as J
18 import JVM.Assembler hiding (Instruction)
19 import JVM.ClassFile
20
21 import Harpy
22 import Harpy.X86Disassembler
23
24 import Mate.BasicBlocks
25 import Mate.Types
26 import Mate.Utilities
27 import Mate.ClassPool
28
29 foreign import ccall "dynamic"
30    code_int :: FunPtr (CInt -> CInt -> IO CInt) -> (CInt -> CInt -> IO CInt)
31
32 foreign import ccall "getaddr"
33   getaddr :: CUInt
34
35 foreign import ccall "callertrap"
36   callertrap :: IO ()
37
38 foreign import ccall "register_signal"
39   register_signal :: IO ()
40
41 test_01, test_02, test_03 :: IO ()
42 test_01 = do
43   register_signal
44   (entry, end) <- testCase "./tests/Fib" "fib"
45   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
46
47   mapM_ (\x -> do
48     result <- code_int entryFuncPtr x 0
49     let iresult :: Int; iresult = fromIntegral result
50     let kk :: String; kk = if iresult == (fib x) then "OK" else "FAIL (" ++ (show (fib x)) ++ ")"
51     printf "result of fib(%2d): %3d\t\t%s\n" (fromIntegral x :: Int) iresult kk
52     ) $ ([0..10] :: [CInt])
53   printf "patched disasm:\n"
54   Right newdisasm <- disassembleBlock entry end
55   mapM_ (putStrLn . showAtt) newdisasm
56   where
57     fib :: CInt -> Int
58     fib n
59       | n <= 1 = 1
60       | otherwise = (fib (n - 1)) + (fib (n - 2))
61
62
63 test_02 = do
64   (entry,_) <- testCase "./tests/While" "f"
65   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
66   result <- code_int entryFuncPtr 5 4
67   let iresult :: Int; iresult = fromIntegral result
68   let kk :: String; kk = if iresult == 15 then "OK" else "FAIL"
69   printf "result of f(5,4): %3d\t\t%s\n" iresult kk
70
71   result2 <- code_int entryFuncPtr 4 3
72   let iresult2 :: Int; iresult2 = fromIntegral result2
73   let kk2 :: String; kk2 = if iresult2 == 10 then "OK" else "FAIL"
74   printf "result of f(4,3): %3d\t\t%s\n" iresult2 kk2
75
76
77 test_03 = do
78   (entry,_) <- testCase "./tests/While" "g"
79   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
80   result <- code_int entryFuncPtr 5 4
81   let iresult :: Int; iresult = fromIntegral result
82   let kk :: String; kk = if iresult == 15 then "OK" else "FAIL"
83   printf "result of g(5,4): %3d\t\t%s\n" iresult kk
84
85   result2 <- code_int entryFuncPtr 4 3
86   let iresult2 :: Int; iresult2 = fromIntegral result2
87   let kk2 :: String; kk2 = if iresult2 == 10 then "OK" else "FAIL"
88   printf "result of g(4,3): %3d\t\t%s\n" iresult2 kk2
89
90
91 testCase :: B.ByteString -> B.ByteString -> IO (Ptr Word8, Int)
92 testCase cf method = do
93       cls <- getClassFile cf
94       hmap <- parseMethod cls method
95       printMapBB hmap
96       case hmap of
97         Nothing -> error "sorry, no code generation"
98         Just hmap' -> do
99               let ebb = emitFromBB cls hmap'
100               (_, Right ((entry, bbstarts, end, _), disasm)) <- runCodeGen ebb () ()
101               let int_entry = ((fromIntegral $ ptrToIntPtr entry) :: Int)
102               printf "disasm:\n"
103               mapM_ (putStrLn . showAtt) disasm
104               printf "basicblocks addresses:\n"
105               let b = map (\(x,y) -> (x,y + int_entry)) $ M.toList bbstarts
106               mapM_ (\(x,y) -> printf "\tBasicBlock %2d starts at 0x%08x\n" x y) b
107               return (entry, end)
108
109 type EntryPoint = Ptr Word8
110 type EntryPointOffset = Int
111 type PatchInfo = (BlockID, EntryPointOffset)
112
113 type BBStarts = M.Map BlockID Int
114
115 type CompileInfo = (EntryPoint, BBStarts, Int, TMap)
116
117
118 emitFromBB :: Class Resolved -> MapBB -> CodeGen e s (CompileInfo, [Instruction])
119 emitFromBB cls hmap =  do
120         llmap <- sequence [newNamedLabel ("bb_" ++ show x) | (x,_) <- M.toList hmap]
121         let lmap = zip (Prelude.fst $ unzip $ M.toList hmap) llmap
122         ep <- getEntryPoint
123         push ebp
124         mov ebp esp
125
126         (calls, bbstarts) <- efBB (0,(hmap M.! 0)) M.empty M.empty lmap
127         d <- disassemble
128         end <- getCodeOffset
129         return ((ep, bbstarts, end, calls), d)
130   where
131   getLabel :: BlockID -> [(BlockID, Label)] -> Label
132   getLabel _ [] = error "label not found!"
133   getLabel i ((x,l):xs) = if i==x then l else getLabel i xs
134
135   efBB :: (BlockID, BasicBlock) -> TMap -> BBStarts -> [(BlockID, Label)] -> CodeGen e s (TMap, BBStarts)
136   efBB (bid, bb) calls bbstarts lmap =
137         if M.member bid bbstarts then
138           return (calls, bbstarts)
139         else do
140           bb_offset <- getCodeOffset
141           let bbstarts' = M.insert bid bb_offset bbstarts
142           defineLabel $ getLabel bid lmap
143           cs <- mapM emit' $ code bb
144           let calls' = calls `M.union` (M.fromList $ catMaybes cs)
145           case successor bb of
146             Return -> return (calls', bbstarts')
147             FallThrough t -> do
148               efBB (t, hmap M.! t) calls' bbstarts' lmap
149             OneTarget t -> do
150               efBB (t, hmap M.! t) calls' bbstarts' lmap
151             TwoTarget t1 t2 -> do
152               (calls'', bbstarts'') <- efBB (t1, hmap M.! t1) calls' bbstarts' lmap
153               efBB (t2, hmap M.! t2) calls'' bbstarts'' lmap
154     -- TODO(bernhard): also use metainformation
155     -- TODO(bernhard): implement `emit' as function which accepts a list of
156     --                 instructions, so we can use patterns for optimizations
157     where
158     getCurrentOffset :: CodeGen e s (Word32)
159     getCurrentOffset = do
160       ep <- getEntryPoint
161       let w32_ep = (fromIntegral $ ptrToIntPtr ep) :: Word32
162       offset <- getCodeOffset
163       return $ w32_ep + (fromIntegral offset)
164
165     emit' :: J.Instruction -> CodeGen e s (Maybe (Word32, TrapInfo))
166     emit' (INVOKESTATIC cpidx) = do
167         let l = buildMethodID cls cpidx
168         calladdr <- getCurrentOffset
169         newNamedLabel (show l) >>= defineLabel
170         -- causes SIGILL. in the signal handler we patch it to the acutal call.
171         -- place a nop at the end, therefore the disasm doesn't screw up
172         emit32 (0xffff9090 :: Word32) >> emit8 (0x90 :: Word8)
173         -- discard arguments on stack
174         let argcnt = (methodGetArgsCount cls cpidx) * 4
175         when (argcnt > 0) (add esp argcnt)
176         -- push result on stack if method has a return value
177         when (methodHaveReturnValue cls cpidx) (push eax)
178         return $ Just $ (calladdr, MI l)
179     emit' (PUTSTATIC cpidx) = do
180         pop eax
181         trapaddr <- getCurrentOffset
182         mov (Addr 0x00000000) eax -- it's a trap
183         return $ Just $ (trapaddr, SFI $ buildFieldID cls cpidx)
184     emit' (GETSTATIC cpidx) = do
185         trapaddr <- getCurrentOffset
186         mov eax (Addr 0x00000000) -- it's a trap
187         push eax
188         return $ Just $ (trapaddr, SFI $ buildFieldID cls cpidx)
189     emit' insn = emit insn >> return Nothing
190
191     emit :: J.Instruction -> CodeGen e s ()
192     emit POP = do -- print dropped value
193         calladdr <- getCurrentOffset
194         -- '5' is the size of the `call' instruction ( + immediate)
195         let w32_calladdr = 5 + calladdr
196         let trapaddr = (fromIntegral getaddr :: Word32)
197         call (trapaddr - w32_calladdr)
198         add esp (4 :: Word32)
199     emit (BIPUSH val) = push ((fromIntegral val) :: Word32)
200     emit (SIPUSH val) = push ((fromIntegral $ ((fromIntegral val) :: Int16)) :: Word32)
201     emit (ICONST_0) = push (0 :: Word32)
202     emit (ICONST_1) = push (1 :: Word32)
203     emit (ICONST_2) = push (2 :: Word32)
204     emit (ICONST_4) = push (4 :: Word32)
205     emit (ICONST_5) = push (5 :: Word32)
206     emit (ILOAD_ x) = do
207         push (Disp (cArgs_ x), ebp)
208     emit (ISTORE_ x) = do
209         pop eax
210         mov (Disp (cArgs_ x), ebp) eax
211     emit IADD = do pop ebx; pop eax; add eax ebx; push eax
212     emit ISUB = do pop ebx; pop eax; sub eax ebx; push eax
213     emit IMUL = do pop ebx; pop eax; mul ebx; push eax
214     emit (IINC x imm) = do
215         add (Disp (cArgs x), ebp) (s8_w32 imm)
216
217     emit (IF_ICMP cond _) = do
218         pop eax -- value2
219         pop ebx -- value1
220         cmp ebx eax -- intel syntax is swapped (TODO(bernhard): test that plz)
221         let sid = case successor bb of TwoTarget _ t -> t; _ -> error "bad"
222         let l = getLabel sid lmap
223         case cond of
224           C_EQ -> je  l; C_NE -> jne l
225           C_LT -> jl  l; C_GT -> jg  l
226           C_GE -> jge l; C_LE -> jle l
227
228     emit (IF cond _) = do
229         pop eax -- value1
230         cmp eax (0 :: Word32) -- TODO(bernhard): test that plz
231         let sid = case successor bb of TwoTarget _ t -> t; _ -> error "bad"
232         let l = getLabel sid lmap
233         case cond of
234           C_EQ -> je  l; C_NE -> jne l
235           C_LT -> jl  l; C_GT -> jg  l
236           C_GE -> jge l; C_LE -> jle l
237
238     emit (GOTO _ ) = do
239         let sid = case successor bb of OneTarget t -> t; _ -> error "bad"
240         jmp $ getLabel sid lmap
241
242     emit RETURN = do mov esp ebp; pop ebp; ret
243     emit IRETURN = do
244         pop eax
245         mov esp ebp
246         pop ebp
247         ret
248     emit invalid = error $ "insn not implemented yet: " ++ (show invalid)
249
250   cArgs x = (8 + 4 * (fromIntegral x))
251   cArgs_ x = (8 + 4 * case x of I0 -> 0; I1 -> 1; I2 -> 2; I3 -> 3)
252
253   -- sign extension from w8 to w32 (over s8)
254   --   unfortunately, hs-java is using Word8 everywhere (while
255   --   it should be Int8 actually)
256   s8_w32 :: Word8 -> Word32
257   s8_w32 w8 = fromIntegral s8
258     where s8 = (fromIntegral w8) :: Int8