Skip to content

Commit 8be6c94

Browse files
authored
Merge pull request #1366 from herbie-fp/batchify-series
Batchifying `run-improve!`
2 parents 2b27f77 + a063b7a commit 8be6c94

File tree

8 files changed

+180
-136
lines changed

8 files changed

+180
-136
lines changed

src/core/alt-table.rkt

Lines changed: 44 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,40 +12,39 @@
1212
"programs.rkt")
1313

1414
(provide (contract-out
15-
(make-alt-table (pcontext? alt? any/c . -> . alt-table?))
15+
(make-alt-table (batch? pcontext? alt? any/c . -> . alt-table?))
1616
(atab-active-alts (alt-table? . -> . (listof alt?)))
1717
(atab-all-alts (alt-table? . -> . (listof alt?)))
1818
(atab-not-done-alts (alt-table? . -> . (listof alt?)))
19-
(atab-eval-altns (alt-table? (listof alt?) context? . -> . (values any/c any/c)))
19+
(atab-eval-altns (alt-table? batch? (listof alt?) context? . -> . (values any/c any/c)))
2020
(atab-add-altns (alt-table? (listof alt?) any/c any/c context? . -> . alt-table?))
2121
(atab-set-picked (alt-table? (listof alt?) . -> . alt-table?))
2222
(atab-completed? (alt-table? . -> . boolean?))
23-
(atab-min-errors (alt-table? . -> . (listof real?)))))
23+
(atab-min-errors (alt-table? . -> . (listof real?)))
24+
(alt-batch-costs (batch? representation? . -> . (batchref? . -> . real?)))))
2425

2526
;; Public API
2627

