Skip to content

Commit bfbdde1

Browse files
authored
Merge pull request #1325 from herbie-fp/codex/plan-changes-to-remove-if-costs
Normalize if operator
2 parents 38a6a40 + 17bb823 commit bfbdde1

19 files changed

+106
-187
lines changed

infra/softposit.rkt

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@
8282

8383
;;;;;;;;;;;;;;;;;;;;;;;;;;;;; EMPTY PLATFORM ;;;;;;;;;;;;;;;;;;;;;;;;
8484

85-
(define-if #:cost 1)
8685

8786
;;;;;;;;;;;;;;;;;;;;;;;;;;;;; REPRESENTATIONS ;;;;;;;;;;;;;;;;;;;;;;;
8887

@@ -158,8 +157,17 @@
158157
;;;;;;;;;;;;;;;;;;;;;;;;;;;;; POSIT IMPLS ;;;;;;;;;;;;;;;;;;;;;;;;;;;
159158

160159
(define-representation <posit8> #:cost 1)
160+
(define-operation (if.p8 [c <bool>] [t <posit8>] [f <posit8>]) <posit8>
161+
#:spec (if c t f) #:impl if-impl
162+
#:cost 1 #:aggregate if-cost)
161163
(define-representation <posit16> #:cost 1)
164+
(define-operation (if.p16 [c <bool>] [t <posit16>] [f <posit16>]) <posit16>
165+
#:spec (if c t f) #:impl if-impl
166+
#:cost 1 #:aggregate if-cost)
162167
(define-representation <posit32> #:cost 1)
168+
(define-operation (if.p32 [c <bool>] [t <posit32>] [f <posit32>]) <posit32>
169+
#:spec (if c t f) #:impl if-impl
170+
#:cost 1 #:aggregate if-cost)
163171

164172
(define-operations ([x <posit8>] [y <posit8>]) <bool>
165173
[==.p8 #:spec (== x y) #:impl posit8= #:cost 1]

src/core/alt-table.rkt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,6 @@
3737
[(? symbol?) ((node-cost-proc node repr))]
3838
[(? number?) 0] ; specs
3939
[(approx _ impl) (vector-ref costs impl)]
40-
[(list 'if cond ift iff)
41-
(define cost-proc (node-cost-proc node repr))
42-
(cost-proc (vector-ref costs cond) (vector-ref costs ift) (vector-ref costs iff))]
4340
[(list (? (negate impl-exists?) impl) args ...) 0] ; specs
4441
[(list impl args ...)
4542
(define cost-proc (node-cost-proc node repr))

src/core/bsearch.rkt

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,13 @@
4242
(for/fold ([expr (alt-expr (list-ref alts (sp-cidx (last splitpoints))))])
4343
([splitpoint (cdr (reverse splitpoints))])
4444
(define repr (repr-of (sp-bexpr splitpoint) ctx))
45+
(define if-impl (get-fpcore-impl 'if '() (list (get-representation 'bool) repr repr)))
4546
(define <=-impl (get-fpcore-impl '<= '() (list repr repr)))
46-
`(if (,<=-impl ,(sp-bexpr splitpoint)
47-
,(literal (repr->real (sp-point splitpoint) repr) (representation-name repr)))
48-
,(alt-expr (list-ref alts (sp-cidx splitpoint)))
49-
,expr)))
47+
`(,if-impl (,<=-impl ,(sp-bexpr splitpoint)
48+
,(literal (repr->real (sp-point splitpoint) repr)
49+
(representation-name repr)))
50+
,(alt-expr (list-ref alts (sp-cidx splitpoint)))
51+
,expr)))
5052

5153
;; We don't want unused alts in our history!
5254
(define-values (alts* splitpoints*) (remove-unused-alts alts splitpoints))

src/core/compiler.rkt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,6 @@
4646
[(list op a b c) (op (vector-ref regs a) (vector-ref regs b) (vector-ref regs c))]
4747
[(list op args ...) (apply op (map (curry vector-ref regs) args))]))
4848

49-
(define (if-proc c a b)
50-
(if c a b))
51-
5249
(define (batch-remove-approx batch)
5350
(batch-replace batch
5451
(lambda (node)
@@ -79,7 +76,6 @@
7976
([node (in-vector (batch-nodes batch*) num-vars)])
8077
(match node
8178
[(literal value (app get-representation repr)) (list (const (real->repr value repr)))]
82-
[(list 'if c t f) (list if-proc c t f)]
8379
[(list op args ...) (cons (impl-info op 'fl) args)])))
8480

8581
(make-progs-interpreter vars instructions (batch-roots batch*)))

src/core/egg-herbie.rkt

Lines changed: 9 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,6 @@
294294
[`(Explanation ,body ...) `(Explanation ,@(map (lambda (e) (loop e type)) body))]
295295
[(list 'Rewrite=> rule expr) (list 'Rewrite=> (get-canon-rule-name rule rule) (loop expr type))]
296296
[(list 'Rewrite<= rule expr) (list 'Rewrite<= (get-canon-rule-name rule rule) (loop expr type))]
297-
[(list 'if cond ift iff)
298-
(if (representation? type)
299-
(list 'if (loop cond (get-representation 'bool)) (loop ift type) (loop iff type))
300-
(list 'if (loop cond 'bool) (loop ift type) (loop iff type)))]
301297
[(list (? impl-exists? impl) args ...) (cons impl (map loop args (impl-info impl 'itype)))]
302298
[(list op args ...)
303299
#:when (string-contains? (~a op) "unsound")
@@ -324,7 +320,7 @@
324320
(cons '(*.f64 x y) '(*.f64 $var0 $var1))
325321
(cons '(+.f64 (*.f64 x y) #s(literal 2 binary64)) '(+.f64 (*.f64 $var0 $var1) 2))
326322
(cons '(cos.f32 (PI.f32)) '(cos.f32 (PI.f32)))
327-
(cons '(if (TRUE) x y) '(if (TRUE) $var0 $var1))))
323+
(cons '(if.f64 (TRUE) x y) '(if.f64 (TRUE) $var0 $var1))))
328324

329325
(let ([egg-graph (make-egraph-data)])
330326
(for ([(in expected-out) (in-dict test-exprs)])
@@ -540,9 +536,9 @@
540536
[(cons f _) ; application
541537
(cond
542538
[(eq? f '$approx) (platform-reprs (*active-platform*))]
543-
[(eq? f 'if) (all-reprs/types)]
544539
[(string-contains? (~a f) "unsound") (list 'real)]
545540
[(impl-exists? f) (list (impl-info f 'otype))]
541+
[(eq? f 'if) '(real bool)]
546542
[else (list (operator-info f 'otype))])]))
547543

548544
;; Rebuilds an e-node using typed e-classes
@@ -556,23 +552,15 @@
556552
(define spec (u32vector-ref ids 0))
557553
(define impl (u32vector-ref ids 1))
558554
(list '$approx (lookup spec (representation-type type)) (lookup impl type))]
559-
[(eq? f 'if) ; if expression
560-
(define cond (u32vector-ref ids 0))
561-
(define ift (u32vector-ref ids 1))
562-
(define iff (u32vector-ref ids 2))
563-
(define cond-type
564-
(if (representation? type)
565-
(get-representation 'bool)
566-
'bool))
567-
(list 'if (lookup cond cond-type) (lookup ift type) (lookup iff type))]
568555
[(string-contains? (~a f) "unsound")
569556
(define op (string->symbol (string-replace (symbol->string f) "unsound-" "")))
570557
(list* op (map (λ (x) (lookup (u32vector-ref ids x) 'real)) (range (u32vector-length ids))))]
571558
[else
572559
(define itypes
573-
(if (impl-exists? f)
574-
(impl-info f 'itype)
575-
(operator-info f 'itype)))
560+
(cond
561+
[(impl-exists? f) (impl-info f 'itype)]
562+
[(eq? f 'if) (list 'bool type type)]
563+
[else (operator-info f 'itype)]))
576564
; unsafe since we don't check that |itypes| = |ids|
577565
; optimize for common cases to avoid extra allocations
578566
(cons
@@ -840,10 +828,6 @@
840828
[(? number?) (platform-repr-cost (*active-platform*) type)]
841829
[(? symbol?) (platform-repr-cost (*active-platform*) type)]
842830
[(list '$approx x y) 0]
843-
[(list 'if c x y)
844-
(match (platform-if-cost (*active-platform*))
845-
[`(max ,n) n] ; Not quite right
846-
[`(sum ,n) n])]
847831
[(list op args ...) (impl-info op 'cost)])
848832
1))
849833
(values (string->symbol (format "~a.~a" n k))
@@ -1032,11 +1016,8 @@
10321016
(representation-type type)
10331017
type))
10341018
(approx (loop spec spec-type) (loop impl type))]
1035-
[(list 'if cond ift iff)
1036-
(if (representation? type)
1037-
(list 'if (loop cond (get-representation 'bool)) (loop ift type) (loop iff type))
1038-
(list 'if (loop cond 'bool) (loop ift type) (loop iff type)))]
10391019
[(list (? impl-exists? impl) args ...) (cons impl (map loop args (impl-info impl 'itype)))]
1020+
[(list 'if c t f) (list 'if (loop c 'bool) (loop t 'real) (loop f 'real))]
10401021
[(list op args ...) (cons op (map loop args (operator-info op 'itype)))])))
10411022

10421023
(define (eggref id)
@@ -1067,16 +1048,14 @@
10671048
(define final-spec (egg-parsed->expr spec* spec-type))
10681049
(define final-spec-idx (mutable-batch-munge! out final-spec))
10691050
(approx final-spec-idx (loop impl type))]
1070-
[(list 'if (app eggref cond) (app eggref ift) (app eggref iff))
1071-
(if (representation? type)
1072-
(list 'if (loop cond (get-representation 'bool)) (loop ift type) (loop iff type))
1073-
(list 'if (loop cond 'bool) (loop ift type) (loop iff type)))]
10741051
[(list (? impl-exists? impl) (app eggref args) ...)
10751052
(define args*
10761053
(for/list ([arg (in-list args)]
10771054
[type (in-list (impl-info impl 'itype))])
10781055
(loop arg type)))
10791056
(cons impl args*)]
1057+
[(list 'if c t f)
1058+
(list 'if (loop (eggref c) 'bool) (loop (eggref t) type) (loop (eggref f) type))]
10801059
[(list (? operator-exists? op) (app eggref args) ...)
10811060
(define args*
10821061
(for/list ([arg (in-list args)]
@@ -1120,7 +1099,6 @@
11201099
[(? symbol?) 1]
11211100
; approx node
11221101
[(list '$approx _ impl) (rec impl)]
1123-
[(list 'if cond ift iff) (+ 1 (rec cond) (rec ift) (rec iff))]
11241102
[(list (? impl-exists? impl) args ...)
11251103
(match (pow-impl-args impl args)
11261104
[(cons _ e)
@@ -1150,9 +1128,6 @@
11501128
((node-cost-proc node repr))]
11511129
; approx node
11521130
[(list '$approx _ impl) (rec impl)]
1153-
[(list 'if cond ift iff) ; if expression
1154-
(define cost-proc (node-cost-proc node type))
1155-
(cost-proc (rec cond) (rec ift) (rec iff))]
11561131
[(list (? impl-exists?) args ...) ; impls
11571132
(define cost-proc (node-cost-proc node type))
11581133
(apply cost-proc (map rec args))])]

src/core/egglog-herbie-tests.rkt

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@
5151

5252
(check-equal? (e2->expr '(Var "y")) 'y)
5353

54-
(check-equal? (e2->expr '(IfTy (Var "y")
55-
(Num (bigrat (from-string "1") (from-string "2")))
56-
(Num (bigrat (from-string "0") (from-string "1")))))
57-
'(if y 1/2 0))
54+
(check-equal? (e2->expr '(Iff64Ty (Var "y")
55+
(Num (bigrat (from-string "1") (from-string "2")))
56+
(Num (bigrat (from-string "0") (from-string "1")))))
57+
'(if.f64 y 1/2 0))
5858

5959
(check-equal? (e2->expr '(Mulf64Ty (Num (bigrat (from-string "2") (from-string "1")))
6060
(Num (bigrat (from-string "3") (from-string "1")))))
@@ -154,10 +154,10 @@
154154
(Var "z"))))
155155
'(*.f32 3/2 (+.f32 4/5 z)))
156156

157-
(check-equal? (e2->expr '(IfTy (Var "cond")
158-
(Num (bigrat (from-string "7") (from-string "8")))
159-
(Num (bigrat (from-string "-2") (from-string "3")))))
160-
'(if cond 7/8 -2/3)))
157+
(check-equal? (e2->expr '(Iff32Ty (Var "cond")
158+
(Num (bigrat (from-string "7") (from-string "8")))
159+
(Num (bigrat (from-string "-2") (from-string "3")))))
160+
'(if.f32 cond 7/8 -2/3)))
161161

162162
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
163163
;; Testing API
@@ -262,6 +262,8 @@
262262
(Neqf64Ty . !=.f64)
263263
(NotboolTy . not.bool)
264264
(OrboolTy . or.bool)
265+
(Iff32Ty . if.f32)
266+
(Iff64Ty . if.f64)
265267
(Pif32Ty . PI.f32)
266268
(Pif64Ty . PI.f64)
267269
(Powf32Ty . pow.f32)
@@ -311,7 +313,7 @@
311313
"../syntax/platform.rkt"
312314
"../utils/float.rkt"
313315
"../syntax/load-platform.rkt")
314-
(activate-platform! (*platform-name*))
316+
(activate-platform! "c")
315317

316318
(define batch
317319
(progs->batch (list '(-.f64 (sin.f64 (+.f64 x eps)) (sin.f64 x))

0 commit comments

Comments
 (0)