Skip to content

Commit 2ebec04

Browse files
authored
Merge pull request #8 from herbie-fp/unsound
Avoid unsoundness
2 parents 42c526b + 14cdcd8 commit 2ebec04

File tree

3 files changed

+46
-14
lines changed

3 files changed

+46
-14
lines changed

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ edition = "2018"
66

77

88
[dependencies]
9-
egg = "0.5"
9+
egg = "0.6"
10+
1011
log = "0.4"
1112
indexmap = "1"
1213
libc = "0.2.71"

src/lib.rs

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
pub mod math;
22
pub mod rules;
33

4+
use egg::Id;
45
use math::*;
56

67
use std::ffi::{CStr, CString};
78
use std::os::raw::c_char;
8-
use std::slice;
9+
use std::{slice, sync::atomic::Ordering};
910

1011
unsafe fn cstring_to_recexpr(c_string: *const c_char) -> Option<RecExpr> {
1112
match CStr::from_ptr(c_string).to_str() {
@@ -79,6 +80,7 @@ pub unsafe extern "C" fn egraph_add_expr(
7980
expr: *const c_char,
8081
) -> *mut EGraphAddResult {
8182
ffirun(|| {
83+
let _ = env_logger::try_init();
8284
let ctx = &mut *ptr;
8385
let mut runner = ctx
8486
.runner
@@ -95,6 +97,7 @@ pub unsafe extern "C" fn egraph_add_expr(
9597
Some(rec_expr) => {
9698
runner = runner.with_expr(&rec_expr);
9799
let id = *runner.roots.last().unwrap();
100+
let id = usize::from(id) as u32;
98101
EGraphAddResult { id, successp: true }
99102
}
100103
};
@@ -151,24 +154,38 @@ pub unsafe extern "C" fn egraph_run_iter(
151154
let rules: Vec<Rewrite> = rules::mk_rules(&ffi_tuples);
152155

153156
runner.egraph.analysis.constant_fold = is_constant_folding_enabled;
154-
runner = runner.with_node_limit(limit as usize).run(&rules);
157+
runner = runner
158+
.with_node_limit(limit as usize)
159+
.with_hook(|r| {
160+
if r.egraph.analysis.unsound.load(Ordering::SeqCst) {
161+
Err("Unsoundness detected".into())
162+
} else {
163+
Ok(())
164+
}
165+
})
166+
.run(&rules);
155167
}
156168
ctx.runner = Some(runner);
157169
})
158170
}
159171

160172
fn find_extracted(runner: &Runner, id: u32) -> &Extracted {
161-
let id = runner.egraph.find(id);
162-
let iter = runner
163-
.iterations
164-
.last()
165-
.expect("There should be some iterations by now!");
166-
iter.data
173+
let id = runner.egraph.find(Id::from(id as usize));
174+
let desired_iter = if runner.egraph.analysis.unsound.load(Ordering::SeqCst) {
175+
// go back one more iter, add egg can duplicate the final iter in the case of an error
176+
runner.iterations.len().saturating_sub(3)
177+
} else {
178+
runner.iterations.len().saturating_sub(1)
179+
};
180+
181+
runner.iterations[desired_iter]
182+
.data
167183
.extracted
168184
.iter()
169185
.find(|(i, _)| runner.egraph.find(*i) == id)
170186
.map(|(_, ext)| ext)
171187
.expect("Couldn't find matching extraction!")
188+
.clone()
172189
}
173190

174191
#[no_mangle]

src/math.rs

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use egg::*;
22

3+
use std::sync::atomic::{AtomicBool, Ordering};
4+
35
use num_bigint::BigInt;
46
use num_rational::Ratio;
57
use num_traits::{Pow, Signed, Zero};
@@ -144,6 +146,7 @@ define_language! {
144146
}
145147

146148
pub struct ConstantFold {
149+
pub unsound: AtomicBool,
147150
pub constant_fold: bool,
148151
pub prune: bool,
149152
}
@@ -153,6 +156,7 @@ impl Default for ConstantFold {
153156
Self {
154157
constant_fold: true,
155158
prune: true,
159+
unsound: AtomicBool::from(false),
156160
}
157161
}
158162
}
@@ -210,11 +214,21 @@ impl Analysis<Math> for ConstantFold {
210214
}
211215

212216
fn merge(&self, to: &mut Self::Data, from: Self::Data) -> bool {
213-
if to.is_none() && from.is_some() {
214-
*to = from;
215-
true
216-
} else {
217-
false
217+
match (&to, from) {
218+
(None, None) => false,
219+
(Some(_), None) => false, // no update needed
220+
(None, Some(c)) => {
221+
*to = Some(c);
222+
true
223+
}
224+
(Some(a), Some(ref b)) => {
225+
if a != b {
226+
if !self.unsound.swap(true, Ordering::SeqCst) {
227+
log::warn!("Bad merge detected: {} != {}", a, b);
228+
}
229+
}
230+
false
231+
}
218232
}
219233
}
220234

0 commit comments

Comments
 (0)