|
1 | 1 | #lang racket |
2 | 2 |
|
| 3 | +(require math/bigfloat) |
3 | 4 | (require "../syntax/platform.rkt" |
4 | 5 | "../syntax/syntax.rkt" |
5 | 6 | "../syntax/types.rkt" |
|
17 | 18 |
|
18 | 19 | (provide find-preprocessing |
19 | 20 | preprocess-pcontext |
20 | | - remove-unnecessary-preprocessing) |
| 21 | + remove-unnecessary-preprocessing |
| 22 | + compile-preprocessing) |
21 | 23 |
|
22 | | -(define (has-fabs-neg-impls? repr) |
23 | | - (and (get-fpcore-impl '- (repr->prop repr) (list repr)) |
24 | | - (get-fpcore-impl 'fabs (repr->prop repr) (list repr)))) |
| 24 | +(define (has-fabs-impl? repr) |
| 25 | + (get-fpcore-impl 'fabs (repr->prop repr) (list repr))) |
| 26 | + |
| 27 | +(define (has-fmin-fmax-impl? repr) |
| 28 | + (and (get-fpcore-impl 'fmin (repr->prop repr) (list repr repr)) |
| 29 | + (get-fpcore-impl 'fmax (repr->prop repr) (list repr repr)))) |
25 | 30 |
|
26 | 31 | (define (has-copysign-impl? repr) |
27 | | - (get-fpcore-impl 'copysign (repr->prop repr) (list repr repr))) |
| 32 | + (and (get-fpcore-impl '* (repr->prop repr) (list repr repr)) |
| 33 | + (get-fpcore-impl 'copysign (repr->prop repr) (list repr repr)))) |
28 | 34 |
|
29 | 35 | ;; The even identities: f(x) = f(-x) |
30 | 36 | ;; Requires `neg` and `fabs` operator implementations. |
31 | 37 | (define (make-even-identities spec ctx) |
32 | 38 | (for/list ([var (in-list (context-vars ctx))] |
33 | 39 | [repr (in-list (context-var-reprs ctx))] |
34 | | - #:when (has-fabs-neg-impls? repr)) |
| 40 | + #:when (has-fabs-impl? repr)) |
35 | 41 | (cons `(abs ,var) (replace-expression spec var `(neg ,var))))) |
36 | 42 |
|
37 | 43 | ;; The odd identities: f(x) = -f(-x) |
38 | 44 | ;; Requires `neg` and `fabs` operator implementations. |
39 | 45 | (define (make-odd-identities spec ctx) |
40 | 46 | (for/list ([var (in-list (context-vars ctx))] |
41 | 47 | [repr (in-list (context-var-reprs ctx))] |
42 | | - #:when (and (has-fabs-neg-impls? repr) (has-copysign-impl? repr))) |
| 48 | + #:when (and (has-fabs-impl? repr) (has-copysign-impl? (context-repr ctx)))) |
43 | 49 | (cons `(negabs ,var) (replace-expression `(neg ,spec) var `(neg ,var))))) |
44 | 50 |
|
45 | | -;; Swap identities: f(a, b) = f(b, a) |
46 | | -(define (make-swap-identities spec ctx) |
| 51 | +;; Sort identities: f(a, b) = f(b, a) |
| 52 | +;; TODO: require both vars have the same repr |
| 53 | +(define (make-sort-identities spec ctx) |
47 | 54 | (define pairs (combinations (context-vars ctx) 2)) |
48 | | - (for/list ([pair (in-list pairs)]) |
| 55 | + (for/list ([pair (in-list pairs)] |
| 56 | + ;; Can only sort same-repr variables |
| 57 | + #:when (equal? (context-lookup ctx (first pair)) (context-lookup ctx (second pair))) |
| 58 | + #:when (has-fmin-fmax-impl? (context-lookup ctx (first pair)))) |
49 | 59 | (match-define (list a b) pair) |
50 | | - (cons `(swap ,a ,b) (replace-vars `((,a . ,b) (,b . ,a)) spec)))) |
| 60 | + (cons `(sort ,a ,b) (replace-vars `((,a . ,b) (,b . ,a)) spec)))) |
51 | 61 |
|
52 | 62 | ;; See https://pavpanchekha.com/blog/symmetric-expressions.html |
53 | 63 | (define (find-preprocessing expr ctx) |
|
56 | 66 | ;; identities |
57 | 67 | (define even-identities (make-even-identities spec ctx)) |
58 | 68 | (define odd-identities (make-odd-identities spec ctx)) |
59 | | - (define swap-identities (make-swap-identities spec ctx)) |
60 | | - (define identities (append even-identities odd-identities swap-identities)) |
| 69 | + (define sort-identities (make-sort-identities spec ctx)) |
| 70 | + (define identities (append even-identities odd-identities sort-identities)) |
61 | 71 |
|
62 | 72 | ;; make egg runner |
63 | 73 | (define rules (*sound-rules*)) |
|
69 | 79 | (make-list (vector-length (batch-roots batch)) (context-repr ctx)) |
70 | 80 | `((,rules . ((node . ,(*node-limit*))))))) |
71 | 81 |
|
72 | | - ;; TODO : FIGURE HOW TO IMPLEMENT PREPROCESS |
73 | | - |
74 | 82 | ;; collect equalities |
75 | 83 | (define abs-instrs |
76 | 84 | (for/list ([(ident spec*) (in-dict even-identities)] |
|
82 | 90 | #:when (egraph-equal? runner spec spec*)) |
83 | 91 | ident)) |
84 | 92 |
|
85 | | - (define swaps |
86 | | - (for/list ([(ident spec*) (in-dict swap-identities)] |
87 | | - #:when (egraph-equal? runner spec spec*)) |
88 | | - (match-define (list 'swap a b) ident) |
89 | | - (list a b))) |
90 | | - (define components (connected-components (context-vars ctx) swaps)) |
91 | 93 | (define sort-instrs |
92 | | - (for/list ([component (in-list components)] |
93 | | - #:when (> (length component) 1)) |
94 | | - (cons 'sort component))) |
95 | | - |
| 94 | + (for/list ([(ident spec*) (in-dict sort-identities)] |
| 95 | + #:when (egraph-equal? runner spec spec*)) |
| 96 | + ident)) |
96 | 97 | (append abs-instrs negabs-instrs sort-instrs)) |
97 | 98 |
|
98 | | -(define (connected-components variables swaps) |
99 | | - (define components (disjoint-set (length variables))) |
100 | | - (for ([swap (in-list swaps)]) |
101 | | - (match-define (list a b) swap) |
102 | | - (disjoint-set-union! components |
103 | | - (disjoint-set-find! components (index-of variables a)) |
104 | | - (disjoint-set-find! components (index-of variables b)))) |
105 | | - (group-by (compose (curry disjoint-set-find! components) (curry index-of variables)) variables)) |
106 | | - |
107 | 99 | (define (preprocess-pcontext context pcontext preprocessing) |
108 | 100 | (define preprocess |
109 | 101 | (apply compose |
|
130 | 122 | (define variables (context-vars context)) |
131 | 123 | (define sort* (curryr sort (curryr </total (context-repr context)))) |
132 | 124 | (match instruction |
133 | | - [(list 'sort component ...) |
134 | | - (define indices (indexes-where variables (curryr member component))) |
| 125 | + [(list 'sort a b) |
| 126 | + (define indices (indexes-where variables (curry set-member? (list a b)))) |
| 127 | + (define repr (context-lookup context a)) |
135 | 128 | (lambda (x y) |
136 | 129 | (define subsequence (map (curry vector-ref x) indices)) |
137 | | - (define sorted (sort* subsequence)) |
| 130 | + (define sorted (sort subsequence (curryr </total repr))) |
138 | 131 | (values (vector-set* x indices sorted) y))] |
139 | 132 | [(list 'abs variable) |
140 | 133 | (define index (index-of variables variable)) |
141 | 134 | (define var-repr (context-lookup context variable)) |
142 | | - (define abs-proc (impl-info (get-fpcore-impl 'fabs (repr->prop var-repr) (list var-repr)) 'fl)) |
143 | | - (lambda (x y) (values (vector-update x index abs-proc) y))] |
| 135 | + (define fabs (impl-info (get-fpcore-impl 'fabs (repr->prop var-repr) (list var-repr)) 'fl)) |
| 136 | + (lambda (x y) (values (vector-update x index fabs) y))] |
144 | 137 | [(list 'negabs variable) |
145 | 138 | (define index (index-of variables variable)) |
146 | 139 | (define var-repr (context-lookup context variable)) |
147 | | - (define neg-var (impl-info (get-fpcore-impl '- (repr->prop var-repr) (list var-repr)) 'fl)) |
148 | | - |
149 | 140 | (define repr (context-repr context)) |
150 | | - (define neg-expr (impl-info (get-fpcore-impl '- (repr->prop repr) (list repr)) 'fl)) |
151 | | - |
| 141 | + (define fabs (impl-info (get-fpcore-impl 'fabs (repr->prop var-repr) (list var-repr)) 'fl)) |
| 142 | + (define mul (impl-info (get-fpcore-impl '* (repr->prop repr) (list repr repr)) 'fl)) |
| 143 | + (define copysign (impl-info (get-fpcore-impl 'copysign (repr->prop repr) (list repr repr)) 'fl)) |
| 144 | + (define repr1 ((representation-bf->repr repr) 1.bf)) |
152 | 145 | (lambda (x y) |
153 | | - ;; Negation is involutive, i.e. it is its own inverse, so t^1(y') = -y' |
154 | | - (if (negative? (repr->real (vector-ref x index) (context-repr context))) |
155 | | - (values (vector-update x index neg-var) (neg-expr y)) |
156 | | - (values x y)))])) |
| 146 | + (values (vector-update x index fabs) (mul (copysign repr1 (vector-ref x index)) y)))])) |
157 | 147 |
|
158 | 148 | ; until fixed point, iterate through preprocessing attempting to drop preprocessing with no effect on error |
159 | 149 | (define (remove-unnecessary-preprocessing expression |
|
182 | 172 | (define pcontext2 (preprocess-pcontext context pcontext preprocessing2)) |
183 | 173 | (<= (errors-score (errors expression pcontext1 context)) |
184 | 174 | (errors-score (errors expression pcontext2 context)))) |
| 175 | + |
| 176 | +(define (compile-preprocessing expression context preprocessing) |
| 177 | + (match preprocessing |
| 178 | + ; Not handled yet |
| 179 | + [(list 'sort a b) |
| 180 | + (define repr (context-lookup context a)) |
| 181 | + (define fmin (get-fpcore-impl 'fmin (repr->prop repr) (list repr repr))) |
| 182 | + (define fmax (get-fpcore-impl 'fmax (repr->prop repr) (list repr repr))) |
| 183 | + (replace-vars (list (cons a `(,fmin ,a ,b)) (cons b `(,fmax ,a ,b))) expression)] |
| 184 | + [(list 'abs var) |
| 185 | + (define repr (context-lookup context var)) |
| 186 | + (define fabs (get-fpcore-impl 'fabs (repr->prop repr) (list repr))) |
| 187 | + (define replacement `(,fabs ,var)) |
| 188 | + (replace-expression expression var replacement)] |
| 189 | + [(list 'negabs var) |
| 190 | + (define repr (context-lookup context var)) |
| 191 | + (define fabs (get-fpcore-impl 'fabs (repr->prop repr) (list repr))) |
| 192 | + (define replacement `(,fabs ,var)) |
| 193 | + (define mul (get-fpcore-impl '* (repr->prop repr) (list repr repr))) |
| 194 | + (define copysign (get-fpcore-impl 'copysign (repr->prop repr) (list repr repr))) |
| 195 | + `(,mul (,copysign ,(literal 1 (representation-name repr)) ,var) |
| 196 | + ,(replace-expression expression var replacement))])) |
0 commit comments