From d3962728277d68aed8fc7b58f40c1d45eaea5284 Mon Sep 17 00:00:00 2001 From: matt Date: Fri, 17 Jan 2025 12:18:52 -1000 Subject: [PATCH 01/16] wip --- rscel/examples/dumps_ast.rs | 1 + rscel/examples/explain.rs | 1 + rscel/src/compiler/compiled_prog.rs | 12 + rscel/src/compiler/compiler.rs | 172 +++++++++++--- rscel/src/compiler/grammar.rs | 34 +++ rscel/src/compiler/string_scanner.rs | 2 +- rscel/src/compiler/string_tokenizer.rs | 6 +- rscel/src/compiler/tokens.rs | 2 + rscel/src/interp/mod.rs | 9 + rscel/src/interp/types.rs | 4 + rscel/src/tests/general_tests.rs | 313 +++++++++++++------------ 11 files changed, 365 insertions(+), 191 deletions(-) diff --git a/rscel/examples/dumps_ast.rs b/rscel/examples/dumps_ast.rs index 13d2bf4..12cff2c 100644 --- a/rscel/examples/dumps_ast.rs +++ b/rscel/examples/dumps_ast.rs @@ -39,6 +39,7 @@ impl<'a> AstDumper<'a> { self.dump_or_node(true_clause, depth + 1); self.dump_expr_node(false_clause, depth + 1); } + Expr::Match { .. } => todo!(), } } diff --git a/rscel/examples/explain.rs b/rscel/examples/explain.rs index fa07adf..26108b0 100644 --- a/rscel/examples/explain.rs +++ b/rscel/examples/explain.rs @@ -116,6 +116,7 @@ impl AstDumper { ] .into_iter(), )), + Expr::Match { .. } => todo!(), } } diff --git a/rscel/src/compiler/compiled_prog.rs b/rscel/src/compiler/compiled_prog.rs index fab8525..bfc1854 100644 --- a/rscel/src/compiler/compiled_prog.rs +++ b/rscel/src/compiler/compiled_prog.rs @@ -61,6 +61,10 @@ macro_rules! compile { } impl CompiledProg { + pub fn new(inner: NodeValue, details: ProgramDetails) -> Self { + Self { inner, details } + } + pub fn empty() -> CompiledProg { CompiledProg { inner: NodeValue::Bytecode(CelByteCode::new()), @@ -75,6 +79,10 @@ impl CompiledProg { } } + pub fn details(&self) -> &ProgramDetails { + &self.details + } + pub fn with_bytecode(bytecode: CelByteCode) -> CompiledProg { CompiledProg { inner: NodeValue::Bytecode(bytecode), @@ -293,6 +301,10 @@ impl CompiledProg { self.inner.into_bytecode() } + pub fn into_parts(self) -> (NodeValue, ProgramDetails) { + (self.inner, self.details) + } + pub fn is_const(&self) -> bool { self.inner.is_const() } diff --git a/rscel/src/compiler/compiler.rs b/rscel/src/compiler/compiler.rs index ef0dde6..87271d1 100644 --- a/rscel/src/compiler/compiler.rs +++ b/rscel/src/compiler/compiler.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use super::{ ast_node::AstNode, - compiled_prog::CompiledProg, + compiled_prog::{CompiledProg, NodeValue}, grammar::*, source_range::SourceRange, syntax_error::SyntaxError, @@ -11,7 +11,9 @@ use super::{ }; use crate::{ interp::{Interpreter, JmpWhen}, - BindContext, ByteCode, CelError, CelResult, CelValue, CelValueDyn, Program, StringTokenizer, + types::CelByteCode, + BindContext, ByteCode, CelError, CelResult, CelValue, CelValueDyn, Program, ProgramDetails, + StringTokenizer, }; use crate::compile; @@ -50,44 +52,150 @@ impl<'l> CelCompiler<'l> { } fn parse_expression(&mut self) -> CelResult<(CompiledProg, AstNode)> { - let (lhs_node, lhs_ast) = self.parse_conditional_or()?; + if let Some(Token::Match) = self.tokenizer.peek()?.as_token() { + self.tokenizer.next()?; + self.parse_match_expression() + } else { + let (lhs_node, lhs_ast) = self.parse_conditional_or()?; - match self.tokenizer.peek()?.as_token() { - Some(Token::Question) => { - self.tokenizer.next()?; - let (true_clause_node, true_clause_ast) = self.parse_conditional_or()?; - - let next = self.tokenizer.next()?; - if next.as_token() != Some(&Token::Colon) { - return Err(SyntaxError::from_location(self.tokenizer.location()) - .with_message(format!("Unexpected token {:?}, expected COLON", next)) - .into()); + match self.tokenizer.peek()?.as_token() { + Some(Token::Question) => { + self.tokenizer.next()?; + self.parse_turnary_expression(lhs_node, lhs_ast) } + _ => { + let range = lhs_ast.range(); + Ok(( + CompiledProg::from_node(lhs_node), + AstNode::new(Expr::Unary(Box::new(lhs_ast)), range), + )) + } + } + } + } - let (false_clause_node, false_clause_ast) = self.parse_expression()?; + fn parse_turnary_expression( + &mut self, + or_prog: CompiledProg, + or_ast: AstNode, + ) -> CelResult<(CompiledProg, AstNode)> { + let (true_clause_node, true_clause_ast) = self.parse_conditional_or()?; - let range = lhs_ast.range().surrounding(false_clause_ast.range()); + let next = self.tokenizer.next()?; + if next.as_token() != Some(&Token::Colon) { + return Err(SyntaxError::from_location(self.tokenizer.location()) + .with_message(format!("Unexpected token {:?}, expected COLON", next)) + .into()); + } - Ok(( - lhs_node.into_turnary(true_clause_node, false_clause_node), - AstNode::new( - Expr::Ternary { - condition: Box::new(lhs_ast), - true_clause: Box::new(true_clause_ast), - false_clause: Box::new(false_clause_ast), - }, - range, - ), - )) + let (false_clause_node, false_clause_ast) = self.parse_expression()?; + + let range = or_ast.range().surrounding(false_clause_ast.range()); + + Ok(( + or_prog.into_turnary(true_clause_node, false_clause_node), + AstNode::new( + Expr::Ternary { + condition: Box::new(or_ast), + true_clause: Box::new(true_clause_ast), + false_clause: Box::new(false_clause_ast), + }, + range, + ), + )) + } + + fn parse_match_expression(&mut self) -> CelResult<(CompiledProg, AstNode)> { + let (condition_node, condition_ast) = self.parse_expression()?; + + let mut range = condition_ast.range(); + + let (node_bytecode, mut node_details) = condition_node.into_parts(); + + let mut node_bytecode = node_bytecode.into_bytecode(); + + let next = self.tokenizer.next()?; + if next.as_token() != Some(&Token::LBrace) { + return Err(SyntaxError::from_location(self.tokenizer.location()) + .with_message(format!("Unexpected token {:?}, expected LBRACE", next)) + .into()); + } + + let mut expressions: Vec> = Vec::new(); + + let mut all_parts = Vec::new(); + + loop { + let lbrace = self.tokenizer.peek()?; + if lbrace.as_token() != Some(&Token::LBrace) { + range = range.surrounding(lbrace.unwrap().loc); + break; } - _ => { - let range = lhs_ast.range(); - Ok(( - CompiledProg::from_node(lhs_node), - AstNode::new(Expr::Unary(Box::new(lhs_ast)), range), - )) + + let case_token = self.tokenizer.next()?; + if case_token.as_token() != Some(&Token::Case) { + return Err(SyntaxError::from_location(self.tokenizer.location()) + .with_message(format!("Unexpected token {:?}, expected CASE", next)) + .into()); } + let (pattern_prog, pattern_ast) = self.parse_match_pattern()?; + let (pattern_bytecode, pattern_details) = pattern_prog.into_parts(); + let pattern_bytecode: Vec<_> = [ByteCode::Dup] + .into_iter() + .chain(pattern_bytecode.into_bytecode().into_iter()) + .collect(); + + node_details.union_from(pattern_details); + + let pattern_range = pattern_ast.range(); + + let colon_token = self.tokenizer.next()?; + if colon_token.as_token() != Some(&Token::Colon) { + return Err(SyntaxError::from_location(self.tokenizer.location()) + .with_message(format!("Unexpected token {:?}, expected COLON", next)) + .into()); + } + + let (expr_prog, expr_ast) = self.parse_expression()?; + let (expr_bytecode, expr_details) = expr_prog.into_parts(); + let expr_bytecode: Vec<_> = [ByteCode::Pop] + .into_iter() + .chain(expr_bytecode.into_bytecode().into_iter()) + .collect(); + + node_details.union_from(expr_details); + + let case_range = pattern_range.surrounding(expr_ast.range()); + + all_parts.push((pattern_bytecode, expr_bytecode)); + expressions.push(AstNode::new( + MatchCase { + pattern: pattern_ast, + expr: Box::new(expr_ast), + }, + case_range, + )); } + + let mut pattern_segment = CelByteCode::new(); + let mut expr_segment = CelByteCode::new(); + + for (pattern_bytecode, expr_bytecode) in all_parts.into_iter().rev() {} + + Ok(( + CompiledProg::new(NodeValue::Bytecode(node_bytecode), node_details), + AstNode::new( + Expr::Match { + condition: Box::new(condition_ast), + cases: expressions, + }, + range, + ), + )) + } + + fn parse_match_pattern(&mut self) -> CelResult<(CompiledProg, AstNode)> { + todo!() } fn parse_conditional_or(&mut self) -> CelResult<(CompiledProg, AstNode)> { diff --git a/rscel/src/compiler/grammar.rs b/rscel/src/compiler/grammar.rs index c10b1a7..2876e3a 100644 --- a/rscel/src/compiler/grammar.rs +++ b/rscel/src/compiler/grammar.rs @@ -24,9 +24,43 @@ pub enum Expr { true_clause: Box>, false_clause: Box>, }, + Match { + condition: Box>, + cases: Vec>, + }, Unary(Box>), } +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct MatchCase { + pub pattern: AstNode, + pub expr: Box>, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum MatchPattern { + Type(AstNode), + Any(AstNode), +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum MatchTypePattern { + Int, + Uint, + Float, + String, + Bool, + Bytes, + List, + Object, + Null, + Timestamp, + Duration, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct MatchAnyPattern; + #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum ConditionalOr { Binary { diff --git a/rscel/src/compiler/string_scanner.rs b/rscel/src/compiler/string_scanner.rs index 5fa7c5a..9523c22 100644 --- a/rscel/src/compiler/string_scanner.rs +++ b/rscel/src/compiler/string_scanner.rs @@ -24,7 +24,7 @@ impl<'l> StringScanner<'l> { } pub fn peek(&mut self) -> Option { - if let None = self.current { + if self.current.is_none() { self.current = self.collect_next(); } diff --git a/rscel/src/compiler/string_tokenizer.rs b/rscel/src/compiler/string_tokenizer.rs index 6911d2e..bf1d1bf 100644 --- a/rscel/src/compiler/string_tokenizer.rs +++ b/rscel/src/compiler/string_tokenizer.rs @@ -126,6 +126,7 @@ impl<'l> StringTokenizer<'l> { self.parse_keywords_or_ident("b", &[]) } } + 'c' => self.parse_keywords_or_ident("c", &[("case", Token::Case)]), 'f' => { if let Some('\'') = self.scanner.peek() { self.scanner.next(); @@ -138,6 +139,7 @@ impl<'l> StringTokenizer<'l> { } } 'i' => self.parse_keywords_or_ident("i", &[("in", Token::In)]), + 'm' => self.parse_keywords_or_ident("m", &[("match", Token::Match)]), 'n' => self.parse_keywords_or_ident("n", &[("null", Token::Null)]), 'r' => { if let Some('\'') = self.scanner.peek() { @@ -562,7 +564,7 @@ impl<'l> StringTokenizer<'l> { impl Tokenizer for StringTokenizer<'_> { fn peek(&mut self) -> Result, SyntaxError> { - if let None = self.current { + if self.current.is_none() { match self.collect_next_token() { Ok(token) => self.current = token, Err(err) => return Err(err), @@ -572,7 +574,7 @@ impl Tokenizer for StringTokenizer<'_> { } fn next(&mut self) -> Result, SyntaxError> { - if let None = self.current { + if self.current.is_none() { self.collect_next_token() } else { let tmp = std::mem::replace(&mut self.current, None); diff --git a/rscel/src/compiler/tokens.rs b/rscel/src/compiler/tokens.rs index 5c85b98..df75439 100644 --- a/rscel/src/compiler/tokens.rs +++ b/rscel/src/compiler/tokens.rs @@ -31,6 +31,8 @@ pub enum Token { NotEqual, // != In, // 'in' Null, // 'null' + Match, // 'match' + Case, // 'case' BoolLit(bool), // true | false IntLit(u64), // [-+]?[0-9]+ UIntLit(u64), // [0-9]+u diff --git a/rscel/src/interp/mod.rs b/rscel/src/interp/mod.rs index c3ca5e3..4bd0f4c 100644 --- a/rscel/src/interp/mod.rs +++ b/rscel/src/interp/mod.rs @@ -160,6 +160,15 @@ impl<'a> Interpreter<'a> { pc += 1; match &prog[oldpc] { ByteCode::Push(val) => stack.push(val.clone().into()), + ByteCode::Dup => { + let v = stack.pop_val()?; + + stack.push_val(v.clone()); + stack.push_val(v); + } + ByteCode::Pop => { + stack.pop_val()?; + } ByteCode::Or => { let v2 = stack.pop_val()?; let v1 = stack.pop_val()?; diff --git a/rscel/src/interp/types.rs b/rscel/src/interp/types.rs index 08ed349..edc5007 100644 --- a/rscel/src/interp/types.rs +++ b/rscel/src/interp/types.rs @@ -12,6 +12,8 @@ pub enum JmpWhen { #[derive(Clone, PartialEq, Serialize, Deserialize)] pub enum ByteCode { Push(CelValue), + Dup, + Pop, Or, And, Not, @@ -48,6 +50,8 @@ impl fmt::Debug for ByteCode { match self { Push(val) => write!(f, "PUSH {:?}", val), + Dup => write!(f, "DUP"), + Pop => write!(f, "POP"), Or => write!(f, "OR"), And => write!(f, "AND"), Not => write!(f, "NOT"), diff --git a/rscel/src/tests/general_tests.rs b/rscel/src/tests/general_tests.rs index d140fef..aade850 100644 --- a/rscel/src/tests/general_tests.rs +++ b/rscel/src/tests/general_tests.rs @@ -29,170 +29,171 @@ fn test_contains() { let _res = ctx.exec("main", &exec_ctx).unwrap(); } -#[test_case("3+3", 6.into(); "add signed")] -#[test_case("4-3", 1.into(); "sub signed")] -#[test_case("4u + 3u", 7u64.into(); "add unsigned")] -#[test_case("7 % 2", 1.into(); "test mod")] -#[test_case("(4+2) * (6-5)", 6.into(); "test parens")] -#[test_case("4+2*6-5", 11.into(); "test op order")] -#[test_case("4-2+5*2", (12).into(); "test op order 2")] -#[test_case("[1, 2, 3].map(x, x+2)", CelValue::from_list(vec![3.into(), 4.into(), 5.into()]); "test map")] -#[test_case("[1, 2, 3].map(x, x % 2 == 1, x + 1)", CelValue::from_list(vec![2.into(), 4.into()]); "test map 2")] -#[test_case("[1,2,3][1]", 2.into(); "array index")] -#[test_case("{\"foo\": 3}.foo", 3.into(); "obj dot access")] -#[test_case("size([1,2,3,4])", 4u64.into(); "test list size")] -#[test_case("true || false", true.into(); "or")] -#[test_case("true || undefined", true.into(); "or shortcut")] -#[test_case("false && undefined", false.into(); "and shortcut")] -#[test_case("false && true", false.into(); "and falsy")] -#[test_case("true && true", true.into(); "and true")] -#[test_case("[1,2].map(x, x+1).map(x, x*2)", CelValue::from_list(vec![4.into(), 6.into()]); "double map")] -#[test_case("\"hello world\".contains(\"hello\")", true.into(); "test contains")] -#[test_case("\"hello world\".endsWith(\"world\")", true.into(); "test endsWith")] -#[test_case("\"hello world\".startsWith(\"hello\")", true.into(); "test startsWith")] -#[test_case("\"abc123\".matches(\"[a-z]{3}[0-9]{3}\")", true.into(); "test matches method")] -#[test_case("matches('abc123', '[a-z]{3}[0-9]{3}')", true.into(); "test matches function")] -#[test_case("string(1)", "1".into(); "test string")] +#[test_case("3+3", 6; "add signed")] +#[test_case("4-3", 1; "sub signed")] +#[test_case("4u + 3u", 7u64; "add unsigned")] +#[test_case("7 % 2", 1; "test mod")] +#[test_case("(4+2) * (6-5)", 6; "test parens")] +#[test_case("4+2*6-5", 11; "test op order")] +#[test_case("4-2+5*2", 12; "test op order 2")] +#[test_case("[1, 2, 3].map(x, x+2)", vec![3, 4, 5]; "test map")] +#[test_case("[1, 2, 3].map(x, x % 2 == 1, x + 1)", vec![2, 4]; "test map 2")] +#[test_case("[1,2,3][1]", 2; "array index")] +#[test_case("{\"foo\": 3}.foo", 3; "obj dot access")] +#[test_case("size([1,2,3,4])", 4u64; "test list size")] +#[test_case("true || false", true; "or")] +#[test_case("true || undefined", true; "or shortcut")] +#[test_case("false && undefined", false; "and shortcut")] +#[test_case("false && true", false; "and falsy")] +#[test_case("true && true", true; "and true")] +#[test_case("[1,2].map(x, x+1).map(x, x*2)", vec![4, 6]; "double map")] +#[test_case("\"hello world\".contains(\"hello\")", true; "test contains")] +#[test_case("\"hello world\".endsWith(\"world\")", true; "test endsWith")] +#[test_case("\"hello world\".startsWith(\"hello\")", true; "test startsWith")] +#[test_case("\"abc123\".matches(\"[a-z]{3}[0-9]{3}\")", true; "test matches method")] +#[test_case("matches('abc123', '[a-z]{3}[0-9]{3}')", true; "test matches function")] +#[test_case("string(1)", "1"; "test string")] #[test_case("type(1)", CelValue::int_type(); "test type")] -#[test_case("4 > 5", false.into(); "test gt")] -#[test_case("4 < 5", true.into(); "test lt")] -#[test_case("4 >= 4", true.into(); "test ge")] -#[test_case("5 <= 4", false.into(); "test le")] -#[test_case("5 == 5", true.into(); "test eq")] -#[test_case("5 != 5", false.into(); "test ne")] -#[test_case("3 in [1,2,3,4,5]", true.into(); "test in")] -#[test_case(r#"has({"foo": 3}.foo)"#, true.into(); "test has")] -#[test_case("[1,2,3,4].all(x, x < 5)", true.into(); "test all true")] -#[test_case("[1,2,3,4,5].all(x, x < 5)", false.into(); "test all false")] -#[test_case("[1,2,3,4].exists(x, x < 3)", true.into(); "test exists true")] -#[test_case("[1,2,3,4].exists(x, x == 5)", false.into(); "test exists false")] -#[test_case("[1,2,3,4].exists_one(x, x == 4)", true.into(); "test exists one true")] -#[test_case("[1,2,3,4].exists_one(x, x == 5)", false.into(); "test exists one false")] -#[test_case("[1,2,3,4].filter(x, x % 2 == 0)", CelValue::from_list(vec![2.into(), 4.into()]); "test filter")] -#[test_case("abs(-9)", 9.into(); "abs")] -#[test_case("sqrt(9.0)", 3.0.into(); "sqrt")] -#[test_case("pow(2, 2)", 4.into(); "pow")] -#[test_case("pow(2.0, 2)", 4.0.into(); "pow2")] -#[test_case("log(1)", 0.into(); "log")] -#[test_case("log(1u)", 0u64.into(); "log unsigned")] -#[test_case("ceil(2.3)", 3.into(); "ceil")] -#[test_case("floor(2.7)", 2.into(); "floor")] -#[test_case("round(2.2)", 2.into(); "round down")] -#[test_case("round(2.5)", 3.into(); "round up")] -#[test_case("min(1,2,3)", 1.into(); "min")] -#[test_case("max(1,2,3)", 3.into(); "max")] -#[test_case("[1,2,3].reduce(curr, next, curr + next, 0)", 6.into(); "reduce")] -#[test_case("{}", CelValue::from_map(HashMap::new()); "empty object")] -#[test_case("[]", CelValue::from_list(Vec::new()); "empy list")] -#[test_case("has(foo) && foo > 10", false.into(); "has works")] -#[test_case("true ? 4 : 3", 4.into(); "ternary true")] -#[test_case("false ? 4 : 3", 3.into(); "ternary false")] -#[test_case("2 * 4 * 8 * 72 / 144", 32.into(); "long multiply operation")] -#[test_case("2 * 3 + 7", 13.into(); "long mixed operation")] -#[test_case("true && false || true && true", true.into(); "long logic operation")] -#[test_case("2 + 3 - 1", 4.into(); "long add/sub operation")] -#[test_case("-2 + 4", 2.into(); "neg pos addition")] -#[test_case("2 < 3 >= 1", true.into(); "type prop: chained cmp")] -#[test_case("3 * 2 - 1 / 4 * 2", 6.into(); "large op 2")] -#[test_case("true || unbound || unbound", true.into(); "Or short cut")] -#[test_case("true == true || false == true && false", true.into(); "Incorrect equality precedence")] -#[test_case("5 < 10 || 10 < 5 && false", true.into(); "Incorrect less-than precedence")] -#[test_case("true || false && false", true.into(); "Incorrect AND precedence")] -#[test_case("false && true || true", true.into(); "Incorrect OR precedence")] -#[test_case("5 + 5 == 10 || 10 - 5 == 5 && false", true.into(); "Incorrect addition precedence")] -#[test_case("6 / 2 == 3 || 2 * 3 == 6 && false", true.into(); "Incorrect division precedence")] -#[test_case("(true || false) && false", false.into(); "Incorrect parentheses precedence")] -#[test_case("'foo' in 'foot'", true.into(); "in string operator")] -#[test_case("'foot' in 'foo'", false.into(); "in string operator false")] -#[test_case("type(3) == type(3)", true.into(); "type eq")] -#[test_case("type(null) == null_type", true.into(); "null_type eq")] -#[test_case("type(3) == int", true.into(); "int type eq")] -#[test_case("type(3u) == uint", true.into(); "uint type eq")] -#[test_case("type('foo') == string", true.into(); "string type eq")] -#[test_case("type(true) == bool", true.into(); "bool type eq true")] -#[test_case("type(false) == bool", true.into(); "bool type eq false")] -#[test_case("type(3.2) == double", true.into(); "double type eq")] -#[test_case("type(3.2) == float", true.into(); "float type eq")] -#[test_case("type(true) == double", false.into(); "bool type neq")] -#[test_case("type(true) != double", true.into(); "bool type neq 2")] -#[test_case("type([1,2,3]) == type([])", true.into(); "list type neq")] -#[test_case("type({'foo': 3}) == type({})", true.into(); "map type neq")] +#[test_case("4 > 5", false; "test gt")] +#[test_case("4 < 5", true; "test lt")] +#[test_case("4 >= 4", true; "test ge")] +#[test_case("5 <= 4", false; "test le")] +#[test_case("5 == 5", true; "test eq")] +#[test_case("5 != 5", false; "test ne")] +#[test_case("3 in [1,2,3,4,5]", true; "test in")] +#[test_case(r#"has({"foo": 3}.foo)"#, true; "test has")] +#[test_case("[1,2,3,4].all(x, x < 5)", true; "test all true")] +#[test_case("[1,2,3,4,5].all(x, x < 5)", false; "test all false")] +#[test_case("[1,2,3,4].exists(x, x < 3)", true; "test exists true")] +#[test_case("[1,2,3,4].exists(x, x == 5)", false; "test exists false")] +#[test_case("[1,2,3,4].exists_one(x, x == 4)", true; "test exists one true")] +#[test_case("[1,2,3,4].exists_one(x, x == 5)", false; "test exists one false")] +#[test_case("[1,2,3,4].filter(x, x % 2 == 0)", vec![2, 4]; "test filter")] +#[test_case("abs(-9)", 9; "abs")] +#[test_case("sqrt(9.0)", 3.0; "sqrt")] +#[test_case("pow(2, 2)", 4; "pow")] +#[test_case("pow(2.0, 2)", 4.0; "pow2")] +#[test_case("log(1)", 0; "log")] +#[test_case("log(1u)", 0u64; "log unsigned")] +#[test_case("ceil(2.3)", 3; "ceil")] +#[test_case("floor(2.7)", 2; "floor")] +#[test_case("round(2.2)", 2; "round down")] +#[test_case("round(2.5)", 3; "round up")] +#[test_case("min(1,2,3)", 1; "min")] +#[test_case("max(1,2,3)", 3; "max")] +#[test_case("[1,2,3].reduce(curr, next, curr + next, 0)", 6; "reduce")] +#[test_case("{}", HashMap::new(); "empty object")] +#[test_case("[]", Vec::::new(); "empy list")] +#[test_case("has(foo) && foo > 10", false; "has works")] +#[test_case("true ? 4 : 3", 4; "ternary true")] +#[test_case("false ? 4 : 3", 3; "ternary false")] +#[test_case("2 * 4 * 8 * 72 / 144", 32; "long multiply operation")] +#[test_case("2 * 3 + 7", 13; "long mixed operation")] +#[test_case("true && false || true && true", true; "long logic operation")] +#[test_case("2 + 3 - 1", 4; "long add/sub operation")] +#[test_case("-2 + 4", 2; "neg pos addition")] +#[test_case("2 < 3 >= 1", true; "type prop: chained cmp")] +#[test_case("3 * 2 - 1 / 4 * 2", 6; "large op 2")] +#[test_case("true || unbound || unbound", true; "Or short cut")] +#[test_case("true == true || false == true && false", true; "Incorrect equality precedence")] +#[test_case("5 < 10 || 10 < 5 && false", true; "Incorrect less-than precedence")] +#[test_case("true || false && false", true; "Incorrect AND precedence")] +#[test_case("false && true || true", true; "Incorrect OR precedence")] +#[test_case("5 + 5 == 10 || 10 - 5 == 5 && false", true; "Incorrect addition precedence")] +#[test_case("6 / 2 == 3 || 2 * 3 == 6 && false", true; "Incorrect division precedence")] +#[test_case("(true || false) && false", false; "Incorrect parentheses precedence")] +#[test_case("'foo' in 'foot'", true; "in string operator")] +#[test_case("'foot' in 'foo'", false; "in string operator false")] +#[test_case("type(3) == type(3)", true; "type eq")] +#[test_case("type(null) == null_type", true; "null_type eq")] +#[test_case("type(3) == int", true; "int type eq")] +#[test_case("type(3u) == uint", true; "uint type eq")] +#[test_case("type('foo') == string", true; "string type eq")] +#[test_case("type(true) == bool", true; "bool type eq true")] +#[test_case("type(false) == bool", true; "bool type eq false")] +#[test_case("type(3.2) == double", true; "double type eq")] +#[test_case("type(3.2) == float", true; "float type eq")] +#[test_case("type(true) == double", false; "bool type neq")] +#[test_case("type(true) != double", true; "bool type neq 2")] +#[test_case("type([1,2,3]) == type([])", true; "list type neq")] +#[test_case("type({'foo': 3}) == type({})", true; "map type neq")] #[test_case("coalesce()", CelValue::from_null(); "coalesce none")] -#[test_case("coalesce(null, 3)", 3.into(); "coalesce explicit null")] -#[test_case("coalesce(foo, 4)", 4.into(); "coalesce unbound var")] -#[test_case("coalesce(1, 2, 3)", 1.into(); "coalesce first val ok")] -#[test_case(".1", 0.1.into(); "dot leading floating point")] -#[test_case("-.1", (-0.1).into(); "neg dot leading floating point")] -#[test_case("2+3 in [5]", true.into(); "check in binding")] -#[test_case("foo.b || true", true.into(); "Error bypassing")] -#[test_case(r#""\u00fc""#, "ü".into(); "Test unicode short lower")] -#[test_case(r#""\u00FC""#, "ü".into(); "Test unicode short upper")] -#[test_case(r#""\U000000fc""#, "ü".into(); "Test unicode long lower")] -#[test_case(r#""\U000000FC""#, "ü".into(); "Test unicode long upper")] -#[test_case(r#""\x48""#, "H".into(); "Test hex escape lower")] -#[test_case(r#""\X48""#, "H".into(); "Test hex escape upper")] -#[test_case("'fooBaR'.endsWithI('bar')", true.into(); "Test endsWithI")] -#[test_case("'FoObar'.startsWithI('foo')", true.into(); "Test startsWithI")] -#[test_case("' foo '.trim()", "foo".into(); "Test trim")] -#[test_case("' foo '.trimStart()", "foo ".into(); "Test trimStart")] -#[test_case("' foo '.trimEnd()", " foo".into(); "Test trimEnd")] -#[test_case("'foo'.toUpper()", "FOO".into(); "test toUpper")] -#[test_case("'FOO'.toLower()", "foo".into(); "test toLower")] -#[test_case(r#"'foo bar\t\tbaz'.splitWhiteSpace()"#, CelValue::from_val_slice(&["foo".into(), "bar".into(), "baz".into()]); "test splitWhiteSpace")] -#[test_case("{'foo': x}.map(k, k)", CelValue::from_val_slice(&["foo".into()]); "test map on map")] -#[test_case("{'foo': x, 'bar': y}.filter(k, k == 'foo')", CelValue::from_val_slice(&["foo".into()]); "test filter on map")] -#[test_case(r#"f"{3}""#, "3".into(); "test basic format string")] -#[test_case(r#"f"{({"foo": 3}).foo)}""#, "3".into(); "test fstring with map")] -#[test_case(r#"f"{[1,2,3][2]}""#, "3".into(); "test fstring with list")] -#[test_case("timestamp('2024-07-30 12:00:00+00:00') - timestamp('2024-07-30 11:55:00+00:00') == duration('5m')", true.into(); "test timestamp sub 1")] -#[test_case("timestamp('2024-07-30 11:55:00+00:00') - timestamp('2024-07-30 12:00:00+00:00')", Duration::new(-300, 0).unwrap().into(); "test timestamp sub 2")] -#[test_case("timestamp('2023-12-25T12:00:00Z').getDayOfMonth()", 24.into(); "getDayOfMonth")] -#[test_case("timestamp('2023-12-25T7:00:00Z').getDayOfMonth('America/Los_Angeles')", 23.into(); "getDayOfMonth with timezone")] -#[test_case("int(1)", 1.into(); "identity -- int")] -#[test_case("uint(1u)", 1u64.into(); "identity -- uint")] -#[test_case("double(5.5)", 5.5.into(); "identity -- double")] -#[test_case("string('hello')", "hello".into(); "identity -- string")] -#[test_case("bytes(bytes('abc'))", crate::types::CelBytes::from_vec(vec![97u8, 98u8, 99u8]).into(); "identity -- bytes 1")] -#[test_case("bytes(b'abc')", crate::types::CelBytes::from_vec(vec![97u8, 98u8, 99u8]).into(); "identity -- bytes 2")] -#[test_case("duration(duration('100s')) == duration('100s')", true.into(); "identity -- duration")] -#[test_case("duration('2h') + duration('1h1m') >= duration('3h')", true.into(); "duration add + comp")] -#[test_case("timestamp(timestamp(100000000)) == timestamp(100000000)", true.into(); "identity -- timestamp")] -#[test_case("bool(true)", true.into(); "bool true")] -#[test_case("bool(false)", false.into(); "bool false")] -#[test_case("bool('1')", true.into(); "'1' -> bool")] -#[test_case("bool('t')", true.into(); "'t' -> bool")] -#[test_case("bool('true')", true.into(); "'true' -> bool 1")] -#[test_case("bool('TRUE')", true.into(); "'TRUE' -> bool 2")] -#[test_case("bool('True')", true.into(); "'True' -> bool 3")] -#[test_case("bool('0')", false.into(); "'0' -> bool")] -#[test_case("bool('f')", false.into(); "'f' -> bool")] -#[test_case("bool('false')", false.into(); "'false' -> bool 1")] -#[test_case("bool('FALSE')", false.into(); "'FALSE' -> bool 2")] -#[test_case("bool('False')", false.into(); "'False' -> bool 3")] -#[test_case("!true", false.into(); "not true")] -#[test_case("!false", true.into(); "not false")] -#[test_case("1 + 2 == 3 && 4 + 5 == 9", true.into(); "and operation with expressions")] -#[test_case("1 + 2 == 3 || 4 + 5 == 10", true.into(); "or operation with expressions")] -#[test_case("! (1 + 2 == 4)", true.into(); "negated expression")] -#[test_case("size([1, 2, 3].filter(x, x > 1)) == 2", true.into(); "filter and size")] -#[test_case("max(1, 2, 3) + min(4, 5, 6) == 7", true.into(); "max and min")] -#[test_case("['hello', 'world'].reduce(curr, next, curr + ' ' + next, '')", " hello world".into(); "reduce with strings")] -#[test_case("timestamp('2024-07-30 12:00:00Z') > timestamp('2023-07-30 12:00:00Z')", true.into(); "timestamp comparison")] -#[test_case("{'a': 1, 'b': 2, 'c': 3}.filter(k, k == 'b').size() == 1", true.into(); "filter on map with modulo")] -#[test_case("[1, 2, 3, 4].all(x, x < 5) && [1, 2, 3, 4].exists(x, x == 3)", true.into(); "all and exists")] -#[test_case("coalesce(null, null, 'hello', null) == 'hello'", true.into(); "coalesce with multiple nulls")] -#[test_case("duration('3h').getHours()", 3.into(); "duration.getHours")] -#[test_case("duration('1s234ms').getMilliseconds()", 234.into(); "duration.getMilliseconds")] -#[test_case("duration('1h30m').getMinutes()", 90.into(); "duration.getMinutes")] -#[test_case("duration('1m30s').getSeconds()", 90.into(); "duration.getSeconds")] -fn test_equation(prog: &str, res: CelValue) { +#[test_case("coalesce(null, 3)", 3; "coalesce explicit null")] +#[test_case("coalesce(foo, 4)", 4; "coalesce unbound var")] +#[test_case("coalesce(1, 2, 3)", 1; "coalesce first val ok")] +#[test_case(".1", 0.1; "dot leading floating point")] +#[test_case("-.1", -0.1; "neg dot leading floating point")] +#[test_case("2+3 in [5]", true; "check in binding")] +#[test_case("foo.b || true", true; "Error bypassing")] +#[test_case(r#""\u00fc""#, "ü"; "Test unicode short lower")] +#[test_case(r#""\u00FC""#, "ü"; "Test unicode short upper")] +#[test_case(r#""\U000000fc""#, "ü"; "Test unicode long lower")] +#[test_case(r#""\U000000FC""#, "ü"; "Test unicode long upper")] +#[test_case(r#""\x48""#, "H"; "Test hex escape lower")] +#[test_case(r#""\X48""#, "H"; "Test hex escape upper")] +#[test_case("'fooBaR'.endsWithI('bar')", true; "Test endsWithI")] +#[test_case("'FoObar'.startsWithI('foo')", true; "Test startsWithI")] +#[test_case("' foo '.trim()", "foo"; "Test trim")] +#[test_case("' foo '.trimStart()", "foo "; "Test trimStart")] +#[test_case("' foo '.trimEnd()", " foo"; "Test trimEnd")] +#[test_case("'foo'.toUpper()", "FOO"; "test toUpper")] +#[test_case("'FOO'.toLower()", "foo"; "test toLower")] +#[test_case(r#"'foo bar\t\tbaz'.splitWhiteSpace()"#, vec!["foo", "bar", "baz"]; "test splitWhiteSpace")] +#[test_case("{'foo': x}.map(k, k)", vec!["foo"]; "test map on map")] +#[test_case("{'foo': x, 'bar': y}.filter(k, k == 'foo')", vec!["foo"]; "test filter on map")] +#[test_case(r#"f"{3}""#, "3"; "test basic format string")] +#[test_case(r#"f"{({"foo": 3}).foo)}""#, "3"; "test fstring with map")] +#[test_case(r#"f"{[1,2,3][2]}""#, "3"; "test fstring with list")] +#[test_case("timestamp('2024-07-30 12:00:00+00:00') - timestamp('2024-07-30 11:55:00+00:00') == duration('5m')", true; "test timestamp sub 1")] +#[test_case("timestamp('2024-07-30 11:55:00+00:00') - timestamp('2024-07-30 12:00:00+00:00')", Duration::new(-300, 0).unwrap(); "test timestamp sub 2")] +#[test_case("timestamp('2023-12-25T12:00:00Z').getDayOfMonth()", 24; "getDayOfMonth")] +#[test_case("timestamp('2023-12-25T7:00:00Z').getDayOfMonth('America/Los_Angeles')", 23; "getDayOfMonth with timezone")] +#[test_case("int(1)", 1; "identity -- int")] +#[test_case("uint(1u)", 1u64; "identity -- uint")] +#[test_case("double(5.5)", 5.5; "identity -- double")] +#[test_case("string('hello')", "hello"; "identity -- string")] +#[test_case("bytes(bytes('abc'))", crate::types::CelBytes::from_vec(vec![97u8, 98u8, 99u8]); "identity -- bytes 1")] +#[test_case("bytes(b'abc')", crate::types::CelBytes::from_vec(vec![97u8, 98u8, 99u8]); "identity -- bytes 2")] +#[test_case("duration(duration('100s')) == duration('100s')", true; "identity -- duration")] +#[test_case("duration('2h') + duration('1h1m') >= duration('3h')", true; "duration add + comp")] +#[test_case("timestamp(timestamp(100000000)) == timestamp(100000000)", true; "identity -- timestamp")] +#[test_case("bool(true)", true; "bool true")] +#[test_case("bool(false)", false; "bool false")] +#[test_case("bool('1')", true; "'1' -> bool")] +#[test_case("bool('t')", true; "'t' -> bool")] +#[test_case("bool('true')", true; "'true' -> bool 1")] +#[test_case("bool('TRUE')", true; "'TRUE' -> bool 2")] +#[test_case("bool('True')", true; "'True' -> bool 3")] +#[test_case("bool('0')", false; "'0' -> bool")] +#[test_case("bool('f')", false; "'f' -> bool")] +#[test_case("bool('false')", false; "'false' -> bool 1")] +#[test_case("bool('FALSE')", false; "'FALSE' -> bool 2")] +#[test_case("bool('False')", false; "'False' -> bool 3")] +#[test_case("!true", false; "not true")] +#[test_case("!false", true; "not false")] +#[test_case("1 + 2 == 3 && 4 + 5 == 9", true; "and operation with expressions")] +#[test_case("1 + 2 == 3 || 4 + 5 == 10", true; "or operation with expressions")] +#[test_case("! (1 + 2 == 4)", true; "negated expression")] +#[test_case("size([1, 2, 3].filter(x, x > 1)) == 2", true; "filter and size")] +#[test_case("max(1, 2, 3) + min(4, 5, 6) == 7", true; "max and min")] +#[test_case("['hello', 'world'].reduce(curr, next, curr + ' ' + next, '')", " hello world"; "reduce with strings")] +#[test_case("timestamp('2024-07-30 12:00:00Z') > timestamp('2023-07-30 12:00:00Z')", true; "timestamp comparison")] +#[test_case("{'a': 1, 'b': 2, 'c': 3}.filter(k, k == 'b').size() == 1", true; "filter on map with modulo")] +#[test_case("[1, 2, 3, 4].all(x, x < 5) && [1, 2, 3, 4].exists(x, x == 3)", true; "all and exists")] +#[test_case("coalesce(null, null, 'hello', null) == 'hello'", true; "coalesce with multiple nulls")] +#[test_case("duration('3h').getHours()", 3; "duration.getHours")] +#[test_case("duration('1s234ms').getMilliseconds()", 234; "duration.getMilliseconds")] +#[test_case("duration('1h30m').getMinutes()", 90; "duration.getMinutes")] +#[test_case("duration('1m30s').getSeconds()", 90; "duration.getSeconds")] +#[test_case("match 3 { case int: true, _: false}", false; "match int" )] +fn test_equation(prog: &str, res: impl Into) { let mut ctx = CelContext::new(); let exec_ctx = BindContext::new(); ctx.add_program_str("main", prog).unwrap(); let eval_res = ctx.exec("main", &exec_ctx).unwrap(); - assert_eq!(eval_res, res); + assert_eq!(eval_res, res.into()); } #[test] From 84b2a002dacad1cd752628794e2fb6ca44af4e7c Mon Sep 17 00:00:00 2001 From: 1BADragon <6611786+1BADragon@users.noreply.github.com> Date: Fri, 24 Jan 2025 22:32:38 -0800 Subject: [PATCH 02/16] Added labels to unresolved bytecode --- rscel/src/compiler/compiled_prog.rs | 122 ++++++++++++++++++++++++---- rscel/src/compiler/compiler.rs | 93 +++++++++++++-------- 2 files changed, 163 insertions(+), 52 deletions(-) diff --git a/rscel/src/compiler/compiled_prog.rs b/rscel/src/compiler/compiled_prog.rs index fab8525..779a09e 100644 --- a/rscel/src/compiler/compiled_prog.rs +++ b/rscel/src/compiler/compiled_prog.rs @@ -1,11 +1,39 @@ +use std::collections::HashMap; + use crate::{ interp::JmpWhen, program::ProgramDetails, types::CelByteCode, ByteCode, CelError, CelValue, CelValueDyn, Program, }; +#[derive(Debug, Clone)] +pub enum PreResolvedByteCode { + Bytecode(ByteCode), + Jmp { + label: u32, + }, + JmpCond { + when: JmpWhen, + label: u32, + leave_val: bool, + }, + Label(u32), +} + +impl From for PreResolvedByteCode { + fn from(value: ByteCode) -> Self { + PreResolvedByteCode::Bytecode(value) + } +} + +impl From for Vec { + fn from(value: CelByteCode) -> Self { + value.into_iter().map(|b| b.into()).collect() + } +} + #[derive(Debug, Clone)] pub enum NodeValue { - Bytecode(CelByteCode), + Bytecode(Vec), ConstExpr(CelValue), } @@ -19,9 +47,8 @@ pub struct CompiledProg { macro_rules! compile { ($bytecode:expr, $const_expr:expr, $( $child : ident),+) => { { - use crate::compiler::compiled_prog::NodeValue; + use crate::compiler::compiled_prog::{NodeValue, PreResolvedByteCode}; use crate::program::ProgramDetails; - use crate::types::CelByteCode; let mut new_details = ProgramDetails::new(); @@ -40,7 +67,7 @@ macro_rules! compile { } } ($($child,)+) => { - let mut new_bytecode = CelByteCode::new(); + let mut new_bytecode = Vec::::new(); $( new_bytecode.extend($child.into_bytecode().into_iter()); @@ -63,7 +90,7 @@ macro_rules! compile { impl CompiledProg { pub fn empty() -> CompiledProg { CompiledProg { - inner: NodeValue::Bytecode(CelByteCode::new()), + inner: NodeValue::Bytecode(Vec::new()), details: ProgramDetails::new(), } } @@ -77,14 +104,14 @@ impl CompiledProg { pub fn with_bytecode(bytecode: CelByteCode) -> CompiledProg { CompiledProg { - inner: NodeValue::Bytecode(bytecode), + inner: NodeValue::Bytecode(bytecode.into()), details: ProgramDetails::new(), } } - pub fn with_code_points(bytecode: Vec) -> CompiledProg { + pub fn with_code_points(bytecode: Vec) -> CompiledProg { CompiledProg { - inner: NodeValue::Bytecode(bytecode.into()), + inner: NodeValue::Bytecode(bytecode.into_iter().map(|b| b.into()).collect()), details: ProgramDetails::new(), } } @@ -122,7 +149,7 @@ impl CompiledProg { .into_iter() .map(|c| c.inner.into_bytecode().into_iter()) .flatten() - .chain(bytecode.into_iter()) + .chain(bytecode.into_iter().map(|b| b.into())) .collect(), ) }; @@ -149,9 +176,9 @@ impl CompiledProg { }, None => CompiledProg { inner: NodeValue::Bytecode( - [ByteCode::Push(c1), ByteCode::Push(c2)] + [ByteCode::Push(c1).into(), ByteCode::Push(c2).into()] .into_iter() - .chain(bytecode.into_iter()) + .chain(bytecode.into_iter().map(|b| b.into())) .collect(), ), details: new_details, @@ -162,7 +189,7 @@ impl CompiledProg { c1.into_bytecode() .into_iter() .chain(c2.into_bytecode().into_iter()) - .chain(bytecode.into_iter()) + .chain(bytecode.into_iter().map(|b| b.into())) .collect(), ), details: new_details, @@ -174,7 +201,7 @@ impl CompiledProg { let mut details = self.details; details.add_source(source); - Program::new(details, self.inner.into_bytecode()) + Program::new(details, resolve_bytecode(self.inner.into_bytecode())) } pub fn add_ident(mut self, ident: &str) -> CompiledProg { @@ -268,11 +295,14 @@ impl CompiledProg { when: JmpWhen::False, dist: (true_clause_bytecode.len() as u32) + 1, // +1 to jmp over the next jump leave_val: false, - }] + } + .into()] .into_iter(), ) .chain(true_clause_bytecode.into_iter()) - .chain([ByteCode::Jmp(false_clause_bytecode.len() as u32)].into_iter()) + .chain( + [ByteCode::Jmp(false_clause_bytecode.len() as u32).into()].into_iter(), + ) .chain(false_clause_bytecode.into_iter()) .collect(), ), @@ -289,7 +319,7 @@ impl CompiledProg { } } - pub fn into_bytecode(self) -> CelByteCode { + pub fn into_unresolved_bytecode(self) -> Vec { self.inner.into_bytecode() } @@ -310,10 +340,66 @@ impl NodeValue { matches!(*self, NodeValue::ConstExpr(_)) } - pub fn into_bytecode(self) -> CelByteCode { + pub fn into_bytecode(self) -> Vec { match self { NodeValue::Bytecode(b) => b, - NodeValue::ConstExpr(c) => CelByteCode::from_code_point(ByteCode::Push(c)), + NodeValue::ConstExpr(c) => vec![ByteCode::Push(c).into()], } } } + +pub fn resolve_bytecode(code: Vec) -> CelByteCode { + let mut curr_loc: usize = 0; + let mut locations = HashMap::::new(); + let mut ret = CelByteCode::new(); + + // determine label locations + for c in code.iter() { + match c { + PreResolvedByteCode::Label(i) => { + if locations.contains_key(i) { + panic!("Duplicate label found!"); + } + locations.insert(*i, curr_loc); + } + _ => { + curr_loc += 1; + } + } + } + + curr_loc = 0; + + // resolve the label locations + for c in code.into_iter() { + match c { + PreResolvedByteCode::Bytecode(byte_code) => { + curr_loc += 1; + ret.push(byte_code); + } + PreResolvedByteCode::Jmp { label } => { + curr_loc += 1; + let jmp_loc = locations[&label]; + let offset = jmp_loc - curr_loc; + ret.push(ByteCode::Jmp(offset as u32)); + } + PreResolvedByteCode::JmpCond { + when, + label, + leave_val, + } => { + curr_loc += 1; + let jmp_loc = locations[&label]; + let offset = jmp_loc - curr_loc; + ret.push(ByteCode::JmpCond { + when, + dist: offset as u32, + leave_val, + }); + } + PreResolvedByteCode::Label(_) => {} + } + } + + ret +} diff --git a/rscel/src/compiler/compiler.rs b/rscel/src/compiler/compiler.rs index ef0dde6..dbd312e 100644 --- a/rscel/src/compiler/compiler.rs +++ b/rscel/src/compiler/compiler.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use super::{ ast_node::AstNode, - compiled_prog::CompiledProg, + compiled_prog::{resolve_bytecode, CompiledProg, PreResolvedByteCode}, grammar::*, source_range::SourceRange, syntax_error::SyntaxError, @@ -19,6 +19,8 @@ use crate::compile; pub struct CelCompiler<'l> { tokenizer: &'l mut dyn Tokenizer, bindings: BindContext<'l>, + + next_label: u32, } impl<'l> CelCompiler<'l> { @@ -26,6 +28,7 @@ impl<'l> CelCompiler<'l> { CelCompiler { tokenizer, bindings: BindContext::for_compile(), + next_label: 0, } } @@ -49,6 +52,12 @@ impl<'l> CelCompiler<'l> { Ok(prog) } + fn get_label(&mut self) -> u32 { + let n = self.next_label; + self.next_label += 1; + n + } + fn parse_expression(&mut self) -> CelResult<(CompiledProg, AstNode)> { let (lhs_node, lhs_ast) = self.parse_conditional_or()?; @@ -98,11 +107,14 @@ impl<'l> CelCompiler<'l> { self.tokenizer.next()?; let (rhs_node, rhs_ast) = self.parse_conditional_and()?; - let jmp_node = CompiledProg::with_code_points(vec![ByteCode::JmpCond { + let label = self.get_label(); + + let jmp_node = CompiledProg::with_code_points(vec![PreResolvedByteCode::JmpCond { when: JmpWhen::True, - dist: rhs_node.bytecode_len() as u32 + 1, + label, leave_val: true, - }]); + } + .into()]); let range = current_ast.range().surrounding(rhs_ast.range()); @@ -114,7 +126,7 @@ impl<'l> CelCompiler<'l> { range, ); current_node = compile!( - [ByteCode::Or], + [ByteCode::Or.into(), PreResolvedByteCode::Label(label)], current_node.or(&rhs_node), current_node, jmp_node, @@ -136,9 +148,11 @@ impl<'l> CelCompiler<'l> { self.tokenizer.next()?; let (rhs_node, rhs_ast) = self.parse_relation()?; - let jmp_node = CompiledProg::with_code_points(vec![ByteCode::JmpCond { + let label = self.get_label(); + + let jmp_node = CompiledProg::with_code_points(vec![PreResolvedByteCode::JmpCond { when: JmpWhen::False, - dist: rhs_node.bytecode_len() as u32 + 1, + label: label, leave_val: true, }]); @@ -152,7 +166,7 @@ impl<'l> CelCompiler<'l> { range, ); current_node = compile!( - [ByteCode::And], + [ByteCode::And.into(), PreResolvedByteCode::Label(label)], current_node.and(rhs_node), current_node, jmp_node, @@ -186,7 +200,7 @@ impl<'l> CelCompiler<'l> { ); current_node = compile!( - [ByteCode::Lt], + [ByteCode::Lt.into()], current_node.lt(rhs_node), current_node, rhs_node @@ -207,7 +221,7 @@ impl<'l> CelCompiler<'l> { ); current_node = compile!( - [ByteCode::Le], + [ByteCode::Le.into()], current_node.le(rhs_node), current_node, rhs_node @@ -228,7 +242,7 @@ impl<'l> CelCompiler<'l> { ); current_node = compile!( - [ByteCode::Eq], + [ByteCode::Eq.into()], CelValueDyn::eq(¤t_node, &rhs_node), current_node, rhs_node @@ -249,7 +263,7 @@ impl<'l> CelCompiler<'l> { ); current_node = compile!( - [ByteCode::Ne], + [ByteCode::Ne.into()], current_node.neq(rhs_node), current_node, rhs_node @@ -270,7 +284,7 @@ impl<'l> CelCompiler<'l> { ); current_node = compile!( - [ByteCode::Ge], + [ByteCode::Ge.into()], current_node.ge(rhs_node), current_node, rhs_node @@ -291,7 +305,7 @@ impl<'l> CelCompiler<'l> { ); current_node = compile!( - [ByteCode::Gt], + [ByteCode::Gt.into()], current_node.gt(rhs_node), current_node, rhs_node @@ -311,7 +325,7 @@ impl<'l> CelCompiler<'l> { range, ); current_node = compile!( - [ByteCode::In], + [ByteCode::In.into()], current_node.in_(rhs_node), current_node, rhs_node @@ -345,7 +359,7 @@ impl<'l> CelCompiler<'l> { ); current_node = compile!( - [ByteCode::Add], + [ByteCode::Add.into()], current_node + rhs_node, current_node, rhs_node @@ -367,7 +381,7 @@ impl<'l> CelCompiler<'l> { ); current_node = compile!( - [ByteCode::Sub], + [ByteCode::Sub.into()], current_node - rhs_node, current_node, rhs_node @@ -400,7 +414,7 @@ impl<'l> CelCompiler<'l> { range, ); current_node = compile!( - [ByteCode::Mul], + [ByteCode::Mul.into()], current_node * rhs_node, current_node, rhs_node @@ -422,7 +436,7 @@ impl<'l> CelCompiler<'l> { ); current_node = compile!( - [ByteCode::Div], + [ByteCode::Div.into()], current_node / rhs_node, current_node, rhs_node @@ -444,7 +458,7 @@ impl<'l> CelCompiler<'l> { ); current_node = compile!( - [ByteCode::Mod], + [ByteCode::Mod.into()], current_node % rhs_node, current_node, rhs_node @@ -506,7 +520,7 @@ impl<'l> CelCompiler<'l> { self.tokenizer.next()?; let (not_list, ast) = self.parse_not_list()?; - let node = compile!([ByteCode::Not], not_list, not_list); + let node = compile!([ByteCode::Not.into()], not_list, not_list); let range = ast.range().surrounding(loc); @@ -539,7 +553,7 @@ impl<'l> CelCompiler<'l> { self.tokenizer.next()?; let (neg_list, ast) = self.parse_neg_list()?; - let node = compile!([ByteCode::Neg], neg_list, neg_list); + let node = compile!([ByteCode::Neg.into()], neg_list, neg_list); let range = ast.range().surrounding(loc); @@ -653,7 +667,10 @@ impl<'l> CelCompiler<'l> { args_ast.push(ast); args_node = args_node.append_result(CompiledProg::with_code_points(vec![ - ByteCode::Push(a.into_bytecode().into()), + ByteCode::Push( + resolve_bytecode(a.into_unresolved_bytecode()).into(), + ) + .into(), ])) } @@ -661,7 +678,8 @@ impl<'l> CelCompiler<'l> { .consume_child(args_node) .consume_child(CompiledProg::with_code_points(vec![ByteCode::Call( args_len as u32, - )])); + ) + .into()])); member_prime_node = self.check_for_const(member_prime_node); @@ -697,7 +715,7 @@ impl<'l> CelCompiler<'l> { loc: rbracket_loc, }) => { member_prime_node = compile!( - [ByteCode::Index], + [ByteCode::Index.into()], member_prime_node.index(index_node), member_prime_node, index_node @@ -746,8 +764,10 @@ impl<'l> CelCompiler<'l> { token: Token::Ident(val), loc, }) => Ok(( - CompiledProg::with_code_points(vec![ByteCode::Push(CelValue::from_ident(&val))]) - .add_ident(&val), + CompiledProg::with_code_points(vec![ + ByteCode::Push(CelValue::from_ident(&val)).into() + ]) + .add_ident(&val), AstNode::new(Primary::Ident(Ident(val.clone())), loc), )), Some(TokenWithLoc { @@ -943,13 +963,13 @@ impl<'l> CelCompiler<'l> { token: Token::FStringLit(segments), loc, }) => { - let mut bytecode = Vec::new(); + let mut bytecode = Vec::::new(); for segment in segments.iter() { - bytecode.push(ByteCode::Push(CelValue::Ident("string".to_string()))); + bytecode.push(ByteCode::Push(CelValue::Ident("string".to_string())).into()); match segment { FStringSegment::Lit(c) => { - bytecode.push(ByteCode::Push(CelValue::String(c.clone()))) + bytecode.push(ByteCode::Push(CelValue::String(c.clone())).into()) } FStringSegment::Expr(e) => { let mut tok = StringTokenizer::with_input(&e); @@ -957,14 +977,19 @@ impl<'l> CelCompiler<'l> { let (e, _) = comp.parse_expression()?; - bytecode.push(ByteCode::Push(CelValue::ByteCode(e.into_bytecode()))); + bytecode.push( + ByteCode::Push(CelValue::ByteCode(resolve_bytecode( + e.into_unresolved_bytecode(), + ))) + .into(), + ); } } - bytecode.push(ByteCode::Call(1)); + bytecode.push(ByteCode::Call(1).into()); } // Reverse it so its evaluated in order on the stack - bytecode.push(ByteCode::FmtString(segments.len() as u32)); + bytecode.push(ByteCode::FmtString(segments.len() as u32).into()); Ok(( CompiledProg::with_code_points(bytecode), @@ -1066,7 +1091,7 @@ impl<'l> CelCompiler<'l> { fn check_for_const(&self, member_prime_node: CompiledProg) -> CompiledProg { let mut i = Interpreter::empty(); i.add_bindings(&self.bindings); - let bc = member_prime_node.into_bytecode(); + let bc = resolve_bytecode(member_prime_node.into_unresolved_bytecode()); let r = i.run_raw(&bc, true); match r { From 1880e81c8a9af93b88b0e43edd0c7ae156b36660 Mon Sep 17 00:00:00 2001 From: 1BADragon <6611786+1BADragon@users.noreply.github.com> Date: Sun, 26 Jan 2025 13:19:35 -0800 Subject: [PATCH 03/16] Allow for negative jumps --- python/rscel.pyi | 14 + rscel/src/compiler/compiled_prog.rs | 133 +---- .../src/compiler/compiled_prog/preresolved.rs | 159 +++++ rscel/src/compiler/compiler.rs | 58 +- rscel/src/interp/interp.rs | 556 +++++++++++++++++ rscel/src/interp/mod.rs | 562 +----------------- rscel/src/interp/types.rs | 158 +---- rscel/src/interp/types/bytecode.rs | 82 +++ rscel/src/interp/types/celstackvalue.rs | 52 ++ rscel/src/interp/types/rscallable.rs | 26 + 10 files changed, 961 insertions(+), 839 deletions(-) create mode 100644 python/rscel.pyi create mode 100644 rscel/src/compiler/compiled_prog/preresolved.rs create mode 100644 rscel/src/interp/interp.rs create mode 100644 rscel/src/interp/types/bytecode.rs create mode 100644 rscel/src/interp/types/celstackvalue.rs create mode 100644 rscel/src/interp/types/rscallable.rs diff --git a/python/rscel.pyi b/python/rscel.pyi new file mode 100644 index 0000000..581e7ce --- /dev/null +++ b/python/rscel.pyi @@ -0,0 +1,14 @@ +from typing import Any, Callable, Tuple + + +CelBasicType = int | float | str | bool | None +CelArrayType = list[CelBasicType | 'CelArrayType' | 'CelDict'] +CelDict = dict[str, 'CelValue'] +CelValue = CelDict | CelArrayType | CelBasicType + +CelCallable = Callable[[*Tuple[CelValue, ...]], CelValue] + +CelBinding = dict[str, CelValue | CelCallable | Any] + +def eval(prog: str, binding: CelBinding) -> CelValue: + ... diff --git a/rscel/src/compiler/compiled_prog.rs b/rscel/src/compiler/compiled_prog.rs index 779a09e..a9ff3b0 100644 --- a/rscel/src/compiler/compiled_prog.rs +++ b/rscel/src/compiler/compiled_prog.rs @@ -1,39 +1,14 @@ -use std::collections::HashMap; +mod preresolved; use crate::{ interp::JmpWhen, program::ProgramDetails, types::CelByteCode, ByteCode, CelError, CelValue, CelValueDyn, Program, }; - -#[derive(Debug, Clone)] -pub enum PreResolvedByteCode { - Bytecode(ByteCode), - Jmp { - label: u32, - }, - JmpCond { - when: JmpWhen, - label: u32, - leave_val: bool, - }, - Label(u32), -} - -impl From for PreResolvedByteCode { - fn from(value: ByteCode) -> Self { - PreResolvedByteCode::Bytecode(value) - } -} - -impl From for Vec { - fn from(value: CelByteCode) -> Self { - value.into_iter().map(|b| b.into()).collect() - } -} +pub use preresolved::{PreResolvedByteCode, PreResolvedCodePoint}; #[derive(Debug, Clone)] pub enum NodeValue { - Bytecode(Vec), + Bytecode(PreResolvedByteCode), ConstExpr(CelValue), } @@ -67,7 +42,7 @@ macro_rules! compile { } } ($($child,)+) => { - let mut new_bytecode = Vec::::new(); + let mut new_bytecode = PreResolvedByteCode::new(); $( new_bytecode.extend($child.into_bytecode().into_iter()); @@ -90,7 +65,7 @@ macro_rules! compile { impl CompiledProg { pub fn empty() -> CompiledProg { CompiledProg { - inner: NodeValue::Bytecode(Vec::new()), + inner: NodeValue::Bytecode(PreResolvedByteCode::new()), details: ProgramDetails::new(), } } @@ -109,13 +84,22 @@ impl CompiledProg { } } - pub fn with_code_points(bytecode: Vec) -> CompiledProg { + pub fn with_code_points(bytecode: Vec) -> CompiledProg { CompiledProg { - inner: NodeValue::Bytecode(bytecode.into_iter().map(|b| b.into()).collect()), + inner: NodeValue::Bytecode(bytecode.into_iter().collect()), details: ProgramDetails::new(), } } + pub fn append_if_bytecode(&mut self, b: impl IntoIterator) { + match &mut self.inner { + NodeValue::Bytecode(bytecode) => { + bytecode.extend(b); + } + NodeValue::ConstExpr(_) => { /* do nothing */ } + } + } + pub fn with_const(val: CelValue) -> CompiledProg { CompiledProg { inner: NodeValue::ConstExpr(val), @@ -176,10 +160,13 @@ impl CompiledProg { }, None => CompiledProg { inner: NodeValue::Bytecode( - [ByteCode::Push(c1).into(), ByteCode::Push(c2).into()] - .into_iter() - .chain(bytecode.into_iter().map(|b| b.into())) - .collect(), + [ + PreResolvedCodePoint::Bytecode(ByteCode::Push(c1)), + PreResolvedCodePoint::Bytecode(ByteCode::Push(c2)), + ] + .into_iter() + .chain(bytecode.into_iter().map(|b| b.into())) + .collect(), ), details: new_details, }, @@ -201,7 +188,7 @@ impl CompiledProg { let mut details = self.details; details.add_source(source); - Program::new(details, resolve_bytecode(self.inner.into_bytecode())) + Program::new(details, self.inner.into_bytecode().resolve()) } pub fn add_ident(mut self, ident: &str) -> CompiledProg { @@ -291,17 +278,17 @@ impl CompiledProg { .into_bytecode() .into_iter() .chain( - [ByteCode::JmpCond { + [PreResolvedCodePoint::Bytecode(ByteCode::JmpCond { when: JmpWhen::False, - dist: (true_clause_bytecode.len() as u32) + 1, // +1 to jmp over the next jump + dist: i32::try_from(true_clause_bytecode.len() + 1) + .expect("Jump distance too far"), leave_val: false, - } - .into()] + })] .into_iter(), ) .chain(true_clause_bytecode.into_iter()) .chain( - [ByteCode::Jmp(false_clause_bytecode.len() as u32).into()].into_iter(), + [ByteCode::Jmp(false_clause_bytecode.len() as i32).into()].into_iter(), ) .chain(false_clause_bytecode.into_iter()) .collect(), @@ -319,7 +306,7 @@ impl CompiledProg { } } - pub fn into_unresolved_bytecode(self) -> Vec { + pub fn into_unresolved_bytecode(self) -> PreResolvedByteCode { self.inner.into_bytecode() } @@ -340,66 +327,10 @@ impl NodeValue { matches!(*self, NodeValue::ConstExpr(_)) } - pub fn into_bytecode(self) -> Vec { + pub fn into_bytecode(self) -> PreResolvedByteCode { match self { NodeValue::Bytecode(b) => b, - NodeValue::ConstExpr(c) => vec![ByteCode::Push(c).into()], - } - } -} - -pub fn resolve_bytecode(code: Vec) -> CelByteCode { - let mut curr_loc: usize = 0; - let mut locations = HashMap::::new(); - let mut ret = CelByteCode::new(); - - // determine label locations - for c in code.iter() { - match c { - PreResolvedByteCode::Label(i) => { - if locations.contains_key(i) { - panic!("Duplicate label found!"); - } - locations.insert(*i, curr_loc); - } - _ => { - curr_loc += 1; - } + NodeValue::ConstExpr(c) => [ByteCode::Push(c)].into_iter().collect(), } } - - curr_loc = 0; - - // resolve the label locations - for c in code.into_iter() { - match c { - PreResolvedByteCode::Bytecode(byte_code) => { - curr_loc += 1; - ret.push(byte_code); - } - PreResolvedByteCode::Jmp { label } => { - curr_loc += 1; - let jmp_loc = locations[&label]; - let offset = jmp_loc - curr_loc; - ret.push(ByteCode::Jmp(offset as u32)); - } - PreResolvedByteCode::JmpCond { - when, - label, - leave_val, - } => { - curr_loc += 1; - let jmp_loc = locations[&label]; - let offset = jmp_loc - curr_loc; - ret.push(ByteCode::JmpCond { - when, - dist: offset as u32, - leave_val, - }); - } - PreResolvedByteCode::Label(_) => {} - } - } - - ret } diff --git a/rscel/src/compiler/compiled_prog/preresolved.rs b/rscel/src/compiler/compiled_prog/preresolved.rs new file mode 100644 index 0000000..b05cca7 --- /dev/null +++ b/rscel/src/compiler/compiled_prog/preresolved.rs @@ -0,0 +1,159 @@ +use std::collections::HashMap; + +use crate::{interp::JmpWhen, types::CelByteCode, ByteCode}; + +#[derive(Debug, Clone)] +pub enum PreResolvedCodePoint { + Bytecode(ByteCode), + Jmp { + label: u32, + }, + JmpCond { + when: JmpWhen, + label: u32, + leave_val: bool, + }, + Label(u32), +} + +#[derive(Debug, Clone)] +pub struct PreResolvedByteCode { + inner: Vec, + len: usize, +} + +impl From for PreResolvedCodePoint { + fn from(value: ByteCode) -> Self { + PreResolvedCodePoint::Bytecode(value) + } +} + +impl From for Vec { + fn from(value: CelByteCode) -> Self { + value.into_iter().map(|b| b.into()).collect() + } +} + +impl PreResolvedByteCode { + pub fn new() -> Self { + PreResolvedByteCode { + inner: Vec::new(), + len: 0, + } + } + + pub fn extend(&mut self, byte_codes: impl IntoIterator) { + for b in byte_codes.into_iter() { + match &b { + PreResolvedCodePoint::Label(_) => {} + _ => self.len += 1, + } + + self.inner.push(b) + } + } + + pub fn into_iter(self) -> impl Iterator { + self.inner.into_iter() + } + + pub fn len(&self) -> usize { + self.len + } + + pub fn resolve(self) -> CelByteCode { + let mut curr_loc: usize = 0; + let mut locations = HashMap::::new(); + let mut ret = CelByteCode::new(); + + // determine label locations + for c in self.inner.iter() { + match c { + PreResolvedCodePoint::Label(i) => { + if locations.contains_key(i) { + panic!("Duplicate label found!"); + } + locations.insert(*i, curr_loc); + } + _ => { + curr_loc += 1; + } + } + } + + curr_loc = 0; + + // resolve the label locations + for c in self.inner.into_iter() { + match c { + PreResolvedCodePoint::Bytecode(byte_code) => { + curr_loc += 1; + ret.push(byte_code); + } + PreResolvedCodePoint::Jmp { label } => { + curr_loc += 1; + let jmp_loc = locations[&label]; + let offset = (jmp_loc as isize) - (curr_loc as isize); + ret.push(ByteCode::Jmp( + i32::try_from(offset).expect("Attempt to jump farther than possible"), + )); + } + PreResolvedCodePoint::JmpCond { + when, + label, + leave_val, + } => { + curr_loc += 1; + let jmp_loc = locations[&label]; + let offset = (jmp_loc as isize) - (curr_loc as isize); + ret.push(ByteCode::JmpCond { + when, + dist: offset as i32, + leave_val, + }); + } + PreResolvedCodePoint::Label(_) => {} + } + } + + ret + } +} + +impl From for PreResolvedByteCode { + fn from(value: CelByteCode) -> Self { + value.into_iter().collect() + } +} + +impl FromIterator for PreResolvedByteCode { + fn from_iter>(iter: T) -> Self { + let v: Vec<_> = iter.into_iter().map(|b| b.into()).collect(); + let l = v.len(); + + PreResolvedByteCode { inner: v, len: l } + } +} + +impl FromIterator for PreResolvedByteCode { + fn from_iter>(iter: T) -> Self { + let mut code_points = Vec::new(); + let mut size = 0; + + for code_point in iter.into_iter() { + match &code_point { + PreResolvedCodePoint::Label(_) => {} + _ => { + size += 1; + } + } + + code_points.push(code_point); + } + + PreResolvedByteCode { + inner: code_points, + len: size, + } + } +} diff --git a/rscel/src/compiler/compiler.rs b/rscel/src/compiler/compiler.rs index dbd312e..b7689a7 100644 --- a/rscel/src/compiler/compiler.rs +++ b/rscel/src/compiler/compiler.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use super::{ ast_node::AstNode, - compiled_prog::{resolve_bytecode, CompiledProg, PreResolvedByteCode}, + compiled_prog::{CompiledProg, PreResolvedCodePoint}, grammar::*, source_range::SourceRange, syntax_error::SyntaxError, @@ -102,19 +102,20 @@ impl<'l> CelCompiler<'l> { fn parse_conditional_or(&mut self) -> CelResult<(CompiledProg, AstNode)> { let (mut current_node, mut current_ast) = into_unary(self.parse_conditional_and()?); + let label = self.get_label(); + loop { if let Some(Token::OrOr) = self.tokenizer.peek()?.as_token() { self.tokenizer.next()?; let (rhs_node, rhs_ast) = self.parse_conditional_and()?; - let label = self.get_label(); - - let jmp_node = CompiledProg::with_code_points(vec![PreResolvedByteCode::JmpCond { - when: JmpWhen::True, - label, - leave_val: true, - } - .into()]); + let jmp_node = + CompiledProg::with_code_points(vec![PreResolvedCodePoint::JmpCond { + when: JmpWhen::True, + label, + leave_val: true, + } + .into()]); let range = current_ast.range().surrounding(rhs_ast.range()); @@ -126,7 +127,7 @@ impl<'l> CelCompiler<'l> { range, ); current_node = compile!( - [ByteCode::Or.into(), PreResolvedByteCode::Label(label)], + [ByteCode::Or.into()], current_node.or(&rhs_node), current_node, jmp_node, @@ -137,24 +138,27 @@ impl<'l> CelCompiler<'l> { } } + current_node.append_if_bytecode([PreResolvedCodePoint::Label(label)]); + Ok((current_node, current_ast)) } fn parse_conditional_and(&mut self) -> CelResult<(CompiledProg, AstNode)> { let (mut current_node, mut current_ast) = into_unary(self.parse_relation()?); + let label = self.get_label(); + loop { if let Some(Token::AndAnd) = self.tokenizer.peek()?.as_token() { self.tokenizer.next()?; let (rhs_node, rhs_ast) = self.parse_relation()?; - let label = self.get_label(); - - let jmp_node = CompiledProg::with_code_points(vec![PreResolvedByteCode::JmpCond { - when: JmpWhen::False, - label: label, - leave_val: true, - }]); + let jmp_node = + CompiledProg::with_code_points(vec![PreResolvedCodePoint::JmpCond { + when: JmpWhen::False, + label: label, + leave_val: true, + }]); let range = current_ast.range().surrounding(rhs_ast.range()); @@ -166,7 +170,7 @@ impl<'l> CelCompiler<'l> { range, ); current_node = compile!( - [ByteCode::And.into(), PreResolvedByteCode::Label(label)], + [ByteCode::And.into()], current_node.and(rhs_node), current_node, jmp_node, @@ -176,6 +180,8 @@ impl<'l> CelCompiler<'l> { break; } } + current_node.append_if_bytecode([PreResolvedCodePoint::Label(label)]); + Ok((current_node, current_ast)) } @@ -667,10 +673,8 @@ impl<'l> CelCompiler<'l> { args_ast.push(ast); args_node = args_node.append_result(CompiledProg::with_code_points(vec![ - ByteCode::Push( - resolve_bytecode(a.into_unresolved_bytecode()).into(), - ) - .into(), + ByteCode::Push(a.into_unresolved_bytecode().resolve().into()) + .into(), ])) } @@ -963,7 +967,7 @@ impl<'l> CelCompiler<'l> { token: Token::FStringLit(segments), loc, }) => { - let mut bytecode = Vec::::new(); + let mut bytecode = Vec::::new(); for segment in segments.iter() { bytecode.push(ByteCode::Push(CelValue::Ident("string".to_string())).into()); @@ -978,9 +982,9 @@ impl<'l> CelCompiler<'l> { let (e, _) = comp.parse_expression()?; bytecode.push( - ByteCode::Push(CelValue::ByteCode(resolve_bytecode( - e.into_unresolved_bytecode(), - ))) + ByteCode::Push(CelValue::ByteCode( + e.into_unresolved_bytecode().resolve(), + )) .into(), ); } @@ -1091,7 +1095,7 @@ impl<'l> CelCompiler<'l> { fn check_for_const(&self, member_prime_node: CompiledProg) -> CompiledProg { let mut i = Interpreter::empty(); i.add_bindings(&self.bindings); - let bc = resolve_bytecode(member_prime_node.into_unresolved_bytecode()); + let bc = member_prime_node.into_unresolved_bytecode().resolve(); let r = i.run_raw(&bc, true); match r { diff --git a/rscel/src/interp/interp.rs b/rscel/src/interp/interp.rs new file mode 100644 index 0000000..44461da --- /dev/null +++ b/rscel/src/interp/interp.rs @@ -0,0 +1,556 @@ +pub use super::types::{ByteCode, CelStackValue, JmpWhen, RsCallable}; +use crate::{types::CelByteCode, CelValueDyn}; +use std::{collections::HashMap, fmt}; + +use crate::{ + context::construct_type, utils::ScopedCounter, BindContext, CelContext, CelError, CelResult, + CelValue, RsCelFunction, RsCelMacro, +}; + +struct InterpStack<'a, 'b> { + stack: Vec>, + + ctx: &'a Interpreter<'b>, +} + +impl<'a, 'b> InterpStack<'a, 'b> { + fn new(ctx: &'b Interpreter) -> InterpStack<'a, 'b> { + InterpStack { + stack: Vec::new(), + ctx, + } + } + + fn push(&mut self, val: CelStackValue<'b>) { + self.stack.push(val); + } + + fn push_val(&mut self, val: CelValue) { + self.stack.push(CelStackValue::Value(val)); + } + + fn pop(&mut self) -> CelResult { + match self.stack.pop() { + Some(stack_val) => { + if let CelStackValue::Value(val) = stack_val { + if let CelValue::Ident(name) = val { + if let Some(val) = self.ctx.get_type_by_name(&name) { + return Ok(CelStackValue::Value(val.clone())); + } + + if let Some(val) = self.ctx.get_param_by_name(&name) { + return Ok(CelStackValue::Value(val.clone())); + } + + if let Some(ctx) = self.ctx.cel { + // Allow for loaded programs to run as values + if let Some(prog) = ctx.get_program(&name) { + return self.ctx.run_raw(prog.bytecode(), true).map(|x| x.into()); + } + } + + Ok(CelValue::from_err(CelError::binding(&name)).into()) + } else { + Ok(val.into()) + } + } else { + Ok(stack_val) + } + } + None => Err(CelError::runtime("No value on stack!")), + } + } + + fn pop_val(&mut self) -> CelResult { + self.pop()?.into_value() + } + + fn pop_noresolve(&mut self) -> CelResult> { + match self.stack.pop() { + Some(val) => Ok(val), + None => Err(CelError::runtime("No value on stack!")), + } + } + + fn pop_tryresolve(&mut self) -> CelResult> { + match self.stack.pop() { + Some(val) => match val.try_into()? { + CelValue::Ident(name) => { + if let Some(val) = self.ctx.get_param_by_name(&name) { + Ok(val.clone().into()) + } else { + Ok(CelStackValue::Value(CelValue::from_ident(&name))) + } + } + other => Ok(CelStackValue::Value(other.into())), + }, + None => Err(CelError::runtime("No value on stack!")), + } + } +} + +impl<'a, 'b> fmt::Debug for InterpStack<'a, 'b> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self.stack) + } +} + +pub struct Interpreter<'a> { + cel: Option<&'a CelContext>, + bindings: Option<&'a BindContext<'a>>, + depth: ScopedCounter, +} + +impl<'a> Interpreter<'a> { + pub fn new(cel: &'a CelContext, bindings: &'a BindContext) -> Interpreter<'a> { + Interpreter { + cel: Some(cel), + bindings: Some(bindings), + depth: ScopedCounter::new(), + } + } + + pub fn empty() -> Interpreter<'a> { + Interpreter { + cel: None, + bindings: None, + depth: ScopedCounter::new(), + } + } + + pub fn add_bindings(&mut self, bindings: &'a BindContext) { + self.bindings = Some(bindings); + } + + pub fn cel_copy(&self) -> Option { + self.cel.cloned() + } + + pub fn bindings_copy(&self) -> Option { + self.bindings.cloned() + } + + pub fn run_program(&self, name: &str) -> CelResult { + match self.cel { + Some(cel) => match cel.get_program(name) { + Some(prog) => self.run_raw(prog.bytecode(), true), + None => Err(CelError::binding(&name)), + }, + None => Err(CelError::internal("No CEL context bound to interpreter")), + } + } + + pub fn run_raw(&self, prog: &CelByteCode, resolve: bool) -> CelResult { + let mut pc: usize = 0; + let mut stack = InterpStack::new(self); + + let count = self.depth.inc(); + + if count.count() > 32 { + return Err(CelError::runtime("Max call depth excceded")); + } + + while pc < prog.len() { + let oldpc = pc; + pc += 1; + match &prog[oldpc] { + ByteCode::Push(val) => stack.push(val.clone().into()), + ByteCode::Or => { + let v2 = stack.pop_val()?; + let v1 = stack.pop_val()?; + + stack.push_val(v1.or(&v2)) + } + ByteCode::And => { + let v2 = stack.pop_val()?; + let v1 = stack.pop_val()?; + + stack.push_val(v1.and(v2)) + } + ByteCode::Not => { + let v1 = stack.pop_val()?; + + stack.push_val(!v1); + } + ByteCode::Neg => { + let v1 = stack.pop_val()?; + + stack.push_val(-v1); + } + ByteCode::Add => { + let v2 = stack.pop_val()?; + let v1 = stack.pop_val()?; + + stack.push_val(v1 + v2); + } + ByteCode::Sub => { + let v2 = stack.pop_val()?; + let v1 = stack.pop_val()?; + + stack.push_val(v1 - v2); + } + ByteCode::Mul => { + let v2 = stack.pop_val()?; + let v1 = stack.pop_val()?; + + stack.push_val(v1 * v2); + } + ByteCode::Div => { + let v2 = stack.pop_val()?; + let v1 = stack.pop_val()?; + + stack.push_val(v1 / v2); + } + ByteCode::Mod => { + let v2 = stack.pop_val()?; + let v1 = stack.pop_val()?; + + stack.push_val(v1 % v2); + } + ByteCode::Lt => { + let v2 = stack.pop_val()?; + let v1 = stack.pop_val()?; + + stack.push_val(v1.lt(v2)); + } + ByteCode::Le => { + let v2 = stack.pop_val()?; + let v1 = stack.pop_val()?; + + stack.push_val(v1.le(v2)); + } + ByteCode::Eq => { + let v2 = stack.pop_val()?; + let v1 = stack.pop_val()?; + + stack.push_val(CelValueDyn::eq(&v1, &v2)); + } + ByteCode::Ne => { + let v2 = stack.pop_val()?; + let v1 = stack.pop_val()?; + + stack.push_val(v1.neq(v2)); + } + ByteCode::Ge => { + let v2 = stack.pop_val()?; + let v1 = stack.pop_val()?; + + stack.push_val(v1.ge(v2)); + } + ByteCode::Gt => { + let v2 = stack.pop_val()?; + let v1 = stack.pop_val()?; + + stack.push_val(v1.gt(v2)); + } + ByteCode::In => { + let rhs = stack.pop_val()?; + let lhs = stack.pop_val()?; + + stack.push_val(lhs.in_(rhs)); + } + ByteCode::Jmp(dist) => pc = pc + *dist as usize, + ByteCode::JmpCond { + when, + dist, + leave_val, + } => { + let mut v1 = stack.pop_val()?; + match when { + JmpWhen::True => { + if cfg!(feature = "type_prop") { + if v1.is_truthy() { + v1 = CelValue::true_(); + pc += *dist as usize + } + } else if let CelValue::Err(ref _e) = v1 { + // do nothing + } else if let CelValue::Bool(v) = v1 { + if v { + pc += *dist as usize + } + } else { + return Err(CelError::invalid_op(&format!( + "JMP TRUE invalid on type {:?}", + v1.as_type() + ))); + } + } + JmpWhen::False => { + if cfg!(feature = "type_prop") { + if !v1.is_truthy() { + v1 = CelValue::false_(); + pc += *dist as usize + } + } else if let CelValue::Bool(v) = v1 { + if !v { + pc += *dist as usize + } + } else { + return Err(CelError::invalid_op(&format!( + "JMP FALSE invalid on type {:?}", + v1.as_type() + ))); + } + } + }; + if *leave_val { + stack.push_val(v1); + } + } + ByteCode::MkList(size) => { + let mut v = Vec::new(); + + for _ in 0..*size { + v.push(stack.pop_val()?) + } + + v.reverse(); + stack.push_val(v.into()); + } + ByteCode::MkDict(size) => { + let mut map = HashMap::new(); + + for _ in 0..*size { + let key = if let CelValue::String(key) = stack.pop_val()? { + key + } else { + return Err(CelError::value("Only strings can be used as Object keys")); + }; + + map.insert(key, stack.pop_val()?); + } + + stack.push_val(map.into()); + } + ByteCode::Index => { + let index = stack.pop_val()?; + let obj = stack.pop_val()?; + + stack.push_val(obj.index(index)); + } + ByteCode::Access => { + let index = stack.pop_noresolve()?; + if let CelValue::Ident(ident) = index.as_value()? { + let obj = stack.pop()?.into_value()?; + match obj { + CelValue::Map(ref map) => match map.get(ident.as_str()) { + Some(val) => stack.push_val(val.clone()), + None => match self.callable_by_name(ident.as_str()) { + Ok(callable) => stack.push(CelStackValue::BoundCall { + callable, + value: obj, + }), + Err(_) => { + stack.push( + CelValue::from_err(CelError::attribute( + "obj", + ident.as_str(), + )) + .into(), + ); + } + }, + }, + #[cfg(feature = "protobuf")] + CelValue::Message(msg) => { + let desc = msg.descriptor_dyn(); + + if let Some(field) = desc.field_by_name(ident.as_str()) { + stack.push_val( + field.get_singular_field_or_default(msg.as_ref()).into(), + ) + } else { + return Err(CelError::attribute("msg", ident.as_str())); + } + } + CelValue::Dyn(d) => { + stack.push_val(d.access(ident.as_str())); + } + _ => { + if let Some(bindings) = self.bindings { + if bindings.get_func(ident.as_str()).is_some() + || bindings.get_macro(ident.as_str()).is_some() + { + stack.push(CelStackValue::BoundCall { + callable: self.callable_by_name(ident.as_str())?, + value: obj, + }); + } else { + stack.push( + CelValue::from_err(CelError::attribute( + "obj", + ident.as_str(), + )) + .into(), + ); + } + } else { + return Err(CelError::Runtime( + "Invalid state: no bindings".to_string(), + )); + } + } + } + } else { + let obj_type = stack.pop()?.into_value()?.as_type(); + stack.push( + CelValue::from_err(CelError::value(&format!( + "Index operator invalid between {:?} and {:?}", + index.into_value()?.as_type(), + obj_type + ))) + .into(), + ); + } + } + ByteCode::Call(n_args) => { + let mut args = Vec::new(); + + for _ in 0..*n_args { + args.push(stack.pop()?.into_value()?) + } + + match stack.pop_noresolve()? { + CelStackValue::BoundCall { callable, value } => match callable { + RsCallable::Function(func) => { + let arg_values = self.resolve_args(args)?; + stack.push_val(func(value, arg_values)); + } + RsCallable::Macro(macro_) => { + stack.push_val(self.call_macro(&value, &args, macro_)?); + } + }, + CelStackValue::Value(value) => match value { + CelValue::Ident(func_name) => { + if let Some(func) = self.get_func_by_name(&func_name) { + let arg_values = self.resolve_args(args)?; + stack.push_val(func(CelValue::from_null(), arg_values)); + } else if let Some(macro_) = self.get_macro_by_name(&func_name) { + stack.push_val(self.call_macro( + &CelValue::from_null(), + &args, + macro_, + )?); + } else if let Some(CelValue::Type(type_name)) = + self.get_type_by_name(&func_name) + { + let arg_values = self.resolve_args(args)?; + stack.push_val(construct_type(type_name, arg_values)); + } else { + stack.push_val(CelValue::from_err(CelError::runtime( + &format!("{} is not callable", func_name), + ))); + } + } + CelValue::Type(type_name) => { + let arg_values = self.resolve_args(args)?; + stack.push_val(construct_type(&type_name, arg_values)); + } + other => stack.push_val( + CelValue::from_err(CelError::runtime(&format!( + "{:?} cannot be called", + other + ))) + .into(), + ), + }, + }; + } + ByteCode::FmtString(nsegments) => { + let mut segments = Vec::new(); + for _ in 0..*nsegments { + segments.push(stack.pop_val()?); + } + + let mut working = String::new(); + for seg in segments.into_iter().rev() { + if let CelValue::String(s) = seg { + working.push_str(&s) + } else { + return Err(CelError::Runtime( + "Expected string from format string specifier".to_string(), + )); + } + } + + stack.push_val(CelValue::String(working)); + } + }; + } + + if resolve { + match stack.pop() { + Ok(val) => { + let cel: CelValue = val.try_into()?; + cel.into_result() + } + Err(err) => Err(err), + } + } else { + match stack.pop_tryresolve() { + Ok(val) => { + let cel: CelValue = val.try_into()?; + cel.into_result() + } + Err(err) => Err(err), + } + } + } + + fn call_macro( + &self, + this: &CelValue, + args: &Vec, + macro_: &RsCelMacro, + ) -> Result { + let mut v = Vec::new(); + for arg in args.iter() { + if let CelValue::ByteCode(bc) = arg { + v.push(bc); + } else { + return Err(CelError::internal("macro args must be bytecode")); + } + } + let res = macro_(self, this.clone(), &v); + Ok(res) + } + + fn resolve_args(&self, args: Vec) -> Result, CelError> { + let mut arg_values = Vec::new(); + for arg in args.into_iter() { + if let CelValue::ByteCode(bc) = arg { + arg_values.push(self.run_raw(&bc, true)?); + } else { + arg_values.push(arg) + } + } + Ok(arg_values) + } + + fn get_param_by_name(&self, name: &str) -> Option<&'a CelValue> { + self.bindings?.get_param(name) + } + + fn get_func_by_name(&self, name: &str) -> Option<&'a RsCelFunction> { + self.bindings?.get_func(name) + } + + fn get_macro_by_name(&self, name: &str) -> Option<&'a RsCelMacro> { + self.bindings?.get_macro(name) + } + + fn get_type_by_name(&self, name: &str) -> Option<&'a CelValue> { + self.bindings?.get_type(name) + } + + fn callable_by_name(&self, name: &str) -> CelResult { + if let Some(func) = self.get_func_by_name(name) { + Ok(RsCallable::Function(func)) + } else if let Some(macro_) = self.get_macro_by_name(name) { + Ok(RsCallable::Macro(macro_)) + } else { + Err(CelError::value(&format!("{} is not callable", name))) + } + } +} diff --git a/rscel/src/interp/mod.rs b/rscel/src/interp/mod.rs index c3ca5e3..e14bac4 100644 --- a/rscel/src/interp/mod.rs +++ b/rscel/src/interp/mod.rs @@ -1,564 +1,8 @@ +mod interp; mod types; -use crate::{types::CelByteCode, CelValueDyn}; -use std::{collections::HashMap, fmt}; -pub use types::{ByteCode, JmpWhen}; -use crate::{ - context::construct_type, utils::ScopedCounter, BindContext, CelContext, CelError, CelResult, - CelValue, RsCelFunction, RsCelMacro, -}; - -use types::CelStackValue; - -use self::types::RsCallable; - -struct InterpStack<'a, 'b> { - stack: Vec>, - - ctx: &'a Interpreter<'b>, -} - -impl<'a, 'b> InterpStack<'a, 'b> { - fn new(ctx: &'b Interpreter) -> InterpStack<'a, 'b> { - InterpStack { - stack: Vec::new(), - ctx, - } - } - - fn push(&mut self, val: CelStackValue<'b>) { - self.stack.push(val); - } - - fn push_val(&mut self, val: CelValue) { - self.stack.push(CelStackValue::Value(val)); - } - - fn pop(&mut self) -> CelResult { - match self.stack.pop() { - Some(stack_val) => { - if let CelStackValue::Value(val) = stack_val { - if let CelValue::Ident(name) = val { - if let Some(val) = self.ctx.get_type_by_name(&name) { - return Ok(CelStackValue::Value(val.clone())); - } - - if let Some(val) = self.ctx.get_param_by_name(&name) { - return Ok(CelStackValue::Value(val.clone())); - } - - if let Some(ctx) = self.ctx.cel { - // Allow for loaded programs to run as values - if let Some(prog) = ctx.get_program(&name) { - return self.ctx.run_raw(prog.bytecode(), true).map(|x| x.into()); - } - } - - Ok(CelValue::from_err(CelError::binding(&name)).into()) - } else { - Ok(val.into()) - } - } else { - Ok(stack_val) - } - } - None => Err(CelError::runtime("No value on stack!")), - } - } - - fn pop_val(&mut self) -> CelResult { - self.pop()?.into_value() - } - - fn pop_noresolve(&mut self) -> CelResult> { - match self.stack.pop() { - Some(val) => Ok(val), - None => Err(CelError::runtime("No value on stack!")), - } - } - - fn pop_tryresolve(&mut self) -> CelResult> { - match self.stack.pop() { - Some(val) => match val.try_into()? { - CelValue::Ident(name) => { - if let Some(val) = self.ctx.get_param_by_name(&name) { - Ok(val.clone().into()) - } else { - Ok(CelStackValue::Value(CelValue::from_ident(&name))) - } - } - other => Ok(CelStackValue::Value(other.into())), - }, - None => Err(CelError::runtime("No value on stack!")), - } - } -} - -impl<'a, 'b> fmt::Debug for InterpStack<'a, 'b> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{:?}", self.stack) - } -} - -pub struct Interpreter<'a> { - cel: Option<&'a CelContext>, - bindings: Option<&'a BindContext<'a>>, - depth: ScopedCounter, -} - -impl<'a> Interpreter<'a> { - pub fn new(cel: &'a CelContext, bindings: &'a BindContext) -> Interpreter<'a> { - Interpreter { - cel: Some(cel), - bindings: Some(bindings), - depth: ScopedCounter::new(), - } - } - - pub fn empty() -> Interpreter<'a> { - Interpreter { - cel: None, - bindings: None, - depth: ScopedCounter::new(), - } - } - - pub fn add_bindings(&mut self, bindings: &'a BindContext) { - self.bindings = Some(bindings); - } - - pub fn cel_copy(&self) -> Option { - self.cel.cloned() - } - - pub fn bindings_copy(&self) -> Option { - self.bindings.cloned() - } - - pub fn run_program(&self, name: &str) -> CelResult { - match self.cel { - Some(cel) => match cel.get_program(name) { - Some(prog) => self.run_raw(prog.bytecode(), true), - None => Err(CelError::binding(&name)), - }, - None => Err(CelError::internal("No CEL context bound to interpreter")), - } - } - - pub fn run_raw(&self, prog: &CelByteCode, resolve: bool) -> CelResult { - let mut pc: usize = 0; - let mut stack = InterpStack::new(self); - - let count = self.depth.inc(); - - if count.count() > 32 { - return Err(CelError::runtime("Max call depth excceded")); - } - - while pc < prog.len() { - let oldpc = pc; - pc += 1; - match &prog[oldpc] { - ByteCode::Push(val) => stack.push(val.clone().into()), - ByteCode::Or => { - let v2 = stack.pop_val()?; - let v1 = stack.pop_val()?; - - stack.push_val(v1.or(&v2)) - } - ByteCode::And => { - let v2 = stack.pop_val()?; - let v1 = stack.pop_val()?; - - stack.push_val(v1.and(v2)) - } - ByteCode::Not => { - let v1 = stack.pop_val()?; - - stack.push_val(!v1); - } - ByteCode::Neg => { - let v1 = stack.pop_val()?; - - stack.push_val(-v1); - } - ByteCode::Add => { - let v2 = stack.pop_val()?; - let v1 = stack.pop_val()?; - - stack.push_val(v1 + v2); - } - ByteCode::Sub => { - let v2 = stack.pop_val()?; - let v1 = stack.pop_val()?; - - stack.push_val(v1 - v2); - } - ByteCode::Mul => { - let v2 = stack.pop_val()?; - let v1 = stack.pop_val()?; - - stack.push_val(v1 * v2); - } - ByteCode::Div => { - let v2 = stack.pop_val()?; - let v1 = stack.pop_val()?; - - stack.push_val(v1 / v2); - } - ByteCode::Mod => { - let v2 = stack.pop_val()?; - let v1 = stack.pop_val()?; - - stack.push_val(v1 % v2); - } - ByteCode::Lt => { - let v2 = stack.pop_val()?; - let v1 = stack.pop_val()?; - - stack.push_val(v1.lt(v2)); - } - ByteCode::Le => { - let v2 = stack.pop_val()?; - let v1 = stack.pop_val()?; - - stack.push_val(v1.le(v2)); - } - ByteCode::Eq => { - let v2 = stack.pop_val()?; - let v1 = stack.pop_val()?; - - stack.push_val(CelValueDyn::eq(&v1, &v2)); - } - ByteCode::Ne => { - let v2 = stack.pop_val()?; - let v1 = stack.pop_val()?; - - stack.push_val(v1.neq(v2)); - } - ByteCode::Ge => { - let v2 = stack.pop_val()?; - let v1 = stack.pop_val()?; - - stack.push_val(v1.ge(v2)); - } - ByteCode::Gt => { - let v2 = stack.pop_val()?; - let v1 = stack.pop_val()?; - - stack.push_val(v1.gt(v2)); - } - ByteCode::In => { - let rhs = stack.pop_val()?; - let lhs = stack.pop_val()?; - - stack.push_val(lhs.in_(rhs)); - } - ByteCode::Jmp(dist) => pc = pc + *dist as usize, - ByteCode::JmpCond { - when, - dist, - leave_val, - } => { - let mut v1 = stack.pop_val()?; - match when { - JmpWhen::True => { - if cfg!(feature = "type_prop") { - if v1.is_truthy() { - v1 = CelValue::true_(); - pc += *dist as usize - } - } else if let CelValue::Err(ref _e) = v1 { - // do nothing - } else if let CelValue::Bool(v) = v1 { - if v { - pc += *dist as usize - } - } else { - return Err(CelError::invalid_op(&format!( - "JMP TRUE invalid on type {:?}", - v1.as_type() - ))); - } - } - JmpWhen::False => { - if cfg!(feature = "type_prop") { - if !v1.is_truthy() { - v1 = CelValue::false_(); - pc += *dist as usize - } - } else if let CelValue::Bool(v) = v1 { - if !v { - pc += *dist as usize - } - } else { - return Err(CelError::invalid_op(&format!( - "JMP FALSE invalid on type {:?}", - v1.as_type() - ))); - } - } - }; - if *leave_val { - stack.push_val(v1); - } - } - ByteCode::MkList(size) => { - let mut v = Vec::new(); - - for _ in 0..*size { - v.push(stack.pop_val()?) - } - - v.reverse(); - stack.push_val(v.into()); - } - ByteCode::MkDict(size) => { - let mut map = HashMap::new(); - - for _ in 0..*size { - let key = if let CelValue::String(key) = stack.pop_val()? { - key - } else { - return Err(CelError::value("Only strings can be used as Object keys")); - }; - - map.insert(key, stack.pop_val()?); - } - - stack.push_val(map.into()); - } - ByteCode::Index => { - let index = stack.pop_val()?; - let obj = stack.pop_val()?; - - stack.push_val(obj.index(index)); - } - ByteCode::Access => { - let index = stack.pop_noresolve()?; - if let CelValue::Ident(ident) = index.as_value()? { - let obj = stack.pop()?.into_value()?; - match obj { - CelValue::Map(ref map) => match map.get(ident.as_str()) { - Some(val) => stack.push_val(val.clone()), - None => match self.callable_by_name(ident.as_str()) { - Ok(callable) => stack.push(CelStackValue::BoundCall { - callable, - value: obj, - }), - Err(_) => { - stack.push( - CelValue::from_err(CelError::attribute( - "obj", - ident.as_str(), - )) - .into(), - ); - } - }, - }, - #[cfg(feature = "protobuf")] - CelValue::Message(msg) => { - let desc = msg.descriptor_dyn(); - - if let Some(field) = desc.field_by_name(ident.as_str()) { - stack.push_val( - field.get_singular_field_or_default(msg.as_ref()).into(), - ) - } else { - return Err(CelError::attribute("msg", ident.as_str())); - } - } - CelValue::Dyn(d) => { - stack.push_val(d.access(ident.as_str())); - } - _ => { - if let Some(bindings) = self.bindings { - if bindings.get_func(ident.as_str()).is_some() - || bindings.get_macro(ident.as_str()).is_some() - { - stack.push(CelStackValue::BoundCall { - callable: self.callable_by_name(ident.as_str())?, - value: obj, - }); - } else { - stack.push( - CelValue::from_err(CelError::attribute( - "obj", - ident.as_str(), - )) - .into(), - ); - } - } else { - return Err(CelError::Runtime( - "Invalid state: no bindings".to_string(), - )); - } - } - } - } else { - let obj_type = stack.pop()?.into_value()?.as_type(); - stack.push( - CelValue::from_err(CelError::value(&format!( - "Index operator invalid between {:?} and {:?}", - index.into_value()?.as_type(), - obj_type - ))) - .into(), - ); - } - } - ByteCode::Call(n_args) => { - let mut args = Vec::new(); - - for _ in 0..*n_args { - args.push(stack.pop()?.into_value()?) - } - - match stack.pop_noresolve()? { - CelStackValue::BoundCall { callable, value } => match callable { - RsCallable::Function(func) => { - let arg_values = self.resolve_args(args)?; - stack.push_val(func(value, arg_values)); - } - RsCallable::Macro(macro_) => { - stack.push_val(self.call_macro(&value, &args, macro_)?); - } - }, - CelStackValue::Value(value) => match value { - CelValue::Ident(func_name) => { - if let Some(func) = self.get_func_by_name(&func_name) { - let arg_values = self.resolve_args(args)?; - stack.push_val(func(CelValue::from_null(), arg_values)); - } else if let Some(macro_) = self.get_macro_by_name(&func_name) { - stack.push_val(self.call_macro( - &CelValue::from_null(), - &args, - macro_, - )?); - } else if let Some(CelValue::Type(type_name)) = - self.get_type_by_name(&func_name) - { - let arg_values = self.resolve_args(args)?; - stack.push_val(construct_type(type_name, arg_values)); - } else { - stack.push_val(CelValue::from_err(CelError::runtime( - &format!("{} is not callable", func_name), - ))); - } - } - CelValue::Type(type_name) => { - let arg_values = self.resolve_args(args)?; - stack.push_val(construct_type(&type_name, arg_values)); - } - other => stack.push_val( - CelValue::from_err(CelError::runtime(&format!( - "{:?} cannot be called", - other - ))) - .into(), - ), - }, - }; - } - ByteCode::FmtString(nsegments) => { - let mut segments = Vec::new(); - for _ in 0..*nsegments { - segments.push(stack.pop_val()?); - } - - let mut working = String::new(); - for seg in segments.into_iter().rev() { - if let CelValue::String(s) = seg { - working.push_str(&s) - } else { - return Err(CelError::Runtime( - "Expected string from format string specifier".to_string(), - )); - } - } - - stack.push_val(CelValue::String(working)); - } - }; - } - - if resolve { - match stack.pop() { - Ok(val) => { - let cel: CelValue = val.try_into()?; - cel.into_result() - } - Err(err) => Err(err), - } - } else { - match stack.pop_tryresolve() { - Ok(val) => { - let cel: CelValue = val.try_into()?; - cel.into_result() - } - Err(err) => Err(err), - } - } - } - - fn call_macro( - &self, - this: &CelValue, - args: &Vec, - macro_: &RsCelMacro, - ) -> Result { - let mut v = Vec::new(); - for arg in args.iter() { - if let CelValue::ByteCode(bc) = arg { - v.push(bc); - } else { - return Err(CelError::internal("macro args must be bytecode")); - } - } - let res = macro_(self, this.clone(), &v); - Ok(res) - } - - fn resolve_args(&self, args: Vec) -> Result, CelError> { - let mut arg_values = Vec::new(); - for arg in args.into_iter() { - if let CelValue::ByteCode(bc) = arg { - arg_values.push(self.run_raw(&bc, true)?); - } else { - arg_values.push(arg) - } - } - Ok(arg_values) - } - - fn get_param_by_name(&self, name: &str) -> Option<&'a CelValue> { - self.bindings?.get_param(name) - } - - fn get_func_by_name(&self, name: &str) -> Option<&'a RsCelFunction> { - self.bindings?.get_func(name) - } - - fn get_macro_by_name(&self, name: &str) -> Option<&'a RsCelMacro> { - self.bindings?.get_macro(name) - } - - fn get_type_by_name(&self, name: &str) -> Option<&'a CelValue> { - self.bindings?.get_type(name) - } - - fn callable_by_name(&self, name: &str) -> CelResult { - if let Some(func) = self.get_func_by_name(name) { - Ok(RsCallable::Function(func)) - } else if let Some(macro_) = self.get_macro_by_name(name) { - Ok(RsCallable::Macro(macro_)) - } else { - Err(CelError::value(&format!("{} is not callable", name))) - } - } -} +pub use interp::Interpreter; +pub use types::*; #[cfg(test)] mod test { diff --git a/rscel/src/interp/types.rs b/rscel/src/interp/types.rs index 08ed349..3f11874 100644 --- a/rscel/src/interp/types.rs +++ b/rscel/src/interp/types.rs @@ -1,153 +1,7 @@ -use serde::{Deserialize, Serialize}; +mod bytecode; +mod celstackvalue; +mod rscallable; -use crate::{CelError, CelResult, CelValue, RsCelFunction, RsCelMacro}; -use std::fmt; - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub enum JmpWhen { - True, - False, -} - -#[derive(Clone, PartialEq, Serialize, Deserialize)] -pub enum ByteCode { - Push(CelValue), - Or, - And, - Not, - Neg, - Add, - Sub, - Mul, - Div, - Mod, - Lt, - Le, - Eq, - Ne, - Ge, - Gt, - In, - Jmp(u32), - JmpCond { - when: JmpWhen, - dist: u32, - leave_val: bool, - }, - MkList(u32), - MkDict(u32), - Index, - Access, - Call(u32), - FmtString(u32), -} - -impl fmt::Debug for ByteCode { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use ByteCode::*; - - match self { - Push(val) => write!(f, "PUSH {:?}", val), - Or => write!(f, "OR"), - And => write!(f, "AND"), - Not => write!(f, "NOT"), - Neg => write!(f, "NEG"), - Add => write!(f, "ADD"), - Sub => write!(f, "SUB"), - Mul => write!(f, "MUL"), - Div => write!(f, "DIV"), - Mod => write!(f, "MOD"), - Lt => write!(f, "LT"), - Le => write!(f, "LE"), - Eq => write!(f, "EQ"), - Ne => write!(f, "NE"), - Ge => write!(f, "GE"), - Gt => write!(f, "GT"), - In => write!(f, "IN"), - Jmp(dist) => write!(f, "JMP {}", dist), - JmpCond { - when, - dist, - leave_val: _, - } => write!(f, "JMP {:?} {}", when, dist), - MkList(size) => write!(f, "MKLIST {}", size), - MkDict(size) => write!(f, "MKDICT {}", size), - Index => write!(f, "INDEX"), - Access => write!(f, "ACCESS"), - Call(size) => write!(f, "CALL {}", size), - FmtString(size) => write!(f, "FMT {}", size), - } - } -} - -/// Wrapper enum that contains either an RsCelCallable or an RsCelFunction. Used -/// as a ValueCell value. -#[derive(Clone)] -pub enum RsCallable<'a> { - Function(&'a RsCelFunction), - Macro(&'a RsCelMacro), -} - -impl<'a> fmt::Debug for RsCallable<'a> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Function(_) => write!(f, "Function"), - Self::Macro(_) => write!(f, "Macro"), - } - } -} - -impl<'a> PartialEq for RsCallable<'a> { - fn eq(&self, _other: &Self) -> bool { - false - } -} - -#[derive(Debug, Clone)] -pub enum CelStackValue<'a> { - Value(CelValue), - BoundCall { - callable: RsCallable<'a>, - value: CelValue, - }, -} - -impl<'a> CelStackValue<'a> { - pub fn into_value(self) -> CelResult { - match self { - CelStackValue::Value(val) => Ok(val), - _ => Err(CelError::internal("Expected value")), - } - } - - pub fn as_value(&'a self) -> CelResult<&'a CelValue> { - match self { - CelStackValue::Value(val) => Ok(val), - _ => Err(CelError::internal("Expected value")), - } - } - - pub fn as_bound_call(&'a self) -> Option<(&'a RsCallable<'a>, &'a CelValue)> { - match self { - CelStackValue::BoundCall { callable, value } => Some((callable, value)), - _ => None, - } - } -} - -impl<'a> Into> for CelValue { - fn into(self) -> CelStackValue<'a> { - CelStackValue::Value(self) - } -} - -impl<'a> TryInto for CelStackValue<'a> { - type Error = CelError; - fn try_into(self) -> Result { - if let CelStackValue::Value(val) = self { - Ok(val) - } else { - Err(CelError::internal("Expected value 2")) - } - } -} +pub use bytecode::{ByteCode, JmpWhen}; +pub use celstackvalue::CelStackValue; +pub use rscallable::RsCallable; diff --git a/rscel/src/interp/types/bytecode.rs b/rscel/src/interp/types/bytecode.rs new file mode 100644 index 0000000..f545ede --- /dev/null +++ b/rscel/src/interp/types/bytecode.rs @@ -0,0 +1,82 @@ +use std::fmt; + +use serde::{Deserialize, Serialize}; + +use crate::CelValue; + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub enum JmpWhen { + True, + False, +} + +#[derive(Clone, PartialEq, Serialize, Deserialize)] +pub enum ByteCode { + Push(CelValue), + Or, + And, + Not, + Neg, + Add, + Sub, + Mul, + Div, + Mod, + Lt, + Le, + Eq, + Ne, + Ge, + Gt, + In, + Jmp(i32), + JmpCond { + when: JmpWhen, + dist: i32, + leave_val: bool, + }, + MkList(u32), + MkDict(u32), + Index, + Access, + Call(u32), + FmtString(u32), +} + +impl fmt::Debug for ByteCode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use ByteCode::*; + + match self { + Push(val) => write!(f, "PUSH {:?}", val), + Or => write!(f, "OR"), + And => write!(f, "AND"), + Not => write!(f, "NOT"), + Neg => write!(f, "NEG"), + Add => write!(f, "ADD"), + Sub => write!(f, "SUB"), + Mul => write!(f, "MUL"), + Div => write!(f, "DIV"), + Mod => write!(f, "MOD"), + Lt => write!(f, "LT"), + Le => write!(f, "LE"), + Eq => write!(f, "EQ"), + Ne => write!(f, "NE"), + Ge => write!(f, "GE"), + Gt => write!(f, "GT"), + In => write!(f, "IN"), + Jmp(dist) => write!(f, "JMP {}", dist), + JmpCond { + when, + dist, + leave_val: _, + } => write!(f, "JMP {:?} {}", when, dist), + MkList(size) => write!(f, "MKLIST {}", size), + MkDict(size) => write!(f, "MKDICT {}", size), + Index => write!(f, "INDEX"), + Access => write!(f, "ACCESS"), + Call(size) => write!(f, "CALL {}", size), + FmtString(size) => write!(f, "FMT {}", size), + } + } +} diff --git a/rscel/src/interp/types/celstackvalue.rs b/rscel/src/interp/types/celstackvalue.rs new file mode 100644 index 0000000..d677842 --- /dev/null +++ b/rscel/src/interp/types/celstackvalue.rs @@ -0,0 +1,52 @@ +use crate::{CelError, CelResult, CelValue}; + +use super::RsCallable; + +#[derive(Debug, Clone)] +pub enum CelStackValue<'a> { + Value(CelValue), + BoundCall { + callable: RsCallable<'a>, + value: CelValue, + }, +} + +impl<'a> CelStackValue<'a> { + pub fn into_value(self) -> CelResult { + match self { + CelStackValue::Value(val) => Ok(val), + _ => Err(CelError::internal("Expected value")), + } + } + + pub fn as_value(&'a self) -> CelResult<&'a CelValue> { + match self { + CelStackValue::Value(val) => Ok(val), + _ => Err(CelError::internal("Expected value")), + } + } + + pub fn as_bound_call(&'a self) -> Option<(&'a RsCallable<'a>, &'a CelValue)> { + match self { + CelStackValue::BoundCall { callable, value } => Some((callable, value)), + _ => None, + } + } +} + +impl<'a> Into> for CelValue { + fn into(self) -> CelStackValue<'a> { + CelStackValue::Value(self) + } +} + +impl<'a> TryInto for CelStackValue<'a> { + type Error = CelError; + fn try_into(self) -> Result { + if let CelStackValue::Value(val) = self { + Ok(val) + } else { + Err(CelError::internal("Expected value 2")) + } + } +} diff --git a/rscel/src/interp/types/rscallable.rs b/rscel/src/interp/types/rscallable.rs new file mode 100644 index 0000000..117eba3 --- /dev/null +++ b/rscel/src/interp/types/rscallable.rs @@ -0,0 +1,26 @@ +use std::fmt; + +use crate::{RsCelFunction, RsCelMacro}; + +/// Wrapper enum that contains either an RsCelCallable or an RsCelFunction. Used +/// as a ValueCell value. +#[derive(Clone)] +pub enum RsCallable<'a> { + Function(&'a RsCelFunction), + Macro(&'a RsCelMacro), +} + +impl<'a> fmt::Debug for RsCallable<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Function(_) => write!(f, "Function"), + Self::Macro(_) => write!(f, "Macro"), + } + } +} + +impl<'a> PartialEq for RsCallable<'a> { + fn eq(&self, _other: &Self) -> bool { + false + } +} From 28b561ce751eb12f5a95992503365cc558c0a199 Mon Sep 17 00:00:00 2001 From: 1BADragon <6611786+1BADragon@users.noreply.github.com> Date: Sun, 26 Jan 2025 13:57:13 -0800 Subject: [PATCH 04/16] add labels to turnary --- rscel/src/compiler/compiled_prog.rs | 92 +++-------------------------- rscel/src/compiler/compiler.rs | 92 ++++++++++++++++++++++++++++- 2 files changed, 99 insertions(+), 85 deletions(-) diff --git a/rscel/src/compiler/compiled_prog.rs b/rscel/src/compiler/compiled_prog.rs index a9ff3b0..40da0b6 100644 --- a/rscel/src/compiler/compiled_prog.rs +++ b/rscel/src/compiler/compiled_prog.rs @@ -1,9 +1,6 @@ mod preresolved; -use crate::{ - interp::JmpWhen, program::ProgramDetails, types::CelByteCode, ByteCode, CelError, CelValue, - CelValueDyn, Program, -}; +use crate::{program::ProgramDetails, types::CelByteCode, ByteCode, CelValue, Program}; pub use preresolved::{PreResolvedByteCode, PreResolvedCodePoint}; #[derive(Debug, Clone)] @@ -91,6 +88,14 @@ impl CompiledProg { } } + pub fn details(&self) -> &ProgramDetails { + &self.details + } + + pub fn into_parts(self) -> (NodeValue, ProgramDetails) { + (self.inner, self.details) + } + pub fn append_if_bytecode(&mut self, b: impl IntoIterator) { match &mut self.inner { NodeValue::Bytecode(bytecode) => { @@ -219,85 +224,6 @@ impl CompiledProg { r } - pub fn into_turnary( - mut self, - true_clause: CompiledProg, - false_clause: CompiledProg, - ) -> CompiledProg { - self.details.union_from(true_clause.details); - self.details.union_from(false_clause.details); - - if let NodeValue::ConstExpr(i) = self.inner { - if i.is_err() { - CompiledProg { - inner: NodeValue::ConstExpr(i), - details: self.details, - } - } else { - if cfg!(feature = "type_prop") { - if i.is_truthy() { - CompiledProg { - inner: true_clause.inner, - details: self.details, - } - } else { - CompiledProg { - inner: false_clause.inner, - details: self.details, - } - } - } else { - if let CelValue::Bool(b) = i { - if b { - CompiledProg { - inner: true_clause.inner, - details: self.details, - } - } else { - CompiledProg { - inner: false_clause.inner, - details: self.details, - } - } - } else { - CompiledProg { - inner: NodeValue::ConstExpr(CelValue::from_err(CelError::Value( - format!("{} cannot be converted to bool", i.as_type()), - ))), - details: self.details, - } - } - } - } - } else { - let true_clause_bytecode = true_clause.inner.into_bytecode(); - let false_clause_bytecode = false_clause.inner.into_bytecode(); - CompiledProg { - inner: NodeValue::Bytecode( - self.inner - .into_bytecode() - .into_iter() - .chain( - [PreResolvedCodePoint::Bytecode(ByteCode::JmpCond { - when: JmpWhen::False, - dist: i32::try_from(true_clause_bytecode.len() + 1) - .expect("Jump distance too far"), - leave_val: false, - })] - .into_iter(), - ) - .chain(true_clause_bytecode.into_iter()) - .chain( - [ByteCode::Jmp(false_clause_bytecode.len() as i32).into()].into_iter(), - ) - .chain(false_clause_bytecode.into_iter()) - .collect(), - ), - details: self.details, - } - } - } - #[inline] pub fn bytecode_len(&self) -> usize { match self.inner { diff --git a/rscel/src/compiler/compiler.rs b/rscel/src/compiler/compiler.rs index b7689a7..44d44b1 100644 --- a/rscel/src/compiler/compiler.rs +++ b/rscel/src/compiler/compiler.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use super::{ ast_node::AstNode, - compiled_prog::{CompiledProg, PreResolvedCodePoint}, + compiled_prog::{CompiledProg, NodeValue, PreResolvedCodePoint}, grammar::*, source_range::SourceRange, syntax_error::SyntaxError, @@ -64,7 +64,10 @@ impl<'l> CelCompiler<'l> { match self.tokenizer.peek()?.as_token() { Some(Token::Question) => { self.tokenizer.next()?; + let (expr_node, mut details) = lhs_node.into_parts(); + let (true_clause_node, true_clause_ast) = self.parse_conditional_or()?; + let (true_clause_node, true_clause_details) = true_clause_node.into_parts(); let next = self.tokenizer.next()?; if next.as_token() != Some(&Token::Colon) { @@ -74,11 +77,96 @@ impl<'l> CelCompiler<'l> { } let (false_clause_node, false_clause_ast) = self.parse_expression()?; + let (false_clause_node, false_clause_details) = false_clause_node.into_parts(); let range = lhs_ast.range().surrounding(false_clause_ast.range()); + details.union_from(true_clause_details); + details.union_from(false_clause_details); + + let turnary_node = if let NodeValue::ConstExpr(i) = expr_node { + if i.is_err() { + CompiledProg { + inner: NodeValue::ConstExpr(i), + details, + } + } else { + if cfg!(feature = "type_prop") { + if i.is_truthy() { + CompiledProg { + inner: true_clause_node, + details, + } + } else { + CompiledProg { + inner: false_clause_node, + details, + } + } + } else { + if let CelValue::Bool(b) = i { + if b { + CompiledProg { + inner: true_clause_node, + details, + } + } else { + CompiledProg { + inner: false_clause_node, + details, + } + } + } else { + CompiledProg { + inner: NodeValue::ConstExpr(CelValue::from_err( + CelError::Value(format!( + "{} cannot be converted to bool", + i.as_type() + )), + )), + details, + } + } + } + } + } else { + let true_clause_bytecode = true_clause_node.into_bytecode(); + let false_clause_bytecode = false_clause_node.into_bytecode(); + + let after_true_clause = self.get_label(); + let end_label = self.get_label(); + + CompiledProg { + inner: NodeValue::Bytecode( + expr_node + .into_bytecode() + .into_iter() + .chain( + [PreResolvedCodePoint::JmpCond { + when: JmpWhen::False, + label: after_true_clause, + leave_val: false, + }] + .into_iter(), + ) + .chain(true_clause_bytecode.into_iter()) + .chain( + [ + PreResolvedCodePoint::Jmp { label: end_label }, + PreResolvedCodePoint::Label(after_true_clause), + ] + .into_iter(), + ) + .chain(false_clause_bytecode.into_iter()) + .chain([PreResolvedCodePoint::Label(end_label)].into_iter()) + .collect(), + ), + details, + } + }; + Ok(( - lhs_node.into_turnary(true_clause_node, false_clause_node), + turnary_node, AstNode::new( Expr::Ternary { condition: Box::new(lhs_ast), From 4bef7a661479cb621751e0727c9591652fe0b8c6 Mon Sep 17 00:00:00 2001 From: 1BADragon <6611786+1BADragon@users.noreply.github.com> Date: Sat, 15 Feb 2025 21:41:57 -0800 Subject: [PATCH 05/16] Removed the leave_val argument in jmpwhen --- .../src/compiler/compiled_prog/preresolved.rs | 17 +++--------- rscel/src/compiler/compiler.rs | 22 +++++++++------- rscel/src/interp/interp.rs | 26 ++++++++++--------- rscel/src/interp/types/bytecode.rs | 16 +++++------- 4 files changed, 35 insertions(+), 46 deletions(-) diff --git a/rscel/src/compiler/compiled_prog/preresolved.rs b/rscel/src/compiler/compiled_prog/preresolved.rs index b05cca7..d57dffe 100644 --- a/rscel/src/compiler/compiled_prog/preresolved.rs +++ b/rscel/src/compiler/compiled_prog/preresolved.rs @@ -5,14 +5,8 @@ use crate::{interp::JmpWhen, types::CelByteCode, ByteCode}; #[derive(Debug, Clone)] pub enum PreResolvedCodePoint { Bytecode(ByteCode), - Jmp { - label: u32, - }, - JmpCond { - when: JmpWhen, - label: u32, - leave_val: bool, - }, + Jmp { label: u32 }, + JmpCond { when: JmpWhen, label: u32 }, Label(u32), } @@ -98,18 +92,13 @@ impl PreResolvedByteCode { i32::try_from(offset).expect("Attempt to jump farther than possible"), )); } - PreResolvedCodePoint::JmpCond { - when, - label, - leave_val, - } => { + PreResolvedCodePoint::JmpCond { when, label } => { curr_loc += 1; let jmp_loc = locations[&label]; let offset = (jmp_loc as isize) - (curr_loc as isize); ret.push(ByteCode::JmpCond { when, dist: offset as i32, - leave_val, }); } PreResolvedCodePoint::Label(_) => {} diff --git a/rscel/src/compiler/compiler.rs b/rscel/src/compiler/compiler.rs index 44d44b1..78d892c 100644 --- a/rscel/src/compiler/compiler.rs +++ b/rscel/src/compiler/compiler.rs @@ -145,7 +145,6 @@ impl<'l> CelCompiler<'l> { [PreResolvedCodePoint::JmpCond { when: JmpWhen::False, label: after_true_clause, - leave_val: false, }] .into_iter(), ) @@ -197,13 +196,14 @@ impl<'l> CelCompiler<'l> { self.tokenizer.next()?; let (rhs_node, rhs_ast) = self.parse_conditional_and()?; - let jmp_node = - CompiledProg::with_code_points(vec![PreResolvedCodePoint::JmpCond { + let jmp_node = CompiledProg::with_code_points(vec![ + PreResolvedCodePoint::Bytecode(ByteCode::Test), + PreResolvedCodePoint::Bytecode(ByteCode::Dup), + PreResolvedCodePoint::JmpCond { when: JmpWhen::True, label, - leave_val: true, - } - .into()]); + }, + ]); let range = current_ast.range().surrounding(rhs_ast.range()); @@ -241,12 +241,14 @@ impl<'l> CelCompiler<'l> { self.tokenizer.next()?; let (rhs_node, rhs_ast) = self.parse_relation()?; - let jmp_node = - CompiledProg::with_code_points(vec![PreResolvedCodePoint::JmpCond { + let jmp_node = CompiledProg::with_code_points(vec![ + PreResolvedCodePoint::Bytecode(ByteCode::Test), + PreResolvedCodePoint::Bytecode(ByteCode::Dup), + PreResolvedCodePoint::JmpCond { when: JmpWhen::False, label: label, - leave_val: true, - }]); + }, + ]); let range = current_ast.range().surrounding(rhs_ast.range()); diff --git a/rscel/src/interp/interp.rs b/rscel/src/interp/interp.rs index 44461da..ec946cc 100644 --- a/rscel/src/interp/interp.rs +++ b/rscel/src/interp/interp.rs @@ -154,7 +154,18 @@ impl<'a> Interpreter<'a> { let oldpc = pc; pc += 1; match &prog[oldpc] { - ByteCode::Push(val) => stack.push(val.clone().into()), + ByteCode::Push(val) => stack.push_val(val.clone()), + ByteCode::Test => { + let v = stack.pop_val()?; + + stack.push_val(v.is_truthy().into()); + } + ByteCode::Dup => { + let v = stack.pop_val()?; + + stack.push_val(v.clone()); + stack.push_val(v); + } ByteCode::Or => { let v2 = stack.pop_val()?; let v1 = stack.pop_val()?; @@ -250,17 +261,12 @@ impl<'a> Interpreter<'a> { stack.push_val(lhs.in_(rhs)); } ByteCode::Jmp(dist) => pc = pc + *dist as usize, - ByteCode::JmpCond { - when, - dist, - leave_val, - } => { - let mut v1 = stack.pop_val()?; + ByteCode::JmpCond { when, dist } => { + let v1 = stack.pop_val()?; match when { JmpWhen::True => { if cfg!(feature = "type_prop") { if v1.is_truthy() { - v1 = CelValue::true_(); pc += *dist as usize } } else if let CelValue::Err(ref _e) = v1 { @@ -279,7 +285,6 @@ impl<'a> Interpreter<'a> { JmpWhen::False => { if cfg!(feature = "type_prop") { if !v1.is_truthy() { - v1 = CelValue::false_(); pc += *dist as usize } } else if let CelValue::Bool(v) = v1 { @@ -294,9 +299,6 @@ impl<'a> Interpreter<'a> { } } }; - if *leave_val { - stack.push_val(v1); - } } ByteCode::MkList(size) => { let mut v = Vec::new(); diff --git a/rscel/src/interp/types/bytecode.rs b/rscel/src/interp/types/bytecode.rs index f545ede..eb5a2de 100644 --- a/rscel/src/interp/types/bytecode.rs +++ b/rscel/src/interp/types/bytecode.rs @@ -13,6 +13,8 @@ pub enum JmpWhen { #[derive(Clone, PartialEq, Serialize, Deserialize)] pub enum ByteCode { Push(CelValue), + Test, + Dup, Or, And, Not, @@ -30,11 +32,7 @@ pub enum ByteCode { Gt, In, Jmp(i32), - JmpCond { - when: JmpWhen, - dist: i32, - leave_val: bool, - }, + JmpCond { when: JmpWhen, dist: i32 }, MkList(u32), MkDict(u32), Index, @@ -49,6 +47,8 @@ impl fmt::Debug for ByteCode { match self { Push(val) => write!(f, "PUSH {:?}", val), + Test => write!(f, "TEST"), + Dup => write!(f, "DUP"), Or => write!(f, "OR"), And => write!(f, "AND"), Not => write!(f, "NOT"), @@ -66,11 +66,7 @@ impl fmt::Debug for ByteCode { Gt => write!(f, "GT"), In => write!(f, "IN"), Jmp(dist) => write!(f, "JMP {}", dist), - JmpCond { - when, - dist, - leave_val: _, - } => write!(f, "JMP {:?} {}", when, dist), + JmpCond { when, dist } => write!(f, "JMP {:?} {}", when, dist), MkList(size) => write!(f, "MKLIST {}", size), MkDict(size) => write!(f, "MKDICT {}", size), Index => write!(f, "INDEX"), From 43e0f79fc68877ad5cccc7be8bc57413dfdbf145 Mon Sep 17 00:00:00 2001 From: 1BADragon <6611786+1BADragon@users.noreply.github.com> Date: Sun, 16 Feb 2025 07:56:03 -0800 Subject: [PATCH 06/16] Fixup the test --- Cargo.toml | 4 +- python/Cargo.toml | 4 +- rscel-macro/Cargo.toml | 6 +-- rscel/Cargo.toml | 18 ++++----- rscel/src/interp/interp.rs | 59 +++++++++++++----------------- rscel/src/interp/types/bytecode.rs | 9 +++++ wasm/Cargo.toml | 6 +-- 7 files changed, 54 insertions(+), 52 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 17d8540..3bfe572 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,5 +18,5 @@ lto = false lto = true [workspace.dependencies] -chrono = { version = "0.4.38", features = ["serde"] } -serde_json = { version = "1.0.121", features = ["raw_value"] } +chrono = { version = "0.4.39", features = ["serde"] } +serde_json = { version = "1.0.138", features = ["raw_value"] } diff --git a/python/Cargo.toml b/python/Cargo.toml index 8d5da7d..a5b13f1 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -11,7 +11,7 @@ crate-type = ["cdylib"] [dependencies] rscel = { path = "../rscel" } -pyo3 = { version = "0.23", features = [ +pyo3 = { version = "0.23.4", features = [ "py-clone", "extension-module", "chrono", @@ -21,4 +21,4 @@ serde_json = { workspace = true } bincode = "1.3.3" [build-dependencies] -pyo3-build-config = "0.23" +pyo3-build-config = "0.23.4" diff --git a/rscel-macro/Cargo.toml b/rscel-macro/Cargo.toml index 9aa118e..098ec3d 100644 --- a/rscel-macro/Cargo.toml +++ b/rscel-macro/Cargo.toml @@ -10,6 +10,6 @@ readme = "../README.md" proc-macro = true [dependencies] -proc-macro2 = "1.0.92" -quote = "1.0.37" -syn = { version = "2.0.91", features = ["full"] } +proc-macro2 = "1.0.93" +quote = "1.0.38" +syn = { version = "2.0.98", features = ["full"] } diff --git a/rscel/Cargo.toml b/rscel/Cargo.toml index 9e37e24..c7c27c6 100644 --- a/rscel/Cargo.toml +++ b/rscel/Cargo.toml @@ -14,18 +14,18 @@ type_prop = [] protobuf = ["dep:protobuf"] [build-dependencies] -protobuf-codegen = { version = "3.4.0" } -protoc-bin-vendored = { version = "3.0.0" } +protobuf-codegen = { version = "3.7.1" } +protoc-bin-vendored = { version = "3.1.0" } [dependencies] -rscel-macro = { path = "../rscel-macro", version = "1.0.4" } +rscel-macro = { path = "../rscel-macro" } test-case = "3.3.1" -regex = "1.10.5" -serde = { version = "1.0.204", features = ["derive", "rc"] } -serde_with = { version = "3.9.0", features = ["chrono"] } +regex = "1.11.1" +serde = { version = "1.0.217", features = ["derive", "rc"] } +serde_with = { version = "3.12.0", features = ["chrono"] } serde_json = { workspace = true } chrono = { workspace = true } -duration-str = "0.11.2" -protobuf = { version = "3.5.0", optional = true } -chrono-tz = "0.9.0" +duration-str = "0.13.0" +protobuf = { version = "3.7.1", optional = true } +chrono-tz = "0.10.1" num-traits = "0.2.19" diff --git a/rscel/src/interp/interp.rs b/rscel/src/interp/interp.rs index ec946cc..7734019 100644 --- a/rscel/src/interp/interp.rs +++ b/rscel/src/interp/interp.rs @@ -158,7 +158,18 @@ impl<'a> Interpreter<'a> { ByteCode::Test => { let v = stack.pop_val()?; - stack.push_val(v.is_truthy().into()); + if v.is_err() { + stack.push_val(v); + } else if cfg!(feature = "type_prop") { + stack.push_val(v.is_truthy().into()); + } else if let CelValue::Bool(b) = v { + stack.push_val(b.into()) + } else { + return Err(CelError::invalid_op(&format!( + "TEST invalid on type {:?}", + v.as_type() + ))); + } } ByteCode::Dup => { let v = stack.pop_val()?; @@ -263,42 +274,24 @@ impl<'a> Interpreter<'a> { ByteCode::Jmp(dist) => pc = pc + *dist as usize, ByteCode::JmpCond { when, dist } => { let v1 = stack.pop_val()?; - match when { - JmpWhen::True => { - if cfg!(feature = "type_prop") { - if v1.is_truthy() { - pc += *dist as usize - } - } else if let CelValue::Err(ref _e) = v1 { - // do nothing - } else if let CelValue::Bool(v) = v1 { - if v { - pc += *dist as usize - } - } else { - return Err(CelError::invalid_op(&format!( - "JMP TRUE invalid on type {:?}", - v1.as_type() - ))); + match v1 { + CelValue::Bool(b) => { + if b == when.as_bool() { + pc += *dist as usize } } - JmpWhen::False => { - if cfg!(feature = "type_prop") { - if !v1.is_truthy() { - pc += *dist as usize - } - } else if let CelValue::Bool(v) = v1 { - if !v { - pc += *dist as usize - } - } else { - return Err(CelError::invalid_op(&format!( - "JMP FALSE invalid on type {:?}", - v1.as_type() - ))); + CelValue::Err(_) => { + if *when == JmpWhen::False { + pc += *dist as usize } } - }; + v => { + return Err(CelError::invalid_op(&format!( + "JMP TRUE invalid on type {:?}", + v.as_type() + ))) + } + } } ByteCode::MkList(size) => { let mut v = Vec::new(); diff --git a/rscel/src/interp/types/bytecode.rs b/rscel/src/interp/types/bytecode.rs index eb5a2de..76c940c 100644 --- a/rscel/src/interp/types/bytecode.rs +++ b/rscel/src/interp/types/bytecode.rs @@ -10,6 +10,15 @@ pub enum JmpWhen { False, } +impl JmpWhen { + pub fn as_bool(&self) -> bool { + match self { + JmpWhen::True => true, + JmpWhen::False => false, + } + } +} + #[derive(Clone, PartialEq, Serialize, Deserialize)] pub enum ByteCode { Push(CelValue), diff --git a/wasm/Cargo.toml b/wasm/Cargo.toml index 8b5b368..1129217 100644 --- a/wasm/Cargo.toml +++ b/wasm/Cargo.toml @@ -13,9 +13,9 @@ console_error_panic_hook = ["dep:console_error_panic_hook"] [dependencies] rscel = { path = "../rscel" } -num = "0.4.2" -wasm-bindgen = "0.2.92" +num = "0.4.3" +wasm-bindgen = "0.2.100" console_error_panic_hook = {version = "0.1.7", optional = true} serde-wasm-bindgen = "0.6.5" -js-sys = "0.3.69" +js-sys = "0.3.77" chrono = { workspace = true } From b740fb9906913bea097e9feb214fb28c6b3adbef Mon Sep 17 00:00:00 2001 From: 1BADragon <6611786+1BADragon@users.noreply.github.com> Date: Sun, 16 Feb 2025 21:20:48 -0800 Subject: [PATCH 07/16] wip --- rscel/src/compiler/string_tokenizer.rs | 2 ++ rscel/src/compiler/tokens.rs | 2 ++ 2 files changed, 4 insertions(+) diff --git a/rscel/src/compiler/string_tokenizer.rs b/rscel/src/compiler/string_tokenizer.rs index 6911d2e..57dc320 100644 --- a/rscel/src/compiler/string_tokenizer.rs +++ b/rscel/src/compiler/string_tokenizer.rs @@ -126,6 +126,7 @@ impl<'l> StringTokenizer<'l> { self.parse_keywords_or_ident("b", &[]) } } + 'c' => self.parse_keywords_or_ident("c", &[("case", Token::Case)]), 'f' => { if let Some('\'') = self.scanner.peek() { self.scanner.next(); @@ -138,6 +139,7 @@ impl<'l> StringTokenizer<'l> { } } 'i' => self.parse_keywords_or_ident("i", &[("in", Token::In)]), + 'm' => self.parse_keywords_or_ident("m", &[("match", Token::Match)]), 'n' => self.parse_keywords_or_ident("n", &[("null", Token::Null)]), 'r' => { if let Some('\'') = self.scanner.peek() { diff --git a/rscel/src/compiler/tokens.rs b/rscel/src/compiler/tokens.rs index 5c85b98..df75439 100644 --- a/rscel/src/compiler/tokens.rs +++ b/rscel/src/compiler/tokens.rs @@ -31,6 +31,8 @@ pub enum Token { NotEqual, // != In, // 'in' Null, // 'null' + Match, // 'match' + Case, // 'case' BoolLit(bool), // true | false IntLit(u64), // [-+]?[0-9]+ UIntLit(u64), // [0-9]+u From 471195cb283b59d8facb3f55b2e7f1d74b522d5c Mon Sep 17 00:00:00 2001 From: matt Date: Fri, 21 Mar 2025 21:13:32 -0700 Subject: [PATCH 08/16] wip: starting pattern gen --- rscel/src/compiler/compiler.rs | 71 +++++++++++++++++++++++++++----- rscel/src/compiler/tokens.rs | 1 + rscel/src/tests/general_tests.rs | 2 +- 3 files changed, 63 insertions(+), 11 deletions(-) diff --git a/rscel/src/compiler/compiler.rs b/rscel/src/compiler/compiler.rs index 87271d1..2b7ba3b 100644 --- a/rscel/src/compiler/compiler.rs +++ b/rscel/src/compiler/compiler.rs @@ -110,9 +110,8 @@ impl<'l> CelCompiler<'l> { let mut range = condition_ast.range(); - let (node_bytecode, mut node_details) = condition_node.into_parts(); - - let mut node_bytecode = node_bytecode.into_bytecode(); + let (mut node_value, mut node_details) = condition_node.into_parts(); + let mut node_bytecode = node_value.into_bytecode(); let next = self.tokenizer.next()?; if next.as_token() != Some(&Token::LBrace) { @@ -125,19 +124,31 @@ impl<'l> CelCompiler<'l> { let mut all_parts = Vec::new(); + let mut comma_seen = true; + loop { - let lbrace = self.tokenizer.peek()?; - if lbrace.as_token() != Some(&Token::LBrace) { - range = range.surrounding(lbrace.unwrap().loc); + // the rbrace at the end of the match + let rbrace = self.tokenizer.peek()?; + if rbrace.as_token() == Some(&Token::RBrace) { + range = range.surrounding(rbrace.unwrap().loc); break; } + if !comma_seen { + return Err(SyntaxError::from_location(self.tokenizer.location()) + .with_message(format!("Expected COMMA")) + .into()); + } + comma_seen = false; + + // case let case_token = self.tokenizer.next()?; if case_token.as_token() != Some(&Token::Case) { return Err(SyntaxError::from_location(self.tokenizer.location()) .with_message(format!("Unexpected token {:?}, expected CASE", next)) .into()); } + //pattern let (pattern_prog, pattern_ast) = self.parse_match_pattern()?; let (pattern_bytecode, pattern_details) = pattern_prog.into_parts(); let pattern_bytecode: Vec<_> = [ByteCode::Dup] @@ -149,6 +160,7 @@ impl<'l> CelCompiler<'l> { let pattern_range = pattern_ast.range(); + // colon after pattern let colon_token = self.tokenizer.next()?; if colon_token.as_token() != Some(&Token::Colon) { return Err(SyntaxError::from_location(self.tokenizer.location()) @@ -156,6 +168,7 @@ impl<'l> CelCompiler<'l> { .into()); } + // eval expression let (expr_prog, expr_ast) = self.parse_expression()?; let (expr_bytecode, expr_details) = expr_prog.into_parts(); let expr_bytecode: Vec<_> = [ByteCode::Pop] @@ -175,12 +188,23 @@ impl<'l> CelCompiler<'l> { }, case_range, )); + // + // comma after pattern + let comma_token = self.tokenizer.peek()?; + if comma_token.as_token() == Some(&Token::Comma) { + comma_seen = true; + self.tokenizer.next()?; + } } - let mut pattern_segment = CelByteCode::new(); - let mut expr_segment = CelByteCode::new(); + let after_match_l = self + + for (pattern_bytecode, expr_bytecode) in all_parts.into_iter().rev() { + + node_bytecode.push(ByteCode::Dup); - for (pattern_bytecode, expr_bytecode) in all_parts.into_iter().rev() {} + + } Ok(( CompiledProg::new(NodeValue::Bytecode(node_bytecode), node_details), @@ -195,7 +219,34 @@ impl<'l> CelCompiler<'l> { } fn parse_match_pattern(&mut self) -> CelResult<(CompiledProg, AstNode)> { - todo!() + let start = self.tokenizer.location(); + + if let Some(t) = self.tokenizer.next()? { + match t.token { + Token::UnderScore => { + let range = SourceRange::new(start, self.tokenizer.location()); + return Ok(( + CompiledProg::with_bytecode(CelByteCode::from_vec(vec![ + ByteCode::Pop, // pop off the pattern value + ByteCode::Push(CelValue::true_()), // push true + ])), + AstNode::new( + MatchPattern::Any(AstNode::new(MatchAnyPattern {}, range)), + range, + ), + )); + } + other => { + return Err(SyntaxError::from_location(self.tokenizer.location()) + .with_message(format!("Expected PATTERN got {:?}", other)) + .into()) + } + } + } else { + return Err(SyntaxError::from_location(self.tokenizer.location()) + .with_message(format!("Expected PATTERN got NOTHING")) + .into()); + } } fn parse_conditional_or(&mut self) -> CelResult<(CompiledProg, AstNode)> { diff --git a/rscel/src/compiler/tokens.rs b/rscel/src/compiler/tokens.rs index df75439..f9db0fe 100644 --- a/rscel/src/compiler/tokens.rs +++ b/rscel/src/compiler/tokens.rs @@ -23,6 +23,7 @@ pub enum Token { RParen, // ) LessThan, // < GreaterThan, // > + UnderScore, // _ OrOr, // || AndAnd, // && LessEqual, // <= diff --git a/rscel/src/tests/general_tests.rs b/rscel/src/tests/general_tests.rs index aade850..4cac71a 100644 --- a/rscel/src/tests/general_tests.rs +++ b/rscel/src/tests/general_tests.rs @@ -185,7 +185,7 @@ fn test_contains() { #[test_case("duration('1s234ms').getMilliseconds()", 234; "duration.getMilliseconds")] #[test_case("duration('1h30m').getMinutes()", 90; "duration.getMinutes")] #[test_case("duration('1m30s').getSeconds()", 90; "duration.getSeconds")] -#[test_case("match 3 { case int: true, _: false}", false; "match int" )] +#[test_case("match 3 { case int: true, _: false}", true; "match int" )] fn test_equation(prog: &str, res: impl Into) { let mut ctx = CelContext::new(); let exec_ctx = BindContext::new(); From d0a45c4bbcc8eb61642d8a55607da04629b94ee8 Mon Sep 17 00:00:00 2001 From: matt Date: Sat, 22 Mar 2025 21:15:04 -0700 Subject: [PATCH 09/16] Have basic type matching working --- rscel/src/compiler/compiled_prog.rs | 8 --- rscel/src/compiler/compiler.rs | 95 ++++++++++++++++++-------- rscel/src/compiler/grammar.rs | 19 ++++++ rscel/src/compiler/tokens.rs | 1 - rscel/src/interp/interp.rs | 100 ++++++++++++++++------------ rscel/src/tests/general_tests.rs | 6 +- 6 files changed, 149 insertions(+), 80 deletions(-) diff --git a/rscel/src/compiler/compiled_prog.rs b/rscel/src/compiler/compiled_prog.rs index 080706e..89cd25e 100644 --- a/rscel/src/compiler/compiled_prog.rs +++ b/rscel/src/compiler/compiled_prog.rs @@ -78,10 +78,6 @@ impl CompiledProg { } } - pub fn details(&self) -> &ProgramDetails { - &self.details - } - pub fn with_bytecode(bytecode: CelByteCode) -> CompiledProg { CompiledProg { inner: NodeValue::Bytecode(bytecode.into()), @@ -244,10 +240,6 @@ impl CompiledProg { self.inner.into_bytecode() } - pub fn into_parts(self) -> (NodeValue, ProgramDetails) { - (self.inner, self.details) - } - pub fn is_const(&self) -> bool { self.inner.is_const() } diff --git a/rscel/src/compiler/compiler.rs b/rscel/src/compiler/compiler.rs index b92543b..1d43f37 100644 --- a/rscel/src/compiler/compiler.rs +++ b/rscel/src/compiler/compiler.rs @@ -12,8 +12,7 @@ use super::{ use crate::{ interp::{Interpreter, JmpWhen}, types::CelByteCode, - BindContext, ByteCode, CelError, CelResult, CelValue, CelValueDyn, Program, ProgramDetails, - StringTokenizer, + BindContext, ByteCode, CelError, CelResult, CelValue, CelValueDyn, Program, StringTokenizer, }; use crate::compile; @@ -54,7 +53,7 @@ impl<'l> CelCompiler<'l> { Ok(prog) } - fn get_label(&mut self) -> u32 { + fn new_label(&mut self) -> u32 { let n = self.next_label; self.next_label += 1; n @@ -154,8 +153,8 @@ impl<'l> CelCompiler<'l> { let true_clause_bytecode = true_clause_node.into_bytecode(); let false_clause_bytecode = false_clause_node.into_bytecode(); - let after_true_clause = self.get_label(); - let end_label = self.get_label(); + let after_true_clause = self.new_label(); + let end_label = self.new_label(); CompiledProg { inner: NodeValue::Bytecode( @@ -203,7 +202,7 @@ impl<'l> CelCompiler<'l> { let mut range = condition_ast.range(); - let (mut node_value, mut node_details) = condition_node.into_parts(); + let (node_value, mut node_details) = condition_node.into_parts(); let mut node_bytecode = node_value.into_bytecode(); let next = self.tokenizer.next()?; @@ -244,10 +243,7 @@ impl<'l> CelCompiler<'l> { //pattern let (pattern_prog, pattern_ast) = self.parse_match_pattern()?; let (pattern_bytecode, pattern_details) = pattern_prog.into_parts(); - let pattern_bytecode: Vec<_> = [ByteCode::Dup.into()] - .into_iter() - .chain(pattern_bytecode.into_bytecode().into_iter()) - .collect(); + let pattern_bytecode = pattern_bytecode.into_bytecode(); node_details.union_from(pattern_details); @@ -290,12 +286,31 @@ impl<'l> CelCompiler<'l> { } } - let after_match_l = self.get_label(); + // consume the RBRACE + self.tokenizer.next()?; + + // After match expression label + let after_match_s_l = self.new_label(); + + for (pattern_bytecode, expr_bytecode) in all_parts.into_iter() { + let after_case_l = self.new_label(); - for (pattern_bytecode, expr_bytecode) in all_parts.into_iter().rev() { node_bytecode.push(ByteCode::Dup); + node_bytecode.extend(pattern_bytecode.into_iter()); + node_bytecode.push(PreResolvedCodePoint::JmpCond { + when: JmpWhen::False, + label: after_case_l, + }); + + node_bytecode.extend(expr_bytecode); + node_bytecode.push(PreResolvedCodePoint::Jmp { + label: after_match_s_l, + }); + node_bytecode.push(PreResolvedCodePoint::Label(after_case_l)); } + node_bytecode.push(PreResolvedCodePoint::Label(after_match_s_l)); + Ok(( CompiledProg::new(NodeValue::Bytecode(node_bytecode), node_details), AstNode::new( @@ -313,18 +328,44 @@ impl<'l> CelCompiler<'l> { if let Some(t) = self.tokenizer.next()? { match t.token { - Token::UnderScore => { + Token::Ident(i) => { let range = SourceRange::new(start, self.tokenizer.location()); - return Ok(( - CompiledProg::with_bytecode(CelByteCode::from_vec(vec![ - ByteCode::Pop, // pop off the pattern value - ByteCode::Push(CelValue::true_()), // push true - ])), - AstNode::new( - MatchPattern::Any(AstNode::new(MatchAnyPattern {}, range)), - range, + let (bytecode_vec, pattern_type) = match i.as_str() { + "int" | "uint" | "float" | "double" | "string" | "bool" | "bytes" + | "list" | "object" | "null" | "timestamp" | "duration" => ( + vec![ + ByteCode::Push(CelValue::Ident("type".to_owned())), + ByteCode::Call(1), + ByteCode::Push(CelValue::Ident(i.clone())), + ByteCode::Eq, + ], + MatchPattern::Type(AstNode::new( + MatchTypePattern::from_type_str(&i), + range, + )), ), - )); + "_" => { + ( + vec![ + ByteCode::Pop, // pop off the pattern value + ByteCode::Push(CelValue::true_()), // push true + ], + MatchPattern::Any(AstNode::new(MatchAnyPattern {}, range)), + ) + } + _ => { + return Err(SyntaxError::from_location(self.tokenizer.location()) + .with_message( + "_ is the only identifier allowed in case expressions" + .to_owned(), + ) + .into()); + } + }; + Ok(( + CompiledProg::with_bytecode(CelByteCode::from_vec(bytecode_vec)), + AstNode::new(pattern_type, range), + )) } other => { return Err(SyntaxError::from_location(self.tokenizer.location()) @@ -342,7 +383,7 @@ impl<'l> CelCompiler<'l> { fn parse_conditional_or(&mut self) -> CelResult<(CompiledProg, AstNode)> { let (mut current_node, mut current_ast) = into_unary(self.parse_conditional_and()?); - let label = self.get_label(); + let label = self.new_label(); loop { if let Some(Token::OrOr) = self.tokenizer.peek()?.as_token() { @@ -387,7 +428,7 @@ impl<'l> CelCompiler<'l> { fn parse_conditional_and(&mut self) -> CelResult<(CompiledProg, AstNode)> { let (mut current_node, mut current_ast) = into_unary(self.parse_relation()?); - let label = self.get_label(); + let label = self.new_label(); loop { if let Some(Token::AndAnd) = self.tokenizer.peek()?.as_token() { @@ -921,8 +962,8 @@ impl<'l> CelCompiler<'l> { ])) } - member_prime_node = member_prime_node - .consume_child(args_node) + member_prime_node = args_node + .consume_child(member_prime_node) .consume_child(CompiledProg::with_code_points(vec![ByteCode::Call( args_len as u32, ) @@ -1213,7 +1254,6 @@ impl<'l> CelCompiler<'l> { let mut bytecode = Vec::::new(); for segment in segments.iter() { - bytecode.push(ByteCode::Push(CelValue::Ident("string".to_string())).into()); match segment { FStringSegment::Lit(c) => { bytecode.push(ByteCode::Push(CelValue::String(c.clone())).into()) @@ -1232,6 +1272,7 @@ impl<'l> CelCompiler<'l> { ); } } + bytecode.push(ByteCode::Push(CelValue::Ident("string".to_string())).into()); bytecode.push(ByteCode::Call(1).into()); } diff --git a/rscel/src/compiler/grammar.rs b/rscel/src/compiler/grammar.rs index 2876e3a..f958d50 100644 --- a/rscel/src/compiler/grammar.rs +++ b/rscel/src/compiler/grammar.rs @@ -58,6 +58,25 @@ pub enum MatchTypePattern { Duration, } +impl MatchTypePattern { + pub fn from_type_str(s: &str) -> Self { + match s { + "int" => MatchTypePattern::Int, + "uint" => MatchTypePattern::Uint, + "float" | "double" => MatchTypePattern::Float, + "string" => MatchTypePattern::String, + "bool" => MatchTypePattern::Bool, + "bytes" => MatchTypePattern::Bytes, + "list" => MatchTypePattern::List, + "object" => MatchTypePattern::Object, + "null" => MatchTypePattern::Null, + "timestamp" => MatchTypePattern::Timestamp, + "duration" => MatchTypePattern::Duration, + _ => panic!("Unknown type"), + } + } +} + #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct MatchAnyPattern; diff --git a/rscel/src/compiler/tokens.rs b/rscel/src/compiler/tokens.rs index f9db0fe..df75439 100644 --- a/rscel/src/compiler/tokens.rs +++ b/rscel/src/compiler/tokens.rs @@ -23,7 +23,6 @@ pub enum Token { RParen, // ) LessThan, // < GreaterThan, // > - UnderScore, // _ OrOr, // || AndAnd, // && LessEqual, // <= diff --git a/rscel/src/interp/interp.rs b/rscel/src/interp/interp.rs index 7734019..182a79c 100644 --- a/rscel/src/interp/interp.rs +++ b/rscel/src/interp/interp.rs @@ -155,6 +155,9 @@ impl<'a> Interpreter<'a> { pc += 1; match &prog[oldpc] { ByteCode::Push(val) => stack.push_val(val.clone()), + ByteCode::Pop => { + stack.pop_val()?; + } ByteCode::Test => { let v = stack.pop_val()?; @@ -400,56 +403,67 @@ impl<'a> Interpreter<'a> { } } ByteCode::Call(n_args) => { - let mut args = Vec::new(); + match stack.pop_noresolve()? { + CelStackValue::BoundCall { callable, value } => { + let mut args = Vec::new(); - for _ in 0..*n_args { - args.push(stack.pop()?.into_value()?) - } + for _ in 0..*n_args { + args.push(stack.pop()?.into_value()?) + } - match stack.pop_noresolve()? { - CelStackValue::BoundCall { callable, value } => match callable { - RsCallable::Function(func) => { - let arg_values = self.resolve_args(args)?; - stack.push_val(func(value, arg_values)); + match callable { + RsCallable::Function(func) => { + let arg_values = self.resolve_args(args)?; + stack.push_val(func(value, arg_values)); + } + RsCallable::Macro(macro_) => { + stack.push_val(self.call_macro(&value, &args, macro_)?); + } } - RsCallable::Macro(macro_) => { - stack.push_val(self.call_macro(&value, &args, macro_)?); + } + CelStackValue::Value(value) => { + let mut args = Vec::new(); + + for _ in 0..*n_args { + args.push(stack.pop()?.into_value()?) } - }, - CelStackValue::Value(value) => match value { - CelValue::Ident(func_name) => { - if let Some(func) = self.get_func_by_name(&func_name) { - let arg_values = self.resolve_args(args)?; - stack.push_val(func(CelValue::from_null(), arg_values)); - } else if let Some(macro_) = self.get_macro_by_name(&func_name) { - stack.push_val(self.call_macro( - &CelValue::from_null(), - &args, - macro_, - )?); - } else if let Some(CelValue::Type(type_name)) = - self.get_type_by_name(&func_name) - { + + match value { + CelValue::Ident(func_name) => { + if let Some(func) = self.get_func_by_name(&func_name) { + let arg_values = self.resolve_args(args)?; + stack.push_val(func(CelValue::from_null(), arg_values)); + } else if let Some(macro_) = self.get_macro_by_name(&func_name) + { + stack.push_val(self.call_macro( + &CelValue::from_null(), + &args, + macro_, + )?); + } else if let Some(CelValue::Type(type_name)) = + self.get_type_by_name(&func_name) + { + let arg_values = self.resolve_args(args)?; + stack.push_val(construct_type(type_name, arg_values)); + } else { + stack.push_val(CelValue::from_err(CelError::runtime( + &format!("{} is not callable", func_name), + ))); + } + } + CelValue::Type(type_name) => { let arg_values = self.resolve_args(args)?; - stack.push_val(construct_type(type_name, arg_values)); - } else { - stack.push_val(CelValue::from_err(CelError::runtime( - &format!("{} is not callable", func_name), - ))); + stack.push_val(construct_type(&type_name, arg_values)); } + other => stack.push_val( + CelValue::from_err(CelError::runtime(&format!( + "{:?} cannot be called", + other + ))) + .into(), + ), } - CelValue::Type(type_name) => { - let arg_values = self.resolve_args(args)?; - stack.push_val(construct_type(&type_name, arg_values)); - } - other => stack.push_val( - CelValue::from_err(CelError::runtime(&format!( - "{:?} cannot be called", - other - ))) - .into(), - ), - }, + } }; } ByteCode::FmtString(nsegments) => { diff --git a/rscel/src/tests/general_tests.rs b/rscel/src/tests/general_tests.rs index 4cac71a..f0c2219 100644 --- a/rscel/src/tests/general_tests.rs +++ b/rscel/src/tests/general_tests.rs @@ -185,7 +185,11 @@ fn test_contains() { #[test_case("duration('1s234ms').getMilliseconds()", 234; "duration.getMilliseconds")] #[test_case("duration('1h30m').getMinutes()", 90; "duration.getMinutes")] #[test_case("duration('1m30s').getSeconds()", 90; "duration.getSeconds")] -#[test_case("match 3 { case int: true, _: false}", true; "match int" )] +#[test_case("match 'foo' {case int: false, case _: true}", true; "match else")] +#[test_case("match 3 { case int: true, case _: false}", true; "match int" )] +#[test_case("match 2.0 { case float: true, case _: flase}", true; "match float")] +#[test_case("match 'foo' { case string: true, case _: false}", true; "match string")] +#[test_case("match false { case bool: true, case _: false}", true; "match bool")] fn test_equation(prog: &str, res: impl Into) { let mut ctx = CelContext::new(); let exec_ctx = BindContext::new(); From 4fb38c59208fc38eebb25a1e8b103f98cd547f5a Mon Sep 17 00:00:00 2001 From: matt Date: Sun, 23 Mar 2025 13:12:45 -0700 Subject: [PATCH 10/16] Have OR parsed as pattern --- rscel/src/compiler/compiler.rs | 88 +++++++++++++++++--------------- rscel/src/compiler/grammar.rs | 1 + rscel/src/tests/general_tests.rs | 3 ++ rscel/src/types/cel_value.rs | 11 ++-- 4 files changed, 59 insertions(+), 44 deletions(-) diff --git a/rscel/src/compiler/compiler.rs b/rscel/src/compiler/compiler.rs index 1d43f37..4ffb0b1 100644 --- a/rscel/src/compiler/compiler.rs +++ b/rscel/src/compiler/compiler.rs @@ -11,7 +11,6 @@ use super::{ }; use crate::{ interp::{Interpreter, JmpWhen}, - types::CelByteCode, BindContext, ByteCode, CelError, CelResult, CelValue, CelValueDyn, Program, StringTokenizer, }; @@ -326,58 +325,65 @@ impl<'l> CelCompiler<'l> { fn parse_match_pattern(&mut self) -> CelResult<(CompiledProg, AstNode)> { let start = self.tokenizer.location(); - if let Some(t) = self.tokenizer.next()? { - match t.token { - Token::Ident(i) => { + if let Some(t) = self.tokenizer.peek()? { + if let Token::Ident(i) = t.token() { + let i = i.clone(); + if i == "_" { + self.tokenizer.next()?; let range = SourceRange::new(start, self.tokenizer.location()); - let (bytecode_vec, pattern_type) = match i.as_str() { - "int" | "uint" | "float" | "double" | "string" | "bool" | "bytes" - | "list" | "object" | "null" | "timestamp" | "duration" => ( - vec![ + + return Ok(( + CompiledProg::with_bytecode( + [ + ByteCode::Pop, // pop off the pattern value + ByteCode::Push(CelValue::true_()), // push true + ] + .into_iter() + .collect(), + ), + AstNode::new( + MatchPattern::Any(AstNode::new(MatchAnyPattern {}, range)), + range, + ), + )); + } else if self.bindings.get_type(&i).is_some() { + self.tokenizer.next()?; + return Ok(( + CompiledProg::with_bytecode( + [ ByteCode::Push(CelValue::Ident("type".to_owned())), ByteCode::Call(1), ByteCode::Push(CelValue::Ident(i.clone())), ByteCode::Eq, - ], + ] + .into_iter() + .collect(), + ), + AstNode::new( MatchPattern::Type(AstNode::new( MatchTypePattern::from_type_str(&i), - range, + SourceRange::new(start, self.tokenizer.location()), )), + SourceRange::new(start, self.tokenizer.location()), ), - "_" => { - ( - vec![ - ByteCode::Pop, // pop off the pattern value - ByteCode::Push(CelValue::true_()), // push true - ], - MatchPattern::Any(AstNode::new(MatchAnyPattern {}, range)), - ) - } - _ => { - return Err(SyntaxError::from_location(self.tokenizer.location()) - .with_message( - "_ is the only identifier allowed in case expressions" - .to_owned(), - ) - .into()); - } - }; - Ok(( - CompiledProg::with_bytecode(CelByteCode::from_vec(bytecode_vec)), - AstNode::new(pattern_type, range), - )) - } - other => { - return Err(SyntaxError::from_location(self.tokenizer.location()) - .with_message(format!("Expected PATTERN got {:?}", other)) - .into()) + )); } } - } else { - return Err(SyntaxError::from_location(self.tokenizer.location()) - .with_message(format!("Expected PATTERN got NOTHING")) - .into()); } + + let (or_prod, or_ast) = self.parse_conditional_or()?; + let or_details = or_prod.details().clone(); + let mut or_bc = or_prod.into_unresolved_bytecode(); + + or_bc.push(ByteCode::Eq); + + Ok(( + CompiledProg::new(NodeValue::Bytecode(or_bc), or_details), + AstNode::new( + MatchPattern::Or(or_ast), + SourceRange::new(start, self.tokenizer.location()), + ), + )) } fn parse_conditional_or(&mut self) -> CelResult<(CompiledProg, AstNode)> { diff --git a/rscel/src/compiler/grammar.rs b/rscel/src/compiler/grammar.rs index f958d50..56ad897 100644 --- a/rscel/src/compiler/grammar.rs +++ b/rscel/src/compiler/grammar.rs @@ -39,6 +39,7 @@ pub struct MatchCase { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum MatchPattern { + Or(AstNode), Type(AstNode), Any(AstNode), } diff --git a/rscel/src/tests/general_tests.rs b/rscel/src/tests/general_tests.rs index f0c2219..d13adaa 100644 --- a/rscel/src/tests/general_tests.rs +++ b/rscel/src/tests/general_tests.rs @@ -190,6 +190,9 @@ fn test_contains() { #[test_case("match 2.0 { case float: true, case _: flase}", true; "match float")] #[test_case("match 'foo' { case string: true, case _: false}", true; "match string")] #[test_case("match false { case bool: true, case _: false}", true; "match bool")] +#[test_case("match 3 { case 3: true, case _: flase}", true; "match int literal")] +#[test_case("match 3.0 { case 3.0: true, case _: flase}", true; "match float literal")] +#[test_case("match '3' { case '3': true, case _: flase}", true; "match string literal")] fn test_equation(prog: &str, res: impl Into) { let mut ctx = CelContext::new(); let exec_ctx = BindContext::new(); diff --git a/rscel/src/types/cel_value.rs b/rscel/src/types/cel_value.rs index 86e34b6..f8515b6 100644 --- a/rscel/src/types/cel_value.rs +++ b/rscel/src/types/cel_value.rs @@ -23,6 +23,11 @@ use crate::{interp::ByteCode, CelError, CelResult, CelValueDyn}; use super::{cel_byte_code::CelByteCode, CelBytes}; +pub type CelTimeStamp = DateTime; +pub type CelValueVec = Vec; +pub type CelValueMap = HashMap; +pub type CelValueBytes = Vec; + /// The basic value of the CEL interpreter. /// /// Houses all possible types and implements most of the valid operations within the @@ -40,13 +45,13 @@ pub enum CelValue { Bool(bool), String(String), Bytes(CelBytes), - List(Vec), - Map(HashMap), + List(CelValueVec), + Map(CelValueMap), Null, Ident(String), Type(String), #[serde(with = "ts_milliseconds")] - TimeStamp(DateTime), + TimeStamp(CelTimeStamp), #[serde( serialize_with = "DurationMilliSeconds::::serialize_as", deserialize_with = "DurationMilliSeconds::::deserialize_as" From 95cb7ad97c60f1d9fae1bfcac4ce99092922ee98 Mon Sep 17 00:00:00 2001 From: matt Date: Sun, 23 Mar 2025 18:40:26 -0700 Subject: [PATCH 11/16] Finished match v1 --- rscel/src/compiler/compiler.rs | 19 +++++++- rscel/src/compiler/compiler/pattern_utils.rs | 48 ++++++++++++++++++++ rscel/src/compiler/grammar.rs | 15 +++++- rscel/src/tests/general_tests.rs | 6 +++ 4 files changed, 85 insertions(+), 3 deletions(-) create mode 100644 rscel/src/compiler/compiler/pattern_utils.rs diff --git a/rscel/src/compiler/compiler.rs b/rscel/src/compiler/compiler.rs index 4ffb0b1..c440b88 100644 --- a/rscel/src/compiler/compiler.rs +++ b/rscel/src/compiler/compiler.rs @@ -1,5 +1,9 @@ use std::collections::HashMap; +mod pattern_utils; + +use pattern_utils::PrefixPattern; + use super::{ ast_node::AstNode, compiled_prog::{CompiledProg, NodeValue, PreResolvedCodePoint}, @@ -324,6 +328,7 @@ impl<'l> CelCompiler<'l> { fn parse_match_pattern(&mut self) -> CelResult<(CompiledProg, AstNode)> { let start = self.tokenizer.location(); + let mut prefix_pattern = PrefixPattern::Eq; if let Some(t) = self.tokenizer.peek()? { if let Token::Ident(i) = t.token() { @@ -369,18 +374,28 @@ impl<'l> CelCompiler<'l> { )); } } + + if let Some(token_prefix_pattern) = PrefixPattern::from_token(t.token()) { + self.tokenizer.next()?; + prefix_pattern = token_prefix_pattern; + } } + let op_range = SourceRange::new(start, self.tokenizer.location()); + let (or_prod, or_ast) = self.parse_conditional_or()?; let or_details = or_prod.details().clone(); let mut or_bc = or_prod.into_unresolved_bytecode(); - or_bc.push(ByteCode::Eq); + or_bc.push(prefix_pattern.as_bytecode()); Ok(( CompiledProg::new(NodeValue::Bytecode(or_bc), or_details), AstNode::new( - MatchPattern::Or(or_ast), + MatchPattern::Cmp { + op: AstNode::new(prefix_pattern.as_ast(), op_range), + or: or_ast, + }, SourceRange::new(start, self.tokenizer.location()), ), )) diff --git a/rscel/src/compiler/compiler/pattern_utils.rs b/rscel/src/compiler/compiler/pattern_utils.rs new file mode 100644 index 0000000..d4dcd80 --- /dev/null +++ b/rscel/src/compiler/compiler/pattern_utils.rs @@ -0,0 +1,48 @@ +use crate::{compiler::tokens::Token, ByteCode}; + +use super::MatchCmpOp; + +pub enum PrefixPattern { + Eq, + Neq, + Gt, + Ge, + Lt, + Le, +} + +impl PrefixPattern { + pub fn from_token(token: &Token) -> Option { + match token { + Token::EqualEqual => Some(PrefixPattern::Eq), + Token::NotEqual => Some(PrefixPattern::Neq), + Token::GreaterThan => Some(PrefixPattern::Gt), + Token::GreaterEqual => Some(PrefixPattern::Ge), + Token::LessThan => Some(PrefixPattern::Lt), + Token::LessEqual => Some(PrefixPattern::Le), + _ => None, + } + } + + pub fn as_bytecode(&self) -> ByteCode { + match self { + PrefixPattern::Eq => ByteCode::Eq, + PrefixPattern::Neq => ByteCode::Ne, + PrefixPattern::Gt => ByteCode::Gt, + PrefixPattern::Ge => ByteCode::Ge, + PrefixPattern::Lt => ByteCode::Lt, + PrefixPattern::Le => ByteCode::Le, + } + } + + pub fn as_ast(&self) -> MatchCmpOp { + match self { + PrefixPattern::Eq => MatchCmpOp::Eq, + PrefixPattern::Neq => MatchCmpOp::Neq, + PrefixPattern::Gt => MatchCmpOp::Gt, + PrefixPattern::Ge => MatchCmpOp::Ge, + PrefixPattern::Lt => MatchCmpOp::Lt, + PrefixPattern::Le => MatchCmpOp::Le, + } + } +} diff --git a/rscel/src/compiler/grammar.rs b/rscel/src/compiler/grammar.rs index 56ad897..3cd85a7 100644 --- a/rscel/src/compiler/grammar.rs +++ b/rscel/src/compiler/grammar.rs @@ -39,11 +39,24 @@ pub struct MatchCase { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum MatchPattern { - Or(AstNode), + Cmp { + op: AstNode, + or: AstNode, + }, Type(AstNode), Any(AstNode), } +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum MatchCmpOp { + Eq, + Neq, + Gt, + Ge, + Lt, + Le, +} + #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum MatchTypePattern { Int, diff --git a/rscel/src/tests/general_tests.rs b/rscel/src/tests/general_tests.rs index d13adaa..2982781 100644 --- a/rscel/src/tests/general_tests.rs +++ b/rscel/src/tests/general_tests.rs @@ -193,6 +193,12 @@ fn test_contains() { #[test_case("match 3 { case 3: true, case _: flase}", true; "match int literal")] #[test_case("match 3.0 { case 3.0: true, case _: flase}", true; "match float literal")] #[test_case("match '3' { case '3': true, case _: flase}", true; "match string literal")] +#[test_case("match 3 { case >2: true, case _: false}", true; "match greater than")] +#[test_case("match 3 { case >=2: true, case _: false}", true; "match greater equal")] +#[test_case("match 3 { case >=3: true, case _: false}", true; "match greater equal equal")] +#[test_case("match 3 { case <2: false, case _: true}", true; "match less than")] +#[test_case("match 3 { case <=2: false, case _: true}", true; "match less equal")] +#[test_case("match 3 { case <=3: true, case _: false}", true; "match less equal equal")] fn test_equation(prog: &str, res: impl Into) { let mut ctx = CelContext::new(); let exec_ctx = BindContext::new(); From 2e764c62f3b5a1c7858b8367557ba40019c70b58 Mon Sep 17 00:00:00 2001 From: matt Date: Sun, 23 Mar 2025 18:54:28 -0700 Subject: [PATCH 12/16] Make bad match return null --- rscel/src/compiler/compiler.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/rscel/src/compiler/compiler.rs b/rscel/src/compiler/compiler.rs index c440b88..31bbea5 100644 --- a/rscel/src/compiler/compiler.rs +++ b/rscel/src/compiler/compiler.rs @@ -312,7 +312,11 @@ impl<'l> CelCompiler<'l> { node_bytecode.push(PreResolvedCodePoint::Label(after_case_l)); } - node_bytecode.push(PreResolvedCodePoint::Label(after_match_s_l)); + node_bytecode.extend([ + ByteCode::Pop.into(), + ByteCode::Push(CelValue::from_null()).into(), + PreResolvedCodePoint::Label(after_match_s_l), + ]); Ok(( CompiledProg::new(NodeValue::Bytecode(node_bytecode), node_details), From 3a81748ffe0950a8cbf856b6a0e12c8c3b7dbea8 Mon Sep 17 00:00:00 2001 From: 1BADragon <6611786+1BADragon@users.noreply.github.com> Date: Sun, 13 Apr 2025 18:36:27 -0700 Subject: [PATCH 13/16] implement negative index as feature --- rscel/Cargo.toml | 3 +- rscel/src/tests/general_tests.rs | 2 ++ rscel/src/types/cel_value.rs | 52 ++++++++++++++++++++++++-------- 3 files changed, 43 insertions(+), 14 deletions(-) diff --git a/rscel/Cargo.toml b/rscel/Cargo.toml index c7c27c6..1f2bd79 100644 --- a/rscel/Cargo.toml +++ b/rscel/Cargo.toml @@ -7,10 +7,11 @@ license = { workspace = true } readme = "../README.md" [features] -default = ["type_prop", "protobuf"] +default = ["type_prop", "protobuf", "neg_index"] ast_ser = [] debug_output = [] type_prop = [] +neg_index = [] protobuf = ["dep:protobuf"] [build-dependencies] diff --git a/rscel/src/tests/general_tests.rs b/rscel/src/tests/general_tests.rs index 2982781..e304bee 100644 --- a/rscel/src/tests/general_tests.rs +++ b/rscel/src/tests/general_tests.rs @@ -199,6 +199,8 @@ fn test_contains() { #[test_case("match 3 { case <2: false, case _: true}", true; "match less than")] #[test_case("match 3 { case <=2: false, case _: true}", true; "match less equal")] #[test_case("match 3 { case <=3: true, case _: false}", true; "match less equal equal")] +#[test_case("[1,2,3][-1]", 3; "negative index 1")] +#[test_case("[1,2,3][-2]", 2; "negative index 2")] fn test_equation(prog: &str, res: impl Into) { let mut ctx = CelContext::new(); let exec_ctx = BindContext::new(); diff --git a/rscel/src/types/cel_value.rs b/rscel/src/types/cel_value.rs index f8515b6..9c3a71b 100644 --- a/rscel/src/types/cel_value.rs +++ b/rscel/src/types/cel_value.rs @@ -26,7 +26,6 @@ use super::{cel_byte_code::CelByteCode, CelBytes}; pub type CelTimeStamp = DateTime; pub type CelValueVec = Vec; pub type CelValueMap = HashMap; -pub type CelValueBytes = Vec; /// The basic value of the CEL interpreter. /// @@ -552,26 +551,53 @@ impl CelValue { pub fn index(self, ival: CelValue) -> CelValue { self.error_prop_or(ival, |obj, index| match obj { CelValue::List(list) => { - let index = if let CelValue::UInt(ref index) = index { - *index as usize + if let CelValue::UInt(index) = index { + if index as usize >= list.len() { + return CelValue::from_err(CelError::value("List access out of bounds")); + } + + return list[index as usize].clone(); } else if let CelValue::Int(index) = index { if index < 0 { - return CelValue::from_err(CelError::value( - "Negative index is not allowed", - )); + if cfg!(feature = "neg_index") { + let adjusted_index: isize = match TryInto::::try_into(list.len()) + { + Ok(v) => v, + Err(_) => { + return CelValue::from_err(CelError::value( + "List access out of bounds", + )) + } + } + (index as isize); + + if adjusted_index < 0 + || TryInto::::try_into(adjusted_index).unwrap() >= list.len() + { + return CelValue::from_err(CelError::value( + "List access out of bounds 3", + )); + } + + list[adjusted_index as usize].clone() + } else { + return CelValue::from_err(CelError::value( + "Negative index is not allowed", + )); + } + } else { + if index as usize >= list.len() { + return CelValue::from_err(CelError::value( + "List access out of bounds", + )); + } + + list[index as usize].clone() } - index as usize } else { return CelValue::from_err(CelError::value( "List index can only be int or uint", )); - }; - - if index >= list.len() { - return CelValue::from_err(CelError::value("List access out of bounds")); } - - return list[index].clone(); } CelValue::Map(map) => { if let CelValue::String(index) = index { From 3a6dd0e7b51ca8b0fb72d87b34735ac04af06d42 Mon Sep 17 00:00:00 2001 From: 1BADragon <6611786+1BADragon@users.noreply.github.com> Date: Sun, 13 Apr 2025 18:38:41 -0700 Subject: [PATCH 14/16] version bump --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 3bfe572..336ce58 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ default-members = ["rscel"] resolver = "2" [workspace.package] -version = "1.0.4" +version = "1.0.5" edition = "2021" description = "Cel interpreter in rust" license = "MIT" From 9a6a1e7de0964c488c48b272efd498f700eeed09 Mon Sep 17 00:00:00 2001 From: 1BADragon <6611786+1BADragon@users.noreply.github.com> Date: Sun, 13 Apr 2025 19:09:42 -0700 Subject: [PATCH 15/16] update pyi file --- python/rscel.pyi | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/python/rscel.pyi b/python/rscel.pyi index 581e7ce..112ce94 100644 --- a/python/rscel.pyi +++ b/python/rscel.pyi @@ -12,3 +12,51 @@ CelBinding = dict[str, CelValue | CelCallable | Any] def eval(prog: str, binding: CelBinding) -> CelValue: ... + +class CelProgram: + def __init__(self): + ... + + def add_source(self, source: str): + ... + + def add_serialized_json(self, source: str): + ... + + def add_serialized_bincode(self, bincode: bytes): + ... + + def serialize_to_json(self) -> str: + ... + + def serialize_to_bincode(self) -> bytes: + ... + + def details_json(self) -> str: + ... + +class BindContext: + def __init__(self): + ... + + def bind_param(self, name: str, val: CelValue): + ... + + def bind_func(self, name: str, val: Callable[[Any, Any]]): + ... + + def bind(self, name: str, val: Any): + ... + +class CelContext: + def __init__(self): + ... + + def add_program_string(self, name: str, source: str): + ... + + def add_program(self, name: str, prog: "CelProgram"): + ... + + def exec(self, name: str, bindings: BindContext) -> CelValue: + ... From 24378fe6f172be120baf839e787e18b9fa0c021b Mon Sep 17 00:00:00 2001 From: 1BADragon <6611786+1BADragon@users.noreply.github.com> Date: Sun, 13 Apr 2025 19:19:41 -0700 Subject: [PATCH 16/16] fix the tests --- rscel/src/tests/general_tests.rs | 2 -- rscel/src/tests/mod.rs | 1 + rscel/src/tests/neg_index_tests.rs | 20 ++++++++++++++++++++ wasm/tests/package-lock.json | 2 +- 4 files changed, 22 insertions(+), 3 deletions(-) create mode 100644 rscel/src/tests/neg_index_tests.rs diff --git a/rscel/src/tests/general_tests.rs b/rscel/src/tests/general_tests.rs index e304bee..2982781 100644 --- a/rscel/src/tests/general_tests.rs +++ b/rscel/src/tests/general_tests.rs @@ -199,8 +199,6 @@ fn test_contains() { #[test_case("match 3 { case <2: false, case _: true}", true; "match less than")] #[test_case("match 3 { case <=2: false, case _: true}", true; "match less equal")] #[test_case("match 3 { case <=3: true, case _: false}", true; "match less equal equal")] -#[test_case("[1,2,3][-1]", 3; "negative index 1")] -#[test_case("[1,2,3][-2]", 2; "negative index 2")] fn test_equation(prog: &str, res: impl Into) { let mut ctx = CelContext::new(); let exec_ctx = BindContext::new(); diff --git a/rscel/src/tests/mod.rs b/rscel/src/tests/mod.rs index 28d2315..0c26457 100644 --- a/rscel/src/tests/mod.rs +++ b/rscel/src/tests/mod.rs @@ -1,4 +1,5 @@ mod general_tests; +mod neg_index_tests; mod type_prop_tests; #[cfg(test_protos)] diff --git a/rscel/src/tests/neg_index_tests.rs b/rscel/src/tests/neg_index_tests.rs new file mode 100644 index 0000000..c403d7f --- /dev/null +++ b/rscel/src/tests/neg_index_tests.rs @@ -0,0 +1,20 @@ +use crate::{BindContext, CelContext}; + +#[test] +fn test_neg_index() { + let mut ctx = CelContext::new(); + let bindings = BindContext::new(); + + ctx.add_program_str("test1", "[1,2,3][-1]") + .expect("Failed to compile program"); + ctx.add_program_str("test2", "[1,2,3][-2]") + .expect("Failed to compile program"); + + if cfg!(feature = "neg_index") { + assert_eq!(ctx.exec("test1", &bindings).unwrap(), 3.into()); + assert_eq!(ctx.exec("test2", &bindings).unwrap(), 2.into()); + } else { + assert!(ctx.exec("test1", &bindings).is_err()); + assert!(ctx.exec("test2", &bindings).is_err()); + } +} diff --git a/wasm/tests/package-lock.json b/wasm/tests/package-lock.json index b0c9309..e389e03 100644 --- a/wasm/tests/package-lock.json +++ b/wasm/tests/package-lock.json @@ -16,7 +16,7 @@ }, "../pkg": { "name": "rscel_wasm", - "version": "1.0.4", + "version": "1.0.5", "license": "MIT" }, "node_modules/@esbuild/aix-ppc64": {