diff --git a/cranelift/codegen/src/isa/aarch64/inst.isle b/cranelift/codegen/src/isa/aarch64/inst.isle index d48c2f63fd0b..d851e2c2b3f3 100644 --- a/cranelift/codegen/src/isa/aarch64/inst.isle +++ b/cranelift/codegen/src/isa/aarch64/inst.isle @@ -2474,7 +2474,7 @@ ;; Helper for emitting `MInst.FpuCSel16` / `MInst.FpuCSel32` / `MInst.FpuCSel64` ;; instructions. -(decl fpu_csel (Type Cond Reg Reg) ConsumesFlags) +(decl rec fpu_csel (Type Cond Reg Reg) ConsumesFlags) (rule (fpu_csel $F16 cond if_true if_false) (fpu_csel $F32 cond if_true if_false)) @@ -2526,7 +2526,7 @@ ;; Helper for emitting `MInst.MovToFpu` instructions. (spec (mov_to_fpu x s) (provide (= result (zero_ext 64 (conv_to s x))))) -(decl mov_to_fpu (Reg ScalarSize) Reg) +(decl rec mov_to_fpu (Reg ScalarSize) Reg) (rule (mov_to_fpu x size) (let ((dst WritableReg (temp_writable_reg $I8X16)) (_ Unit (emit (MInst.MovToFpu dst x size)))) @@ -4017,7 +4017,7 @@ ;; Note that we must make sure that all bits outside the lowest 16 are set to 0 ;; because this function is also used to load wider constants (that have zeros ;; in their most significant bits). -(decl constant_f16 (u16) Reg) +(decl rec constant_f16 (u16) Reg) (rule 3 (constant_f16 n) (if-let false (use_fp16)) (constant_f32 n)) @@ -4036,7 +4036,7 @@ ;; Note that we must make sure that all bits outside the lowest 32 are set to 0 ;; because this function is also used to load wider constants (that have zeros ;; in their most significant bits). -(decl constant_f32 (u32) Reg) +(decl rec constant_f32 (u32) Reg) (rule 3 (constant_f32 0) (vec_dup_imm (asimd_mov_mod_imm_zero (ScalarSize.Size32)) false @@ -4099,7 +4099,7 @@ ;; ;; The 64-bit input here only uses the low bits for the lane size in ;; `VectorSize` and all other bits are ignored. -(decl splat_const (u64 VectorSize) Reg) +(decl rec splat_const (u64 VectorSize) Reg) ;; If the splat'd constant can itself be reduced in size then attempt to do so ;; as it will make it easier to create the immediates in the instructions below. @@ -4956,7 +4956,7 @@ (MInst.CSel dst (Cond.Eq) tmp1 tmp2) (value_reg dst)))) -(decl lower_bmask (Type Type ValueRegs) ValueRegs) +(decl rec lower_bmask (Type Type ValueRegs) ValueRegs) ;; For conversions that exactly fit a register, we can use csetm. diff --git a/cranelift/codegen/src/isa/pulley_shared/lower.isle b/cranelift/codegen/src/isa/pulley_shared/lower.isle index b9a9c6d02055..0dda1809eebf 100644 --- a/cranelift/codegen/src/isa/pulley_shared/lower.isle +++ b/cranelift/codegen/src/isa/pulley_shared/lower.isle @@ -11,7 +11,7 @@ ;; needs to handle situations such as when the `Value` is 64-bits an explicit ;; comparison must be made. Additionally if `Value` is smaller than 32-bits ;; then it must be sign-extended up to at least 32 bits. -(decl lower_cond (Value) Cond) +(decl rec lower_cond (Value) Cond) (rule 0 (lower_cond val @ (value_type (fits_in_32 _))) (Cond.If32 (zext32 val))) (rule 1 (lower_cond val @ (value_type $I64)) (Cond.IfXneq64I32 val 0)) @@ -737,7 +737,7 @@ (rule (lower (icmp cc a b @ (value_type (ty_int ty)))) (lower_icmp ty cc a b)) -(decl lower_icmp (Type IntCC Value Value) XReg) +(decl rec lower_icmp (Type IntCC Value Value) XReg) (rule (lower_icmp $I64 (IntCC.Equal) a b) (pulley_xeq64 a b)) @@ -846,7 +846,7 @@ (rule 1 (lower (icmp cc a @ (value_type (ty_vec128 ty)) b)) (lower_vcmp ty cc a b)) -(decl lower_vcmp (Type IntCC Value Value) VReg) +(decl rec lower_vcmp (Type IntCC Value Value) VReg) (rule (lower_vcmp $I8X16 (IntCC.Equal) a b) (pulley_veq8x16 a b)) (rule (lower_vcmp $I8X16 (IntCC.NotEqual) a b) (pulley_vneq8x16 a b)) (rule (lower_vcmp $I8X16 (IntCC.SignedLessThan) a b) (pulley_vslt8x16 a b)) @@ -890,7 +890,7 @@ (rule 1 (lower (fcmp cc a b @ (value_type (ty_vec128 ty)))) (lower_vfcmp ty cc a b)) -(decl lower_fcmp (Type FloatCC Value Value) XReg) +(decl rec lower_fcmp (Type FloatCC Value Value) XReg) (rule (lower_fcmp $F32 (FloatCC.Equal) a b) (pulley_feq32 a b)) (rule (lower_fcmp $F64 (FloatCC.Equal) a b) (pulley_feq64 a b)) @@ -921,7 +921,7 @@ (if-let true (floatcc_unordered cc)) (pulley_xbxor32_s8 (lower_fcmp ty (floatcc_complement cc) a b) 1)) -(decl lower_vfcmp (Type FloatCC Value Value) VReg) +(decl rec lower_vfcmp (Type FloatCC Value Value) VReg) (rule (lower_vfcmp $F32X4 (FloatCC.Equal) a b) (pulley_veqf32x4 a b)) (rule (lower_vfcmp $F64X2 (FloatCC.Equal) a b) (pulley_veqf64x2 a b)) diff --git a/cranelift/codegen/src/isa/riscv64/inst.isle b/cranelift/codegen/src/isa/riscv64/inst.isle index fa17ae74d9ae..684bf68ecb5d 100644 --- a/cranelift/codegen/src/isa/riscv64/inst.isle +++ b/cranelift/codegen/src/isa/riscv64/inst.isle @@ -1873,7 +1873,7 @@ ;; Immediate Loading rules ;; TODO: Loading the zero reg directly causes a bunch of regalloc errors, we should look into it. ;; TODO: Load floats using `fld` instead of `ld` -(decl imm (Type u64) Reg) +(decl rec imm (Type u64) Reg) ;; Special-case 0.0 for floats to use the `(zero_reg)` directly. ;; See #7162 for why this doesn't fall out of the rules below. @@ -2470,7 +2470,7 @@ (rule 0 (load_op_reg_type _) $I64) ;; Helper constructor to build a load instruction. -(decl gen_load (AMode LoadOP MemFlags) Reg) +(decl rec gen_load (AMode LoadOP MemFlags) Reg) (rule (gen_load amode op flags) (let ((dst WritableReg (temp_writable_reg (load_op_reg_type op))) (_ Unit (emit (MInst.Load dst op flags amode)))) @@ -2661,7 +2661,7 @@ (decl gen_stack_addr (StackSlot Offset32) Reg) (extern constructor gen_stack_addr gen_stack_addr) -(decl gen_select_xreg (IntegerCompare XReg XReg) XReg) +(decl rec gen_select_xreg (IntegerCompare XReg XReg) XReg) (rule 6 (gen_select_xreg (int_compare_decompose cc x y) x y) (if-let (IntCC.UnsignedLessThan) (intcc_without_eq cc)) @@ -2994,7 +2994,7 @@ ;; Generates a bitcast instruction. ;; Args are: src, src_ty, dst_ty -(decl gen_bitcast (Reg Type Type) Reg) +(decl rec gen_bitcast (Reg Type Type) Reg) (rule 9 (gen_bitcast r (ty_supported_float_size $F16) (ty_supported_vec _)) (if-let false (has_zvfh)) (rv_vfmv_sf r (vstate_from_type $F32))) (rule 8 (gen_bitcast r (ty_supported_vec ty) (ty_supported_float_size $F16)) (if-let false (has_zvfh)) (gen_bitcast (gen_bitcast r ty $I16) $I16 $F16)) @@ -3214,7 +3214,7 @@ (convert FloatCompare IntegerCompare float_to_int_compare) ;; Compare two floating point numbers and return a zero/non-zero result. -(decl fcmp_to_float_compare (FloatCC Type FReg FReg) FloatCompare) +(decl rec fcmp_to_float_compare (FloatCC Type FReg FReg) FloatCompare) ;; Direct codegen for unordered comparisons is not that efficient, so invert ;; the comparison to get an ordered comparison and generate that. Then invert diff --git a/cranelift/codegen/src/isa/riscv64/inst_vector.isle b/cranelift/codegen/src/isa/riscv64/inst_vector.isle index e0de755fccde..d7c9808c5571 100644 --- a/cranelift/codegen/src/isa/riscv64/inst_vector.isle +++ b/cranelift/codegen/src/isa/riscv64/inst_vector.isle @@ -1501,7 +1501,7 @@ ;;;; Multi-Instruction Helpers ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; -(decl gen_extractlane (Type VReg u8) Reg) +(decl rec gen_extractlane (Type VReg u8) Reg) ;; When extracting lane 0 for floats, we can use `vfmv.f.s` directly. (rule 3 (gen_extractlane (ty_vec_fits_in_register ty) src 0) @@ -1731,7 +1731,7 @@ ;; Builds a vector mask corresponding to the FloatCC operation. -(decl gen_fcmp_mask (Type FloatCC Value Value) VReg) +(decl rec gen_fcmp_mask (Type FloatCC Value Value) VReg) ;; FloatCC.Equal diff --git a/cranelift/codegen/src/isa/riscv64/lower.isle b/cranelift/codegen/src/isa/riscv64/lower.isle index d1a759dcba2a..f7fa4a297d01 100644 --- a/cranelift/codegen/src/isa/riscv64/lower.isle +++ b/cranelift/codegen/src/isa/riscv64/lower.isle @@ -1044,7 +1044,7 @@ ;; Constructs a sequence of instructions that reverse all bits in `x` up to ;; the given type width. -(decl gen_bitrev (Type XReg) XReg) +(decl rec gen_bitrev (Type XReg) XReg) (rule 0 (gen_bitrev (ty_16_or_32 (ty_int ty)) x) (if-let shift_amt (u64_to_imm12 (u64_wrapping_sub 64 (ty_bits ty)))) @@ -1069,7 +1069,7 @@ ;; Builds a sequence of instructions that swaps the bytes in `x` up to the given ;; type width. -(decl gen_bswap (Type XReg) XReg) +(decl rec gen_bswap (Type XReg) XReg) ;; This is only here to make the rule below work. bswap.i8 isn't valid (rule 0 (gen_bswap $I8 x) x) @@ -2263,7 +2263,7 @@ (rule 0 (lower (icmp cc x @ (value_type (fits_in_64 ty)) y)) (lower_icmp cc x y)) -(decl lower_icmp (IntCC Value Value) XReg) +(decl rec lower_icmp (IntCC Value Value) XReg) (rule 0 (lower_icmp cc x y) (lower_int_compare (icmp_to_int_compare cc x y))) @@ -2352,7 +2352,7 @@ (rule 20 (lower (icmp cc x @ (value_type $I128) y)) (lower_icmp_i128 cc x y)) -(decl lower_icmp_i128 (IntCC ValueRegs ValueRegs) XReg) +(decl rec lower_icmp_i128 (IntCC ValueRegs ValueRegs) XReg) (rule 0 (lower_icmp_i128 (IntCC.Equal) x y) (let ((lo XReg (rv_xor (value_regs_get x 0) (value_regs_get y 0))) (hi XReg (rv_xor (value_regs_get x 1) (value_regs_get y 1)))) diff --git a/cranelift/codegen/src/isa/s390x/inst.isle b/cranelift/codegen/src/isa/s390x/inst.isle index d9dd041e2059..3a0a800bd75d 100644 --- a/cranelift/codegen/src/isa/s390x/inst.isle +++ b/cranelift/codegen/src/isa/s390x/inst.isle @@ -2884,7 +2884,7 @@ ;; Helpers for generating immediate values ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; Allocate a temporary register, initialized with an immediate. -(decl imm (Type u64) Reg) +(decl rec imm (Type u64) Reg) ;; 16-bit (or smaller) result type, any value (rule 7 (imm (fits_in_16 (ty_int ty)) n) @@ -2986,7 +2986,7 @@ (vec_load ty (memarg_const (emit_u128_be_const n)))) ;; Variant with replicated immediate. -(decl vec_imm_splat (Type u64) Reg) +(decl rec vec_imm_splat (Type u64) Reg) (rule 1 (vec_imm_splat (ty_vec128 ty) 0) (vec_imm_byte_mask ty 0)) (rule 2 (vec_imm_splat ty @ (multi_lane 8 _) n) @@ -3289,7 +3289,7 @@ (rule (lower_bool $I8 cond) (select_bool_imm $I8 cond 1 0)) ;; Lower a boolean condition to the values -1/0. -(decl lower_bool_to_mask (Type ProducesBool) Reg) +(decl rec lower_bool_to_mask (Type ProducesBool) Reg) (rule 0 (lower_bool_to_mask (fits_in_64 ty) producer) (select_bool_imm ty producer -1 0)) diff --git a/cranelift/codegen/src/isa/x64/inst.isle b/cranelift/codegen/src/isa/x64/inst.isle index b8e1fa8481c4..3185592c7c23 100644 --- a/cranelift/codegen/src/isa/x64/inst.isle +++ b/cranelift/codegen/src/isa/x64/inst.isle @@ -1929,7 +1929,7 @@ ;; ;; Note that if `Type` is less than 64-bits then the upper bits of the `imm` ;; argument will be set to zero and lost. -(decl imm (Type u64) Reg) +(decl rec imm (Type u64) Reg) ;; Base case: integers of up to at most 32-bits. ;; @@ -3346,7 +3346,7 @@ (ConsumesFlags.ConsumesFlagsSideEffect (MInst.JmpCondOr cc1 cc2 taken not_taken))) ;; Conditional jump based on a `CondResult` -(decl jmp_cond_result (CondResult MachLabel MachLabel) SideEffectNoResult) +(decl rec jmp_cond_result (CondResult MachLabel MachLabel) SideEffectNoResult) (rule (jmp_cond_result (CondResult.CC producer cc) taken not_taken) (with_flags_side_effect producer (jmp_cond cc taken not_taken))) (rule (jmp_cond_result cond @ (CondResult.And _ _ _) taken not_taken) @@ -3549,7 +3549,7 @@ (rule 5 (emit_cmp (IntCC.NotEqual) a (u64_from_iconst 0)) (is_nonzero a)) (rule 6 (emit_cmp (IntCC.NotEqual) (u64_from_iconst 0) a) (is_nonzero a)) -(decl emit_cmp_i128 (CC Gpr Gpr Gpr Gpr) CondResult) +(decl rec emit_cmp_i128 (CC Gpr Gpr Gpr Gpr) CondResult) ;; Eliminate cases which compare something "or equal" by swapping arguments. (rule 2 (emit_cmp_i128 (CC.NLE) a_hi a_lo b_hi b_lo) (emit_cmp_i128 (CC.L) b_hi b_lo a_hi a_lo)) diff --git a/cranelift/codegen/src/isa/x64/lower.isle b/cranelift/codegen/src/isa/x64/lower.isle index b9c7aa76e77f..c462ce1f907b 100644 --- a/cranelift/codegen/src/isa/x64/lower.isle +++ b/cranelift/codegen/src/isa/x64/lower.isle @@ -628,7 +628,7 @@ ;; Get the address of the mask to use when fixing up the lanes that weren't ;; correctly generated by the 16x8 shift. -(decl ishl_i8x16_mask (RegMemImm) SyntheticAmode) +(decl rec ishl_i8x16_mask (RegMemImm) SyntheticAmode) ;; When the shift amount is known, we can statically (i.e. at compile time) ;; determine the mask to use and only emit that. @@ -732,7 +732,7 @@ ;; Get the address of the mask to use when fixing up the lanes that weren't ;; correctly generated by the 16x8 shift. -(decl ushr_i8x16_mask (RegMemImm) SyntheticAmode) +(decl rec ushr_i8x16_mask (RegMemImm) SyntheticAmode) ;; When the shift amount is known, we can statically (i.e. at compile time) ;; determine the mask to use and only emit that. @@ -1422,7 +1422,7 @@ ;;;; Rules for `bmask` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; -(decl lower_bmask (Type Type ValueRegs) ValueRegs) +(decl rec lower_bmask (Type Type ValueRegs) ValueRegs) ;; Values that fit in a register ;; @@ -2178,7 +2178,7 @@ (rule (lower (select cond x y)) (lower_select (is_nonzero_cmp cond) x y)) -(decl lower_select (CondResult Value Value) InstOutput) +(decl rec lower_select (CondResult Value Value) InstOutput) (rule 0 (lower_select cond a @ (value_type (ty_int (fits_in_64 ty))) b) (lower_select_gpr ty cond a b)) (rule 1 (lower_select cond a @ (value_type (is_xmm_type ty)) b) @@ -4276,7 +4276,7 @@ ;; Emits either a `round{ss,sd,ps,pd}` instruction, as appropriate, or generates ;; the appropriate libcall and sequence to call that. -(decl x64_round (Type RegMem RoundImm) Xmm) +(decl rec x64_round (Type RegMem RoundImm) Xmm) (rule 1 (x64_round $F32 a imm) (if-let true (has_sse41)) (x64_roundss a imm)) @@ -4683,7 +4683,7 @@ ;; performant thing in the world so this is primarily here for completeness ;; of lowerings on all x86 cpus but if rules are ideally gated on the presence ;; of SSSE3 to use the `pshufb` instruction itself. -(decl lower_pshufb (Xmm RegMem) Xmm) +(decl rec lower_pshufb (Xmm RegMem) Xmm) (rule 1 (lower_pshufb src mask) (if-let true (has_ssse3)) (x64_pshufb src mask)) diff --git a/cranelift/codegen/src/prelude_opt.isle b/cranelift/codegen/src/prelude_opt.isle index e9b9dcdc0d6e..dcb1bebc3598 100644 --- a/cranelift/codegen/src/prelude_opt.isle +++ b/cranelift/codegen/src/prelude_opt.isle @@ -131,7 +131,7 @@ ;; so that `iconst.i8 255` will give you a `-1_i64`. ;; When constructing, the rule will fail if the value cannot be represented in ;; the target type. If it fits, it'll be masked accordingly in the constant. -(decl iconst_s (Type i64) Value) +(decl rec iconst_s (Type i64) Value) (extractor (iconst_s ty c) (inst_data_value_tupled (iconst_sextend_etor ty c))) (rule 0 (iconst_s ty c) (if-let c_masked (u64_and (i64_cast_unsigned c) @@ -147,7 +147,7 @@ ;; so that `iconst.i8 255` will give you a `255_u64`. ;; When constructing, the rule will fail if the value cannot be represented in ;; the target type. -(decl iconst_u (Type u64) Value) +(decl rec iconst_u (Type u64) Value) (extractor (iconst_u ty c) (iconst ty (u64_from_imm64 c))) (rule 0 (iconst_u ty c) (if-let true (u64_lt_eq c (ty_umax ty))) diff --git a/cranelift/isle/isle/isle_examples/pass/prio_trie_bug.isle b/cranelift/isle/isle/isle_examples/pass/prio_trie_bug.isle index b63de1ab152b..0040f65b9bad 100644 --- a/cranelift/isle/isle/isle_examples/pass/prio_trie_bug.isle +++ b/cranelift/isle/isle/isle_examples/pass/prio_trie_bug.isle @@ -62,7 +62,7 @@ ;; One step in amode processing: take an existing amode and add ;; another value to it. -(decl amode_add (Amode Value) Amode) +(decl rec amode_add (Amode Value) Amode) ;; -- Top-level driver: pull apart the addends. ;; diff --git a/cranelift/isle/isle/src/ast.rs b/cranelift/isle/isle/src/ast.rs index 6fbff36b2daf..ada5747348dd 100644 --- a/cranelift/isle/isle/src/ast.rs +++ b/cranelift/isle/isle/src/ast.rs @@ -80,6 +80,8 @@ pub struct Decl { pub multi: bool, /// Whether this term's constructor can fail to match. pub partial: bool, + /// Whether this term is recursive. + pub rec: bool, pub pos: Pos, } diff --git a/cranelift/isle/isle/src/compile.rs b/cranelift/isle/isle/src/compile.rs index 01a02e1f7aba..a27c961cd224 100644 --- a/cranelift/isle/isle/src/compile.rs +++ b/cranelift/isle/isle/src/compile.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use crate::ast::Def; use crate::error::Errors; use crate::files::Files; -use crate::{ast, codegen, overlap, sema}; +use crate::{ast, codegen, overlap, recursion, sema}; /// Compile the given AST definitions into Rust source code. pub fn compile( @@ -26,6 +26,7 @@ pub fn compile( Ok(terms) => terms, Err(errs) => return Err(Errors::new(errs, files)), }; + recursion::check(&terms, &term_env).map_err(|errs| Errors::new(errs, files.clone()))?; Ok(codegen::codegen( files, &type_env, &term_env, &terms, options, diff --git a/cranelift/isle/isle/src/error.rs b/cranelift/isle/isle/src/error.rs index 250a6db369fe..bb316e0146bf 100644 --- a/cranelift/isle/isle/src/error.rs +++ b/cranelift/isle/isle/src/error.rs @@ -23,6 +23,7 @@ impl std::fmt::Debug for Errors { Error::TypeError { msg, .. } => format!("type error: {msg}"), Error::UnreachableError { msg, .. } => format!("unreachable rule: {msg}"), Error::OverlapError { msg, .. } => format!("overlap error: {msg}"), + Error::RecursionError { msg, .. } => format!("recursion error: {msg}"), Error::ShadowedError { .. } => { "more general higher-priority rule shadows other rules".to_string() } @@ -33,7 +34,8 @@ impl std::fmt::Debug for Errors { Error::ParseError { span, .. } | Error::TypeError { span, .. } - | Error::UnreachableError { span, .. } => { + | Error::UnreachableError { span, .. } + | Error::RecursionError { span, .. } => { vec![Label::primary(span.from.file, span)] } @@ -127,6 +129,15 @@ pub enum Error { rules: Vec, }, + /// Recurive rules error. Term is recursive without explicit opt-in, or vice versa. + RecursionError { + /// The error message. + msg: String, + + /// The location of the term declaration. + span: Span, + }, + /// The rules can never match because another rule will always match first. ShadowedError { /// The locations of the unmatchable rules. diff --git a/cranelift/isle/isle/src/lib.rs b/cranelift/isle/isle/src/lib.rs index 1ccc6ae9b207..369b1b815b85 100644 --- a/cranelift/isle/isle/src/lib.rs +++ b/cranelift/isle/isle/src/lib.rs @@ -29,6 +29,7 @@ mod log; pub mod overlap; pub mod parser; pub mod printer; +pub mod recursion; pub mod sema; pub mod serialize; pub mod stablemapset; diff --git a/cranelift/isle/isle/src/parser.rs b/cranelift/isle/isle/src/parser.rs index 949bafaf4884..6583c69aa6ed 100644 --- a/cranelift/isle/isle/src/parser.rs +++ b/cranelift/isle/isle/src/parser.rs @@ -336,6 +336,7 @@ impl<'a> Parser<'a> { let pure = self.eat_sym_str("pure")?; let multi = self.eat_sym_str("multi")?; let partial = self.eat_sym_str("partial")?; + let rec = self.eat_sym_str("rec")?; let term = self.parse_ident()?; @@ -355,6 +356,7 @@ impl<'a> Parser<'a> { pure, multi, partial, + rec, pos, }) } diff --git a/cranelift/isle/isle/src/printer.rs b/cranelift/isle/isle/src/printer.rs index b511f1f2faec..7159c44c3337 100644 --- a/cranelift/isle/isle/src/printer.rs +++ b/cranelift/isle/isle/src/printer.rs @@ -255,6 +255,7 @@ impl ToSExpr for Decl { pure, multi, partial, + rec, pos: _, } = self; let mut parts = vec![SExpr::atom("decl")]; @@ -267,6 +268,9 @@ impl ToSExpr for Decl { if *partial { parts.push(SExpr::atom("partial")); } + if *rec { + parts.push(SExpr::atom("rec")); + } parts.push(term.to_sexpr()); parts.push(SExpr::list(arg_tys)); parts.push(ret_ty.to_sexpr()); diff --git a/cranelift/isle/isle/src/recursion.rs b/cranelift/isle/isle/src/recursion.rs new file mode 100644 index 000000000000..f7ea38da8d6d --- /dev/null +++ b/cranelift/isle/isle/src/recursion.rs @@ -0,0 +1,92 @@ +//! Recursion checking for ISLE terms. + +use std::collections::{HashMap, HashSet}; + +use crate::{ + error::{Error, Span}, + sema::{TermEnv, TermId}, + trie_again::{Binding, RuleSet}, +}; + +/// Check for recursive terms. +pub fn check(terms: &[(TermId, RuleSet)], termenv: &TermEnv) -> Result<(), Vec> { + let term_rule_sets: HashMap = terms + .iter() + .map(|(term_id, rule_set)| (*term_id, rule_set)) + .collect(); + + let mut errors = Vec::new(); + for (term_id, _) in terms { + // Check if this term is involved in a reference cycle. + let reachable = terms_reachable_from(*term_id, &term_rule_sets); + let is_cyclic = reachable.contains(term_id); + + // Lookup if this term is explicitly marked recursive. + let term = &termenv.terms[term_id.index()]; + let is_marked_recursive = term.is_recursive(); + + // Require the two to agree. + match (is_cyclic, is_marked_recursive) { + (true, true) | (false, false) => {} + (true, false) => { + errors.push(Error::RecursionError { + msg: "Term is recursive but does not have the `rec` attribute".to_string(), + span: Span::new_single(term.decl_pos), + }); + } + (false, true) => { + errors.push(Error::RecursionError { + msg: "Term has the `rec` attribute but is not recursive".to_string(), + span: Span::new_single(term.decl_pos), + }); + } + } + } + + if errors.is_empty() { + Ok(()) + } else { + Err(errors) + } +} + +/// Search for all terms reachable from the source. +fn terms_reachable_from( + source: TermId, + term_rule_sets: &HashMap, +) -> HashSet { + let mut reachable = HashSet::new(); + let mut stack = vec![source]; + + while let Some(term_id) = stack.pop() { + if !term_rule_sets.contains_key(&term_id) { + continue; + } + + let used = terms_in_rule_set(&term_rule_sets[&term_id]); + for used_term_id in used { + if reachable.contains(&used_term_id) { + continue; + } + reachable.insert(used_term_id); + stack.push(used_term_id); + } + } + + reachable +} + +fn terms_in_rule_set(rule_set: &RuleSet) -> HashSet { + rule_set + .bindings + .iter() + .filter_map(binding_used_term) + .collect() +} + +fn binding_used_term(binding: &Binding) -> Option { + match binding { + Binding::Constructor { term, .. } | Binding::Extractor { term, .. } => Some(*term), + _ => None, + } +} diff --git a/cranelift/isle/isle/src/sema.rs b/cranelift/isle/isle/src/sema.rs index e910ffa290ce..14c322a4e975 100644 --- a/cranelift/isle/isle/src/sema.rs +++ b/cranelift/isle/isle/src/sema.rs @@ -389,6 +389,8 @@ pub struct TermFlags { pub multi: bool, /// Whether the term is marked as `partial`. pub partial: bool, + /// Whether the term is marked as `rec`. + pub rec: bool, } impl TermFlags { @@ -516,6 +518,17 @@ impl Term { ) } + /// Is this term marked as recursive? + pub fn is_recursive(&self) -> bool { + matches!( + self.kind, + TermKind::Decl { + flags: TermFlags { rec: true, .. }, + .. + } + ) + } + /// Does this term have a constructor? pub fn has_constructor(&self) -> bool { matches!( @@ -903,6 +916,7 @@ pub trait ExprVisitor { pure: bool, infallible: bool, multi: bool, + rec: bool, ) -> Self::ExprId; } @@ -967,6 +981,7 @@ impl Expr { flags.pure, /* infallible = */ !flags.partial, flags.multi, + flags.rec, ) } TermKind::Decl { @@ -1463,6 +1478,7 @@ impl TermEnv { pure: decl.pure, multi: decl.multi, partial: decl.partial, + rec: decl.rec, }; self.terms.push(Term { id: tid, diff --git a/cranelift/isle/isle/src/trie_again.rs b/cranelift/isle/isle/src/trie_again.rs index bf0fd976c797..1ffb578106c2 100644 --- a/cranelift/isle/isle/src/trie_again.rs +++ b/cranelift/isle/isle/src/trie_again.rs @@ -649,6 +649,7 @@ impl sema::ExprVisitor for RuleSetBuilder { pure: bool, infallible: bool, multi: bool, + _rec: bool, ) -> BindingId { let instance = if pure { 0