Skip to content

Commit a5a9652

Browse files
authored
Merge pull request #1388 from herbie-fp/codex/refactor-cache-initialization-in-taylor-xyz-functions
Unify base cases and normal cases in Taylor
2 parents bb9418a + 5819e9e commit a5a9652

File tree

1 file changed

+66
-67
lines changed

1 file changed

+66
-67
lines changed

src/core/taylor.rkt

Lines changed: 66 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -298,14 +298,15 @@
298298
so we extract that case out."
299299
(match-define (cons offset b) (normalize-series term))
300300
(define cache (make-dvector 10))
301-
(dvector-set! cache 0 (reducer (adder `(/ 1 ,(b 0)))))
302301

303302
(make-series n
304303
(- offset)
305304
cache
306305
n*
307-
`(neg (+ ,@(for/list ([i (range n*)])
308-
`(* ,(dvector-ref cache i) (/ ,(b (- n* i)) ,(b 0))))))))
306+
(if (zero? n*)
307+
`(/ 1 ,(b 0))
308+
`(neg (+ ,@(for/list ([i (range n*)])
309+
`(* ,(dvector-ref cache i) (/ ,(b (- n* i)) ,(b 0)))))))))
309310

310311
(define (taylor-quotient num denom)
311312
;(-> term? term? term?)
@@ -315,15 +316,15 @@
315316
(match-define (cons noff a) (normalize-series num))
316317
(match-define (cons doff b) (normalize-series denom))
317318
(define cache (make-dvector 10))
318-
(dvector-set! cache 0 (reducer (adder `(/ ,(a 0) ,(b 0)))))
319-
320319
(make-series n
321320
(- noff doff)
322321
cache
323322
n*
324-
`(- (/ ,(a n*) ,(b 0))
325-
(+ ,@(for/list ([i (range n*)])
326-
`(* ,(dvector-ref cache i) (/ ,(b (- n* i)) ,(b 0))))))))
323+
(if (zero? n*)
324+
`(/ ,(a 0) ,(b 0))
325+
`(- (/ ,(a n*) ,(b 0))
326+
(+ ,@(for/list ([i (range n*)])
327+
`(* ,(dvector-ref cache i) (/ ,(b (- n* i)) ,(b 0)))))))))
327328

328329
(define (modulo-series var n series)
329330
;(-> symbol? number? term? term?)
@@ -349,14 +350,13 @@
349350
;(-> symbol? term? term?)
350351
(match-define (cons offset* coeffs*) (modulo-series var 2 num))
351352
(define cache (make-dvector 10))
352-
(dvector-set! cache 0 (reducer (adder `(sqrt ,(coeffs* 0)))))
353-
(dvector-set! cache 1 (reducer (adder `(/ ,(coeffs* 1) (* 2 (sqrt ,(coeffs* 0)))))))
354-
355353
(make-series n
356354
(/ offset* 2)
357355
cache
358356
n*
359357
(cond
358+
[(zero? n*) `(sqrt ,(coeffs* 0))]
359+
[(= n* 1) `(/ ,(coeffs* 1) (* 2 (sqrt ,(coeffs* 0))))]
360360
[(even? n*)
361361
`(/ (- ,(coeffs* n*)
362362
(pow ,(dvector-ref cache (/ n* 2)) 2)
@@ -375,22 +375,21 @@
375375
;(-> symbol? term? term?)
376376
(match-define (cons offset* coeffs*) (modulo-series var 3 num))
377377
(define cache (make-dvector 10))
378-
(dvector-set! cache 0 (reducer (adder `(cbrt ,(coeffs* 0)))))
379-
(dvector-set! cache
380-
1
381-
(reducer (adder `(/ ,(coeffs* 1)
382-
(* 3 (cbrt (* ,(dvector-ref cache 0) ,(dvector-ref cache 0))))))))
383-
384-
(make-series n
385-
(/ offset* 3)
386-
cache
387-
n*
388-
`(/ (- ,(coeffs* n*)
389-
,@(for*/list ([terms (n-sum-to 3 n*)]
390-
#:unless (set-member? terms n*))
391-
(match-define (list a b c) terms)
392-
`(* ,(dvector-ref cache a) ,(dvector-ref cache b) ,(dvector-ref cache c))))
393-
(* 3 ,(dvector-ref cache 0) ,(dvector-ref cache 0)))))
378+
(make-series
379+
n
380+
(/ offset* 3)
381+
cache
382+
n*
383+
(cond
384+
[(zero? n*) `(cbrt ,(coeffs* 0))]
385+
[(= n* 1) `(/ ,(coeffs* 1) (* 3 (cbrt (* ,(dvector-ref cache 0) ,(dvector-ref cache 0)))))]
386+
[else
387+
`(/ (- ,(coeffs* n*)
388+
,@(for*/list ([terms (n-sum-to 3 n*)]
389+
#:unless (set-member? terms n*))
390+
(match-define (list a b c) terms)
391+
`(* ,(dvector-ref cache a) ,(dvector-ref cache b) ,(dvector-ref cache c))))
392+
(* 3 ,(dvector-ref cache 0) ,(dvector-ref cache 0)))])))
394393

395394
(define (taylor-pow coeffs n)
396395
;(-> term? number? term?)
@@ -423,66 +422,66 @@
423422
(define (taylor-exp coeffs)
424423
;(-> (-> number? batchref?) term?)
425424
(define cache (make-dvector 10))
426-
(dvector-set! cache 0 (reducer (adder `(exp ,(coeffs 0)))))
427-
428425
(make-series n
429426
0
430427
cache
431428
n*
432-
(let* ([coeffs* (list->vector (map coeffs (range 1 (+ n* 1))))]
433-
[nums (for/list ([i (in-range 1 (+ n* 1))]
434-
[coeff (in-vector coeffs*)]
435-
#:unless (equal? (deref coeff) 0))
436-
i)])
437-
`(* (exp ,(coeffs 0))
438-
(+ ,@(for/list ([p (all-partitions n* (sort nums >))])
439-
`(* ,@(for/list ([(count num) (in-dict p)])
440-
`(/ (pow ,(vector-ref coeffs* (- num 1)) ,count)
441-
,(factorial count))))))))))
429+
(if (zero? n)
430+
`(exp ,(coeffs 0))
431+
(let* ([coeffs* (list->vector (map coeffs (range 1 (+ n 1))))]
432+
[nums (for/list ([i (in-range 1 (+ n 1))]
433+
[coeff (in-vector coeffs*)]
434+
#:unless (equal? (deref coeff) 0))
435+
i)])
436+
`(* (exp ,(coeffs 0))
437+
(+ ,@(for/list ([p (all-partitions n (sort nums >))])
438+
`(* ,@(for/list ([(count num) (in-dict p)])
439+
`(/ (pow ,(vector-ref coeffs* (- num 1)) ,count)
440+
,(factorial count)))))))))))
442441

443442
(define (taylor-sin coeffs)
444443
;(-> (-> number? batchref?) term?)
445444
(define cache (make-dvector 10))
446-
(dvector-set! cache 0 (adder 0))
447-
448445
(make-series n
449446
0
450447
cache
451448
n*
452-
(let* ([coeffs* (list->vector (map coeffs (range 1 (+ n* 1))))]
453-
[nums (for/list ([i (in-range 1 (+ n* 1))]
454-
[coeff (in-vector coeffs*)]
455-
#:unless (equal? (deref coeff) 0))
456-
i)])
457-
`(+ ,@(for/list ([p (all-partitions n* (sort nums >))])
458-
(if (= (modulo (apply + (map car p)) 2) 1)
459-
`(* ,(if (= (modulo (apply + (map car p)) 4) 1) 1 -1)
460-
,@(for/list ([(count num) (in-dict p)])
461-
`(/ (pow ,(vector-ref coeffs* (- num 1)) ,count)
462-
,(factorial count))))
463-
0))))))
449+
(if (zero? n*)
450+
0
451+
(let* ([coeffs* (list->vector (map coeffs (range 1 (+ n* 1))))]
452+
[nums (for/list ([i (in-range 1 (+ n* 1))]
453+
[coeff (in-vector coeffs*)]
454+
#:unless (equal? (deref coeff) 0))
455+
i)])
456+
`(+ ,@(for/list ([p (all-partitions n* (sort nums >))])
457+
(if (= (modulo (apply + (map car p)) 2) 1)
458+
`(* ,(if (= (modulo (apply + (map car p)) 4) 1) 1 -1)
459+
,@(for/list ([(count num) (in-dict p)])
460+
`(/ (pow ,(vector-ref coeffs* (- num 1)) ,count)
461+
,(factorial count))))
462+
0)))))))
464463

465464
(define (taylor-cos coeffs)
466465
;(-> (-> number? batchref?) term?)
467466
(define cache (make-dvector 10))
468-
(dvector-set! cache 0 (adder 1))
469-
470467
(make-series n
471468
0
472469
cache
473470
n*
474-
(let* ([coeffs* (list->vector (map coeffs (range 1 (+ n* 1))))]
475-
[nums (for/list ([i (in-range 1 (+ n* 1))]
476-
[coeff (in-vector coeffs*)]
477-
#:unless (equal? (deref coeff) 0))
478-
i)])
479-
`(+ ,@(for/list ([p (all-partitions n* (sort nums >))])
480-
(if (= (modulo (apply + (map car p)) 2) 0)
481-
`(* ,(if (= (modulo (apply + (map car p)) 4) 0) 1 -1)
482-
,@(for/list ([(count num) (in-dict p)])
483-
`(/ (pow ,(vector-ref coeffs* (- num 1)) ,count)
484-
,(factorial count))))
485-
0))))))
471+
(if (zero? n*)
472+
1
473+
(let* ([coeffs* (list->vector (map coeffs (range 1 (+ n* 1))))]
474+
[nums (for/list ([i (in-range 1 (+ n* 1))]
475+
[coeff (in-vector coeffs*)]
476+
#:unless (equal? (deref coeff) 0))
477+
i)])
478+
`(+ ,@(for/list ([p (all-partitions n* (sort nums >))])
479+
(if (= (modulo (apply + (map car p)) 2) 0)
480+
`(* ,(if (= (modulo (apply + (map car p)) 4) 0) 1 -1)
481+
,@(for/list ([(count num) (in-dict p)])
482+
`(/ (pow ,(vector-ref coeffs* (- num 1)) ,count)
483+
,(factorial count))))
484+
0)))))))
486485

487486
;; This is a hyper-specialized symbolic differentiator for log(f(x))
488487

0 commit comments

Comments
 (0)