Skip to content

Commit 3ae5030

Browse files
committed
ffi: support returning structs
1 parent 59b415a commit 3ae5030

File tree

3 files changed

+76
-15
lines changed

3 files changed

+76
-15
lines changed

lib/AST.hs

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module AST where
22

33
import Data.Binary qualified
4+
import Data.Map qualified
45
import GHC.Generics (Generic)
56
import VM qualified
67

@@ -213,7 +214,7 @@ typeOf (FuncCall{}) = Any -- error "Cannot infer type of variable"
213214
typeOf Placeholder = Any
214215
typeOf (Var{}) = Any -- error "Cannot infer type of variable"
215216
typeOf (Let _ _) = error "Cannot infer type of let"
216-
typeOf (If{}) = error "Cannot infer type of if"
217+
typeOf (If _ b _) = typeOf b
217218
typeOf (FuncDef{}) = error "Cannot infer type of function definition"
218219
typeOf (FuncDec _ _) = error "Cannot infer type of function declaration"
219220
typeOf (Function _ _) = error "Cannot infer type of modern function"
@@ -253,6 +254,8 @@ typeToData String = VM.DString ""
253254
typeToData (StructT "IO") = VM.DNone -- Hmmm...
254255
typeToData Char = VM.DChar ' '
255256
typeToData CPtr = VM.DCPtr 0
257+
typeToData StructT{} = VM.DMap Data.Map.empty
258+
typeToData Any = VM.DNone
256259
typeToData x = error $ "Cannot convert type " ++ show x ++ " to data"
257260

258261
-- typeOf x = error $ "Cannot infer type of " ++ show x

lib/BytecodeCompiler.hs

+16-1
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ import AST qualified as Parser.Type (Type (Unknown))
66
import Control.Monad (when, (>=>))
77
import Control.Monad.Loops (firstM)
88
import Control.Monad.State (MonadIO (liftIO), StateT, gets, modify)
9+
import Data.Bifunctor (second)
910
import Data.Functor ((<&>))
1011
import Data.List (elemIndex, find, inits, intercalate, isInfixOf, tails)
1112
import Data.List.Split qualified
13+
import Data.Map qualified
1214
import Data.Maybe (fromJust, isJust, isNothing)
1315
import Data.Text (isPrefixOf, splitOn)
1416
import Data.Text qualified
@@ -181,6 +183,15 @@ typeOf (Parser.Var varName _) = do
181183
return a
182184
typeOf x = return $ Parser.typeOf x
183185

186+
getStructFields :: String -> StateT (CompilerState a) IO [(String, Parser.Type)]
187+
getStructFields structName = do
188+
structDecs' <- gets structDecs
189+
let structDec = fromJust $ find (\x -> Parser.name x == structName) structDecs'
190+
let fields = case structDec of
191+
Parser.Struct _ fields' -> fields'
192+
_ -> error "Not a struct"
193+
return fields
194+
184195
constructFQName :: String -> [Parser.Expr] -> StateT (CompilerState a) IO String
185196
constructFQName "main" _ = return "main"
186197
constructFQName funcName args = mapM (typeOf >=> return . show) args <&> \x -> funcName ++ ":" ++ intercalate "," x
@@ -246,7 +257,11 @@ compileExpr (Parser.FuncCall funcName args _) = do
246257

247258
case external of
248259
Just (External _ ereturnType _ from) -> do
249-
let retT = typeToData ereturnType
260+
retT <- case ereturnType of
261+
Parser.StructT name -> do
262+
fields <- getStructFields name
263+
return $ DMap $ Data.Map.fromList $ map (second typeToData) fields
264+
_ -> return $ typeToData ereturnType
250265
args' <- concatMapM compileExpr (reverse args)
251266
return $ [Push retT] ++ args' ++ [CallFFI funcName from (length args)]
252267
Nothing ->

lib/VM.hs

+56-13
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,25 @@ import GHC.Generics (Generic)
2424
import System.Random (randomIO)
2525
#ifdef FFI
2626
import Ffi
27-
import Foreign.LibFFI.Base (newStorableStructArgRet, newStructCType)
27+
import Foreign.LibFFI.Base (newStorableStructArgRet, newStructCType, sizeAndAlignmentOfCType)
2828
import Foreign.LibFFI.FFITypes
2929
import Foreign.LibFFI
3030
#endif
3131

32+
import Control.Monad.Reader
3233
import Data.Binary qualified (get, put)
3334
import Data.Char (chr, ord)
35+
import Data.IORef (IORef)
3436
import Data.Text.Internal.Unsafe.Char
3537
import Foreign.C (CDouble, newCString)
3638
import Foreign.C.String (castCharToCChar)
37-
import Foreign.C.Types (CChar, CFloat, CInt)
39+
import Foreign.C.Types (CChar, CFloat, CInt, CSChar, CUChar)
40+
import Foreign.LibFFI.Internal (CType)
3841
import Foreign.Marshal.Array
3942
import Foreign.Ptr
4043
import Foreign.Storable
4144
import GHC.IO (unsafePerformIO)
45+
import GHC.IORef
4246
import GHC.Int qualified as Ghc.Int
4347

