Skip to content

Commit c01b628

Browse files
committed
Add comptime eval for user pure functions
PureFunctions registry tracks functions with pure bodies is_pure_expr_with_fns checks purity with knowledge of registered pure fns Functions calling other pure functions are also considered pure Calls to pure functions with constant args are folded at comptime Works with nested pure function calls: (quad 3) -> 12 try_const_eval_with_fns substitutes parameters and evaluates compile_all method for compiling multiple sexps with shared pure function
1 parent 571d98b commit c01b628

File tree

1 file changed

+231
-8
lines changed

1 file changed

+231
-8
lines changed

src/compiler.rs

Lines changed: 231 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,98 @@
11
use crate::bytecode::{Chunk, ConstIdx, Op, Reg};
22
use crate::value::Value;
3+
use std::collections::HashMap;
4+
5+
/// A pure function definition (for compile-time evaluation)
6+
#[derive(Clone)]
7+
struct PureFunction {
8+
params: Vec<String>,
9+
body: Value,
10+
}
11+
12+
/// Registry of pure functions for compile-time evaluation
13+
#[derive(Clone, Default)]
14+
struct PureFunctions {
15+
funcs: HashMap<String, PureFunction>,
16+
}
17+
18+
impl PureFunctions {
19+
fn new() -> Self {
20+
PureFunctions {
21+
funcs: HashMap::new(),
22+
}
23+
}
24+
25+
fn register(&mut self, name: &str, params: Vec<String>, body: Value) {
26+
self.funcs.insert(name.to_string(), PureFunction { params, body });
27+
}
28+
29+
fn get(&self, name: &str) -> Option<&PureFunction> {
30+
self.funcs.get(name)
31+
}
32+
}
333

434
pub struct Compiler {
535
chunk: Chunk,
6-
locals: Vec<String>, // local variable names, index = register
36+
locals: Vec<String>,
737
scope_depth: usize,
38+
pure_fns: PureFunctions,
39+
}
40+
41+
/// Check if an expression is pure (no side effects)
42+
fn is_pure_expr_with_fns(expr: &Value, pure_fns: &PureFunctions) -> bool {
43+
match expr {
44+
// Literals are pure
45+
Value::Int(_) | Value::Float(_) | Value::Bool(_) | Value::Nil | Value::String(_) => true,
46+
// Symbols are pure (just variable references)
47+
Value::Symbol(_) => true,
48+
// Lists need to be checked
49+
Value::List(items) if items.is_empty() => true,
50+
Value::List(items) => {
51+
let first = &items[0];
52+
if let Some(sym) = first.as_symbol() {
53+
match sym {
54+
// Pure built-in operations
55+
"+" | "-" | "*" | "/" | "mod" | "<" | "<=" | ">" | ">=" | "=" | "!=" | "not" => {
56+
items[1..].iter().all(|e| is_pure_expr_with_fns(e, pure_fns))
57+
}
58+
// Conditional is pure if branches are pure
59+
"if" => items[1..].iter().all(|e| is_pure_expr_with_fns(e, pure_fns)),
60+
// Let is pure if bindings and body are pure
61+
"let" => {
62+
if items.len() >= 3 {
63+
if let Some(bindings) = items[1].as_list() {
64+
bindings.iter().all(|e| is_pure_expr_with_fns(e, pure_fns)) &&
65+
items[2..].iter().all(|e| is_pure_expr_with_fns(e, pure_fns))
66+
} else {
67+
false
68+
}
69+
} else {
70+
false
71+
}
72+
}
73+
// Quote is pure
74+
"quote" => true,
75+
// Check if it's a known pure function
76+
_ => {
77+
if pure_fns.get(sym).is_some() {
78+
// It's a call to a known pure function, check args are pure
79+
items[1..].iter().all(|e| is_pure_expr_with_fns(e, pure_fns))
80+
} else {
81+
false
82+
}
83+
}
84+
}
85+
} else {
86+
false
87+
}
88+
}
89+
_ => false,
90+
}
891
}
992

