Skip to content

Commit 7e5a86b

Browse files
authored
Merge pull request #1044 from herbie-fp/cleanup-localize
Add helper functions `local-error` and `remove-infinities` to `localize.rkt`
2 parents 7c7e6f3 + 471bd0b commit 7e5a86b

File tree

1 file changed

+65
-75
lines changed

1 file changed

+65
-75
lines changed

src/core/localize.rkt

Lines changed: 65 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -174,46 +174,60 @@
174174
>
175175
#:key (compose errors-score car))))
176176

177+
;; The local error of an expression f(x, y) is
178+
;;
179+
;; R[f(x, y)] - f(R[x], R[y])
180+
;;
181+
;; where the `-` is interpreted as ULP difference and `E` means
182+
;; exact real evaluation rounded to target repr.
183+
;;
184+
;; Local error is high when `f` is highly sensitive to rounding error
185+
;; in its inputs `x` and `y`.
186+
187+
(define (local-error exact node repr get-exact)
188+
(match node
189+
[(? literal?) 1]
190+
[(? variable?) 1]
191+
[(approx _ impl) (ulp-difference exact (get-exact impl) repr)]
192+
[`(if ,c ,ift ,iff) 1]
193+
[(list f args ...)
194+
(define argapprox (map get-exact args))
195+
(define approx (apply (impl-info f 'fl) argapprox))
196+
(ulp-difference exact approx repr)]))
197+
198+
(define (make-matrix roots pcontext)
199+
(for/vector #:length (vector-length roots)
200+
([node (in-vector roots)])
201+
(make-vector (pcontext-length (*pcontext*)))))
202+
177203
; Compute local error or each sampled point at each node in `prog`.
178204
(define (compute-local-errors subexprss ctx)
179205
(define exprs-list (append* subexprss)) ; unroll subexprss
206+
(define reprs-list (map (curryr repr-of ctx) exprs-list))
180207
(define ctx-list
181-
(for/list ([subexpr (in-list exprs-list)])
182-
(struct-copy context ctx [repr (repr-of subexpr ctx)])))
208+
(for/list ([subexpr (in-list exprs-list)]
209+
[repr (in-list reprs-list)])
210+
(struct-copy context ctx [repr repr])))
183211

184212
(define expr-batch (progs->batch exprs-list))
185213
(define nodes (batch-nodes expr-batch))
186214
(define roots (batch-roots expr-batch))
187215

188216
(define subexprs-fn (eval-progs-real (map prog->spec exprs-list) ctx-list))
189217

190-
(define errs
191-
(for/vector #:length (vector-length roots)
192-
([node (in-vector roots)])
193-
(make-vector (pcontext-length (*pcontext*)))))
218+
(define errs (make-matrix roots (*pcontext*)))
194219

195220
(for ([(pt ex) (in-pcontext (*pcontext*))]
196221
[pt-idx (in-naturals)])
197222
(define exacts (list->vector (apply subexprs-fn pt)))
223+
(define (get-exact idx)
224+
(vector-ref exacts (vector-member idx roots)))
198225
(for ([expr (in-list exprs-list)]
199226
[root (in-vector roots)]
227+
[repr (in-list reprs-list)]
200228
[exact (in-vector exacts)]
201229
[expr-idx (in-naturals)])
202-
(define err
203-
(match (vector-ref nodes root)
204-
[(? literal?) 1]
205-
[(? variable?) 1]
206-
[(approx _ impl)
207-
(define repr (repr-of expr ctx))
208-
(ulp-difference exact (vector-ref exacts (vector-member impl roots)) repr)]
209-
[`(if ,c ,ift ,iff) 1]
210-
[(list f args ...)
211-
(define repr (impl-info f 'otype))
212-
(define argapprox
213-
(for/list ([idx (in-list args)])
214-
(vector-ref exacts (vector-member idx roots)))) ; arg's index mapping to exact
215-
(define approx (apply (impl-info f 'fl) argapprox))
216-
(ulp-difference exact approx repr)]))
230+
(define err (local-error exact (vector-ref nodes root) repr get-exact))
217231
(vector-set! (vector-ref errs expr-idx) pt-idx err)))
218232

219233
(define n 0)
@@ -222,6 +236,19 @@
222236
(begin0 (values subexpr (vector->list (vector-ref errs n)))
223237
(set! n (add1 n))))))
224238

239+
;; The absolute error of expression `e` is R[e - R[e]].
240+
;; However, it's possible that R[e] is infinity or NaN;
241+
;; in this case, computing the absolute error won't work
242+
;; since those aren't real numbers. To fix this, we replace all
243+
;; non-finite R[e] with 0.
244+
(define (remove-infinities pt reprs)
245+
(for/list ([val (in-vector pt)]
246+
[repr (in-list reprs)])
247+
(define bf-val ((representation-repr->bf repr) val))
248+
(if (implies (bigfloat? bf-val) (bfrational? bf-val))
249+
val
250+
((representation-bf->repr repr) 0.bf))))
251+
225252
;; Compute local error or each sampled point at each node in `prog`.
226253
(define (compute-errors subexprss ctx)
227254
;; We compute the actual (float) result
@@ -230,9 +257,11 @@
230257

231258
;; And the real result
232259
(define spec-list (map prog->spec exprs-list))
260+
(define reprs-list (map (curryr repr-of ctx) exprs-list))
233261
(define ctx-list
234-
(for/list ([subexpr (in-list exprs-list)])
235-
(struct-copy context ctx [repr (repr-of subexpr ctx)])))
262+
(for/list ([subexpr (in-list exprs-list)]
263+
[repr (in-list reprs-list)])
264+
(struct-copy context ctx [repr repr])))
236265
(define subexprs-fn (eval-progs-real spec-list ctx-list))
237266

