classloading: load classfile on demand
[mate.git] / Mate / X86CodeGen.hs
1 {-# LANGUAGE OverloadedStrings #-}
2 {-# LANGUAGE ForeignFunctionInterface #-}
3 module Mate.X86CodeGen where
4
5 import Data.Binary
6 import Data.Int
7 import Data.Maybe
8 import qualified Data.Map as M
9 import qualified Data.ByteString.Lazy as B
10 import Control.Monad
11
12 import Foreign
13 import Foreign.C.Types
14
15 import Text.Printf
16
17 import qualified JVM.Assembler as J
18 import JVM.Assembler hiding (Instruction)
19 import JVM.ClassFile
20 import JVM.Converter
21
22 import Harpy
23 import Harpy.X86Disassembler
24
25 import Mate.BasicBlocks
26 import Mate.Utilities
27
28 foreign import ccall "dynamic"
29    code_int :: FunPtr (CInt -> CInt -> IO CInt) -> (CInt -> CInt -> IO CInt)
30
31 foreign import ccall "getaddr"
32   getaddr :: CUInt
33
34 foreign import ccall "callertrap"
35   callertrap :: IO ()
36
37 foreign import ccall "register_signal"
38   register_signal :: IO ()
39
40 foreign import ccall "get_cmap"
41   get_cmap :: IO (Ptr ())
42
43 foreign import ccall "set_cmap"
44   set_cmap :: Ptr () -> IO ()
45
46 test_01, test_02, test_03 :: IO ()
47 test_01 = do
48   register_signal
49   (entry, end) <- testCase "./tests/Fib.class" "fib"
50   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
51
52   mapM_ (\x -> do
53     result <- code_int entryFuncPtr x 0
54     let iresult :: Int; iresult = fromIntegral result
55     let kk :: String; kk = if iresult == (fib x) then "OK" else "FAIL (" ++ (show (fib x)) ++ ")"
56     printf "result of fib(%2d): %3d\t\t%s\n" (fromIntegral x :: Int) iresult kk
57     ) $ ([0..10] :: [CInt])
58   printf "patched disasm:\n"
59   Right newdisasm <- disassembleBlock entry end
60   mapM_ (putStrLn . showAtt) newdisasm
61   where
62     fib :: CInt -> Int
63     fib n
64       | n <= 1 = 1
65       | otherwise = (fib (n - 1)) + (fib (n - 2))
66
67
68 test_02 = do
69   (entry,_) <- testCase "./tests/While.class" "f"
70   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
71   result <- code_int entryFuncPtr 5 4
72   let iresult :: Int; iresult = fromIntegral result
73   let kk :: String; kk = if iresult == 15 then "OK" else "FAIL"
74   printf "result of f(5,4): %3d\t\t%s\n" iresult kk
75
76   result2 <- code_int entryFuncPtr 4 3
77   let iresult2 :: Int; iresult2 = fromIntegral result2
78   let kk2 :: String; kk2 = if iresult2 == 10 then "OK" else "FAIL"
79   printf "result of f(4,3): %3d\t\t%s\n" iresult2 kk2
80
81
82 test_03 = do
83   (entry,_) <- testCase "./tests/While.class" "g"
84   let entryFuncPtr = ((castPtrToFunPtr entry) :: FunPtr (CInt -> CInt -> IO CInt))
85   result <- code_int entryFuncPtr 5 4
86   let iresult :: Int; iresult = fromIntegral result
87   let kk :: String; kk = if iresult == 15 then "OK" else "FAIL"
88   printf "result of g(5,4): %3d\t\t%s\n" iresult kk
89
90   result2 <- code_int entryFuncPtr 4 3
91   let iresult2 :: Int; iresult2 = fromIntegral result2
92   let kk2 :: String; kk2 = if iresult2 == 10 then "OK" else "FAIL"
93   printf "result of g(4,3): %3d\t\t%s\n" iresult2 kk2
94
95
96 testCase :: String -> B.ByteString -> IO (Ptr Word8, Int)
97 testCase cf method = do
98       cls <- parseClassFile cf
99       hmap <- parseMethod cls method
100       printMapBB hmap
101       case hmap of
102         Nothing -> error "sorry, no code generation"
103         Just hmap' -> do
104               let ebb = emitFromBB cls hmap'
105               (_, Right ((entry, bbstarts, end, _), disasm)) <- runCodeGen ebb () ()
106               let int_entry = ((fromIntegral $ ptrToIntPtr entry) :: Int)
107               printf "disasm:\n"
108               mapM_ (putStrLn . showAtt) disasm
109               printf "basicblocks addresses:\n"
110               let b = map (\(x,y) -> (x,y + int_entry)) $ M.toList bbstarts
111               mapM_ (\(x,y) -> printf "\tBasicBlock %2d starts at 0x%08x\n" x y) b
112               return (entry, end)
113
114 type EntryPoint = Ptr Word8
115 type EntryPointOffset = Int
116 type PatchInfo = (BlockID, EntryPointOffset)
117
118 type BBStarts = M.Map BlockID Int
119
120 type CompileInfo = (EntryPoint, BBStarts, Int, CMap)
121
122 -- Word32 = point of method call in generated code
123 -- MethodInfo = relevant information about callee
124 type CMap = M.Map Word32 MethodInfo
125
126
127 emitFromBB :: Class Resolved -> MapBB -> CodeGen e s (CompileInfo, [Instruction])
128 emitFromBB cls hmap =  do
129         llmap <- sequence [newNamedLabel ("bb_" ++ show x) | (x,_) <- M.toList hmap]
130         let lmap = zip (Prelude.fst $ unzip $ M.toList hmap) llmap
131         ep <- getEntryPoint
132         push ebp
133         mov ebp esp
134
135         (calls, bbstarts) <- efBB (0,(hmap M.! 0)) M.empty M.empty lmap
136         d <- disassemble
137         end <- getCodeOffset
138         return ((ep, bbstarts, end, calls), d)
139   where
140   getLabel :: BlockID -> [(BlockID, Label)] -> Label
141   getLabel _ [] = error "label not found!"
142   getLabel i ((x,l):xs) = if i==x then l else getLabel i xs
143
144   efBB :: (BlockID, BasicBlock) -> CMap -> BBStarts -> [(BlockID, Label)] -> CodeGen e s (CMap, BBStarts)
145   efBB (bid, bb) calls bbstarts lmap =
146         if M.member bid bbstarts then
147           return (calls, bbstarts)
148         else do
149           bb_offset <- getCodeOffset
150           let bbstarts' = M.insert bid bb_offset bbstarts
151           defineLabel $ getLabel bid lmap
152           cs <- mapM emit' $ code bb
153           let calls' = calls `M.union` (M.fromList $ catMaybes cs)
154           case successor bb of
155             Return -> return (calls', bbstarts')
156             FallThrough t -> do
157               efBB (t, hmap M.! t) calls' bbstarts' lmap
158             OneTarget t -> do
159               efBB (t, hmap M.! t) calls' bbstarts' lmap
160             TwoTarget t1 t2 -> do
161               (calls'', bbstarts'') <- efBB (t1, hmap M.! t1) calls' bbstarts' lmap
162               efBB (t2, hmap M.! t2) calls'' bbstarts'' lmap
163     -- TODO(bernhard): also use metainformation
164     -- TODO(bernhard): implement `emit' as function which accepts a list of
165     --                 instructions, so we can use patterns for optimizations
166     where
167     emit' :: J.Instruction -> CodeGen e s (Maybe (Word32, MethodInfo))
168     emit' (INVOKESTATIC cpidx) = do
169         ep <- getEntryPoint
170         let w32_ep = (fromIntegral $ ptrToIntPtr ep) :: Word32
171         let l = buildMethodID cls cpidx
172         calladdr <- getCodeOffset
173         let w32_calladdr = w32_ep + (fromIntegral calladdr) :: Word32
174         newNamedLabel (show l) >>= defineLabel
175         -- causes SIGILL. in the signal handler we patch it to the acutal call.
176         -- place a nop at the end, therefore the disasm doesn't screw up
177         emit32 (0xffff9090 :: Word32) >> emit8 (0x90 :: Word8)
178         -- discard arguments on stack
179         let argcnt = (methodGetArgsCount cls cpidx) * 4
180         when (argcnt > 0) (add esp argcnt)
181         -- push result on stack if method has a return value
182         when (methodHaveReturnValue cls cpidx) (push eax)
183         return $ Just $ (w32_calladdr, l)
184     emit' insn = emit insn >> return Nothing
185
186     emit :: J.Instruction -> CodeGen e s ()
187     emit POP = do -- print dropped value
188         ep <- getEntryPoint
189         let w32_ep = (fromIntegral $ ptrToIntPtr ep) :: Word32
190         -- '5' is the size of the `call' instruction ( + immediate)
191         calladdr <- getCodeOffset
192         let w32_calladdr = 5 + w32_ep + (fromIntegral calladdr) :: Word32
193         let trapaddr = (fromIntegral getaddr :: Word32)
194         call (trapaddr - w32_calladdr)
195         add esp (4 :: Word32)
196     emit (BIPUSH val) = push ((fromIntegral val) :: Word32)
197     emit (SIPUSH val) = push ((fromIntegral $ ((fromIntegral val) :: Int16)) :: Word32)
198     emit (ICONST_0) = push (0 :: Word32)
199     emit (ICONST_1) = push (1 :: Word32)
200     emit (ICONST_2) = push (2 :: Word32)
201     emit (ICONST_4) = push (4 :: Word32)
202     emit (ICONST_5) = push (5 :: Word32)
203     emit (ILOAD_ x) = do
204         push (Disp (cArgs_ x), ebp)
205     emit (ISTORE_ x) = do
206         pop eax
207         mov (Disp (cArgs_ x), ebp) eax
208     emit IADD = do pop ebx; pop eax; add eax ebx; push eax
209     emit ISUB = do pop ebx; pop eax; sub eax ebx; push eax
210     emit IMUL = do pop ebx; pop eax; mul ebx; push eax
211     emit (IINC x imm) = do
212         add (Disp (cArgs x), ebp) (s8_w32 imm)
213
214     emit (IF_ICMP cond _) = do
215         pop eax -- value2
216         pop ebx -- value1
217         cmp ebx eax -- intel syntax is swapped (TODO(bernhard): test that plz)
218         let sid = case successor bb of TwoTarget _ t -> t; _ -> error "bad"
219         let l = getLabel sid lmap
220         case cond of
221           C_EQ -> je  l; C_NE -> jne l
222           C_LT -> jl  l; C_GT -> jg  l
223           C_GE -> jge l; C_LE -> jle l
224
225     emit (IF cond _) = do
226         pop eax -- value1
227         cmp eax (0 :: Word32) -- TODO(bernhard): test that plz
228         let sid = case successor bb of TwoTarget _ t -> t; _ -> error "bad"
229         let l = getLabel sid lmap
230         case cond of
231           C_EQ -> je  l; C_NE -> jne l
232           C_LT -> jl  l; C_GT -> jg  l
233           C_GE -> jge l; C_LE -> jle l
234
235     emit (GOTO _ ) = do
236         let sid = case successor bb of OneTarget t -> t; _ -> error "bad"
237         jmp $ getLabel sid lmap
238
239     emit RETURN = do mov esp ebp; pop ebp; ret
240     emit IRETURN = do
241         pop eax
242         mov esp ebp
243         pop ebp
244         ret
245     emit _ = do cmovbe eax eax -- dummy
246
247   cArgs x = (8 + 4 * (fromIntegral x))
248   cArgs_ x = (8 + 4 * case x of I0 -> 0; I1 -> 1; I2 -> 2; I3 -> 3)
249
250   -- sign extension from w8 to w32 (over s8)
251   --   unfortunately, hs-java is using Word8 everywhere (while
252   --   it should be Int8 actually)
253   s8_w32 :: Word8 -> Word32
254   s8_w32 w8 = fromIntegral s8
255     where s8 = (fromIntegral w8) :: Int8