|
1 | 1 | #lang racket |
2 | 2 |
|
3 | 3 | (require "../common.rkt" "../alternative.rkt" "../programs.rkt" "../timeline.rkt") |
4 | | -(require "../syntax/types.rkt" "../interface.rkt") |
| 4 | +(require "../syntax/types.rkt" "../interface.rkt" "../errors.rkt") |
5 | 5 | (require "../points.rkt" "../float.rkt") ; For binary search |
6 | 6 |
|
7 | 7 | (module+ test |
|
86 | 86 | (critical-subexpression? prog-body expr))) |
87 | 87 | expr)) |
88 | 88 |
|
89 | | -(define (combine-alts best-option repr) |
| 89 | +(define (combine-alts best-option repr sampler) |
90 | 90 | (match-define (option splitindices alts pts expr _) best-option) |
91 | 91 | (match splitindices |
92 | 92 | [(list (si cidx _)) (list-ref alts cidx)] |
93 | 93 | [_ |
94 | | - (define splitpoints (sindices->spoints pts expr alts splitindices repr)) |
| 94 | + (define splitpoints (sindices->spoints pts expr alts splitindices repr sampler)) |
95 | 95 | (debug #:from 'regimes "Found splitpoints:" splitpoints ", with alts" alts) |
96 | 96 |
|
97 | 97 | (define expr* |
|
191 | 191 | ;; float form always come from the range [f(idx1), f(idx2)). If the |
192 | 192 | ;; float form of a split is f(idx2), or entirely outside that range, |
193 | 193 | ;; problems may arise. |
194 | | -(define (sindices->spoints points expr alts sindices repr) |
| 194 | +(define (sindices->spoints points expr alts sindices repr sampler) |
195 | 195 | (define eval-expr |
196 | 196 | (eval-prog `(λ ,(program-variables (alt-program (car alts))) ,expr) 'fl repr)) |
197 | 197 |
|
|
207 | 207 | [*timeline-disabled* true] |
208 | 208 | [*var-reprs* (dict-set (*var-reprs*) var repr)]) |
209 | 209 | (define ctx |
210 | | - (prepare-points start-prog `(== ,(caadr start-prog) ,v) repr)) |
| 210 | + (prepare-points start-prog |
| 211 | + `(λ ,(program-variables start-prog) (== ,(caadr start-prog) ,v)) |
| 212 | + repr |
| 213 | + (λ () (cons v (sampler))))) |
211 | 214 | (< (errors-score (errors prog1 ctx repr)) |
212 | 215 | (errors-score (errors prog2 ctx repr))))) |
213 | 216 | (define pt (binary-search-floats pred v1 v2 repr)) |
|
223 | 226 |
|
224 | 227 | (sp (si-cidx sidx) expr (find-split prog1 prog2 p1 p2))) |
225 | 228 |
|
| 229 | + (define (regimes-sidx->spoint sidx) |
| 230 | + (sp (si-cidx sidx) expr (apply eval-expr (list-ref points (- (si-pidx sidx) 1))))) |
| 231 | + |
226 | 232 | (define final-sp (sp (si-cidx (last sindices)) expr +nan.0)) |
227 | 233 |
|
| 234 | + (define use-binary |
| 235 | + (and (flag-set? 'reduce 'binary-search) |
| 236 | + ;; Binary search is only valid if we correctly extracted the branch expression |
| 237 | + (andmap identity (cons start-prog progs)))) |
| 238 | + (if use-binary |
| 239 | + (debug #:from 'binary-search "Improving bounds with binary search for" expr "and" alts) |
| 240 | + (debug #:from 'binary-search "Only using regimes for bounds on" expr "and" alts)) |
| 241 | + |
228 | 242 | (append |
229 | | - (if (and (flag-set? 'reduce 'binary-search) |
230 | | - ;; Binary search is only valid if we correctly extracted the branch expression |
231 | | - (andmap identity (cons start-prog progs))) |
232 | | - (begin |
233 | | - (debug #:from 'binary-search "Improving bounds with binary search for" expr "and" alts) |
234 | | - (for/list ([si1 sindices] [si2 (cdr sindices)]) |
235 | | - (sidx->spoint si1 si2))) |
236 | | - (begin |
237 | | - (debug #:from 'binary-search "Only using regimes for bounds on" expr "and" alts) |
238 | | - (for/list ([sindex (take sindices (sub1 (length sindices)))]) |
239 | | - (sp (si-cidx sindex) expr (apply eval-expr (list-ref points (- (si-pidx sindex) 1))))))) |
| 243 | + (for/list ([si1 sindices] [si2 (cdr sindices)]) |
| 244 | + (cond |
| 245 | + [use-binary |
| 246 | + (with-handlers ([exn:fail:user:herbie:sampling? |
| 247 | + (lambda (e) (regimes-sidx->spoint si1))]) |
| 248 | + (sidx->spoint si1 si2))] |
| 249 | + [else |
| 250 | + (regimes-sidx->spoint si1)])) |
240 | 251 | (list final-sp))) |
241 | 252 |
|
242 | 253 | (define (point-with-dim index point val) |
|
0 commit comments