238267
;; And the absolute difference between the two
@@ -244,86 +273,47 @@
244273
(define delta-ctx
245274
(context (append (context-vars ctx) exact-var-names)
246275
(get-representation 'binary64)
247-
(append (context-var-reprs ctx)
248-
(for/list ([expr (in-list exprs-list)])
249-
(repr-of expr ctx)))))
276+
(append (context-var-reprs ctx) reprs-list)))
250277
(define compare-specs
251278
(for/list ([spec (in-list spec-list)]
252279
[expr (in-list exprs-list)]
280+
[repr (in-list reprs-list)]
253281
[var (in-list exact-var-names)])
254282
(cond
255283
[(number? spec) 0] ; HACK: unclear why numbers don't work in Rival but :shrug:
256-
[(equal? (representation-type (repr-of expr ctx)) 'bool)
257-
0] ; HACK: just ignore differences in booleans
284+
[(equal? (representation-type repr) 'bool) 0] ; HACK: just ignore differences in booleans
258285
[else `(fabs (- ,spec ,var))])))
259286
(define delta-fn (eval-progs-real compare-specs (map (const delta-ctx) compare-specs)))
260287

261288
(define expr-batch (progs->batch exprs-list))
262289
(define nodes (batch-nodes expr-batch))
263290
(define roots (batch-roots expr-batch))
264291

265-
(define ulp-errs
266-
(for/vector #:length (vector-length roots)
267-
([node (in-vector roots)])
268-
(make-vector (pcontext-length (*pcontext*)))))
269-
270-
(define exacts-out
271-
(for/vector #:length (vector-length roots)
272-
([node (in-vector roots)])
273-
(make-vector (pcontext-length (*pcontext*)))))
274-
275-
(define approx-out
276-
(for/vector #:length (vector-length roots)
277-
([node (in-vector roots)])
278-
(make-vector (pcontext-length (*pcontext*)))))
279-
280-
(define true-error-out
281-
(for/vector #:length (vector-length roots)
282-
([node (in-vector roots)])
283-
(make-vector (pcontext-length (*pcontext*)))))
292+
(define ulp-errs (make-matrix roots (*pcontext*)))
293+
(define exacts-out (make-matrix roots (*pcontext*)))
294+
(define approx-out (make-matrix roots (*pcontext*)))
295+
(define true-error-out (make-matrix roots (*pcontext*)))
284296

285297
(define spec-vec (list->vector spec-list))
286298
(define ctx-vec (list->vector ctx-list))
287299
(for ([(pt ex) (in-pcontext (*pcontext*))]
288300
[pt-idx (in-naturals)])
289301

290302
(define exacts (list->vector (apply subexprs-fn pt)))
291-
(define actuals (apply actual-value-fn pt))
303+
(define (get-exact idx)
304+
(vector-ref exacts (vector-member idx roots)))
292305

293-
(define actuals*
294-
(for/list ([val (in-vector actuals)]
295-
[expr (in-list exprs-list)])
296-
(define repr (repr-of expr ctx))
297-
(define bf-val ((representation-repr->bf repr) val))
298-
(if (implies (bigfloat? bf-val) (bfrational? bf-val))
299-
val
300-
((representation-bf->repr repr) 0.bf)))) ; HACK: inf and nan -> 0 for absolute error
301-
(define pt* (append pt actuals*))
306+
(define actuals (apply actual-value-fn pt))
307+
(define pt* (append pt (remove-infinities actuals reprs-list)))
302308
(define deltas (list->vector (apply delta-fn pt*)))
303309

304-
(for ([spec (in-list spec-list)]
305-
[expr (in-list exprs-list)]
310+
(for ([repr (in-list reprs-list)]
306311
[root (in-vector roots)]
307312
[exact (in-vector exacts)]
308313
[actual (in-vector actuals)]
309314
[delta (in-vector deltas)]
310315
[expr-idx (in-naturals)])
311-
(define ulp-err
312-
(match (vector-ref nodes root)
313-
[(? literal?) 1]
314-
[(? variable?) 1]
315-
[(approx _ impl)
316-
(define repr (repr-of expr ctx))
317-
(ulp-difference exact (vector-ref exacts (vector-member impl roots)) repr)]
318-
[`(if ,c ,ift ,iff) 1]
319-
[(list f args-roots ...)
320-
(define repr (impl-info f 'otype))
321-
(define argapprox
322-
(for/list ([idx (in-list args-roots)])
323-
(vector-ref exacts (vector-member idx roots)))) ; arg's index mapping to exact
324-
(define approx (apply (impl-info f 'fl) argapprox))
325-
(ulp-difference exact approx repr)]))
326-
316+
(define ulp-err (local-error exact (vector-ref nodes root) repr get-exact))
327317
(vector-set! (vector-ref exacts-out expr-idx) pt-idx exact)
328318
(vector-set! (vector-ref approx-out expr-idx) pt-idx actual)
329319
(vector-set! (vector-ref ulp-errs expr-idx) pt-idx ulp-err)

0 commit comments

Comments
 (0)