|
| 1 | +//! Elaboration: Convert egraph terms back to sonatina IR. |
| 2 | +
|
| 3 | +use rustc_hash::FxHashMap; |
| 4 | + |
| 5 | +use sonatina_ir::{ |
| 6 | + inst::{arith::*, cmp::*, logic::*}, |
| 7 | + BlockId, Function, Type, Value, ValueId, |
| 8 | +}; |
| 9 | + |
| 10 | +/// Represents an egglog term that can be elaborated back to IR. |
| 11 | +#[derive(Debug, Clone)] |
| 12 | +pub enum EggTerm { |
| 13 | + Var(String), |
| 14 | + Const(i64, Type), |
| 15 | + // Binary ops |
| 16 | + Add(Box<EggTerm>, Box<EggTerm>), |
| 17 | + Sub(Box<EggTerm>, Box<EggTerm>), |
| 18 | + Mul(Box<EggTerm>, Box<EggTerm>), |
| 19 | + Udiv(Box<EggTerm>, Box<EggTerm>), |
| 20 | + Sdiv(Box<EggTerm>, Box<EggTerm>), |
| 21 | + Umod(Box<EggTerm>, Box<EggTerm>), |
| 22 | + Smod(Box<EggTerm>, Box<EggTerm>), |
| 23 | + // Shifts |
| 24 | + Shl(Box<EggTerm>, Box<EggTerm>), |
| 25 | + Shr(Box<EggTerm>, Box<EggTerm>), |
| 26 | + Sar(Box<EggTerm>, Box<EggTerm>), |
| 27 | + // Unary |
| 28 | + Neg(Box<EggTerm>), |
| 29 | + Not(Box<EggTerm>), |
| 30 | + // Logic |
| 31 | + And(Box<EggTerm>, Box<EggTerm>), |
| 32 | + Or(Box<EggTerm>, Box<EggTerm>), |
| 33 | + Xor(Box<EggTerm>, Box<EggTerm>), |
| 34 | + // Comparisons |
| 35 | + Lt(Box<EggTerm>, Box<EggTerm>), |
| 36 | + Gt(Box<EggTerm>, Box<EggTerm>), |
| 37 | + Le(Box<EggTerm>, Box<EggTerm>), |
| 38 | + Ge(Box<EggTerm>, Box<EggTerm>), |
| 39 | + Slt(Box<EggTerm>, Box<EggTerm>), |
| 40 | + Sgt(Box<EggTerm>, Box<EggTerm>), |
| 41 | + Sle(Box<EggTerm>, Box<EggTerm>), |
| 42 | + Sge(Box<EggTerm>, Box<EggTerm>), |
| 43 | + Eq(Box<EggTerm>, Box<EggTerm>), |
| 44 | + Ne(Box<EggTerm>, Box<EggTerm>), |
| 45 | + IsZero(Box<EggTerm>), |
| 46 | +} |
| 47 | + |
| 48 | +/// Elaborator converts egraph terms back to sonatina IR instructions. |
| 49 | +pub struct Elaborator<'a> { |
| 50 | + func: &'a mut Function, |
| 51 | + block: BlockId, |
| 52 | + /// Maps variable names to their ValueIds |
| 53 | + var_map: FxHashMap<String, ValueId>, |
| 54 | +} |
| 55 | + |
| 56 | +impl<'a> Elaborator<'a> { |
| 57 | + pub fn new(func: &'a mut Function, block: BlockId) -> Self { |
| 58 | + Self { |
| 59 | + func, |
| 60 | + block, |
| 61 | + var_map: FxHashMap::default(), |
| 62 | + } |
| 63 | + } |
| 64 | + |
| 65 | + /// Register an existing value with a variable name. |
| 66 | + pub fn bind_var(&mut self, name: String, value: ValueId) { |
| 67 | + self.var_map.insert(name, value); |
| 68 | + } |
| 69 | + |
| 70 | + /// Elaborate a term into IR, returning the resulting ValueId. |
| 71 | + pub fn elaborate(&mut self, term: &EggTerm, ty: Type) -> ValueId { |
| 72 | + match term { |
| 73 | + EggTerm::Var(name) => *self.var_map.get(name).expect("undefined variable"), |
| 74 | + EggTerm::Const(val, ty) => self |
| 75 | + .func |
| 76 | + .dfg |
| 77 | + .make_imm_value(sonatina_ir::Immediate::from_i256((*val).into(), *ty)), |
| 78 | + EggTerm::Add(lhs, rhs) => self.elaborate_binary::<Add>(lhs, rhs, ty), |
| 79 | + EggTerm::Sub(lhs, rhs) => self.elaborate_binary::<Sub>(lhs, rhs, ty), |
| 80 | + EggTerm::Mul(lhs, rhs) => self.elaborate_binary::<Mul>(lhs, rhs, ty), |
| 81 | + EggTerm::Udiv(lhs, rhs) => self.elaborate_binary::<Udiv>(lhs, rhs, ty), |
| 82 | + EggTerm::Sdiv(lhs, rhs) => self.elaborate_binary::<Sdiv>(lhs, rhs, ty), |
| 83 | + EggTerm::Umod(lhs, rhs) => self.elaborate_binary::<Umod>(lhs, rhs, ty), |
| 84 | + EggTerm::Smod(lhs, rhs) => self.elaborate_binary::<Smod>(lhs, rhs, ty), |
| 85 | + EggTerm::Shl(bits, val) => self.elaborate_shift::<Shl>(bits, val, ty), |
| 86 | + EggTerm::Shr(bits, val) => self.elaborate_shift::<Shr>(bits, val, ty), |
| 87 | + EggTerm::Sar(bits, val) => self.elaborate_shift::<Sar>(bits, val, ty), |
| 88 | + EggTerm::Neg(arg) => self.elaborate_unary::<Neg>(arg, ty), |
| 89 | + EggTerm::Not(arg) => self.elaborate_unary::<Not>(arg, ty), |
| 90 | + EggTerm::And(lhs, rhs) => self.elaborate_binary::<And>(lhs, rhs, ty), |
| 91 | + EggTerm::Or(lhs, rhs) => self.elaborate_binary::<Or>(lhs, rhs, ty), |
| 92 | + EggTerm::Xor(lhs, rhs) => self.elaborate_binary::<Xor>(lhs, rhs, ty), |
| 93 | + EggTerm::Lt(lhs, rhs) => self.elaborate_cmp::<Lt>(lhs, rhs, ty), |
| 94 | + EggTerm::Gt(lhs, rhs) => self.elaborate_cmp::<Gt>(lhs, rhs, ty), |
| 95 | + EggTerm::Le(lhs, rhs) => self.elaborate_cmp::<Le>(lhs, rhs, ty), |
| 96 | + EggTerm::Ge(lhs, rhs) => self.elaborate_cmp::<Ge>(lhs, rhs, ty), |
| 97 | + EggTerm::Slt(lhs, rhs) => self.elaborate_cmp::<Slt>(lhs, rhs, ty), |
| 98 | + EggTerm::Sgt(lhs, rhs) => self.elaborate_cmp::<Sgt>(lhs, rhs, ty), |
| 99 | + EggTerm::Sle(lhs, rhs) => self.elaborate_cmp::<Sle>(lhs, rhs, ty), |
| 100 | + EggTerm::Sge(lhs, rhs) => self.elaborate_cmp::<Sge>(lhs, rhs, ty), |
| 101 | + EggTerm::Eq(lhs, rhs) => self.elaborate_cmp::<Eq>(lhs, rhs, ty), |
| 102 | + EggTerm::Ne(lhs, rhs) => self.elaborate_cmp::<Ne>(lhs, rhs, ty), |
| 103 | + EggTerm::IsZero(arg) => self.elaborate_iszero(arg, ty), |
| 104 | + } |
| 105 | + } |
| 106 | + |
| 107 | + fn elaborate_binary<I>(&mut self, lhs: &EggTerm, rhs: &EggTerm, ty: Type) -> ValueId |
| 108 | + where |
| 109 | + I: BinaryInst, |
| 110 | + { |
| 111 | + let lhs_val = self.elaborate(lhs, ty); |
| 112 | + let rhs_val = self.elaborate(rhs, ty); |
| 113 | + let is = self.func.inst_set(); |
| 114 | + let inst = I::new(is, lhs_val, rhs_val); |
| 115 | + self.make_inst_value(inst, ty) |
| 116 | + } |
| 117 | + |
| 118 | + fn elaborate_shift<I>(&mut self, bits: &EggTerm, val: &EggTerm, ty: Type) -> ValueId |
| 119 | + where |
| 120 | + I: ShiftInst, |
| 121 | + { |
| 122 | + let bits_val = self.elaborate(bits, ty); |
| 123 | + let val_val = self.elaborate(val, ty); |
| 124 | + let is = self.func.inst_set(); |
| 125 | + let inst = I::new(is, bits_val, val_val); |
| 126 | + self.make_inst_value(inst, ty) |
| 127 | + } |
| 128 | + |
| 129 | + fn elaborate_unary<I>(&mut self, arg: &EggTerm, ty: Type) -> ValueId |
| 130 | + where |
| 131 | + I: UnaryInst, |
| 132 | + { |
| 133 | + let arg_val = self.elaborate(arg, ty); |
| 134 | + let is = self.func.inst_set(); |
| 135 | + let inst = I::new(is, arg_val); |
| 136 | + self.make_inst_value(inst, ty) |
| 137 | + } |
| 138 | + |
| 139 | + fn elaborate_cmp<I>(&mut self, lhs: &EggTerm, rhs: &EggTerm, ty: Type) -> ValueId |
| 140 | + where |
| 141 | + I: BinaryInst, |
| 142 | + { |
| 143 | + let lhs_val = self.elaborate(lhs, ty); |
| 144 | + let rhs_val = self.elaborate(rhs, ty); |
| 145 | + let is = self.func.inst_set(); |
| 146 | + let inst = I::new(is, lhs_val, rhs_val); |
| 147 | + self.make_inst_value(inst, Type::I1) |
| 148 | + } |
| 149 | + |
| 150 | + fn elaborate_iszero(&mut self, arg: &EggTerm, ty: Type) -> ValueId { |
| 151 | + let arg_val = self.elaborate(arg, ty); |
| 152 | + let is = self.func.inst_set(); |
| 153 | + let inst = IsZero::new(is.has_is_zero().unwrap(), arg_val); |
| 154 | + self.make_inst_value(inst, Type::I1) |
| 155 | + } |
| 156 | + |
| 157 | + fn make_inst_value<I: sonatina_ir::Inst>(&mut self, inst: I, ty: Type) -> ValueId { |
| 158 | + let inst_id = self.func.dfg.make_inst(inst); |
| 159 | + let value = Value::Inst { inst: inst_id, ty }; |
| 160 | + let value_id = self.func.dfg.make_value(value); |
| 161 | + self.func.dfg.attach_result(inst_id, value_id); |
| 162 | + self.func.layout.append_inst(inst_id, self.block); |
| 163 | + value_id |
| 164 | + } |
| 165 | +} |
| 166 | + |
| 167 | +/// Trait for binary instructions that can be constructed. |
| 168 | +trait BinaryInst: sonatina_ir::Inst + Sized { |
| 169 | + fn new(is: &dyn sonatina_ir::InstSetBase, lhs: ValueId, rhs: ValueId) -> Self; |
| 170 | +} |
| 171 | + |
| 172 | +/// Trait for shift instructions. |
| 173 | +trait ShiftInst: sonatina_ir::Inst + Sized { |
| 174 | + fn new(is: &dyn sonatina_ir::InstSetBase, bits: ValueId, value: ValueId) -> Self; |
| 175 | +} |
| 176 | + |
| 177 | +/// Trait for unary instructions. |
| 178 | +trait UnaryInst: sonatina_ir::Inst + Sized { |
| 179 | + fn new(is: &dyn sonatina_ir::InstSetBase, arg: ValueId) -> Self; |
| 180 | +} |
| 181 | + |
| 182 | +// Implement BinaryInst for binary ops |
| 183 | +macro_rules! impl_binary { |
| 184 | + ($ty:ty, $has:ident) => { |
| 185 | + impl BinaryInst for $ty { |
| 186 | + fn new(is: &dyn sonatina_ir::InstSetBase, lhs: ValueId, rhs: ValueId) -> Self { |
| 187 | + <$ty>::new(is.$has().unwrap(), lhs, rhs) |
| 188 | + } |
| 189 | + } |
| 190 | + }; |
| 191 | +} |
| 192 | + |
| 193 | +impl_binary!(Add, has_add); |
| 194 | +impl_binary!(Sub, has_sub); |
| 195 | +impl_binary!(Mul, has_mul); |
| 196 | +impl_binary!(Udiv, has_udiv); |
| 197 | +impl_binary!(Sdiv, has_sdiv); |
| 198 | +impl_binary!(Umod, has_umod); |
| 199 | +impl_binary!(Smod, has_smod); |
| 200 | +impl_binary!(And, has_and); |
| 201 | +impl_binary!(Or, has_or); |
| 202 | +impl_binary!(Xor, has_xor); |
| 203 | +impl_binary!(Lt, has_lt); |
| 204 | +impl_binary!(Gt, has_gt); |
| 205 | +impl_binary!(Le, has_le); |
| 206 | +impl_binary!(Ge, has_ge); |
| 207 | +impl_binary!(Slt, has_slt); |
| 208 | +impl_binary!(Sgt, has_sgt); |
| 209 | +impl_binary!(Sle, has_sle); |
| 210 | +impl_binary!(Sge, has_sge); |
| 211 | +impl_binary!(Eq, has_eq); |
| 212 | +impl_binary!(Ne, has_ne); |
| 213 | + |
| 214 | +// Implement ShiftInst for shift ops |
| 215 | +macro_rules! impl_shift { |
| 216 | + ($ty:ty, $has:ident) => { |
| 217 | + impl ShiftInst for $ty { |
| 218 | + fn new(is: &dyn sonatina_ir::InstSetBase, bits: ValueId, value: ValueId) -> Self { |
| 219 | + <$ty>::new(is.$has().unwrap(), bits, value) |
| 220 | + } |
| 221 | + } |
| 222 | + }; |
| 223 | +} |
| 224 | + |
| 225 | +impl_shift!(Shl, has_shl); |
| 226 | +impl_shift!(Shr, has_shr); |
| 227 | +impl_shift!(Sar, has_sar); |
| 228 | + |
| 229 | +// Implement UnaryInst for unary ops |
| 230 | +macro_rules! impl_unary { |
| 231 | + ($ty:ty, $has:ident) => { |
| 232 | + impl UnaryInst for $ty { |
| 233 | + fn new(is: &dyn sonatina_ir::InstSetBase, arg: ValueId) -> Self { |
| 234 | + <$ty>::new(is.$has().unwrap(), arg) |
| 235 | + } |
| 236 | + } |
| 237 | + }; |
| 238 | +} |
| 239 | + |
| 240 | +impl_unary!(Neg, has_neg); |
| 241 | +impl_unary!(Not, has_not); |
0 commit comments