Skip to content

Commit 31e078a

Browse files
authored
Merge pull request #1330 from herbie-fp/no-immutable-batches
No immutable batches
2 parents 2db7c82 + 35ea367 commit 31e078a

File tree

11 files changed

+293
-195
lines changed

11 files changed

+293
-195
lines changed

src/core/alt-table.rkt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828

2929
(define (alt-batch-cost batch repr)
3030
(define node-cost-proc (platform-node-cost-proc (*active-platform*)))
31-
(define costs (make-vector (vector-length (batch-nodes batch)) 0))
32-
(for ([node (in-vector (batch-nodes batch))]
31+
(define costs (make-vector (batch-length batch) 0))
32+
(for ([node (in-batch batch)]
3333
[i (in-naturals)])
3434
(define cost
3535
(match node

src/core/batch.rkt

Lines changed: 88 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -2,38 +2,45 @@
22

33
(require "../syntax/syntax.rkt"
44
"../utils/common.rkt"
5-
"../utils/alternative.rkt") ; for unbatchify-alts
5+
"../utils/alternative.rkt" ; for unbatchify-alts
6+
"dvector.rkt")
67

78
(provide progs->batch ; List<Expr> -> Batch
89
batch->progs ; Batch -> ?(or List<Root> Vector<Root>) -> List<Expr>
10+
911
(struct-out batch)
10-
(struct-out batchref)
11-
(struct-out mutable-batch)
12+
make-batch ; Batch
13+
batch-push! ; Batch -> Node -> Idx
14+
batch-munge! ; Batch -> Expr -> Root
15+
batch-copy ; Batch -> Batch
1216
batch-length ; Batch -> Integer
1317
batch-tree-size ; Batch -> Integer
1418
batch-free-vars
15-
batch-ref ; Batch -> Idx -> Expr
16-
deref ; Batchref -> Expr
19+
in-batch ; Batch -> Sequence<Node>
20+
batch-ref ; Batch -> Idx -> Node
21+
batch-pull ; Batch -> Idx -> Expr
1722
batch-replace ; Batch -> (Expr<Batchref> -> Expr<Batchref>) -> Batch
18-
debatchref ; Batchref -> Expr
1923
batch-alive-nodes ; Batch -> ?Vector<Root> -> Vector<Idx>
2024
batch-reconstruct-exprs ; Batch -> Vector<Expr>
2125
batch-remove-zombie ; Batch -> ?Vector<Root> -> Batch
22-
mutable-batch-munge! ; Mutable-batch -> Expr -> Root
23-
make-mutable-batch ; Mutable-batch
24-
batch->mutable-batch ; Batch -> Mutable-batch
25-
batch-copy-mutable-nodes! ; Batch -> Mutable-batch -> Void
26-
mutable-batch-push! ; Mutable-batch -> Node -> Idx
27-
batch-copy
26+
27+
(struct-out batchref)
28+
deref ; Batchref -> Expr
29+
debatchref ; Batchref -> Expr
30+
2831
unbatchify-alts)
2932

3033
;; Batches store these recursive structures, flattened
31-
(struct batch ([nodes #:mutable] [roots #:mutable]))
32-
33-
(struct mutable-batch ([nodes #:mutable] [index #:mutable] cache))
34+
(struct batch ([nodes #:mutable] [index #:mutable] cache [roots #:mutable]))
3435

3536
(struct batchref (batch idx) #:transparent)
3637

38+
(define (make-batch)
39+
(batch (make-dvector) (make-hash) (make-hasheq) (vector)))
40+
41+
(define (in-batch batch [start 0] [end #f] [step 1])
42+
(in-dvector (batch-nodes batch) start end step))
43+
3744
;; This function defines the recursive structure of expressions
3845
(define (expr-recurse expr f)
3946
(match expr
@@ -55,69 +62,58 @@
5562
(map (curry alt-map unmunge) altns))
5663

5764
(define (batch-length b)
58-
(cond
59-
[(batch? b) (vector-length (batch-nodes b))]
60-
[(mutable-batch? b) (hash-count (mutable-batch-index b))]
61-
[else (error 'batch-length "Invalid batch" b)]))
62-
63-
(define (make-mutable-batch)
64-
(mutable-batch '() (make-hash) (make-hasheq)))
65+
(dvector-length (batch-nodes b)))
6566

66-
(define (mutable-batch-push! b term)
67-
(define hashcons (mutable-batch-index b))
67+
(define (batch-push! b term)
68+
(define hashcons (batch-index b))
6869
(hash-ref! hashcons
6970
term
7071
(lambda ()
71-
(define new-idx (hash-count hashcons))
72-
(hash-set! hashcons term new-idx)
73-
(set-mutable-batch-nodes! b (cons term (mutable-batch-nodes b)))
74-
new-idx)))
75-
76-
(define (mutable-batch->batch b roots)
77-
(batch (list->vector (reverse (mutable-batch-nodes b))) roots))
78-
79-
(define (batch->mutable-batch b)
80-
(mutable-batch (reverse (vector->list (batch-nodes b))) (batch-restore-index b) (make-hasheq)))
81-
82-
(define (batch-copy-mutable-nodes! b mb)
83-
(set-batch-nodes! b (list->vector (reverse (mutable-batch-nodes mb)))))
72+
(define idx (hash-count hashcons))
73+
(hash-set! hashcons term idx)
74+
(dvector-add! (batch-nodes b) term)
75+
idx)))
8476

8577
(define (batch-copy b)
86-
(batch (vector-copy (batch-nodes b)) (vector-copy (batch-roots b))))
78+
(batch (dvector-copy (batch-nodes b))
79+
(hash-copy (batch-index b))
80+
(hash-copy (batch-cache b))
81+
(vector-copy (batch-roots b))))
8782

8883
(define (deref x)
8984
(match-define (batchref b idx) x)
90-
(expr-recurse (vector-ref (batch-nodes b) idx) (lambda (ref) (batchref b ref))))
85+
(expr-recurse (batch-ref b idx) (lambda (ref) (batchref b ref))))
9186

9287
(define (debatchref x)
9388
(match-define (batchref b idx) x)
94-
(batch-ref b idx))
89+
(batch-pull b idx))
9590

9691
(define (progs->batch exprs #:vars [vars '()])
97-
(define out (make-mutable-batch))
92+
(define out (make-batch))
9893

9994
(for ([var (in-list vars)])
100-
(mutable-batch-push! out var))
95+
(batch-push! out var))
10196
(define roots
10297
(for/vector #:length (length exprs)
10398
([expr (in-list exprs)])
104-
(mutable-batch-munge! out expr)))
99+
(batch-munge! out expr)))
105100

106-
(mutable-batch->batch out roots))
101+
(set-batch-roots! out roots)
102+
out)
107103

108104
(define (batch-tree-size b)
109-
(define len (vector-length (batch-nodes b)))
105+
(define len (batch-length b))
110106
(define counts (make-vector len 0))
111107
(for ([i (in-naturals)]
112-
[node (in-vector (batch-nodes b))])
108+
[node (in-batch b)])
113109
(define args (reap [sow] (expr-recurse node sow)))
114110
(vector-set! counts i (apply + 1 (map (curry vector-ref counts) args))))
115111
(apply + (map (curry vector-ref counts) (vector->list (batch-roots b)))))
116112

117-
(define (mutable-batch-munge! b expr)
118-
(define cache (mutable-batch-cache b))
113+
(define (batch-munge! b expr)
114+
(define cache (batch-cache b))
119115
(define (munge prog)
120-
(hash-ref! cache prog (lambda () (mutable-batch-push! b (expr-recurse prog munge)))))
116+
(hash-ref! cache prog (lambda () (batch-push! b (expr-recurse prog munge)))))
121117
(munge expr))
122118

123119
(define (batch->progs b [roots (batch-roots b)])
@@ -126,9 +122,9 @@
126122
(vector-ref exprs root)))
127123

128124
(define (batch-free-vars batch)
129-
(define out (make-vector (vector-length (batch-nodes batch))))
125+
(define out (make-vector (batch-length batch)))
130126
(for ([i (in-naturals)]
131-
[node (in-vector (batch-nodes batch))])
127+
[node (in-batch batch)])
132128
(define fv
133129
(cond
134130
[(symbol? node) (set node)]
@@ -140,9 +136,9 @@
140136
out)
141137

142138
(define (batch-replace b f)
143-
(define out (make-mutable-batch))
139+
(define out (make-batch))
144140
(define mapping (make-vector (batch-length b) -1))
145-
(for ([node (in-vector (batch-nodes b))]
141+
(for ([node (in-batch b)]
146142
[idx (in-naturals)])
147143
(define replacement (f (expr-recurse node (lambda (x) (batchref b x)))))
148144
(define final-idx
@@ -154,31 +150,31 @@
154150
(when (= -1 (vector-ref mapping idx))
155151
(error 'batch-replace "Replacement ~a references unknown index ~a" replacement idx))
156152
(vector-ref mapping idx)]
157-
[_ (mutable-batch-push! out (expr-recurse expr loop))])))
153+
[_ (batch-push! out (expr-recurse expr loop))])))
158154
(vector-set! mapping idx final-idx))
159155
(define roots (vector-map (curry vector-ref mapping) (batch-roots b)))
160-
(mutable-batch->batch out roots))
156+
(set-batch-roots! out roots)
157+
out)
161158

162159
;; Function returns indices of alive nodes within a batch for given roots,
163160
;; where alive node is a child of a root + meets a condition - (condition node)
164161
(define (batch-alive-nodes batch
165162
[roots (batch-roots batch)]
166163
#:keep-vars-alive [keep-vars-alive #f]
167164
#:condition [condition (const #t)])
168-
(define nodes (batch-nodes batch))
169-
(define nodes-length (batch-length batch))
170-
(define alive-mask (make-vector nodes-length #f))
165+
(define len (batch-length batch))
166+
(define alive-mask (make-vector len #f))
171167
(for ([root (in-vector roots)])
172168
(vector-set! alive-mask root #t))
173-
(for ([i (in-range (- nodes-length 1) -1 -1)]
174-
[node (in-vector nodes (- nodes-length 1) -1 -1)]
175-
[alv (in-vector alive-mask (- nodes-length 1) -1 -1)]
169+
(for ([i (in-range (- len 1) -1 -1)]
170+
[node (in-batch batch (- len 1) -1 -1)]
171+
[alv (in-vector alive-mask (- len 1) -1 -1)]
176172
#:when (or (and alv (condition node)) (and keep-vars-alive (symbol? node))))
177173
(unless alv ; if keep-vars-alive then alv may not be #t, making sure it's #t
178174
(vector-set! alive-mask i #t))
179175
(expr-recurse node
180176
(λ (n)
181-
(when (condition (vector-ref nodes n))
177+
(when (condition (batch-ref batch n))
182178
(vector-set! alive-mask n #t)))))
183179
; Return indices of alive nodes in ascending order
184180
(for/vector ([alv (in-vector alive-mask)]
@@ -189,7 +185,7 @@
189185
;; Function constructs a vector of expressions for the given nodes of a batch
190186
(define (batch-reconstruct-exprs batch)
191187
(define exprs (make-vector (batch-length batch)))
192-
(for ([node (in-vector (batch-nodes batch))]
188+
(for ([node (in-batch batch)]
193189
[idx (in-naturals)])
194190
(vector-set! exprs idx (expr-recurse node (lambda (x) (vector-ref exprs x)))))
195191
exprs)
@@ -199,36 +195,34 @@
199195
;; Space complexity: O(|N| + |N*| + |R|), where |N*| is a length of nodes without zombie nodes
200196
;; The flag keep-vars is used in compiler.rkt when vars should be preserved no matter what
201197
(define (batch-remove-zombie batch [roots (batch-roots batch)] #:keep-vars [keep-vars #f])
202-
(define nodes (batch-nodes batch))
203-
(define nodes-length (batch-length batch))
204-
(match (zero? nodes-length)
198+
(define len (batch-length batch))
199+
(match (zero? len)
205200
[#f
206201
(define alive-nodes (batch-alive-nodes batch roots #:keep-vars-alive keep-vars))
207202

208-
(define mappings (make-vector nodes-length -1))
203+
(define mappings (make-vector len -1))
209204
(define (remap idx)
210205
(vector-ref mappings idx))
211206

212-
(define out (make-mutable-batch))
207+
(define out (make-batch))
213208
(for ([alv (in-vector alive-nodes)])
214-
(define node (vector-ref nodes alv))
215-
(vector-set! mappings alv (mutable-batch-push! out (expr-recurse node remap))))
209+
(define node (batch-ref batch alv))
210+
(vector-set! mappings alv (batch-push! out (expr-recurse node remap))))
216211

217212
(define roots* (vector-map (curry vector-ref mappings) roots))
218-
(mutable-batch->batch out roots*)]
213+
(set-batch-roots! out roots*)
214+
out]
219215
[#t (batch-copy batch)]))
220216

221217
(define (batch-ref batch reg)
218+
(dvector-ref (batch-nodes batch) reg))
219+
220+
(define (batch-pull batch reg)
222221
(define (unmunge reg)
223-
(define node (vector-ref (batch-nodes batch) reg))
222+
(define node (batch-ref batch reg))
224223
(expr-recurse node unmunge))
225224
(unmunge reg))
226225

227-
(define (batch-restore-index batch)
228-
(make-hash (for/list ([node (in-vector (batch-nodes batch))]
229-
[n (in-naturals)])
230-
(cons node n))))
231-
232226
; Tests for progs->batch and batch->progs
233227
(module+ test
234228
(require rackunit)
@@ -253,26 +247,26 @@
253247
(module+ test
254248
(require rackunit)
255249
(define (zombie-test #:nodes nodes #:roots roots)
256-
(define in-batch (batch nodes roots))
250+
(define in-batch (batch nodes (make-hash) (make-hasheq) roots))
257251
(define out-batch (batch-remove-zombie in-batch))
258252
(check-equal? (batch->progs out-batch) (batch->progs in-batch))
259253
(batch-nodes out-batch))
260254

261-
(check-equal? (vector 0 '(sqrt 0) 2 '(pow 2 1))
262-
(zombie-test #:nodes (vector 0 1 '(sqrt 0) 2 '(pow 3 2)) #:roots (vector 4)))
263-
(check-equal? (vector 0 '(sqrt 0) '(exp 1))
264-
(zombie-test #:nodes (vector 0 6 '(pow 0 1) '(* 2 0) '(sqrt 0) '(exp 4))
255+
(check-equal? (create-dvector 0 '(sqrt 0) 2 '(pow 2 1))
256+
(zombie-test #:nodes (create-dvector 0 1 '(sqrt 0) 2 '(pow 3 2)) #:roots (vector 4)))
257+
(check-equal? (create-dvector 0 '(sqrt 0) '(exp 1))
258+
(zombie-test #:nodes (create-dvector 0 6 '(pow 0 1) '(* 2 0) '(sqrt 0) '(exp 4))
265259
#:roots (vector 5)))
266-
(check-equal? (vector 0 1/2 '(+ 0 1))
267-
(zombie-test #:nodes (vector 0 1/2 '(+ 0 1) '(* 2 0)) #:roots (vector 2)))
268-
(check-equal? (vector 0 1/2 '(exp 1) (approx 2 0))
269-
(zombie-test #:nodes (vector 0 1/2 '(+ 0 1) '(* 2 0) '(exp 1) (approx 4 0))
260+
(check-equal? (create-dvector 0 1/2 '(+ 0 1))
261+
(zombie-test #:nodes (create-dvector 0 1/2 '(+ 0 1) '(* 2 0)) #:roots (vector 2)))
262+
(check-equal? (create-dvector 0 1/2 '(exp 1) (approx 2 0))
263+
(zombie-test #:nodes (create-dvector 0 1/2 '(+ 0 1) '(* 2 0) '(exp 1) (approx 4 0))
270264
#:roots (vector 5)))
271-
(check-equal? (vector 'x 2 1/2 '(* 0 0) (approx 3 1) '(pow 2 4))
272-
(zombie-test #:nodes
273-
(vector 'x 2 1/2 '(sqrt 1) '(cbrt 1) '(* 0 0) (approx 5 1) '(pow 2 6))
274-
#:roots (vector 7)))
275-
(check-equal? (vector 'x 2 1/2 '(sqrt 1) '(* 0 0) (approx 4 1) '(pow 2 5))
276-
(zombie-test #:nodes
277-
(vector 'x 2 1/2 '(sqrt 1) '(cbrt 1) '(* 0 0) (approx 5 1) '(pow 2 6))
278-
#:roots (vector 7 3))))
265+
(check-equal?
266+
(create-dvector 'x 2 1/2 '(* 0 0) (approx 3 1) '(pow 2 4))
267+
(zombie-test #:nodes (create-dvector 'x 2 1/2 '(sqrt 1) '(cbrt 1) '(* 0 0) (approx 5 1) '(pow 2 6))
268+
#:roots (vector 7)))
269+
(check-equal?
270+
(create-dvector 'x 2 1/2 '(sqrt 1) '(* 0 0) (approx 4 1) '(pow 2 5))
271+
(zombie-test #:nodes (create-dvector 'x 2 1/2 '(sqrt 1) '(cbrt 1) '(* 0 0) (approx 5 1) '(pow 2 6))
272+
#:roots (vector 7 3))))

src/core/compiler.rkt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373

7474
(define instructions
7575
(for/vector #:length (- (batch-length batch*) num-vars)
76-
([node (in-vector (batch-nodes batch*) num-vars)])
76+
([node (in-batch batch* num-vars)])
7777
(match node
7878
[(literal value (app get-representation repr)) (list (const (real->repr value repr)))]
7979
[(list op args ...) (cons (impl-info op 'fl) args)])))

0 commit comments

Comments
 (0)