Skip to content

Commit bb9418a

Browse files
authored
Merge pull request #1387 from herbie-fp/codex/create-branch-for-taylor-xyz-functions
Extend cached Taylor series evaluation
2 parents b7bf174 + 9dfb555 commit bb9418a

File tree

1 file changed

+101
-107
lines changed

1 file changed

+101
-107
lines changed

src/core/taylor.rkt

Lines changed: 101 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -241,16 +241,7 @@
241241
(adder 0)
242242
((cdr series) (+ n (- offset offset*)))))))]))
243243

244-
(define-syntax-rule (make-cached-series n offset cache body ...)
245-
(cons offset
246-
(λ (n)
247-
(unless (and (> (dvector-capacity cache) n) (dvector-ref cache n))
248-
(let ([value (reducer (adder (begin
249-
body ...)))])
250-
(dvector-set! cache n value)))
251-
(dvector-ref cache n))))
252-
253-
(define-syntax-rule (make-cached-series/extend n offset cache n* body ...)
244+
(define-syntax-rule (make-series n offset cache n* body ...)
254245
(cons offset
255246
(λ (n)
256247
(when (>= n (dvector-length cache))
@@ -264,25 +255,27 @@
264255
;(->* () #:rest (listof term?) term?)
265256
(match-define `((,offset . ,serieses) ...) (apply align-series terms))
266257
(define cache (make-dvector 10))
267-
(make-cached-series n
268-
(car offset)
269-
cache
270-
(make-sum (for/list ([series serieses])
271-
(series n)))))
258+
(make-series n
259+
(car offset)
260+
cache
261+
n*
262+
(make-sum (for/list ([series serieses])
263+
(series n*)))))
272264

273265
(define (taylor-negate term)
274266
;(-> term? term?)
275267
(define cache (make-dvector 10))
276-
(make-cached-series n (car term) cache (list 'neg ((cdr term) n))))
268+
(make-series n (car term) cache n* (list 'neg ((cdr term) n*))))
277269

278270
(define (taylor-mult left right)
279271
;(-> term? term? term?)
280272
(define cache (make-dvector 10))
281-
(make-cached-series n
282-
(+ (car left) (car right))
283-
cache
284-
(make-sum (for/list ([i (range (+ n 1))])
285-
(list '* ((cdr left) i) ((cdr right) (- n i)))))))
273+
(make-series n
274+
(+ (car left) (car right))
275+
cache
276+
n*
277+
(make-sum (for/list ([i (range (+ n* 1))])
278+
(list '* ((cdr left) i) ((cdr right) (- n* i)))))))
286279

287280
(define (normalize-series series)
288281
;(-> term? term?)
@@ -307,12 +300,12 @@
307300
(define cache (make-dvector 10))
308301
(dvector-set! cache 0 (reducer (adder `(/ 1 ,(b 0)))))
309302

310-
(make-cached-series/extend n
311-
(- offset)
312-
cache
313-
n*
314-
`(neg (+ ,@(for/list ([i (range n*)])
315-
`(* ,(dvector-ref cache i) (/ ,(b (- n* i)) ,(b 0))))))))
303+
(make-series n
304+
(- offset)
305+
cache
306+
n*
307+
`(neg (+ ,@(for/list ([i (range n*)])
308+
`(* ,(dvector-ref cache i) (/ ,(b (- n* i)) ,(b 0))))))))
316309

