Skip to content

Commit c6c7da7

Browse files
authored
Merge pull request #1397 from herbie-fp/codex/refactor-taylor-series-functions-to-use-series-struct-ra1who
Refactor Taylor series representation
2 parents 2951533 + 49cb4dc commit c6c7da7

File tree

1 file changed

+89
-64
lines changed

1 file changed

+89
-64
lines changed

src/core/taylor.rkt

Lines changed: 89 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@
4040
#:iters [iters 5])
4141
(define replacer (batch-replace-expression! batch var ((cdr tform) var)))
4242
(for/list ([ta taylor-approxs])
43-
(match-define (cons offset coeffs) ta)
43+
(define offset (series-offset ta))
4444
(define i 0)
4545
(define terms '())
4646

4747
(define (next [iter 0])
48-
(define coeff (reducer (replacer (coeffs i))))
48+
(define coeff (reducer (replacer (series-ref ta i))))
4949
(set! i (+ i 1))
5050
(match (deref coeff)
5151
[0
@@ -169,33 +169,33 @@
169169
[`(cbrt ,arg) (taylor-cbrt var (recurse arg))]
170170
[`(exp ,arg)
171171
(define arg* (normalize-series (recurse arg)))
172-
(if (positive? (car arg*))
172+
(if (positive? (series-offset arg*))
173173
(taylor-exact brf)
174174
(taylor-exp (zero-series arg*)))]
175175
[`(sin ,arg)
176176
(define arg* (normalize-series (recurse arg)))
177177
(cond
178-
[(positive? (car arg*)) (taylor-exact brf)]
179-
[(= (car arg*) 0)
178+
[(positive? (series-offset arg*)) (taylor-exact brf)]
179+
[(= (series-offset arg*) 0)
180180
; Our taylor-sin function assumes that a0 is 0,
181181
; because that way it is especially simple. We correct for this here
182182
; We use the identity sin (x + y) = sin x cos y + cos x sin y
183-
(taylor-add (taylor-mult (taylor-exact (adder `(sin ,((cdr arg*) 0))))
183+
(taylor-add (taylor-mult (taylor-exact (adder `(sin ,(series-ref arg* 0))))
184184
(taylor-cos (zero-series arg*)))
185-
(taylor-mult (taylor-exact (adder `(cos ,((cdr arg*) 0))))
185+
(taylor-mult (taylor-exact (adder `(cos ,(series-ref arg* 0))))
186186
(taylor-sin (zero-series arg*))))]
187187
[else (taylor-sin (zero-series arg*))])]
188188
[`(cos ,arg)
189189
(define arg* (normalize-series (recurse arg)))
190190
(cond
191-
[(positive? (car arg*)) (taylor-exact brf)]
192-
[(= (car arg*) 0)
191+
[(positive? (series-offset arg*)) (taylor-exact brf)]
192+
[(= (series-offset arg*) 0)
193193
; Our taylor-cos function assumes that a0 is 0,
194194
; because that way it is especially simple. We correct for this here
195195
; We use the identity cos (x + y) = cos x cos y - sin x sin y
196-
(taylor-add (taylor-mult (taylor-exact (adder `(cos ,((cdr arg*) 0))))
196+
(taylor-add (taylor-mult (taylor-exact (adder `(cos ,(series-ref arg* 0))))
197197
(taylor-cos (zero-series arg*)))
198-
(taylor-negate (taylor-mult (taylor-exact (adder `(sin ,((cdr arg*) 0))))
198+
(taylor-negate (taylor-mult (taylor-exact (adder `(sin ,(series-ref arg* 0))))
199199
(taylor-sin (zero-series arg*)))))]
200200
[else (taylor-cos (zero-series arg*))])]
201201
[`(log ,arg) (taylor-log var (recurse arg))]
@@ -204,11 +204,12 @@
204204
(taylor-pow (normalize-series (recurse base)) (deref power))]
205205
[_ (taylor-exact brf)]))))
206206

207-
; A taylor series is represented by a function f : nat -> expr,
208-
; representing the coefficients (the 1 / n! terms not included),
209-
; and an integer offset to the exponent
207+
; A taylor series is represented by a struct containing a coefficient builder,
208+
; a cache of computed coefficients, and an integer offset to the exponent
210209

211-
; (define term? (cons/c number? (-> number? batchref?)))
210+
; (define term? series?)
211+
212+
(struct series (offset f cache) #:transparent)
212213

213214
(define (taylor-exact . terms)
214215
;(->* () #:rest (listof batchref?) term?)
@@ -229,65 +230,75 @@
229230
n)))
230231

231232
(define (make-series offset builder)
232-
(define cache (make-dvector 10))
233-
(define fetch (curry dvector-ref cache))
234-
(define (lookup n)
235-
(when (>= n (dvector-length cache))
236-
(for ([i (in-range (dvector-length cache) (add1 n))])
237-
(define value (reducer (adder (builder fetch i))))
238-
(dvector-set! cache i value)))
239-
(dvector-ref cache n))
240-
(cons offset lookup))
233+
(series offset builder (make-dvector 10)))
234+
235+
(define (series-ref s n)
236+
(define cache (series-cache s))
237+
(define builder (series-f s))
238+
(define fetch (λ (i) (dvector-ref cache i)))
239+
(when (>= n (dvector-length cache))
240+
(for ([i (in-range (dvector-length cache) (add1 n))])
241+
(define value (reducer (adder (builder fetch i))))
242+
(dvector-set! cache i value)))
243+
(dvector-ref cache n))
244+
245+
(define (series-function s)
246+
(λ (n) (series-ref s n)))
241247

242248
(define (taylor-add left right)
243249
;(-> term? term? term?)
244-
(match-define (cons left-offset left-series) left)
245-
(match-define (cons right-offset right-series) right)
250+
(define left-offset (series-offset left))
251+
(define right-offset (series-offset right))
246252
(define target-offset (max left-offset right-offset))
247253
(define (align offset series)
248254
(define shift (- offset target-offset))
249255
(cond
250-
[(zero? shift) series]
256+
[(zero? shift) (series-function series)]
251257
[else
252258
(λ (n)
253259
(if (negative? (+ n shift))
254260
(adder 0)
255-
(series (+ n shift))))]))
256-
(define left* (align left-offset left-series))
257-
(define right* (align right-offset right-series))
261+
(series-ref series (+ n shift))))]))
262+
(define left* (align left-offset left))
263+
(define right* (align right-offset right))
258264
(make-series target-offset (λ (f n) (make-sum (list (left* n) (right* n))))))
259265

260266
(define (taylor-negate term)
261267
;(-> term? term?)
262-
(make-series (car term) (λ (f n) (list 'neg ((cdr term) n)))))
268+
(make-series (series-offset term) (λ (f n) (list 'neg (series-ref term n)))))
263269

264270
(define (taylor-mult left right)
265271
;(-> term? term? term?)
266-
(make-series (+ (car left) (car right))
272+
(make-series (+ (series-offset left) (series-offset right))
267273
(λ (f n)
268274
(make-sum (for/list ([i (range (+ n 1))])
269-
(list '* ((cdr left) i) ((cdr right) (- n i))))))))
275+
(list '* (series-ref left i) (series-ref right (- n i))))))))
270276

271-
(define (normalize-series series)
277+
(define (normalize-series s)
272278
;(-> term? term?)
273279
"Fixes up the series to have a non-zero zeroth term,
274280
allowing a possibly negative offset"
275-
(match-define (cons offset coeffs) series)
281+
(define offset (series-offset s))
282+
(define coeffs (series-function s))
276283
(define slack (first-nonzero-exp coeffs))
277-
(cons (- offset slack) (compose coeffs (curry + slack))))
284+
(if (zero? slack)
285+
s
286+
(make-series (- offset slack) (λ (f n) (deref (series-ref s (+ n slack)))))))
278287

279-
(define ((zero-series series) n)
280-
;(-> (cons/c number? (-> number? batchref?)) (-> number? batchref?))
281-
(if (< n (- (car series)))
288+
(define ((zero-series s) n)
289+
;(-> series? (-> number? batchref?))
290+
(if (< n (- (series-offset s)))
282291
(adder 0)
283-
((cdr series) (+ n (car series)))))
292+
(series-ref s (+ n (series-offset s)))))
284293

285294
(define (taylor-invert term)
286295
;(-> term? term?)
287296
"This gets tricky, because the function might have a pole at 0.
288297
This happens if the inverted series doesn't have a constant term,
289298
so we extract that case out."
290-
(match-define (cons offset b) (normalize-series term))
299+
(define normalized (normalize-series term))
300+
(define offset (series-offset normalized))
301+
(define b (series-function normalized))
291302
(make-series (- offset)
292303
(λ (f n)
293304
(if (zero? n)
@@ -300,8 +311,12 @@
300311
"This gets tricky, because the function might have a pole at 0.
301312
This happens if the inverted series doesn't have a constant term,
302313
so we extract that case out."
303-
(match-define (cons noff a) (normalize-series num))
304-
(match-define (cons doff b) (normalize-series denom))
314+
(define normalized-num (normalize-series num))
315+
(define normalized-denom (normalize-series denom))
316+
(define noff (series-offset normalized-num))
317+
(define doff (series-offset normalized-denom))
318+
(define a (series-function normalized-num))
319+
(define b (series-function normalized-denom))
305320
(make-series (- noff doff)
306321
(λ (f n)
307322
(if (zero? n)
@@ -312,27 +327,33 @@
312327

313328
(define (modulo-series var n series)
314329
;(-> symbol? number? term? term?)
315-
(match-define (cons offset coeffs) (normalize-series series))
330+
(define normalized (normalize-series series))
331+
(define offset (series-offset normalized))
332+
(define coeffs (series-function normalized))
316333
(define offset* (+ offset (modulo (- offset) n)))
317-
(define cache (make-dvector 2)) ;; never called mor than twice
318-
(define (coeffs* i)
319-
(unless (and (> (dvector-capacity cache) i) (dvector-ref cache i))
320-
(define res
321-
(match i
322-
[0
323-
(adder (make-sum (for/list ([j (in-range (modulo offset n))])
324-
`(* ,(coeffs j) (pow ,var ,(+ j (modulo (- offset) n)))))))]
325-
[_
326-
#:when (< i n)
327-
(adder 0)]
328-
[_ (coeffs (+ (- i n) (modulo offset n)))]))
329-
(dvector-set! cache i res))
330-
(dvector-ref cache i))
331-
(cons offset* (if (= offset offset*) coeffs coeffs*)))
334+
(if (= offset offset*)
335+
normalized
336+
(let ([cache (make-dvector 2)]) ;; never called more than twice
337+
(define (coeffs* i)
338+
(unless (and (> (dvector-capacity cache) i) (dvector-ref cache i))
339+
(define res
340+
(match i
341+
[0
342+
(adder (make-sum (for/list ([j (in-range (modulo offset n))])
343+
`(* ,(coeffs j) (pow ,var ,(+ j (modulo (- offset) n)))))))]
344+
[_
345+
#:when (< i n)
346+
(adder 0)]
347+
[_ (coeffs (+ (- i n) (modulo offset n)))]))
348+
(dvector-set! cache i res))
349+
(dvector-ref cache i))
350+
(make-series offset* (λ (f i) (deref (coeffs* i)))))))
332351

333352
(define (taylor-sqrt var num)
334353
;(-> symbol? term? term?)
335-
(match-define (cons offset* coeffs*) (modulo-series var 2 num))
354+
(define normalized (modulo-series var 2 num))
355+
(define offset* (series-offset normalized))
356+
(define coeffs* (series-function normalized))
336357
(make-series (/ offset* 2)
337358
(λ (f n)
338359
(cond
@@ -354,7 +375,9 @@
354375

355376
(define (taylor-cbrt var num)
356377
;(-> symbol? term? term?)
357-
(match-define (cons offset* coeffs*) (modulo-series var 3 num))
378+
(define normalized (modulo-series var 3 num))
379+
(define offset* (series-offset normalized))
380+
(define coeffs* (series-function normalized))
358381
(make-series (/ offset* 3)
359382
(λ (f n)
360383
(cond
@@ -487,7 +510,9 @@
487510

488511
(define (taylor-log var arg)
489512
;(-> symbol? term? term?)
490-
(match-define (cons shift coeffs) (normalize-series arg))
513+
(define normalized (normalize-series arg))
514+
(define shift (series-offset normalized))
515+
(define coeffs (series-function normalized))
491516
(define negate? (and (number? (deref (coeffs 0))) (not (positive? (deref (coeffs 0))))))
492517
(define (maybe-negate x)
493518
(if negate?
@@ -527,7 +552,7 @@
527552
[add (λ (x) (batch-add! batch x))])
528553
(define brfs* (map (expand-taylor! batch) brfs))
529554
(define brf (car brfs*))
530-
(check-pred exact-integer? (car ((taylor 'x batch) brf)))))
555+
(check-pred exact-integer? (series-offset ((taylor 'x batch) brf)))))
531556

532557
(module+ test
533558
(require "batch-reduce.rkt")
@@ -537,7 +562,7 @@
537562
[add (λ (x) (batch-add! batch x))])
538563
(define brfs* (map (expand-taylor! batch) brfs))
539564
(define brf (car brfs*))
540-
(match-define fn (zero-series ((taylor 'x batch) brf)))
565+
(define fn (zero-series ((taylor 'x batch) brf)))
541566
(map batch-pull (build-list n fn))))
542567
(check-equal? (coeffs '(sin x)) '(0 1 0 -1/6 0 1/120 0))
543568
(check-equal? (coeffs '(sqrt (+ 1 x))) '(1 1/2 -1/8 1/16 -5/128 7/256 -21/1024))

0 commit comments

Comments
 (0)