|
25 | 25 | [(define (write-proc opt port mode) |
26 | 26 | (fprintf port "#<option ~a>" (option-split-indices opt)))]) |
27 | 27 |
|
28 | | -(define (pareto-regimes sorted start-prog ctx) |
| 28 | +(define (pareto-regimes sorted start-prog ctx pcontext) |
29 | 29 | (timeline-event! 'regimes) |
30 | | - (define err-lsts (batch-errors (map alt-expr sorted) (*pcontext*) ctx)) |
| 30 | + (define err-lsts (batch-errors (map alt-expr sorted) pcontext ctx)) |
31 | 31 | (define branches |
32 | 32 | (if (null? sorted) |
33 | 33 | '() |
|
44 | 44 | ; Only return one option if not pareto mode |
45 | 45 | [(and (not (*pareto-mode*)) (not (equal? alts sorted))) '()] |
46 | 46 | [else |
47 | | - (define-values (opt new-errs) (infer-splitpoints branch-exprs alts err-lsts #:errs errs ctx)) |
| 47 | + (define-values (opt new-errs) |
| 48 | + (infer-splitpoints branch-exprs alts err-lsts #:errs errs ctx pcontext)) |
48 | 49 | (define high (si-cidx (argmax (λ (x) (si-cidx x)) (option-split-indices opt)))) |
49 | 50 | (cons opt (loop (take alts high) new-errs (take err-lsts high)))]))) |
50 | 51 |
|
51 | 52 | ;; `infer-splitpoints` and `combine-alts` are split so the mainloop |
52 | 53 | ;; can insert a timeline break between them. |
53 | 54 |
|
54 | | -(define (infer-splitpoints branch-exprs alts err-lsts* #:errs [cerrs (hash)] ctx) |
| 55 | +(define (infer-splitpoints branch-exprs alts err-lsts* #:errs [cerrs (hash)] ctx pcontext) |
55 | 56 | (timeline-push! 'inputs (map (compose ~a alt-expr) alts)) |
56 | 57 | (define sorted-bexprs |
57 | 58 | (sort branch-exprs (lambda (x y) (< (hash-ref cerrs x -1) (hash-ref cerrs y -1))))) |
|
67 | 68 | ([bexpr sorted-bexprs] |
68 | 69 | ;; stop if we've computed this (and following) branch-expr on more alts and it's still worse |
69 | 70 | #:break (> (hash-ref cerrs bexpr -1) best-err)) |
70 | | - (define opt (option-on-expr alts err-lsts bexpr ctx)) |
| 71 | + (define opt (option-on-expr alts err-lsts bexpr ctx pcontext)) |
71 | 72 | (define err |
72 | 73 | (+ (errors-score (option-errors opt)) |
73 | 74 | (length (option-split-indices opt)))) ;; one-bit penalty per split |
|
113 | 114 | #:when (critical-subexpression? expr subexpr)) |
114 | 115 | subexpr)) |
115 | 116 |
|
116 | | -(define (option-on-expr alts err-lsts expr ctx) |
| 117 | +(define (option-on-expr alts err-lsts expr ctx pcontext) |
117 | 118 | (define timeline-stop! (timeline-start! 'times (~a expr))) |
118 | 119 |
|
119 | 120 | (define fn (compile-prog expr ctx)) |
120 | 121 | (define repr (repr-of expr ctx)) |
121 | 122 |
|
122 | 123 | (define big-table ; pt ; splitval ; alt1-err ; alt2-err ; ... |
123 | | - (for/list ([(pt ex) (in-pcontext (*pcontext*))] |
| 124 | + (for/list ([(pt ex) (in-pcontext pcontext)] |
124 | 125 | [err-lst err-lsts]) |
125 | 126 | (list* (fn pt) pt err-lst))) |
126 | 127 | (match-define (list splitvals* pts* err-lsts* ...) |
|
151 | 152 |
|
152 | 153 | (module+ test |
153 | 154 | (define ctx (make-debug-context '(x))) |
154 | | - (parameterize ([*pcontext* (mk-pcontext '(#(0.5) #(4.0)) '(1.0 1.0))]) |
155 | | - (define alts (map make-alt (list '(fmin.f64 x 1) '(fmax.f64 x 1)))) |
156 | | - (define err-lsts `((,(expt 2.0 53) 1.0) (1.0 ,(expt 2.0 53)))) |
157 | | - |
158 | | - (define (test-regimes expr goal) |
159 | | - (check (lambda (x y) (equal? (map si-cidx (option-split-indices x)) y)) |
160 | | - (option-on-expr alts err-lsts expr ctx) |
161 | | - goal)) |
162 | | - |
163 | | - ;; This is a basic sanity test |
164 | | - (test-regimes 'x '(1 0)) |
165 | | - |
166 | | - ;; This test ensures we handle equal points correctly. All points |
167 | | - ;; are equal along the `1` axis, so we should only get one |
168 | | - ;; splitpoint (the second, since it is better at the further point). |
169 | | - (test-regimes (literal 1 'binary64) '(0)) |
170 | | - |
171 | | - (test-regimes `(if (==.f64 x ,(literal 0.5 'binary64)) |
172 | | - ,(literal 1 'binary64) |
173 | | - (NAN.f64)) |
174 | | - '(1 0)))) |
| 155 | + (define pctx (mk-pcontext '(#(0.5) #(4.0)) '(1.0 1.0))) |
| 156 | + (define alts (map make-alt (list '(fmin.f64 x 1) '(fmax.f64 x 1)))) |
| 157 | + (define err-lsts `((,(expt 2.0 53) 1.0) (1.0 ,(expt 2.0 53)))) |
| 158 | + |
| 159 | + (define (test-regimes expr goal) |
| 160 | + (check (lambda (x y) (equal? (map si-cidx (option-split-indices x)) y)) |
| 161 | + (option-on-expr alts err-lsts expr ctx pctx) |
| 162 | + goal)) |
| 163 | + |
| 164 | + ;; This is a basic sanity test |
| 165 | + (test-regimes 'x '(1 0)) |
| 166 | + |
| 167 | + ;; This test ensures we handle equal points correctly. All points |
| 168 | + ;; are equal along the `1` axis, so we should only get one |
| 169 | + ;; splitpoint (the second, since it is better at the further point). |
| 170 | + (test-regimes (literal 1 'binary64) '(0)) |
| 171 | + |
| 172 | + (test-regimes `(if (==.f64 x ,(literal 0.5 'binary64)) |
| 173 | + ,(literal 1 'binary64) |
| 174 | + (NAN.f64)) |
| 175 | + '(1 0))) |
175 | 176 |
|
176 | 177 | ;; Given error-lsts, returns a list of sp objects representing where the optimal splitpoints are. |
177 | 178 | (define (valid-splitindices? can-split? split-indices) |
|
0 commit comments