|
73 | 73 | [egraph-pointer (egraph_copy (egraph-data-egraph-pointer eg-data))])) |
74 | 74 |
|
75 | 75 | ; Adds expressions returning the root ids |
76 | | -; TODO: take a batch rather than list of expressions |
77 | | -(define (egraph-add-exprs egg-data exprs ctx) |
| 76 | +(define (egraph-add-exprs egg-data batch roots ctx) |
78 | 77 | (match-define (egraph-data ptr herbie->egg-dict egg->herbie-dict id->spec) egg-data) |
79 | 78 |
|
80 | 79 | ; lookups the egg name of a variable |
|
125 | 124 | [(? symbol? x) (egraph_add_node ptr (symbol->string x) 0-vec root?)] |
126 | 125 | [(? number? n) (egraph_add_node ptr (number->string n) 0-vec root?)])) |
127 | 126 |
|
| 127 | + ; The function recurses on spec |
| 128 | + (define (batch-parse-approx batch) |
| 129 | + (batch-replace batch |
| 130 | + (lambda (node) |
| 131 | + (match node |
| 132 | + [(approx spec impl) (list '$approx spec impl)] |
| 133 | + [_ node])))) |
| 134 | + |
| 135 | + (set-batch-roots! batch roots) ; make sure that we work with the right roots |
| 136 | + ; the algorithm may crash if batch-length is zero |
| 137 | + (define insert-batch |
| 138 | + (if (zero? (batch-length batch)) batch (remove-zombie-nodes (batch-parse-approx batch)))) |
| 139 | + |
| 140 | + (define mappings (build-vector (batch-length insert-batch) values)) |
| 141 | + (define (remap x) |
| 142 | + (vector-ref mappings x)) |
| 143 | + |
| 144 | + ; Inserting nodes bottom-up |
| 145 | + (define root-mask (make-vector (batch-length insert-batch) #f)) |
| 146 | + (for ([root (in-vector (batch-roots insert-batch))]) |
| 147 | + (vector-set! root-mask root #t)) |
| 148 | + (for ([node (in-vector (batch-nodes insert-batch))] |
| 149 | + [root? (in-vector root-mask)] |
| 150 | + [n (in-naturals)]) |
| 151 | + (define node* |
| 152 | + (match node |
| 153 | + [(literal v _) v] |
| 154 | + [(? number?) node] |
| 155 | + [(? symbol?) (normalize-var node)] |
| 156 | + [(list '$approx spec impl) |
| 157 | + (hash-ref! id->spec |
| 158 | + (remap spec) |
| 159 | + (lambda () |
| 160 | + (define spec* (normalize-spec (batch-ref insert-batch spec))) |
| 161 | + (define type (representation-type (repr-of-node insert-batch impl ctx))) |
| 162 | + (cons spec* type))) ; preserved spec and type for extraction |
| 163 | + (list '$approx (remap spec) (remap impl))] |
| 164 | + [(list op (app remap args) ...) (cons op args)])) |
| 165 | + |
| 166 | + (vector-set! mappings n (insert-node! node* root?))) |
| 167 | + |
| 168 | + ;------------------------- DEBUGGING |
128 | 169 | ; expr -> id |
129 | 170 | ; expression cache |
130 | | - (define expr->id (make-hash)) |
| 171 | + #;(define expr->id (make-hash)) |
131 | 172 |
|
132 | 173 | ; expr -> natural |
133 | 174 | ; inserts an expresison into the e-graph, returning its e-class id. |
134 | | - (define (insert! expr [root? #f]) |
135 | | - ; transform the expression into a node pointing |
136 | | - ; to its child e-classes |
137 | | - (define node |
138 | | - (match expr |
139 | | - [(? number?) expr] |
140 | | - [(? symbol?) (normalize-var expr)] |
141 | | - [(literal v _) v] |
142 | | - [(approx spec impl) |
143 | | - (define spec* (insert! spec)) |
144 | | - (define impl* (insert! impl)) |
145 | | - (hash-ref! id->spec |
146 | | - spec* |
147 | | - (lambda () |
148 | | - (define spec* (normalize-spec spec)) ; preserved spec for extraction |
149 | | - (define type (representation-type (repr-of impl ctx))) ; track type of spec |
150 | | - (cons spec* type))) |
151 | | - (list '$approx spec* impl*)] |
152 | | - [(list op args ...) (cons op (map insert! args))])) |
153 | | - ; always insert the node if it is a root since |
154 | | - ; the e-graph tracks which nodes are roots |
155 | | - (cond |
156 | | - [root? (insert-node! node #t)] |
157 | | - [else (hash-ref! expr->id node (lambda () (insert-node! node #f)))])) |
158 | | - |
159 | | - (for/list ([expr (in-list exprs)]) |
160 | | - (insert! expr #t))) |
| 175 | + #;(define (insert! expr [root? #f]) |
| 176 | + ; transform the expression into a node pointing |
| 177 | + ; to its child e-classes |
| 178 | + (define node |
| 179 | + (match expr |
| 180 | + [(literal v _) v] |
| 181 | + [(? number?) expr] |
| 182 | + [(? symbol?) (normalize-var expr)] |
| 183 | + [(list '$approx spec impl) |
| 184 | + (define spec* (insert! (vector-ref nodes spec))) |
| 185 | + (define impl* (insert! (vector-ref nodes impl))) |
| 186 | + (hash-ref! id->spec |
| 187 | + spec* |
| 188 | + (lambda () |
| 189 | + (define spec* (normalize-spec (batch-ref insert-batch spec))) |
| 190 | + (define type (representation-type (repr-of-node insert-batch impl ctx))) |
| 191 | + (cons spec* type))) |
| 192 | + (list '$approx spec* impl*)] |
| 193 | + [(list op args ...) (cons op (map insert! (map (curry vector-ref nodes) args)))])) |
| 194 | + ; always insert the node if it is a root since |
| 195 | + ; the e-graph tracks which nodes are roots |
| 196 | + (cond |
| 197 | + [root? (insert-node! node #t)] |
| 198 | + [else (hash-ref! expr->id node (lambda () (insert-node! node #f)))])) |
| 199 | + |
| 200 | + #;(define nodes (batch-nodes insert-batch)) |
| 201 | + #;(for/list ([root (in-vector (batch-roots insert-batch))]) |
| 202 | + (insert! (vector-ref nodes root) #t)) |
| 203 | + ; ---------------------- END OF DEBUGGING |
| 204 | + |
| 205 | + (for/list ([root (in-vector (batch-roots insert-batch))]) |
| 206 | + (remap root))) |
161 | 207 |
|
162 | 208 | ;; runs rules on an egraph (optional iteration limit) |
163 | 209 | (define (egraph-run egraph-data ffi-rules node-limit iter-limit scheduler const-folding?) |
|
226 | 272 | (egraph_find (egraph-data-egraph-pointer egraph-data) id)) |
227 | 273 |
|
228 | 274 | (define (egraph-expr-equal? egraph-data expr goal ctx) |
229 | | - (match-define (list id1 id2) (egraph-add-exprs egraph-data (list expr goal) ctx)) |
| 275 | + (define batch (progs->batch (list expr goal))) |
| 276 | + (match-define (list id1 id2) (egraph-add-exprs egraph-data batch (batch-roots batch) ctx)) |
230 | 277 | (= id1 id2)) |
231 | 278 |
|
232 | 279 | ;; returns a flattened list of terms or #f if it failed to expand the proof due to budget |
|
1198 | 1245 | (loop (sub1 num-iters)))] |
1199 | 1246 | [else (values egg-graph iteration-data)]))) |
1200 | 1247 |
|
1201 | | -(define (egraph-run-schedule exprs schedule ctx) |
| 1248 | +(define (egraph-run-schedule batch roots schedule ctx) |
1202 | 1249 | ; allocate the e-graph |
1203 | 1250 | (define egg-graph (make-egraph)) |
1204 | 1251 |
|
1205 | 1252 | ; insert expressions into the e-graph |
1206 | | - (define root-ids (egraph-add-exprs egg-graph exprs ctx)) |
| 1253 | + (define root-ids (egraph-add-exprs egg-graph batch roots ctx)) |
1207 | 1254 |
|
1208 | 1255 | ; run the schedule |
1209 | 1256 | (define egg-graph* |
|
1235 | 1282 |
|
1236 | 1283 | ;; Herbie's version of an egg runner. |
1237 | 1284 | ;; Defines parameters for running rewrite rules with egg |
1238 | | -(struct egg-runner (exprs reprs schedule ctx) |
| 1285 | +(struct egg-runner (batch roots reprs schedule ctx) |
1239 | 1286 | #:transparent ; for equality |
1240 | 1287 | #:methods gen:custom-write ; for abbreviated printing |
1241 | 1288 | [(define (write-proc alt port mode) |
|
1252 | 1299 | ;; - scheduler: `(scheduler . <name>)` [default: backoff] |
1253 | 1300 | ;; - `simple`: run all rules without banning |
1254 | 1301 | ;; - `backoff`: ban rules if the fire too much |
1255 | | -(define (make-egg-runner exprs reprs schedule #:context [ctx (*context*)]) |
| 1302 | +(define (make-egg-runner batch roots reprs schedule #:context [ctx (*context*)]) |
1256 | 1303 | (define (oops! fmt . args) |
1257 | 1304 | (apply error 'verify-schedule! fmt args)) |
1258 | 1305 | ; verify the schedule |
|
1273 | 1320 | [_ (oops! "in instruction `~a`, unknown parameter `~a`" instr param)]))] |
1274 | 1321 | [_ (oops! "expected `(<rules> . <params>)`, got `~a`" instr)])) |
1275 | 1322 | ; make the runner |
1276 | | - (egg-runner exprs reprs schedule ctx)) |
| 1323 | + (egg-runner batch roots reprs schedule ctx)) |
1277 | 1324 |
|
1278 | 1325 | ;; Runs egg using an egg runner. |
1279 | 1326 | ;; |
|
1285 | 1332 | ;; Run egg using runner |
1286 | 1333 | (define ctx (egg-runner-ctx runner)) |
1287 | 1334 | (define-values (root-ids egg-graph) |
1288 | | - (egraph-run-schedule (egg-runner-exprs runner) (egg-runner-schedule runner) ctx)) |
| 1335 | + (egraph-run-schedule (egg-runner-batch runner) |
| 1336 | + (egg-runner-roots runner) |
| 1337 | + (egg-runner-schedule runner) |
| 1338 | + ctx)) |
1289 | 1339 | ; Perform extraction |
1290 | 1340 | (match cmd |
1291 | 1341 | [`(single . ,extractor) ; single expression extraction |
|
0 commit comments