|
55 | 55 |
|
56 | 56 | ;; Constructs an egglog runner - structurally serves the same purpose as egg-runner |
57 | 57 | ;; |
58 | | -;; The schedule is a list of pairs specifying |
59 | | -;; - a list of rules |
60 | | -;; - scheduling parameters: |
61 | | -;; - node limit: `(node . <number>)` |
62 | | -;; - iteration limit: `(iteration . <number>)` |
63 | | -;; - constant fold: `(const-fold? . <boolean>)` [default: #t] |
64 | | -;; - scheduler: `(scheduler . <name>)` [default: backoff] |
65 | | -;; - `simple`: run all rules without banning |
66 | | -;; - `backoff`: ban rules if the fire too much |
| 58 | +;; The schedule is a list of step symbols: |
| 59 | +;; - `lift`: run lifting rules for 1 iteration with simple scheduler |
| 60 | +;; - `rewrite`: run rewrite rules up to node limit with backoff scheduler |
| 61 | +;; - `unsound`: run sound-removal rules for 1 iteration with simple scheduler |
| 62 | +;; - `lower`: run lowering rules for 1 iteration with simple scheduler |
67 | 63 | (define (make-egglog-runner batch brfs reprs schedule ctx) |
68 | 64 | (define (oops! fmt . args) |
69 | 65 | (apply error 'verify-schedule! fmt args)) |
70 | 66 | ; verify the schedule |
71 | | - (for ([instr (in-list schedule)]) |
72 | | - (match instr |
73 | | - [(cons rules params) |
74 | | - |
75 | | - ;; `run` instruction |
76 | | - (unless (or (equal? `lift rules) |
77 | | - (equal? `lower rules) |
78 | | - (and (list? rules) (andmap rule? rules))) |
79 | | - (oops! "expected list of rules: `~a`" rules)) |
80 | | - |
81 | | - (for ([param (in-list params)]) |
82 | | - (match param |
83 | | - [(cons 'node (? nonnegative-integer?)) (void)] |
84 | | - [(cons 'iteration (? nonnegative-integer?)) (void)] |
85 | | - [(cons 'const-fold? (? boolean?)) (void)] |
86 | | - [(cons 'scheduler mode) |
87 | | - (unless (set-member? '(simple backoff) mode) |
88 | | - (oops! "in instruction `~a`, unknown scheduler `~a`" instr mode))] |
89 | | - [_ (oops! "in instruction `~a`, unknown parameter `~a`" instr param)]))] |
90 | | - [_ (oops! "expected `(<rules> . <params>)`, got `~a`" instr)])) |
| 67 | + (for ([step (in-list schedule)]) |
| 68 | + (unless (memq step '(lift lower unsound rewrite)) |
| 69 | + (oops! "unknown schedule step `~a`" step))) |
91 | 70 |
|
92 | 71 | ; make the runner |
93 | 72 | (egglog-runner batch brfs reprs schedule ctx)) |
|
120 | 99 |
|
121 | 100 | ;; 2. User Rules which comes from schedule (need to be translated) |
122 | 101 | (define tag-schedule |
123 | | - (for/list ([i (in-naturals 1)] |
124 | | - [element (in-list (egglog-runner-schedule runner))]) |
125 | | - |
126 | | - (define rule-type (car element)) |
127 | | - (define schedule-params (cdr element)) |
128 | | - (define tag |
129 | | - (match rule-type |
130 | | - ['lift 'lifting] |
131 | | - ['lower 'lowering] |
132 | | - [_ |
133 | | - (define curr-tag (string->symbol (string-append "?tag" (number->string i)))) |
134 | | - ;; Add rulesets |
135 | | - (egglog-program-add! `(ruleset ,curr-tag) curr-program) |
136 | | - |
137 | | - ;; Add the actual egglog rewrite rules |
138 | | - (egglog-program-add-list! (egglog-rewrite-rules rule-type curr-tag) curr-program) |
139 | | - |
140 | | - curr-tag])) |
141 | | - |
142 | | - (cons tag schedule-params))) |
| 102 | + (for/list ([step (in-list (egglog-runner-schedule runner))]) |
| 103 | + (match step |
| 104 | + ['lift 'lift] |
| 105 | + ['lower 'lower] |
| 106 | + ['unsound |
| 107 | + ;; Add the unsound rules |
| 108 | + (egglog-program-add-list! (egglog-rewrite-rules (*sound-removal-rules*) 'unsound) |
| 109 | + curr-program) |
| 110 | + 'unsound] |
| 111 | + ['rewrite |
| 112 | + ;; Add the rewrite ruleset and rules |
| 113 | + (egglog-program-add! `(ruleset rewrite) curr-program) |
| 114 | + (egglog-program-add-list! (egglog-rewrite-rules (*rules*) 'rewrite) curr-program) |
| 115 | + 'rewrite]))) |
143 | 116 |
|
144 | 117 | ;; 3. Inserting expressions into the egglog program and getting a Listof (exprs . extract bindings) |
145 | 118 |
|
|
214 | 187 | ; run-schedule specifies the schedule of rulesets to saturate the egraph |
215 | 188 | ; For performance, it stores the schedule in reverse order, and is reversed at the end |
216 | 189 |
|
217 | | - (for ([(tag schedule-params) (in-dict tag-schedule)]) |
| 190 | + (for ([tag (in-list tag-schedule)]) |
218 | 191 | (match tag |
219 | | - ['lifting |
220 | | - (send-to-egglog (list '(run-schedule (saturate lifting))) |
| 192 | + ['lift |
| 193 | + (send-to-egglog (list '(run-schedule (saturate lift))) |
221 | 194 | egglog-process |
222 | 195 | egglog-output |
223 | 196 | egglog-in |
224 | 197 | err |
225 | 198 | dump-file)] |
226 | 199 |
|
227 | | - ['lowering |
228 | | - (send-to-egglog (list '(run-schedule (saturate lowering))) |
| 200 | + ['lower |
| 201 | + (send-to-egglog (list '(run-schedule (saturate lower))) |
229 | 202 | egglog-process |
230 | 203 | egglog-output |
231 | 204 | egglog-in |
232 | 205 | err |
233 | 206 | dump-file)] |
234 | 207 |
|
235 | | - [_ |
236 | | - ;; Run the current ruleset tag interleaved with const-fold until the best iteration |
| 208 | + ['unsound |
| 209 | + (send-to-egglog (list '(run-schedule (saturate unsound))) |
| 210 | + egglog-process |
| 211 | + egglog-output |
| 212 | + egglog-in |
| 213 | + err |
| 214 | + dump-file)] |
| 215 | + |
| 216 | + ['rewrite |
| 217 | + ;; Run the rewrite ruleset interleaved with const-fold until the best iteration |
237 | 218 | (egglog-unsound-detected-subprocess tag |
238 | | - schedule-params |
239 | 219 | egglog-process |
240 | 220 | egglog-output |
241 | 221 | egglog-in |
|
343 | 323 | ,@(platform-impl-nodes pform min-cost))) |
344 | 324 | (egglog-program-add! typed-graph curr-program) |
345 | 325 |
|
346 | | - (egglog-program-add! `(constructor lower (M String) MTy :unextractable) curr-program) |
| 326 | + (egglog-program-add! `(constructor do-lower (M String) MTy :unextractable) curr-program) |
347 | 327 |
|
348 | | - (egglog-program-add! `(constructor lift (MTy) M :unextractable) curr-program) |
| 328 | + (egglog-program-add! `(constructor do-lift (MTy) M :unextractable) curr-program) |
349 | 329 |
|
350 | 330 | (egglog-program-add! `(ruleset const-fold) curr-program) |
351 | 331 |
|
352 | | - (egglog-program-add! `(ruleset lowering) curr-program) |
| 332 | + (egglog-program-add! `(ruleset lower) curr-program) |
| 333 | + |
| 334 | + (egglog-program-add! `(ruleset lift) curr-program) |
353 | 335 |
|
354 | | - (egglog-program-add! `(ruleset lifting) curr-program) |
| 336 | + (egglog-program-add! `(ruleset unsound) curr-program) |
355 | 337 |
|
356 | | - ;;; Adding function unsound before rules |
| 338 | + ;;; Adding bad-merge detection |
357 | 339 |
|
358 | | - ;; unsound functions |
359 | | - (egglog-program-add! `(function unsound () bool :merge (or old new)) curr-program) |
360 | | - (egglog-program-add! `(ruleset unsound-rule) curr-program) |
361 | | - (egglog-program-add! `(set (unsound) false) curr-program) |
| 340 | + ;; bad-merge detection function and rules |
| 341 | + (egglog-program-add! `(function bad-merge? () bool :merge (or old new)) curr-program) |
| 342 | + (egglog-program-add! `(ruleset bad-merge-rule) curr-program) |
| 343 | + (egglog-program-add! `(set (bad-merge?) false) curr-program) |
362 | 344 |
|
363 | 345 | (egglog-program-add! |
364 | | - `(rule ((= (Num c1) (Num c2)) (!= c1 c2)) ((set (unsound) true)) :ruleset unsound-rule) |
| 346 | + `(rule ((= (Num c1) (Num c2)) (!= c1 c2)) ((set (bad-merge?) true)) :ruleset bad-merge-rule) |
365 | 347 | curr-program) |
366 | 348 |
|
367 | 349 | (for ([curr-expr const-fold]) |
|
487 | 469 | (let etx (,(typed-num-id repr) |
488 | 470 | n) |
489 | 471 | ) |
490 | | - (union (lower e tx) etx)) |
| 472 | + (union (do-lower e tx) etx)) |
491 | 473 | :ruleset |
492 | | - lowering))) |
| 474 | + lower))) |
493 | 475 |
|
494 | 476 | (define (num-lifting-rules) |
495 | 477 | (for/list ([repr (in-list (all-repr-names))] |
|
498 | 480 | ((let se (Num |
499 | 481 | n) |
500 | 482 | ) |
501 | | - (union (lift e) se)) |
| 483 | + (union (do-lift e) se)) |
502 | 484 | :ruleset |
503 | | - lifting))) |
| 485 | + lift))) |
504 | 486 |
|
505 | 487 | (define (approx-lifting-rule) |
506 | | - `(rule ((= e (Approx spec impl))) ((union (lift e) spec)) :ruleset lifting)) |
| 488 | + `(rule ((= e (Approx spec impl))) ((union (do-lift e) spec)) :ruleset lift)) |
507 | 489 |
|
508 | 490 | (define (impl-lowering-rules pform) |
509 | 491 | (for/list ([impl (in-list (platform-impls pform))]) |
|
513 | 495 | ,@(for/list ([v (in-list (impl-info impl 'vars))] |
514 | 496 | [vt (in-list (impl-info impl 'itype))]) |
515 | 497 | `(= ,(string->symbol (string-append "t" (symbol->string v))) |
516 | | - (lower ,v ,(symbol->string (representation-name vt)))))) |
| 498 | + (do-lower ,v ,(symbol->string (representation-name vt)))))) |
517 | 499 | ((let t0 ,(symbol->string (representation-name (impl-info impl 'otype))) |
518 | 500 | ) |
519 | 501 | (let et0 (,(string->symbol (string-append (symbol->string (serialize-impl impl)) "Ty")) |
520 | 502 | ,@(for/list ([v (in-list (impl-info impl 'vars))]) |
521 | 503 | (string->symbol (string-append "t" (symbol->string v))))) |
522 | 504 | ) |
523 | | - (union (lower e t0) et0)) |
| 505 | + (union (do-lower e t0) et0)) |
524 | 506 | :ruleset |
525 | | - lowering))) |
| 507 | + lower))) |
526 | 508 |
|
527 | 509 | (define (impl-lifting-rules pform) |
528 | 510 | (for/list ([impl (in-list (platform-impls pform))]) |
|
534 | 516 | ,@(impl-info impl 'vars))) |
535 | 517 | ,@(for/list ([v (in-list (impl-info impl 'vars))] |
536 | 518 | [vt (in-list (impl-info impl 'itype))]) |
537 | | - `(= ,(string->symbol (string-append "s" (symbol->string v))) (lift ,v)))) |
| 519 | + `(= ,(string->symbol (string-append "s" (symbol->string v))) (do-lift ,v)))) |
538 | 520 | ((let se ,(expr->egglog-spec-serialized spec-expr "s") |
539 | 521 | ) |
540 | | - (union (lift e) se)) |
| 522 | + (union (do-lift e) se)) |
541 | 523 | :ruleset |
542 | | - lifting))) |
| 524 | + lift))) |
543 | 525 |
|
544 | 526 | (define (expr->egglog-spec-serialized expr s) |
545 | 527 | (let loop ([expr expr]) |
|
678 | 660 | ,@(for/list ([arg (in-list args)]) |
679 | 661 | (remap arg (spec? (batchref batch n)))))] |
680 | 662 |
|
681 | | - [(hole ty spec) `(lower ,(remap spec #t) ,(symbol->string ty))])) |
| 663 | + [(hole ty spec) `(do-lower ,(remap spec #t) ,(symbol->string ty))])) |
682 | 664 |
|
683 | 665 | (if node* |
684 | 666 | (vector-set! mappings n (insert-node! node* n root?)) |
|
697 | 679 | (let ety (,(typed-var-id (representation-name repr)) |
698 | 680 | ,(symbol->string var)) |
699 | 681 | ) |
700 | | - (union (lower e ty) ety)) |
| 682 | + (union (do-lower e ty) ety)) |
701 | 683 | :ruleset |
702 | | - lowering)) |
| 684 | + lower)) |
703 | 685 |
|
704 | 686 | (egglog-program-add! curr-var-lowering-rule curr-program)) |
705 | 687 |
|
|
712 | 694 | ((let se (Var |
713 | 695 | ,(symbol->string var)) |
714 | 696 | ) |
715 | | - (union (lift e) se)) |
| 697 | + (union (do-lift e) se)) |
716 | 698 | :ruleset |
717 | | - lifting)) |
| 699 | + lift)) |
718 | 700 |
|
719 | 701 | (egglog-program-add! curr-var-lifting-rule curr-program)) |
720 | 702 |
|
|
780 | 762 |
|
781 | 763 | (define curr-datatype |
782 | 764 | (match actual-binding |
783 | | - [(cons 'lower _) 'MTy] |
784 | | - [(cons 'lift _) 'M] |
| 765 | + [(cons 'do-lower _) 'MTy] |
| 766 | + [(cons 'do-lift _) 'M] |
785 | 767 |
|
786 | 768 | ;; TODO : fix this way of getting spec or impl |
787 | 769 | [_ (if root? 'MTy 'M)])) |
|
810 | 792 |
|
811 | 793 | (values (reverse all-bindings) curr-bindings)) |
812 | 794 |
|
813 | | -(define (egglog-unsound-detected-subprocess tag |
814 | | - params |
815 | | - egglog-process |
816 | | - egglog-output |
817 | | - egglog-in |
818 | | - err |
819 | | - dump-file) |
| 795 | +(define (egglog-unsound-detected-subprocess tag egglog-process egglog-output egglog-in err dump-file) |
820 | 796 |
|
821 | | - (define node-limit (dict-ref params 'node (*node-limit*))) |
822 | | - (define iter-limit (dict-ref params 'iteration (*default-egglog-iter-limit*))) |
| 797 | + (define node-limit (*node-limit*)) |
| 798 | + (define iter-limit (*default-egglog-iter-limit*)) |
823 | 799 |
|
824 | 800 | ;; Algorithm: |
825 | 801 | ;; 1. Run (PUSH) to the save the above state of the egraph |
826 | 802 | ;; 2. Repeat rules based on their ruleset tag once |
827 | | - ;; 3. Run the unsound-rule function ruleset once |
828 | | - ;; 4. Extract the (unsound) function that returns a bool |
829 | | - ;; 5. If (unsound) function returns "true", we have unsoundless, so go to Step 10 for ROLLBACK |
| 803 | + ;; 3. Run the bad-merge-rule ruleset once |
| 804 | + ;; 4. Extract the (bad-merge?) function that returns a bool |
| 805 | + ;; 5. If (bad-merge?) function returns "true", we have a bad merge, so go to Step 10 for ROLLBACK |
830 | 806 | ;; 6. Run (print-size) to get nodes of the form "node_name : num_nodes" for all nodes in egraph |
831 | 807 | ;; 7. If the total number of nodes is more than node-limit, do NOT ROLLBACK and go to Step 11 |
832 | 808 | ;; 8. Repeat rules based on the const-fold tag once and repeat Steps 3-7 |
|
852 | 828 | (list '(push) |
853 | 829 | `(run-schedule (repeat 1 ,tag)) |
854 | 830 | '(print-size) |
855 | | - '(run unsound-rule 1) |
856 | | - '(extract (unsound)))) |
| 831 | + '(run bad-merge-rule 1) |
| 832 | + '(extract (bad-merge?)))) |
857 | 833 |
|
858 | 834 | ;; Get egglog output |
859 | 835 | (define-values (math-unsound? math-node-limit? math-total-nodes) |
|
912 | 888 | (list '(push) |
913 | 889 | `(run-schedule (repeat 1 const-fold)) |
914 | 890 | '(print-size) |
915 | | - '(run unsound-rule 1) |
916 | | - '(extract (unsound)))) |
| 891 | + '(run bad-merge-rule 1) |
| 892 | + '(extract (bad-merge?)))) |
917 | 893 |
|
918 | 894 | (define-values (const-unsound? const-node-limit? const-total-nodes) |
919 | 895 | (get-egglog-output const-schedule |
|
0 commit comments