11use crate :: bytecode:: { Chunk , ConstIdx , Op , Reg } ;
22use crate :: value:: Value ;
3+ use std:: collections:: HashMap ;
4+
5+ /// A pure function definition (for compile-time evaluation)
6+ #[ derive( Clone ) ]
7+ struct PureFunction {
8+ params : Vec < String > ,
9+ body : Value ,
10+ }
11+
12+ /// Registry of pure functions for compile-time evaluation
13+ #[ derive( Clone , Default ) ]
14+ struct PureFunctions {
15+ funcs : HashMap < String , PureFunction > ,
16+ }
17+
18+ impl PureFunctions {
19+ fn new ( ) -> Self {
20+ PureFunctions {
21+ funcs : HashMap :: new ( ) ,
22+ }
23+ }
24+
25+ fn register ( & mut self , name : & str , params : Vec < String > , body : Value ) {
26+ self . funcs . insert ( name. to_string ( ) , PureFunction { params, body } ) ;
27+ }
28+
29+ fn get ( & self , name : & str ) -> Option < & PureFunction > {
30+ self . funcs . get ( name)
31+ }
32+ }
333
434pub struct Compiler {
535 chunk : Chunk ,
6- locals : Vec < String > , // local variable names, index = register
36+ locals : Vec < String > ,
737 scope_depth : usize ,
38+ pure_fns : PureFunctions ,
39+ }
40+
41+ /// Check if an expression is pure (no side effects)
42+ fn is_pure_expr_with_fns ( expr : & Value , pure_fns : & PureFunctions ) -> bool {
43+ match expr {
44+ // Literals are pure
45+ Value :: Int ( _) | Value :: Float ( _) | Value :: Bool ( _) | Value :: Nil | Value :: String ( _) => true ,
46+ // Symbols are pure (just variable references)
47+ Value :: Symbol ( _) => true ,
48+ // Lists need to be checked
49+ Value :: List ( items) if items. is_empty ( ) => true ,
50+ Value :: List ( items) => {
51+ let first = & items[ 0 ] ;
52+ if let Some ( sym) = first. as_symbol ( ) {
53+ match sym {
54+ // Pure built-in operations
55+ "+" | "-" | "*" | "/" | "mod" | "<" | "<=" | ">" | ">=" | "=" | "!=" | "not" => {
56+ items[ 1 ..] . iter ( ) . all ( |e| is_pure_expr_with_fns ( e, pure_fns) )
57+ }
58+ // Conditional is pure if branches are pure
59+ "if" => items[ 1 ..] . iter ( ) . all ( |e| is_pure_expr_with_fns ( e, pure_fns) ) ,
60+ // Let is pure if bindings and body are pure
61+ "let" => {
62+ if items. len ( ) >= 3 {
63+ if let Some ( bindings) = items[ 1 ] . as_list ( ) {
64+ bindings. iter ( ) . all ( |e| is_pure_expr_with_fns ( e, pure_fns) ) &&
65+ items[ 2 ..] . iter ( ) . all ( |e| is_pure_expr_with_fns ( e, pure_fns) )
66+ } else {
67+ false
68+ }
69+ } else {
70+ false
71+ }
72+ }
73+ // Quote is pure
74+ "quote" => true ,
75+ // Check if it's a known pure function
76+ _ => {
77+ if pure_fns. get ( sym) . is_some ( ) {
78+ // It's a call to a known pure function, check args are pure
79+ items[ 1 ..] . iter ( ) . all ( |e| is_pure_expr_with_fns ( e, pure_fns) )
80+ } else {
81+ false
82+ }
83+ }
84+ }
85+ } else {
86+ false
87+ }
88+ }
89+ _ => false ,
90+ }
891}
992
10- /// Try to evaluate a constant expression recursively
11- fn try_const_eval ( expr : & Value ) -> Option < Value > {
93+
94+ /// Try to evaluate a constant expression recursively, including pure function calls
95+ fn try_const_eval_with_fns ( expr : & Value , pure_fns : & PureFunctions ) -> Option < Value > {
1296 match expr {
1397 // Literals are constants
1498 Value :: Int ( _) | Value :: Float ( _) | Value :: Bool ( _) | Value :: Nil | Value :: String ( _) => {
@@ -18,16 +102,56 @@ fn try_const_eval(expr: &Value) -> Option<Value> {
18102 Value :: List ( items) if !items. is_empty ( ) => {
19103 let op = items[ 0 ] . as_symbol ( ) ?;
20104 let args = & items[ 1 ..] ;
21- // Recursively evaluate arguments
22- let const_args: Option < Vec < Value > > = args. iter ( ) . map ( try_const_eval) . collect ( ) ;
105+
106+ // First try built-in operations
107+ let const_args: Option < Vec < Value > > = args. iter ( )
108+ . map ( |a| try_const_eval_with_fns ( a, pure_fns) )
109+ . collect ( ) ;
23110 let const_args = const_args?;
24111 let const_refs: Vec < & Value > = const_args. iter ( ) . collect ( ) ;
25- fold_op ( op, & const_refs)
112+
113+ if let Some ( result) = fold_op ( op, & const_refs) {
114+ return Some ( result) ;
115+ }
116+
117+ // Try pure user-defined functions
118+ if let Some ( pure_fn) = pure_fns. get ( op) {
119+ if pure_fn. params . len ( ) == const_args. len ( ) {
120+ // Substitute parameters with constant values
121+ let substituted = substitute ( & pure_fn. body , & pure_fn. params , & const_args) ;
122+ // Recursively evaluate the substituted body
123+ return try_const_eval_with_fns ( & substituted, pure_fns) ;
124+ }
125+ }
126+
127+ None
26128 }
27129 _ => None ,
28130 }
29131}
30132
133+ /// Substitute parameters with values in an expression
134+ fn substitute ( expr : & Value , params : & [ String ] , args : & [ Value ] ) -> Value {
135+ match expr {
136+ Value :: Symbol ( name) => {
137+ for ( i, param) in params. iter ( ) . enumerate ( ) {
138+ if param == name. as_ref ( ) {
139+ return args[ i] . clone ( ) ;
140+ }
141+ }
142+ expr. clone ( )
143+ }
144+ Value :: List ( items) => {
145+ let new_items: Vec < Value > = items
146+ . iter ( )
147+ . map ( |item| substitute ( item, params, args) )
148+ . collect ( ) ;
149+ Value :: list ( new_items)
150+ }
151+ _ => expr. clone ( ) ,
152+ }
153+ }
154+
31155/// Try to fold an operation with constant arguments
32156fn fold_op ( op : & str , args : & [ & Value ] ) -> Option < Value > {
33157 match op {
@@ -227,6 +351,7 @@ impl Compiler {
227351 chunk : Chunk :: new ( ) ,
228352 locals : Vec :: new ( ) ,
229353 scope_depth : 0 ,
354+ pure_fns : PureFunctions :: new ( ) ,
230355 }
231356 }
232357
@@ -239,6 +364,28 @@ impl Compiler {
239364 Ok ( compiler. chunk )
240365 }
241366
367+ /// Compile multiple expressions, allowing pure function definitions to be used
368+ /// in subsequent expressions
369+ pub fn compile_all ( exprs : & [ Value ] ) -> Result < Chunk , String > {
370+ let mut compiler = Compiler :: new ( ) ;
371+ let dest = compiler. alloc_reg ( ) ;
372+
373+ if exprs. is_empty ( ) {
374+ compiler. emit ( Op :: LoadNil ( dest) ) ;
375+ } else {
376+ // Compile all but last expression (not in tail position)
377+ for expr in & exprs[ ..exprs. len ( ) - 1 ] {
378+ compiler. compile_expr ( expr, dest, false ) ?;
379+ }
380+ // Last expression in tail position
381+ compiler. compile_expr ( & exprs[ exprs. len ( ) - 1 ] , dest, true ) ?;
382+ }
383+
384+ compiler. emit ( Op :: Return ( dest) ) ;
385+ compiler. chunk . num_registers = compiler. locals . len ( ) . max ( 1 ) as u8 + 16 ;
386+ Ok ( compiler. chunk )
387+ }
388+
242389 pub fn compile_function ( params : & [ String ] , body : & Value ) -> Result < Chunk , String > {
243390 let mut compiler = Compiler :: new ( ) ;
244391 compiler. chunk . num_params = params. len ( ) as u8 ;
@@ -400,6 +547,39 @@ impl Compiler {
400547 . as_symbol ( )
401548 . ok_or ( "def expects a symbol as first argument" ) ?;
402549
550+ // Check if we're defining a pure function: (def name (fn (params) body))
551+ if let Some ( fn_expr) = args[ 1 ] . as_list ( ) {
552+ if fn_expr. len ( ) >= 3 {
553+ if let Some ( fn_sym) = fn_expr[ 0 ] . as_symbol ( ) {
554+ if fn_sym == "fn" {
555+ if let Some ( params_list) = fn_expr[ 1 ] . as_list ( ) {
556+ // Get parameter names
557+ let params: Option < Vec < String > > = params_list
558+ . iter ( )
559+ . map ( |p| p. as_symbol ( ) . map ( |s| s. to_string ( ) ) )
560+ . collect ( ) ;
561+
562+ if let Some ( params) = params {
563+ // Get the body (handle multi-expression body)
564+ let body = if fn_expr. len ( ) == 3 {
565+ fn_expr[ 2 ] . clone ( )
566+ } else {
567+ let mut do_list = vec ! [ Value :: symbol( "do" ) ] ;
568+ do_list. extend ( fn_expr[ 2 ..] . iter ( ) . cloned ( ) ) ;
569+ Value :: list ( do_list)
570+ } ;
571+
572+ // Check if body is pure (with knowledge of already-registered pure fns)
573+ if is_pure_expr_with_fns ( & body, & self . pure_fns ) {
574+ self . pure_fns . register ( name, params, body) ;
575+ }
576+ }
577+ }
578+ }
579+ }
580+ }
581+ }
582+
403583 // Compile value
404584 self . compile_expr ( & args[ 1 ] , dest, false ) ?;
405585
@@ -513,9 +693,9 @@ impl Compiler {
513693 }
514694
515695 fn compile_call ( & mut self , items : & [ Value ] , dest : Reg , tail_pos : bool ) -> Result < ( ) , String > {
516- // Try constant folding for the entire call expression
696+ // Try constant folding for the entire call expression (including pure user functions)
517697 let call_expr = Value :: list ( items. to_vec ( ) ) ;
518- if let Some ( folded) = try_const_eval ( & call_expr) {
698+ if let Some ( folded) = try_const_eval_with_fns ( & call_expr, & self . pure_fns ) {
519699 let idx = self . add_constant ( folded) ;
520700 self . emit ( Op :: LoadConst ( dest, idx) ) ;
521701 return Ok ( ( ) ) ;
@@ -703,4 +883,47 @@ mod tests {
703883 assert ! ( matches!( chunk. code[ 0 ] , Op :: LoadConst ( 0 , _) ) ) ;
704884 assert_eq ! ( chunk. constants[ 0 ] , Value :: Int ( 2 ) ) ;
705885 }
886+
887+ #[ test]
888+ fn test_pure_function_folding ( ) {
889+ use crate :: parser:: parse_all;
890+
891+ // Define a pure function and call it with constants
892+ let exprs = parse_all ( "(def square (fn (n) (* n n))) (square 5)" ) . unwrap ( ) ;
893+ let chunk = Compiler :: compile_all ( & exprs) . unwrap ( ) ;
894+
895+ // The call (square 5) should be folded to 25
896+ // Look for LoadConst 25 in the chunk
897+ let has_25 = chunk. constants . iter ( ) . any ( |c| * c == Value :: Int ( 25 ) ) ;
898+ assert ! ( has_25, "Expected constant 25 from folding (square 5)" ) ;
899+
900+ // Should NOT have a function call for (square 5)
901+ // (there may be a call for def though, so just check we have the constant)
902+ }
903+
904+ #[ test]
905+ fn test_pure_function_nested ( ) {
906+ use crate :: parser:: parse_all;
907+
908+ // Define two pure functions
909+ let exprs = parse_all ( "(def double (fn (x) (* x 2))) (def quad (fn (x) (double (double x)))) (quad 3)" ) . unwrap ( ) ;
910+ let chunk = Compiler :: compile_all ( & exprs) . unwrap ( ) ;
911+
912+ // (quad 3) = (double (double 3)) = (double 6) = 12
913+ let has_12 = chunk. constants . iter ( ) . any ( |c| * c == Value :: Int ( 12 ) ) ;
914+ assert ! ( has_12, "Expected constant 12 from folding (quad 3)" ) ;
915+ }
916+
917+ #[ test]
918+ fn test_impure_function_not_folded ( ) {
919+ use crate :: parser:: parse_all;
920+
921+ // A function that calls println is not pure
922+ let exprs = parse_all ( "(def greet (fn (x) (println x))) (greet 5)" ) . unwrap ( ) ;
923+ let chunk = Compiler :: compile_all ( & exprs) . unwrap ( ) ;
924+
925+ // Should have a function call (not folded)
926+ let has_call = chunk. code . iter ( ) . any ( |op| matches ! ( op, Op :: Call ( _, _, _) | Op :: TailCall ( _, _) ) ) ;
927+ assert ! ( has_call, "Impure function should not be folded" ) ;
928+ }
706929}
0 commit comments