Skip to content

Commit

Permalink
Proper static dispatch for traits
Browse files Browse the repository at this point in the history
  • Loading branch information
rockofox committed Dec 1, 2024
1 parent 0e232ed commit 0a0aab4
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 53 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ cabal.project.local~
*.wasm
result
*.tix
examples/scratch.in
64 changes: 28 additions & 36 deletions examples/scratch.in
Original file line number Diff line number Diff line change
@@ -1,41 +1,33 @@
# # let split ([]: String sep: String) => [String] = [""]
# # let split ((c:cs): String sep: String) => [String] = do
# # let rest = split cs, sep
# # if c == sep then do
# # "" : (split cs, sep)
# # else do
# # (c : (head rest)) : (tail rest)
# # end
# # end
# trait Number
# impl Number for Int
# impl Number for Float
### println (maybe (-1), (bind Some{value: 5}, \x -> x + 1))
### println (maybe (-1), (bind None{}, \x -> x + 1))
##
###let x => Optional = return 12
##let xxx => Optional = Some{value: 12}
###println x
##println (bind Some{value: 5}, \x -> x + 1)
##println (bind xxx, \x -> x + 1)
##println (sequence Some{value: 5}, Some{value: 6})
#
# # let add<N: Number> (a: N b: N) => N = a + b
#bla :: IO
#bla = (println 1) >> (println 2)
#
# let split<T> ([]: [T] sep: T) => [T] = [""]
#
# let printNumList<N: Number> (l: [N] n: N) => IO = do
#
# end
printNumbersAndSum :: IO
printNumbersAndSum _ = (println "x") >> (return 2)

main :: IO
main _ = do
println printNumbersAndSum
end
# let main => IO = println printNumbersAndSum
#
# let xxx (a: Int) => Int = a
#trait Number
#impl Number for Int
#impl Number for Float
#
#let add<N: Number> (a: N b: N) => N = do
# a + b
#end
#
# struct Child = (name: String, age: Int) satisfies (it.age < 18)
# let a+ (a: Int b: Int) => Int = a
# let main => IO = do
# println 1 a+ 2
# # println Child {name: "John", age: 13}
# # println Child {name: "Abram", age: 20}
# end

# `+` :: Int -> Int -> Int
# struct Cat = (name: String, age: Int)
# # let + (a: Cat b: Cat) => Cat = Cat {name: a.name : b.name, age: a.age + b.age}
# # let add (a: Int b: Int) => Int = a + b
# # let add (a: Float b: Float) => Float = a - b
# add :: Int -> Int -> Int
# add a b = a + b
# let main => IO = println add 3.0, 4.0
2+5
#let main => IO = do
# println add 1.2, 2.2
#end
55 changes: 39 additions & 16 deletions lib/BytecodeCompiler.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import Data.String
import Data.Text (isPrefixOf, splitOn)
import Data.Text qualified
import Data.Text qualified as T
import Debug.Trace
import Foreign (nullPtr, ptrToWordPtr)
import Foreign.C.Types ()
import GHC.Generics (Generic)
Expand Down Expand Up @@ -58,7 +57,7 @@ data CompilerState a = CompilerState
, impls :: [Parser.Expr]
, currentContext :: String -- TODO: Nested contexts
, externals :: [External]
, currentExpectedReturnType :: Parser.Type
, functionsByTrait :: [(String, String, String, Parser.Expr)]
}
deriving (Show)

Expand All @@ -82,7 +81,7 @@ initCompilerState prog =
[]
"__outside"
[]
Parser.Any
[]

allocId :: StateT (CompilerState a) IO Int
allocId = do
Expand Down Expand Up @@ -234,6 +233,12 @@ methodsForStruct structName = do
let methods = concatMap Parser.methods impls''
return $ map Parser.name methods

findBaseDecInTraits :: String -> StateT (CompilerState a) IO (Maybe Parser.Expr)
findBaseDecInTraits funcName = do
traits' <- gets traits
let baseDecs = map (find (\y -> Parser.name y == funcName) . Parser.methods) traits'
return $ firstJust id baseDecs

typeToData :: Parser.Type -> VM.Data
typeToData (Parser.StructT "Int") = VM.DInt 0
typeToData (Parser.StructT "Float") = VM.DFloat 0
Expand All @@ -251,6 +256,10 @@ typeToString :: Parser.Type -> String
typeToString (Parser.StructT x) = x
typeToString x = show x

