Skip to content

Commit bf9f88a

Browse files
authored
Merge pull request #1389 from herbie-fp/codex/refactor-make-series-macro-to-function
Fix make-series builder cache access
2 parents a5a9652 + 380ef27 commit bf9f88a

File tree

1 file changed

+108
-137
lines changed

1 file changed

+108
-137
lines changed

src/core/taylor.rkt

Lines changed: 108 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -241,41 +241,35 @@
241241
(adder 0)
242242
((cdr series) (+ n (- offset offset*)))))))]))
243243

244-
(define-syntax-rule (make-series n offset cache n* body ...)
245-
(cons offset
246-
(λ (n)
247-
(when (>= n (dvector-length cache))
248-
(for ([n* (in-range (dvector-length cache) (add1 n))])
249-
(let ([value (reducer (adder (begin
250-
body ...)))])
251-
(dvector-set! cache n* value))))
252-
(dvector-ref cache n))))
244+
(define (make-series offset builder)
245+
(define cache (make-dvector 10))
246+
(define fetch (curry dvector-ref cache))
247+
(define (lookup n)
248+
(when (>= n (dvector-length cache))
249+
(for ([i (in-range (dvector-length cache) (add1 n))])
250+
(define value (reducer (adder (builder fetch i))))
251+
(dvector-set! cache i value)))
252+
(dvector-ref cache n))
253+
(cons offset lookup))
253254

