Enhace constants pool handling.
[hs-java.git] / JVM / ClassFile.hs
index 87acae42515ba4f7ae24ba6cd66988678692fe40..6171c1bdecffd3854bc85f5c1c7259dc63cc2f52 100644 (file)
@@ -2,35 +2,67 @@
 -- | This module declares (low-level) data types for Java .class files
 -- structures, and Binary instances to read/write them.
 module JVM.ClassFile
-  (Attribute (..),
+  (-- * About
+   -- $about
+   --
+   -- * Internal class file structures
+   Attribute (..),
    FieldType (..),
+   -- * Signatures
    FieldSignature, MethodSignature (..), ReturnSignature (..),
    ArgumentSignature (..),
+   -- * Stage types
+   File, Direct,
+   -- * Staged structures
    Pool, Link,
    Method (..), Field (..), Class (..),
    Constant (..),
-   Pointers, Resolved,
-   NameType (..),
-   HasSignature (..), HasAttributes (..),
    AccessFlag (..), AccessFlags,
    Attributes (..),
-   className
+   defaultClass,
+   -- * Misc
+   HasSignature (..), HasAttributes (..),
+   NameType (..),
+   fieldNameType, methodNameType,
+   lookupField, lookupMethod,
+   toString,
+   className,
+   apsize, arsize, arlist
   )
   where
 
 import Control.Monad
+import Control.Monad.Trans (lift)
 import Control.Applicative
+import qualified Control.Monad.State as St
 import Data.Binary
 import Data.Binary.IEEE754
 import Data.Binary.Get
 import Data.Binary.Put
 import Data.Char
 import Data.List
+import Data.Default
 import qualified Data.Set as S
 import qualified Data.Map as M
 import qualified Data.ByteString.Lazy as B
 import Codec.Binary.UTF8.String hiding (encode, decode)
 
+-- $about
+--
+-- Java .class file uses constants pool, which stores almost all source-code-level
+-- constants (strings, integer literals etc), and also all identifiers (class,
+-- method, field names etc). All other structures contain indexes of constants in
+-- the pool instead of constants theirself.
+--
+-- It's not convient to use that indexes programmatically. So, .class file is represented
+-- at two stages: File and Direct. At File stage, all data structures contain only indexes,
+-- not constants theirself. When we read a class from a file, we get structure at File stage.
+-- We only can write File stage structure to file.
+--
+-- At Direct stage, structures conain constants, not indexes. Convertion functions (File <-> Direct)
+-- are located in the JVM.Converter module.
+--
+
 -- | Read one-byte Char
 getChar8 :: Get Char
 getChar8 = do
@@ -40,26 +72,58 @@ getChar8 = do
 toString :: B.ByteString -> String
 toString bstr = decodeString $ map (chr . fromIntegral) $ B.unpack bstr
 
-type family Link s a
+-- | File stage
+data File = File
 
-data Pointers = Pointers
+-- | Direct representation stage
+data Direct = Direct
 
-data Resolved = Resolved
+-- | Link to some object
+type family Link stage a
 
-type instance Link Pointers a = Word16
+-- | At File stage, Link contain index of object in the constants pool.
+type instance Link File a = Word16
 
-type instance Link Resolved a = a
+-- | At Direct stage, Link contain object itself.
+type instance Link Direct a = a
 
+-- | Object (class, method, field …) access flags 
 type family AccessFlags stage
 
-type instance AccessFlags Pointers = Word16
+-- | At File stage, access flags are represented as Word16
+type instance AccessFlags File = Word16
+
+-- | At Direct stage, access flags are represented as set of flags.
+type instance AccessFlags Direct = S.Set AccessFlag
+
+-- | Object (class, method, field) attributes
+data family Attributes stage
+
+-- | At File stage, attributes are represented as list of Attribute structures.
+data instance Attributes File = AP {attributesList :: [Attribute]}
+  deriving (Eq, Show)
+
+instance Default (Attributes File) where
+  def = AP []
+
+-- | At Direct stage, attributes are represented as a Map.
+data instance Attributes Direct = AR (M.Map B.ByteString B.ByteString)
+  deriving (Eq, Show)
+
+instance Default (Attributes Direct) where
+  def = AR M.empty
 
-type instance AccessFlags Resolved = S.Set AccessFlag
+-- | Size of attributes set at Direct stage
+arsize :: Attributes Direct -> Int
+arsize (AR m) = M.size m
 
-type family Attributes stage
+-- | Associative list of attributes at Direct stage
+arlist :: Attributes Direct -> [(B.ByteString, B.ByteString)]
+arlist (AR m) = M.assocs m
 
-type instance Attributes Pointers = [Attribute]
-type instance Attributes Resolved = M.Map B.ByteString B.ByteString
+-- | Size of attributes set at File stage
+apsize :: Attributes File -> Int
+apsize (AP list) = length list
 
 -- | Access flags. Used for classess, methods, variables.
 data AccessFlag =
@@ -76,7 +140,9 @@ data AccessFlag =
   | ACC_ABSTRACT          -- ^ 0x0400 
   deriving (Eq, Show, Ord, Enum)
 
-class HasSignature a where
+-- | Fields and methods have signatures.
+class (Binary (Signature a), Show (Signature a), Eq (Signature a))
+    => HasSignature a where
   type Signature a
 
 instance HasSignature Field where
@@ -90,36 +156,37 @@ data NameType a = NameType {
   ntName :: B.ByteString,
   ntSignature :: Signature a }
 
-instance Show (Signature a) => Show (NameType a) where
+instance (HasSignature a) => Show (NameType a) where
   show (NameType n t) = toString n ++ ": " ++ show t
 
-deriving instance Eq (Signature a) => Eq (NameType a)
+deriving instance HasSignature a => Eq (NameType a)
 
-instance (Binary (Signature a)) => Binary (NameType a) where
+instance HasSignature a => Binary (NameType a) where
   put (NameType n t) = putLazyByteString n >> put t
 
   get = NameType <$> get <*> get
 
 -- | Constant pool item
 data Constant stage =
-    CClass B.ByteString
-  | CField {refClass :: Link stage B.ByteString, fieldNameType :: Link stage (NameType Field)}
-  | CMethod {refClass :: Link stage B.ByteString, nameType :: Link stage (NameType Method)}
-  | CIfaceMethod {refClass :: Link stage B.ByteString, nameType :: Link stage (NameType Method)}
+    CClass (Link stage B.ByteString)
+  | CField (Link stage B.ByteString) (Link stage (NameType Field))
+  | CMethod (Link stage B.ByteString) (Link stage (NameType Method))
+  | CIfaceMethod (Link stage B.ByteString) (Link stage (NameType Method))
   | CString (Link stage B.ByteString)
   | CInteger Word32
   | CFloat Float
-  | CLong Integer
+  | CLong Word64
   | CDouble Double
   | CNameType (Link stage B.ByteString) (Link stage B.ByteString)
   | CUTF8 {getString :: B.ByteString}
   | CUnicode {getString :: B.ByteString}
 
-className ::  Constant Resolved -> B.ByteString
+-- | Name of the CClass. Error on any other constant.
+className ::  Constant Direct -> B.ByteString
 className (CClass s) = s
 className x = error $ "Not a class: " ++ show x
 
-instance Show (Constant Resolved) where
+instance Show (Constant Direct) where
   show (CClass name) = "class " ++ toString name
   show (CField cls nt) = "field " ++ toString cls ++ "." ++ show nt
   show (CMethod cls nt) = "method " ++ toString cls ++ "." ++ show nt
@@ -156,17 +223,42 @@ data Class stage = Class {
   classAttributes :: Attributes stage -- ^ Class attributes
   }
 
-deriving instance Eq (Constant Pointers)
-deriving instance Eq (Constant Resolved)
-deriving instance Show (Constant Pointers)
-
-instance Binary (Class Pointers) where
+deriving instance Eq (Class File)
+deriving instance Eq (Class Direct)
+deriving instance Show (Class File)
+deriving instance Show (Class Direct)
+
+deriving instance Eq (Constant File)
+deriving instance Eq (Constant Direct)
+deriving instance Show (Constant File)
+
+-- | Default (empty) class file definition.
+defaultClass :: (Default (AccessFlags stage), Default (Link stage B.ByteString), Default (Attributes stage))
+             => Class stage
+defaultClass = Class {
+  magic = 0xCAFEBABE,
+  minorVersion = 0,
+  majorVersion = 50,
+  constsPoolSize = 0,
+  constsPool = def,
+  accessFlags = def,
+  thisClass = def,
+  superClass = def,
+  interfacesCount = 0,
+  interfaces = [],
+  classFieldsCount = 0,
+  classFields = [],
+  classMethodsCount = 0,
+  classMethods = [],
+  classAttributesCount = 0,
+  classAttributes = def }
+
+instance Binary (Class File) where
   put (Class {..}) = do
     put magic
     put minorVersion
     put majorVersion
-    put constsPoolSize
-    forM_ (M.elems constsPool) put
+    putPool constsPool
     put accessFlags
     put thisClass
     put superClass
@@ -177,28 +269,32 @@ instance Binary (Class Pointers) where
     put classMethodsCount
     forM_ classMethods put
     put classAttributesCount
-    forM_ classAttributes put
+    forM_ (attributesList classAttributes) put
 
   get = do
     magic <- get
+    when (magic /= 0xCAFEBABE) $
+      fail $ "Invalid .class file MAGIC value: " ++ show magic
     minor <- get
     major <- get
-    poolsize <- get
-    pool <- replicateM (fromIntegral poolsize - 1) get
-    af <- get
+    when (major > 50) $
+      fail $ "Too new .class file format: " ++ show major
+    poolsize <- getWord16be
+    pool <- getPool (poolsize - 1)
+    af <-  get
     this <- get
     super <- get
     interfacesCount <- get
     ifaces <- replicateM (fromIntegral interfacesCount) get
-    classFieldsCount <- get
+    classFieldsCount <- getWord16be
     classFields <- replicateM (fromIntegral classFieldsCount) get
     classMethodsCount <- get
     classMethods <- replicateM (fromIntegral classMethodsCount) get
     asCount <- get
     as <- replicateM (fromIntegral $ asCount) get
-    let pool' = M.fromList $ zip [1..] pool
-    return $ Class magic minor major poolsize pool' af this super
-               interfacesCount ifaces classFieldsCount classFields classMethodsCount classMethods asCount as
+    return $ Class magic minor major poolsize pool af this super
+               interfacesCount ifaces classFieldsCount classFields
+               classMethodsCount classMethods asCount (AP as)
 
 -- | Field signature format
 data FieldType =
@@ -363,49 +459,79 @@ whileJust m = do
               return (x: next)
     Nothing -> return []
 
-instance Binary (Constant Pointers) where
-  put (CClass i) = putWord8 7 >> put i
-  put (CField i j) = putWord8 9 >> put i >> put j
-  put (CMethod i j) = putWord8 10 >> put i >> put j
-  put (CIfaceMethod i j) = putWord8 11 >> put i >> put j
-  put (CString i) = putWord8 8 >> put i
-  put (CInteger x) = putWord8 3 >> put x
-  put (CFloat x)   = putWord8 4 >> putFloat32be x
-  put (CLong x)    = putWord8 5 >> put x
-  put (CDouble x)  = putWord8 6 >> putFloat64be x
-  put (CNameType i j) = putWord8 12 >> put i >> put j
-  put (CUTF8 bs) = do
-                   putWord8 1
-                   put (fromIntegral (B.length bs) :: Word16)
-                   putLazyByteString bs
-  put (CUnicode bs) = do
-                   putWord8 2
-                   put (fromIntegral (B.length bs) :: Word16)
-                   putLazyByteString bs
+long (CLong _)   = True
+long (CDouble _) = True
+long _           = False
 
-  get = do
-    !offset <- bytesRead
-    tag <- getWord8
-    case tag of
-      1 -> do
-        l <- get
-        bs <- getLazyByteString (fromIntegral (l :: Word16))
-        return $ CUTF8 bs
-      2 -> do
-        l <- get
-        bs <- getLazyByteString (fromIntegral (l :: Word16))
-        return $ CUnicode bs
-      3  -> CInteger   <$> get
-      4  -> CFloat     <$> getFloat32be
-      5  -> CLong      <$> get
-      6  -> CDouble    <$> getFloat64be
-      7  -> CClass     <$> get
-      8  -> CString    <$> get
-      9  -> CField     <$> get <*> get
-      10 -> CMethod    <$> get <*> get
-      11 -> CIfaceMethod <$> get <*> get
-      12 -> CNameType    <$> get <*> get
-      _  -> fail $ "Unknown constants pool entry tag: " ++ show tag
+putPool :: Pool File -> Put
+putPool pool = do
+    let list = M.elems pool
+        d = length $ filter long list
+    putWord16be $ fromIntegral (M.size pool + d + 1)
+    forM_ list putC
+  where
+    putC (CClass i) = putWord8 7 >> put i
+    putC (CField i j) = putWord8 9 >> put i >> put j
+    putC (CMethod i j) = putWord8 10 >> put i >> put j
+    putC (CIfaceMethod i j) = putWord8 11 >> put i >> put j
+    putC (CString i) = putWord8 8 >> put i
+    putC (CInteger x) = putWord8 3 >> put x
+    putC (CFloat x)   = putWord8 4 >> putFloat32be x
+    putC (CLong x)    = putWord8 5 >> put x
+    putC (CDouble x)  = putWord8 6 >> putFloat64be x
+    putC (CNameType i j) = putWord8 12 >> put i >> put j
+    putC (CUTF8 bs) = do
+                     putWord8 1
+                     put (fromIntegral (B.length bs) :: Word16)
+                     putLazyByteString bs
+    putC (CUnicode bs) = do
+                     putWord8 2
+                     put (fromIntegral (B.length bs) :: Word16)
+                     putLazyByteString bs
+
+getPool :: Word16 -> Get (Pool File)
+getPool n = do
+    items <- St.evalStateT go 1
+    return $ M.fromList items
+  where
+    go :: St.StateT Word16 Get [(Word16, Constant File)]
+    go = do
+      i <- St.get
+      if i > n
+        then return []
+        else do
+          c <- lift getC
+          let i' = if long c
+                      then i+2
+                      else i+1
+          St.put i'
+          next <- go
+          return $ (i,c): next
+
+    getC = do
+      !offset <- bytesRead
+      tag <- getWord8
+      case tag of
+        1 -> do
+          l <- get
+          bs <- getLazyByteString (fromIntegral (l :: Word16))
+          return $ CUTF8 bs
+        2 -> do
+          l <- get
+          bs <- getLazyByteString (fromIntegral (l :: Word16))
+          return $ CUnicode bs
+        3  -> CInteger   <$> get
+        4  -> CFloat     <$> getFloat32be
+        5  -> CLong      <$> get
+        6  -> CDouble    <$> getFloat64be
+        7  -> CClass     <$> get
+        8  -> CString    <$> get
+        9  -> CField     <$> get <*> get
+        10 -> CMethod    <$> get <*> get
+        11 -> CIfaceMethod <$> get <*> get
+        12 -> CNameType    <$> get <*> get
+        _  -> fail $ "Unknown constants pool entry tag: " ++ show tag
+--         _ -> return $ CInteger 0
 
 -- | Class field format
 data Field stage = Field {
@@ -415,47 +541,69 @@ data Field stage = Field {
   fieldAttributesCount :: Word16,
   fieldAttributes :: Attributes stage }
 
-deriving instance Eq (Field Pointers)
-deriving instance Eq (Field Resolved)
-deriving instance Show (Field Pointers)
-deriving instance Show (Field Resolved)
+deriving instance Eq (Field File)
+deriving instance Eq (Field Direct)
+deriving instance Show (Field File)
+deriving instance Show (Field Direct)
+
+lookupField :: B.ByteString -> Class Direct -> Maybe (Field Direct)
+lookupField name cls = look (classFields cls)
+  where
+    look [] = Nothing
+    look (f:fs)
+      | fieldName f == name = Just f
+      | otherwise           = look fs
 
-instance Binary (Field Pointers) where
+fieldNameType :: Field Direct -> NameType Field
+fieldNameType f = NameType (fieldName f) (fieldSignature f)
+
+instance Binary (Field File) where
   put (Field {..}) = do
     put fieldAccessFlags 
     put fieldName
     put fieldSignature
     put fieldAttributesCount
-    forM_ fieldAttributes put
+    forM_ (attributesList fieldAttributes) put
 
   get = do
     af <- get
-    ni <- get
+    ni <- getWord16be
     si <- get
-    n <- get
+    n <- getWord16be
     as <- replicateM (fromIntegral n) get
-    return $ Field af ni si n as
+    return $ Field af ni si n (AP as)
 
 -- | Class method format
 data Method stage = Method {
-  methodAccessFlags :: Attributes stage,
+  methodAccessFlags :: AccessFlags stage,
   methodName :: Link stage B.ByteString,
   methodSignature :: Link stage MethodSignature,
   methodAttributesCount :: Word16,
   methodAttributes :: Attributes stage }
 
-deriving instance Eq (Method Pointers)
-deriving instance Eq (Method Resolved)
-deriving instance Show (Method Pointers)
-deriving instance Show (Method Resolved)
+deriving instance Eq (Method File)
+deriving instance Eq (Method Direct)
+deriving instance Show (Method File)
+deriving instance Show (Method Direct)
+
+methodNameType :: Method Direct -> NameType Method
+methodNameType m = NameType (methodName m) (methodSignature m)
+
+lookupMethod :: B.ByteString -> Class Direct -> Maybe (Method Direct)
+lookupMethod name cls = look (classMethods cls)
+  where
+    look [] = Nothing
+    look (f:fs)
+      | methodName f == name = Just f
+      | otherwise           = look fs
 
-instance Binary (Method Pointers) where
+instance Binary (Method File) where
   put (Method {..}) = do
     put methodAccessFlags
     put methodName
     put methodSignature
     put methodAttributesCount 
-    forM_ methodAttributes put
+    forM_ (attributesList methodAttributes) put
 
   get = do
     offset <- bytesRead
@@ -464,7 +612,12 @@ instance Binary (Method Pointers) where
     si <- get
     n <- get
     as <- replicateM (fromIntegral n) get
-    return $ Method af ni si n as
+    return $ Method {
+               methodAccessFlags = af,
+               methodName = ni,
+               methodSignature = si,
+               methodAttributesCount = n,
+               methodAttributes = AP as }
 
 -- | Any (class/ field/ method/ ...) attribute format.
 -- Some formats specify special formats for @attributeValue@.
@@ -482,7 +635,7 @@ instance Binary Attribute where
 
   get = do
     offset <- bytesRead
-    name <- get
+    name <- getWord16be
     len <- getWord32be
     value <- getLazyByteString (fromIntegral len)
     return $ Attribute name len value