2728
(struct alt-table (point-idx->alts alt->point-idxs alt->done? alt->cost pcontext all) #:prefab)
2829

29-
(define (alt-batch-cost batch brfs repr)
30+
(define (alt-batch-costs batch repr)
3031
(define node-cost-proc (platform-node-cost-proc (*active-platform*)))
31-
(define costs
32-
(batch-map batch
33-
(λ (get-args-costs node)
34-
(match node
35-
[(? literal?) ((node-cost-proc node repr))]
36-
[(? symbol?) ((node-cost-proc node repr))]
37-
[(? number?) 0] ; specs
38-
[(approx _ impl) (get-args-costs impl)]
39-
[(list (? (negate impl-exists?) impl) args ...) 0] ; specs
40-
[(list impl args ...)
41-
(define cost-proc (node-cost-proc node repr))
42-
(define itypes (impl-info impl 'itype))
43-
(apply cost-proc (map get-args-costs args))]))))
44-
(map costs brfs))
45-
46-
(define (make-alt-table pcontext initial-alt ctx)
47-
(define cost (alt-cost initial-alt (context-repr ctx)))
48-
(define errs (errors (alt-expr initial-alt) pcontext ctx))
32+
(batch-map batch
33+
(λ (get-args-costs node)
34+
(match node
35+
[(? literal?) ((node-cost-proc node repr))]
36+
[(? symbol?) ((node-cost-proc node repr))]
37+
[(? number?) 0] ; specs
38+
[(approx _ impl) (get-args-costs impl)]
39+
[(list (? (negate impl-exists?) impl) args ...) 0] ; specs
40+
[(list impl args ...)
41+
(define cost-proc (node-cost-proc node repr))
42+
(define itypes (impl-info impl 'itype))
43+
(apply cost-proc (map get-args-costs args))]))))
44+
45+
(define (make-alt-table batch pcontext initial-alt ctx)
46+
(define cost ((alt-batch-costs batch (context-repr ctx)) (alt-expr initial-alt)))
47+
(define errs (batchref-errors (alt-expr initial-alt) pcontext ctx))
4948
(alt-table (for/vector #:length (pcontext-length pcontext)
5049
([err (in-list errs)])
5150
(list (pareto-point cost err (list initial-alt))))
@@ -180,10 +179,10 @@
180179
[alt->done? (hash-remove* alt->done? altns)]
181180
[alt->cost (hash-remove* alt->cost altns)]))
182181

183-
(define (atab-eval-altns atab altns ctx)
184-
(define-values (batch brfs) (progs->batch (map alt-expr altns) #:vars (context-vars ctx)))
182+
(define (atab-eval-altns atab batch altns ctx)
183+
(define brfs (map alt-expr altns))
185184
(define errss (batch-errors batch brfs (alt-table-pcontext atab) ctx))
186-
(define costs (alt-batch-cost batch brfs (context-repr ctx)))
185+
(define costs (map (alt-batch-costs batch (context-repr ctx)) brfs))
187186
(values errss costs))
188187

189188
(define (atab-add-altns atab altns errss costs ctx)
@@ -215,22 +214,26 @@
215214
(match-define (alt-table point-idx->alts alt->point-idxs alt->done? alt->cost pcontext _) atab)
216215
(define max-error (+ 1 (expt 2 (representation-total-bits (context-repr ctx)))))
217216

218-
(define point-idx->alts*
219-
(for/vector #:length (vector-length point-idx->alts)
220-
([pcurve (in-vector point-idx->alts)]
221-
[err (in-list errs)])
222-
(cond
223-
[(< err max-error) ; Only include points if they are valid
224-
(define ppt (pareto-point cost err (list altn)))
225-
(pareto-union (list ppt) pcurve #:combine append)]
226-
[else pcurve])))
227-
228-
(alt-table point-idx->alts*
229-
(hash-set alt->point-idxs altn #f)
230-
(hash-set alt->done? altn #f)
231-
(hash-set alt->cost altn cost)
232-
pcontext
233-
#f))
217+
;; Check whether altn is already inserted into atab
218+
(match (hash-has-key? alt->point-idxs altn)
219+
[#f
220+
(define point-idx->alts*
221+
(for/vector #:length (vector-length point-idx->alts)
222+
([pcurve (in-vector point-idx->alts)]
223+
[err (in-list errs)])
224+
(cond
225+
[(< err max-error) ; Only include points if they are valid
226+
(define ppt (pareto-point cost err (list altn)))
227+
(pareto-union (list ppt) pcurve #:combine append)]
228+
[else pcurve])))
229+
230+
(alt-table point-idx->alts*
231+
(hash-set alt->point-idxs altn #f)
232+
(hash-set alt->done? altn #f)
233+
(hash-set alt->cost altn cost)
234+
pcontext
235+
#f)]
236+
[_ atab]))
234237

235238
(define (atab-min-errors atab)
236239
(define pnt-idx->alts (alt-table-point-idx->alts atab))

src/core/batch.rkt

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040

4141
(struct batchref (batch idx) #:transparent)
4242

43+
;; --------------------------------- CORE BATCH FUNCTION ------------------------------------
44+
4345
(define (batch-empty)
4446
(batch (make-dvector) (make-hash) (make-hasheq)))
4547

@@ -107,13 +109,11 @@
107109
(values out brfs))
108110

109111
(define (batch->progs b brfs)
110-
(brfs-belong-to-batch? b brfs)
111112
(map (batch-exprs b) brfs))
112113

113114
;; batch-map does not iterate over nodes that are not a child of brf
114115
;; A lot of parts of Herbie rely on that
115116
(define (batch-map batch f)
116-
(define len (batch-length batch))
117117
(define out (make-dvector))
118118
(define visited (make-dvector))
119119
(λ (brf)
@@ -131,22 +131,24 @@
131131
(dvector-set! visited idx #t)
132132
res]))))
133133

134-
(define (batch-apply b brfs f)
135-
(define out (batch-empty))
136-
(define apply-f
137-
(λ (remap node)
138-
(define node-with-batchrefs (expr-recurse node (lambda (ref) (batchref b ref))))
139-
(define node* (f node-with-batchrefs))
140-
(define brf*
141-
(let loop ([node* node*])
142-
(match node*
143-
[(? batchref? brf) (remap (batchref-idx brf))]
144-
[_ (batch-push! out (expr-recurse node* (compose batchref-idx loop)))])))
145-
brf*))
146-
(define brfs* (map (batch-map b apply-f) brfs))
147-
(values out brfs*))
134+
(define (batch-ref batch reg)
135+
(dvector-ref (batch-nodes batch) reg))
148136

149-
(define (batch-apply! b f)
137+
(define (batch-pull brf)
138+
(define (unmunge brf)
139+
(expr-recurse (deref brf) unmunge))
140+
(unmunge brf))
141+
142+
(define (brfs-belong-to-batch? batch brfs)
143+
(unless (andmap (compose (curry equal? batch) batchref-batch) brfs)
144+
(error 'brfs-belong-to-batch? "One of batchrefs does not belong to the provided batch")))
145+
146+
;; --------------------------------- CUSTOM BATCH FUNCTION ------------------------------------
147+
148+
;; out - batch to where write new nodes
149+
;; b - batch from which to read nodes
150+
;; f - function to modify nodes from b
151+
(define (batch-apply-internal out b f)
150152
(batch-map b
151153
(λ (remap node)
152154
(define node-with-batchrefs (expr-recurse node (lambda (ref) (batchref b ref))))
@@ -155,9 +157,22 @@
155157
(let loop ([node* node*])
156158
(match node*
157159
[(? batchref? brf) (remap (batchref-idx brf))]
158-
[_ (batch-push! b (expr-recurse node* (compose batchref-idx loop)))])))
160+
[_ (batch-push! out (expr-recurse node* (compose batchref-idx loop)))])))
159161
brf*)))
160162

163+
;; Allocates new batch
164+
(define (batch-apply b brfs f)
165+
(define out (batch-empty))
166+
(define apply-f (batch-apply-internal out b f))
167+
(define brfs* (map apply-f brfs))
168+
(values out brfs*))
169+
170+
;; Modifies batch in-place
171+
(define (batch-apply! b f)
172+
(batch-apply-internal b b f))
173+
174+
;; Function returns indices of children nodes within a batch for given roots,
175+
;; where a child node is a child of a root + meets a condition - (condition node)
161176
(define (batch-reachable batch brfs #:condition [condition (const #t)])
162177
; Little check
163178
(brfs-belong-to-batch? batch brfs)
@@ -207,21 +222,11 @@
207222
(apply + 1 (map get-children-counts args)))))
208223
(apply + (map counts brfs)))
209224

210-
(define (brfs-belong-to-batch? batch brfs)
211-
(unless (andmap (compose (curry equal? batch) batchref-batch) brfs)
212-
(error 'brfs-belong-to-batch? "One of batchrefs does not belong to the provided batch")))
213-
214225
;; The function removes any zombie nodes from batch with respect to the brfs
215226
(define (batch-copy-only batch brfs)
216227
(batch-apply batch brfs identity))
217228

218-
(define (batch-ref batch reg)
219-
(dvector-ref (batch-nodes batch) reg))
220-
221-
(define (batch-pull brf)
222-
(define (unmunge brf)
223-
(expr-recurse (deref brf) unmunge))
224-
(unmunge brf))
229+
;; --------------------------------- TESTS ---------------------------------------
225230

226231
; Tests for progs->batch and batch->progs
227232
(module+ test

src/core/compiler.rkt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,9 @@
7474
(define vars (context-vars ctx))
7575
(define num-vars (length vars))
7676

77-
(timeline-push! 'compiler (batch-tree-size batch brfs) (batch-length batch))
78-
7977
; Here we need to keep vars even though no roots refer to the vars
8078
(define-values (batch* brfs*) (batch-for-compiler batch brfs vars))
79+
(timeline-push! 'compiler (batch-tree-size batch* brfs*) (batch-length batch*))
8180

8281
(define instructions
8382
(for/vector #:length (- (batch-length batch*) num-vars)

src/core/mainloop.rkt

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
(define *start-prog* (make-parameter #f))
3535
(define *pcontext* (make-parameter #f))
3636
(define *preprocessing* (make-parameter '()))
37+
(define *global-batch* (make-parameter #f))
3738

3839
;; These high-level functions give the high-level workflow of Herbie:
3940
;; - Initial steps: explain, preprocessing, initialize the alt table
@@ -47,9 +48,11 @@
4748
(define pcontext* (preprocess-pcontext context pcontext preprocessing))
4849
(*pcontext* pcontext*)
4950
(*start-prog* initial)
51+
(*global-batch* (batch-empty))
5052
(*preprocessing* preprocessing)
51-
(define start-alt (alt initial 'start '()))
52-
(^table^ (make-alt-table pcontext start-alt context))
53+
(define initial-brf (batch-add! (*global-batch*) initial))
54+
(define start-alt (alt initial-brf 'start '()))
55+
(^table^ (make-alt-table (*global-batch*) pcontext start-alt context))
5356

5457
(for ([iteration (in-range (*num-iterations*))]
5558
#:break (atab-completed? (^table^)))
@@ -63,8 +66,7 @@
6366

6467
(define (extract!)
6568
(timeline-push-alts! '())
66-
67-
(define all-alts (atab-all-alts (^table^)))
69+
(define all-alts (unbatchify-alts (*global-batch*) (atab-all-alts (^table^))))
6870
(define joined-alts (make-regime! all-alts (*start-prog*))) ;; HERE
6971
(define annotated-alts (add-derivations! joined-alts))
7072

@@ -111,15 +113,19 @@
111113

112114
(define (score-alt alt)
113115
(errors-score (errors (alt-expr alt) (*pcontext*) (*context*))))
116+
(define (batch-score-alts altns)
117+
(map errors-score (batch-errors (*global-batch*) (map alt-expr altns) (*pcontext*) (*context*))))
114118

115119
; Pareto mode alt picking
116120
(define (choose-mult-alts altns)
117121
(define repr (context-repr (*context*)))
118122
(cond
119123
[(< (length altns) (*pareto-pick-limit*)) altns] ; take max
120124
[else
121-
(define best (argmin score-alt altns))
122-
(define altns* (sort (set-remove altns best) < #:key (curryr alt-cost repr)))
125+
(define scores (batch-score-alts altns))
126+
(define best (list-ref altns (index-of scores (argmin identity scores))))
127+
(define alt-costs (alt-batch-costs (*global-batch*) repr))
128+
(define altns* (sort (set-remove altns best) < #:key (compose alt-costs alt-expr)))
123129
(define simplest (car altns*))
124130
(define altns** (cdr altns*))
125131
(define div-size (round (/ (length altns**) (- (*pareto-pick-limit*) 1))))
@@ -128,16 +134,18 @@
128134
(list-ref altns** (- (* i div-size) 1))))]))
129135

130136
(define (timeline-push-alts! picked-alts)
137+
(define exprs (batch-exprs (*global-batch*)))
131138
(define fresh-alts (atab-not-done-alts (^table^)))
132139
(define repr (context-repr (*context*)))
133-
(for ([alt (atab-active-alts (^table^))])
140+
(for ([alt (atab-active-alts (^table^))]
141+
[sc (in-list (batch-score-alts (atab-active-alts (^table^))))])
134142
(timeline-push! 'alts
135-
(~a (alt-expr alt))
143+
(~a (exprs (alt-expr alt)))
136144
(cond
137145
[(set-member? picked-alts alt) "next"]
138146
[(set-member? fresh-alts alt) "fresh"]
139147
[else "done"])
140-
(score-alt alt)
148+
sc
141149
(~a (representation-name repr)))))
142150

143151
(define (choose-alts!)
@@ -150,7 +158,7 @@
150158
(void))
151159

152160
;; Converts a patch to full alt with valid history
153-
(define (reconstruct! global-batch alts)
161+
(define (reconstruct! alts)
154162
;; extracts the base expressions of a patch as a batchref
155163
(define (get-starting-expr altn)
156164
(match (alt-prevs altn)
@@ -169,7 +177,7 @@
169177
[(list 'evaluate) (list 'evaluate loc0)]
170178
[(list 'taylor name var) (list 'taylor loc0 name var)]
171179
[(list 'rr input proof) (list 'rr loc0 input proof)]))
172-
(define expr* (batch-location-set global-batch (alt-expr orig) loc0 (alt-expr altn)))
180+
(define expr* (batch-location-set (*global-batch*) (alt-expr orig) loc0 (alt-expr altn)))
173181
(alt expr* event* (list (loop (first prevs))))])))
174182

175183
(^patched^ (remove-duplicates
@@ -183,8 +191,6 @@
183191
(reconstruct-alt altn loc full-altn))))))
184192
#:key (compose batchref-idx alt-expr)))
185193

186-
(^patched^ (unbatchify-alts global-batch (^patched^)))
187-
; No need to unmunge ^next-alts^
188194
(void))
189195

190196
;; Finish iteration
@@ -197,7 +203,7 @@
197203
(define orig-fresh-alts (atab-not-done-alts (^table^)))
198204
(define orig-done-alts (set-subtract orig-all-alts (atab-not-done-alts (^table^))))
199205

200-
(define-values (errss costs) (atab-eval-altns (^table^) (^patched^) (*context*)))
206+
(define-values (errss costs) (atab-eval-altns (^table^) (*global-batch*) (^patched^) (*context*)))
201207
(timeline-event! 'prune)
202208
(^table^ (atab-add-altns (^table^) (^patched^) errss costs (*context*)))
203209
(define final-fresh-alts (atab-not-done-alts (^table^)))
@@ -232,14 +238,10 @@
232238
(unless (^next-alts^)
233239
(choose-alts!))
234240

235-
(define-values (global-batch brfs) (progs->batch (map alt-expr (^next-alts^))))
236-
(define (make-batchref x brf)
237-
(struct-copy alt x [expr brf]))
238-
239-
(^next-alts^ (map make-batchref (^next-alts^) brfs))
240-
(define brfs* (batch-reachable global-batch brfs #:condition node-is-impl?))
241+
(define brfs (map alt-expr (^next-alts^)))
242+
(define brfs* (batch-reachable (*global-batch*) brfs #:condition node-is-impl?))
241243

242-
(reconstruct! global-batch (generate-candidates global-batch brfs*))
244+
(reconstruct! (generate-candidates (*global-batch*) brfs*))
243245
(finalize-iter!)
244246
(void))
245247

0 commit comments

Comments
 (0)