Skip to content

Commit 4dceae1

Browse files
authored
Merge pull request #1288 from herbie-fp/codex/add-extraction-for-sound--operations
Add sound- operation extraction
2 parents 8ff70ad + bfa4abd commit 4dceae1

File tree

8 files changed

+88
-83
lines changed

8 files changed

+88
-83
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,4 @@ herbie-compiled/
2626

2727
# Python
2828
.env/
29+
.worktrees

AGENTS.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11

22
# Testing
33

4-
- Run `make fmt` to format the code before presentingcode. This is
4+
- Run `make fmt` to format the code before finishing a task. This is
55
mandatory and PRs that don't follow the coding style are rejected.
6+
- Always check your `git diff` before finishing a task. There's often
7+
leftover or dead code, and you should delete it.
68
- Run `racket src/main.rkt report bench/tutorial.fpcore tmp` to test
79
that your changes work; this should take about 5-10 seconds and all
810
of the tests should pass, getting basically perfect accuracy.

src/core/egg-herbie.rkt

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,10 @@
288288
[(list 'Rewrite=> rule expr) (list 'Rewrite=> (get-canon-rule-name rule rule) (loop expr type))]
289289
[(list 'Rewrite<= rule expr) (list 'Rewrite<= (get-canon-rule-name rule rule) (loop expr type))]
290290
[(list op args ...)
291-
#:when (string-contains? (~a op) "unsound")
292-
(define op* (string->symbol (string-replace (symbol->string (car expr)) "unsound-" "")))
293-
(cons op* (map loop args (map (const 'real) args)))]
291+
#:when (string-prefix? (symbol->string op) "sound-")
292+
(define op* (string->symbol (substring (symbol->string (car expr)) (string-length "sound-"))))
293+
(define args* (drop-right args 1))
294+
(cons op* (map loop args* (map (const 'real) args*)))]
294295
[(list op args ...)
295296
;; Unfortunately the type parameter doesn't tell us much because mixed exprs exist
296297
;; so if we see something like (and a b) we literally don't know which "and" it is
@@ -333,6 +334,8 @@
333334
(check-equal? out expected-out)
334335
(check-equal? computed-in in)))
335336

337+
(check-equal? (egg-expr->expr '(sound-sqrt $var0 $var1) ctx) '(sqrt x))
338+
336339
(set! ctx (context '(x a b c r) <binary64> (make-list 5 <binary64>)))
337340
(define extended-expr-list
338341
; specifications
@@ -522,17 +525,16 @@
522525
;; Synthesizes lowering rules for a given platform.
523526
(define (platform-lowering-rules [pform (*active-platform*)])
524527
(define impls (platform-impls pform))
525-
(append* (for/list ([impl (in-list impls)])
526-
(hash-ref! (*lowering-rules*)
527-
(cons impl pform)
528-
(lambda ()
529-
(define name (sym-append 'lower- impl))
530-
(define-values (vars spec-expr impl-expr) (impl->rule-parts impl))
531-
(list (rule name spec-expr impl-expr '(lowering))
532-
(rule (sym-append 'lower-unsound- impl)
533-
(add-unsound spec-expr)
534-
impl-expr
535-
'(lowering))))))))
528+
(append*
529+
(for/list ([impl (in-list impls)])
530+
(hash-ref!
531+
(*lowering-rules*)
532+
(cons impl pform)
533+
(lambda ()
534+
(define name (sym-append 'lower- impl))
535+
(define-values (vars spec-expr impl-expr) (impl->rule-parts impl))
536+
(list (rule name spec-expr impl-expr '(lowering))
537+
(rule (sym-append 'lower-sound- impl) (add-sound spec-expr) impl-expr '(lowering))))))))
536538

537539
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
538540
;; Racket egraph
@@ -575,7 +577,7 @@
575577
[(cons f _) ; application
576578
(cond
577579
[(eq? f '$approx) (platform-reprs (*active-platform*))]
578-
[(string-contains? (~a f) "unsound") (list 'real)]
580+
[(string-prefix? (symbol->string f) "sound-") (list 'real)]
579581
[else
580582
(filter values
581583
(list (and (impl-exists? f) (impl-info f 'otype))
@@ -592,9 +594,11 @@
592594
(define spec (u32vector-ref ids 0))
593595
(define impl (u32vector-ref ids 1))
594596
(list '$approx (lookup spec (representation-type type)) (lookup impl type))]
595-
[(string-contains? (~a f) "unsound")
596-
(define op (string->symbol (string-replace (symbol->string f) "unsound-" "")))
597-
(list* op (map (λ (x) (lookup (u32vector-ref ids x) 'real)) (range (u32vector-length ids))))]
597+
[(string-prefix? (~a f) "sound-")
598+
(define op (string->symbol (substring (symbol->string f) (string-length "sound-"))))
599+
(list* op
600+
(map (λ (x) (lookup (u32vector-ref ids x) 'real))
601+
(range (- (u32vector-length ids) 1))))]
598602
[else
599603
(define itypes
600604
(cond

src/core/egglog-herbie.rkt

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -437,15 +437,14 @@
437437
(hash-set! (id->e1) unsound-op (serialize-op unsound-op))
438438
(hash-set! (e1->id) (serialize-op unsound-op) unsound-op)
439439

440+
(define sound-op (sym-append "sound-" op))
441+
(hash-set! (id->e1) sound-op (serialize-op sound-op))
442+
(hash-set! (e1->id) (serialize-op sound-op) sound-op)
443+
440444
(define arity (length (operator-info op 'itype)))
441-
(list `(,(serialize-op op) ,@(for/list ([i (in-range arity)])
442-
'M)
443-
:cost
444-
4294967295)
445-
`(,(serialize-op unsound-op) ,@(for/list ([i (in-range arity)])
446-
'M)
447-
:cost
448-
4294967295)))))
445+
(list `(,(serialize-op op) ,@(make-list arity 'M) :cost 4294967295)
446+
`(,(serialize-op sound-op) ,@(make-list arity 'M) M :cost 4294967295)
447+
`(,(serialize-op unsound-op) ,@(make-list arity 'M) :cost 4294967295)))))
449448

450449
(define (platform-impl-nodes pform min-cost)
451450
(for/list ([impl (in-list (platform-impls pform))])

src/core/preprocess.rkt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,12 @@
6868
(make-sort-identities spec ctx)))
6969

7070
;; make egg runner
71-
(define rules (*sound-rules*))
72-
7371
(define-values (batch brfs) (progs->batch (cons spec (map cdr identities))))
7472
(define runner
7573
(make-egraph batch
7674
brfs
7775
(make-list (length brfs) (context-repr ctx))
78-
`((,rules . ((node . ,(*node-limit*)))))
76+
`((,(*rules*) . ((node . ,(*node-limit*)))))
7977
ctx))
8078

8179
;; collect equalities

src/core/prove-rules.rkt

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292

9393
[`(< (* ,a ,a) 0) '(FALSE)]
9494
[`(< (sqrt ,a) 0) '(FALSE)]
95+
[`(< (fabs ,a) 0) '(FALSE)]
9596
[`(,(or '< '==) (cosh ,a) ,(? (conjoin number? (curryr < 1)))) '(FALSE)]
9697
[`(,(or '< '==) (exp ,a) ,(? (conjoin number? (curryr <= 0)))) '(FALSE)]
9798
[`(,(or '< '==) (* ,a ,a) ,(? (conjoin number? (curryr < 0)))) '(FALSE)]
@@ -159,29 +160,34 @@
159160
xs
160161
(simplify-conditions simple1)))
161162

163+
;; The prover must prove: rhs-bad => lhs-bad
164+
;; IOW we can weaken the RHS or strengthen the LHS
165+
162166
(define soundness-proofs
163167
'((pow-plus (implies (< b -1) (< b 0)))
164168
(pow-sqr (implies (even-denominator? (* 2 b)) (even-denominator? b)))
165169
(hang-0p-tan (implies (== (cos (/ a 2)) 0) (== (cos a) -1)))
166-
(hang-0p-tan-rev (implies (== (cos (/ a 2)) 0) (== (cos a) -1)))
167-
(hang-0m-tan (implies (== (cos (/ a 2)) 0) (== (cos a) -1)))
168-
(hang-0m-tan-rev (implies (== (cos (/ a 2)) 0) (== (cos a) -1)))
170+
(hang-0p-tan-rev (implies (== (cos a) -1) (== (cos (/ a 2)) 0)))
171+
(hang-0m-tan (implies (== (cos (/ (neg a) 2)) 0) (== (cos a) -1)))
172+
(hang-0m-tan-rev (implies (== (cos a) -1) (== (cos (/ (neg a) 2)) 0)))
169173
(tanh-sum (implies (== (* (tanh x) (tanh y)) -1) (FALSE)))
170174
(tanh-def-a (implies (== (+ (exp x) (exp (neg x))) 0) (FALSE)))
171-
(acosh-def (implies (< x 1) (or (< x -1) (== x -1) (< (fabs x) 1))))
175+
(acosh-def (implies (< x -1) (< x 1))
176+
(implies (== x -1) (< x 1))
177+
(implies (< (fabs x) 1) (< x 1)))
172178
(acosh-def-rev (implies (< x 1) (or (< x -1) (== x -1) (< (fabs x) 1))))
173179
(sqrt-undiv (implies (< (/ x y) 0) (or (< x 0) (< y 0))))
174180
(sqrt-unprod (implies (< (* x y) 0) (or (< x 0) (< y 0))))
181+
(sqrt-pow2 (implies (and (< x 0) _) (< x 0)))
175182
(tan-sum-rev (implies (== (cos (+ x y)) 0) (== (* (tan x) (tan y)) 1)))
176183
(sum-log (implies (< (* x y) 0) (or (< x 0) (< y 0))))
177184
(diff-log (implies (< (/ x y) 0) (or (< x 0) (< y 0))))
178185
(exp-to-pow (implies (and a b) a))
179-
(sinh-acosh (implies (< (fabs x) 1) (< x 1)))
180186
(acosh-2-rev (implies (< (fabs x) 1) (< x 1)))
181187
(tanh-acosh (implies (< (fabs x) 1) (< x 1)) (implies (== x 0) (< x 1)))
188+
(sinh-acosh (implies (< (fabs x) 1) (< x 1)))
182189
(hang-p0-tan (implies (== (cos (/ a 2)) 0) (== (sin a) 0)))
183190
(hang-m0-tan (implies (== (cos (/ a 2)) 0) (== (sin a) 0)))
184-
(sqrt-pow2 (implies (and a b) a))
185191
(pow-div (implies (< (- b c) 0) (or (< b 0) (> c 0)))
186192
(implies (even-denominator? (- b c)) (or (even-denominator? b) (even-denominator? c))))
187193
(pow-prod-up (implies (< (+ b c) 0) (or (< b 0) (< c 0)))
@@ -196,16 +202,15 @@
196202
(simplify-conditions (map (curryr rewrite-all a b) terms))))
197203

198204
(define (rewrite-unsound? lhs rhs [proof '()])
199-
(define lhs-bad (execute-proof proof (undefined-conditions lhs)))
205+
(define lhs-bad (simplify-conditions (undefined-conditions lhs)))
200206
(define rhs-bad (execute-proof proof (undefined-conditions rhs)))
201207
(define extra (set-remove (set-subtract rhs-bad lhs-bad) '(FALSE)))
202208
(if (empty? extra)
203209
(values #f #f)
204210
(values lhs-bad extra)))
205211

206212
(define (potentially-unsound)
207-
(define num 0)
208-
(for ([rule (in-list (*sound-rules*))])
213+
(for ([rule (in-list (*rules*))])
209214
(test-case (~a (rule-name rule))
210215
(define proof (dict-ref soundness-proofs (rule-name rule) '()))
211216
(define-values (lhs-bad rhs-bad) (rewrite-unsound? (rule-input rule) (rule-output rule) proof))

src/core/rules.rkt

Lines changed: 29 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
"../syntax/syntax.rkt")
77

88
(provide *rules*
9-
*sound-rules*
109
(struct-out rule)
11-
add-unsound)
10+
add-sound)
1211

1312
;; A rule represents "find-and-replacing" `input` by `output`. Both
1413
;; are patterns, meaning that symbols represent pattern variables.
@@ -22,26 +21,18 @@
2221
(define (rule-enabled? rule)
2322
(ormap (curry flag-set? 'rules) (rule-tags rule)))
2423

25-
(define (rule-sound? rule)
26-
(set-member? (rule-tags rule) 'sound))
27-
2824
(define (*rules*)
2925
(filter rule-enabled? *all-rules*))
3026

31-
(define (*sound-rules*)
32-
(filter (conjoin rule-enabled? rule-sound?) *all-rules*))
33-
34-
(define (add-unsound expr)
27+
(define (add-sound expr)
3528
(match expr
36-
[(list op args ...) (cons (sym-append "unsound-" op) (map add-unsound args))]
29+
[(list (and (or '/ 'pow 'log) op) args ...)
30+
`(,(sym-append "sound-" op) ,@(map add-sound args) ,(gensym))]
31+
[(list op args ...) (cons op (map add-sound expr))]
3732
[_ expr]))
3833

39-
(define-syntax define-rule
40-
(syntax-rules ()
41-
[(define-rule rname group input output)
42-
(set! *all-rules* (cons (rule 'rname 'input 'output '(group sound)) *all-rules*))]
43-
[(define-rule rname group input output #:unsound)
44-
(set! *all-rules* (cons (rule 'rname 'input (add-unsound 'output) '(group)) *all-rules*))]))
34+
(define-syntax-rule (define-rule rname group input output)
35+
(set! *all-rules* (cons (rule 'rname 'input 'output '(group)) *all-rules*)))
4536

4637
(define-syntax-rule (define-rules group
4738
[rname input output flags ...] ...)
@@ -154,14 +145,14 @@
154145
(define-rules arithmetic
155146
[mult-flip (/ a b) (* a (/ 1 b))]
156147
[mult-flip-rev (* a (/ 1 b)) (/ a b)]
157-
[div-flip (/ a b) (/ 1 (/ b a)) #:unsound] ; unsound @ a = 0, b != 0
148+
[div-flip (/ a b) (sound-/ 1 (sound-/ b a 0) (/ a b))]
158149
[div-flip-rev (/ 1 (/ b a)) (/ a b)])
159150

160151
; Fractions
161152
(define-rules arithmetic
162-
[sum-to-mult (+ a b) (* (+ 1 (/ b a)) a) #:unsound] ; unsound @ a = 0, b = 1
153+
#;[sum-to-mult (+ a b) (* (+ 1 (/ b a)) a) #:unsound] ; unsound @ a = 0, b = 1
163154
[sum-to-mult-rev (* (+ 1 (/ b a)) a) (+ a b)]
164-
[sub-to-mult (- a b) (* (- 1 (/ b a)) a) #:unsound] ; unsound @ a = 0, b = 1
155+
#;[sub-to-mult (- a b) (* (- 1 (/ b a)) a) #:unsound] ; unsound @ a = 0, b = 1
165156
[sub-to-mult-rev (* (- 1 (/ b a)) a) (- a b)]
166157
[add-to-fraction (+ c (/ b a)) (/ (+ (* c a) b) a)]
167158
[add-to-fraction-rev (/ (+ (* c a) b) a) (+ c (/ b a))]
@@ -170,9 +161,9 @@
170161
[common-denominator (+ (/ a b) (/ c d)) (/ (+ (* a d) (* c b)) (* b d))])
171162

172163
(define-rules polynomials
173-
[sqr-pow (pow a b) (* (pow a (/ b 2)) (pow a (/ b 2))) #:unsound] ; unsound @ a = -1, b = 1
174-
[flip-+ (+ a b) (/ (- (* a a) (* b b)) (- a b)) #:unsound] ; unsound @ a = b = 1
175-
[flip-- (- a b) (/ (- (* a a) (* b b)) (+ a b)) #:unsound]) ; unsound @ a = -1, b = 1
164+
#;[sqr-pow (pow a b) (* (pow a (/ b 2)) (pow a (/ b 2))) #:unsound] ; unsound @ a = -1, b = 1
165+
[flip-+ (+ a b) (sound-/ (- (* a a) (* b b)) (- a b) (+ a b))]
166+
[flip-- (- a b) (sound-/ (- (* a a) (* b b)) (+ a b) (- a b))])
176167

177168
; Difference of cubes
178169
(define-rules polynomials
@@ -181,9 +172,9 @@
181172
[difference-cubes-rev (* (+ (* a a) (+ (* b b) (* a b))) (- a b)) (- (pow a 3) (pow b 3))]
182173
[sum-cubes-rev (* (+ (* a a) (- (* b b) (* a b))) (+ a b)) (+ (pow a 3) (pow b 3))])
183174

184-
(define-rules polynomials ; unsound @ a = b = 0
185-
[flip3-+ (+ a b) (/ (+ (pow a 3) (pow b 3)) (+ (* a a) (- (* b b) (* a b)))) #:unsound]
186-
[flip3-- (- a b) (/ (- (pow a 3) (pow b 3)) (+ (* a a) (+ (* b b) (* a b)))) #:unsound])
175+
(define-rules polynomials
176+
[flip3-+ (+ a b) (sound-/ (+ (pow a 3) (pow b 3)) (+ (* a a) (- (* b b) (* a b))) (+ a b))]
177+
[flip3-- (- a b) (sound-/ (- (pow a 3) (pow b 3)) (+ (* a a) (+ (* b b) (* a b))) (- a b))])
187178

188179
; Dealing with fractions
189180
(define-rules fractions
@@ -245,9 +236,9 @@
245236
[sqrt-undiv (/ (sqrt x) (sqrt y)) (sqrt (/ x y))])
246237

247238
(define-rules arithmetic
248-
[sqrt-prod (sqrt (* x y)) (* (sqrt x) (sqrt y)) #:unsound] ; unsound @ x = y = -1
249-
[sqrt-div (sqrt (/ x y)) (/ (sqrt x) (sqrt y)) #:unsound] ; unsound @ x = y = -1
250-
[add-sqr-sqrt x (* (sqrt x) (sqrt x)) #:unsound]) ; unsound @ x = -1
239+
[sqrt-prod (sqrt (* x y)) (* (sqrt (fabs x)) (sqrt (fabs y)))]
240+
[sqrt-div (sqrt (/ x y)) (/ (sqrt (fabs x)) (sqrt (fabs y)))]
241+
[add-sqr-sqrt x (copysign (* (sqrt (fabs x)) (sqrt (fabs x))) x)])
251242

252243
; Cubing
253244
(define-rules arithmetic
@@ -289,7 +280,7 @@
289280
; Exponentials
290281
(define-rules exponents
291282
[add-log-exp x (log (exp x))]
292-
[add-exp-log x (exp (log x)) #:unsound] ; unsound @ x = 0
283+
#;[add-exp-log x (exp (log x)) #:unsound] ; unsound @ x = -1
293284
[rem-exp-log (exp (log x)) x]
294285
[rem-log-exp (log (exp x)) x])
295286

@@ -348,13 +339,14 @@
348339
[pow-div (/ (pow a b) (pow a c)) (pow a (- b c))])
349340

350341
(define-rules exponents
351-
[pow-plus-rev (pow a (+ b 1)) (* (pow a b) a) #:unsound] ; unsound @ a = 0, b = -1/2
352-
[pow-neg (pow a (neg b)) (/ 1 (pow a b)) #:unsound]) ; unsound @ a = 0, b = -1
342+
[pow-plus-rev (pow a (+ b 1)) (* (sound-pow a b 1) a)]
343+
[pow-neg (pow a (neg b)) (sound-/ 1 (sound-pow a b 0) 0)])
353344

354345
(define-rules exponents
355-
[pow-to-exp (pow a b) (exp (* (log a) b)) #:unsound] ; unsound @ a = -1, b = 1
356-
[pow-add (pow a (+ b c)) (* (pow a b) (pow a c)) #:unsound] ; unsound @ a = -1, b = c = 1/2
357-
[pow-sub (pow a (- b c)) (/ (pow a b) (pow a c)) #:unsound] ; unsound @ a = -1, b = c = 1/2
346+
#;[pow-to-exp (pow a b) (exp (* (log a) b)) #:unsound] ; unsound @ a = -1, b = 1
347+
#;[pow-add (pow a (+ b c)) (* (pow a b) (pow a c)) #:unsound] ; unsound @ a = -1, b = c = 1/2
348+
#;[pow-sub (pow a (- b c)) (/ (pow a b) (pow a c)) #:unsound] ; unsound @ a = -1, b = c = 1/2
349+
#;
358350
[unpow-prod-down (pow (* b c) a) (* (pow b a) (pow c a)) #:unsound]) ; unsound @ a = 1/2, b = c = -1
359351

360352
; Logarithms
@@ -364,9 +356,9 @@
364356
[log-pow-rev (* b (log a)) (log (pow a b))])
365357

366358
(define-rules exponents
367-
[log-prod (log (* a b)) (+ (log a) (log b)) #:unsound] ; unsound @ a = b = -1
368-
[log-div (log (/ a b)) (- (log a) (log b)) #:unsound] ; unsound @ a = b = -1
369-
[log-pow (log (pow a b)) (* b (log a)) #:unsound]) ; unsound @ a = -1, b = 2
359+
[log-prod (log (* a b)) (+ (log (fabs a)) (log (fabs b)))]
360+
[log-div (log (/ a b)) (- (log (fabs a)) (log (fabs b)))]
361+
[log-pow (log (pow a b)) (* b (sound-log (fabs a) 0))])
370362

371363
(define-rules exponents
372364
[sum-log (+ (log a) (log b)) (log (* a b))]

src/core/test-rules.rkt

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,22 @@
1313

1414
(activate-platform! (*platform-name*))
1515

16+
(define skip-rules '(log-pow))
17+
1618
(define num-test-points (make-parameter 100))
1719
(define double-repr (get-representation 'binary64))
1820

1921
(define (env->ctx p1 p2)
2022
(define vars (set-union (free-variables p1) (free-variables p2)))
2123
(context vars double-repr (map (const double-repr) vars)))
2224

23-
(define (drop-unsound expr)
25+
(define (drop-sound expr)
2426
(match expr
25-
[(list op args ...)
26-
#:when (string-contains? (~a op) "unsound")
27-
(define op* (string->symbol (string-replace (symbol->string (car expr)) "unsound-" "")))
28-
(cons op* (map drop-unsound args))]
27+
[(list op args ... extra)
28+
#:when (string-contains? (~a op) "sound")
29+
(define op* (string->symbol (substring (symbol->string (car expr)) (string-length "sound-"))))
30+
(cons op* (map drop-sound args))]
31+
[(list op args ...) (cons op (map drop-sound args))]
2932
[_ expr]))
3033

3134
(define (check-rule test-rule)
@@ -35,7 +38,7 @@
3538
(match-define (list pts exs1 exs2)
3639
(parameterize ([*num-points* (num-test-points)]
3740
[*max-find-range-depth* 0])
38-
(sample-points '(TRUE) (list p1 (drop-unsound p2)) (list ctx ctx))))
41+
(sample-points '(TRUE) (list p1 (drop-sound p2)) (list ctx ctx))))
3942

4043
(for ([pt (in-list pts)]
4144
[v1 (in-list exs1)]
@@ -55,6 +58,7 @@
5558
(check-rule rule))))
5659

5760
(module+ test
58-
(for* ([rule (in-list (*rules*))])
61+
(for* ([rule (in-list (*rules*))]
62+
#:unless (set-member? skip-rules (rule-name rule)))
5963
(test-case (~a (rule-name rule))
6064
(check-rule rule))))

0 commit comments

Comments
 (0)