Skip to content

Commit b0629e1

Browse files
committed
recurse on spec everywhere
1 parent 1a2a641 commit b0629e1

File tree

5 files changed

+49
-36
lines changed

5 files changed

+49
-36
lines changed

src/core/batch.rkt

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,23 @@
1010
(struct-out batchref) ; temporarily for patch.rkt
1111
(struct-out mutable-batch) ; temporarily for patch.rkt
1212
batch-length ; Batch -> Integer
13-
batch-ref ; Batch -> Index -> Expr
13+
batch-ref ; Batch -> Idx -> Expr
1414
deref ; Batchref -> Expr
1515
batch-replace ; Batch -> (Expr<Batchref> -> Expr<Batchref>) -> Batch
1616
egg-nodes->batch ; Nodes -> Spec-maps -> Batch -> (Listof Batchref)
1717
batchref->expr ; Batchref -> Expr
1818
batch-remove-zombie ; Batch -> *(Vectorof Root) -> Batch
19-
mutable-batch-add-expr! ; Mutable-batch -> Root
19+
mutable-batch-munge! ; Mutable-batch -> Root
2020
mutable-batch->batch ; Mutable-batch -> Batch
21+
make-mutable-batch ; Mutable-batch
22+
mutable-batch-devour-batchref! ; Mutable-batch -> Batchref -> Idx
2123
batch->mutable-batch ; Batch -> Mutable-batch
22-
batch-push!) ; Mutable-batch -> Expr -> Index
24+
batch-push!) ; Mutable-batch -> Expr -> Idx
2325

2426
;; This function defines the recursive structure of expressions
2527
(define (expr-recurse expr f)
2628
(match expr
27-
[(approx spec impl) (approx spec (f impl))]
29+
[(approx spec impl) (approx (f spec) (f impl))]
2830
[(list op args ...) (cons op (map f args))]
2931
[_ expr]))
3032

@@ -89,11 +91,18 @@
8991
(timeline-push! 'compiler size (batch-length final)))
9092
final)
9193

