Skip to content

Commit 8f5509f

Browse files
committed
Precision optimization machinery
1 parent fc97a4b commit 8f5509f

File tree

3 files changed

+192
-4
lines changed

3 files changed

+192
-4
lines changed

eval/optimal.rkt

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#lang racket
2+
3+
(require "../ops/all.rkt"
4+
"machine.rkt"
5+
"run.rkt"
6+
"main.rkt")
7+
8+
(provide rival-machine-test-precision
9+
rival-machine-search-precision
10+
rival-machine-find-optimal-precisions)
11+
12+
; Test if a machine succeeds at a given point with a specific precision vector
13+
; Returns #t if the evaluation succeeds (good? and done?), #f otherwise
14+
(define (rival-machine-test-precision machine pt prec-vec)
15+
; Load point into registers
16+
(define ival-pt
17+
(for/vector #:length (vector-length pt)
18+
([x (in-vector pt)])
19+
(ival x)))
20+
(rival-machine-load machine ival-pt)
21+
22+
; Set custom precision vector
23+
(set-rival-machine-iteration! machine 1) ; Don't use initial precision vector
24+
(vector-copy! (rival-machine-precisions machine) 0 prec-vec)
25+
(vector-copy! (rival-machine-repeats machine) 0 (rival-machine-initial-repeats machine))
26+
(rival-machine-run machine (rival-machine-default-hint machine))
27+
28+
; Check result
29+
(define-values (good? done? bad? stuck? fvec) (rival-machine-return machine))
30+
(and good? done?))
31+
32+
; Binary search for the lowest precision at index idx that makes the machine succeed
33+
; Returns the minimum precision in [min-prec, max-prec] where evaluation succeeds,
34+
; or #f if even max-prec fails
35+
(define (rival-machine-search-precision machine pt prec-vec idx)
36+
(define test-vec (vector-copy prec-vec))
37+
(define max-prec (vector-ref test-vec idx))
38+
39+
; Check if max-prec works at all
40+
(unless (rival-machine-test-precision machine pt test-vec)
41+
(error 'rival-machine-search-precision "max-prec does not succeed"))
42+
43+
; Binary search for minimum
44+
(let loop ([lo 2]
45+
[hi max-prec])
46+
(if (>= lo hi)
47+
hi
48+
(let* ([mid (quotient (+ lo hi) 2)])
49+
(vector-set! test-vec idx mid)
50+
(if (rival-machine-test-precision machine pt test-vec)
51+
(loop lo mid)
52+
(loop (+ mid 1) hi))))))
53+
54+
; Run thunk n times and return the minimum time
55+
(define (time-min thunk #:min [n 5] #:sum [m 10])
56+
(thunk) ; Discard warm-up run
57+
(apply min
58+
(for/list ([i (in-range n)])
59+
(define start (current-inexact-milliseconds))
60+
(for ([i (in-range m)])
61+
(thunk))
62+
(/ (- (current-inexact-milliseconds) start) m))))
63+
64+
; Find optimal precisions for a machine at a given point
65+
(define (rival-machine-find-optimal-precisions machine pt)
66+
; Extract the precision assignment, assuming no slack
67+
(define out (rival-apply machine pt))
68+
(set-rival-machine-iteration! machine 1) ; Don't use initial precision vector
69+
(rival-machine-adjust machine (rival-machine-default-hint machine))
70+
(define max-precs (vector-copy (rival-machine-precisions machine)))
71+
72+
(cond
73+
[(rival-machine-test-precision machine pt max-precs)
74+
; Timed run with rival-apply (full evaluation), take min of 5 runs
75+
(define final-time (time-min (lambda () (rival-machine-test-precision machine pt max-precs))))
76+
77+
; Start with max precisions
78+
(define optimal-precs (vector-copy max-precs))
79+
(define n-instrs (vector-length (rival-machine-instructions machine)))
80+
(for ([idx (in-range (- n-instrs 1) -1 -1)])
81+
(vector-set! optimal-precs idx (rival-machine-search-precision machine pt optimal-precs idx)))
82+
83+
; Time run with optimal precisions
84+
(define optimal-time (time-min (lambda () (rival-machine-test-precision machine pt optimal-precs))))
85+
86+
; Return both ratios and precision vectors
87+
(list optimal-precs optimal-time max-precs final-time)]
88+
[else
89+
#f]))

infra/optimize.rkt

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#lang racket
2+
3+
(require json
4+
math/bigfloat)
5+
(require "main.rkt"
6+
"eval/machine.rkt"
7+
"eval/optimal.rkt"
8+
"utils.rkt")
9+
10+
(define (read-from-string s)
11+
(read (open-input-string s)))
12+
13+
(define (analyze-program rec bench-id min-speedup output-port)
14+
(define exprs (map read-from-string (hash-ref rec 'exprs)))
15+
(define vars (map read-from-string (hash-ref rec 'vars)))
16+
(match-define `(bool flonum ...) (map read-from-string (hash-ref rec 'discs)))
17+
(define discs (cons boolean-discretization (map (const flonum-discretization) (cdr exprs))))
18+
19+
(define machine
20+
(parameterize ([*rival-max-precision* 32256])
21+
(rival-compile exprs vars discs)))
22+
23+
(define results
24+
(filter identity
25+
(for/list ([pt* (in-list (hash-ref rec 'points))]
26+
[pt-id (in-naturals)])
27+
(match-define (list pt _sollya-exs _sollya-status _sollya-apply-time) pt*)
28+
(define pt-vec
29+
(parameterize ([bf-precision 53])
30+
(list->vector (map bf pt))))
31+
(define result (rival-machine-find-optimal-precisions machine pt-vec))
32+
(match result
33+
[(list optimal-precs optimal-time cur-precs cur-time)
34+
(list pt-id pt optimal-time cur-time)]
35+
[#f
36+
(eprintf "; Benchmark ~a point ~a, failure to optimize\n" bench-id pt-id)
37+
#f]))))
38+
39+
(define (bad-pt? rec)
40+
(match-define (list pt-id pt opt-time cur-time) rec)
41+
(and (> cur-time (* min-speedup opt-time))
42+
(> (- cur-time opt-time) .001))) ; At least 1 us of speedup!
43+
44+
(define dt (* 1000 (- (apply + (map fourth results)) (apply + (map third results)))))
45+
(define valid-results (sort (filter bad-pt? results) > #:key fourth)) ; Sort by cur-time
46+
(eprintf "; Benchmark ~a, total ~aµs available, ~a bad points\n"
47+
bench-id (~r dt #:precision '(= 1)) (length valid-results))
48+
(unless (empty? valid-results)
49+
(fprintf output-port "; Benchmark ~a, total ~aµs available\n" bench-id (~r dt #:precision '(= 1)))
50+
(fprintf output-port
51+
"(define (b~a ~a)\n ~a)\n"
52+
bench-id
53+
(string-join (map ~s vars) " ")
54+
(string-join (map ~s exprs) " "))
55+
56+
(for ([result (in-list valid-results)]
57+
[n (in-range 10)]) ; At most 10
58+
(match-define (list pt-id pt opt-time cur-time) result)
59+
(fprintf output-port "(optimize b~a ~a)\n" bench-id (string-join (map ~a pt) " ")))
60+
(fprintf output-port "\n")))
61+
62+
(module+ main
63+
(require racket/cmdline)
64+
(define min-speedup (make-parameter 1.2))
65+
(command-line
66+
#:once-each [("--min") n "Minimum speedup to report" (min-speedup (string->number n))]
67+
#:args ([points-file "infra/points.json"]
68+
[output-file "optimaize.rival"])
69+
(printf "Analyzing points bad precision assignment (min speedup: ~a)...\n\n" (min-speedup))
70+
(call-with-output-file output-file #:exists 'replace
71+
(λ (out-port)
72+
(call-with-input-file points-file
73+
(λ (input)
74+
(for ([rec (in-port read-json input)]
75+
[bench-id (in-naturals)])
76+
(analyze-program rec bench-id (min-speedup) out-port))))))))

repl.rkt

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
profile)
1111
(require "eval/main.rkt"
1212
"eval/machine.rkt"
13+
"eval/optimal.rkt"
1314
"utils.rkt")
1415
(provide repl-main
1516
repl-profile)
@@ -56,9 +57,9 @@
5657
#:value [value (lambda (_iter exec) exec)])
5758
(define entry
5859
(for/first ([exec (in-list execs)]
59-
#:when (and (or (not target-iter) (= (car exec) target-iter))
60-
(or (not target-id) (= (execution-number (cdr exec)) target-id))
61-
(or (not target-name) (= (execution-name (cdr exec)) target-name))))
60+
#:when (or (not target-iter) (= (car exec) target-iter))
61+
#:when (or (not target-id) (= (execution-number (cdr exec)) target-id))
62+
#:when (or (not target-name) (eq? (execution-name (cdr exec)) target-name)))
6263
exec))
6364
(if entry
6465
(value (car entry) (cdr entry))
@@ -85,7 +86,7 @@
8586
(define (repl-value->string val)
8687
(cond
8788
[(bigfloat? val) (bigfloat->string val)]
88-
[(number? val) (~r val)]
89+
[(number? val) (~a val)]
8990
[else (~a val)]))
9091

9192
(define (repl-precision-bits repl)
@@ -220,6 +221,27 @@
220221
(when print?
221222
(write-explain machine)
222223
(printf "\nTotal: ~aµs\n" (~r (* (- end start) 1000) #:precision '(= 1))))]
224+
[`(optimize ,name ,(? (disjoin real? boolean?) vals) ...)
225+
(define machine (repl-get-machine repl name))
226+
(check-args! name machine vals)
227+
(define result
228+
(parameterize ([bf-precision (repl-precision-bits repl)])
229+
(rival-machine-find-optimal-precisions machine (list->vector (map ->bf vals)))))
230+
(when print?
231+
(match-define (list optimal-precs optimal-time cur-precs cur-time) result)
232+
(printf "~a optimal ~aµs faster (~a×)\n"
233+
name
234+
(~r (* 1000 (- cur-time optimal-time)) #:precision '(= 1))
235+
(if (zero? optimal-time) "" (~r (/ cur-time optimal-time) #:precision '(= 3))))
236+
(define ivec (rival-machine-instructions machine))
237+
(for ([instr (in-vector ivec)]
238+
[final (in-vector cur-precs)]
239+
[optimal (in-vector optimal-precs)])
240+
(define instr-name (normalize-function-name (~a (object-name (car instr)))))
241+
(printf "~a ~a ~a\n" (~a instr-name #:width 20 #:align 'left)
242+
(~a final #:width 6 #:align 'right)
243+
(~a optimal #:width 6 #:align 'right)))
244+
(newline))]
223245
[(or '(help) 'help)
224246
(displayln "This is the Rival REPL, a demo of the Rival real evaluator.")
225247
(newline)
@@ -229,6 +251,7 @@
229251
(displayln " (eval <name> <vals> ...) Evaluate a named function")
230252
(displayln
231253
" (explain <name> <vals> ...) Show profile for evaluating a named function")
254+
(displayln " (optimize <name> <vals> ...) Find optimal precisions and report speedup")
232255
(newline)
233256
(displayln "A closed expression can always be used in place of a named function.")]
234257
[_ (printf "Unknown command ~a; use help for command list\n" cmd)])))

0 commit comments

Comments
 (0)