Skip to content

Commit

Permalink
Pattern matching on structs
Browse files Browse the repository at this point in the history
  • Loading branch information
rockofox committed Nov 17, 2024
1 parent ad57ccd commit 0e0ff8a
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 59 deletions.
36 changes: 29 additions & 7 deletions lib/BytecodeCompiler.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@ module BytecodeCompiler where

import AST (zeroPosition)
import AST qualified as Parser.Type (Type (Unknown))
import Control.Arrow ((>>>))
import Control.Monad (when, (>=>))
import Control.Monad.Loops (firstM)
import Control.Monad.State (MonadIO (liftIO), StateT, gets, modify)
import Data.Bifunctor (second)
import Data.Function ((&))
import Data.Functor ((<&>))
import Data.List (elemIndex, find, inits, intercalate)
import Data.List.Split qualified
Expand All @@ -17,7 +15,6 @@ import Data.String
import Data.Text (isPrefixOf, splitOn)
import Data.Text qualified
import Data.Text qualified as T
import Debug.Trace
import Foreign (nullPtr, ptrToWordPtr)
import Foreign.C.Types ()
import GHC.Generics (Generic)
Expand Down Expand Up @@ -376,12 +373,37 @@ compileExpr (Parser.FuncDef origName args body) = do
let xToY = map (\(x, index) -> [Dup, Push $ DInt index, Index, x]) paramsWithIndex
let rest = [Push $ DInt (length elements - 1), Push DNone, Slice, last elements']
return $ lengthCheck ++ concat xToY ++ rest
compileParameter (Parser.IntLit x) funcName = do
-- TODO: fix this
compileParameter (Parser.StructLit name fields _) funcName = do
nextFunName <- ((funcName ++ "#") ++) . show . (+ 1) <$> allocId
return [Dup, Push $ DInt $ fromIntegral x, Eq, Jf nextFunName]
let fields' =
concatMap
( \case
(sName, Parser.Var tName _) -> [(sName, tName)]
_ -> []
)
fields
let fieldMappings = concatMap (\(sName, tName) -> [Dup, Access sName, LStore tName]) fields'
let fields'' =
concatMap
( \case
(_, Parser.Var _ _) -> []
(sName, x) -> [(sName, x)]
)
fields
fieldChecks <-
concatMapM
( \(sName, x) -> do
let a = [Dup, Access sName]
b <- compileExpr x
return $ a ++ b ++ [Eq, Jf nextFunName]
)
fields''
return $ [Dup, Push $ DTypeQuery name, TypeEq, Jf nextFunName] ++ fieldMappings ++ ([Pop | null fieldChecks]) ++ fieldChecks
compileParameter Parser.Placeholder _ = return []
compileParameter x _ = error $ show x ++ ": not implemented as a function parameter"
compileParameter x funcName = do
nextFunName <- ((funcName ++ "#") ++) . show . (+ 1) <$> allocId
x' <- compileExpr x
return $ x' ++ [Eq, Jf nextFunName]
compileExpr (Parser.ParenApply x y _) = do
fun <- compileExpr x
args <- concatMapM compileExpr y
Expand Down
80 changes: 40 additions & 40 deletions lib/Parser.hs
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ binOpTable =
, [binary "as" Cast]
, [prefix "$" StrictEval]
, [prefix "!" Not]
-- , [prefix "-" UnaryMinus]
, [binary "**" Power, binary "*" Mul, binary "/" Div]
, -- , [prefix "-" UnaryMinus]
[binary "**" Power, binary "*" Mul, binary "/" Div]
, [binary "%" Modulo]
, [binary "+" Add, binary "-" Sub]
, [binary ">>" Then]
Expand Down Expand Up @@ -239,13 +239,13 @@ charLit = lexeme (char '\'' *> L.charLiteral <* char '\'') <|> lexeme (L.decimal

funcDec :: Parser Expr
funcDec = do
name <- identifier <|> gravis
name <- (identifier <|> gravis) <?> "function name"
symbol "::"
argTypes <- sepBy1 validType (symbol "->")
argTypes <- sepBy1 validType (symbol "->") <?> "function arguments"
return $ FuncDec name argTypes []

defArg :: Parser Expr
defArg = var <|> parens listPattern <|> array <|> placeholder <|> IntLit <$> integer
defArg = try structLit <|> var <|> parens listPattern <|> array <|> placeholder <|> IntLit <$> integer <|> StringLit <$> stringLit

funcDef :: Parser Expr
funcDef = do
Expand All @@ -268,47 +268,47 @@ funcCall = do
parenApply :: Parser Expr
parenApply = do
start <- getOffset
paren <- parens expr
paren <- parens expr <?> "parenthesized function expression"
args <- sepBy1 expr (symbol ",") <?> "function arguments"
end <- getOffset
return $ ParenApply paren args (Position (start, end))

letExpr :: Parser Expr
letExpr = do
keyword "let"
name <- identifier <|> gravis
name <- identifier <|> gravis <?> "variable name"
symbol "="
value <- recover expr
value <- recover expr <?> "variable value"
state <- get
put state{validLets = name : validLets state}
return $ Let name value

ifExpr :: Parser Expr
ifExpr = do
keyword "if"
cond <- expr
cond <- expr <?> "condition"
keyword "then"
optional newline'
thenExpr <- expr
thenExpr <- expr <?> "then expression"
optional newline'
keyword "else"
optional newline'
elseExpr <- expr
elseExpr <- expr <?> "else expression"
optional newline'
return $ If cond thenExpr elseExpr

doBlock :: Parser Expr
doBlock = do
keyword "do"
newline'
exprs <- lines'
exprs <- lines' <?> "do block"
keyword "end" <|> lookAhead (keyword "else")
return $ DoBlock exprs

generic :: Parser [GenericExpr]
generic = do
symbol "<"
args <- sepBy1 typeAndMaybeConstraint (symbol ",")
args <- sepBy1 typeAndMaybeConstraint (symbol ",") <?> "generic arguments"
symbol ">"
return args
where
Expand All @@ -323,9 +323,9 @@ combinedFunc :: Parser Expr
combinedFunc = do
keyword "let"
name <- identifier <|> gravis <?> "function name"
generics <- fromMaybe [] <$> optional generic
(args, argTypes) <- parens argsAndTypes <|> argsAndTypes
returnType <- optional (symbol "=>" >> validType <?> "return type")
generics <- (fromMaybe [] <$> optional generic) <?> "function generics"
(args, argTypes) <- (parens argsAndTypes <|> argsAndTypes) <?> "function arguments"
returnType <- optional (symbol "=>" >> validType <?> "return type") <?> "return type"
symbol "="
body <- recover expr <?> "function body"
return $ Function [FuncDef name args body] (FuncDec name (argTypes ++ [fromMaybe Any returnType]) generics)
Expand All @@ -342,7 +342,7 @@ import_ :: Parser Expr
import_ = do
keyword "import"
qualified <- optional (keyword "qualified" >> return True) <?> "qualified"
objects <- sepBy (extra <|> (symbol "*" >> return "*")) (symbol ",")
objects <- sepBy (extra <|> (symbol "*" >> return "*")) (symbol ",") <?> "import objects"
keyword "from"
from <- many (alphaNumChar <|> char '.' <|> char '@' <|> char '/') <?> "import path"
alias <- optional (keyword " as" >> extra) <?> "import alias"
Expand All @@ -351,25 +351,25 @@ import_ = do
array :: Parser Expr
array = do
symbol "["
elements <- sepBy expr (symbol ",")
elements <- sepBy expr (symbol ",") <?> "array elements"
symbol "]"
return $ ListLit elements

struct :: Parser Expr
struct = do
keyword "struct"
name <- extra
name <- extra <?> "struct name"
symbol "="
fields <- parens $ structField `sepBy` symbol ","
fields <- parens (structField `sepBy` symbol ",") <?> "struct fields"
refinementSrc <- lookAhead $ optional $ do
keyword "satisfies"
parens $ many (noneOf [')'])
parens (many (noneOf [')'])) <?> "refinement source"
refinement <- optional $ do
keyword "satisfies"
parens expr
parens expr <?> "refinement"
is <- optional $ do
keyword "is"
sepBy extra (symbol ",")
sepBy extra (symbol ",") <?> "struct interfaces"
return $ Struct{name = name, fields = fields, refinement = refinement, refinementSrc = fromMaybe "" refinementSrc, is = fromMaybe [] is}
where
structField = do
Expand All @@ -381,14 +381,14 @@ struct = do
structLit :: Parser Expr
structLit = do
start <- getOffset
name <- extra
name <- extra <?> "struct name"
symbol "{"
fields <-
sepBy
( do
fieldName <- identifier
fieldName <- identifier <?> "field name"
symbol ":"
fieldValue <- expr
fieldValue <- expr <?> "field value"
return (fieldName, fieldValue)
)
(symbol ",")
Expand All @@ -398,9 +398,9 @@ structLit = do

arrayAccess :: Parser Expr
arrayAccess = do
a <- choice [var, array]
a <- choice [var, array] <?> "array"
symbol "["
index <- expr
index <- expr <?> "array index"
symbol "]"
return $ ArrayAccess a index

Expand Down Expand Up @@ -468,19 +468,19 @@ target :: Parser Expr
target = do
symbol "__target"
target' <- (symbol "wasm" <|> symbol "c") <?> "target"
Target (unpack target') <$> expr
Target (unpack target') <$> expr <?> "target"

listPattern :: Parser Expr
listPattern = do
elements <- sepBy1 (var <|> placeholder <|> array) (symbol ":")
elements <- sepBy1 (var <|> placeholder <|> array) (symbol ":") <?> "list pattern"
return $ ListPattern elements

lambda :: Parser Expr
lambda = do
symbol "\\"
args <- some (var <|> parens listPattern <|> array)
args <- some (var <|> parens listPattern <|> array) <?> "lambda arguments"
symbol "->"
Lambda args <$> expr
Lambda args <$> expr <?> "lambda body"

verbose :: (Show a) => String -> Parser a -> Parser a
verbose str parser = do
Expand All @@ -503,13 +503,13 @@ typeLiteral = do
trait :: Parser Expr
trait = do
keyword "trait"
name <- identifier
name <- identifier <?> "trait name"
methods <-
( do
symbol "="
keyword "do"
newline'
fds <- funcDec `sepEndBy` newline'
fds <- funcDec `sepEndBy` newline' <?> "trait methods"
keyword "end"
return fds
)
Expand All @@ -519,15 +519,15 @@ trait = do
impl :: Parser Expr
impl = do
keyword "impl"
name <- identifier
name <- identifier <?> "impl name"
symbol "for"
for <- identifier
for <- identifier <?> "impl for"
methods <-
( do
symbol "="
symbol "do"
newline'
fds <- funcDef `sepEndBy` newline'
fds <- funcDef `sepEndBy` newline' <?> "impl methods"
symbol "end"
return fds
)
Expand All @@ -537,18 +537,18 @@ impl = do
external :: Parser Expr
external = do
keyword "external"
from <- stringLit
from <- stringLit <?> "external from"
symbol "="
symbol "do"
newline'
decs <- funcDec `sepEndBy` newline'
decs <- funcDec `sepEndBy` newline' <?> "external declarations"
symbol "end"
return $ External from decs

unaryMinus :: Parser Expr
unaryMinus = parens $ do
symbol "-"
UnaryMinus <$> expr
UnaryMinus <$> expr <?> "negatable expression"

term :: Parser Expr
term =
Expand Down
5 changes: 5 additions & 0 deletions lib/Verifier.hs
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,10 @@ verifyProgram' name source exprs = do
verifyMultiple :: [Expr] -> StateT VerifierState IO [ParseError s e]
verifyMultiple = concatMapM verifyExpr

structLitToBindings :: Expr -> Type -> [VBinding]
structLitToBindings (StructLit _ fields _) _ = map (\case (_, Var name _) -> VBinding{name = name, ttype = Any, args = [], generics = []}; _ -> error "Impossible") fields
structLitToBindings x y = error $ "structLitToBindings called with " ++ show x ++ " and " ++ show y

verifyExpr :: Expr -> StateT VerifierState IO [ParseError s e]
verifyExpr (FuncDef name args body) = do
-- TODO: Position
Expand All @@ -270,6 +274,7 @@ verifyExpr (FuncDef name args body) = do
argToBindings (Var name' _, Fn args' ret) = [VBinding{name = name', args = args', ttype = ret, generics = []}]
argToBindings (Var name' _, ttype') = [VBinding{name = name', args = [], ttype = ttype', generics = []}]
argToBindings (l@ListPattern{}, t) = listPatternToBindings l t
argToBindings (s@StructLit{}, t) = structLitToBindings s t
argToBindings _ = []
Nothing -> return [FancyError 0 (Set.singleton (ErrorFail $ "Function " ++ name ++ " is missing a declaration"))]
verifyExpr (FuncDec name dtypes generics) = do
Expand Down
8 changes: 8 additions & 0 deletions share/std/prelude.in
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
struct IO = (inner: Any)

trait Optional
struct Some = (value: Any) is Optional
struct None = () is Optional

maybe :: Any -> Optional -> Any
maybe x Some{value: y} = y
maybe x None{} = x

# `+` :: Any -> Any -> Any
# `+` x y = x + y
# `-` :: Any -> Any -> Any
Expand Down
Loading

0 comments on commit 0e0ff8a

Please sign in to comment.