254255
(define (taylor-add . terms)
255256
;(->* () #:rest (listof term?) term?)
256257
(match-define `((,offset . ,serieses) ...) (apply align-series terms))
257-
(define cache (make-dvector 10))
258-
(make-series n
259-
(car offset)
260-
cache
261-
n*
262-
(make-sum (for/list ([series serieses])
263-
(series n*)))))
258+
(make-series (car offset)
259+
(λ (f n)
260+
(make-sum (for/list ([series serieses])
261+
(series n))))))
264262

265263
(define (taylor-negate term)
266264
;(-> term? term?)
267-
(define cache (make-dvector 10))
268-
(make-series n (car term) cache n* (list 'neg ((cdr term) n*))))
265+
(make-series (car term) (λ (f n) (list 'neg ((cdr term) n)))))
269266

270267
(define (taylor-mult left right)
271268
;(-> term? term? term?)
272-
(define cache (make-dvector 10))
273-
(make-series n
274-
(+ (car left) (car right))
275-
cache
276-
n*
277-
(make-sum (for/list ([i (range (+ n* 1))])
278-
(list '* ((cdr left) i) ((cdr right) (- n* i)))))))
269+
(make-series (+ (car left) (car right))
270+
(λ (f n)
271+
(make-sum (for/list ([i (range (+ n 1))])
272+
(list '* ((cdr left) i) ((cdr right) (- n i))))))))
279273

280274
(define (normalize-series series)
281275
;(-> term? term?)
@@ -297,16 +291,12 @@
297291
This happens if the inverted series doesn't have a constant term,
298292
so we extract that case out."
299293
(match-define (cons offset b) (normalize-series term))
300-
(define cache (make-dvector 10))
301-
302-
(make-series n
303-
(- offset)
304-
cache
305-
n*
306-
(if (zero? n*)
307-
`(/ 1 ,(b 0))
308-
`(neg (+ ,@(for/list ([i (range n*)])
309-
`(* ,(dvector-ref cache i) (/ ,(b (- n* i)) ,(b 0)))))))))
294+
(make-series (- offset)
295+
(λ (f n)
296+
(if (zero? n)
297+
`(/ 1 ,(b 0))
298+
`(neg (+ ,@(for/list ([i (range n)])
299+
`(* ,(f i) (/ ,(b (- n i)) ,(b 0))))))))))
310300

311301
(define (taylor-quotient num denom)
312302
;(-> term? term? term?)
@@ -315,16 +305,13 @@
315305
so we extract that case out."
316306
(match-define (cons noff a) (normalize-series num))
317307
(match-define (cons doff b) (normalize-series denom))
318-
(define cache (make-dvector 10))
319-
(make-series n
320-
(- noff doff)
321-
cache
322-
n*
323-
(if (zero? n*)
324-
`(/ ,(a 0) ,(b 0))
325-
`(- (/ ,(a n*) ,(b 0))
326-
(+ ,@(for/list ([i (range n*)])
327-
`(* ,(dvector-ref cache i) (/ ,(b (- n* i)) ,(b 0)))))))))
308+
(make-series (- noff doff)
309+
(λ (f n)
310+
(if (zero? n)
311+
`(/ ,(a 0) ,(b 0))
312+
`(- (/ ,(a n) ,(b 0))
313+
(+ ,@(for/list ([i (range n)])
314+
`(* ,(f i) (/ ,(b (- n i)) ,(b 0))))))))))
328315

329316
(define (modulo-series var n series)
330317
;(-> symbol? number? term? term?)
@@ -349,47 +336,40 @@
349336
(define (taylor-sqrt var num)
350337
;(-> symbol? term? term?)
351338
(match-define (cons offset* coeffs*) (modulo-series var 2 num))
352-
(define cache (make-dvector 10))
353-
(make-series n
354-
(/ offset* 2)
355-
cache
356-
n*
357-
(cond
358-
[(zero? n*) `(sqrt ,(coeffs* 0))]
359-
[(= n* 1) `(/ ,(coeffs* 1) (* 2 (sqrt ,(coeffs* 0))))]
360-
[(even? n*)
361-
`(/ (- ,(coeffs* n*)
362-
(pow ,(dvector-ref cache (/ n* 2)) 2)
363-
(+ ,@(for/list ([k (in-naturals 1)]
364-
#:break (>= k (- n* k)))
365-
`(* 2 (* ,(dvector-ref cache k) ,(dvector-ref cache (- n* k)))))))
366-
(* 2 ,(dvector-ref cache 0)))]
367-
[(odd? n*)
368-
`(/ (- ,(coeffs* n*)
369-
(+ ,@(for/list ([k (in-naturals 1)]
370-
#:break (>= k (- n* k)))
371-
`(* 2 (* ,(dvector-ref cache k) ,(dvector-ref cache (- n* k)))))))
372-
(* 2 ,(dvector-ref cache 0)))])))
339+
(make-series (/ offset* 2)
340+
(λ (f n)
341+
(cond
342+
[(zero? n) `(sqrt ,(coeffs* 0))]
343+
[(= n 1) `(/ ,(coeffs* 1) (* 2 (sqrt ,(coeffs* 0))))]
344+
[(even? n)
345+
`(/ (- ,(coeffs* n)
346+
(pow ,(f (/ n 2)) 2)
347+
(+ ,@(for/list ([k (in-naturals 1)]
348+
#:break (>= k (- n k)))
349+
`(* 2 (* ,(f k) ,(f (- n k)))))))
350+
(* 2 ,(f 0)))]
351+
[(odd? n)
352+
`(/ (- ,(coeffs* n)
353+
(+ ,@(for/list ([k (in-naturals 1)]
354+
#:break (>= k (- n k)))
355+
`(* 2 (* ,(f k) ,(f (- n k)))))))
356+
(* 2 ,(f 0)))]))))
373357

374358
(define (taylor-cbrt var num)
375359
;(-> symbol? term? term?)
376360
(match-define (cons offset* coeffs*) (modulo-series var 3 num))
377-
(define cache (make-dvector 10))
378-
(make-series
379-
n
380-
(/ offset* 3)
381-
cache
382-
n*
383-
(cond
384-
[(zero? n*) `(cbrt ,(coeffs* 0))]
385-
[(= n* 1) `(/ ,(coeffs* 1) (* 3 (cbrt (* ,(dvector-ref cache 0) ,(dvector-ref cache 0)))))]
386-
[else
387-
`(/ (- ,(coeffs* n*)
388-
,@(for*/list ([terms (n-sum-to 3 n*)]
389-
#:unless (set-member? terms n*))
390-
(match-define (list a b c) terms)
391-
`(* ,(dvector-ref cache a) ,(dvector-ref cache b) ,(dvector-ref cache c))))
392-
(* 3 ,(dvector-ref cache 0) ,(dvector-ref cache 0)))])))
361+
(make-series (/ offset* 3)
362+
(λ (f n)
363+
(cond
364+
[(zero? n) `(cbrt ,(coeffs* 0))]
365+
[(= n 1) `(/ ,(coeffs* 1) (* 3 (cbrt (* ,(f 0) ,(f 0)))))]
366+
[else
367+
`(/ (- ,(coeffs* n)
368+
,@(for*/list ([terms (n-sum-to 3 n)]
369+
#:unless (set-member? terms n))
370+
(match-define (list a b c) terms)
371+
`(* ,(f a) ,(f b) ,(f c))))
372+
(* 3 ,(f 0) ,(f 0)))]))))
393373

394374
(define (taylor-pow coeffs n)
395375
;(-> term? number? term?)
@@ -421,67 +401,58 @@
421401

422402
(define (taylor-exp coeffs)
423403
;(-> (-> number? batchref?) term?)
424-
(define cache (make-dvector 10))
425-
(make-series n
426-
0
427-
cache
428-
n*
429-
(if (zero? n)
430-
`(exp ,(coeffs 0))
431-
(let* ([coeffs* (list->vector (map coeffs (range 1 (+ n 1))))]
432-
[nums (for/list ([i (in-range 1 (+ n 1))]
433-
[coeff (in-vector coeffs*)]
434-
#:unless (equal? (deref coeff) 0))
435-
i)])
436-
`(* (exp ,(coeffs 0))
437-
(+ ,@(for/list ([p (all-partitions n (sort nums >))])
438-
`(* ,@(for/list ([(count num) (in-dict p)])
439-
`(/ (pow ,(vector-ref coeffs* (- num 1)) ,count)
440-
,(factorial count)))))))))))
404+
(make-series 0
405+
(λ (f n)
406+
(if (zero? n)
407+
`(exp ,(coeffs 0))
408+
(let* ([coeffs* (list->vector (map coeffs (range 1 (+ n 1))))]
409+
[nums (for/list ([i (in-range 1 (+ n 1))]
410+
[coeff (in-vector coeffs*)]
411+
#:unless (equal? (deref coeff) 0))
412+
i)])
413+
`(* (exp ,(coeffs 0))
414+
(+ ,@(for/list ([p (all-partitions n (sort nums >))])
415+
`(* ,@(for/list ([(count num) (in-dict p)])
416+
`(/ (pow ,(vector-ref coeffs* (- num 1)) ,count)
417+
,(factorial count))))))))))))
441418

442419
(define (taylor-sin coeffs)
443420
;(-> (-> number? batchref?) term?)
444-
(define cache (make-dvector 10))
445-
(make-series n
446-
0
447-
cache
448-
n*
449-
(if (zero? n*)
450-
0
451-
(let* ([coeffs* (list->vector (map coeffs (range 1 (+ n* 1))))]
452-
[nums (for/list ([i (in-range 1 (+ n* 1))]
453-
[coeff (in-vector coeffs*)]
454-
#:unless (equal? (deref coeff) 0))
455-
i)])
456-
`(+ ,@(for/list ([p (all-partitions n* (sort nums >))])
457-
(if (= (modulo (apply + (map car p)) 2) 1)
458-
`(* ,(if (= (modulo (apply + (map car p)) 4) 1) 1 -1)
459-
,@(for/list ([(count num) (in-dict p)])
460-
`(/ (pow ,(vector-ref coeffs* (- num 1)) ,count)
461-
,(factorial count))))
462-
0)))))))
421+
(make-series 0
422+
(λ (f n)
423+
(if (zero? n)
424+
0
425+
(let* ([coeffs* (list->vector (map coeffs (range 1 (+ n 1))))]
426+
[nums (for/list ([i (in-range 1 (+ n 1))]
427+
[coeff (in-vector coeffs*)]
428+
#:unless (equal? (deref coeff) 0))
429+
i)])
430+
`(+ ,@(for/list ([p (all-partitions n (sort nums >))])
431+
(if (= (modulo (apply + (map car p)) 2) 1)
432+
`(* ,(if (= (modulo (apply + (map car p)) 4) 1) 1 -1)
433+
,@(for/list ([(count num) (in-dict p)])
434+
`(/ (pow ,(vector-ref coeffs* (- num 1)) ,count)
435+
,(factorial count))))
436+
0))))))))
463437

464438
(define (taylor-cos coeffs)
465439
;(-> (-> number? batchref?) term?)
466-
(define cache (make-dvector 10))
467-
(make-series n
468-
0
469-
cache
470-
n*
471-
(if (zero? n*)
472-
1
473-
(let* ([coeffs* (list->vector (map coeffs (range 1 (+ n* 1))))]
474-
[nums (for/list ([i (in-range 1 (+ n* 1))]
475-
[coeff (in-vector coeffs*)]
476-
#:unless (equal? (deref coeff) 0))
477-
i)])
478-
`(+ ,@(for/list ([p (all-partitions n* (sort nums >))])
479-
(if (= (modulo (apply + (map car p)) 2) 0)
480-
`(* ,(if (= (modulo (apply + (map car p)) 4) 0) 1 -1)
481-
,@(for/list ([(count num) (in-dict p)])
482-
`(/ (pow ,(vector-ref coeffs* (- num 1)) ,count)
483-
,(factorial count))))
484-
0)))))))
440+
(make-series 0
441+
(λ (f n)
442+
(if (zero? n)
443+
1
444+
(let* ([coeffs* (list->vector (map coeffs (range 1 (+ n 1))))]
445+
[nums (for/list ([i (in-range 1 (+ n 1))]
446+
[coeff (in-vector coeffs*)]
447+
#:unless (equal? (deref coeff) 0))
448+
i)])
449+
`(+ ,@(for/list ([p (all-partitions n (sort nums >))])
450+
(if (= (modulo (apply + (map car p)) 2) 0)
451+
`(* ,(if (= (modulo (apply + (map car p)) 4) 0) 1 -1)
452+
,@(for/list ([(count num) (in-dict p)])
453+
`(/ (pow ,(vector-ref coeffs* (- num 1)) ,count)
454+
,(factorial count))))
455+
0))))))))
485456

486457
;; This is a hyper-specialized symbolic differentiator for log(f(x))
487458

0 commit comments

Comments
 (0)