From 0a0aab48d6fada8b504c8f601e3de9f779093cb2 Mon Sep 17 00:00:00 2001 From: rockofox Date: Sun, 1 Dec 2024 19:44:25 +0100 Subject: [PATCH] Proper static dispatch for traits --- .gitignore | 1 + examples/scratch.in | 64 ++++++++++++++++++---------------------- lib/BytecodeCompiler.hs | 55 ++++++++++++++++++++++++---------- lib/Parser.hs | 3 +- tests/IntegrationSpec.hs | 12 ++++++++ tests/ParserSpec.hs | 20 +++++++++++++ wasm_reactor/.envrc | 1 + 7 files changed, 103 insertions(+), 53 deletions(-) create mode 100644 wasm_reactor/.envrc diff --git a/.gitignore b/.gitignore index d042c78..79b8e15 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,4 @@ cabal.project.local~ *.wasm result *.tix +examples/scratch.in \ No newline at end of file diff --git a/examples/scratch.in b/examples/scratch.in index e52100a..92d477d 100644 --- a/examples/scratch.in +++ b/examples/scratch.in @@ -1,41 +1,33 @@ -# # let split ([]: String sep: String) => [String] = [""] -# # let split ((c:cs): String sep: String) => [String] = do -# # let rest = split cs, sep -# # if c == sep then do -# # "" : (split cs, sep) -# # else do -# # (c : (head rest)) : (tail rest) -# # end -# # end -# trait Number -# impl Number for Int -# impl Number for Float +### println (maybe (-1), (bind Some{value: 5}, \x -> x + 1)) +### println (maybe (-1), (bind None{}, \x -> x + 1)) +## +###let x => Optional = return 12 +##let xxx => Optional = Some{value: 12} +###println x +##println (bind Some{value: 5}, \x -> x + 1) +##println (bind xxx, \x -> x + 1) +##println (sequence Some{value: 5}, Some{value: 6}) # -# # let add (a: N b: N) => N = a + b +#bla :: IO +#bla = (println 1) >> (println 2) # -# let split ([]: [T] sep: T) => [T] = [""] -# -# let printNumList (l: [N] n: N) => IO = do -# -# end +printNumbersAndSum :: IO +printNumbersAndSum _ = (println "x") >> (return 2) + +main :: IO +main _ = do + println printNumbersAndSum +end +# let main => IO = println printNumbersAndSum # -# let xxx (a: Int) => Int = a +#trait Number +#impl Number for Int +#impl Number for Float # +#let add (a: N b: N) => N = do +# a + b +#end # -# struct Child = (name: String, age: Int) satisfies (it.age < 18) -# let a+ (a: Int b: Int) => Int = a -# let main => IO = do -# println 1 a+ 2 -# # println Child {name: "John", age: 13} -# # println Child {name: "Abram", age: 20} -# end - -# `+` :: Int -> Int -> Int -# struct Cat = (name: String, age: Int) -# # let + (a: Cat b: Cat) => Cat = Cat {name: a.name : b.name, age: a.age + b.age} -# # let add (a: Int b: Int) => Int = a + b -# # let add (a: Float b: Float) => Float = a - b -# add :: Int -> Int -> Int -# add a b = a + b -# let main => IO = println add 3.0, 4.0 -2+5 +#let main => IO = do +# println add 1.2, 2.2 +#end diff --git a/lib/BytecodeCompiler.hs b/lib/BytecodeCompiler.hs index fcfa6d8..5963366 100644 --- a/lib/BytecodeCompiler.hs +++ b/lib/BytecodeCompiler.hs @@ -15,7 +15,6 @@ import Data.String import Data.Text (isPrefixOf, splitOn) import Data.Text qualified import Data.Text qualified as T -import Debug.Trace import Foreign (nullPtr, ptrToWordPtr) import Foreign.C.Types () import GHC.Generics (Generic) @@ -58,7 +57,7 @@ data CompilerState a = CompilerState , impls :: [Parser.Expr] , currentContext :: String -- TODO: Nested contexts , externals :: [External] - , currentExpectedReturnType :: Parser.Type + , functionsByTrait :: [(String, String, String, Parser.Expr)] } deriving (Show) @@ -82,7 +81,7 @@ initCompilerState prog = [] "__outside" [] - Parser.Any + [] allocId :: StateT (CompilerState a) IO Int allocId = do @@ -234,6 +233,12 @@ methodsForStruct structName = do let methods = concatMap Parser.methods impls'' return $ map Parser.name methods +findBaseDecInTraits :: String -> StateT (CompilerState a) IO (Maybe Parser.Expr) +findBaseDecInTraits funcName = do + traits' <- gets traits + let baseDecs = map (find (\y -> Parser.name y == funcName) . Parser.methods) traits' + return $ firstJust id baseDecs + typeToData :: Parser.Type -> VM.Data typeToData (Parser.StructT "Int") = VM.DInt 0 typeToData (Parser.StructT "Float") = VM.DFloat 0 @@ -251,6 +256,10 @@ typeToString :: Parser.Type -> String typeToString (Parser.StructT x) = x typeToString x = show x +typesEqual :: Parser.Type -> Parser.Type -> Bool +typesEqual (Parser.StructT x) (Parser.StructT y) = x == y +typesEqual x y = x == y + compileExpr :: Parser.Expr -> Parser.Type -> StateT (CompilerState a) IO [Instruction] compileExpr (Parser.Add x y) _ = compileExpr (Parser.FuncCall "+" [x, y] zeroPosition) Parser.Any >>= doBinOp x y . last compileExpr (Parser.Sub x y) _ = compileExpr (Parser.FuncCall "-" [x, y] zeroPosition) Parser.Any >>= doBinOp x y . last @@ -301,6 +310,7 @@ compileExpr (Parser.FuncCall funcName args _) expectedType = do funcDecs' <- gets funcDecs curCon <- gets currentContext externals' <- gets externals + fbt <- gets functionsByTrait let contexts = map (T.pack . intercalate "@") (inits (Data.List.Split.splitOn "@" curCon)) -- Find the function in any context using firstM let contextFunctions = firstJust (\context -> findFunction (Data.Text.unpack context ++ "@" ++ funcName) functions' argTypes) contexts @@ -313,14 +323,12 @@ compileExpr (Parser.FuncCall funcName args _) expectedType = do Nothing -> case findFunction funcName functions' argTypes of (Just f) -> f Nothing -> Function{baseName = unmangleFunctionName funcName, funame = funcName, function = [], types = [], context = "__outside"} - -- traceM $ "Looking for " ++ funcName ++ " with types " ++ show argTypes ++ " in " ++ show implsForExpectedTypePrefixes - -- when (funcName == "return") $ traceM $ "Expected type: " ++ show expectedType - -- when (funcName == "return") $ traceM $ "Find " ++ show implsForExpectedTypePrefixes ++ " in " ++ (show $ map (\x -> x.baseName) functions') - -- Find in impls - let funcDec = case find (\(Function{baseName}) -> baseName `elem` implsForExpectedTypePrefixes) functions' of - Just funf -> do - -- traceM $ "Found " ++ baseName funf ++ " in impls" - Just Parser.FuncDec{Parser.name = baseName funf, Parser.types = funf.types, Parser.generics = []} + let funcDec = case find (\(Parser.FuncDec name' _ _) -> name' == baseName fun) funcDecs' of + Just fd -> do + case find (\(_, n, _, newDec) -> n == funcName && take (length args) (Parser.types newDec) == argTypes) fbt of + Just (_, _, fqn, newDec) -> do + Just Parser.FuncDec{Parser.name = fqn, Parser.types = Parser.types newDec, Parser.generics = []} + Nothing -> Just fd Nothing -> find (\(Parser.FuncDec name' _ _) -> name' == baseName fun) funcDecs' let external = find (\x -> x.name == funcName) externals' -- If the funcName starts with curCon@, it's a local function @@ -339,9 +347,9 @@ compileExpr (Parser.FuncCall funcName args _) expectedType = do case funcDec of (Just fd) -> do if length args == length (Parser.types fd) - 1 - then concatMapM (\arg -> typeOf arg >>= compileExpr arg) args >>= \args' -> return (args' ++ [callWay]) + then concatMapM (uncurry compileExpr) (zip args fd.types) >>= \args' -> return (args' ++ [callWay]) else - concatMapM (\arg -> typeOf arg >>= compileExpr arg) args >>= \args' -> + concatMapM (uncurry compileExpr) (zip args fd.types) >>= \args' -> return $ args' ++ [PushPf (funame fun) (length args')] @@ -357,7 +365,7 @@ compileExpr (Parser.FuncCall funcName args _) expectedType = do compileExpr fd@(Parser.FuncDec{}) _ = do modify (\s -> s{funcDecs = fd : funcDecs s}) return [] -- Function declarations are only used for compilation -compileExpr (Parser.FuncDef origName args body) expectedType = do +compileExpr (Parser.FuncDef origName args body) _ = do curCon <- gets currentContext funs <- gets functions let previousContext = curCon @@ -439,7 +447,7 @@ compileExpr (Parser.FuncDef origName args body) expectedType = do compileParameter Parser.Placeholder _ = return [] compileParameter x funcName = do nextFunName <- ((funcName ++ "#") ++) . show . (+ 1) <$> allocId - x' <- compileExpr x Parser.Any + x' <- compileExpr x Parser.Unknown return $ [Dup] ++ x' ++ [Eq, Jf nextFunName] compileExpr (Parser.ParenApply x y _) _ = do fun <- compileExpr x Parser.Any @@ -553,10 +561,25 @@ compileExpr (Parser.Trait name methods) _ = do mapM_ (`compileExpr` Parser.Any) methods' return [] compileExpr (Parser.Impl name for methods) _ = do - let methods' = map (\(Parser.FuncDef name' args body) -> Parser.FuncDef (name ++ "." ++ for ++ "::" ++ name') args body) methods + methods' <- + mapM + ( \(Parser.FuncDef name' args body) -> do + let fullyQualifiedName = name ++ "." ++ for ++ "::" ++ name' + trait <- gets traits >>= \traits' -> return $ fromJust $ find (\x -> Parser.name x == name) traits' + let dec = fromJust $ find (\x -> Parser.name x == name') (Parser.methods trait) + let newDec = Parser.FuncDec{name = fullyQualifiedName, types = unself dec.types for, generics = dec.generics} + _ <- compileExpr newDec Parser.Any + modify (\s -> s{functionsByTrait = (for, name', fullyQualifiedName, newDec) : functionsByTrait s}) + return $ Parser.FuncDef fullyQualifiedName args body + ) + methods + -- gets functionsByTrait >>= traceShowM modify (\s -> s{impls = Parser.Impl name for methods : impls s}) mapM_ (`compileExpr` Parser.Any) methods' return [] + where + unself :: [Parser.Type] -> String -> [Parser.Type] + unself types self = map (\case Parser.Self -> Parser.StructT self; x -> x) types compileExpr (Parser.Lambda args body) _ = do fId <- allocId curCon <- gets currentContext diff --git a/lib/Parser.hs b/lib/Parser.hs index be1741f..bb7754f 100644 --- a/lib/Parser.hs +++ b/lib/Parser.hs @@ -248,9 +248,10 @@ charLit = lexeme (char '\'' *> L.charLiteral <* char '\'') <|> lexeme (L.decimal funcDec :: Parser Expr funcDec = do name <- (identifier <|> gravis) "function name" + generics <- fromMaybe [] <$> optional generic "function generics" symbol "::" argTypes <- sepBy1 validType (symbol "->") "function arguments" - return $ FuncDec name argTypes [] + return $ FuncDec name argTypes generics defArg :: Parser Expr defArg = try structLit <|> var <|> parens listPattern <|> array <|> placeholder <|> IntLit <$> integer <|> StringLit <$> stringLit diff --git a/tests/IntegrationSpec.hs b/tests/IntegrationSpec.hs index b3cb75c..483e299 100644 --- a/tests/IntegrationSpec.hs +++ b/tests/IntegrationSpec.hs @@ -332,6 +332,18 @@ spec = do end |] `shouldReturn` "hello" + it "Static dispatch" $ do + compileAndRun + [r| + printNumbersAndSum :: IO + printNumbersAndSum _ = (println "x") >> (return 2) + + main :: IO + main _ = do + println printNumbersAndSum + end + |] + `shouldReturn` "x\nIO{__traits: [Monad,__field_inner], inner: 2}\n" describe "Lambdas" $ do it "Can use a lambda in the map function" $ do compileAndRun diff --git a/tests/ParserSpec.hs b/tests/ParserSpec.hs index 904f6b1..c9a053c 100644 --- a/tests/ParserSpec.hs +++ b/tests/ParserSpec.hs @@ -138,3 +138,23 @@ spec = do parseProgram "bottles (i)-1" parserCompilerFlags `shouldBe` Right (Program [FuncCall "bottles" [Sub (Var "i" anyPosition) (IntLit 1)] anyPosition]) + + describe "Generics" $ do + it "Should parse let generics" $ do + parseProgram + [r| + let add (a: N b: N) => N = do + a + b + end + |] + parserCompilerFlags + `shouldBe` Right + (Program [Function{def = [FuncDef{name = "add", args = [Var "a" anyPosition, Var "b" anyPosition], body = DoBlock [Add (Var "a" anyPosition) (Var "b" anyPosition)]}], dec = FuncDec{name = "add", types = [StructT "N", StructT "N", StructT "N"], generics = [GenericExpr "N" (Just $ StructT "Number")]}}]) + it "Should parse classic decleration generics" $ do + parseProgram + [r| + add :: N -> N -> N + |] + parserCompilerFlags + `shouldBe` Right + (Program [FuncDec{name = "add", types = [StructT "N", StructT "N", StructT "N"], generics = [GenericExpr "N" (Just $ StructT "Number")]}]) diff --git a/wasm_reactor/.envrc b/wasm_reactor/.envrc new file mode 100644 index 0000000..3550a30 --- /dev/null +++ b/wasm_reactor/.envrc @@ -0,0 +1 @@ +use flake