Skip to content

Commit 5b27dd3

Browse files
authored
Merge pull request #1299 from herbie-fp/codex/use-roots-from-batch-in-make-egraph-and-make-egglog-runner
Refactor roots handling
2 parents 70ca116 + 456c431 commit 5b27dd3

File tree

5 files changed

+34
-39
lines changed

5 files changed

+34
-39
lines changed

src/core/egg-herbie.rkt

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
[egraph-pointer (egraph_copy (egraph-data-egraph-pointer eg-data))]))
6868

6969
; Adds expressions returning the root ids
70-
(define (egraph-add-exprs egg-data batch roots ctx)
70+
(define (egraph-add-exprs egg-data batch ctx)
7171
(match-define (egraph-data ptr id->spec) egg-data)
7272

7373
; normalizes an approx spec
@@ -108,13 +108,12 @@
108108
[(list op ids ...) (egraph_add_node ptr (~s op) (list->u32vec ids))]
109109
[(? (disjoin symbol? number?) x) (egraph_add_node ptr (~s x) 0-vec)]))
110110

111-
(define insert-batch (batch-remove-zombie batch roots))
112-
(define mappings (build-vector (batch-length insert-batch) values))
111+
(define mappings (build-vector (batch-length batch) values))
113112
(define (remap x)
114113
(vector-ref mappings x))
115114

