Skip to content

Commit 24d0a0e

Browse files
authored
Merge pull request #1238 from herbie-fp/compile-preprocessing
Compile away preprocessing
2 parents 12353b3 + 88a2994 commit 24d0a0e

File tree

5 files changed

+86
-187
lines changed

5 files changed

+86
-187
lines changed

src/core/mainloop.rkt

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,17 @@
5050
(define alternatives (extract!))
5151
(timeline-event! 'preprocess)
5252
(for/list ([altn alternatives])
53-
(define expr (alt-expr altn))
54-
(define preprocessing (alt-preprocessing altn))
55-
(alt-add-preprocessing altn
56-
(remove-unnecessary-preprocessing expr context pcontext preprocessing))))
53+
(apply-preprocessing altn context pcontext)))
54+
55+
(define (apply-preprocessing altn context pcontext)
56+
(define expr (alt-expr altn))
57+
(define initial-preprocessing (alt-preprocessing altn))
58+
(define useful-preprocessing
59+
(remove-unnecessary-preprocessing expr context pcontext initial-preprocessing))
60+
(define expr*
61+
(for/fold ([expr expr]) ([preprocessing (in-list (reverse useful-preprocessing))])
62+
(compile-preprocessing expr context preprocessing)))
63+
(alt expr* 'add-preprocessing (list altn) '()))
5764