4448
data Instruction
@@ -298,9 +302,11 @@ showDebugInfo = do
298302
vm <- get
299303
-- Show the current instruction, two values from the stack, and the result
300304
let inst = program vm V.! pc vm
301-
let stack' = stack vm
302-
let stack'' = if length stack' > 1 then take 2 stack' else stack'
303-
return $ show (pc vm) ++ "\t" ++ show inst ++ "\t" ++ show stack'' ++ "\t" ++ show (safeHead $ callStack vm)
305+
-- let stack' = stack vm
306+
-- let stack'' = if length stack' > 1 then take 2 stack' else stack'
307+
-- return $ show (pc vm) ++ "\t" ++ show inst ++ "\t" ++ show stack'' ++ "\t" ++ show (safeHead $ callStack vm)
308+
let layers = length (callStack vm) - 1
309+
return $ replicate layers ' ' ++ show (pc vm) ++ " " ++ show inst -- ++ " (" ++ show (stack vm) ++ ")"
304310

305311
run' :: Program -> StateT VM IO ()
306312
run' program = do
@@ -450,27 +456,56 @@ instance IntoData (Ptr ()) where
450456
clearMap :: Data.Map String Data -> Data.Map String Data
451457
clearMap = Data.Map.delete "__name" . Data.Map.delete "__traits"
452458

459+
460+
globalStructType :: IORef [Ptr CType]
461+
{-# NOINLINE globalStructType #-}
462+
globalStructType = unsafePerformIO $ newIORef []
463+
453464
-- Only used for structs. I really don't like this
454465
instance Storable Data where
455466
sizeOf _ = 4
456467
alignment _ = 4
457-
peek _ = error "peek"
468+
peek ptr = do
469+
types <- readIORef globalStructType
470+
sizesAndAlignments <- mapM sizeAndAlignmentOfCType types
471+
let typesAndOffsets = zip types (scanl (\(o,_) (s,a) -> (o+s,a)) (0,0) sizesAndAlignments)
472+
b <- mapM mapData typesAndOffsets
473+
return $ DList b
474+
where
475+
mapData :: (Ptr CType, (Int, Int)) -> IO Data
476+
mapData tsa = do
477+
let (t,(o,_)) = tsa
478+
case () of _
479+
| t == ffi_type_float -> do
480+
return $ DFloat (realToFrac (unsafePerformIO $ peekByteOff ptr o :: CFloat))
481+
| t == ffi_type_sint32 -> do
482+
return $ DInt (fromIntegral (unsafePerformIO $ peekByteOff ptr o :: CInt))
483+
| t == ffi_type_double -> do
484+
return $ DDouble (realToFrac (unsafePerformIO $ peekByteOff ptr o :: CDouble))
485+
| t == ffi_type_pointer -> do
486+
return $ DCPtr (unsafePerformIO $ peekByteOff ptr o :: WordPtr)
487+
| t == ffi_type_uchar -> do
488+
return $ DChar (toEnum (fromEnum (unsafePerformIO $ peekByteOff ptr o :: CUChar)))
489+
| t == ffi_type_void -> do
490+
return DNone
491+
| otherwise -> error $ "mapData: Invalid type " ++ show t
492+
493+
458494
poke ptr (DInt x) = poke (castPtr ptr) x
459495
poke ptr (DFloat x) = poke (castPtr ptr) x
460496
poke ptr (DString x) = poke (castPtr ptr) (unsafePerformIO $ newCString x)
461497
poke ptr (DChar x) = poke (castPtr ptr) x
462498
poke ptr (DMap x) = do
463499
let values = reverse $ Data.Map.elems (clearMap x)
464-
let sizes = map sizeOf' values
500+
let sizes = map sizeOfC values
465501
mapM_ (\(i, v) -> pokeByteOff ptr (sum $ take i sizes) v) (zip [0..] values)
466-
where
467-
-- Normal sizeof crashes and I don't know why
468-
sizeOf' :: Data -> Int
469-
sizeOf' (DChar _) = 1
470-
sizeOf' (DDouble _) = 8
471-
sizeOf' _ = 4
472502
poke _ x = error $ "unsupported poke " ++ show x
473503

504+
sizeOfC :: Data -> Int
505+
sizeOfC (DChar _) = 1
506+
sizeOfC (DDouble _) = 8
507+
sizeOfC _ = 4
508+
474509
dataCType (DInt _) = ffi_type_sint32
475510
dataCType (DFloat _) = ffi_type_float
476511
dataCType (DString _) = ffi_type_pointer
@@ -579,6 +614,14 @@ runInstruction (CallFFI name from numArgs) = do
579614
let retType = ret retT :: RetType ()
580615
_ <- liftIO $ callFFI fun retType ffiArgs'
581616
return ()
617+
DMap x -> do
618+
let keys = reverse (Data.Map.keys (clearMap x))
619+
let types = map dataCType (Data.Map.elems (clearMap x))
620+
liftIO $ writeIORef globalStructType types
621+
(_,retType,_) <- liftIO (newStorableStructArgRet types :: IO (Data -> Arg, RetType Data, IO ()))
622+
(DList values) <- liftIO $ callFFI fun retType ffiArgs'
623+
let result = DMap $ Data.Map.fromList $ zip keys values
624+
stackPushA result
582625
_ -> error $ "Invalid return type: " ++ show retT
583626
liftIO $ mapM_ (\case Just x -> x; Nothing -> return ()) frees
584627
#else

0 commit comments

Comments
 (0)