Skip to content

Commit 73ab1ec

Browse files
authored
Merge pull request #1282 from herbie-fp/codex/refactor--start-prog--global-usage
Pass start program explicitly
2 parents dd66cc0 + 56dca9a commit 73ab1ec

File tree

4 files changed

+19
-22
lines changed

4 files changed

+19
-22
lines changed

src/core/bsearch.rkt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@
3232
;; The last splitpoint uses +nan.0 for pt and represents the "else"
3333
(struct sp (cidx bexpr point) #:prefab)
3434

35-
(define (combine-alts best-option ctx)
35+
(define (combine-alts best-option start-prog ctx)
3636
(match-define (option splitindices alts pts expr _) best-option)
3737
(match splitindices
3838
[(list (si cidx _)) (list-ref alts cidx)]
3939
[_
4040
(timeline-event! 'bsearch)
41-
(define splitpoints (sindices->spoints pts expr alts splitindices ctx))
41+
(define splitpoints (sindices->spoints pts expr alts splitindices start-prog ctx))
4242

4343
(define expr*
4444
(for/fold ([expr (alt-expr (list-ref alts (sp-cidx (last splitpoints))))])
@@ -119,16 +119,16 @@
119119
;; float form always come from the range [f(idx1), f(idx2)). If the
120120
;; float form of a split is f(idx2), or entirely outside that range,
121121
;; problems may arise.
122-
(define/contract (sindices->spoints points expr alts sindices ctx)
123-
(-> (listof vector?) any/c (listof alt?) (listof si?) context? valid-splitpoints?)
122+
(define/contract (sindices->spoints points expr alts sindices start-prog ctx)
123+
(-> (listof vector?) any/c (listof alt?) (listof si?) any/c context? valid-splitpoints?)
124124
(define repr (repr-of expr ctx))
125125

126126
(define eval-expr (compile-prog expr ctx))
127127

128128
(define var (gensym 'branch))
129129
(define ctx* (context-extend ctx var repr))
130130
(define progs (map (compose (curryr extract-subexpression var expr ctx) alt-expr) alts))
131-
(define start-prog (extract-subexpression (*start-prog*) var expr ctx))
131+
(define start-prog-sub (extract-subexpression start-prog var expr ctx))
132132

133133
; Not totally clear if this should actually use the precondition
134134
(define start-real-compiler
@@ -163,7 +163,7 @@
163163
(define use-binary
164164
(and (flag-set? 'reduce 'binary-search)
165165
;; Binary search is only valid if we correctly extracted the branch expression
166-
(andmap identity (cons start-prog progs))))
166+
(andmap identity (cons start-prog-sub progs))))
167167

168168
(append (for/list ([si1 sindices]
169169
[si2 (cdr sindices)])

src/core/mainloop.rkt

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
(define/reset ^patched^ #f)
3131
(define/reset ^table^ #f)
3232

33+
;; Starting program for the current run
34+
(define *start-prog* (make-parameter #f))
35+
3336
;; These high-level functions give the high-level workflow of Herbie:
3437
;; - Initial steps: explain, preprocessing, initialize the alt table
3538
;; - the loop: choose some alts, localize, run the patch table, and finalize
@@ -66,7 +69,7 @@
6669
(timeline-push-alts! '())
6770

6871
(define all-alts (atab-all-alts (^table^)))
69-
(define joined-alts (make-regime! all-alts)) ;; HERE
72+
(define joined-alts (make-regime! all-alts (*start-prog*))) ;; HERE
7073
(define annotated-alts (add-derivations! joined-alts))
7174

7275
(timeline-push! 'stop (if (atab-completed? (^table^)) "done" "fuel") 1)
@@ -237,7 +240,7 @@
237240
(finalize-iter!)
238241
(void))
239242

240-
(define (make-regime! alts)
243+
(define (make-regime! alts start-prog)
241244
(define ctx (*context*))
242245
(define repr (context-repr ctx))
243246

@@ -247,9 +250,9 @@
247250
(equal? (representation-type repr) 'real)
248251
(not (null? (context-vars ctx)))
249252
(get-fpcore-impl '<= '() (list repr repr)))
250-
(define opts (pareto-regimes (sort alts < #:key (curryr alt-cost repr)) ctx))
253+
(define opts (pareto-regimes (sort alts < #:key (curryr alt-cost repr)) start-prog ctx))
251254
(for/list ([opt (in-list opts)])
252-
(combine-alts opt ctx))]
255+
(combine-alts opt start-prog ctx))]
253256
[else (list (argmin score-alt alts))]))
254257

255258
(define (add-derivations! alts)

src/core/regimes.rkt

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@
2525
[(define (write-proc opt port mode)
2626
(fprintf port "#<option ~a>" (option-split-indices opt)))])
2727

28-
(define (pareto-regimes sorted ctx)
28+
(define (pareto-regimes sorted start-prog ctx)
2929
(timeline-event! 'regimes)
3030
(define err-lsts (batch-errors (map alt-expr sorted) (*pcontext*) ctx))
3131
(define branches
3232
(if (null? sorted)
3333
'()
34-
(exprs-to-branch-on sorted ctx)))
34+
(exprs-to-branch-on sorted start-prog ctx)))
3535
(define branch-exprs
3636
(if (flag-set? 'reduce 'branch-expressions)
3737
branches
@@ -87,11 +87,11 @@
8787
(timeline-push! 'oracle (errors-score (map (curry apply max) err-lsts)))
8888
(values best errs))
8989

90-
(define (exprs-to-branch-on alts ctx)
90+
(define (exprs-to-branch-on alts start-prog ctx)
9191
(define alt-critexprs
9292
(for/list ([alt (in-list alts)])
9393
(all-critical-subexpressions (alt-expr alt) ctx)))
94-
(define start-critexprs (all-critical-subexpressions (*start-prog*) ctx))
94+
(define start-critexprs (all-critical-subexpressions start-prog ctx))
9595
;; We can only binary search if the branch expression is critical
9696
;; for all of the alts and also for the start prgoram.
9797
(filter (λ (e) (equal? (representation-type (repr-of e ctx)) 'real))
@@ -151,8 +151,7 @@
151151

152152
(module+ test
153153
(define ctx (make-debug-context '(x)))
154-
(parameterize ([*start-prog* (literal 1 'binary64)]
155-
[*pcontext* (mk-pcontext '(#(0.5) #(4.0)) '(1.0 1.0))])
154+
(parameterize ([*pcontext* (mk-pcontext '(#(0.5) #(4.0)) '(1.0 1.0))])
156155
(define alts (map make-alt (list '(fmin.f64 x 1) '(fmax.f64 x 1))))
157156
(define err-lsts `((,(expt 2.0 53) 1.0) (1.0 ,(expt 2.0 53))))
158157

src/utils/alternative.rkt

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
alt?
77
alt-expr
88
alt-add-event
9-
*start-prog*
109
*all-alts*
1110
alt-cost
1211
alt-equal?
@@ -43,9 +42,5 @@
4342
(define (alt-map f altn)
4443
(f (struct-copy alt altn [prevs (map (curry alt-map f) (alt-prevs altn))])))
4544

46-
;; A useful parameter for many of Herbie's subsystems, though
47-
;; ultimately one that should be located somewhere else or perhaps
48-
;; exorcised
49-
50-
(define *start-prog* (make-parameter #f))
45+
;; Keeps track of all alts so far.
5146
(define *all-alts* (make-parameter '()))

0 commit comments

Comments
 (0)