Skip to content

Commit 21d4cc5

Browse files
committed
Merge branch 'main' into 1-split-rule
2 parents 16e6a4a + 4dceae1 commit 21d4cc5

File tree

15 files changed

+459
-204
lines changed

15 files changed

+459
-204
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,4 @@ herbie-compiled/
2626

2727
# Python
2828
.env/
29+
.worktrees

AGENTS.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11

22
# Testing
33

4-
- Run `make fmt` to format the code before presentingcode. This is
4+
- Run `make fmt` to format the code before finishing a task. This is
55
mandatory and PRs that don't follow the coding style are rejected.
6+
- Always check your `git diff` before finishing a task. There's often
7+
leftover or dead code, and you should delete it.
68
- Run `racket src/main.rkt report bench/tutorial.fpcore tmp` to test
79
that your changes work; this should take about 5-10 seconds and all
810
of the tests should pass, getting basically perfect accuracy.

bench/physics/kalman.fpcore

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
(FPCore (dt r)
2+
:name "Kalman filter per K"
3+
:pre (and (> dt 0) (> r 0))
4+
; initializing matrices
5+
(let ([p00 25.0]
6+
[p01 0.0]
7+
[p02 0.0]
8+
[p10 0.0]
9+
[p11 10.0]
10+
[p12 0.0]
11+
[p20 0.0]
12+
[p21 0.0]
13+
[p22 1.0]
14+
15+
[f00 1.0]
16+
[f01 dt]
17+
[f02 (* 0.5 (* dt dt))]
18+
[f10 0.0]
19+
[f11 1.0]
20+
[f12 dt]
21+
[f20 0.0]
22+
[f21 0.0]
23+
[f22 1.0]
24+
25+
[q00 (* 0.25 (* dt (* dt (* dt dt))))]
26+
[q01 (* 0.5 (* dt (* dt dt)))]
27+
[q02 (* 0.5 (* dt dt))]
28+
[q10 (* 0.5 (* dt (* dt dt)))]
29+
[q11 (* dt dt)]
30+
[q12 dt]
31+
[q20 (* 0.5 (* dt dt))]
32+
[q21 dt]
33+
[q22 1])
34+
; axbT_33(P, F, P)
35+
(let ([p00* (+ (+ (* p00 f00) (* p01 f01)) (* p02 f02))]
36+
[p01* (+ (+ (* p00 f10) (* p01 f11)) (* p02 f12))]
37+
[p02* (+ (+ (* p00 f20) (* p01 f21)) (* p02 f22))]
38+
[p10* (+ (+ (* p10 f00) (* p11 f01)) (* p12 f02))]
39+
[p11* (+ (+ (* p10 f10) (* p11 f11)) (* p12 f12))]
40+
[p12* (+ (+ (* p10 f20) (* p11 f21)) (* p12 f22))]
41+
[p20* (+ (+ (* p20 f00) (* p21 f01)) (* p22 f02))]
42+
[p21* (+ (+ (* p20 f10) (* p21 f11)) (* p22 f12))]
43+
[p22* (+ (+ (* p20 f20) (* p21 f21)) (* p22 f22))])
44+
; axb_33(F, P, P)
45+
(let ([p00** (+ (+ (* f00 p00*) (* f01 p10*)) (* f02 p20*))]
46+
[p01** (+ (+ (* f00 p01*) (* f01 p11*)) (* f02 p21*))]
47+
[p02** (+ (+ (* f00 p02*) (* f01 p12*)) (* f02 p22*))]
48+
[p10** (+ (+ (* f10 p00*) (* f11 p10*)) (* f12 p20*))]
49+
[p11** (+ (+ (* f10 p01*) (* f11 p11*)) (* f12 p21*))]
50+
[p12** (+ (+ (* f10 p02*) (* f11 p12*)) (* f12 p22*))]
51+
[p20** (+ (+ (* f20 p00*) (* f21 p10*)) (* f22 p20*))]
52+
[p21** (+ (+ (* f20 p01*) (* f21 p11*)) (* f22 p21*))]
53+
[p22** (+ (+ (* f20 p02*) (* f21 p12*)) (* f22 p22*))])
54+
; a_add_b_33(P, Q, P)
55+
(let ([p00*** (+ p00** q00)]
56+
[p01*** (+ p01** q01)]
57+
[p02*** (+ p02** q02)]
58+
[p10*** (+ p10** q10)]
59+
[p11*** (+ p11** q11)]
60+
[p12*** (+ p12** q12)]
61+
[p20*** (+ p20** q20)]
62+
[p21*** (+ p21** q21)]
63+
[p22*** (+ p22** q22)])
64+
; update_K
65+
(let ([K0 (/ p00*** (+ p00*** r))]
66+
[K1 (/ p10*** (+ p00*** r))]
67+
[K2 (/ p20*** (+ p00*** r))])
68+
K0))))))
69+
70+
(FPCore (x0 x1 x2 dt r sensor)
71+
:name "Kalman filter per x"
72+
:pre (and (> dt 0) (> r 0) (> sensor 0))
73+
; initializing matrices
74+
(let ([p00 25.0]
75+
[p01 0.0]
76+
[p02 0.0]
77+
[p10 0.0]
78+
[p11 10.0]
79+
[p12 0.0]
80+
[p20 0.0]
81+
[p21 0.0]
82+
[p22 1.0]
83+
84+
[f00 1.0]
85+
[f01 dt]
86+
[f02 (* 0.5 (* dt dt))]
87+
[f10 0.0]
88+
[f11 1.0]
89+
[f12 dt]
90+
[f20 0.0]
91+
[f21 0.0]
92+
[f22 1.0]
93+
94+
[q00 (* 0.25 (* dt (* dt (* dt dt))))]
95+
[q01 (* 0.5 (* dt (* dt dt)))]
96+
[q02 (* 0.5 (* dt dt))]
97+
[q10 (* 0.5 (* dt (* dt dt)))]
98+
[q11 (* dt dt)]
99+
[q12 dt]
100+
[q20 (* 0.5 (* dt dt))]
101+
[q21 dt]
102+
[q22 1])
103+
; axbT_33(P, F, P)
104+
(let ([p00* (+ (+ (* p00 f00) (* p01 f01)) (* p02 f02))]
105+
[p01* (+ (+ (* p00 f10) (* p01 f11)) (* p02 f12))]
106+
[p02* (+ (+ (* p00 f20) (* p01 f21)) (* p02 f22))]
107+
[p10* (+ (+ (* p10 f00) (* p11 f01)) (* p12 f02))]
108+
[p11* (+ (+ (* p10 f10) (* p11 f11)) (* p12 f12))]
109+
[p12* (+ (+ (* p10 f20) (* p11 f21)) (* p12 f22))]
110+
[p20* (+ (+ (* p20 f00) (* p21 f01)) (* p22 f02))]
111+
[p21* (+ (+ (* p20 f10) (* p21 f11)) (* p22 f12))]
112+
[p22* (+ (+ (* p20 f20) (* p21 f21)) (* p22 f22))])
113+
; axb_33(F, P, P)
114+
(let ([p00** (+ (+ (* f00 p00*) (* f01 p10*)) (* f02 p20*))]
115+
[p01** (+ (+ (* f00 p01*) (* f01 p11*)) (* f02 p21*))]
116+
[p02** (+ (+ (* f00 p02*) (* f01 p12*)) (* f02 p22*))]
117+
[p10** (+ (+ (* f10 p00*) (* f11 p10*)) (* f12 p20*))]
118+
[p11** (+ (+ (* f10 p01*) (* f11 p11*)) (* f12 p21*))]
119+
[p12** (+ (+ (* f10 p02*) (* f11 p12*)) (* f12 p22*))]
120+
[p20** (+ (+ (* f20 p00*) (* f21 p10*)) (* f22 p20*))]
121+
[p21** (+ (+ (* f20 p01*) (* f21 p11*)) (* f22 p21*))]
122+
[p22** (+ (+ (* f20 p02*) (* f21 p12*)) (* f22 p22*))])
123+
; a_add_b_33(P, Q, P)
124+
(let ([p00*** (+ p00** q00)]
125+
[p01*** (+ p01** q01)]
126+
[p02*** (+ p02** q02)]
127+
[p10*** (+ p10** q10)]
128+
[p11*** (+ p11** q11)]
129+
[p12*** (+ p12** q12)]
130+
[p20*** (+ p20** q20)]
131+
[p21*** (+ p21** q21)]
132+
[p22*** (+ p22** q22)])
133+
; update_K
134+
(let ([K0 (/ p00*** (+ p00*** r))]
135+
[K1 (/ p10*** (+ p00*** r))]
136+
[K2 (/ p20*** (+ p00*** r))])
137+
; predict_x
138+
(let ([x0* (+ x0 (+ x1 (* 0.5 (* dt (* dt x2)))))]
139+
[x1* (+ x1 x2)]
140+
[x2* x2])
141+
(let ([y (- sensor x0*)])
142+
; update_x
143+
(let ([x0** (+ x0* (* K0 y))]
144+
[x1** (+ x1* (* K1 y))]
145+
[x2** (+ x2* (* K2 y))])
146+
x0**)))))))))
147+
148+
(FPCore (dt r)
149+
:name "Kalman filter per P"
150+
:pre (and (> dt 0) (> r 0))
151+
; initializing matrices
152+
(let ([p00 25.0]
153+
[p01 0.0]
154+
[p02 0.0]
155+
[p10 0.0]
156+
[p11 10.0]
157+
[p12 0.0]
158+
[p20 0.0]
159+
[p21 0.0]
160+
[p22 1.0]
161+
162+
[f00 1.0]
163+
[f01 dt]
164+
[f02 (* 0.5 (* dt dt))]
165+
[f10 0.0]
166+
[f11 1.0]
167+
[f12 dt]
168+
[f20 0.0]
169+
[f21 0.0]
170+
[f22 1.0]
171+
172+
[q00 (* 0.25 (* dt (* dt (* dt dt))))]
173+
[q01 (* 0.5 (* dt (* dt dt)))]
174+
[q02 (* 0.5 (* dt dt))]
175+
[q10 (* 0.5 (* dt (* dt dt)))]
176+
[q11 (* dt dt)]
177+
[q12 dt]
178+
[q20 (* 0.5 (* dt dt))]
179+
[q21 dt]
180+
[q22 1])
181+
; axbT_33(P, F, P)
182+
(let ([p00* (+ (+ (* p00 f00) (* p01 f01)) (* p02 f02))]
183+
[p01* (+ (+ (* p00 f10) (* p01 f11)) (* p02 f12))]
184+
[p02* (+ (+ (* p00 f20) (* p01 f21)) (* p02 f22))]
185+
[p10* (+ (+ (* p10 f00) (* p11 f01)) (* p12 f02))]
186+
[p11* (+ (+ (* p10 f10) (* p11 f11)) (* p12 f12))]
187+
[p12* (+ (+ (* p10 f20) (* p11 f21)) (* p12 f22))]
188+
[p20* (+ (+ (* p20 f00) (* p21 f01)) (* p22 f02))]
189+
[p21* (+ (+ (* p20 f10) (* p21 f11)) (* p22 f12))]
190+
[p22* (+ (+ (* p20 f20) (* p21 f21)) (* p22 f22))])
191+
; axb_33(F, P, P)
192+
(let ([p00** (+ (+ (* f00 p00*) (* f01 p10*)) (* f02 p20*))]
193+
[p01** (+ (+ (* f00 p01*) (* f01 p11*)) (* f02 p21*))]
194+
[p02** (+ (+ (* f00 p02*) (* f01 p12*)) (* f02 p22*))]
195+
[p10** (+ (+ (* f10 p00*) (* f11 p10*)) (* f12 p20*))]
196+
[p11** (+ (+ (* f10 p01*) (* f11 p11*)) (* f12 p21*))]
197+
[p12** (+ (+ (* f10 p02*) (* f11 p12*)) (* f12 p22*))]
198+
[p20** (+ (+ (* f20 p00*) (* f21 p10*)) (* f22 p20*))]
199+
[p21** (+ (+ (* f20 p01*) (* f21 p11*)) (* f22 p21*))]
200+
[p22** (+ (+ (* f20 p02*) (* f21 p12*)) (* f22 p22*))])
201+
; a_add_b_33(P, Q, P)
202+
(let ([p00*** (+ p00** q00)]
203+
[p01*** (+ p01** q01)]
204+
[p02*** (+ p02** q02)]
205+
[p10*** (+ p10** q10)]
206+
[p11*** (+ p11** q11)]
207+
[p12*** (+ p12** q12)]
208+
[p20*** (+ p20** q20)]
209+
[p21*** (+ p21** q21)]
210+
[p22*** (+ p22** q22)])
211+
; update_K
212+
(let ([K0 (/ p00*** (+ p00*** r))]
213+
[K1 (/ p10*** (+ p00*** r))]
214+
[K2 (/ p20*** (+ p00*** r))])
215+
; update_P
216+
(let ([eyekh00 (- 1 K0)]
217+
[eyekh01 0.0]
218+
[eyekh02 0.0]
219+
[eyekh10 (- K1)]
220+
[eyekh11 1.0]
221+
[eyekh12 0.0]
222+
[eyekh20 (- K2)]
223+
[eyekh21 0.0]
224+
[eyekh22 1.0])
225+
; axb_33(&eyekh, P, P)
226+
(let ([p00**** (+ (+ (* eyekh00 p00***) (* eyekh01 p10***)) (* eyekh02 p20***))]
227+
[p01**** (+ (+ (* eyekh00 p01***) (* eyekh01 p11***)) (* eyekh02 p21***))]
228+
[p02**** (+ (+ (* eyekh00 p02***) (* eyekh01 p12***)) (* eyekh02 p22***))]
229+
[p10**** (+ (+ (* eyekh10 p00***) (* eyekh11 p10***)) (* eyekh12 p20***))]
230+
[p11**** (+ (+ (* eyekh10 p01***) (* eyekh11 p11***)) (* eyekh12 p21***))]
231+
[p12**** (+ (+ (* eyekh10 p02***) (* eyekh11 p12***)) (* eyekh12 p22***))]
232+
[p20**** (+ (+ (* eyekh20 p00***) (* eyekh21 p10***)) (* eyekh22 p20***))]
233+
[p21**** (+ (+ (* eyekh20 p01***) (* eyekh21 p11***)) (* eyekh22 p21***))]
234+
[p22**** (+ (+ (* eyekh20 p02***) (* eyekh21 p12***)) (* eyekh22 p22***))])
235+
p00****))))))))

