|
294 | 294 | [`(Explanation ,body ...) `(Explanation ,@(map (lambda (e) (loop e type)) body))] |
295 | 295 | [(list 'Rewrite=> rule expr) (list 'Rewrite=> (get-canon-rule-name rule rule) (loop expr type))] |
296 | 296 | [(list 'Rewrite<= rule expr) (list 'Rewrite<= (get-canon-rule-name rule rule) (loop expr type))] |
297 | | - [(list 'if cond ift iff) |
298 | | - (if (representation? type) |
299 | | - (list 'if (loop cond (get-representation 'bool)) (loop ift type) (loop iff type)) |
300 | | - (list 'if (loop cond 'bool) (loop ift type) (loop iff type)))] |
301 | 297 | [(list (? impl-exists? impl) args ...) (cons impl (map loop args (impl-info impl 'itype)))] |
302 | 298 | [(list op args ...) |
303 | 299 | #:when (string-contains? (~a op) "unsound") |
|
324 | 320 | (cons '(*.f64 x y) '(*.f64 $var0 $var1)) |
325 | 321 | (cons '(+.f64 (*.f64 x y) #s(literal 2 binary64)) '(+.f64 (*.f64 $var0 $var1) 2)) |
326 | 322 | (cons '(cos.f32 (PI.f32)) '(cos.f32 (PI.f32))) |
327 | | - (cons '(if (TRUE) x y) '(if (TRUE) $var0 $var1)))) |
| 323 | + (cons '(if.f64 (TRUE) x y) '(if.f64 (TRUE) $var0 $var1)))) |
328 | 324 |
|
329 | 325 | (let ([egg-graph (make-egraph-data)]) |
330 | 326 | (for ([(in expected-out) (in-dict test-exprs)]) |
|
540 | 536 | [(cons f _) ; application |
541 | 537 | (cond |
542 | 538 | [(eq? f '$approx) (platform-reprs (*active-platform*))] |
543 | | - [(eq? f 'if) (all-reprs/types)] |
544 | 539 | [(string-contains? (~a f) "unsound") (list 'real)] |
545 | 540 | [(impl-exists? f) (list (impl-info f 'otype))] |
| 541 | + [(eq? f 'if) '(real bool)] |
546 | 542 | [else (list (operator-info f 'otype))])])) |
547 | 543 |
|
548 | 544 | ;; Rebuilds an e-node using typed e-classes |
|
556 | 552 | (define spec (u32vector-ref ids 0)) |
557 | 553 | (define impl (u32vector-ref ids 1)) |
558 | 554 | (list '$approx (lookup spec (representation-type type)) (lookup impl type))] |
559 | | - [(eq? f 'if) ; if expression |
560 | | - (define cond (u32vector-ref ids 0)) |
561 | | - (define ift (u32vector-ref ids 1)) |
562 | | - (define iff (u32vector-ref ids 2)) |
563 | | - (define cond-type |
564 | | - (if (representation? type) |
565 | | - (get-representation 'bool) |
566 | | - 'bool)) |
567 | | - (list 'if (lookup cond cond-type) (lookup ift type) (lookup iff type))] |
568 | 555 | [(string-contains? (~a f) "unsound") |
569 | 556 | (define op (string->symbol (string-replace (symbol->string f) "unsound-" ""))) |
570 | 557 | (list* op (map (λ (x) (lookup (u32vector-ref ids x) 'real)) (range (u32vector-length ids))))] |
571 | 558 | [else |
572 | 559 | (define itypes |
573 | | - (if (impl-exists? f) |
574 | | - (impl-info f 'itype) |
575 | | - (operator-info f 'itype))) |
| 560 | + (cond |
| 561 | + [(impl-exists? f) (impl-info f 'itype)] |
| 562 | + [(eq? f 'if) (list 'bool type type)] |
| 563 | + [else (operator-info f 'itype)])) |
576 | 564 | ; unsafe since we don't check that |itypes| = |ids| |
577 | 565 | ; optimize for common cases to avoid extra allocations |
578 | 566 | (cons |
|
840 | 828 | [(? number?) (platform-repr-cost (*active-platform*) type)] |
841 | 829 | [(? symbol?) (platform-repr-cost (*active-platform*) type)] |
842 | 830 | [(list '$approx x y) 0] |
843 | | - [(list 'if c x y) |
844 | | - (match (platform-if-cost (*active-platform*)) |
845 | | - [`(max ,n) n] ; Not quite right |
846 | | - [`(sum ,n) n])] |
847 | 831 | [(list op args ...) (impl-info op 'cost)]) |
848 | 832 | 1)) |
849 | 833 | (values (string->symbol (format "~a.~a" n k)) |
|
1032 | 1016 | (representation-type type) |
1033 | 1017 | type)) |
1034 | 1018 | (approx (loop spec spec-type) (loop impl type))] |
1035 | | - [(list 'if cond ift iff) |
1036 | | - (if (representation? type) |
1037 | | - (list 'if (loop cond (get-representation 'bool)) (loop ift type) (loop iff type)) |
1038 | | - (list 'if (loop cond 'bool) (loop ift type) (loop iff type)))] |
1039 | 1019 | [(list (? impl-exists? impl) args ...) (cons impl (map loop args (impl-info impl 'itype)))] |
| 1020 | + [(list 'if c t f) (list 'if (loop c 'bool) (loop t 'real) (loop f 'real))] |
1040 | 1021 | [(list op args ...) (cons op (map loop args (operator-info op 'itype)))]))) |
1041 | 1022 |
|
1042 | 1023 | (define (eggref id) |
|
1067 | 1048 | (define final-spec (egg-parsed->expr spec* spec-type)) |
1068 | 1049 | (define final-spec-idx (mutable-batch-munge! out final-spec)) |
1069 | 1050 | (approx final-spec-idx (loop impl type))] |
1070 | | - [(list 'if (app eggref cond) (app eggref ift) (app eggref iff)) |
1071 | | - (if (representation? type) |
1072 | | - (list 'if (loop cond (get-representation 'bool)) (loop ift type) (loop iff type)) |
1073 | | - (list 'if (loop cond 'bool) (loop ift type) (loop iff type)))] |
1074 | 1051 | [(list (? impl-exists? impl) (app eggref args) ...) |
1075 | 1052 | (define args* |
1076 | 1053 | (for/list ([arg (in-list args)] |
1077 | 1054 | [type (in-list (impl-info impl 'itype))]) |
1078 | 1055 | (loop arg type))) |
1079 | 1056 | (cons impl args*)] |
| 1057 | + [(list 'if c t f) |
| 1058 | + (list 'if (loop (eggref c) 'bool) (loop (eggref t) type) (loop (eggref f) type))] |
1080 | 1059 | [(list (? operator-exists? op) (app eggref args) ...) |
1081 | 1060 | (define args* |
1082 | 1061 | (for/list ([arg (in-list args)] |
|
1120 | 1099 | [(? symbol?) 1] |
1121 | 1100 | ; approx node |
1122 | 1101 | [(list '$approx _ impl) (rec impl)] |
1123 | | - [(list 'if cond ift iff) (+ 1 (rec cond) (rec ift) (rec iff))] |
1124 | 1102 | [(list (? impl-exists? impl) args ...) |
1125 | 1103 | (match (pow-impl-args impl args) |
1126 | 1104 | [(cons _ e) |
|
1150 | 1128 | ((node-cost-proc node repr))] |
1151 | 1129 | ; approx node |
1152 | 1130 | [(list '$approx _ impl) (rec impl)] |
1153 | | - [(list 'if cond ift iff) ; if expression |
1154 | | - (define cost-proc (node-cost-proc node type)) |
1155 | | - (cost-proc (rec cond) (rec ift) (rec iff))] |
1156 | 1131 | [(list (? impl-exists?) args ...) ; impls |
1157 | 1132 | (define cost-proc (node-cost-proc node type)) |
1158 | 1133 | (apply cost-proc (map rec args))])] |
|
0 commit comments