|
2 | 2 |
|
3 | 3 | (require "../syntax/syntax.rkt" |
4 | 4 | "../utils/common.rkt" |
5 | | - "../utils/alternative.rkt") ; for unbatchify-alts |
| 5 | + "../utils/alternative.rkt" ; for unbatchify-alts |
| 6 | + "dvector.rkt") |
6 | 7 |
|
7 | 8 | (provide progs->batch ; List<Expr> -> Batch |
8 | 9 | batch->progs ; Batch -> ?(or List<Root> Vector<Root>) -> List<Expr> |
| 10 | + |
9 | 11 | (struct-out batch) |
10 | | - (struct-out batchref) |
11 | | - (struct-out mutable-batch) |
| 12 | + make-batch ; Batch |
| 13 | + batch-push! ; Batch -> Node -> Idx |
| 14 | + batch-munge! ; Batch -> Expr -> Root |
| 15 | + batch-copy ; Batch -> Batch |
12 | 16 | batch-length ; Batch -> Integer |
13 | 17 | batch-tree-size ; Batch -> Integer |
14 | 18 | batch-free-vars |
15 | | - batch-ref ; Batch -> Idx -> Expr |
16 | | - deref ; Batchref -> Expr |
| 19 | + in-batch ; Batch -> Sequence<Node> |
| 20 | + batch-ref ; Batch -> Idx -> Node |
| 21 | + batch-pull ; Batch -> Idx -> Expr |
17 | 22 | batch-replace ; Batch -> (Expr<Batchref> -> Expr<Batchref>) -> Batch |
18 | | - debatchref ; Batchref -> Expr |
19 | 23 | batch-alive-nodes ; Batch -> ?Vector<Root> -> Vector<Idx> |
20 | 24 | batch-reconstruct-exprs ; Batch -> Vector<Expr> |
21 | 25 | batch-remove-zombie ; Batch -> ?Vector<Root> -> Batch |
22 | | - mutable-batch-munge! ; Mutable-batch -> Expr -> Root |
23 | | - make-mutable-batch ; Mutable-batch |
24 | | - batch->mutable-batch ; Batch -> Mutable-batch |
25 | | - batch-copy-mutable-nodes! ; Batch -> Mutable-batch -> Void |
26 | | - mutable-batch-push! ; Mutable-batch -> Node -> Idx |
27 | | - batch-copy |
| 26 | + |
| 27 | + (struct-out batchref) |
| 28 | + deref ; Batchref -> Expr |
| 29 | + debatchref ; Batchref -> Expr |
| 30 | + |
28 | 31 | unbatchify-alts) |
29 | 32 |
|
30 | 33 | ;; Batches store these recursive structures, flattened |
31 | | -(struct batch ([nodes #:mutable] [roots #:mutable])) |
32 | | - |
33 | | -(struct mutable-batch ([nodes #:mutable] [index #:mutable] cache)) |
| 34 | +(struct batch ([nodes #:mutable] [index #:mutable] cache [roots #:mutable])) |
34 | 35 |
|
35 | 36 | (struct batchref (batch idx) #:transparent) |
36 | 37 |
|
| 38 | +(define (make-batch) |
| 39 | + (batch (make-dvector) (make-hash) (make-hasheq) (vector))) |
| 40 | + |
| 41 | +(define (in-batch batch [start 0] [end #f] [step 1]) |
| 42 | + (in-dvector (batch-nodes batch) start end step)) |
| 43 | + |
37 | 44 | ;; This function defines the recursive structure of expressions |
38 | 45 | (define (expr-recurse expr f) |
39 | 46 | (match expr |
|
55 | 62 | (map (curry alt-map unmunge) altns)) |
56 | 63 |
|
57 | 64 | (define (batch-length b) |
58 | | - (cond |
59 | | - [(batch? b) (vector-length (batch-nodes b))] |
60 | | - [(mutable-batch? b) (hash-count (mutable-batch-index b))] |
61 | | - [else (error 'batch-length "Invalid batch" b)])) |
62 | | - |
63 | | -(define (make-mutable-batch) |
64 | | - (mutable-batch '() (make-hash) (make-hasheq))) |
| 65 | + (dvector-length (batch-nodes b))) |
65 | 66 |
|
66 | | -(define (mutable-batch-push! b term) |
67 | | - (define hashcons (mutable-batch-index b)) |
| 67 | +(define (batch-push! b term) |
| 68 | + (define hashcons (batch-index b)) |
68 | 69 | (hash-ref! hashcons |
69 | 70 | term |
70 | 71 | (lambda () |
71 | | - (define new-idx (hash-count hashcons)) |
72 | | - (hash-set! hashcons term new-idx) |
73 | | - (set-mutable-batch-nodes! b (cons term (mutable-batch-nodes b))) |
74 | | - new-idx))) |
75 | | - |
76 | | -(define (mutable-batch->batch b roots) |
77 | | - (batch (list->vector (reverse (mutable-batch-nodes b))) roots)) |
78 | | - |
79 | | -(define (batch->mutable-batch b) |
80 | | - (mutable-batch (reverse (vector->list (batch-nodes b))) (batch-restore-index b) (make-hasheq))) |
81 | | - |
82 | | -(define (batch-copy-mutable-nodes! b mb) |
83 | | - (set-batch-nodes! b (list->vector (reverse (mutable-batch-nodes mb))))) |
| 72 | + (define idx (hash-count hashcons)) |
| 73 | + (hash-set! hashcons term idx) |
| 74 | + (dvector-add! (batch-nodes b) term) |
| 75 | + idx))) |
84 | 76 |
|
85 | 77 | (define (batch-copy b) |
86 | | - (batch (vector-copy (batch-nodes b)) (vector-copy (batch-roots b)))) |
| 78 | + (batch (dvector-copy (batch-nodes b)) |
| 79 | + (hash-copy (batch-index b)) |
| 80 | + (hash-copy (batch-cache b)) |
| 81 | + (vector-copy (batch-roots b)))) |
87 | 82 |
|
88 | 83 | (define (deref x) |
89 | 84 | (match-define (batchref b idx) x) |
90 | | - (expr-recurse (vector-ref (batch-nodes b) idx) (lambda (ref) (batchref b ref)))) |
| 85 | + (expr-recurse (batch-ref b idx) (lambda (ref) (batchref b ref)))) |
91 | 86 |
|
92 | 87 | (define (debatchref x) |
93 | 88 | (match-define (batchref b idx) x) |
94 | | - (batch-ref b idx)) |
| 89 | + (batch-pull b idx)) |
95 | 90 |
|
96 | 91 | (define (progs->batch exprs #:vars [vars '()]) |
97 | | - (define out (make-mutable-batch)) |
| 92 | + (define out (make-batch)) |
98 | 93 |
|
99 | 94 | (for ([var (in-list vars)]) |
100 | | - (mutable-batch-push! out var)) |
| 95 | + (batch-push! out var)) |
101 | 96 | (define roots |
102 | 97 | (for/vector #:length (length exprs) |
103 | 98 | ([expr (in-list exprs)]) |
104 | | - (mutable-batch-munge! out expr))) |
| 99 | + (batch-munge! out expr))) |
105 | 100 |
|
106 | | - (mutable-batch->batch out roots)) |
| 101 | + (set-batch-roots! out roots) |
| 102 | + out) |
107 | 103 |
|
108 | 104 | (define (batch-tree-size b) |
109 | | - (define len (vector-length (batch-nodes b))) |
| 105 | + (define len (batch-length b)) |
110 | 106 | (define counts (make-vector len 0)) |
111 | 107 | (for ([i (in-naturals)] |
112 | | - [node (in-vector (batch-nodes b))]) |
| 108 | + [node (in-batch b)]) |
113 | 109 | (define args (reap [sow] (expr-recurse node sow))) |
114 | 110 | (vector-set! counts i (apply + 1 (map (curry vector-ref counts) args)))) |
115 | 111 | (apply + (map (curry vector-ref counts) (vector->list (batch-roots b))))) |
116 | 112 |
|
117 | | -(define (mutable-batch-munge! b expr) |
118 | | - (define cache (mutable-batch-cache b)) |
| 113 | +(define (batch-munge! b expr) |
| 114 | + (define cache (batch-cache b)) |
119 | 115 | (define (munge prog) |
120 | | - (hash-ref! cache prog (lambda () (mutable-batch-push! b (expr-recurse prog munge))))) |
| 116 | + (hash-ref! cache prog (lambda () (batch-push! b (expr-recurse prog munge))))) |
121 | 117 | (munge expr)) |
122 | 118 |
|
123 | 119 | (define (batch->progs b [roots (batch-roots b)]) |
|
126 | 122 | (vector-ref exprs root))) |
127 | 123 |
|
128 | 124 | (define (batch-free-vars batch) |
129 | | - (define out (make-vector (vector-length (batch-nodes batch)))) |
| 125 | + (define out (make-vector (batch-length batch))) |
130 | 126 | (for ([i (in-naturals)] |
131 | | - [node (in-vector (batch-nodes batch))]) |
| 127 | + [node (in-batch batch)]) |
132 | 128 | (define fv |
133 | 129 | (cond |
134 | 130 | [(symbol? node) (set node)] |
|
140 | 136 | out) |
141 | 137 |
|
142 | 138 | (define (batch-replace b f) |
143 | | - (define out (make-mutable-batch)) |
| 139 | + (define out (make-batch)) |
144 | 140 | (define mapping (make-vector (batch-length b) -1)) |
145 | | - (for ([node (in-vector (batch-nodes b))] |
| 141 | + (for ([node (in-batch b)] |
146 | 142 | [idx (in-naturals)]) |
147 | 143 | (define replacement (f (expr-recurse node (lambda (x) (batchref b x))))) |
148 | 144 | (define final-idx |
|
154 | 150 | (when (= -1 (vector-ref mapping idx)) |
155 | 151 | (error 'batch-replace "Replacement ~a references unknown index ~a" replacement idx)) |
156 | 152 | (vector-ref mapping idx)] |
157 | | - [_ (mutable-batch-push! out (expr-recurse expr loop))]))) |
| 153 | + [_ (batch-push! out (expr-recurse expr loop))]))) |
158 | 154 | (vector-set! mapping idx final-idx)) |
159 | 155 | (define roots (vector-map (curry vector-ref mapping) (batch-roots b))) |
160 | | - (mutable-batch->batch out roots)) |
| 156 | + (set-batch-roots! out roots) |
| 157 | + out) |
161 | 158 |
|
162 | 159 | ;; Function returns indices of alive nodes within a batch for given roots, |
163 | 160 | ;; where alive node is a child of a root + meets a condition - (condition node) |
164 | 161 | (define (batch-alive-nodes batch |
165 | 162 | [roots (batch-roots batch)] |
166 | 163 | #:keep-vars-alive [keep-vars-alive #f] |
167 | 164 | #:condition [condition (const #t)]) |
168 | | - (define nodes (batch-nodes batch)) |
169 | | - (define nodes-length (batch-length batch)) |
170 | | - (define alive-mask (make-vector nodes-length #f)) |
| 165 | + (define len (batch-length batch)) |
| 166 | + (define alive-mask (make-vector len #f)) |
171 | 167 | (for ([root (in-vector roots)]) |
172 | 168 | (vector-set! alive-mask root #t)) |
173 | | - (for ([i (in-range (- nodes-length 1) -1 -1)] |
174 | | - [node (in-vector nodes (- nodes-length 1) -1 -1)] |
175 | | - [alv (in-vector alive-mask (- nodes-length 1) -1 -1)] |
| 169 | + (for ([i (in-range (- len 1) -1 -1)] |
| 170 | + [node (in-batch batch (- len 1) -1 -1)] |
| 171 | + [alv (in-vector alive-mask (- len 1) -1 -1)] |
176 | 172 | #:when (or (and alv (condition node)) (and keep-vars-alive (symbol? node)))) |
177 | 173 | (unless alv ; if keep-vars-alive then alv may not be #t, making sure it's #t |
178 | 174 | (vector-set! alive-mask i #t)) |
179 | 175 | (expr-recurse node |
180 | 176 | (λ (n) |
181 | | - (when (condition (vector-ref nodes n)) |
| 177 | + (when (condition (batch-ref batch n)) |
182 | 178 | (vector-set! alive-mask n #t))))) |
183 | 179 | ; Return indices of alive nodes in ascending order |
184 | 180 | (for/vector ([alv (in-vector alive-mask)] |
|
189 | 185 | ;; Function constructs a vector of expressions for the given nodes of a batch |
190 | 186 | (define (batch-reconstruct-exprs batch) |
191 | 187 | (define exprs (make-vector (batch-length batch))) |
192 | | - (for ([node (in-vector (batch-nodes batch))] |
| 188 | + (for ([node (in-batch batch)] |
193 | 189 | [idx (in-naturals)]) |
194 | 190 | (vector-set! exprs idx (expr-recurse node (lambda (x) (vector-ref exprs x))))) |
195 | 191 | exprs) |
|
199 | 195 | ;; Space complexity: O(|N| + |N*| + |R|), where |N*| is a length of nodes without zombie nodes |
200 | 196 | ;; The flag keep-vars is used in compiler.rkt when vars should be preserved no matter what |
201 | 197 | (define (batch-remove-zombie batch [roots (batch-roots batch)] #:keep-vars [keep-vars #f]) |
202 | | - (define nodes (batch-nodes batch)) |
203 | | - (define nodes-length (batch-length batch)) |
204 | | - (match (zero? nodes-length) |
| 198 | + (define len (batch-length batch)) |
| 199 | + (match (zero? len) |
205 | 200 | [#f |
206 | 201 | (define alive-nodes (batch-alive-nodes batch roots #:keep-vars-alive keep-vars)) |
207 | 202 |
|
208 | | - (define mappings (make-vector nodes-length -1)) |
| 203 | + (define mappings (make-vector len -1)) |
209 | 204 | (define (remap idx) |
210 | 205 | (vector-ref mappings idx)) |
211 | 206 |
|
212 | | - (define out (make-mutable-batch)) |
| 207 | + (define out (make-batch)) |
213 | 208 | (for ([alv (in-vector alive-nodes)]) |
214 | | - (define node (vector-ref nodes alv)) |
215 | | - (vector-set! mappings alv (mutable-batch-push! out (expr-recurse node remap)))) |
| 209 | + (define node (batch-ref batch alv)) |
| 210 | + (vector-set! mappings alv (batch-push! out (expr-recurse node remap)))) |
216 | 211 |
|
217 | 212 | (define roots* (vector-map (curry vector-ref mappings) roots)) |
218 | | - (mutable-batch->batch out roots*)] |
| 213 | + (set-batch-roots! out roots*) |
| 214 | + out] |
219 | 215 | [#t (batch-copy batch)])) |
220 | 216 |
|
221 | 217 | (define (batch-ref batch reg) |
| 218 | + (dvector-ref (batch-nodes batch) reg)) |
| 219 | + |
| 220 | +(define (batch-pull batch reg) |
222 | 221 | (define (unmunge reg) |
223 | | - (define node (vector-ref (batch-nodes batch) reg)) |
| 222 | + (define node (batch-ref batch reg)) |
224 | 223 | (expr-recurse node unmunge)) |
225 | 224 | (unmunge reg)) |
226 | 225 |
|
227 | | -(define (batch-restore-index batch) |
228 | | - (make-hash (for/list ([node (in-vector (batch-nodes batch))] |
229 | | - [n (in-naturals)]) |
230 | | - (cons node n)))) |
231 | | - |
232 | 226 | ; Tests for progs->batch and batch->progs |
233 | 227 | (module+ test |
234 | 228 | (require rackunit) |
|
253 | 247 | (module+ test |
254 | 248 | (require rackunit) |
255 | 249 | (define (zombie-test #:nodes nodes #:roots roots) |
256 | | - (define in-batch (batch nodes roots)) |
| 250 | + (define in-batch (batch nodes (make-hash) (make-hasheq) roots)) |
257 | 251 | (define out-batch (batch-remove-zombie in-batch)) |
258 | 252 | (check-equal? (batch->progs out-batch) (batch->progs in-batch)) |
259 | 253 | (batch-nodes out-batch)) |
260 | 254 |
|
261 | | - (check-equal? (vector 0 '(sqrt 0) 2 '(pow 2 1)) |
262 | | - (zombie-test #:nodes (vector 0 1 '(sqrt 0) 2 '(pow 3 2)) #:roots (vector 4))) |
263 | | - (check-equal? (vector 0 '(sqrt 0) '(exp 1)) |
264 | | - (zombie-test #:nodes (vector 0 6 '(pow 0 1) '(* 2 0) '(sqrt 0) '(exp 4)) |
| 255 | + (check-equal? (create-dvector 0 '(sqrt 0) 2 '(pow 2 1)) |
| 256 | + (zombie-test #:nodes (create-dvector 0 1 '(sqrt 0) 2 '(pow 3 2)) #:roots (vector 4))) |
| 257 | + (check-equal? (create-dvector 0 '(sqrt 0) '(exp 1)) |
| 258 | + (zombie-test #:nodes (create-dvector 0 6 '(pow 0 1) '(* 2 0) '(sqrt 0) '(exp 4)) |
265 | 259 | #:roots (vector 5))) |
266 | | - (check-equal? (vector 0 1/2 '(+ 0 1)) |
267 | | - (zombie-test #:nodes (vector 0 1/2 '(+ 0 1) '(* 2 0)) #:roots (vector 2))) |
268 | | - (check-equal? (vector 0 1/2 '(exp 1) (approx 2 0)) |
269 | | - (zombie-test #:nodes (vector 0 1/2 '(+ 0 1) '(* 2 0) '(exp 1) (approx 4 0)) |
| 260 | + (check-equal? (create-dvector 0 1/2 '(+ 0 1)) |
| 261 | + (zombie-test #:nodes (create-dvector 0 1/2 '(+ 0 1) '(* 2 0)) #:roots (vector 2))) |
| 262 | + (check-equal? (create-dvector 0 1/2 '(exp 1) (approx 2 0)) |
| 263 | + (zombie-test #:nodes (create-dvector 0 1/2 '(+ 0 1) '(* 2 0) '(exp 1) (approx 4 0)) |
270 | 264 | #:roots (vector 5))) |
271 | | - (check-equal? (vector 'x 2 1/2 '(* 0 0) (approx 3 1) '(pow 2 4)) |
272 | | - (zombie-test #:nodes |
273 | | - (vector 'x 2 1/2 '(sqrt 1) '(cbrt 1) '(* 0 0) (approx 5 1) '(pow 2 6)) |
274 | | - #:roots (vector 7))) |
275 | | - (check-equal? (vector 'x 2 1/2 '(sqrt 1) '(* 0 0) (approx 4 1) '(pow 2 5)) |
276 | | - (zombie-test #:nodes |
277 | | - (vector 'x 2 1/2 '(sqrt 1) '(cbrt 1) '(* 0 0) (approx 5 1) '(pow 2 6)) |
278 | | - #:roots (vector 7 3)))) |
| 265 | + (check-equal? |
| 266 | + (create-dvector 'x 2 1/2 '(* 0 0) (approx 3 1) '(pow 2 4)) |
| 267 | + (zombie-test #:nodes (create-dvector 'x 2 1/2 '(sqrt 1) '(cbrt 1) '(* 0 0) (approx 5 1) '(pow 2 6)) |
| 268 | + #:roots (vector 7))) |
| 269 | + (check-equal? |
| 270 | + (create-dvector 'x 2 1/2 '(sqrt 1) '(* 0 0) (approx 4 1) '(pow 2 5)) |
| 271 | + (zombie-test #:nodes (create-dvector 'x 2 1/2 '(sqrt 1) '(cbrt 1) '(* 0 0) (approx 5 1) '(pow 2 6)) |
| 272 | + #:roots (vector 7 3)))) |
0 commit comments