10-
/// Try to evaluate a constant expression recursively
11-
fn try_const_eval(expr: &Value) -> Option<Value> {
93+
94+
/// Try to evaluate a constant expression recursively, including pure function calls
95+
fn try_const_eval_with_fns(expr: &Value, pure_fns: &PureFunctions) -> Option<Value> {
1296
match expr {
1397
// Literals are constants
1498
Value::Int(_) | Value::Float(_) | Value::Bool(_) | Value::Nil | Value::String(_) => {
@@ -18,16 +102,56 @@ fn try_const_eval(expr: &Value) -> Option<Value> {
18102
Value::List(items) if !items.is_empty() => {
19103
let op = items[0].as_symbol()?;
20104
let args = &items[1..];
21-
// Recursively evaluate arguments
22-
let const_args: Option<Vec<Value>> = args.iter().map(try_const_eval).collect();
105+
106+
// First try built-in operations
107+
let const_args: Option<Vec<Value>> = args.iter()
108+
.map(|a| try_const_eval_with_fns(a, pure_fns))
109+
.collect();
23110
let const_args = const_args?;
24111
let const_refs: Vec<&Value> = const_args.iter().collect();
25-
fold_op(op, &const_refs)
112+
113+
if let Some(result) = fold_op(op, &const_refs) {
114+
return Some(result);
115+
}
116+
117+
// Try pure user-defined functions
118+
if let Some(pure_fn) = pure_fns.get(op) {
119+
if pure_fn.params.len() == const_args.len() {
120+
// Substitute parameters with constant values
121+
let substituted = substitute(&pure_fn.body, &pure_fn.params, &const_args);
122+
// Recursively evaluate the substituted body
123+
return try_const_eval_with_fns(&substituted, pure_fns);
124+
}
125+
}
126+
127+
None
26128
}
27129
_ => None,
28130
}
29131
}
30132

133+
/// Substitute parameters with values in an expression
134+
fn substitute(expr: &Value, params: &[String], args: &[Value]) -> Value {
135+
match expr {
136+
Value::Symbol(name) => {
137+
for (i, param) in params.iter().enumerate() {
138+
if param == name.as_ref() {
139+
return args[i].clone();
140+
}
141+
}
142+
expr.clone()
143+
}
144+
Value::List(items) => {
145+
let new_items: Vec<Value> = items
146+
.iter()
147+
.map(|item| substitute(item, params, args))
148+
.collect();
149+
Value::list(new_items)
150+
}
151+
_ => expr.clone(),
152+
}
153+
}
154+
31155
/// Try to fold an operation with constant arguments
32156
fn fold_op(op: &str, args: &[&Value]) -> Option<Value> {
33157
match op {
@@ -227,6 +351,7 @@ impl Compiler {
227351
chunk: Chunk::new(),
228352
locals: Vec::new(),
229353
scope_depth: 0,
354+
pure_fns: PureFunctions::new(),
230355
}
231356
}
232357

@@ -239,6 +364,28 @@ impl Compiler {
239364
Ok(compiler.chunk)
240365
}
241366

367+
/// Compile multiple expressions, allowing pure function definitions to be used
368+
/// in subsequent expressions
369+
pub fn compile_all(exprs: &[Value]) -> Result<Chunk, String> {
370+
let mut compiler = Compiler::new();
371+
let dest = compiler.alloc_reg();
372+
373+
if exprs.is_empty() {
374+
compiler.emit(Op::LoadNil(dest));
375+
} else {
376+
// Compile all but last expression (not in tail position)
377+
for expr in &exprs[..exprs.len() - 1] {
378+
compiler.compile_expr(expr, dest, false)?;
379+
}
380+
// Last expression in tail position
381+
compiler.compile_expr(&exprs[exprs.len() - 1], dest, true)?;
382+
}
383+
384+
compiler.emit(Op::Return(dest));
385+
compiler.chunk.num_registers = compiler.locals.len().max(1) as u8 + 16;
386+
Ok(compiler.chunk)
387+
}
388+
242389
pub fn compile_function(params: &[String], body: &Value) -> Result<Chunk, String> {
243390
let mut compiler = Compiler::new();
244391
compiler.chunk.num_params = params.len() as u8;
@@ -400,6 +547,39 @@ impl Compiler {
400547
.as_symbol()
401548
.ok_or("def expects a symbol as first argument")?;
402549

550+
// Check if we're defining a pure function: (def name (fn (params) body))
551+
if let Some(fn_expr) = args[1].as_list() {
552+
if fn_expr.len() >= 3 {
553+
if let Some(fn_sym) = fn_expr[0].as_symbol() {
554+
if fn_sym == "fn" {
555+
if let Some(params_list) = fn_expr[1].as_list() {
556+
// Get parameter names
557+
let params: Option<Vec<String>> = params_list
558+
.iter()
559+
.map(|p| p.as_symbol().map(|s| s.to_string()))
560+
.collect();
561+
562+
if let Some(params) = params {
563+
// Get the body (handle multi-expression body)
564+
let body = if fn_expr.len() == 3 {
565+
fn_expr[2].clone()
566+
} else {
567+
let mut do_list = vec![Value::symbol("do")];
568+
do_list.extend(fn_expr[2..].iter().cloned());
569+
Value::list(do_list)
570+
};
571+
572+
// Check if body is pure (with knowledge of already-registered pure fns)
573+
if is_pure_expr_with_fns(&body, &self.pure_fns) {
574+
self.pure_fns.register(name, params, body);
575+
}
576+
}
577+
}
578+
}
579+
}
580+
}
581+
}
582+
403583
// Compile value
404584
self.compile_expr(&args[1], dest, false)?;
405585

@@ -513,9 +693,9 @@ impl Compiler {
513693
}
514694

515695
fn compile_call(&mut self, items: &[Value], dest: Reg, tail_pos: bool) -> Result<(), String> {
516-
// Try constant folding for the entire call expression
696+
// Try constant folding for the entire call expression (including pure user functions)
517697
let call_expr = Value::list(items.to_vec());
518-
if let Some(folded) = try_const_eval(&call_expr) {
698+
if let Some(folded) = try_const_eval_with_fns(&call_expr, &self.pure_fns) {
519699
let idx = self.add_constant(folded);
520700
self.emit(Op::LoadConst(dest, idx));
521701
return Ok(());
@@ -703,4 +883,47 @@ mod tests {
703883
assert!(matches!(chunk.code[0], Op::LoadConst(0, _)));
704884
assert_eq!(chunk.constants[0], Value::Int(2));
705885
}
886+
887+
#[test]
888+
fn test_pure_function_folding() {
889+
use crate::parser::parse_all;
890+
891+
// Define a pure function and call it with constants
892+
let exprs = parse_all("(def square (fn (n) (* n n))) (square 5)").unwrap();
893+
let chunk = Compiler::compile_all(&exprs).unwrap();
894+
895+
// The call (square 5) should be folded to 25
896+
// Look for LoadConst 25 in the chunk
897+
let has_25 = chunk.constants.iter().any(|c| *c == Value::Int(25));
898+
assert!(has_25, "Expected constant 25 from folding (square 5)");
899+
900+
// Should NOT have a function call for (square 5)
901+
// (there may be a call for def though, so just check we have the constant)
902+
}
903+
904+
#[test]
905+
fn test_pure_function_nested() {
906+
use crate::parser::parse_all;
907+
908+
// Define two pure functions
909+
let exprs = parse_all("(def double (fn (x) (* x 2))) (def quad (fn (x) (double (double x)))) (quad 3)").unwrap();
910+
let chunk = Compiler::compile_all(&exprs).unwrap();
911+
912+
// (quad 3) = (double (double 3)) = (double 6) = 12
913+
let has_12 = chunk.constants.iter().any(|c| *c == Value::Int(12));
914+
assert!(has_12, "Expected constant 12 from folding (quad 3)");
915+
}
916+
917+
#[test]
918+
fn test_impure_function_not_folded() {
919+
use crate::parser::parse_all;
920+
921+
// A function that calls println is not pure
922+
let exprs = parse_all("(def greet (fn (x) (println x))) (greet 5)").unwrap();
923+
let chunk = Compiler::compile_all(&exprs).unwrap();
924+
925+
// Should have a function call (not folded)
926+
let has_call = chunk.code.iter().any(|op| matches!(op, Op::Call(_, _, _) | Op::TailCall(_, _)));
927+
assert!(has_call, "Impure function should not be folded");
928+
}
706929
}

0 commit comments

Comments
 (0)