fields: use offsets from ClassInfo in codegen
[mate.git] / Mate / X86CodeGen.hs
1 {-# LANGUAGE OverloadedStrings #-}
2 {-# LANGUAGE ForeignFunctionInterface #-}
3 module Mate.X86CodeGen where
4
5 import Data.Binary
6 import Data.Int
7 import Data.Maybe
8 import qualified Data.Map as M
9 import qualified Data.Set as S
10 import qualified Data.ByteString.Lazy as B
11 import Control.Monad
12
13 import Foreign
14 import Foreign.C.Types
15
16 import Text.Printf
17
18 import qualified JVM.Assembler as J
19 import JVM.Assembler hiding (Instruction)
20 import JVM.ClassFile
21
22 import Harpy
23 import Harpy.X86Disassembler
24
25 import Mate.BasicBlocks
26 import Mate.Types
27 import Mate.Utilities
28 import Mate.ClassPool
29
30 foreign import ccall "dynamic"
31    code_int :: FunPtr (CInt -> CInt -> IO CInt) -> (CInt -> CInt -> IO CInt)
32
33 foreign import ccall "getaddr"
34   getaddr :: CUInt
35
36 foreign import ccall "getMallocAddr"
37   getMallocAddr :: CUInt
38
39 foreign import ccall "callertrap"
40   callertrap :: IO ()
41
42 foreign import ccall "register_signal"
43   register_signal :: IO ()
44
45 test_01, test_02, test_03 :: IO ()
46 test_01 = do
47   register_signal
48   (entry, end) <- testCase "./tests/Fib" "fib"
49   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
50
51   mapM_ (\x -> do
52     result <- code_int entryFuncPtr x 0
53     let iresult :: Int; iresult = fromIntegral result
54     let kk :: String; kk = if iresult == (fib x) then "OK" else "FAIL (" ++ (show (fib x)) ++ ")"
55     printf "result of fib(%2d): %3d\t\t%s\n" (fromIntegral x :: Int) iresult kk
56     ) $ ([0..10] :: [CInt])
57   printf "patched disasm:\n"
58   Right newdisasm <- disassembleBlock entry end
59   mapM_ (putStrLn . showAtt) newdisasm
60   where
61     fib :: CInt -> Int
62     fib n
63       | n <= 1 = 1
64       | otherwise = (fib (n - 1)) + (fib (n - 2))
65
66
67 test_02 = do
68   (entry,_) <- testCase "./tests/While" "f"
69   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
70   result <- code_int entryFuncPtr 5 4
71   let iresult :: Int; iresult = fromIntegral result
72   let kk :: String; kk = if iresult == 15 then "OK" else "FAIL"
73   printf "result of f(5,4): %3d\t\t%s\n" iresult kk
74
75   result2 <- code_int entryFuncPtr 4 3
76   let iresult2 :: Int; iresult2 = fromIntegral result2
77   let kk2 :: String; kk2 = if iresult2 == 10 then "OK" else "FAIL"
78   printf "result of f(4,3): %3d\t\t%s\n" iresult2 kk2
79
80
81 test_03 = do
82   (entry,_) <- testCase "./tests/While" "g"
83   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
84   result <- code_int entryFuncPtr 5 4
85   let iresult :: Int; iresult = fromIntegral result
86   let kk :: String; kk = if iresult == 15 then "OK" else "FAIL"
87   printf "result of g(5,4): %3d\t\t%s\n" iresult kk
88
89   result2 <- code_int entryFuncPtr 4 3
90   let iresult2 :: Int; iresult2 = fromIntegral result2
91   let kk2 :: String; kk2 = if iresult2 == 10 then "OK" else "FAIL"
92   printf "result of g(4,3): %3d\t\t%s\n" iresult2 kk2
93
94
95 testCase :: B.ByteString -> B.ByteString -> IO (Ptr Word8, Int)
96 testCase cf method = do
97       cls <- getClassFile cf
98       hmap <- parseMethod cls method
99       printMapBB hmap
100       case hmap of
101         Nothing -> error "sorry, no code generation"
102         Just hmap' -> do
103               let ebb = emitFromBB method cls hmap'
104               (_, Right ((entry, bbstarts, end, _), disasm)) <- runCodeGen ebb () ()
105               let int_entry = ((fromIntegral $ ptrToIntPtr entry) :: Int)
106               printf "disasm:\n"
107               mapM_ (putStrLn . showAtt) disasm
108               printf "basicblocks addresses:\n"
109               let b = map (\(x,y) -> (x,y + int_entry)) $ M.toList bbstarts
110               mapM_ (\(x,y) -> printf "\tBasicBlock %2d starts at 0x%08x\n" x y) b
111               return (entry, end)
112
113 type EntryPoint = Ptr Word8
114 type EntryPointOffset = Int
115 type PatchInfo = (BlockID, EntryPointOffset)
116
117 type BBStarts = M.Map BlockID Int
118
119 type CompileInfo = (EntryPoint, BBStarts, Int, TMap)
120
121
122 emitFromBB :: B.ByteString -> Class Resolved -> MapBB -> CodeGen e s (CompileInfo, [Instruction])
123 emitFromBB method cls hmap =  do
124         llmap <- sequence [newNamedLabel ("bb_" ++ show x) | (x,_) <- M.toList hmap]
125         let lmap = zip (Prelude.fst $ unzip $ M.toList hmap) llmap
126         ep <- getEntryPoint
127         push ebp
128         mov ebp esp
129         -- TODO(bernhard): determine a reasonable value.
130         --                 e.g. (locals used) * 4
131         sub esp (0x60 :: Word32)
132
133         (calls, bbstarts) <- efBB (0,(hmap M.! 0)) M.empty M.empty lmap
134         d <- disassemble
135         end <- getCodeOffset
136         return ((ep, bbstarts, end, calls), d)
137   where
138   getLabel :: BlockID -> [(BlockID, Label)] -> Label
139   getLabel _ [] = error "label not found!"
140   getLabel i ((x,l):xs) = if i==x then l else getLabel i xs
141
142   efBB :: (BlockID, BasicBlock) -> TMap -> BBStarts -> [(BlockID, Label)] -> CodeGen e s (TMap, BBStarts)
143   efBB (bid, bb) calls bbstarts lmap =
144         if M.member bid bbstarts then
145           return (calls, bbstarts)
146         else do
147           bb_offset <- getCodeOffset
148           let bbstarts' = M.insert bid bb_offset bbstarts
149           defineLabel $ getLabel bid lmap
150           cs <- mapM emit' $ code bb
151           let calls' = calls `M.union` (M.fromList $ catMaybes cs)
152           case successor bb of
153             Return -> return (calls', bbstarts')
154             FallThrough t -> do
155               efBB (t, hmap M.! t) calls' bbstarts' lmap
156             OneTarget t -> do
157               efBB (t, hmap M.! t) calls' bbstarts' lmap
158             TwoTarget t1 t2 -> do
159               (calls'', bbstarts'') <- efBB (t1, hmap M.! t1) calls' bbstarts' lmap
160               efBB (t2, hmap M.! t2) calls'' bbstarts'' lmap
161     -- TODO(bernhard): also use metainformation
162     -- TODO(bernhard): implement `emit' as function which accepts a list of
163     --                 instructions, so we can use patterns for optimizations
164     where
165     getCurrentOffset :: CodeGen e s (Word32)
166     getCurrentOffset = do
167       ep <- getEntryPoint
168       let w32_ep = (fromIntegral $ ptrToIntPtr ep) :: Word32
169       offset <- getCodeOffset
170       return $ w32_ep + (fromIntegral offset)
171
172     emit' :: J.Instruction -> CodeGen e s (Maybe (Word32, TrapInfo))
173     emit' (INVOKESPECIAL cpidx) = emit' (INVOKESTATIC cpidx)
174     emit' (INVOKESTATIC cpidx) = do
175         let l = buildMethodID cls cpidx
176         calladdr <- getCurrentOffset
177         newNamedLabel (show l) >>= defineLabel
178         -- causes SIGILL. in the signal handler we patch it to the acutal call.
179         -- place a nop at the end, therefore the disasm doesn't screw up
180         emit32 (0xffff9090 :: Word32) >> emit8 (0x90 :: Word8)
181         -- discard arguments on stack
182         let argcnt = (methodGetArgsCount cls cpidx) * 4
183         when (argcnt > 0) (add esp argcnt)
184         -- push result on stack if method has a return value
185         when (methodHaveReturnValue cls cpidx) (push eax)
186         return $ Just $ (calladdr, MI l)
187     emit' (PUTSTATIC cpidx) = do
188         pop eax
189         trapaddr <- getCurrentOffset
190         mov (Addr 0x00000000) eax -- it's a trap
191         return $ Just $ (trapaddr, SFI $ buildStaticFieldID cls cpidx)
192     emit' (GETSTATIC cpidx) = do
193         trapaddr <- getCurrentOffset
194         mov eax (Addr 0x00000000) -- it's a trap
195         push eax
196         return $ Just $ (trapaddr, SFI $ buildStaticFieldID cls cpidx)
197     emit' insn = emit insn >> return Nothing
198
199     emit :: J.Instruction -> CodeGen e s ()
200     emit POP = do -- print dropped value
201         calladdr <- getCurrentOffset
202         -- '5' is the size of the `call' instruction ( + immediate)
203         let w32_calladdr = 5 + calladdr
204         let trapaddr = (fromIntegral getaddr :: Word32)
205         call (trapaddr - w32_calladdr)
206         add esp (4 :: Word32)
207     emit DUP = pop (Disp 0, esp)
208     emit (NEW objidx) = do
209         -- TODO(bernhard): determine right amount...
210         let amount = 0x20
211         push (amount :: Word32)
212         calladdr <- getCurrentOffset
213         let w32_calladdr = 5 + calladdr
214         let malloaddr = (fromIntegral getMallocAddr :: Word32)
215         call (malloaddr - w32_calladdr)
216         add esp (4 :: Word32)
217         push eax
218         -- TODO(bernhard): save reference somewhere for GC
219     emit (BIPUSH val) = push ((fromIntegral val) :: Word32)
220     emit (SIPUSH val) = push ((fromIntegral $ ((fromIntegral val) :: Int16)) :: Word32)
221     emit (ICONST_0) = push (0 :: Word32)
222     emit (ICONST_1) = push (1 :: Word32)
223     emit (ICONST_2) = push (2 :: Word32)
224     emit (ICONST_4) = push (4 :: Word32)
225     emit (ICONST_5) = push (5 :: Word32)
226     emit (ALOAD_ x) = emit (ILOAD_ x)
227     emit (ILOAD_ x) = do
228         push (Disp (cArgs_ x), ebp)
229     emit (ALOAD x) = emit (ILOAD x)
230     emit (ILOAD x) = do
231         push (Disp (cArgs x), ebp)
232     emit (ASTORE_ x) = emit (ISTORE_ x)
233     emit (ISTORE_ x) = do
234         pop eax
235         mov (Disp (cArgs_ x), ebp) eax
236     emit (ASTORE x) = emit (ISTORE x)
237     emit (ISTORE x) = do
238         pop eax
239         mov (Disp (cArgs x), ebp) eax
240
241     emit (GETFIELD x) = do
242         pop eax -- this pointer
243         let (cname, fname) = buildFieldOffset cls x
244         let offset = unsafePerformIO $ getFieldOffset cname fname
245         push (Disp (fromIntegral $ offset * 4), eax) -- get field
246     emit (PUTFIELD x) = do
247         pop ebx -- value to write
248         pop eax -- this pointer
249         let (cname, fname) = buildFieldOffset cls x
250         let offset = unsafePerformIO $ getFieldOffset cname fname
251         mov (Disp (fromIntegral $ offset * 4), eax) ebx -- set field
252
253     emit IADD = do pop ebx; pop eax; add eax ebx; push eax
254     emit ISUB = do pop ebx; pop eax; sub eax ebx; push eax
255     emit IMUL = do pop ebx; pop eax; mul ebx; push eax
256     emit (IINC x imm) = do
257         add (Disp (cArgs x), ebp) (s8_w32 imm)
258
259     emit (IF_ICMP cond _) = do
260         pop eax -- value2
261         pop ebx -- value1
262         cmp ebx eax -- intel syntax is swapped (TODO(bernhard): test that plz)
263         let sid = case successor bb of TwoTarget _ t -> t; _ -> error "bad"
264         let l = getLabel sid lmap
265         case cond of
266           C_EQ -> je  l; C_NE -> jne l
267           C_LT -> jl  l; C_GT -> jg  l
268           C_GE -> jge l; C_LE -> jle l
269
270     emit (IF cond _) = do
271         pop eax -- value1
272         cmp eax (0 :: Word32) -- TODO(bernhard): test that plz
273         let sid = case successor bb of TwoTarget _ t -> t; _ -> error "bad"
274         let l = getLabel sid lmap
275         case cond of
276           C_EQ -> je  l; C_NE -> jne l
277           C_LT -> jl  l; C_GT -> jg  l
278           C_GE -> jge l; C_LE -> jle l
279
280     emit (GOTO _ ) = do
281         let sid = case successor bb of OneTarget t -> t; _ -> error "bad"
282         jmp $ getLabel sid lmap
283
284     emit RETURN = do mov esp ebp; pop ebp; ret
285     emit IRETURN = do
286         pop eax
287         mov esp ebp
288         pop ebp
289         ret
290     emit invalid = error $ "insn not implemented yet: " ++ (show invalid)
291
292   -- for locals we use a different storage
293   cArgs :: Word8 -> Word32
294   cArgs x = if (x' >= thisMethodArgCnt)
295       -- TODO(bernhard): maybe s/(-4)/(-8)/
296       then fromIntegral $ (-4) * (x' - thisMethodArgCnt + 1)
297       else 8 + (4 * x')
298     where x' = fromIntegral x
299
300   cArgs_ :: IMM -> Word32
301   cArgs_ x = cArgs $ case x of I0 -> 0; I1 -> 1; I2 -> 2; I3 -> 3
302
303   thisMethodArgCnt :: Word32
304   thisMethodArgCnt = isNonStatic + (fromIntegral $ length args)
305     where
306     (Just m) = lookupMethod method cls
307     (MethodSignature args _) = methodSignature m
308     isNonStatic = if S.member ACC_STATIC (methodAccessFlags m)
309         then 0
310         else 1 -- one argument for the this pointer
311
312
313   -- sign extension from w8 to w32 (over s8)
314   --   unfortunately, hs-java is using Word8 everywhere (while
315   --   it should be Int8 actually)
316   s8_w32 :: Word8 -> Word32
317   s8_w32 w8 = fromIntegral s8
318     where s8 = (fromIntegral w8) :: Int8