src/core/batch-reduce.rkt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@
167167
(define (gather-multiplicative-terms brf recurse)
168168
(match (deref brf)
169169
[+nan.0 (nan-term)]
170-
[(? number? n) `(,n . ())]
170+
[(? number? n) (list n)]
171171
[(? symbol?) `(1 . ((1 . ,brf)))]
172172
[`(neg ,arg)
173173
(define terms (recurse arg))
@@ -201,10 +201,10 @@
201201
(cons exact-cbrt
202202
(for/list ([term (cdr terms)])
203203
(cons (/ (car term) 3) (cdr term))))
204-
(cons 1
205-
(list* (cons 1 (batch-add! batch `(cbrt ,(car terms))))
206-
(for/list ([term (cdr terms)])
207-
(cons (/ (car term) 3) (cdr term))))))])]
204+
(list* 1
205+
(cons 1 (batch-add! batch `(cbrt ,(car terms))))
206+
(for/list ([term (cdr terms)])
207+
(cons (/ (car term) 3) (cdr term)))))])]
208208
[`(pow ,arg ,(app deref 0))
209209
(define terms (recurse arg))
210210
(if (equal? (car terms) +nan.0)

src/core/egg-herbie.rkt

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,10 @@
288288
[(list 'Rewrite=> rule expr) (list 'Rewrite=> (get-canon-rule-name rule rule) (loop expr type))]
289289
[(list 'Rewrite<= rule expr) (list 'Rewrite<= (get-canon-rule-name rule rule) (loop expr type))]
290290
[(list op args ...)
291-
#:when (string-contains? (~a op) "unsound")
292-
(define op* (string->symbol (string-replace (symbol->string (car expr)) "unsound-" "")))
293-
(cons op* (map loop args (map (const 'real) args)))]
291+
#:when (string-prefix? (symbol->string op) "sound-")
292+
(define op* (string->symbol (substring (symbol->string (car expr)) (string-length "sound-"))))
293+
(define args* (drop-right args 1))
294+
(cons op* (map loop args* (map (const 'real) args*)))]
294295
[(list op args ...)
295296
;; Unfortunately the type parameter doesn't tell us much because mixed exprs exist
296297
;; so if we see something like (and a b) we literally don't know which "and" it is
@@ -333,6 +334,8 @@
333334
(check-equal? out expected-out)
334335
(check-equal? computed-in in)))
335336

337+
(check-equal? (egg-expr->expr '(sound-sqrt $var0 $var1) ctx) '(sqrt x))
338+
336339
(set! ctx (context '(x a b c r) <binary64> (make-list 5 <binary64>)))
337340
(define extended-expr-list
338341
; specifications
@@ -522,17 +525,16 @@
522525
;; Synthesizes lowering rules for a given platform.
523526
(define (platform-lowering-rules [pform (*active-platform*)])
524527
(define impls (platform-impls pform))
525-
(append* (for/list ([impl (in-list impls)])
526-
(hash-ref! (*lowering-rules*)
527-
(cons impl pform)
528-
(lambda ()
529-
(define name (sym-append 'lower- impl))
530-
(define-values (vars spec-expr impl-expr) (impl->rule-parts impl))
531-
(list (rule name spec-expr impl-expr '(lowering))
532-
(rule (sym-append 'lower-unsound- impl)
533-
(add-unsound spec-expr)
534-
impl-expr
535-
'(lowering))))))))
528+
(append*
529+
(for/list ([impl (in-list impls)])
530+
(hash-ref!
531+
(*lowering-rules*)
532+
(cons impl pform)
533+
(lambda ()
534+
(define name (sym-append 'lower- impl))
535+
(define-values (vars spec-expr impl-expr) (impl->rule-parts impl))
536+
(list (rule name spec-expr impl-expr '(lowering))
537+
(rule (sym-append 'lower-sound- impl) (add-sound spec-expr) impl-expr '(lowering))))))))
536538

537539
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
538540
;; Racket egraph
@@ -575,7 +577,7 @@
575577
[(cons f _) ; application
576578
(cond
577579
[(eq? f '$approx) (platform-reprs (*active-platform*))]
578-
[(string-contains? (~a f) "unsound") (list 'real)]
580+
[(string-prefix? (symbol->string f) "sound-") (list 'real)]
579581
[else
580582
(filter values
581583
(list (and (impl-exists? f) (impl-info f 'otype))
@@ -592,9 +594,11 @@
592594
(define spec (u32vector-ref ids 0))
593595
(define impl (u32vector-ref ids 1))
594596
(list '$approx (lookup spec (representation-type type)) (lookup impl type))]
595-
[(string-contains? (~a f) "unsound")
596-
(define op (string->symbol (string-replace (symbol->string f) "unsound-" "")))
597-
(list* op (map (λ (x) (lookup (u32vector-ref ids x) 'real)) (range (u32vector-length ids))))]
597+
[(string-prefix? (~a f) "sound-")
598+
(define op (string->symbol (substring (symbol->string f) (string-length "sound-"))))
599+
(list* op
600+
(map (λ (x) (lookup (u32vector-ref ids x) 'real))
601+
(range (- (u32vector-length ids) 1))))]
598602
[else
599603
(define itypes
600604
(cond

0 commit comments

Comments
 (0)