317310
(define (taylor-quotient num denom)
318311
;(-> term? term? term?)
@@ -324,13 +317,13 @@
324317
(define cache (make-dvector 10))
325318
(dvector-set! cache 0 (reducer (adder `(/ ,(a 0) ,(b 0)))))
326319

327-
(make-cached-series/extend n
328-
(- noff doff)
329-
cache
330-
n*
331-
`(- (/ ,(a n*) ,(b 0))
332-
(+ ,@(for/list ([i (range n*)])
333-
`(* ,(dvector-ref cache i) (/ ,(b (- n* i)) ,(b 0))))))))
320+
(make-series n
321+
(- noff doff)
322+
cache
323+
n*
324+
`(- (/ ,(a n*) ,(b 0))
325+
(+ ,@(for/list ([i (range n*)])
326+
`(* ,(dvector-ref cache i) (/ ,(b (- n* i)) ,(b 0))))))))
334327

335328
(define (modulo-series var n series)
336329
;(-> symbol? number? term? term?)
@@ -359,25 +352,24 @@
359352
(dvector-set! cache 0 (reducer (adder `(sqrt ,(coeffs* 0)))))
360353
(dvector-set! cache 1 (reducer (adder `(/ ,(coeffs* 1) (* 2 (sqrt ,(coeffs* 0)))))))
361354

362-
(make-cached-series/extend
363-
n
364-
(/ offset* 2)
365-
cache
366-
n*
367-
(cond
368-
[(even? n*)
369-
`(/ (- ,(coeffs* n*)
370-
(pow ,(dvector-ref cache (/ n* 2)) 2)
371-
(+ ,@(for/list ([k (in-naturals 1)]
372-
#:break (>= k (- n* k)))
373-
`(* 2 (* ,(dvector-ref cache k) ,(dvector-ref cache (- n* k)))))))
374-
(* 2 ,(dvector-ref cache 0)))]
375-
[(odd? n*)
376-
`(/ (- ,(coeffs* n*)
377-
(+ ,@(for/list ([k (in-naturals 1)]
378-
#:break (>= k (- n* k)))
379-
`(* 2 (* ,(dvector-ref cache k) ,(dvector-ref cache (- n* k)))))))
380-
(* 2 ,(dvector-ref cache 0)))])))
355+
(make-series n
356+
(/ offset* 2)
357+
cache
358+
n*
359+
(cond
360+
[(even? n*)
361+
`(/ (- ,(coeffs* n*)
362+
(pow ,(dvector-ref cache (/ n* 2)) 2)
363+
(+ ,@(for/list ([k (in-naturals 1)]
364+
#:break (>= k (- n* k)))
365+
`(* 2 (* ,(dvector-ref cache k) ,(dvector-ref cache (- n* k)))))))
366+
(* 2 ,(dvector-ref cache 0)))]
367+
[(odd? n*)
368+
`(/ (- ,(coeffs* n*)
369+
(+ ,@(for/list ([k (in-naturals 1)]
370+
#:break (>= k (- n* k)))
371+
`(* 2 (* ,(dvector-ref cache k) ,(dvector-ref cache (- n* k)))))))
372+
(* 2 ,(dvector-ref cache 0)))])))
381373

382374
(define (taylor-cbrt var num)
383375
;(-> symbol? term? term?)
@@ -389,17 +381,16 @@
389381
(reducer (adder `(/ ,(coeffs* 1)
390382
(* 3 (cbrt (* ,(dvector-ref cache 0) ,(dvector-ref cache 0))))))))
391383

392-
(make-cached-series/extend
393-
n
394-
(/ offset* 3)
395-
cache
396-
n*
397-
`(/ (- ,(coeffs* n*)
398-
,@(for*/list ([terms (n-sum-to 3 n*)]
399-
#:unless (set-member? terms n*))
400-
(match-define (list a b c) terms)
401-
`(* ,(dvector-ref cache a) ,(dvector-ref cache b) ,(dvector-ref cache c))))
402-
(* 3 ,(dvector-ref cache 0) ,(dvector-ref cache 0)))))
384+
(make-series n
385+
(/ offset* 3)
386+
cache
387+
n*
388+
`(/ (- ,(coeffs* n*)
389+
,@(for*/list ([terms (n-sum-to 3 n*)]
390+
#:unless (set-member? terms n*))
391+
(match-define (list a b c) terms)
392+
`(* ,(dvector-ref cache a) ,(dvector-ref cache b) ,(dvector-ref cache c))))
393+
(* 3 ,(dvector-ref cache 0) ,(dvector-ref cache 0)))))
403394

404395
(define (taylor-pow coeffs n)
405396
;(-> term? number? term?)
@@ -434,61 +425,64 @@
434425
(define cache (make-dvector 10))
435426
(dvector-set! cache 0 (reducer (adder `(exp ,(coeffs 0)))))
436427

437-
(make-cached-series n
438-
0
439-
cache
440-
(let* ([coeffs* (list->vector (map coeffs (range 1 (+ n 1))))]
441-
[nums (for/list ([i (in-range 1 (+ n 1))]
442-
[coeff (in-vector coeffs*)]
443-
#:unless (equal? (deref coeff) 0))
444-
i)])
445-
`(* (exp ,(coeffs 0))
446-
(+ ,@(for/list ([p (all-partitions n (sort nums >))])
447-
`(* ,@(for/list ([(count num) (in-dict p)])
448-
`(/ (pow ,(vector-ref coeffs* (- num 1)) ,count)
449-
,(factorial count))))))))))
428+
(make-series n
429+
0
430+
cache
431+
n*
432+
(let* ([coeffs* (list->vector (map coeffs (range 1 (+ n* 1))))]
433+
[nums (for/list ([i (in-range 1 (+ n* 1))]
434+
[coeff (in-vector coeffs*)]
435+
#:unless (equal? (deref coeff) 0))
436+
i)])
437+
`(* (exp ,(coeffs 0))
438+
(+ ,@(for/list ([p (all-partitions n* (sort nums >))])
439+
`(* ,@(for/list ([(count num) (in-dict p)])
440+
`(/ (pow ,(vector-ref coeffs* (- num 1)) ,count)
441+
,(factorial count))))))))))
450442

451443
(define (taylor-sin coeffs)
452444
;(-> (-> number? batchref?) term?)
453445
(define cache (make-dvector 10))
454446
(dvector-set! cache 0 (adder 0))
455447

456-
(make-cached-series n
457-
0
458-
cache
459-
(let* ([coeffs* (list->vector (map coeffs (range 1 (+ n 1))))]
460-
[nums (for/list ([i (in-range 1 (+ n 1))]
461-
[coeff (in-vector coeffs*)]
462-
#:unless (equal? (deref coeff) 0))
463-
i)])
464-
`(+ ,@(for/list ([p (all-partitions n (sort nums >))])
465-
(if (= (modulo (apply + (map car p)) 2) 1)
466-
`(* ,(if (= (modulo (apply + (map car p)) 4) 1) 1 -1)
467-
,@(for/list ([(count num) (in-dict p)])
468-
`(/ (pow ,(vector-ref coeffs* (- num 1)) ,count)
469-
,(factorial count))))
470-
0))))))
448+
(make-series n
449+
0
450+
cache
451+
n*
452+
(let* ([coeffs* (list->vector (map coeffs (range 1 (+ n* 1))))]
453+
[nums (for/list ([i (in-range 1 (+ n* 1))]
454+
[coeff (in-vector coeffs*)]
455+
#:unless (equal? (deref coeff) 0))
456+
i)])
457+
`(+ ,@(for/list ([p (all-partitions n* (sort nums >))])
458+
(if (= (modulo (apply + (map car p)) 2) 1)
459+
`(* ,(if (= (modulo (apply + (map car p)) 4) 1) 1 -1)
460+
,@(for/list ([(count num) (in-dict p)])
461+
`(/ (pow ,(vector-ref coeffs* (- num 1)) ,count)
462+
,(factorial count))))
463+
0))))))
471464

472465
(define (taylor-cos coeffs)
473466
;(-> (-> number? batchref?) term?)
474467
(define cache (make-dvector 10))
475468
(dvector-set! cache 0 (adder 1))
476469

477-
(make-cached-series n
478-
0
479-
cache
480-
(let* ([coeffs* (list->vector (map coeffs (range 1 (+ n 1))))]
481-
[nums (for/list ([i (in-range 1 (+ n 1))]
482-
[coeff (in-vector coeffs*)]
483-
#:unless (equal? (deref coeff) 0))
484-
i)])
485-
`(+ ,@(for/list ([p (all-partitions n (sort nums >))])
486-
(if (= (modulo (apply + (map car p)) 2) 0)
487-
`(* ,(if (= (modulo (apply + (map car p)) 4) 0) 1 -1)
488-
,@(for/list ([(count num) (in-dict p)])
489-
`(/ (pow ,(vector-ref coeffs* (- num 1)) ,count)
490-
,(factorial count))))
491-
0))))))
470+
(make-series n
471+
0
472+
cache
473+
n*
474+
(let* ([coeffs* (list->vector (map coeffs (range 1 (+ n* 1))))]
475+
[nums (for/list ([i (in-range 1 (+ n* 1))]
476+
[coeff (in-vector coeffs*)]
477+
#:unless (equal? (deref coeff) 0))
478+
i)])
479+
`(+ ,@(for/list ([p (all-partitions n* (sort nums >))])
480+
(if (= (modulo (apply + (map car p)) 2) 0)
481+
`(* ,(if (= (modulo (apply + (map car p)) 4) 0) 1 -1)
482+
,@(for/list ([(count num) (in-dict p)])
483+
`(/ (pow ,(vector-ref coeffs* (- num 1)) ,count)
484+
,(factorial count))))
485+
0))))))
492486

493487
;; This is a hyper-specialized symbolic differentiator for log(f(x))
494488

0 commit comments

Comments
 (0)