2db7a979a0ad5730748325ccdca3b4ad3a1b7ba9
[mate.git] / Mate / X86CodeGen.hs
1 {-# LANGUAGE OverloadedStrings #-}
2 {-# LANGUAGE ForeignFunctionInterface #-}
3 module Mate.X86CodeGen where
4
5 import Data.Binary
6 import Data.Int
7 import Data.Maybe
8 import qualified Data.Map as M
9 import qualified Data.Set as S
10 import qualified Data.ByteString.Lazy as B
11 import Control.Monad
12
13 import Foreign
14 import Foreign.C.Types
15
16 import Text.Printf
17
18 import qualified JVM.Assembler as J
19 import JVM.Assembler hiding (Instruction)
20 import JVM.ClassFile
21
22 import Harpy
23 import Harpy.X86Disassembler
24
25 import Mate.BasicBlocks
26 import Mate.Types
27 import Mate.Utilities
28 import Mate.ClassPool
29
30 foreign import ccall "dynamic"
31    code_int :: FunPtr (CInt -> CInt -> IO CInt) -> (CInt -> CInt -> IO CInt)
32
33 foreign import ccall "getaddr"
34   getaddr :: CUInt
35
36 foreign import ccall "getMallocAddr"
37   getMallocAddr :: CUInt
38
39 foreign import ccall "callertrap"
40   callertrap :: IO ()
41
42 foreign import ccall "register_signal"
43   register_signal :: IO ()
44
45 test_01, test_02, test_03 :: IO ()
46 test_01 = do
47   register_signal
48   (entry, end) <- testCase "./tests/Fib" "fib"
49   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
50
51   mapM_ (\x -> do
52     result <- code_int entryFuncPtr x 0
53     let iresult :: Int; iresult = fromIntegral result
54     let kk :: String; kk = if iresult == (fib x) then "OK" else "FAIL (" ++ (show (fib x)) ++ ")"
55     printf "result of fib(%2d): %3d\t\t%s\n" (fromIntegral x :: Int) iresult kk
56     ) $ ([0..10] :: [CInt])
57   printf "patched disasm:\n"
58   Right newdisasm <- disassembleBlock entry end
59   mapM_ (putStrLn . showAtt) newdisasm
60   where
61     fib :: CInt -> Int
62     fib n
63       | n <= 1 = 1
64       | otherwise = (fib (n - 1)) + (fib (n - 2))
65
66
67 test_02 = do
68   (entry,_) <- testCase "./tests/While" "f"
69   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
70   result <- code_int entryFuncPtr 5 4
71   let iresult :: Int; iresult = fromIntegral result
72   let kk :: String; kk = if iresult == 15 then "OK" else "FAIL"
73   printf "result of f(5,4): %3d\t\t%s\n" iresult kk
74
75   result2 <- code_int entryFuncPtr 4 3
76   let iresult2 :: Int; iresult2 = fromIntegral result2
77   let kk2 :: String; kk2 = if iresult2 == 10 then "OK" else "FAIL"
78   printf "result of f(4,3): %3d\t\t%s\n" iresult2 kk2
79
80
81 test_03 = do
82   (entry,_) <- testCase "./tests/While" "g"
83   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
84   result <- code_int entryFuncPtr 5 4
85   let iresult :: Int; iresult = fromIntegral result
86   let kk :: String; kk = if iresult == 15 then "OK" else "FAIL"
87   printf "result of g(5,4): %3d\t\t%s\n" iresult kk
88
89   result2 <- code_int entryFuncPtr 4 3
90   let iresult2 :: Int; iresult2 = fromIntegral result2
91   let kk2 :: String; kk2 = if iresult2 == 10 then "OK" else "FAIL"
92   printf "result of g(4,3): %3d\t\t%s\n" iresult2 kk2
93
94
95 testCase :: B.ByteString -> B.ByteString -> IO (Ptr Word8, Int)
96 testCase cf method = do
97       cls <- getClassFile cf
98       hmap <- parseMethod cls method
99       printMapBB hmap
100       case hmap of
101         Nothing -> error "sorry, no code generation"
102         Just hmap' -> do
103               let ebb = emitFromBB method cls hmap'
104               (_, Right ((entry, bbstarts, end, _), disasm)) <- runCodeGen ebb () ()
105               let int_entry = ((fromIntegral $ ptrToIntPtr entry) :: Int)
106               printf "disasm:\n"
107               mapM_ (putStrLn . showAtt) disasm
108               printf "basicblocks addresses:\n"
109               let b = map (\(x,y) -> (x,y + int_entry)) $ M.toList bbstarts
110               mapM_ (\(x,y) -> printf "\tBasicBlock %2d starts at 0x%08x\n" x y) b
111               return (entry, end)
112
113 type EntryPoint = Ptr Word8
114 type EntryPointOffset = Int
115 type PatchInfo = (BlockID, EntryPointOffset)
116
117 type BBStarts = M.Map BlockID Int
118
119 type CompileInfo = (EntryPoint, BBStarts, Int, TMap)
120
121
122 emitFromBB :: B.ByteString -> Class Resolved -> MapBB -> CodeGen e s (CompileInfo, [Instruction])
123 emitFromBB method cls hmap =  do
124         llmap <- sequence [newNamedLabel ("bb_" ++ show x) | (x,_) <- M.toList hmap]
125         let lmap = zip (Prelude.fst $ unzip $ M.toList hmap) llmap
126         ep <- getEntryPoint
127         push ebp
128         mov ebp esp
129         -- TODO(bernhard): determine a reasonable value.
130         --                 e.g. (locals used) * 4
131         sub esp (0x60 :: Word32)
132
133         (calls, bbstarts) <- efBB (0,(hmap M.! 0)) M.empty M.empty lmap
134         d <- disassemble
135         end <- getCodeOffset
136         return ((ep, bbstarts, end, calls), d)
137   where
138   getLabel :: BlockID -> [(BlockID, Label)] -> Label
139   getLabel _ [] = error "label not found!"
140   getLabel i ((x,l):xs) = if i==x then l else getLabel i xs
141
142   efBB :: (BlockID, BasicBlock) -> TMap -> BBStarts -> [(BlockID, Label)] -> CodeGen e s (TMap, BBStarts)
143   efBB (bid, bb) calls bbstarts lmap =
144         if M.member bid bbstarts then
145           return (calls, bbstarts)
146         else do
147           bb_offset <- getCodeOffset
148           let bbstarts' = M.insert bid bb_offset bbstarts
149           defineLabel $ getLabel bid lmap
150           cs <- mapM emit' $ code bb
151           let calls' = calls `M.union` (M.fromList $ catMaybes cs)
152           case successor bb of
153             Return -> return (calls', bbstarts')
154             FallThrough t -> do
155               efBB (t, hmap M.! t) calls' bbstarts' lmap
156             OneTarget t -> do
157               efBB (t, hmap M.! t) calls' bbstarts' lmap
158             TwoTarget t1 t2 -> do
159               (calls'', bbstarts'') <- efBB (t1, hmap M.! t1) calls' bbstarts' lmap
160               efBB (t2, hmap M.! t2) calls'' bbstarts'' lmap
161     -- TODO(bernhard): also use metainformation
162     -- TODO(bernhard): implement `emit' as function which accepts a list of
163     --                 instructions, so we can use patterns for optimizations
164     where
165     getCurrentOffset :: CodeGen e s (Word32)
166     getCurrentOffset = do
167       ep <- getEntryPoint
168       let w32_ep = (fromIntegral $ ptrToIntPtr ep) :: Word32
169       offset <- getCodeOffset
170       return $ w32_ep + (fromIntegral offset)
171
172     emitInvoke :: Word16 -> Bool -> CodeGen e s (Maybe (Word32, TrapInfo))
173     emitInvoke cpidx hasThis = do
174         let l = buildMethodID cls cpidx
175         calladdr <- getCurrentOffset
176         newNamedLabel (show l) >>= defineLabel
177         -- causes SIGILL. in the signal handler we patch it to the acutal call.
178         -- place a nop at the end, therefore the disasm doesn't screw up
179         emit32 (0xffff9090 :: Word32) >> emit8 (0x90 :: Word8)
180         -- discard arguments on stack
181         let argcnt = ((if hasThis then 1 else 0) + (methodGetArgsCount cls cpidx)) * 4
182         when (argcnt > 0) (add esp argcnt)
183         -- push result on stack if method has a return value
184         when (methodHaveReturnValue cls cpidx) (push eax)
185         return $ Just $ (calladdr, MI l)
186
187     emit' :: J.Instruction -> CodeGen e s (Maybe (Word32, TrapInfo))
188     emit' (INVOKESPECIAL cpidx) = emitInvoke cpidx True
189     emit' (INVOKESTATIC cpidx) = emitInvoke cpidx False
190     emit' (INVOKEVIRTUAL cpidx) = do
191         -- get methodInfo entry
192         let mi@(MethodInfo methodname objname msig@(MethodSignature args _))  = buildMethodID cls cpidx
193         newNamedLabel (show mi) >>= defineLabel
194         -- objref lives somewhere on the argument stack
195         mov eax (Disp ((*4) $ fromIntegral $ length args), esp)
196         -- get method-table-ptr
197         mov eax (Disp 0, eax)
198         -- get method offset
199         let nameAndSig = methodname `B.append` (encode msig)
200         let offset = unsafePerformIO $ getMethodOffset objname nameAndSig
201         -- make actual (indirect) call
202         calladdr <- getCurrentOffset
203         call (Disp offset, eax)
204         -- discard arguments on stack (+4 for "this")
205         let argcnt = 4 + ((methodGetArgsCount cls cpidx) * 4)
206         when (argcnt > 0) (add esp argcnt)
207         -- push result on stack if method has a return value
208         when (methodHaveReturnValue cls cpidx) (push eax)
209         -- note, the "mi" has the wrong class reference here.
210         -- we figure that out at run-time, in the methodpool,
211         -- depending on the method-table-ptr
212         return $ Just $ (calladdr, VI mi)
213     emit' (PUTSTATIC cpidx) = do
214         pop eax
215         trapaddr <- getCurrentOffset
216         mov (Addr 0x00000000) eax -- it's a trap
217         return $ Just $ (trapaddr, SFI $ buildStaticFieldID cls cpidx)
218     emit' (GETSTATIC cpidx) = do
219         trapaddr <- getCurrentOffset
220         mov eax (Addr 0x00000000) -- it's a trap
221         push eax
222         return $ Just $ (trapaddr, SFI $ buildStaticFieldID cls cpidx)
223     emit' insn = emit insn >> return Nothing
224
225     emit :: J.Instruction -> CodeGen e s ()
226     emit POP = do -- print dropped value
227         calladdr <- getCurrentOffset
228         -- '5' is the size of the `call' instruction ( + immediate)
229         let w32_calladdr = 5 + calladdr
230         let trapaddr = (fromIntegral getaddr :: Word32)
231         call (trapaddr - w32_calladdr)
232         add esp (4 :: Word32)
233     emit DUP = push (Disp 0, esp)
234     emit (NEW objidx) = do
235         let objname = buildClassID cls objidx
236         let amount = unsafePerformIO $ getMethodSize objname
237         push (amount :: Word32)
238         calladdr <- getCurrentOffset
239         let w32_calladdr = 5 + calladdr
240         let malloaddr = (fromIntegral getMallocAddr :: Word32)
241         call (malloaddr - w32_calladdr)
242         add esp (4 :: Word32)
243         push eax
244         -- TODO(bernhard): save reference somewhere for GC
245         -- set method table pointer
246         let mtable = unsafePerformIO $ getMethodTable objname
247         mov (Disp 0, eax) mtable
248     emit (CHECKCAST _) = nop -- TODO(bernhard): ...
249     emit (BIPUSH val) = push ((fromIntegral val) :: Word32)
250     emit (SIPUSH val) = push ((fromIntegral $ ((fromIntegral val) :: Int16)) :: Word32)
251     emit (ICONST_0) = push (0 :: Word32)
252     emit (ICONST_1) = push (1 :: Word32)
253     emit (ICONST_2) = push (2 :: Word32)
254     emit (ICONST_4) = push (4 :: Word32)
255     emit (ICONST_5) = push (5 :: Word32)
256     emit (ALOAD_ x) = emit (ILOAD_ x)
257     emit (ILOAD_ x) = do
258         push (Disp (cArgs_ x), ebp)
259     emit (ALOAD x) = emit (ILOAD x)
260     emit (ILOAD x) = do
261         push (Disp (cArgs x), ebp)
262     emit (ASTORE_ x) = emit (ISTORE_ x)
263     emit (ISTORE_ x) = do
264         pop eax
265         mov (Disp (cArgs_ x), ebp) eax
266     emit (ASTORE x) = emit (ISTORE x)
267     emit (ISTORE x) = do
268         pop eax
269         mov (Disp (cArgs x), ebp) eax
270
271     emit (GETFIELD x) = do
272         pop eax -- this pointer
273         let (cname, fname) = buildFieldOffset cls x
274         let offset = unsafePerformIO $ getFieldOffset cname fname
275         push (Disp (fromIntegral $ offset * 4), eax) -- get field
276     emit (PUTFIELD x) = do
277         pop ebx -- value to write
278         pop eax -- this pointer
279         let (cname, fname) = buildFieldOffset cls x
280         let offset = unsafePerformIO $ getFieldOffset cname fname
281         mov (Disp (fromIntegral $ offset * 4), eax) ebx -- set field
282
283     emit IADD = do pop ebx; pop eax; add eax ebx; push eax
284     emit ISUB = do pop ebx; pop eax; sub eax ebx; push eax
285     emit IMUL = do pop ebx; pop eax; mul ebx; push eax
286     emit (IINC x imm) = do
287         add (Disp (cArgs x), ebp) (s8_w32 imm)
288
289     emit (IF_ICMP cond _) = do
290         pop eax -- value2
291         pop ebx -- value1
292         cmp ebx eax -- intel syntax is swapped (TODO(bernhard): test that plz)
293         let sid = case successor bb of TwoTarget _ t -> t; _ -> error "bad"
294         let l = getLabel sid lmap
295         case cond of
296           C_EQ -> je  l; C_NE -> jne l
297           C_LT -> jl  l; C_GT -> jg  l
298           C_GE -> jge l; C_LE -> jle l
299
300     emit (IF cond _) = do
301         pop eax -- value1
302         cmp eax (0 :: Word32) -- TODO(bernhard): test that plz
303         let sid = case successor bb of TwoTarget _ t -> t; _ -> error "bad"
304         let l = getLabel sid lmap
305         case cond of
306           C_EQ -> je  l; C_NE -> jne l
307           C_LT -> jl  l; C_GT -> jg  l
308           C_GE -> jge l; C_LE -> jle l
309
310     emit (GOTO _ ) = do
311         let sid = case successor bb of OneTarget t -> t; _ -> error "bad"
312         jmp $ getLabel sid lmap
313
314     emit RETURN = do mov esp ebp; pop ebp; ret
315     emit ARETURN = emit IRETURN
316     emit IRETURN = do
317         pop eax
318         mov esp ebp
319         pop ebp
320         ret
321     emit invalid = error $ "insn not implemented yet: " ++ (show invalid)
322
323   -- for locals we use a different storage
324   cArgs :: Word8 -> Word32
325   cArgs x = if (x' >= thisMethodArgCnt)
326       -- TODO(bernhard): maybe s/(-4)/(-8)/
327       then fromIntegral $ (-4) * (x' - thisMethodArgCnt + 1)
328       else 4 + (thisMethodArgCnt * 4) - (4 * x')
329     where x' = fromIntegral x
330
331   cArgs_ :: IMM -> Word32
332   cArgs_ x = cArgs $ case x of I0 -> 0; I1 -> 1; I2 -> 2; I3 -> 3
333
334   thisMethodArgCnt :: Word32
335   thisMethodArgCnt = isNonStatic + (fromIntegral $ length args)
336     where
337     (Just m) = lookupMethod method cls
338     (MethodSignature args _) = methodSignature m
339     isNonStatic = if S.member ACC_STATIC (methodAccessFlags m)
340         then 0
341         else 1 -- one argument for the this pointer
342
343
344   -- sign extension from w8 to w32 (over s8)
345   --   unfortunately, hs-java is using Word8 everywhere (while
346   --   it should be Int8 actually)
347   s8_w32 :: Word8 -> Word32
348   s8_w32 w8 = fromIntegral s8
349     where s8 = (fromIntegral w8) :: Int8