92-
(define (mutable-batch-add-expr! b expr)
94+
(define (mutable-batch-munge! b expr)
9395
(define (munge prog)
9496
(batch-push! b (expr-recurse prog munge)))
9597
(munge expr))
9698

99+
(define (mutable-batch-devour-batchref! b ref)
100+
(match-define (batchref b* idx) ref)
101+
(define nodes* (batch-nodes b*))
102+
(define (munge idx)
103+
(batch-push! b (expr-recurse (vector-ref nodes* idx) munge)))
104+
(munge idx))
105+
97106
(define (batch->progs b [roots (batch-roots b)])
98107
(define exprs (make-vector (batch-length b)))
99108
(for ([node (in-vector (batch-nodes b))]
@@ -125,7 +134,8 @@
125134
; The function removes any zombie nodes from batch with respect to the roots
126135
; Time complexity: O(|R| + |N|), where |R| - number of roots, |N| - length of nodes
127136
; Space complexity: O(|N| + |N*| + |R|), where |N*| is a length of nodes without zombie nodes
128-
(define (batch-remove-zombie input-batch [roots (batch-roots input-batch)])
137+
; The flag keep-vars is used in compiler.rkt when vars should be preserved no matter what
138+
(define (batch-remove-zombie input-batch [roots (batch-roots input-batch)] #:keep-vars [keep-vars #f])
129139
(define nodes (batch-nodes input-batch))
130140
(define nodes-length (batch-length input-batch))
131141

@@ -142,6 +152,10 @@
142152
(vector-ref mappings idx))
143153

144154
(define out (make-mutable-batch))
155+
(when keep-vars
156+
(for ([var (in-list (batch-vars input-batch))])
157+
(batch-push! out var)))
158+
145159
(for ([node (in-vector nodes)]
146160
[zmb (in-vector zombie-mask)]
147161
[n (in-naturals)]
@@ -155,7 +169,7 @@
155169
(define (unmunge reg)
156170
(define node (vector-ref (batch-nodes batch) reg))
157171
(match node
158-
[(approx spec impl) (approx spec (unmunge impl))]
172+
[(approx spec impl) (approx (unmunge spec) (unmunge impl))]
159173
[(list op regs ...) (cons op (map unmunge regs))]
160174
[_ node]))
161175
(unmunge reg))
@@ -167,7 +181,6 @@
167181

168182
(define (egg-nodes->batch egg-nodes id->spec input-batch rename-dict)
169183
(define out (batch->mutable-batch input-batch))
170-
171184
; This fuction here is only because of cycles in loads:( Can not be imported from egg-herbie.rkt
172185
(define (egg-parsed->expr expr rename-dict type)
173186
(let loop ([expr expr]
@@ -202,7 +215,8 @@
202215
(error 'regraph-extract-variants "no initial approx node in eclass"))
203216
(define spec-type (if (representation? type) (representation-type type) type))
204217
(define final-spec (egg-parsed->expr spec* rename-dict spec-type))
205-
(approx final-spec (add-enode (eggref impl) type))]
218+
(define final-spec-idx (mutable-batch-munge! out final-spec))
219+
(approx final-spec-idx (add-enode (eggref impl) type))]
206220
[(list 'if cond ift iff)
207221
(if (representation? type)
208222
(list 'if
@@ -272,12 +286,14 @@
272286
#:roots (vector 5)))
273287
(check-equal? (vector 0 1/2 '(+ 0 1))
274288
(zombie-test #:nodes (vector 0 1/2 '(+ 0 1) '(* 2 0)) #:roots (vector 2)))
275-
(check-equal? (vector 0 (approx '(exp 2) 0))
276-
(zombie-test #:nodes (vector 0 1/2 '(+ 0 1) '(* 2 0) (approx '(exp 2) 0))
277-
#:roots (vector 4)))
278-
(check-equal? (vector 2 1/2 (approx '(* x x) 0) '(pow 1 2))
279-
(zombie-test #:nodes (vector 2 1/2 '(sqrt 0) '(cbrt 0) (approx '(* x x) 0) '(pow 1 4))
289+
(check-equal? (vector 0 1/2 '(exp 1) (approx 2 0))
290+
(zombie-test #:nodes (vector 0 1/2 '(+ 0 1) '(* 2 0) '(exp 1) (approx 4 0))
280291
#:roots (vector 5)))
281-
(check-equal? (vector 2 1/2 '(sqrt 0) (approx '(* x x) 0) '(pow 1 3))
282-
(zombie-test #:nodes (vector 2 1/2 '(sqrt 0) '(cbrt 0) (approx '(* x x) 0) '(pow 1 4))
283-
#:roots (vector 5 2))))
292+
(check-equal? (vector 'x 2 1/2 '(* 0 0) (approx 3 1) '(pow 2 4))
293+
(zombie-test #:nodes
294+
(vector 'x 2 1/2 '(sqrt 1) '(cbrt 1) '(* 0 0) (approx 5 1) '(pow 2 6))
295+
#:roots (vector 7)))
296+
(check-equal? (vector 'x 2 1/2 '(sqrt 1) '(* 0 0) (approx 4 1) '(pow 2 5))
297+
(zombie-test #:nodes
298+
(vector 'x 2 1/2 '(sqrt 1) '(cbrt 1) '(* 0 0) (approx 5 1) '(pow 2 6))
299+
#:roots (vector 7 3))))

src/core/compiler.rkt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,11 @@
6161
;; Requires some hooks to complete the translation.
6262
(define (make-compiler exprs vars)
6363
(define num-vars (length vars))
64-
(define batch (batch-remove-approx (progs->batch exprs #:timeline-push #t #:vars vars)))
64+
65+
; Here we need to keep vars even though no roots refer to the vars
66+
(define batch
67+
(batch-remove-zombie (batch-remove-approx (progs->batch exprs #:timeline-push #t #:vars vars))
68+
#:keep-vars #t))
6569

6670
(define instructions
6771
(for/vector #:length (- (batch-length batch) num-vars)

src/core/egg-herbie.rkt

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -124,18 +124,9 @@
124124
[(? symbol? x) (egraph_add_node ptr (symbol->string x) 0-vec root?)]
125125
[(? number? n) (egraph_add_node ptr (number->string n) 0-vec root?)]))
126126

127-
; The function recurses on spec
128-
(define (batch-parse-approx batch)
129-
(batch-replace batch
130-
(lambda (node)
131-
(match node
132-
[(approx spec impl) (list '$approx spec impl)]
133-
[_ node]))))
134-
135127
(set-batch-roots! batch roots) ; make sure that we work with the right roots
136128
; the algorithm may crash if batch-length is zero
137-
(define insert-batch
138-
(if (zero? (batch-length batch)) batch (batch-remove-zombie (batch-parse-approx batch))))
129+
(define insert-batch (if (zero? (batch-length batch)) batch (batch-remove-zombie batch)))
139130

140131
(define mappings (build-vector (batch-length insert-batch) values))
141132
(define (remap x)
@@ -153,7 +144,7 @@
153144
[(literal v _) v]
154145
[(? number?) node]
155146
[(? symbol?) (normalize-var node)]
156-
[(list '$approx spec impl)
147+
[(approx spec impl)
157148
(hash-ref! id->spec
158149
(remap spec)
159150
(lambda ()

src/core/mainloop.rkt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,11 +221,11 @@
221221

222222
;; Converts a patch to full alt with valid history
223223
(define (reconstruct! alts)
224-
225-
;; extracts the base expression of a patch
224+
(define reconstruct-batch (make-mutable-batch))
225+
;; extracts the base expressions of a patch as a batchref
226226
(define (get-starting-expr altn)
227227
(match* ((alt-event altn) (alt-prevs altn))
228-
[((list 'patch expr _) _) expr]
228+
[((list 'patch expr _) _) expr] ; here original Expr can be pulled as well
229229
[(_ (list prev)) (get-starting-expr prev)]
230230
[(_ _) (error 'get-starting-spec "unexpected: ~a" altn)]))
231231

src/core/patch.rkt

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,11 @@
6464
(define prev (car (alt-prevs altn)))
6565
(for ([expr (in-list simplified)])
6666
(define spec (prog->spec (batchref->expr (alt-expr prev))))
67-
(match-define (batchref b idx) expr)
68-
(define idx* (batch-push! global-batch-mutable (approx spec idx)))
69-
(sow (alt (batchref global-batch idx*) `(simplify ,runner #f #f) (list altn) '()))))))
67+
(define idx
68+
(batch-push! global-batch-mutable
69+
(approx (mutable-batch-munge! global-batch-mutable spec)
70+
(batchref-idx expr))))
71+
(sow (alt (batchref global-batch idx) `(simplify ,runner #f #f) (list altn) '()))))))
7072

7173
; Commit changes to global-batch
7274
(set-batch-nodes! global-batch (list->vector (reverse (mutable-batch-nodes global-batch-mutable))))
@@ -113,7 +115,7 @@
113115
(for ([i (in-range (*taylor-order-limit*))])
114116
(define gen (genexpr))
115117
(unless (spec-has-nan? gen)
116-
(define idx (mutable-batch-add-expr! global-batch-mutable gen))
118+
(define idx (mutable-batch-munge! global-batch-mutable gen))
117119
; we create a batchref that doesn't exist yet in global-batch, we update it later
118120
(sow (alt (batchref global-batch idx) `(taylor ,name ,var) (list altn) '())))))
119121
(timeline-stop!))))

0 commit comments

Comments
 (0)