|
1 | 1 | #lang racket/base |
2 | 2 |
|
3 | | -(require racket/function |
4 | | - racket/list |
5 | | - racket/match) |
6 | | -(require "../mpfr.rkt" |
| 3 | +(require "tricks.rkt" |
7 | 4 | "../ops/all.rkt" |
8 | 5 | "machine.rkt") |
| 6 | + |
9 | 7 | (provide backward-pass) |
10 | 8 |
|
11 | 9 | (define (backward-pass machine) |
|
35 | 33 |
|
36 | 34 | ; Step 1b. Checking if a operation should be computed again at all |
37 | 35 | (define vuseful (make-vector (vector-length ivec) #f)) |
38 | | - (for ([root (in-vector rootvec)] #:when (>= root varc)) |
| 36 | + (for ([root (in-vector rootvec)] |
| 37 | + #:when (>= root varc)) |
39 | 38 | (vector-set! vuseful (- root varc) #t)) |
40 | 39 | (for ([reg (in-vector vregs (- (vector-length vregs) 1) (- varc 1) -1)] |
41 | 40 | [instr (in-vector ivec (- (vector-length ivec) 1) -1 -1)] |
|
44 | 43 | (cond |
45 | 44 | [(and (ival-lo-fixed? reg) (ival-hi-fixed? reg)) (vector-set! vuseful i #f)] |
46 | 45 | [useful? |
47 | | - (for ([arg (in-list (cdr instr))] #:when (>= arg varc)) |
| 46 | + (for ([arg (in-list (cdr instr))] |
| 47 | + #:when (>= arg varc)) |
48 | 48 | (vector-set! vuseful (- arg varc) #t))])) |
49 | 49 |
|
50 | 50 | ; Step 2. Precision tuning |
|
74 | 74 | (unless any-false? |
75 | 75 | (set-rival-machine-bumps! machine (add1 bumps)) |
76 | 76 | (define slack (get-slack)) |
77 | | - (for ([prec (in-vector vprecs)] [n (in-range (vector-length vprecs))]) |
| 77 | + (for ([prec (in-vector vprecs)] |
| 78 | + [n (in-range (vector-length vprecs))]) |
78 | 79 | (define prec* (min (*rival-max-precision*) (+ prec slack))) |
79 | 80 | (when (equal? prec* (*rival-max-precision*)) |
80 | 81 | (*sampling-iteration* (*rival-max-iterations*))) |
|
95 | 96 | (define output (vector-ref vregs n)) ; output of the current instr |
96 | 97 |
|
97 | 98 | (define intro (vector-ref vprecs-new (- n varc))) ; intro for the current instruction |
98 | | - (define ampls (get-ampls op output srcs)) ; ampls for the tail instructions |
| 99 | + (define ampls (get-bounds op output srcs)) ; ampls for the tail instructions |
99 | 100 |
|
100 | 101 | (define final-parent-precision |
101 | 102 | (max (+ intro (vector-ref vstart-precs (- n varc))) (*base-tuning-precision*))) |
|
113 | 114 | ; check whether this op already has a precision that is higher |
114 | 115 | (when (> (+ intro ampl) (vector-ref vprecs-new (- x varc))) |
115 | 116 | (vector-set! vprecs-new (- x varc) (+ intro ampl)))))) |
116 | | - |
117 | | -(define (crosses-zero? x) |
118 | | - (not (equal? (mpfr-sign (ival-lo x)) (mpfr-sign (ival-hi x))))) |
119 | | - |
120 | | -; We assume the interval x is valid. Critical not to take mpfr-exp of inf or 0, |
121 | | -; the results are platform-dependant |
122 | | -(define (maxlog x) |
123 | | - (define lo (ival-lo x)) |
124 | | - (define hi (ival-hi x)) |
125 | | - (cond |
126 | | - ; x = [-inf, inf] |
127 | | - [(and (bfinfinite? hi) (bfinfinite? lo)) (get-slack)] |
128 | | - [(bfinfinite? hi) (+ (max (mpfr-exp lo) 0) (get-slack))] ; x = [..., inf] |
129 | | - [(bfinfinite? lo) (+ (max (mpfr-exp hi) 0) (get-slack))] ; x = [-inf, ...] |
130 | | - [else |
131 | | - (+ (max (mpfr-exp lo) (mpfr-exp hi)) 1)])) ; x does not contain inf, safe with respect to 0.bf |
132 | | - |
133 | | -(define (minlog x) |
134 | | - (define lo (ival-lo x)) |
135 | | - (define hi (ival-hi x)) |
136 | | - (cond |
137 | | - ; x = [0.bf, ...] |
138 | | - [(bfzero? lo) (if (bfinfinite? hi) (- (get-slack)) (- (min (mpfr-exp hi) 0) (get-slack)))] |
139 | | - ; x = [..., 0.bf] |
140 | | - [(bfzero? hi) (if (bfinfinite? lo) (- (get-slack)) (- (min (mpfr-exp lo) 0) (get-slack)))] |
141 | | - [(crosses-zero? x) ; x = [-..., +...] |
142 | | - (cond |
143 | | - [(and (bfinfinite? hi) (bfinfinite? lo)) (- (get-slack))] |
144 | | - [(bfinfinite? hi) (- (min (mpfr-exp lo) 0) (get-slack))] |
145 | | - [(bfinfinite? lo) (- (min (mpfr-exp hi) 0) (get-slack))] |
146 | | - [else (- (min (mpfr-exp lo) (mpfr-exp hi) 0) (get-slack))])] |
147 | | - [else |
148 | | - (cond |
149 | | - ; Can't both be inf, since: |
150 | | - ; - [inf, inf] not a valid interval |
151 | | - ; - [-inf, inf] crosses zero |
152 | | - [(bfinfinite? lo) (mpfr-exp hi)] |
153 | | - [(bfinfinite? hi) (mpfr-exp lo)] |
154 | | - [else (min (mpfr-exp lo) (mpfr-exp hi))])])) |
155 | | - |
156 | | -(define (logspan x) |
157 | | - #;(define lo (ival-lo x)) |
158 | | - #;(define hi (ival-hi x)) |
159 | | - #;(if (or (bfzero? lo) (bfinfinite? lo) (bfzero? hi) (bfinfinite? hi)) |
160 | | - (get-slack) |
161 | | - (+ (abs (- (mpfr-exp lo) (mpfr-exp hi))) 1)) |
162 | | - 0) |
163 | | - |
164 | | -; Function calculates an ampl factor per input for a certain output and inputs using condition formulas, |
165 | | -; where an ampl is an additional precision that needs to be added to srcs evaluation so, |
166 | | -; that the output will be fixed in its precision when evaluating again |
167 | | -(define (get-ampls op z srcs) |
168 | | - (case (object-name op) |
169 | | - [(ival-mult) |
170 | | - ; k = 1: logspan(y) |
171 | | - ; k = 2: logspan(x) |
172 | | - (define x (first srcs)) |
173 | | - (define y (second srcs)) |
174 | | - (list (logspan y) ; exponent per x |
175 | | - (logspan x))] ; exponent per y |
176 | | - |
177 | | - [(ival-div) |
178 | | - ; k = 1: logspan(y) |
179 | | - ; k = 2: logspan(x) + 2 * logspan(y) |
180 | | - (define x (first srcs)) |
181 | | - (define y (second srcs)) |
182 | | - (list (logspan y) ; exponent per x |
183 | | - (+ (logspan x) (* 2 (logspan y))))] ; exponent per y |
184 | | - |
185 | | - [(ival-sqrt ival-cbrt) |
186 | | - ; sqrt: logspan(x)/2 - 1 |
187 | | - ; cbrt: logspan(x)*2/3 - 1 |
188 | | - (define x (first srcs)) |
189 | | - (list (quotient (logspan x) 2))] |
190 | | - |
191 | | - [(ival-add ival-sub) |
192 | | - ; k = 1: maxlog(x) - minlog(z) |
193 | | - ; k = 2: maxlog(y) - minlog(z) |
194 | | - (define x (first srcs)) |
195 | | - (define y (second srcs)) |
196 | | - |
197 | | - (list (- (maxlog x) (minlog z)) ; exponent per x |
198 | | - (- (maxlog y) (minlog z)))] ; exponent per y |
199 | | - |
200 | | - [(ival-pow) |
201 | | - ; k = 1: maxlog(y) + logspan(x) + logspan(z) |
202 | | - ; k = 2: maxlog(y) + max(|minlog(x)|,|maxlog(x)|) + logspan(z) |
203 | | - (define x (first srcs)) |
204 | | - (define y (second srcs)) |
205 | | - |
206 | | - ; when output crosses zero and x is negative - means that y was fractional and not fixed (specific of Rival) |
207 | | - ; solution - add more slack for y to converge |
208 | | - (define slack (if (and (crosses-zero? z) (bfnegative? (ival-lo x))) (get-slack) 0)) |
209 | | - |
210 | | - (list (+ (maxlog y) (logspan x) (logspan z)) ; exponent per x |
211 | | - (+ (maxlog y) (max (abs (maxlog x)) (abs (minlog x))) (logspan z) slack))] ; exponent per y |
212 | | - |
213 | | - [(ival-exp ival-exp2) |
214 | | - ; maxlog(x) + logspan(z) |
215 | | - (define x (car srcs)) |
216 | | - (list (+ (maxlog x) (logspan z)))] |
217 | | - |
218 | | - [(ival-tan) |
219 | | - ; maxlog(x) + max(|minlog(z)|,|maxlog(z)|) + logspan(z) + 1 |
220 | | - (define x (first srcs)) |
221 | | - (list (+ (maxlog x) (max (abs (maxlog z)) (abs (minlog z))) (logspan z) 1))] |
222 | | - |
223 | | - [(ival-sin) |
224 | | - ; maxlog(x) - minlog(z) |
225 | | - (define x (first srcs)) |
226 | | - (list (- (maxlog x) (minlog z)))] |
227 | | - |
228 | | - [(ival-cos) |
229 | | - ; maxlog(x) - minlog(z) + min(maxlog(x), 0) |
230 | | - (define x (first srcs)) |
231 | | - (list (+ (- (maxlog x) (minlog z)) (min (maxlog x) 0)))] |
232 | | - |
233 | | - [(ival-sinh) |
234 | | - ; maxlog(x) + logspan(z) - min(minlog(x), 0) |
235 | | - (define x (first srcs)) |
236 | | - (list (- (+ (maxlog x) (logspan z)) (min (minlog x) 0)))] |
237 | | - |
238 | | - [(ival-cosh) |
239 | | - ; maxlog(x) + logspan(z) + min(maxlog(x), 0) |
240 | | - (define x (first srcs)) |
241 | | - (list (+ (maxlog x) (logspan z) (min (maxlog x) 0)))] |
242 | | - |
243 | | - [(ival-log ival-log2 ival-log10) |
244 | | - ; log: logspan(x) - minlog(z) |
245 | | - ; log2: logspan(x) - minlog(z) + 1 |
246 | | - ; log10: logspan(x) - minlog(z) - 1 |
247 | | - (define x (first srcs)) |
248 | | - (list (+ (- (logspan x) (minlog z)) 1))] |
249 | | - |
250 | | - [(ival-asin) |
251 | | - ; maxlog(x) - log[1-x^2]/2 - minlog(z) |
252 | | - ; ^^^^^^^^^^^^ |
253 | | - ; condition of uncertainty |
254 | | - (define x (first srcs)) |
255 | | - (define slack |
256 | | - (if (>= (maxlog z) 2) ; Condition of uncertainty |
257 | | - (get-slack) ; assumes that log[1-x^2]/2 is equal to slack |
258 | | - 0)) |
259 | | - |
260 | | - (list (+ (- (maxlog x) (minlog z)) slack))] |
261 | | - |
262 | | - [(ival-acos) |
263 | | - ; maxlog(x) - log[1-x^2]/2 - minlog(z) |
264 | | - ; ^^^^^^^^^^^^ |
265 | | - ; condition of uncertainty |
266 | | - (define x (first srcs)) |
267 | | - (define slack |
268 | | - (if (>= (maxlog x) 1) ; Condition of uncertainty |
269 | | - (get-slack) ; assumes that log[1-x^2]/2 is equal to slack |
270 | | - 0)) |
271 | | - |
272 | | - (list (+ (- (maxlog x) (minlog z)) slack))] |
273 | | - |
274 | | - [(ival-atan) |
275 | | - ; logspan(x) - min(|minlog(x)|, |maxlog(x)|) - minlog(z) |
276 | | - (define x (first srcs)) |
277 | | - (list (- (logspan x) (min (abs (minlog x)) (abs (maxlog x))) (minlog z)))] |
278 | | - |
279 | | - [(ival-fmod ival-remainder) |
280 | | - ; x mod y = x - y*q, where q is rnd_down(x/y) |
281 | | - ; k = 1: maxlog(x) - minlog(z) |
282 | | - ; k = 2: ~ log[y * rnd_down(x/y)] - log[mod(x,y)] <= maxlog(x) - minlog(z) |
283 | | - ; ^ ^ |
284 | | - ; conditions of uncertainty |
285 | | - (define x (first srcs)) |
286 | | - (define y (second srcs)) |
287 | | - |
288 | | - (define slack |
289 | | - (if (crosses-zero? y) |
290 | | - (get-slack) ; y crosses zero |
291 | | - 0)) |
292 | | - |
293 | | - (list (- (maxlog x) (minlog z)) ; exponent per x |
294 | | - (+ (- (maxlog x) (minlog z)) slack))] ; exponent per y |
295 | | - |
296 | | - ; Currently log1p has a very poor approximation |
297 | | - [(ival-log1p) |
298 | | - ; maxlog(x) - log[1+x] - minlog(z) |
299 | | - ; ^^^^^^^^^^ |
300 | | - ; treated like a slack if x < 0 |
301 | | - (define x (first srcs)) |
302 | | - (define xhi (ival-hi x)) |
303 | | - (define xlo (ival-lo x)) |
304 | | - |
305 | | - (define slack |
306 | | - (if (or (equal? (mpfr-sign xlo) -1) (equal? (mpfr-sign xhi) -1)) |
307 | | - (get-slack) ; if x in negative |
308 | | - 0)) |
309 | | - |
310 | | - (list (+ (- (maxlog x) (minlog z)) slack))] |
311 | | - |
312 | | - ; Currently expm1 has a very poor solution for negative values |
313 | | - [(ival-expm1) |
314 | | - ; log[Гexpm1] = log[x * e^x / expm1] <= max(1 + maxlog(x), 1 + maxlog(x) - minlog(z)) |
315 | | - (define x (first srcs)) |
316 | | - (list (max (+ 1 (maxlog x)) (+ 1 (- (maxlog x) (minlog z)))))] |
317 | | - |
318 | | - [(ival-atan2) |
319 | | - ; maxlog(x) + maxlog(y) - 2*max(minlog(x), minlog(y)) - minlog(z) |
320 | | - (define x (first srcs)) |
321 | | - (define y (second srcs)) |
322 | | - |
323 | | - (make-list 2 (- (+ (maxlog x) (maxlog y)) (* 2 (max (minlog x) (minlog y))) (minlog z)))] |
324 | | - |
325 | | - [(ival-tanh) |
326 | | - ; logspan(z) + logspan(x) |
327 | | - (define x (first srcs)) |
328 | | - (list (+ (logspan z) (logspan x)))] |
329 | | - |
330 | | - [(ival-atanh) |
331 | | - ; log[Гarctanh] = maxlog(x) - log[(1-x^2)] - minlog(z) = 1 if x < 0.5, otherwise slack |
332 | | - ; ^^^^^^^ |
333 | | - ; a possible uncertainty |
334 | | - (define x (first srcs)) |
335 | | - (list (if (>= (maxlog x) 1) (get-slack) 1))] |
336 | | - |
337 | | - [(ival-acosh) |
338 | | - ; log[Гacosh] = log[x / (sqrt(x-1) * sqrt(x+1) * acosh)] <= -minlog(z) + slack |
339 | | - (define z-exp (minlog z)) |
340 | | - (define slack |
341 | | - (if (< z-exp 2) ; when acosh(x) < 1 |
342 | | - (get-slack) |
343 | | - 0)) |
344 | | - |
345 | | - (list (- slack z-exp))] |
346 | | - |
347 | | - [(ival-pow2) |
348 | | - ; same as multiplication |
349 | | - (define x (first srcs)) |
350 | | - (list (+ (logspan x) 1))] |
351 | | - |
352 | | - ; TODO |
353 | | - [(ival-erfc ival-erf ival-lgamma ival-tgamma ival-asinh ival-logb) (list (get-slack))] |
354 | | - ; TODO |
355 | | - [(ival-ceil ival-floor ival-rint ival-round ival-trunc) (list (get-slack))] |
356 | | - |
357 | | - [else (map (const 0) srcs)])) ; exponents for arguments |
358 | | - |
359 | | -(define (get-slack) |
360 | | - (match (*sampling-iteration*) |
361 | | - [1 512] |
362 | | - [2 1024] |
363 | | - [3 2048] |
364 | | - [4 4096] |
365 | | - [5 8192])) |
0 commit comments