typesEqual :: Parser.Type -> Parser.Type -> Bool
typesEqual (Parser.StructT x) (Parser.StructT y) = x == y
typesEqual x y = x == y

compileExpr :: Parser.Expr -> Parser.Type -> StateT (CompilerState a) IO [Instruction]
compileExpr (Parser.Add x y) _ = compileExpr (Parser.FuncCall "+" [x, y] zeroPosition) Parser.Any >>= doBinOp x y . last
compileExpr (Parser.Sub x y) _ = compileExpr (Parser.FuncCall "-" [x, y] zeroPosition) Parser.Any >>= doBinOp x y . last
Expand Down Expand Up @@ -301,6 +310,7 @@ compileExpr (Parser.FuncCall funcName args _) expectedType = do
funcDecs' <- gets funcDecs
curCon <- gets currentContext
externals' <- gets externals
fbt <- gets functionsByTrait
let contexts = map (T.pack . intercalate "@") (inits (Data.List.Split.splitOn "@" curCon))
-- Find the function in any context using firstM
let contextFunctions = firstJust (\context -> findFunction (Data.Text.unpack context ++ "@" ++ funcName) functions' argTypes) contexts
Expand All @@ -313,14 +323,12 @@ compileExpr (Parser.FuncCall funcName args _) expectedType = do
Nothing -> case findFunction funcName functions' argTypes of
(Just f) -> f
Nothing -> Function{baseName = unmangleFunctionName funcName, funame = funcName, function = [], types = [], context = "__outside"}
-- traceM $ "Looking for " ++ funcName ++ " with types " ++ show argTypes ++ " in " ++ show implsForExpectedTypePrefixes
-- when (funcName == "return") $ traceM $ "Expected type: " ++ show expectedType
-- when (funcName == "return") $ traceM $ "Find " ++ show implsForExpectedTypePrefixes ++ " in " ++ (show $ map (\x -> x.baseName) functions')
-- Find in impls
let funcDec = case find (\(Function{baseName}) -> baseName `elem` implsForExpectedTypePrefixes) functions' of
Just funf -> do
-- traceM $ "Found " ++ baseName funf ++ " in impls"
Just Parser.FuncDec{Parser.name = baseName funf, Parser.types = funf.types, Parser.generics = []}
let funcDec = case find (\(Parser.FuncDec name' _ _) -> name' == baseName fun) funcDecs' of
Just fd -> do
case find (\(_, n, _, newDec) -> n == funcName && take (length args) (Parser.types newDec) == argTypes) fbt of
Just (_, _, fqn, newDec) -> do
Just Parser.FuncDec{Parser.name = fqn, Parser.types = Parser.types newDec, Parser.generics = []}
Nothing -> Just fd
Nothing -> find (\(Parser.FuncDec name' _ _) -> name' == baseName fun) funcDecs'
let external = find (\x -> x.name == funcName) externals'
-- If the funcName starts with curCon@, it's a local function
Expand All @@ -339,9 +347,9 @@ compileExpr (Parser.FuncCall funcName args _) expectedType = do
case funcDec of
(Just fd) -> do
if length args == length (Parser.types fd) - 1
then concatMapM (\arg -> typeOf arg >>= compileExpr arg) args >>= \args' -> return (args' ++ [callWay])
then concatMapM (uncurry compileExpr) (zip args fd.types) >>= \args' -> return (args' ++ [callWay])
else
concatMapM (\arg -> typeOf arg >>= compileExpr arg) args >>= \args' ->
concatMapM (uncurry compileExpr) (zip args fd.types) >>= \args' ->
return $
args'
++ [PushPf (funame fun) (length args')]
Expand All @@ -357,7 +365,7 @@ compileExpr (Parser.FuncCall funcName args _) expectedType = do
compileExpr fd@(Parser.FuncDec{}) _ = do
modify (\s -> s{funcDecs = fd : funcDecs s})
return [] -- Function declarations are only used for compilation
compileExpr (Parser.FuncDef origName args body) expectedType = do
compileExpr (Parser.FuncDef origName args body) _ = do
curCon <- gets currentContext
funs <- gets functions
let previousContext = curCon
Expand Down Expand Up @@ -439,7 +447,7 @@ compileExpr (Parser.FuncDef origName args body) expectedType = do
compileParameter Parser.Placeholder _ = return []
compileParameter x funcName = do
nextFunName <- ((funcName ++ "#") ++) . show . (+ 1) <$> allocId
x' <- compileExpr x Parser.Any
x' <- compileExpr x Parser.Unknown
return $ [Dup] ++ x' ++ [Eq, Jf nextFunName]
compileExpr (Parser.ParenApply x y _) _ = do
fun <- compileExpr x Parser.Any
Expand Down Expand Up @@ -553,10 +561,25 @@ compileExpr (Parser.Trait name methods) _ = do
mapM_ (`compileExpr` Parser.Any) methods'
return []
compileExpr (Parser.Impl name for methods) _ = do
let methods' = map (\(Parser.FuncDef name' args body) -> Parser.FuncDef (name ++ "." ++ for ++ "::" ++ name') args body) methods
methods' <-
mapM
( \(Parser.FuncDef name' args body) -> do
let fullyQualifiedName = name ++ "." ++ for ++ "::" ++ name'
trait <- gets traits >>= \traits' -> return $ fromJust $ find (\x -> Parser.name x == name) traits'
let dec = fromJust $ find (\x -> Parser.name x == name') (Parser.methods trait)
let newDec = Parser.FuncDec{name = fullyQualifiedName, types = unself dec.types for, generics = dec.generics}
_ <- compileExpr newDec Parser.Any
modify (\s -> s{functionsByTrait = (for, name', fullyQualifiedName, newDec) : functionsByTrait s})
return $ Parser.FuncDef fullyQualifiedName args body
)
methods
-- gets functionsByTrait >>= traceShowM
modify (\s -> s{impls = Parser.Impl name for methods : impls s})
mapM_ (`compileExpr` Parser.Any) methods'
return []
where
unself :: [Parser.Type] -> String -> [Parser.Type]
unself types self = map (\case Parser.Self -> Parser.StructT self; x -> x) types
compileExpr (Parser.Lambda args body) _ = do
fId <- allocId
curCon <- gets currentContext
Expand Down
3 changes: 2 additions & 1 deletion lib/Parser.hs
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,10 @@ charLit = lexeme (char '\'' *> L.charLiteral <* char '\'') <|> lexeme (L.decimal
funcDec :: Parser Expr
funcDec = do
name <- (identifier <|> gravis) <?> "function name"
generics <- fromMaybe [] <$> optional generic <?> "function generics"
symbol "::"
argTypes <- sepBy1 validType (symbol "->") <?> "function arguments"
return $ FuncDec name argTypes []
return $ FuncDec name argTypes generics

defArg :: Parser Expr
defArg = try structLit <|> var <|> parens listPattern <|> array <|> placeholder <|> IntLit <$> integer <|> StringLit <$> stringLit
Expand Down
12 changes: 12 additions & 0 deletions tests/IntegrationSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,18 @@ spec = do
end
|]
`shouldReturn` "hello"
it "Static dispatch" $ do
compileAndRun
[r|
printNumbersAndSum :: IO
printNumbersAndSum _ = (println "x") >> (return 2)

main :: IO
main _ = do
println printNumbersAndSum
end
|]
`shouldReturn` "x\nIO{__traits: [Monad,__field_inner], inner: 2}\n"
describe "Lambdas" $ do
it "Can use a lambda in the map function" $ do
compileAndRun
Expand Down
20 changes: 20 additions & 0 deletions tests/ParserSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,23 @@ spec = do
parseProgram "bottles (i)-1" parserCompilerFlags
`shouldBe` Right
(Program [FuncCall "bottles" [Sub (Var "i" anyPosition) (IntLit 1)] anyPosition])

describe "Generics" $ do
it "Should parse let generics" $ do
parseProgram
[r|
let add<N: Number> (a: N b: N) => N = do
a + b
end
|]
parserCompilerFlags
`shouldBe` Right
(Program [Function{def = [FuncDef{name = "add", args = [Var "a" anyPosition, Var "b" anyPosition], body = DoBlock [Add (Var "a" anyPosition) (Var "b" anyPosition)]}], dec = FuncDec{name = "add", types = [StructT "N", StructT "N", StructT "N"], generics = [GenericExpr "N" (Just $ StructT "Number")]}}])
it "Should parse classic decleration generics" $ do
parseProgram
[r|
add<N: Number> :: N -> N -> N
|]
parserCompilerFlags
`shouldBe` Right
(Program [FuncDec{name = "add", types = [StructT "N", StructT "N", StructT "N"], generics = [GenericExpr "N" (Just $ StructT "Number")]}])
1 change: 1 addition & 0 deletions wasm_reactor/.envrc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
use flake

0 comments on commit 0a0aab4

Please sign in to comment.