strings: put every String from the constantpool in a Map
[mate.git] / Mate / X86CodeGen.hs
1 {-# LANGUAGE OverloadedStrings #-}
2 {-# LANGUAGE ForeignFunctionInterface #-}
3 module Mate.X86CodeGen where
4
5 import Data.Binary
6 import Data.Int
7 import Data.Maybe
8 import qualified Data.Map as M
9 import qualified Data.Set as S
10 import qualified Data.ByteString.Lazy as B
11 import Control.Monad
12
13 import Foreign
14 import Foreign.C.Types
15
16 import Text.Printf
17
18 import qualified JVM.Assembler as J
19 import JVM.Assembler hiding (Instruction)
20 import JVM.ClassFile
21
22 import Harpy
23 import Harpy.X86Disassembler
24
25 import Mate.BasicBlocks
26 import Mate.Types
27 import Mate.Utilities
28 import Mate.ClassPool
29 import Mate.Strings
30
31
32 foreign import ccall "dynamic"
33    code_int :: FunPtr (CInt -> CInt -> IO CInt) -> (CInt -> CInt -> IO CInt)
34
35 foreign import ccall "getaddr"
36   getaddr :: CUInt
37
38 foreign import ccall "getMallocAddr"
39   getMallocAddr :: CUInt
40
41 foreign import ccall "callertrap"
42   callertrap :: IO ()
43
44 foreign import ccall "register_signal"
45   register_signal :: IO ()
46
47 test_01, test_02, test_03 :: IO ()
48 test_01 = do
49   register_signal
50   (entry, end) <- testCase "./tests/Fib" "fib"
51   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
52
53   mapM_ (\x -> do
54     result <- code_int entryFuncPtr x 0
55     let iresult :: Int; iresult = fromIntegral result
56     let kk :: String; kk = if iresult == (fib x) then "OK" else "FAIL (" ++ (show (fib x)) ++ ")"
57     printf "result of fib(%2d): %3d\t\t%s\n" (fromIntegral x :: Int) iresult kk
58     ) $ ([0..10] :: [CInt])
59   printf "patched disasm:\n"
60   Right newdisasm <- disassembleBlock entry end
61   mapM_ (putStrLn . showAtt) newdisasm
62   where
63     fib :: CInt -> Int
64     fib n
65       | n <= 1 = 1
66       | otherwise = (fib (n - 1)) + (fib (n - 2))
67
68
69 test_02 = do
70   (entry,_) <- testCase "./tests/While" "f"
71   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
72   result <- code_int entryFuncPtr 5 4
73   let iresult :: Int; iresult = fromIntegral result
74   let kk :: String; kk = if iresult == 15 then "OK" else "FAIL"
75   printf "result of f(5,4): %3d\t\t%s\n" iresult kk
76
77   result2 <- code_int entryFuncPtr 4 3
78   let iresult2 :: Int; iresult2 = fromIntegral result2
79   let kk2 :: String; kk2 = if iresult2 == 10 then "OK" else "FAIL"
80   printf "result of f(4,3): %3d\t\t%s\n" iresult2 kk2
81
82
83 test_03 = do
84   (entry,_) <- testCase "./tests/While" "g"
85   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
86   result <- code_int entryFuncPtr 5 4
87   let iresult :: Int; iresult = fromIntegral result
88   let kk :: String; kk = if iresult == 15 then "OK" else "FAIL"
89   printf "result of g(5,4): %3d\t\t%s\n" iresult kk
90
91   result2 <- code_int entryFuncPtr 4 3
92   let iresult2 :: Int; iresult2 = fromIntegral result2
93   let kk2 :: String; kk2 = if iresult2 == 10 then "OK" else "FAIL"
94   printf "result of g(4,3): %3d\t\t%s\n" iresult2 kk2
95
96
97 testCase :: B.ByteString -> B.ByteString -> IO (Ptr Word8, Int)
98 testCase cf method = do
99       cls <- getClassFile cf
100       hmap <- parseMethod cls method
101       printMapBB hmap
102       case hmap of
103         Nothing -> error "sorry, no code generation"
104         Just hmap' -> do
105               let ebb = emitFromBB method cls hmap'
106               (_, Right ((entry, bbstarts, end, _), disasm)) <- runCodeGen ebb () ()
107               let int_entry = ((fromIntegral $ ptrToIntPtr entry) :: Int)
108               printf "disasm:\n"
109               mapM_ (putStrLn . showAtt) disasm
110               printf "basicblocks addresses:\n"
111               let b = map (\(x,y) -> (x,y + int_entry)) $ M.toList bbstarts
112               mapM_ (\(x,y) -> printf "\tBasicBlock %2d starts at 0x%08x\n" x y) b
113               return (entry, end)
114
115 type EntryPoint = Ptr Word8
116 type EntryPointOffset = Int
117 type PatchInfo = (BlockID, EntryPointOffset)
118
119 type BBStarts = M.Map BlockID Int
120
121 type CompileInfo = (EntryPoint, BBStarts, Int, TMap)
122
123
124 emitFromBB :: B.ByteString -> Class Resolved -> MapBB -> CodeGen e s (CompileInfo, [Instruction])
125 emitFromBB method cls hmap =  do
126         llmap <- sequence [newNamedLabel ("bb_" ++ show x) | (x,_) <- M.toList hmap]
127         let lmap = zip (Prelude.fst $ unzip $ M.toList hmap) llmap
128         ep <- getEntryPoint
129         push ebp
130         mov ebp esp
131         -- TODO(bernhard): determine a reasonable value.
132         --                 e.g. (locals used) * 4
133         sub esp (0x60 :: Word32)
134
135         (calls, bbstarts) <- efBB (0,(hmap M.! 0)) M.empty M.empty lmap
136         d <- disassemble
137         end <- getCodeOffset
138         return ((ep, bbstarts, end, calls), d)
139   where
140   getLabel :: BlockID -> [(BlockID, Label)] -> Label
141   getLabel _ [] = error "label not found!"
142   getLabel i ((x,l):xs) = if i==x then l else getLabel i xs
143
144   efBB :: (BlockID, BasicBlock) -> TMap -> BBStarts -> [(BlockID, Label)] -> CodeGen e s (TMap, BBStarts)
145   efBB (bid, bb) calls bbstarts lmap =
146         if M.member bid bbstarts then
147           return (calls, bbstarts)
148         else do
149           bb_offset <- getCodeOffset
150           let bbstarts' = M.insert bid bb_offset bbstarts
151           defineLabel $ getLabel bid lmap
152           cs <- mapM emit' $ code bb
153           let calls' = calls `M.union` (M.fromList $ catMaybes cs)
154           case successor bb of
155             Return -> return (calls', bbstarts')
156             FallThrough t -> do
157               efBB (t, hmap M.! t) calls' bbstarts' lmap
158             OneTarget t -> do
159               efBB (t, hmap M.! t) calls' bbstarts' lmap
160             TwoTarget t1 t2 -> do
161               (calls'', bbstarts'') <- efBB (t1, hmap M.! t1) calls' bbstarts' lmap
162               efBB (t2, hmap M.! t2) calls'' bbstarts'' lmap
163     -- TODO(bernhard): also use metainformation
164     -- TODO(bernhard): implement `emit' as function which accepts a list of
165     --                 instructions, so we can use patterns for optimizations
166     where
167     getCurrentOffset :: CodeGen e s (Word32)
168     getCurrentOffset = do
169       ep <- getEntryPoint
170       let w32_ep = (fromIntegral $ ptrToIntPtr ep) :: Word32
171       offset <- getCodeOffset
172       return $ w32_ep + (fromIntegral offset)
173
174     emitInvoke :: Word16 -> Bool -> CodeGen e s (Maybe (Word32, TrapInfo))
175     emitInvoke cpidx hasThis = do
176         let l = buildMethodID cls cpidx
177         calladdr <- getCurrentOffset
178         newNamedLabel (show l) >>= defineLabel
179         -- causes SIGILL. in the signal handler we patch it to the acutal call.
180         -- place a nop at the end, therefore the disasm doesn't screw up
181         emit32 (0xffff9090 :: Word32) >> emit8 (0x90 :: Word8)
182         -- discard arguments on stack
183         let argcnt = ((if hasThis then 1 else 0) + (methodGetArgsCount cls cpidx)) * 4
184         when (argcnt > 0) (add esp argcnt)
185         -- push result on stack if method has a return value
186         when (methodHaveReturnValue cls cpidx) (push eax)
187         return $ Just $ (calladdr, MI l)
188
189     emit' :: J.Instruction -> CodeGen e s (Maybe (Word32, TrapInfo))
190     emit' (INVOKESPECIAL cpidx) = emitInvoke cpidx True
191     emit' (INVOKESTATIC cpidx) = emitInvoke cpidx False
192     emit' (INVOKEVIRTUAL cpidx) = do
193         -- get methodInfo entry
194         let mi@(MethodInfo methodname objname msig@(MethodSignature args _))  = buildMethodID cls cpidx
195         newNamedLabel (show mi) >>= defineLabel
196         -- objref lives somewhere on the argument stack
197         mov eax (Disp ((*4) $ fromIntegral $ length args), esp)
198         -- get method-table-ptr
199         mov eax (Disp 0, eax)
200         -- get method offset
201         let nameAndSig = methodname `B.append` (encode msig)
202         let offset = unsafePerformIO $ getMethodOffset objname nameAndSig
203         -- make actual (indirect) call
204         calladdr <- getCurrentOffset
205         call (Disp offset, eax)
206         -- discard arguments on stack (+4 for "this")
207         let argcnt = 4 + ((methodGetArgsCount cls cpidx) * 4)
208         when (argcnt > 0) (add esp argcnt)
209         -- push result on stack if method has a return value
210         when (methodHaveReturnValue cls cpidx) (push eax)
211         -- note, the "mi" has the wrong class reference here.
212         -- we figure that out at run-time, in the methodpool,
213         -- depending on the method-table-ptr
214         return $ Just $ (calladdr, VI mi)
215     emit' (PUTSTATIC cpidx) = do
216         pop eax
217         trapaddr <- getCurrentOffset
218         mov (Addr 0x00000000) eax -- it's a trap
219         return $ Just $ (trapaddr, SFI $ buildStaticFieldID cls cpidx)
220     emit' (GETSTATIC cpidx) = do
221         trapaddr <- getCurrentOffset
222         mov eax (Addr 0x00000000) -- it's a trap
223         push eax
224         return $ Just $ (trapaddr, SFI $ buildStaticFieldID cls cpidx)
225     emit' insn = emit insn >> return Nothing
226
227     emit :: J.Instruction -> CodeGen e s ()
228     emit POP = do -- print dropped value
229         calladdr <- getCurrentOffset
230         -- '5' is the size of the `call' instruction ( + immediate)
231         let w32_calladdr = 5 + calladdr
232         let trapaddr = (fromIntegral getaddr :: Word32)
233         call (trapaddr - w32_calladdr)
234         add esp (4 :: Word32)
235     emit DUP = push (Disp 0, esp)
236     emit (NEW objidx) = do
237         let objname = buildClassID cls objidx
238         let amount = unsafePerformIO $ getMethodSize objname
239         push (amount :: Word32)
240         calladdr <- getCurrentOffset
241         let w32_calladdr = 5 + calladdr
242         let malloaddr = (fromIntegral getMallocAddr :: Word32)
243         call (malloaddr - w32_calladdr)
244         add esp (4 :: Word32)
245         push eax
246         -- TODO(bernhard): save reference somewhere for GC
247         -- set method table pointer
248         let mtable = unsafePerformIO $ getMethodTable objname
249         mov (Disp 0, eax) mtable
250     emit (CHECKCAST _) = nop -- TODO(bernhard): ...
251     emit (BIPUSH val) = push ((fromIntegral val) :: Word32)
252     emit (SIPUSH val) = push ((fromIntegral $ ((fromIntegral val) :: Int16)) :: Word32)
253     emit (ICONST_0) = push (0 :: Word32)
254     emit (ICONST_1) = push (1 :: Word32)
255     emit (ICONST_2) = push (2 :: Word32)
256     emit (ICONST_4) = push (4 :: Word32)
257     emit (ICONST_5) = push (5 :: Word32)
258     emit (ALOAD_ x) = emit (ILOAD_ x)
259     emit (ILOAD_ x) = do
260         push (Disp (cArgs_ x), ebp)
261     emit (ALOAD x) = emit (ILOAD x)
262     emit (ILOAD x) = do
263         push (Disp (cArgs x), ebp)
264     emit (ASTORE_ x) = emit (ISTORE_ x)
265     emit (ISTORE_ x) = do
266         pop eax
267         mov (Disp (cArgs_ x), ebp) eax
268     emit (ASTORE x) = emit (ISTORE x)
269     emit (ISTORE x) = do
270         pop eax
271         mov (Disp (cArgs x), ebp) eax
272
273     emit (LDC1 x) = emit (LDC2 $ fromIntegral x)
274     emit (LDC2 x) = do
275         let value = case (constsPool cls) M.! x of
276                       (CString s) -> unsafePerformIO $ getUniqueStringAddr s
277                       _ -> error $ "LDCI... missing impl."
278         push value
279     emit (GETFIELD x) = do
280         pop eax -- this pointer
281         let (cname, fname) = buildFieldOffset cls x
282         let offset = unsafePerformIO $ getFieldOffset cname fname
283         push (Disp (fromIntegral $ offset * 4), eax) -- get field
284     emit (PUTFIELD x) = do
285         pop ebx -- value to write
286         pop eax -- this pointer
287         let (cname, fname) = buildFieldOffset cls x
288         let offset = unsafePerformIO $ getFieldOffset cname fname
289         mov (Disp (fromIntegral $ offset * 4), eax) ebx -- set field
290
291     emit IADD = do pop ebx; pop eax; add eax ebx; push eax
292     emit ISUB = do pop ebx; pop eax; sub eax ebx; push eax
293     emit IMUL = do pop ebx; pop eax; mul ebx; push eax
294     emit (IINC x imm) = do
295         add (Disp (cArgs x), ebp) (s8_w32 imm)
296
297     emit (IF_ICMP cond _) = do
298         pop eax -- value2
299         pop ebx -- value1
300         cmp ebx eax -- intel syntax is swapped (TODO(bernhard): test that plz)
301         let sid = case successor bb of TwoTarget _ t -> t; _ -> error "bad"
302         let l = getLabel sid lmap
303         case cond of
304           C_EQ -> je  l; C_NE -> jne l
305           C_LT -> jl  l; C_GT -> jg  l
306           C_GE -> jge l; C_LE -> jle l
307
308     emit (IF cond _) = do
309         pop eax -- value1
310         cmp eax (0 :: Word32) -- TODO(bernhard): test that plz
311         let sid = case successor bb of TwoTarget _ t -> t; _ -> error "bad"
312         let l = getLabel sid lmap
313         case cond of
314           C_EQ -> je  l; C_NE -> jne l
315           C_LT -> jl  l; C_GT -> jg  l
316           C_GE -> jge l; C_LE -> jle l
317
318     emit (GOTO _ ) = do
319         let sid = case successor bb of OneTarget t -> t; _ -> error "bad"
320         jmp $ getLabel sid lmap
321
322     emit RETURN = do mov esp ebp; pop ebp; ret
323     emit ARETURN = emit IRETURN
324     emit IRETURN = do
325         pop eax
326         mov esp ebp
327         pop ebp
328         ret
329     emit invalid = error $ "insn not implemented yet: " ++ (show invalid)
330
331   -- for locals we use a different storage
332   cArgs :: Word8 -> Word32
333   cArgs x = if (x' >= thisMethodArgCnt)
334       -- TODO(bernhard): maybe s/(-4)/(-8)/
335       then fromIntegral $ (-4) * (x' - thisMethodArgCnt + 1)
336       else 4 + (thisMethodArgCnt * 4) - (4 * x')
337     where x' = fromIntegral x
338
339   cArgs_ :: IMM -> Word32
340   cArgs_ x = cArgs $ case x of I0 -> 0; I1 -> 1; I2 -> 2; I3 -> 3
341
342   thisMethodArgCnt :: Word32
343   thisMethodArgCnt = isNonStatic + (fromIntegral $ length args)
344     where
345     (Just m) = lookupMethod method cls
346     (MethodSignature args _) = methodSignature m
347     isNonStatic = if S.member ACC_STATIC (methodAccessFlags m)
348         then 0
349         else 1 -- one argument for the this pointer
350
351
352   -- sign extension from w8 to w32 (over s8)
353   --   unfortunately, hs-java is using Word8 everywhere (while
354   --   it should be Int8 actually)
355   s8_w32 :: Word8 -> Word32
356   s8_w32 w8 = fromIntegral s8
357     where s8 = (fromIntegral w8) :: Int8