Skip to content

Commit aca6619

Browse files
authored
Merge pull request #457 from herbie-fp/remove-egraph-addresults
Remove `EGraphAddResult`
2 parents f16b201 + 4b23437 commit aca6619

File tree

5 files changed

+59
-111
lines changed

5 files changed

+59
-111
lines changed

egg-herbie/egg-interface.rkt

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55
racket/runtime-path)
66

77
(provide egraph_create egraph_destroy egraph_add_expr
8-
egraph_addresult_destroy egraph_run egraph_run_with_iter_limit
8+
egraph_run egraph_run_with_iter_limit
99
egraph_get_stop_reason
1010
egraph_get_simplest egraph_get_variants
1111
_EGraphIter destroy_egraphiters egraph_get_cost
1212
egraph_is_unsound_detected egraph_get_times_applied
1313
destroy_string
14-
(struct-out EGraphAddResult)
1514
(struct-out EGraphIter)
1615
(struct-out FFIRule))
1716

@@ -27,10 +26,6 @@
2726

2827
(define _egraph-pointer (_cpointer 'egraph))
2928

30-
(define-cstruct _EGraphAddResult
31-
([id _uint]
32-
[successp _bool]))
33-
3429
(define-cstruct _EGraphIter
3530
([numnodes _uint]
3631
[numeclasses _uint]
@@ -51,9 +46,7 @@
5146
(define-eggmath destroy_string (_fun _pointer -> _void))
5247

5348
;; egraph pointer, s-expr string -> node number
54-
(define-eggmath egraph_add_expr (_fun _egraph-pointer _string/utf-8 -> _EGraphAddResult-pointer))
55-
56-
(define-eggmath egraph_addresult_destroy (_fun _EGraphAddResult-pointer -> _void))
49+
(define-eggmath egraph_add_expr (_fun _egraph-pointer _string/utf-8 -> _uint))
5750

5851
(define-eggmath destroy_egraphiters (_fun _uint _EGraphIter-pointer -> _void))
5952

egg-herbie/main.rkt

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
(module+ test (require rackunit))
88

9-
(provide egraph-run egraph-add-exprs with-egraph
9+
(provide egraph-run egraph-add-expr with-egraph
1010
egraph-get-simplest egraph-get-variants
1111
egg-expr->expr egg-exprs->exprs egg-add-exn?
1212
make-ffi-rules free-ffi-rules egraph-get-cost
@@ -166,34 +166,14 @@
166166
(struct egg-add-exn exn:fail ())
167167

168168
;; result function is a function that takes the ids of the nodes
169-
;; egraph-add-exprs returns the result of result-function
170-
(define (egraph-add-exprs eg-data exprs result-function)
171-
(define egg-exprs
172-
(map
173-
(lambda (expr) (expr->egg-expr expr eg-data))
174-
exprs))
175-
176-
(define expr-results
177-
(map
178-
(lambda (expr)
179-
(egraph_add_expr (egraph-data-egraph-pointer eg-data) expr))
180-
egg-exprs))
181-
182-
(define node-ids
183-
(for/list ([result expr-results])
184-
(if (EGraphAddResult-successp result)
185-
(EGraphAddResult-id result)
186-
(raise (egg-add-exn
187-
(string-append "Failed to add expr to egraph")
188-
(current-continuation-marks))))))
189-
190-
(define res (result-function node-ids))
191-
192-
(for/list ([result expr-results])
193-
(egraph_addresult_destroy result))
194-
195-
res)
196-
169+
(define (egraph-add-expr eg-data expr)
170+
(define egg-expr (expr->egg-expr expr eg-data))
171+
(define result (egraph_add_expr (egraph-data-egraph-pointer eg-data) egg-expr))
172+
(when (= result 0)
173+
(raise (egg-add-exn
174+
"Failed to add expr to egraph"
175+
(current-continuation-marks))))
176+
(- result 1))
197177

198178
(module+ test
199179

egg-herbie/src/lib.rs

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,6 @@ pub unsafe extern "C" fn egraph_destroy(ptr: *mut Context) {
4343
std::mem::drop(Box::from_raw(ptr))
4444
}
4545

46-
#[no_mangle]
47-
pub unsafe extern "C" fn egraph_addresult_destroy(ptr: *mut EGraphAddResult) {
48-
std::mem::drop(Box::from_raw(ptr))
49-
}
50-
5146
#[no_mangle]
5247
pub unsafe extern "C" fn destroy_egraphiters(size: u32, ptr: *mut EGraphIter) {
5348
let _array: &[EGraphIter] = slice::from_raw_parts(ptr, size as usize);
@@ -58,13 +53,6 @@ pub unsafe extern "C" fn destroy_string(ptr: *mut c_char) {
5853
let _str = CString::from_raw(ptr);
5954
}
6055

61-
// a struct to report failure if the add fails
62-
#[repr(C)]
63-
pub struct EGraphAddResult {
64-
id: u32,
65-
successp: bool,
66-
}
67-
6856
#[repr(C)]
6957
pub struct EGraphIter {
7058
numnodes: u32,
@@ -110,10 +98,7 @@ fn runner_egraphiters(runner: &Runner) -> *mut EGraphIter {
11098
}
11199

112100
#[no_mangle]
113-
pub unsafe extern "C" fn egraph_add_expr(
114-
ptr: *mut Context,
115-
expr: *const c_char,
116-
) -> *mut EGraphAddResult {
101+
pub unsafe extern "C" fn egraph_add_expr(ptr: *mut Context, expr: *const c_char) -> u32 {
117102
ffirun(|| {
118103
let _ = env_logger::try_init();
119104
let ctx = &mut *ptr;
@@ -125,20 +110,18 @@ pub unsafe extern "C" fn egraph_add_expr(
125110
assert_eq!(ctx.iteration, 0);
126111

127112
let result = match cstring_to_recexpr(expr) {
128-
None => EGraphAddResult {
129-
id: 0,
130-
successp: false,
131-
},
113+
None => 0 as u32,
132114
Some(rec_expr) => {
133115
runner = runner.with_expr(&rec_expr);
134116
let id = *runner.roots.last().unwrap();
135117
let id = usize::from(id) as u32;
136-
EGraphAddResult { id, successp: true }
118+
assert!(id < u32::MAX);
119+
id + 1 as u32
137120
}
138121
};
139122

140123
ctx.runner = Some(runner);
141-
Box::into_raw(Box::new(result))
124+
result
142125
})
143126
}
144127

src/core/matcher.rkt

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -95,40 +95,37 @@
9595
(define result-thunk
9696
(with-egraph
9797
(λ (egg-graph)
98-
(egraph-add-exprs
99-
egg-graph
100-
exprs
101-
(λ (node-ids)
102-
(define iter-data (egg-run-rules egg-graph #:limit iter-limit (*node-limit*) irules node-ids #t))
103-
(for ([rule rules])
104-
(define count (egraph-get-times-applied egg-graph (rule-name rule)))
105-
(when (> count 0) (timeline-push! 'rules (~a (rule-name rule)) count)))
106-
(cond
107-
[(egraph-is-unsound-detected egg-graph)
108-
; unsoundness detected, fallback
109-
(match* (exprs iter-limit)
110-
[((list (? list?) (? list?) (? list?) ...) #f) ; run expressions individually
111-
(λ ()
112-
(for/list ([expr exprs] [root-loc root-locs])
113-
(timeline-push! 'method "egg-rewrite")
114-
(car (loop (list expr) (list root-loc) #f))))]
115-
[((list (? list?)) #f) ; run expressions with iter limit
116-
(λ ()
117-
(let ([limit (- (length iter-data) 2)])
118-
(timeline-push! 'method "egg-rewrite-iter-limit")
119-
(loop exprs root-locs limit)))]
120-
[(_ (? number?)) ; give up
121-
(timeline-push! 'method "egg-rewrite-fail")
122-
(λ () '(()))])]
123-
[else
124-
(define variants
125-
(for/list ([id node-ids] [expr exprs] [root-loc root-locs] [expr-repr reprs])
126-
(define egg-rule (rule "egg-rr" 'x 'x (list expr-repr) expr-repr))
127-
(define output (egraph-get-variants egg-graph id expr))
128-
(define extracted (egg-exprs->exprs output egg-graph))
129-
(for/list ([variant (remove-duplicates extracted)])
130-
(list (change egg-rule root-loc (list (cons 'x variant)))))))
131-
(λ () variants)]))))))
98+
(define node-ids (map (curry egraph-add-expr egg-graph) exprs))
99+
(define iter-data (egg-run-rules egg-graph #:limit iter-limit (*node-limit*) irules node-ids #t))
100+
(for ([rule rules])
101+
(define count (egraph-get-times-applied egg-graph (rule-name rule)))
102+
(when (> count 0) (timeline-push! 'rules (~a (rule-name rule)) count)))
103+
(cond
104+
[(egraph-is-unsound-detected egg-graph)
105+
; unsoundness detected, fallback
106+
(match* (exprs iter-limit)
107+
[((list (? list?) (? list?) (? list?) ...) #f) ; run expressions individually
108+
(λ ()
109+
(for/list ([expr exprs] [root-loc root-locs])
110+
(timeline-push! 'method "egg-rewrite")
111+
(car (loop (list expr) (list root-loc) #f))))]
112+
[((list (? list?)) #f) ; run expressions with iter limit
113+
(λ ()
114+
(let ([limit (- (length iter-data) 2)])
115+
(timeline-push! 'method "egg-rewrite-iter-limit")
116+
(loop exprs root-locs limit)))]
117+
[(_ (? number?)) ; give up
118+
(timeline-push! 'method "egg-rewrite-fail")
119+
(λ () '(()))])]
120+
[else
121+
(define variants
122+
(for/list ([id node-ids] [expr exprs] [root-loc root-locs] [expr-repr reprs])
123+
(define egg-rule (rule "egg-rr" 'x 'x (list expr-repr) expr-repr))
124+
(define output (egraph-get-variants egg-graph id expr))
125+
(define extracted (egg-exprs->exprs output egg-graph))
126+
(for/list ([variant (remove-duplicates extracted)])
127+
(list (change egg-rule root-loc (list (cons 'x variant)))))))
128+
(λ () variants)]))))
132129
(result-thunk)))
133130

134131
;; Recursive rewrite chooser

src/core/simplify.rkt

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -80,26 +80,21 @@
8080

8181
(with-egraph
8282
(lambda (egg-graph)
83-
(egraph-add-exprs
84-
egg-graph
85-
exprs
86-
(lambda (node-ids)
87-
(define iter-data (egg-run-rules egg-graph (*node-limit*) irules node-ids (and precompute? true)))
83+
(define node-ids (map (curry egraph-add-expr egg-graph) exprs))
84+
(define iter-data (egg-run-rules egg-graph (*node-limit*) irules node-ids (and precompute? true)))
8885

89-
(when (egraph-is-unsound-detected egg-graph)
90-
(warn 'unsound-rules #:url "faq.html#unsound-rules"
91-
"Unsound rule application detected in e-graph. Results from simplify may not be sound."))
86+
(when (egraph-is-unsound-detected egg-graph)
87+
(warn 'unsound-rules #:url "faq.html#unsound-rules"
88+
"Unsound rule application detected in e-graph. Results from simplify may not be sound."))
9289

93-
(for ([rule rls])
94-
(define count (egraph-get-times-applied egg-graph (rule-name rule)))
95-
(when (> count 0)
96-
(timeline-push! 'rules (~a (rule-name rule)) count)))
90+
(for ([rule rls])
91+
(define count (egraph-get-times-applied egg-graph (rule-name rule)))
92+
(when (> count 0)
93+
(timeline-push! 'rules (~a (rule-name rule)) count)))
9794

98-
(map
99-
(lambda (id)
100-
(for/list ([iter (in-range (length iter-data))])
101-
(egg-expr->expr (egraph-get-simplest egg-graph id iter) egg-graph)))
102-
node-ids))))))
95+
(for/list ([id node-ids])
96+
(for/list ([iter (in-range (length iter-data))])
97+
(egg-expr->expr (egraph-get-simplest egg-graph id iter) egg-graph))))))
10398

10499
(define (stop-reason->string sr)
105100
(match sr

0 commit comments

Comments
 (0)