Skip to content

Commit f998550

Browse files
authored
Merge pull request #987 from herbie-fp/artem-batch-egg-input
Batch-input to egg
2 parents 6490ee3 + c43e889 commit f998550

File tree

9 files changed

+157
-72
lines changed

9 files changed

+157
-72
lines changed

src/core/batch.rkt

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
deref ; Batchref -> Expr
1414
batch-replace ; Batch -> (Expr<Batchref> -> Expr<Batchref>) -> Batch
1515
egg-nodes->batch ; Nodes -> Spec-maps -> Batch -> (Listof Batchref)
16-
batchref->expr) ; Batchref -> Expr
16+
batchref->expr ; Batchref -> Expr
17+
batch-extract-exprs ; Batch -> (Listof Root) -> (Listof Expr)
18+
remove-zombie-nodes) ; Batch -> Batch
1719

1820
;; This function defines the recursive structure of expressions
1921
(define (expr-recurse expr f)
@@ -83,6 +85,14 @@
8385
(timeline-push! 'compiler size (batch-length final)))
8486
final)
8587

88+
(define (batch-extract-exprs b roots)
89+
(define exprs (make-vector (batch-length b)))
90+
(for ([node (in-vector (batch-nodes b))]
91+
[idx (in-naturals)])
92+
(vector-set! exprs idx (expr-recurse node (lambda (x) (vector-ref exprs x)))))
93+
(for/list ([root roots])
94+
(vector-ref exprs root)))
95+
8696
(define (batch->progs b)
8797
(define exprs (make-vector (batch-length b)))
8898
(for ([node (in-vector (batch-nodes b))]
@@ -111,7 +121,9 @@
111121
(define roots (vector-map (curry vector-ref mapping) (batch-roots b)))
112122
(mutable-batch->batch out roots))
113123

114-
; The function removes any zombie nodes from batch
124+
; The function removes any zombie nodes from batch with respect to the roots
125+
; Time complexity: O(|R| + |N|), where |R| - number of roots, |N| - length of nodes
126+
; Space complexity: O(|N| + |N*| + |R|), where |N*| is a length of nodes without zombie nodes
115127
(define (remove-zombie-nodes input-batch)
116128
(define nodes (batch-nodes input-batch))
117129
(define roots (batch-roots input-batch))
@@ -123,29 +135,21 @@
123135
(for ([node (in-vector nodes (- nodes-length 1) -1 -1)]
124136
[zmb (in-vector zombie-mask (- nodes-length 1) -1 -1)]
125137
#:when (not zmb))
126-
(match node
127-
[(list op args ...) (map (λ (n) (vector-set! zombie-mask n #f)) args)]
128-
[(approx spec impl) (vector-set! zombie-mask impl #f)]
129-
[_ void]))
138+
(expr-recurse node (λ (n) (vector-set! zombie-mask n #f))))
130139

131-
(define mappings (build-vector nodes-length values))
140+
(define mappings (make-vector nodes-length -1))
141+
(define (remap idx)
142+
(vector-ref mappings idx))
132143

133-
(define nodes* '())
144+
(define out (make-mutable-batch))
134145
(for ([node (in-vector nodes)]
135146
[zmb (in-vector zombie-mask)]
136-
[n (in-naturals)])
137-
(if zmb
138-
(for ([i (in-range n nodes-length)])
139-
(vector-set! mappings i (sub1 (vector-ref mappings i))))
140-
(set! nodes*
141-
(cons (match node
142-
[(list op args ...) (cons op (map (curry vector-ref mappings) args))]
143-
[(approx spec impl) (approx spec (vector-ref mappings impl))]
144-
[_ node])
145-
nodes*))))
146-
(set! nodes* (list->vector (reverse nodes*)))
147+
[n (in-naturals)]
148+
#:unless zmb)
149+
(vector-set! mappings n (batch-push! out (expr-recurse node remap))))
150+
147151
(define roots* (vector-map (curry vector-ref mappings) roots))
148-
(batch nodes* roots* (batch-vars input-batch)))
152+
(mutable-batch->batch out roots*))
149153

150154
(define (batch-ref batch reg)
151155
(define (unmunge reg)
@@ -249,6 +253,7 @@
249253
(define (zombie-test #:nodes nodes #:roots roots)
250254
(define in-batch (batch nodes roots '()))
251255
(define out-batch (remove-zombie-nodes in-batch))
256+
(check-equal? (batch->progs out-batch) (batch->progs in-batch))
252257
(batch-nodes out-batch))
253258

254259
(check-equal? (vector 0 '(sqrt 0) 2 '(pow 2 1))

src/core/egg-herbie.rkt

Lines changed: 87 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,7 @@
7373
[egraph-pointer (egraph_copy (egraph-data-egraph-pointer eg-data))]))
7474

7575
; Adds expressions returning the root ids
76-
; TODO: take a batch rather than list of expressions
77-
(define (egraph-add-exprs egg-data exprs ctx)
76+
(define (egraph-add-exprs egg-data batch roots ctx)
7877
(match-define (egraph-data ptr herbie->egg-dict egg->herbie-dict id->spec) egg-data)
7978

8079
; lookups the egg name of a variable
@@ -125,39 +124,86 @@
125124
[(? symbol? x) (egraph_add_node ptr (symbol->string x) 0-vec root?)]
126125
[(? number? n) (egraph_add_node ptr (number->string n) 0-vec root?)]))
127126

127+
; The function recurses on spec
128+
(define (batch-parse-approx batch)
129+
(batch-replace batch
130+
(lambda (node)
131+
(match node
132+
[(approx spec impl) (list '$approx spec impl)]
133+
[_ node]))))
134+
135+
(set-batch-roots! batch roots) ; make sure that we work with the right roots
136+
; the algorithm may crash if batch-length is zero
137+
(define insert-batch
138+
(if (zero? (batch-length batch)) batch (remove-zombie-nodes (batch-parse-approx batch))))
139+
140+
(define mappings (build-vector (batch-length insert-batch) values))
141+
(define (remap x)
142+
(vector-ref mappings x))
143+
144+
; Inserting nodes bottom-up
145+
(define root-mask (make-vector (batch-length insert-batch) #f))
146+
(for ([root (in-vector (batch-roots insert-batch))])
147+
(vector-set! root-mask root #t))
148+
(for ([node (in-vector (batch-nodes insert-batch))]
149+
[root? (in-vector root-mask)]
150+
[n (in-naturals)])
151+
(define node*
152+
(match node
153+
[(literal v _) v]
154+
[(? number?) node]
155+
[(? symbol?) (normalize-var node)]
156+
[(list '$approx spec impl)
157+
(hash-ref! id->spec
158+
(remap spec)
159+
(lambda ()
160+
(define spec* (normalize-spec (batch-ref insert-batch spec)))
161+
(define type (representation-type (repr-of-node insert-batch impl ctx)))
162+
(cons spec* type))) ; preserved spec and type for extraction
163+
(list '$approx (remap spec) (remap impl))]
164+
[(list op (app remap args) ...) (cons op args)]))
165+
166+
(vector-set! mappings n (insert-node! node* root?)))
167+
168+
;------------------------- DEBUGGING
128169
; expr -> id
129170
; expression cache
130-
(define expr->id (make-hash))
171+
#;(define expr->id (make-hash))
131172

132173
; expr -> natural
133174
; inserts an expresison into the e-graph, returning its e-class id.
134-
(define (insert! expr [root? #f])
135-
; transform the expression into a node pointing
136-
; to its child e-classes
137-
(define node
138-
(match expr
139-
[(? number?) expr]
140-
[(? symbol?) (normalize-var expr)]
141-
[(literal v _) v]
142-
[(approx spec impl)
143-
(define spec* (insert! spec))
144-
(define impl* (insert! impl))
145-
(hash-ref! id->spec
146-
spec*
147-
(lambda ()
148-
(define spec* (normalize-spec spec)) ; preserved spec for extraction
149-
(define type (representation-type (repr-of impl ctx))) ; track type of spec
150-
(cons spec* type)))
151-
(list '$approx spec* impl*)]
152-
[(list op args ...) (cons op (map insert! args))]))
153-
; always insert the node if it is a root since
154-
; the e-graph tracks which nodes are roots
155-
(cond
156-
[root? (insert-node! node #t)]
157-
[else (hash-ref! expr->id node (lambda () (insert-node! node #f)))]))
158-
159-
(for/list ([expr (in-list exprs)])
160-
(insert! expr #t)))
175+
#;(define (insert! expr [root? #f])
176+
; transform the expression into a node pointing
177+
; to its child e-classes
178+
(define node
179+
(match expr
180+
[(literal v _) v]
181+
[(? number?) expr]
182+
[(? symbol?) (normalize-var expr)]
183+
[(list '$approx spec impl)
184+
(define spec* (insert! (vector-ref nodes spec)))
185+
(define impl* (insert! (vector-ref nodes impl)))
186+
(hash-ref! id->spec
187+
spec*
188+
(lambda ()
189+
(define spec* (normalize-spec (batch-ref insert-batch spec)))
190+
(define type (representation-type (repr-of-node insert-batch impl ctx)))
191+
(cons spec* type)))
192+
(list '$approx spec* impl*)]
193+
[(list op args ...) (cons op (map insert! (map (curry vector-ref nodes) args)))]))
194+
; always insert the node if it is a root since
195+
; the e-graph tracks which nodes are roots
196+
(cond
197+
[root? (insert-node! node #t)]
198+
[else (hash-ref! expr->id node (lambda () (insert-node! node #f)))]))
199+
200+
#;(define nodes (batch-nodes insert-batch))
201+
#;(for/list ([root (in-vector (batch-roots insert-batch))])
202+
(insert! (vector-ref nodes root) #t))
203+
; ---------------------- END OF DEBUGGING
204+
205+
(for/list ([root (in-vector (batch-roots insert-batch))])
206+
(remap root)))
161207

162208
;; runs rules on an egraph (optional iteration limit)
163209
(define (egraph-run egraph-data ffi-rules node-limit iter-limit scheduler const-folding?)
@@ -226,7 +272,8 @@
226272
(egraph_find (egraph-data-egraph-pointer egraph-data) id))
227273

228274
(define (egraph-expr-equal? egraph-data expr goal ctx)
229-
(match-define (list id1 id2) (egraph-add-exprs egraph-data (list expr goal) ctx))
275+
(define batch (progs->batch (list expr goal)))
276+
(match-define (list id1 id2) (egraph-add-exprs egraph-data batch (batch-roots batch) ctx))
230277
(= id1 id2))
231278

232279
;; returns a flattened list of terms or #f if it failed to expand the proof due to budget
@@ -1198,12 +1245,12 @@
11981245
(loop (sub1 num-iters)))]
11991246
[else (values egg-graph iteration-data)])))
12001247

1201-
(define (egraph-run-schedule exprs schedule ctx)
1248+
(define (egraph-run-schedule batch roots schedule ctx)
12021249
; allocate the e-graph
12031250
(define egg-graph (make-egraph))
12041251

12051252
; insert expressions into the e-graph
1206-
(define root-ids (egraph-add-exprs egg-graph exprs ctx))
1253+
(define root-ids (egraph-add-exprs egg-graph batch roots ctx))
12071254

12081255
; run the schedule
12091256
(define egg-graph*
@@ -1235,7 +1282,7 @@
12351282

12361283
;; Herbie's version of an egg runner.
12371284
;; Defines parameters for running rewrite rules with egg
1238-
(struct egg-runner (exprs reprs schedule ctx)
1285+
(struct egg-runner (batch roots reprs schedule ctx)
12391286
#:transparent ; for equality
12401287
#:methods gen:custom-write ; for abbreviated printing
12411288
[(define (write-proc alt port mode)
@@ -1252,7 +1299,7 @@
12521299
;; - scheduler: `(scheduler . <name>)` [default: backoff]
12531300
;; - `simple`: run all rules without banning
12541301
;; - `backoff`: ban rules if the fire too much
1255-
(define (make-egg-runner exprs reprs schedule #:context [ctx (*context*)])
1302+
(define (make-egg-runner batch roots reprs schedule #:context [ctx (*context*)])
12561303
(define (oops! fmt . args)
12571304
(apply error 'verify-schedule! fmt args))
12581305
; verify the schedule
@@ -1273,7 +1320,7 @@
12731320
[_ (oops! "in instruction `~a`, unknown parameter `~a`" instr param)]))]
12741321
[_ (oops! "expected `(<rules> . <params>)`, got `~a`" instr)]))
12751322
; make the runner
1276-
(egg-runner exprs reprs schedule ctx))
1323+
(egg-runner batch roots reprs schedule ctx))
12771324

12781325
;; Runs egg using an egg runner.
12791326
;;
@@ -1285,7 +1332,10 @@
12851332
;; Run egg using runner
12861333
(define ctx (egg-runner-ctx runner))
12871334
(define-values (root-ids egg-graph)
1288-
(egraph-run-schedule (egg-runner-exprs runner) (egg-runner-schedule runner) ctx))
1335+
(egraph-run-schedule (egg-runner-batch runner)
1336+
(egg-runner-roots runner)
1337+
(egg-runner-schedule runner)
1338+
ctx))
12891339
; Perform extraction
12901340
(match cmd
12911341
[`(single . ,extractor) ; single expression extraction

src/core/localize.rkt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,10 @@
3838
(define lowering-rules (platform-lowering-rules))
3939

4040
; egg runner (2-phases for real rewrites and implementation selection)
41+
(define batch (progs->batch progs))
4142
(define runner
42-
(make-egg-runner progs
43+
(make-egg-runner batch
44+
(batch-roots batch)
4345
reprs
4446
`((,lifting-rules . ((iteration . 1) (scheduler . simple)))
4547
(,rules . ((node . ,(*node-limit*))))

src/core/mainloop.rkt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
"preprocess.rkt"
2020
"programs.rkt"
2121
"../utils/timeline.rkt"
22-
"soundiness.rkt")
22+
"soundiness.rkt"
23+
"batch.rkt")
2324
(provide run-improve!)
2425

2526
;; The Herbie main loop goes through a simple iterative process:
@@ -374,7 +375,8 @@
374375
; egg runner
375376
(define exprs (map alt-expr alts))
376377
(define reprs (map (lambda (expr) (repr-of expr (*context*))) exprs))
377-
(define runner (make-egg-runner exprs reprs schedule))
378+
(define batch (progs->batch exprs))
379+
(define runner (make-egg-runner batch (batch-roots batch) reprs schedule))
378380

379381
; run egg
380382
(define simplified

src/core/patch.rkt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@
4040
`((,lowering-rules . ((iteration . 1) (scheduler . simple))))))
4141

4242
; run egg
43-
(define runner (make-egg-runner (map alt-expr approxs) reprs schedule))
43+
(define batch (progs->batch (map alt-expr approxs)))
44+
(define runner (make-egg-runner batch (batch-roots batch) reprs schedule))
4445
(define simplification-options
4546
(simplify-batch runner
4647
(typed-egg-extractor
@@ -89,7 +90,7 @@
8990
[altn (in-list altns)]
9091
[fv (in-list free-vars)]
9192
#:when (member var fv)) ; check whether var exists in expr at all
92-
(for ([_ (in-range (*taylor-order-limit*))])
93+
(for ([i (in-range (*taylor-order-limit*))])
9394
(define gen (genexpr))
9495
(unless (spec-has-nan? gen)
9596
(sow (alt gen `(taylor ,name ,var) (list altn) '())))))
@@ -133,7 +134,8 @@
133134
(define exprs (map alt-expr altns))
134135
(define reprs (map (curryr repr-of (*context*)) exprs))
135136
(timeline-push! 'inputs (map ~a exprs))
136-
(define runner (make-egg-runner exprs reprs schedule #:context (*context*)))
137+
(define batch (progs->batch exprs))
138+
(define runner (make-egg-runner batch (batch-roots batch) reprs schedule #:context (*context*)))
137139
; batchrefss is a (listof (listof batchref))
138140
(define batchrefss (run-egg runner `(multi . ,extractor)))
139141

src/core/preprocess.rkt

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
"programs.rkt"
1414
"points.rkt"
1515
"../utils/timeline.rkt"
16-
"../utils/float.rkt")
16+
"../utils/float.rkt"
17+
"batch.rkt")
1718

1819
(provide find-preprocessing
1920
preprocess-pcontext
@@ -66,7 +67,8 @@
6667
(,lowering-rules . ((iteration . 1) (scheduler . simple)))))
6768

6869
; egg query
69-
(define runner (make-egg-runner (list expr) (list (context-repr ctx)) schedule))
70+
(define batch (progs->batch (list expr)))
71+
(define runner (make-egg-runner batch (batch-roots batch) (list (context-repr ctx)) schedule))
7072

7173
; run egg
7274
(define simplified
@@ -100,8 +102,11 @@
100102

101103
;; make egg runner
102104
(define rules (real-rules (*simplify-rules*)))
105+
106+
(define batch (progs->batch specs))
103107
(define runner
104-
(make-egg-runner specs
108+
(make-egg-runner batch
109+
(batch-roots batch)
105110
(map (lambda (_) (context-repr ctx)) specs)
106111
`((,rules . ((node . ,(*node-limit*)))))))
107112

0 commit comments

Comments
 (0)