116115
; Inserting nodes bottom-up
117-
(for ([node (in-vector (batch-nodes insert-batch))]
116+
(for ([node (in-vector (batch-nodes batch))]
118117
[n (in-naturals)])
119118
(define idx
120119
(match node
@@ -125,20 +124,20 @@
125124
[(approx spec impl) (insert-node! (list '$approx (remap spec) (remap impl)))]
126125
[(list op (app remap args) ...) (insert-node! (cons op args))]))
127126
(vector-set! mappings n idx))
128-
(for ([root (in-vector (batch-roots insert-batch))])
127+
(for ([root (in-vector (batch-roots batch))])
129128
(egraph_add_root ptr (remap root)))
130129

131-
(for ([node (in-vector (batch-nodes insert-batch))]
130+
(for ([node (in-vector (batch-nodes batch))]
132131
#:when (approx? node))
133132
(match-define (approx spec impl) node)
134133
(hash-ref! id->spec
135134
(remap spec)
136135
(lambda ()
137-
(define spec* (normalize-spec (batch-ref insert-batch spec)))
138-
(define type (representation-type (repr-of-node insert-batch impl ctx)))
136+
(define spec* (normalize-spec (batch-ref batch spec)))
137+
(define type (representation-type (repr-of-node batch impl ctx)))
139138
(cons spec* type))))
140139

141-
(for/list ([root (in-vector (batch-roots insert-batch))])
140+
(for/list ([root (in-vector (batch-roots batch))])
142141
(remap root)))
143142

144143
;; runs rules on an egraph (optional iteration limit)
@@ -202,7 +201,7 @@
202201

203202
(define (egraph-expr-equal? egraph-data expr goal ctx)
204203
(define batch (progs->batch (list expr goal)))
205-
(match-define (list id1 id2) (egraph-add-exprs egraph-data batch (batch-roots batch) ctx))
204+
(match-define (list id1 id2) (egraph-add-exprs egraph-data batch ctx))
206205
(= id1 id2))
207206

208207
;; returns a flattened list of terms or #f if it failed to expand the proof due to budget
@@ -1222,12 +1221,12 @@
12221221
(loop (sub1 num-iters)))]
12231222
[else (values egg-graph iteration-data)])))
12241223

1225-
(define (egraph-run-schedule batch roots schedule ctx)
1224+
(define (egraph-run-schedule batch schedule ctx)
12261225
; allocate the e-graph
12271226
(define egg-graph (make-egraph-data))
12281227

12291228
; insert expressions into the e-graph
1230-
(define root-ids (egraph-add-exprs egg-graph batch roots ctx))
1229+
(define root-ids (egraph-add-exprs egg-graph batch ctx))
12311230

12321231
; run the schedule
12331232
(define egg-graph*
@@ -1266,7 +1265,7 @@
12661265

12671266
;; Herbie's version of an egg runner.
12681267
;; Defines parameters for running rewrite rules with egg
1269-
(struct egg-runner (batch roots reprs schedule ctx new-roots egg-graph)
1268+
(struct egg-runner (batch reprs schedule ctx new-roots egg-graph)
12701269
#:transparent ; for equality
12711270
#:methods gen:custom-write ; for abbreviated printing
12721271
[(define (write-proc alt port mode)
@@ -1282,7 +1281,7 @@
12821281
;; - scheduler: `(scheduler . <name>)` [default: backoff]
12831282
;; - `simple`: run all rules without banning
12841283
;; - `backoff`: ban rules if the fire too much
1285-
(define (make-egraph batch roots reprs schedule ctx)
1284+
(define (make-egraph batch reprs schedule ctx)
12861285
(define (oops! fmt . args)
12871286
(apply error 'verify-schedule! fmt args))
12881287
; verify the schedule
@@ -1306,10 +1305,10 @@
13061305
[_ (oops! "in instruction `~a`, unknown parameter `~a`" instr param)]))]
13071306
[_ (oops! "expected `(<rules> . <params>)`, got `~a`" instr)]))
13081307

1309-
(define-values (root-ids egg-graph) (egraph-run-schedule batch roots schedule ctx))
1308+
(define-values (root-ids egg-graph) (egraph-run-schedule batch schedule ctx))
13101309

13111310
; make the runner
1312-
(egg-runner batch roots reprs schedule ctx root-ids egg-graph))
1311+
(egg-runner batch reprs schedule ctx root-ids egg-graph))
13131312

13141313
(define (regraph-dump regraph root-ids reprs)
13151314
(define dump-dir "dump-egg")

src/core/egglog-herbie-tests.rkt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -363,8 +363,6 @@
363363
#s(approx (- (sin (+ x 1)) (sin x)) #s(hole binary64 (- (sin (- 1 (* -1 x))) (sin x))))
364364
#s(approx (sin (+ x 1)) #s(hole binary64 (sin (- 1 (* -1 x))))))))
365365

366-
(define roots (batch-roots batch))
367-
368366
(define ctx (make-debug-context '(x eps)))
369367

370368
(define reprs (make-list (vector-length (batch-roots batch)) (context-repr ctx)))
@@ -376,4 +374,4 @@
376374
(lower . ((iteration . 1) (scheduler . simple)))))
377375

378376
(when (find-executable-path "egglog")
379-
(run-egglog-multi-extractor (egglog-runner batch roots reprs schedule ctx) batch)))
377+
(run-egglog-multi-extractor (egglog-runner batch reprs schedule ctx) batch)))

src/core/egglog-herbie.rkt

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848

4949
;; Herbie's version of an egglog runner.
5050
;; Defines parameters for running rewrite rules with egglog
51-
(struct egglog-runner (batch roots reprs schedule ctx)
51+
(struct egglog-runner (batch reprs schedule ctx)
5252
#:transparent ; for equality
5353
#:methods gen:custom-write ; for abbreviated printing
5454
[(define (write-proc alt port mode)
@@ -65,7 +65,7 @@
6565
;; - scheduler: `(scheduler . <name>)` [default: backoff]
6666
;; - `simple`: run all rules without banning
6767
;; - `backoff`: ban rules if the fire too much
68-
(define (make-egglog-runner batch roots reprs schedule ctx)
68+
(define (make-egglog-runner batch reprs schedule ctx)
6969
(define (oops! fmt . args)
7070
(apply error 'verify-schedule! fmt args))
7171
; verify the schedule
@@ -91,12 +91,11 @@
9191
[_ (oops! "expected `(<rules> . <params>)`, got `~a`" instr)]))
9292

9393
; make the runner
94-
(egglog-runner batch roots reprs schedule ctx))
94+
(egglog-runner batch reprs schedule ctx))
9595

9696
;; Runs egglog using an egglog runner by extracting multiple variants
9797
(define (run-egglog-multi-extractor runner output-batch) ; multi expression extraction
98-
(define insert-batch
99-
(batch-remove-zombie (egglog-runner-batch runner) (egglog-runner-roots runner)))
98+
(define insert-batch (egglog-runner-batch runner))
10099
(define curr-program (make-egglog-program))
101100

102101
;; Dump-file

src/core/patch.rkt

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,17 @@
7676
(define schedule `((lower . ((iteration . 1) (scheduler . simple)))))
7777

7878
; run egg
79-
(define exprs (map (compose debatchref alt-expr) altns))
80-
(define input-batch (progs->batch exprs))
81-
8279
(define roots (list->vector (map (compose batchref-idx alt-expr) altns)))
83-
(define reprs (map (curryr repr-of (*context*)) exprs))
80+
(define reprs
81+
(for/list ([root (in-vector roots)])
82+
(repr-of-node global-batch root (*context*))))
83+
84+
(define batch* (batch-remove-zombie global-batch roots))
8485

8586
(define runner
8687
(if (flag-set? 'generate 'egglog)
87-
(make-egglog-runner input-batch (batch-roots input-batch) reprs schedule (*context*))
88-
(make-egraph global-batch roots reprs schedule (*context*))))
88+
(make-egglog-runner batch* reprs schedule (*context*))
89+
(make-egraph batch* reprs schedule (*context*))))
8990

9091
(define batchrefss
9192
(if (flag-set? 'generate 'egglog)
@@ -154,18 +155,17 @@
154155
`(,rules . ((node . ,(*node-limit*))))
155156
`(lower . ((iteration . 1) (scheduler . simple)))))
156157

157-
; run egg
158-
(define exprs (map (compose debatchref alt-expr) altns))
159-
(define input-batch (progs->batch exprs))
160-
161158
(define roots (list->vector (map (compose batchref-idx alt-expr) altns)))
162-
(define reprs (map (curryr repr-of (*context*)) exprs))
163-
(timeline-push! 'inputs (map ~a exprs))
159+
(define reprs
160+
(for/list ([root (in-vector roots)])
161+
(repr-of-node global-batch root (*context*))))
162+
(define batch* (batch-remove-zombie global-batch roots))
163+
(timeline-push! 'inputs (map (compose ~a debatchref alt-expr) altns))
164164

165165
(define runner
166166
(if (flag-set? 'generate 'egglog)
167-
(make-egglog-runner input-batch (batch-roots input-batch) reprs schedule (*context*))
168-
(make-egraph global-batch roots reprs schedule (*context*))))
167+
(make-egglog-runner batch* reprs schedule (*context*))
168+
(make-egraph batch* reprs schedule (*context*))))
169169

170170
(define batchrefss
171171
(if (flag-set? 'generate 'egglog)

src/core/preprocess.rkt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@
7272
(define batch (progs->batch (cons spec (map cdr identities))))
7373
(define runner
7474
(make-egraph batch
75-
(batch-roots batch)
7675
(make-list (vector-length (batch-roots batch)) (context-repr ctx))
7776
`((,rules . ((node . ,(*node-limit*)))))
7877
ctx))

0 commit comments

Comments
 (0)