5865
(define (extract!)
5966
(timeline-push-alts! '())

src/core/preprocess.rkt

Lines changed: 59 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#lang racket
22

3+
(require math/bigfloat)
34
(require "../syntax/platform.rkt"
45
"../syntax/syntax.rkt"
56
"../syntax/types.rkt"
@@ -17,37 +18,46 @@
1718

1819
(provide find-preprocessing
1920
preprocess-pcontext
20-
remove-unnecessary-preprocessing)
21+
remove-unnecessary-preprocessing
22+
compile-preprocessing)
2123

22-
(define (has-fabs-neg-impls? repr)
23-
(and (get-fpcore-impl '- (repr->prop repr) (list repr))
24-
(get-fpcore-impl 'fabs (repr->prop repr) (list repr))))
24+
(define (has-fabs-impl? repr)
25+
(get-fpcore-impl 'fabs (repr->prop repr) (list repr)))
26+
27+
(define (has-fmin-fmax-impl? repr)
28+
(and (get-fpcore-impl 'fmin (repr->prop repr) (list repr repr))
29+
(get-fpcore-impl 'fmax (repr->prop repr) (list repr repr))))
2530

2631
(define (has-copysign-impl? repr)
27-
(get-fpcore-impl 'copysign (repr->prop repr) (list repr repr)))
32+
(and (get-fpcore-impl '* (repr->prop repr) (list repr repr))
33+
(get-fpcore-impl 'copysign (repr->prop repr) (list repr repr))))
2834

2935
;; The even identities: f(x) = f(-x)
3036
;; Requires `neg` and `fabs` operator implementations.
3137
(define (make-even-identities spec ctx)
3238
(for/list ([var (in-list (context-vars ctx))]
3339
[repr (in-list (context-var-reprs ctx))]
34-
#:when (has-fabs-neg-impls? repr))
40+
#:when (has-fabs-impl? repr))
3541
(cons `(abs ,var) (replace-expression spec var `(neg ,var)))))
3642

3743
;; The odd identities: f(x) = -f(-x)
3844
;; Requires `neg` and `fabs` operator implementations.
3945
(define (make-odd-identities spec ctx)
4046
(for/list ([var (in-list (context-vars ctx))]
4147
[repr (in-list (context-var-reprs ctx))]
42-
#:when (and (has-fabs-neg-impls? repr) (has-copysign-impl? repr)))
48+
#:when (and (has-fabs-impl? repr) (has-copysign-impl? (context-repr ctx))))
4349
(cons `(negabs ,var) (replace-expression `(neg ,spec) var `(neg ,var)))))
4450

45-
;; Swap identities: f(a, b) = f(b, a)
46-
(define (make-swap-identities spec ctx)
51+
;; Sort identities: f(a, b) = f(b, a)
52+
;; TODO: require both vars have the same repr
53+
(define (make-sort-identities spec ctx)
4754
(define pairs (combinations (context-vars ctx) 2))
48-
(for/list ([pair (in-list pairs)])
55+
(for/list ([pair (in-list pairs)]
56+
;; Can only sort same-repr variables
57+
#:when (equal? (context-lookup ctx (first pair)) (context-lookup ctx (second pair)))
58+
#:when (has-fmin-fmax-impl? (context-lookup ctx (first pair))))
4959
(match-define (list a b) pair)
50-
(cons `(swap ,a ,b) (replace-vars `((,a . ,b) (,b . ,a)) spec))))
60+
(cons `(sort ,a ,b) (replace-vars `((,a . ,b) (,b . ,a)) spec))))
5161

5262
;; See https://pavpanchekha.com/blog/symmetric-expressions.html
5363
(define (find-preprocessing expr ctx)
@@ -56,8 +66,8 @@
5666
;; identities
5767
(define even-identities (make-even-identities spec ctx))
5868
(define odd-identities (make-odd-identities spec ctx))
59-
(define swap-identities (make-swap-identities spec ctx))
60-
(define identities (append even-identities odd-identities swap-identities))
69+
(define sort-identities (make-sort-identities spec ctx))
70+
(define identities (append even-identities odd-identities sort-identities))
6171

6272
;; make egg runner
6373
(define rules (*sound-rules*))
@@ -69,8 +79,6 @@
6979
(make-list (vector-length (batch-roots batch)) (context-repr ctx))
7080
`((,rules . ((node . ,(*node-limit*)))))))
7181

72-
;; TODO : FIGURE HOW TO IMPLEMENT PREPROCESS
73-
7482
;; collect equalities
7583
(define abs-instrs
7684
(for/list ([(ident spec*) (in-dict even-identities)]
@@ -82,28 +90,12 @@
8290
#:when (egraph-equal? runner spec spec*))
8391
ident))
8492

85-
(define swaps
86-
(for/list ([(ident spec*) (in-dict swap-identities)]
87-
#:when (egraph-equal? runner spec spec*))
88-
(match-define (list 'swap a b) ident)
89-
(list a b)))
90-
(define components (connected-components (context-vars ctx) swaps))
9193
(define sort-instrs
92-
(for/list ([component (in-list components)]
93-
#:when (> (length component) 1))
94-
(cons 'sort component)))
95-
94+
(for/list ([(ident spec*) (in-dict sort-identities)]
95+
#:when (egraph-equal? runner spec spec*))
96+
ident))
9697
(append abs-instrs negabs-instrs sort-instrs))
9798

98-
(define (connected-components variables swaps)
99-
(define components (disjoint-set (length variables)))
100-
(for ([swap (in-list swaps)])
101-
(match-define (list a b) swap)
102-
(disjoint-set-union! components
103-
(disjoint-set-find! components (index-of variables a))
104-
(disjoint-set-find! components (index-of variables b))))
105-
(group-by (compose (curry disjoint-set-find! components) (curry index-of variables)) variables))
106-
10799
(define (preprocess-pcontext context pcontext preprocessing)
108100
(define preprocess
109101
(apply compose
@@ -130,30 +122,28 @@
130122
(define variables (context-vars context))
131123
(define sort* (curryr sort (curryr </total (context-repr context))))
132124
(match instruction
133-
[(list 'sort component ...)
134-
(define indices (indexes-where variables (curryr member component)))
125+
[(list 'sort a b)
126+
(define indices (indexes-where variables (curry set-member? (list a b))))
127+
(define repr (context-lookup context a))
135128
(lambda (x y)
136129
(define subsequence (map (curry vector-ref x) indices))
137-
(define sorted (sort* subsequence))
130+
(define sorted (sort subsequence (curryr </total repr)))
138131
(values (vector-set* x indices sorted) y))]
139132
[(list 'abs variable)
140133
(define index (index-of variables variable))
141134
(define var-repr (context-lookup context variable))
142-
(define abs-proc (impl-info (get-fpcore-impl 'fabs (repr->prop var-repr) (list var-repr)) 'fl))
143-
(lambda (x y) (values (vector-update x index abs-proc) y))]
135+
(define fabs (impl-info (get-fpcore-impl 'fabs (repr->prop var-repr) (list var-repr)) 'fl))
136+
(lambda (x y) (values (vector-update x index fabs) y))]
144137
[(list 'negabs variable)
145138
(define index (index-of variables variable))
146139
(define var-repr (context-lookup context variable))
147-
(define neg-var (impl-info (get-fpcore-impl '- (repr->prop var-repr) (list var-repr)) 'fl))
148-
149140
(define repr (context-repr context))
150-
(define neg-expr (impl-info (get-fpcore-impl '- (repr->prop repr) (list repr)) 'fl))
151-
141+
(define fabs (impl-info (get-fpcore-impl 'fabs (repr->prop var-repr) (list var-repr)) 'fl))
142+
(define mul (impl-info (get-fpcore-impl '* (repr->prop repr) (list repr repr)) 'fl))
143+
(define copysign (impl-info (get-fpcore-impl 'copysign (repr->prop repr) (list repr repr)) 'fl))
144+
(define repr1 ((representation-bf->repr repr) 1.bf))
152145
(lambda (x y)
153-
;; Negation is involutive, i.e. it is its own inverse, so t^1(y') = -y'
154-
(if (negative? (repr->real (vector-ref x index) (context-repr context)))
155-
(values (vector-update x index neg-var) (neg-expr y))
156-
(values x y)))]))
146+
(values (vector-update x index fabs) (mul (copysign repr1 (vector-ref x index)) y)))]))
157147

158148
; until fixed point, iterate through preprocessing attempting to drop preprocessing with no effect on error
159149
(define (remove-unnecessary-preprocessing expression
@@ -182,3 +172,25 @@
182172
(define pcontext2 (preprocess-pcontext context pcontext preprocessing2))
183173
(<= (errors-score (errors expression pcontext1 context))
184174
(errors-score (errors expression pcontext2 context))))
175+
176+
(define (compile-preprocessing expression context preprocessing)
177+
(match preprocessing
178+
; Not handled yet
179+
[(list 'sort a b)
180+
(define repr (context-lookup context a))
181+
(define fmin (get-fpcore-impl 'fmin (repr->prop repr) (list repr repr)))
182+
(define fmax (get-fpcore-impl 'fmax (repr->prop repr) (list repr repr)))
183+
(replace-vars (list (cons a `(,fmin ,a ,b)) (cons b `(,fmax ,a ,b))) expression)]
184+
[(list 'abs var)
185+
(define repr (context-lookup context var))
186+
(define fabs (get-fpcore-impl 'fabs (repr->prop repr) (list repr)))
187+
(define replacement `(,fabs ,var))
188+
(replace-expression expression var replacement)]
189+
[(list 'negabs var)
190+
(define repr (context-lookup context var))
191+
(define fabs (get-fpcore-impl 'fabs (repr->prop repr) (list repr)))
192+
(define replacement `(,fabs ,var))
193+
(define mul (get-fpcore-impl '* (repr->prop repr) (list repr repr)))
194+
(define copysign (get-fpcore-impl 'copysign (repr->prop repr) (list repr repr)))
195+
`(,mul (,copysign ,(literal 1 (representation-name repr)) ,var)
196+
,(replace-expression expression var replacement))]))

src/core/rules.rkt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,15 @@
241241
[fabs-cbrt (fabs (/ (cbrt a) a)) (/ (cbrt a) a)]
242242
[fabs-cbrt-rev (/ (cbrt a) a) (fabs (/ (cbrt a) a))])
243243

244+
; Copysign
245+
(define-rules arithmetic
246+
[copysign-neg (copysign a (neg b)) (neg (copysign a b))]
247+
[neg-copysign (neg (copysign a b)) (copysign a (neg b))]
248+
[copysign-other-neg (copysign (neg a) b) (copysign a b)]
249+
[copysign-fabs (copysign a (fabs b)) (fabs a)]
250+
[copysign-other-fabs (copysign (fabs a) b) (copysign a b)]
251+
[fabs-copysign (fabs (copysign a b)) (fabs a)])
252+
244253
; Square root
245254
(define-rules arithmetic
246255
[sqrt-pow2 (pow (sqrt x) y) (pow x (/ y 2))]
@@ -284,6 +293,11 @@
284293
[cbrt-div-cbrt (/ (cbrt x) (fabs (cbrt x))) (copysign 1 x)]
285294
[cbrt-div-cbrt2 (/ (fabs (cbrt x)) (cbrt x)) (copysign 1 x)])
286295

296+
; Min and max
297+
(define-rules arithmetic
298+
[fmin-swap (fmin a b) (fmin b a)]
299+
[fmax-swap (fmax a b) (fmax b a)])
300+
287301
; Exponentials
288302
(define-rules exponents
289303
[add-log-exp x (log (exp x))]

src/reports/common.rkt

Lines changed: 2 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -155,117 +155,15 @@
155155
(core->tex prog* #:loc (and loc (cons 2 loc)) #:color "blue")
156156
"ERROR"))
157157

158-
(define (combine-fpcore-instruction i e c)
159-
(match i
160-
[(list 'abs x)
161-
(define x* (string->symbol (string-append (symbol->string x) "_m")))
162-
(define e* (replace-expression e x x*))
163-
(define p (index-of (context-vars c) x))
164-
(define c* (struct-copy context c [vars (list-set (context-vars c) p x*)]))
165-
(cons e* c*)]
166-
[(list 'negabs x)
167-
(define x-string (symbol->string x))
168-
(define x-sign (string->symbol (string-append x-string "_s")))
169-
(define x* (string->symbol (string-append x-string "_m")))
170-
(define p (index-of (context-vars c) x))
171-
(define r (list-ref (context-var-reprs c) p))
172-
(define c* (struct-copy context c [vars (list-set (context-vars c) p x*)]))
173-
(define c** (context-extend c* x-sign r))
174-
(define *-impl (get-fpcore-impl '* (repr->prop (context-repr c)) (list r (context-repr c))))
175-
(define e* (list *-impl x-sign (replace-expression e x x*)))
176-
(cons e* c**)]
177-
[_ (cons e c)]))
178-
179-
(define (format-prelude-instruction instruction ctx ctx* language converter)
180-
(define (converter* e c)
181-
(define fpcore (program->fpcore e c))
182-
(define output (converter fpcore "code"))
183-
(define lines (string-split output "\n"))
184-
(match language
185-
["FPCore" (pretty-format e #:mode 'display)]
186-
["Fortran" (string-trim (third lines) #px"\\s+code\\s+=\\s+")]
187-
["MATLAB" (string-trim (second lines) #px"\\s+tmp\\s+=\\s+")]
188-
["Wolfram" (string-trim (first lines) #px".*:=\\s+")]
189-
["TeX" output]
190-
[_ (string-trim (second lines) #px"\\s+return\\s+")]))
191-
(match instruction
192-
[(list 'abs x)
193-
(define x* (string->symbol (string-append (symbol->string x) "_m")))
194-
(define r (list-ref (context-var-reprs ctx) (index-of (context-vars ctx) x)))
195-
(define fabs-impl (get-fpcore-impl 'fabs (repr->prop r) (list r)))
196-
(define e (list fabs-impl x))
197-
(define c (context (list x) r r))
198-
(list (format "~a = ~a" x* (converter* e c)))]
199-
[(list 'negabs x)
200-
; TODO: why are x* and x-sign unused?
201-
(define x* (string->symbol (format "~a_m" x)))
202-
(define r (context-lookup ctx x))
203-
(define fabs-impl (get-fpcore-impl 'fabs (repr->prop r) (list r)))
204-
(define copysign-impl (get-fpcore-impl 'copysign (repr->prop r) (list r r)))
205-
(define e* (list fabs-impl x))
206-
(define x-sign (string->symbol (format "~a_s" x)))
207-
(define e-sign (list copysign-impl (literal 1 (representation-name r)) x))
208-
(define c (context (list x) r r))
209-
(list (format "~a = ~a" (format "~a\\_m" x) (converter* e* c))
210-
(format "~a = ~a" (format "~a\\_s" x) (converter* e-sign c)))]
211-
[(list 'sort vs ...)
212-
(define vs (context-vars ctx))
213-
(define vs* (context-vars ctx*))
214-
;; We added some sign-* variables to the front of the variable
215-
;; list in `ctx*`, we only want the originals here
216-
(list (format-sort-instruction (take-right vs* (length vs)) language))]))
217-
218-
(define (format-sort-instruction vs l)
219-
(match l
220-
["C" (format "assert(~a);" (format-less-than-condition vs))]
221-
["Java" (format "assert ~a;" (format-less-than-condition vs))]
222-
["Python"
223-
(define comma-joined (comma-join vs))
224-
(format "[~a] = sort([~a])" comma-joined comma-joined)]
225-
["Julia"
226-
(define comma-joined (comma-join vs))
227-
(format "~a = sort([~a])" comma-joined comma-joined)]
228-
["MATLAB"
229-
(define comma-joined (comma-join vs))
230-
(format "~a = num2cell(sort([~a])){:}" comma-joined comma-joined)]
231-
["TeX"
232-
(define comma-joined (comma-join vs))
233-
(format "[~a] = \\mathsf{sort}([~a])\\\\" comma-joined comma-joined)]
234-
[_
235-
(match vs
236-
[(list x y) (format sort-note (format "~a and ~a" x y))]
237-
[(list vs ...)
238-
(format sort-note
239-
(string-join (map ~a vs)
240-
", "
241-
;; "Lil Jon, he always tells the truth"
242-
#:before-last ", and "))])]))
243-
244-
(define (format-less-than-condition variables)
245-
(string-join (for/list ([a (in-list variables)]
246-
[b (in-list (cdr variables))])
247-
(format "~a < ~a" a b))
248-
" && "))
249-
250-
(define (comma-join vs)
251-
(string-join (map ~a vs) ", "))
252-
253-
(define sort-note "NOTE: ~a should be sorted in increasing order before calling this function.")
254-
255158
(define (render-program expr
256159
ctx
257160
#:ident [identifier #f]
258161
#:pre [precondition '(TRUE)]
259162
#:instructions [instructions empty])
260163
(define output-repr (context-repr ctx))
261-
(match-define (cons expr* ctx*)
262-
(foldl (match-lambda*
263-
[(list i (cons e c)) (combine-fpcore-instruction i e c)])
264-
(cons expr ctx)
265-
instructions))
266164
(define out-prog
267165
(parameterize ([*expr-cse-able?* at-least-two-ops?])
268-
(core-cse (program->fpcore expr* ctx* #:ident identifier))))
166+
(core-cse (program->fpcore expr ctx #:ident identifier))))
269167

270168
(define output-prec (representation-name output-repr))
271169
(define out-prog* (fpcore-add-props out-prog (list ':precision output-prec)))
@@ -281,19 +179,7 @@
281179
(symbol->string identifier)
282180
"code"))
283181
(define out (converter out-prog* name))
284-
(define prelude-lines
285-
(string-join
286-
(append-map (lambda (instruction)
287-
(format-prelude-instruction instruction ctx ctx* lang converter))
288-
instructions)
289-
(if (equal? lang "TeX") "\\\\\n" "\n")
290-
#:after-last "\n"))
291-
(sow (cons lang
292-
((if (equal? lang "TeX")
293-
(curry format "\\begin{array}{l}\n~a\\\\\n~a\\end{array}\n")
294-
string-append)
295-
prelude-lines
296-
out)))))))
182+
(sow (cons lang out))))))
297183

298184
(define math-out (dict-ref versions "TeX" ""))
299185

0 commit comments

Comments
 (0)