Skip to content

Commit 60e458c

Browse files
authored
Merge pull request #1298 from herbie-fp/codex/define-pcontext-in-mainloop-and-pass-as-argument
Refactor pcontext handling
2 parents 4492f5c + c260a39 commit 60e458c

File tree

7 files changed

+55
-56
lines changed

7 files changed

+55
-56
lines changed

src/api/sandbox.rkt

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,13 @@
107107
(unless pcontext
108108
(error 'explain "cannot run without a pcontext"))
109109

110-
(*pcontext* pcontext)
111110
(define-values (fperrors
112111
sorted-explanations-table
113112
confusion-matrix
114113
maybe-confusion-matrix
115114
total-confusion-matrix
116115
freqs)
117-
(explain (test-input test) (*context*) (*pcontext*)))
116+
(explain (test-input test) (*context*) pcontext))
118117

119118
sorted-explanations-table)
120119

@@ -127,8 +126,7 @@
127126
(unless pcontext
128127
(error 'get-local-error "cannnot run without a pcontext"))
129128

130-
(*pcontext* pcontext)
131-
(local-error-as-tree (test-input test) (*context*)))
129+
(local-error-as-tree (test-input test) (*context*) pcontext))
132130

133131
(define (get-sample test)
134132
(random) ;; Tick the random number generator, for backwards compatibility

src/core/bsearch.rkt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@
3232
;; The last splitpoint uses +nan.0 for pt and represents the "else"
3333
(struct sp (cidx bexpr point) #:prefab)
3434

35-
(define (combine-alts best-option start-prog ctx)
35+
(define (combine-alts best-option start-prog ctx pcontext)
3636
(match-define (option splitindices alts pts expr _) best-option)
3737
(match splitindices
3838
[(list (si cidx _)) (list-ref alts cidx)]
3939
[_
4040
(timeline-event! 'bsearch)
41-
(define splitpoints (sindices->spoints pts expr alts splitindices start-prog ctx))
41+
(define splitpoints (sindices->spoints pts expr alts splitindices start-prog ctx pcontext))
4242

4343
(define expr*
4444
(for/fold ([expr (alt-expr (list-ref alts (sp-cidx (last splitpoints))))])
@@ -119,8 +119,8 @@
119119
;; float form always come from the range [f(idx1), f(idx2)). If the
120120
;; float form of a split is f(idx2), or entirely outside that range,
121121
;; problems may arise.
122-
(define/contract (sindices->spoints points expr alts sindices start-prog ctx)
123-
(-> (listof vector?) any/c (listof alt?) (listof si?) any/c context? valid-splitpoints?)
122+
(define/contract (sindices->spoints points expr alts sindices start-prog ctx pcontext)
123+
(-> (listof vector?) any/c (listof alt?) (listof si?) any/c context? pcontext? valid-splitpoints?)
124124
(define repr (repr-of expr ctx))
125125

126126
(define eval-expr (compile-prog expr ctx))
@@ -135,7 +135,7 @@
135135
(and start-prog (make-real-compiler (list (prog->spec start-prog)) (list ctx*))))
136136

137137
(define (prepend-macro v)
138-
(prepend-argument start-real-compiler v (*pcontext*)))
138+
(prepend-argument start-real-compiler v pcontext))
139139

140140
(define (find-split expr1 expr2 v1 v2)
141141
(define (pred v)

src/core/explain.rkt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@
3737

3838
(define (actual-errors expr ctx pcontext)
3939
(match-define (cons subexprs pt-errorss)
40-
(parameterize ([*pcontext* pcontext])
41-
(flip-lists (hash->list (first (compute-local-errors (list (all-subexpressions expr)) ctx))))))
40+
(flip-lists
41+
(hash->list (first (compute-local-errors (list (all-subexpressions expr)) ctx pcontext)))))
4242

4343
(define pt-worst-subexpr
4444
(append* (reap [sow]

src/core/localize.rkt

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@
5858
(define (make-matrix roots pcontext)
5959
(for/vector #:length (vector-length roots)
6060
([node (in-vector roots)])
61-
(make-vector (pcontext-length (*pcontext*)))))
61+
(make-vector (pcontext-length pcontext))))
6262

6363
; Compute local error or each sampled point at each node in `prog`.
64-
(define (compute-local-errors subexprss ctx)
64+
(define (compute-local-errors subexprss ctx pcontext)
6565
(define exprs-list (append* subexprss)) ; unroll subexprss
6666
(define reprs-list (map (curryr repr-of ctx) exprs-list))
6767
(define ctx-list
@@ -75,9 +75,9 @@
7575

7676
(define subexprs-fn (eval-progs-real (map prog->spec exprs-list) ctx-list))
7777

78-
(define errs (make-matrix roots (*pcontext*)))
78+
(define errs (make-matrix roots pcontext))
7979

80-
(for ([(pt ex) (in-pcontext (*pcontext*))]
80+
(for ([(pt ex) (in-pcontext pcontext)]
8181
[pt-idx (in-naturals)])
8282
(define exacts (list->vector (subexprs-fn pt)))
8383
(define (get-exact idx)
@@ -110,7 +110,7 @@
110110
((representation-bf->repr repr) 0.bf))))
111111

112112
;; Compute local error or each sampled point at each node in `prog`.
113-
(define (compute-errors subexprss ctx)
113+
(define (compute-errors subexprss ctx pcontext)
114114
;; We compute the actual (float) result
115115
(define exprs-list (append* subexprss)) ; unroll subexprss
116116
(define actual-value-fn (compile-progs exprs-list ctx))
@@ -146,14 +146,14 @@
146146
(define nodes (batch-nodes expr-batch))
147147
(define roots (batch-roots expr-batch))
148148

149-
(define ulp-errs (make-matrix roots (*pcontext*)))
150-
(define exacts-out (make-matrix roots (*pcontext*)))
151-
(define approx-out (make-matrix roots (*pcontext*)))
152-
(define true-error-out (make-matrix roots (*pcontext*)))
149+
(define ulp-errs (make-matrix roots pcontext))
150+
(define exacts-out (make-matrix roots pcontext))
151+
(define approx-out (make-matrix roots pcontext))
152+
(define true-error-out (make-matrix roots pcontext))
153153

154154
(define spec-vec (list->vector spec-list))
155155
(define ctx-vec (list->vector ctx-list))
156-
(for ([(pt ex) (in-pcontext (*pcontext*))]
156+
(for ([(pt ex) (in-pcontext pcontext)]
157157
[pt-idx (in-naturals)])
158158

159159
(define exacts (list->vector (subexprs-fn pt)))
@@ -211,8 +211,8 @@
211211
;; Compute the local error of every subexpression of `prog`
212212
;; and returns the error information as an S-expr in the
213213
;; same shape as `prog`
214-
(define (local-error-as-tree expr ctx)
215-
(define data-hash (first (compute-errors (list (all-subexpressions expr)) ctx)))
214+
(define (local-error-as-tree expr ctx pcontext)
215+
(define data-hash (first (compute-errors (list (all-subexpressions expr)) ctx pcontext)))
216216

217217
(define (translate-booleans value)
218218
(match value

src/core/mainloop.rkt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
;; Starting program for the current run
3434
(define *start-prog* (make-parameter #f))
35+
(define *pcontext* (make-parameter #f))
3536

3637
;; These high-level functions give the high-level workflow of Herbie:
3738
;; - Initial steps: explain, preprocessing, initialize the alt table
@@ -251,9 +252,10 @@
251252
(equal? (representation-type repr) 'real)
252253
(not (null? (context-vars ctx)))
253254
(get-fpcore-impl '<= '() (list repr repr)))
254-
(define opts (pareto-regimes (sort alts < #:key (curryr alt-cost repr)) start-prog ctx))
255+
(define opts
256+
(pareto-regimes (sort alts < #:key (curryr alt-cost repr)) start-prog ctx (*pcontext*)))
255257
(for/list ([opt (in-list opts)])
256-
(combine-alts opt start-prog ctx))]
258+
(combine-alts opt start-prog ctx (*pcontext*)))]
257259
[else (list (argmin score-alt alts))]))
258260

259261
(define (add-derivations! alts)

src/core/points.rkt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
"batch.rkt"
66
"compiler.rkt")
77

8-
(provide *pcontext*
9-
in-pcontext
8+
(provide in-pcontext
109
mk-pcontext
1110
for/pcontext
1211
pcontext?
@@ -21,7 +20,6 @@
2120
;; ground-truth information. They contain 1) a set of sampled input
2221
;; points; and 2) a ground-truth output for each input.
2322

24-
(define *pcontext* (make-parameter #f))
2523
(struct pcontext (points exacts) #:prefab)
2624

2725
(define (in-pcontext pcontext)

src/core/regimes.rkt

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
[(define (write-proc opt port mode)
2626
(fprintf port "#<option ~a>" (option-split-indices opt)))])
2727

28-
(define (pareto-regimes sorted start-prog ctx)
28+
(define (pareto-regimes sorted start-prog ctx pcontext)
2929
(timeline-event! 'regimes)
30-
(define err-lsts (batch-errors (map alt-expr sorted) (*pcontext*) ctx))
30+
(define err-lsts (batch-errors (map alt-expr sorted) pcontext ctx))
3131
(define branches
3232
(if (null? sorted)
3333
'()
@@ -44,14 +44,15 @@
4444
; Only return one option if not pareto mode
4545
[(and (not (*pareto-mode*)) (not (equal? alts sorted))) '()]
4646
[else
47-
(define-values (opt new-errs) (infer-splitpoints branch-exprs alts err-lsts #:errs errs ctx))
47+
(define-values (opt new-errs)
48+
(infer-splitpoints branch-exprs alts err-lsts #:errs errs ctx pcontext))
4849
(define high (si-cidx (argmax (λ (x) (si-cidx x)) (option-split-indices opt))))
4950
(cons opt (loop (take alts high) new-errs (take err-lsts high)))])))
5051

5152
;; `infer-splitpoints` and `combine-alts` are split so the mainloop
5253
;; can insert a timeline break between them.
5354

54-
(define (infer-splitpoints branch-exprs alts err-lsts* #:errs [cerrs (hash)] ctx)
55+
(define (infer-splitpoints branch-exprs alts err-lsts* #:errs [cerrs (hash)] ctx pcontext)
5556
(timeline-push! 'inputs (map (compose ~a alt-expr) alts))
5657
(define sorted-bexprs
5758
(sort branch-exprs (lambda (x y) (< (hash-ref cerrs x -1) (hash-ref cerrs y -1)))))
@@ -67,7 +68,7 @@
6768
([bexpr sorted-bexprs]
6869
;; stop if we've computed this (and following) branch-expr on more alts and it's still worse
6970
#:break (> (hash-ref cerrs bexpr -1) best-err))
70-
(define opt (option-on-expr alts err-lsts bexpr ctx))
71+
(define opt (option-on-expr alts err-lsts bexpr ctx pcontext))
7172
(define err
7273
(+ (errors-score (option-errors opt))
7374
(length (option-split-indices opt)))) ;; one-bit penalty per split
@@ -113,14 +114,14 @@
113114
#:when (critical-subexpression? expr subexpr))
114115
subexpr))
115116

116-
(define (option-on-expr alts err-lsts expr ctx)
117+
(define (option-on-expr alts err-lsts expr ctx pcontext)
117118
(define timeline-stop! (timeline-start! 'times (~a expr)))
118119

119120
(define fn (compile-prog expr ctx))
120121
(define repr (repr-of expr ctx))
121122

122123
(define big-table ; pt ; splitval ; alt1-err ; alt2-err ; ...
123-
(for/list ([(pt ex) (in-pcontext (*pcontext*))]
124+
(for/list ([(pt ex) (in-pcontext pcontext)]
124125
[err-lst err-lsts])
125126
(list* (fn pt) pt err-lst)))
126127
(match-define (list splitvals* pts* err-lsts* ...)
@@ -151,27 +152,27 @@
151152

152153
(module+ test
153154
(define ctx (make-debug-context '(x)))
154-
(parameterize ([*pcontext* (mk-pcontext '(#(0.5) #(4.0)) '(1.0 1.0))])
155-
(define alts (map make-alt (list '(fmin.f64 x 1) '(fmax.f64 x 1))))
156-
(define err-lsts `((,(expt 2.0 53) 1.0) (1.0 ,(expt 2.0 53))))
157-
158-
(define (test-regimes expr goal)
159-
(check (lambda (x y) (equal? (map si-cidx (option-split-indices x)) y))
160-
(option-on-expr alts err-lsts expr ctx)
161-
goal))
162-
163-
;; This is a basic sanity test
164-
(test-regimes 'x '(1 0))
165-
166-
;; This test ensures we handle equal points correctly. All points
167-
;; are equal along the `1` axis, so we should only get one
168-
;; splitpoint (the second, since it is better at the further point).
169-
(test-regimes (literal 1 'binary64) '(0))
170-
171-
(test-regimes `(if (==.f64 x ,(literal 0.5 'binary64))
172-
,(literal 1 'binary64)
173-
(NAN.f64))
174-
'(1 0))))
155+
(define pctx (mk-pcontext '(#(0.5) #(4.0)) '(1.0 1.0)))
156+
(define alts (map make-alt (list '(fmin.f64 x 1) '(fmax.f64 x 1))))
157+
(define err-lsts `((,(expt 2.0 53) 1.0) (1.0 ,(expt 2.0 53))))
158+
159+
(define (test-regimes expr goal)
160+
(check (lambda (x y) (equal? (map si-cidx (option-split-indices x)) y))
161+
(option-on-expr alts err-lsts expr ctx pctx)
162+
goal))
163+
164+
;; This is a basic sanity test
165+
(test-regimes 'x '(1 0))
166+
167+
;; This test ensures we handle equal points correctly. All points
168+
;; are equal along the `1` axis, so we should only get one
169+
;; splitpoint (the second, since it is better at the further point).
170+
(test-regimes (literal 1 'binary64) '(0))
171+
172+
(test-regimes `(if (==.f64 x ,(literal 0.5 'binary64))
173+
,(literal 1 'binary64)
174+
(NAN.f64))
175+
'(1 0)))
175176

176177
;; Given error-lsts, returns a list of sp objects representing where the optimal splitpoints are.
177178
(define (valid-splitindices? can-split? split-indices)

0 commit comments

Comments
 (0)