modules: move (public) datatypes into a new module
[mate.git] / Mate / X86CodeGen.hs
1 {-# LANGUAGE OverloadedStrings #-}
2 {-# LANGUAGE ForeignFunctionInterface #-}
3 module Mate.X86CodeGen where
4
5 import Data.Binary
6 import Data.Int
7 import Data.Maybe
8 import qualified Data.Map as M
9 import qualified Data.ByteString.Lazy as B
10 import Control.Monad
11
12 import Foreign
13 import Foreign.C.Types
14
15 import Text.Printf
16
17 import qualified JVM.Assembler as J
18 import JVM.Assembler hiding (Instruction)
19 import JVM.ClassFile
20 import JVM.Converter
21
22 import Harpy
23 import Harpy.X86Disassembler
24
25 import Mate.BasicBlocks
26 import Mate.Types
27 import Mate.Utilities
28
29 foreign import ccall "dynamic"
30    code_int :: FunPtr (CInt -> CInt -> IO CInt) -> (CInt -> CInt -> IO CInt)
31
32 foreign import ccall "getaddr"
33   getaddr :: CUInt
34
35 foreign import ccall "callertrap"
36   callertrap :: IO ()
37
38 foreign import ccall "register_signal"
39   register_signal :: IO ()
40
41 foreign import ccall "get_cmap"
42   get_cmap :: IO (Ptr ())
43
44 foreign import ccall "set_cmap"
45   set_cmap :: Ptr () -> IO ()
46
47 test_01, test_02, test_03 :: IO ()
48 test_01 = do
49   register_signal
50   (entry, end) <- testCase "./tests/Fib.class" "fib"
51   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
52
53   mapM_ (\x -> do
54     result <- code_int entryFuncPtr x 0
55     let iresult :: Int; iresult = fromIntegral result
56     let kk :: String; kk = if iresult == (fib x) then "OK" else "FAIL (" ++ (show (fib x)) ++ ")"
57     printf "result of fib(%2d): %3d\t\t%s\n" (fromIntegral x :: Int) iresult kk
58     ) $ ([0..10] :: [CInt])
59   printf "patched disasm:\n"
60   Right newdisasm <- disassembleBlock entry end
61   mapM_ (putStrLn . showAtt) newdisasm
62   where
63     fib :: CInt -> Int
64     fib n
65       | n <= 1 = 1
66       | otherwise = (fib (n - 1)) + (fib (n - 2))
67
68
69 test_02 = do
70   (entry,_) <- testCase "./tests/While.class" "f"
71   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
72   result <- code_int entryFuncPtr 5 4
73   let iresult :: Int; iresult = fromIntegral result
74   let kk :: String; kk = if iresult == 15 then "OK" else "FAIL"
75   printf "result of f(5,4): %3d\t\t%s\n" iresult kk
76
77   result2 <- code_int entryFuncPtr 4 3
78   let iresult2 :: Int; iresult2 = fromIntegral result2
79   let kk2 :: String; kk2 = if iresult2 == 10 then "OK" else "FAIL"
80   printf "result of f(4,3): %3d\t\t%s\n" iresult2 kk2
81
82
83 test_03 = do
84   (entry,_) <- testCase "./tests/While.class" "g"
85   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
86   result <- code_int entryFuncPtr 5 4
87   let iresult :: Int; iresult = fromIntegral result
88   let kk :: String; kk = if iresult == 15 then "OK" else "FAIL"
89   printf "result of g(5,4): %3d\t\t%s\n" iresult kk
90
91   result2 <- code_int entryFuncPtr 4 3
92   let iresult2 :: Int; iresult2 = fromIntegral result2
93   let kk2 :: String; kk2 = if iresult2 == 10 then "OK" else "FAIL"
94   printf "result of g(4,3): %3d\t\t%s\n" iresult2 kk2
95
96
97 testCase :: String -> B.ByteString -> IO (Ptr Word8, Int)
98 testCase cf method = do
99       cls <- parseClassFile cf
100       hmap <- parseMethod cls method
101       printMapBB hmap
102       case hmap of
103         Nothing -> error "sorry, no code generation"
104         Just hmap' -> do
105               let ebb = emitFromBB cls hmap'
106               (_, Right ((entry, bbstarts, end, _), disasm)) <- runCodeGen ebb () ()
107               let int_entry = ((fromIntegral $ ptrToIntPtr entry) :: Int)
108               printf "disasm:\n"
109               mapM_ (putStrLn . showAtt) disasm
110               printf "basicblocks addresses:\n"
111               let b = map (\(x,y) -> (x,y + int_entry)) $ M.toList bbstarts
112               mapM_ (\(x,y) -> printf "\tBasicBlock %2d starts at 0x%08x\n" x y) b
113               return (entry, end)
114
115 type EntryPoint = Ptr Word8
116 type EntryPointOffset = Int
117 type PatchInfo = (BlockID, EntryPointOffset)
118
119 type BBStarts = M.Map BlockID Int
120
121 type CompileInfo = (EntryPoint, BBStarts, Int, CMap)
122
123
124 emitFromBB :: Class Resolved -> MapBB -> CodeGen e s (CompileInfo, [Instruction])
125 emitFromBB cls hmap =  do
126         llmap <- sequence [newNamedLabel ("bb_" ++ show x) | (x,_) <- M.toList hmap]
127         let lmap = zip (Prelude.fst $ unzip $ M.toList hmap) llmap
128         ep <- getEntryPoint
129         push ebp
130         mov ebp esp
131
132         (calls, bbstarts) <- efBB (0,(hmap M.! 0)) M.empty M.empty lmap
133         d <- disassemble
134         end <- getCodeOffset
135         return ((ep, bbstarts, end, calls), d)
136   where
137   getLabel :: BlockID -> [(BlockID, Label)] -> Label
138   getLabel _ [] = error "label not found!"
139   getLabel i ((x,l):xs) = if i==x then l else getLabel i xs
140
141   efBB :: (BlockID, BasicBlock) -> CMap -> BBStarts -> [(BlockID, Label)] -> CodeGen e s (CMap, BBStarts)
142   efBB (bid, bb) calls bbstarts lmap =
143         if M.member bid bbstarts then
144           return (calls, bbstarts)
145         else do
146           bb_offset <- getCodeOffset
147           let bbstarts' = M.insert bid bb_offset bbstarts
148           defineLabel $ getLabel bid lmap
149           cs <- mapM emit' $ code bb
150           let calls' = calls `M.union` (M.fromList $ catMaybes cs)
151           case successor bb of
152             Return -> return (calls', bbstarts')
153             FallThrough t -> do
154               efBB (t, hmap M.! t) calls' bbstarts' lmap
155             OneTarget t -> do
156               efBB (t, hmap M.! t) calls' bbstarts' lmap
157             TwoTarget t1 t2 -> do
158               (calls'', bbstarts'') <- efBB (t1, hmap M.! t1) calls' bbstarts' lmap
159               efBB (t2, hmap M.! t2) calls'' bbstarts'' lmap
160     -- TODO(bernhard): also use metainformation
161     -- TODO(bernhard): implement `emit' as function which accepts a list of
162     --                 instructions, so we can use patterns for optimizations
163     where
164     emit' :: J.Instruction -> CodeGen e s (Maybe (Word32, MethodInfo))
165     emit' (INVOKESTATIC cpidx) = do
166         ep <- getEntryPoint
167         let w32_ep = (fromIntegral $ ptrToIntPtr ep) :: Word32
168         let l = buildMethodID cls cpidx
169         calladdr <- getCodeOffset
170         let w32_calladdr = w32_ep + (fromIntegral calladdr) :: Word32
171         newNamedLabel (show l) >>= defineLabel
172         -- causes SIGILL. in the signal handler we patch it to the acutal call.
173         -- place a nop at the end, therefore the disasm doesn't screw up
174         emit32 (0xffff9090 :: Word32) >> emit8 (0x90 :: Word8)
175         -- discard arguments on stack
176         let argcnt = (methodGetArgsCount cls cpidx) * 4
177         when (argcnt > 0) (add esp argcnt)
178         -- push result on stack if method has a return value
179         when (methodHaveReturnValue cls cpidx) (push eax)
180         return $ Just $ (w32_calladdr, l)
181     emit' insn = emit insn >> return Nothing
182
183     emit :: J.Instruction -> CodeGen e s ()
184     emit POP = do -- print dropped value
185         ep <- getEntryPoint
186         let w32_ep = (fromIntegral $ ptrToIntPtr ep) :: Word32
187         -- '5' is the size of the `call' instruction ( + immediate)
188         calladdr <- getCodeOffset
189         let w32_calladdr = 5 + w32_ep + (fromIntegral calladdr) :: Word32
190         let trapaddr = (fromIntegral getaddr :: Word32)
191         call (trapaddr - w32_calladdr)
192         add esp (4 :: Word32)
193     emit (BIPUSH val) = push ((fromIntegral val) :: Word32)
194     emit (SIPUSH val) = push ((fromIntegral $ ((fromIntegral val) :: Int16)) :: Word32)
195     emit (ICONST_0) = push (0 :: Word32)
196     emit (ICONST_1) = push (1 :: Word32)
197     emit (ICONST_2) = push (2 :: Word32)
198     emit (ICONST_4) = push (4 :: Word32)
199     emit (ICONST_5) = push (5 :: Word32)
200     emit (ILOAD_ x) = do
201         push (Disp (cArgs_ x), ebp)
202     emit (ISTORE_ x) = do
203         pop eax
204         mov (Disp (cArgs_ x), ebp) eax
205     emit IADD = do pop ebx; pop eax; add eax ebx; push eax
206     emit ISUB = do pop ebx; pop eax; sub eax ebx; push eax
207     emit IMUL = do pop ebx; pop eax; mul ebx; push eax
208     emit (IINC x imm) = do
209         add (Disp (cArgs x), ebp) (s8_w32 imm)
210
211     emit (IF_ICMP cond _) = do
212         pop eax -- value2
213         pop ebx -- value1
214         cmp ebx eax -- intel syntax is swapped (TODO(bernhard): test that plz)
215         let sid = case successor bb of TwoTarget _ t -> t; _ -> error "bad"
216         let l = getLabel sid lmap
217         case cond of
218           C_EQ -> je  l; C_NE -> jne l
219           C_LT -> jl  l; C_GT -> jg  l
220           C_GE -> jge l; C_LE -> jle l
221
222     emit (IF cond _) = do
223         pop eax -- value1
224         cmp eax (0 :: Word32) -- TODO(bernhard): test that plz
225         let sid = case successor bb of TwoTarget _ t -> t; _ -> error "bad"
226         let l = getLabel sid lmap
227         case cond of
228           C_EQ -> je  l; C_NE -> jne l
229           C_LT -> jl  l; C_GT -> jg  l
230           C_GE -> jge l; C_LE -> jle l
231
232     emit (GOTO _ ) = do
233         let sid = case successor bb of OneTarget t -> t; _ -> error "bad"
234         jmp $ getLabel sid lmap
235
236     emit RETURN = do mov esp ebp; pop ebp; ret
237     emit IRETURN = do
238         pop eax
239         mov esp ebp
240         pop ebp
241         ret
242     emit _ = do cmovbe eax eax -- dummy
243
244   cArgs x = (8 + 4 * (fromIntegral x))
245   cArgs_ x = (8 + 4 * case x of I0 -> 0; I1 -> 1; I2 -> 2; I3 -> 3)
246
247   -- sign extension from w8 to w32 (over s8)
248   --   unfortunately, hs-java is using Word8 everywhere (while
249   --   it should be Int8 actually)
250   s8_w32 :: Word8 -> Word32
251   s8_w32 w8 = fromIntegral s8
252     where s8 = (fromIntegral w8) :: Int8