Skip to content

Commit f731d97

Browse files
authored
Merge pull request #168 from Y-Nak/egglog
Egglog
2 parents 0552c17 + c20c4ea commit f731d97

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+2981
-25
lines changed

Cargo.lock

Lines changed: 692 additions & 17 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/codegen/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ sonatina-ir = { path = "../ir", version = "0.0.3-alpha" }
2222
sonatina-triple = { path = "../triple", version = "0.0.3-alpha" }
2323
sonatina-macros = { path = "../macros", version = "0.0.3-alpha" }
2424
dashmap = { version = "6.1", features = ["rayon"] }
25+
egglog = "0.5"
2526
indexmap = { version = "2.11" }
2627

2728
[dev-dependencies]
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
//! Elaboration: Convert egraph terms back to sonatina IR.
2+
3+
use rustc_hash::FxHashMap;
4+
5+
use sonatina_ir::{
6+
inst::{arith::*, cmp::*, logic::*},
7+
BlockId, Function, Type, Value, ValueId,
8+
};
9+
10+
/// Represents an egglog term that can be elaborated back to IR.
11+
#[derive(Debug, Clone)]
12+
pub enum EggTerm {
13+
Var(String),
14+
Const(i64, Type),
15+
// Binary ops
16+
Add(Box<EggTerm>, Box<EggTerm>),
17+
Sub(Box<EggTerm>, Box<EggTerm>),
18+
Mul(Box<EggTerm>, Box<EggTerm>),
19+
Udiv(Box<EggTerm>, Box<EggTerm>),
20+
Sdiv(Box<EggTerm>, Box<EggTerm>),
21+
Umod(Box<EggTerm>, Box<EggTerm>),
22+
Smod(Box<EggTerm>, Box<EggTerm>),
23+
// Shifts
24+
Shl(Box<EggTerm>, Box<EggTerm>),
25+
Shr(Box<EggTerm>, Box<EggTerm>),
26+
Sar(Box<EggTerm>, Box<EggTerm>),
27+
// Unary
28+
Neg(Box<EggTerm>),
29+
Not(Box<EggTerm>),
30+
// Logic
31+
And(Box<EggTerm>, Box<EggTerm>),
32+
Or(Box<EggTerm>, Box<EggTerm>),
33+
Xor(Box<EggTerm>, Box<EggTerm>),
34+
// Comparisons
35+
Lt(Box<EggTerm>, Box<EggTerm>),
36+
Gt(Box<EggTerm>, Box<EggTerm>),
37+
Le(Box<EggTerm>, Box<EggTerm>),
38+
Ge(Box<EggTerm>, Box<EggTerm>),
39+
Slt(Box<EggTerm>, Box<EggTerm>),
40+
Sgt(Box<EggTerm>, Box<EggTerm>),
41+
Sle(Box<EggTerm>, Box<EggTerm>),
42+
Sge(Box<EggTerm>, Box<EggTerm>),
43+
Eq(Box<EggTerm>, Box<EggTerm>),
44+
Ne(Box<EggTerm>, Box<EggTerm>),
45+
IsZero(Box<EggTerm>),
46+
}
47+
48+
/// Elaborator converts egraph terms back to sonatina IR instructions.
49+
pub struct Elaborator<'a> {
50+
func: &'a mut Function,
51+
block: BlockId,
52+
/// Maps variable names to their ValueIds
53+
var_map: FxHashMap<String, ValueId>,
54+
}
55+
56+
impl<'a> Elaborator<'a> {
57+
pub fn new(func: &'a mut Function, block: BlockId) -> Self {
58+
Self {
59+
func,
60+
block,
61+
var_map: FxHashMap::default(),
62+
}
63+
}
64+
65+
/// Register an existing value with a variable name.
66+
pub fn bind_var(&mut self, name: String, value: ValueId) {
67+
self.var_map.insert(name, value);
68+
}
69+
70+
/// Elaborate a term into IR, returning the resulting ValueId.
71+
pub fn elaborate(&mut self, term: &EggTerm, ty: Type) -> ValueId {
72+
match term {
73+
EggTerm::Var(name) => *self.var_map.get(name).expect("undefined variable"),
74+
EggTerm::Const(val, ty) => self
75+
.func
76+
.dfg
77+
.make_imm_value(sonatina_ir::Immediate::from_i256((*val).into(), *ty)),
78+
EggTerm::Add(lhs, rhs) => self.elaborate_binary::<Add>(lhs, rhs, ty),
79+
EggTerm::Sub(lhs, rhs) => self.elaborate_binary::<Sub>(lhs, rhs, ty),
80+
EggTerm::Mul(lhs, rhs) => self.elaborate_binary::<Mul>(lhs, rhs, ty),
81+
EggTerm::Udiv(lhs, rhs) => self.elaborate_binary::<Udiv>(lhs, rhs, ty),
82+
EggTerm::Sdiv(lhs, rhs) => self.elaborate_binary::<Sdiv>(lhs, rhs, ty),
83+
EggTerm::Umod(lhs, rhs) => self.elaborate_binary::<Umod>(lhs, rhs, ty),
84+
EggTerm::Smod(lhs, rhs) => self.elaborate_binary::<Smod>(lhs, rhs, ty),
85+
EggTerm::Shl(bits, val) => self.elaborate_shift::<Shl>(bits, val, ty),
86+
EggTerm::Shr(bits, val) => self.elaborate_shift::<Shr>(bits, val, ty),
87+
EggTerm::Sar(bits, val) => self.elaborate_shift::<Sar>(bits, val, ty),
88+
EggTerm::Neg(arg) => self.elaborate_unary::<Neg>(arg, ty),
89+
EggTerm::Not(arg) => self.elaborate_unary::<Not>(arg, ty),
90+
EggTerm::And(lhs, rhs) => self.elaborate_binary::<And>(lhs, rhs, ty),
91+
EggTerm::Or(lhs, rhs) => self.elaborate_binary::<Or>(lhs, rhs, ty),
92+
EggTerm::Xor(lhs, rhs) => self.elaborate_binary::<Xor>(lhs, rhs, ty),
93+
EggTerm::Lt(lhs, rhs) => self.elaborate_cmp::<Lt>(lhs, rhs, ty),
94+
EggTerm::Gt(lhs, rhs) => self.elaborate_cmp::<Gt>(lhs, rhs, ty),
95+
EggTerm::Le(lhs, rhs) => self.elaborate_cmp::<Le>(lhs, rhs, ty),
96+
EggTerm::Ge(lhs, rhs) => self.elaborate_cmp::<Ge>(lhs, rhs, ty),
97+
EggTerm::Slt(lhs, rhs) => self.elaborate_cmp::<Slt>(lhs, rhs, ty),
98+
EggTerm::Sgt(lhs, rhs) => self.elaborate_cmp::<Sgt>(lhs, rhs, ty),
99+
EggTerm::Sle(lhs, rhs) => self.elaborate_cmp::<Sle>(lhs, rhs, ty),
100+
EggTerm::Sge(lhs, rhs) => self.elaborate_cmp::<Sge>(lhs, rhs, ty),
101+
EggTerm::Eq(lhs, rhs) => self.elaborate_cmp::<Eq>(lhs, rhs, ty),
102+
EggTerm::Ne(lhs, rhs) => self.elaborate_cmp::<Ne>(lhs, rhs, ty),
103+
EggTerm::IsZero(arg) => self.elaborate_iszero(arg, ty),
104+
}
105+
}
106+
107+
fn elaborate_binary<I>(&mut self, lhs: &EggTerm, rhs: &EggTerm, ty: Type) -> ValueId
108+
where
109+
I: BinaryInst,
110+
{
111+
let lhs_val = self.elaborate(lhs, ty);
112+
let rhs_val = self.elaborate(rhs, ty);
113+
let is = self.func.inst_set();
114+
let inst = I::new(is, lhs_val, rhs_val);
115+
self.make_inst_value(inst, ty)
116+
}
117+
118+
fn elaborate_shift<I>(&mut self, bits: &EggTerm, val: &EggTerm, ty: Type) -> ValueId
119+
where
120+
I: ShiftInst,
121+
{
122+
let bits_val = self.elaborate(bits, ty);
123+
let val_val = self.elaborate(val, ty);
124+
let is = self.func.inst_set();
125+
let inst = I::new(is, bits_val, val_val);
126+
self.make_inst_value(inst, ty)
127+
}
128+
129+
fn elaborate_unary<I>(&mut self, arg: &EggTerm, ty: Type) -> ValueId
130+
where
131+
I: UnaryInst,
132+
{
133+
let arg_val = self.elaborate(arg, ty);
134+
let is = self.func.inst_set();
135+
let inst = I::new(is, arg_val);
136+
self.make_inst_value(inst, ty)
137+
}
138+
139+
fn elaborate_cmp<I>(&mut self, lhs: &EggTerm, rhs: &EggTerm, ty: Type) -> ValueId
140+
where
141+
I: BinaryInst,
142+
{
143+
let lhs_val = self.elaborate(lhs, ty);
144+
let rhs_val = self.elaborate(rhs, ty);
145+
let is = self.func.inst_set();
146+
let inst = I::new(is, lhs_val, rhs_val);
147+
self.make_inst_value(inst, Type::I1)
148+
}
149+
150+
fn elaborate_iszero(&mut self, arg: &EggTerm, ty: Type) -> ValueId {
151+
let arg_val = self.elaborate(arg, ty);
152+
let is = self.func.inst_set();
153+
let inst = IsZero::new(is.has_is_zero().unwrap(), arg_val);
154+
self.make_inst_value(inst, Type::I1)
155+
}
156+
157+
fn make_inst_value<I: sonatina_ir::Inst>(&mut self, inst: I, ty: Type) -> ValueId {
158+
let inst_id = self.func.dfg.make_inst(inst);
159+
let value = Value::Inst { inst: inst_id, ty };
160+
let value_id = self.func.dfg.make_value(value);
161+
self.func.dfg.attach_result(inst_id, value_id);
162+
self.func.layout.append_inst(inst_id, self.block);
163+
value_id
164+
}
165+
}
166+
167+
/// Trait for binary instructions that can be constructed.
168+
trait BinaryInst: sonatina_ir::Inst + Sized {
169+
fn new(is: &dyn sonatina_ir::InstSetBase, lhs: ValueId, rhs: ValueId) -> Self;
170+
}
171+
172+
/// Trait for shift instructions.
173+
trait ShiftInst: sonatina_ir::Inst + Sized {
174+
fn new(is: &dyn sonatina_ir::InstSetBase, bits: ValueId, value: ValueId) -> Self;
175+
}
176+
177+
/// Trait for unary instructions.
178+
trait UnaryInst: sonatina_ir::Inst + Sized {
179+
fn new(is: &dyn sonatina_ir::InstSetBase, arg: ValueId) -> Self;
180+
}
181+
182+
// Implement BinaryInst for binary ops
183+
macro_rules! impl_binary {
184+
($ty:ty, $has:ident) => {
185+
impl BinaryInst for $ty {
186+
fn new(is: &dyn sonatina_ir::InstSetBase, lhs: ValueId, rhs: ValueId) -> Self {
187+
<$ty>::new(is.$has().unwrap(), lhs, rhs)
188+
}
189+
}
190+
};
191+
}
192+
193+
impl_binary!(Add, has_add);
194+
impl_binary!(Sub, has_sub);
195+
impl_binary!(Mul, has_mul);
196+
impl_binary!(Udiv, has_udiv);
197+
impl_binary!(Sdiv, has_sdiv);
198+
impl_binary!(Umod, has_umod);
199+
impl_binary!(Smod, has_smod);
200+
impl_binary!(And, has_and);
201+
impl_binary!(Or, has_or);
202+
impl_binary!(Xor, has_xor);
203+
impl_binary!(Lt, has_lt);
204+
impl_binary!(Gt, has_gt);
205+
impl_binary!(Le, has_le);
206+
impl_binary!(Ge, has_ge);
207+
impl_binary!(Slt, has_slt);
208+
impl_binary!(Sgt, has_sgt);
209+
impl_binary!(Sle, has_sle);
210+
impl_binary!(Sge, has_sge);
211+
impl_binary!(Eq, has_eq);
212+
impl_binary!(Ne, has_ne);
213+
214+
// Implement ShiftInst for shift ops
215+
macro_rules! impl_shift {
216+
($ty:ty, $has:ident) => {
217+
impl ShiftInst for $ty {
218+
fn new(is: &dyn sonatina_ir::InstSetBase, bits: ValueId, value: ValueId) -> Self {
219+
<$ty>::new(is.$has().unwrap(), bits, value)
220+
}
221+
}
222+
};
223+
}
224+
225+
impl_shift!(Shl, has_shl);
226+
impl_shift!(Shr, has_shr);
227+
impl_shift!(Sar, has_sar);
228+
229+
// Implement UnaryInst for unary ops
230+
macro_rules! impl_unary {
231+
($ty:ty, $has:ident) => {
232+
impl UnaryInst for $ty {
233+
fn new(is: &dyn sonatina_ir::InstSetBase, arg: ValueId) -> Self {
234+
<$ty>::new(is.$has().unwrap(), arg)
235+
}
236+
}
237+
};
238+
}
239+
240+
impl_unary!(Neg, has_neg);
241+
impl_unary!(Not, has_not);
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
; Sonatina IR Pure Expression Definitions for egglog
2+
; Pure operations that can float freely in the egraph
3+
; Maps to sonatina-ir/src/inst/{arith,cmp,logic,cast}.rs
4+
; NOTE: Requires types.egg to be loaded first
5+
6+
(datatype Expr
7+
; --- Values ---
8+
(Const i64 Type) ; immediate value
9+
(Arg i64 Type) ; function argument (index, type)
10+
(Global i64 Type) ; global variable ref
11+
(Undef Type) ; undefined value
12+
; Opaque result from side-effecting instruction (phi, call, etc.)
13+
; The i64 is a unique ID (e.g., ValueId) to distinguish different results
14+
(SideEffectResult i64 Type)
15+
16+
; --- Memory Addresses (data.rs) ---
17+
; Alloca result - represents a unique stack allocation
18+
; The i64 is the alloca instruction's unique ID (ValueId)
19+
(AllocaResult i64 Type)
20+
; GEP (GetElementPointer) - computes address from base + indices
21+
(GepBase Expr) ; Base pointer for GEP chain
22+
(GepOffset Expr Expr i64) ; (base_gep, index_expr, field_idx)
23+
24+
; --- Memory Loads ---
25+
; LoadResult represents the result of a memory load operation
26+
; Parameters: (unique_id, mem_state_id, type)
27+
; The mem_state_id refers to which memory state this load reads from
28+
(LoadResult i64 i64 Type)
29+
30+
; --- Arithmetic (arith.rs) ---
31+
(Neg Expr)
32+
(Add Expr Expr)
33+
(Sub Expr Expr)
34+
(Mul Expr Expr)
35+
(Udiv Expr Expr)
36+
(Sdiv Expr Expr)
37+
(Umod Expr Expr)
38+
(Smod Expr Expr)
39+
(Shl Expr Expr) ; bits, value
40+
(Shr Expr Expr)
41+
(Sar Expr Expr)
42+
43+
; --- Logic (logic.rs) ---
44+
(Not Expr)
45+
(And Expr Expr)
46+
(Or Expr Expr)
47+
(Xor Expr Expr)
48+
49+
; --- Comparisons (cmp.rs) ---
50+
(Lt Expr Expr) ; unsigned
51+
(Gt Expr Expr)
52+
(Le Expr Expr)
53+
(Ge Expr Expr)
54+
(Slt Expr Expr) ; signed
55+
(Sgt Expr Expr)
56+
(Sle Expr Expr)
57+
(Sge Expr Expr)
58+
(Eq Expr Expr)
59+
(Ne Expr Expr)
60+
(IsZero Expr)
61+
62+
; --- Casts (cast.rs) ---
63+
(Sext Expr Type)
64+
(Zext Expr Type)
65+
(Trunc Expr Type)
66+
(Bitcast Expr Type)
67+
(IntToPtr Expr Type)
68+
(PtrToInt Expr Type)
69+
70+
; --- Aggregates (data.rs - pure parts) ---
71+
(ExtractValue Expr i64)
72+
(InsertValue Expr i64 Expr)
73+
)
74+
75+
; === Memory State Tracking ===
76+
; Memory states are tracked by integer IDs
77+
; mem_state_id 0 = InitMem (initial memory state)
78+
; mem_state_id N (N > 0) = AfterStore with that ID
79+
80+
; Store metadata functions
81+
; store-prev: mem_state_id -> prev_mem_state_id
82+
; store-addr: mem_state_id -> address expression
83+
; store-val: mem_state_id -> stored value expression
84+
; store-ty: mem_state_id -> stored type
85+
(function store-prev (i64) i64 :merge old)
86+
(function store-addr (i64) Expr :merge old)
87+
(function store-val (i64) Expr :merge old)
88+
(function store-ty (i64) Type :merge old)
89+
90+
; Load metadata functions
91+
; load-addr: load_id -> address expression
92+
(function load-addr (i64) Expr :merge old)
93+
94+
; === Memory Phi (merge points) ===
95+
; At control flow merge points, memory state needs to be merged
96+
; A MemPhi represents a memory state that could be any of its predecessor states
97+
98+
; Relation to mark a memory state as being a MemPhi
99+
(relation is-memphi (i64))
100+
101+
; memphi-pred: (memphi_id, pred_idx) -> predecessor memory state id
102+
; This tracks which memory states flow into a MemPhi
103+
(function memphi-pred (i64 i64) i64 :merge old)
104+
105+
; memphi-num-preds: memphi_id -> number of predecessors
106+
(function memphi-num-preds (i64) i64 :merge old)
107+
108+
; Type inference for expressions
109+
(function expr-type (Expr) Type :merge old)
110+
111+
; Values carry their type
112+
(rule ((= e (Const v ty))) ((set (expr-type e) ty)))
113+
(rule ((= e (Arg i ty))) ((set (expr-type e) ty)))
114+
(rule ((= e (Global i ty)))((set (expr-type e) ty)))
115+
(rule ((= e (Undef ty))) ((set (expr-type e) ty)))
116+
(rule ((= e (SideEffectResult i ty))) ((set (expr-type e) ty)))
117+
(rule ((= e (LoadResult _ _ ty))) ((set (expr-type e) ty)))
118+
119+
; Unary ops preserve type
120+
(rule ((= e (Neg x)) (= ty (expr-type x))) ((set (expr-type e) ty)))
121+
(rule ((= e (Not x)) (= ty (expr-type x))) ((set (expr-type e) ty)))
122+
123+
; Comparisons return i1
124+
(rule ((= e (Lt _ _))) ((set (expr-type e) (I1))))
125+
(rule ((= e (Gt _ _))) ((set (expr-type e) (I1))))
126+
(rule ((= e (Le _ _))) ((set (expr-type e) (I1))))
127+
(rule ((= e (Ge _ _))) ((set (expr-type e) (I1))))
128+
(rule ((= e (Slt _ _))) ((set (expr-type e) (I1))))
129+
(rule ((= e (Sgt _ _))) ((set (expr-type e) (I1))))
130+
(rule ((= e (Sle _ _))) ((set (expr-type e) (I1))))
131+
(rule ((= e (Sge _ _))) ((set (expr-type e) (I1))))
132+
(rule ((= e (Eq _ _))) ((set (expr-type e) (I1))))
133+
(rule ((= e (Ne _ _))) ((set (expr-type e) (I1))))
134+
(rule ((= e (IsZero _))) ((set (expr-type e) (I1))))
135+
136+
; Casts use target type
137+
(rule ((= e (Sext _ ty))) ((set (expr-type e) ty)))
138+
(rule ((= e (Zext _ ty))) ((set (expr-type e) ty)))
139+
(rule ((= e (Trunc _ ty))) ((set (expr-type e) ty)))
140+
(rule ((= e (Bitcast _ ty))) ((set (expr-type e) ty)))
141+
(rule ((= e (IntToPtr _ ty)))((set (expr-type e) ty)))
142+
(rule ((= e (PtrToInt _ ty)))((set (expr-type e) ty)))
143+
144+
; Memory address types
145+
(rule ((= e (AllocaResult _ ty))) ((set (expr-type e) ty)))
146+
(rule ((= e (GepBase base)) (= ty (expr-type base))) ((set (expr-type e) ty)))
147+
(rule ((= e (GepOffset base _ _)) (= ty (expr-type base))) ((set (expr-type e) ty)))

0 commit comments

Comments
 (0)