Skip to content

Commit ecc905e

Browse files
authored
Merge pull request #1419 from herbie-fp/egglog-simpler-schedules
Egglog simpler schedules
2 parents e9e9760 + eeb97c3 commit ecc905e

File tree

2 files changed

+79
-107
lines changed

2 files changed

+79
-107
lines changed

src/core/egglog-herbie-tests.rkt

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -368,11 +368,7 @@
368368

369369
(define reprs (make-list (length brfs) (context-repr ctx)))
370370

371-
(define rules (*rules*))
372-
(define schedule
373-
`((lift . ((iteration . 1) (scheduler . simple)))
374-
(,rules . ((node . ,(*node-limit*)) (scheduler . simple)))
375-
(lower . ((iteration . 1) (scheduler . simple)))))
371+
(define schedule '(lift rewrite lower))
376372

377373
(when (find-executable-path "egglog")
378-
(run-egglog-multi-extractor (egglog-runner batch brfs reprs schedule ctx) batch)))
374+
(run-egglog-multi-extractor (make-egglog-runner batch brfs reprs schedule ctx) batch)))

src/core/egglog-herbie.rkt

Lines changed: 77 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -55,39 +55,18 @@
5555

5656
;; Constructs an egglog runner - structurally serves the same purpose as egg-runner
5757
;;
58-
;; The schedule is a list of pairs specifying
59-
;; - a list of rules
60-
;; - scheduling parameters:
61-
;; - node limit: `(node . <number>)`
62-
;; - iteration limit: `(iteration . <number>)`
63-
;; - constant fold: `(const-fold? . <boolean>)` [default: #t]
64-
;; - scheduler: `(scheduler . <name>)` [default: backoff]
65-
;; - `simple`: run all rules without banning
66-
;; - `backoff`: ban rules if the fire too much
58+
;; The schedule is a list of step symbols:
59+
;; - `lift`: run lifting rules for 1 iteration with simple scheduler
60+
;; - `rewrite`: run rewrite rules up to node limit with backoff scheduler
61+
;; - `unsound`: run sound-removal rules for 1 iteration with simple scheduler
62+
;; - `lower`: run lowering rules for 1 iteration with simple scheduler
6763
(define (make-egglog-runner batch brfs reprs schedule ctx)
6864
(define (oops! fmt . args)
6965
(apply error 'verify-schedule! fmt args))
7066
; verify the schedule
71-
(for ([instr (in-list schedule)])
72-
(match instr
73-
[(cons rules params)
74-
75-
;; `run` instruction
76-
(unless (or (equal? `lift rules)
77-
(equal? `lower rules)
78-
(and (list? rules) (andmap rule? rules)))
79-
(oops! "expected list of rules: `~a`" rules))
80-
81-
(for ([param (in-list params)])
82-
(match param
83-
[(cons 'node (? nonnegative-integer?)) (void)]
84-
[(cons 'iteration (? nonnegative-integer?)) (void)]
85-
[(cons 'const-fold? (? boolean?)) (void)]
86-
[(cons 'scheduler mode)
87-
(unless (set-member? '(simple backoff) mode)
88-
(oops! "in instruction `~a`, unknown scheduler `~a`" instr mode))]
89-
[_ (oops! "in instruction `~a`, unknown parameter `~a`" instr param)]))]
90-
[_ (oops! "expected `(<rules> . <params>)`, got `~a`" instr)]))
67+
(for ([step (in-list schedule)])
68+
(unless (memq step '(lift lower unsound rewrite))
69+
(oops! "unknown schedule step `~a`" step)))
9170

9271
; make the runner
9372
(egglog-runner batch brfs reprs schedule ctx))
@@ -120,26 +99,20 @@
12099

121100
;; 2. User Rules which comes from schedule (need to be translated)
122101
(define tag-schedule
123-
(for/list ([i (in-naturals 1)]
124-
[element (in-list (egglog-runner-schedule runner))])
125-
126-
(define rule-type (car element))
127-
(define schedule-params (cdr element))
128-
(define tag
129-
(match rule-type
130-
['lift 'lifting]
131-
['lower 'lowering]
132-
[_
133-
(define curr-tag (string->symbol (string-append "?tag" (number->string i))))
134-
;; Add rulesets
135-
(egglog-program-add! `(ruleset ,curr-tag) curr-program)
136-
137-
;; Add the actual egglog rewrite rules
138-
(egglog-program-add-list! (egglog-rewrite-rules rule-type curr-tag) curr-program)
139-
140-
curr-tag]))
141-
142-
(cons tag schedule-params)))
102+
(for/list ([step (in-list (egglog-runner-schedule runner))])
103+
(match step
104+
['lift 'lift]
105+
['lower 'lower]
106+
['unsound
107+
;; Add the unsound rules
108+
(egglog-program-add-list! (egglog-rewrite-rules (*sound-removal-rules*) 'unsound)
109+
curr-program)
110+
'unsound]
111+
['rewrite
112+
;; Add the rewrite ruleset and rules
113+
(egglog-program-add! `(ruleset rewrite) curr-program)
114+
(egglog-program-add-list! (egglog-rewrite-rules (*rules*) 'rewrite) curr-program)
115+
'rewrite])))
143116

144117
;; 3. Inserting expressions into the egglog program and getting a Listof (exprs . extract bindings)
145118

@@ -214,28 +187,35 @@
214187
; run-schedule specifies the schedule of rulesets to saturate the egraph
215188
; For performance, it stores the schedule in reverse order, and is reversed at the end
216189

217-
(for ([(tag schedule-params) (in-dict tag-schedule)])
190+
(for ([tag (in-list tag-schedule)])
218191
(match tag
219-
['lifting
220-
(send-to-egglog (list '(run-schedule (saturate lifting)))
192+
['lift
193+
(send-to-egglog (list '(run-schedule (saturate lift)))
221194
egglog-process
222195
egglog-output
223196
egglog-in
224197
err
225198
dump-file)]
226199

227-
['lowering
228-
(send-to-egglog (list '(run-schedule (saturate lowering)))
200+
['lower
201+
(send-to-egglog (list '(run-schedule (saturate lower)))
229202
egglog-process
230203
egglog-output
231204
egglog-in
232205
err
233206
dump-file)]
234207

235-
[_
236-
;; Run the current ruleset tag interleaved with const-fold until the best iteration
208+
['unsound
209+
(send-to-egglog (list '(run-schedule (saturate unsound)))
210+
egglog-process
211+
egglog-output
212+
egglog-in
213+
err
214+
dump-file)]
215+
216+
['rewrite
217+
;; Run the rewrite ruleset interleaved with const-fold until the best iteration
237218
(egglog-unsound-detected-subprocess tag
238-
schedule-params
239219
egglog-process
240220
egglog-output
241221
egglog-in
@@ -343,25 +323,27 @@
343323
,@(platform-impl-nodes pform min-cost)))
344324
(egglog-program-add! typed-graph curr-program)
345325

346-
(egglog-program-add! `(constructor lower (M String) MTy :unextractable) curr-program)
326+
(egglog-program-add! `(constructor do-lower (M String) MTy :unextractable) curr-program)
347327

348-
(egglog-program-add! `(constructor lift (MTy) M :unextractable) curr-program)
328+
(egglog-program-add! `(constructor do-lift (MTy) M :unextractable) curr-program)
349329

350330
(egglog-program-add! `(ruleset const-fold) curr-program)
351331

352-
(egglog-program-add! `(ruleset lowering) curr-program)
332+
(egglog-program-add! `(ruleset lower) curr-program)
333+
334+
(egglog-program-add! `(ruleset lift) curr-program)
353335

354-
(egglog-program-add! `(ruleset lifting) curr-program)
336+
(egglog-program-add! `(ruleset unsound) curr-program)
355337

356-
;;; Adding function unsound before rules
338+
;;; Adding bad-merge detection
357339

358-
;; unsound functions
359-
(egglog-program-add! `(function unsound () bool :merge (or old new)) curr-program)
360-
(egglog-program-add! `(ruleset unsound-rule) curr-program)
361-
(egglog-program-add! `(set (unsound) false) curr-program)
340+
;; bad-merge detection function and rules
341+
(egglog-program-add! `(function bad-merge? () bool :merge (or old new)) curr-program)
342+
(egglog-program-add! `(ruleset bad-merge-rule) curr-program)
343+
(egglog-program-add! `(set (bad-merge?) false) curr-program)
362344

363345
(egglog-program-add!
364-
`(rule ((= (Num c1) (Num c2)) (!= c1 c2)) ((set (unsound) true)) :ruleset unsound-rule)
346+
`(rule ((= (Num c1) (Num c2)) (!= c1 c2)) ((set (bad-merge?) true)) :ruleset bad-merge-rule)
365347
curr-program)
366348

367349
(for ([curr-expr const-fold])
@@ -487,9 +469,9 @@
487469
(let etx (,(typed-num-id repr)
488470
n)
489471
)
490-
(union (lower e tx) etx))
472+
(union (do-lower e tx) etx))
491473
:ruleset
492-
lowering)))
474+
lower)))
493475

494476
(define (num-lifting-rules)
495477
(for/list ([repr (in-list (all-repr-names))]
@@ -498,12 +480,12 @@
498480
((let se (Num
499481
n)
500482
)
501-
(union (lift e) se))
483+
(union (do-lift e) se))
502484
:ruleset
503-
lifting)))
485+
lift)))
504486

505487
(define (approx-lifting-rule)
506-
`(rule ((= e (Approx spec impl))) ((union (lift e) spec)) :ruleset lifting))
488+
`(rule ((= e (Approx spec impl))) ((union (do-lift e) spec)) :ruleset lift))
507489

508490
(define (impl-lowering-rules pform)
509491
(for/list ([impl (in-list (platform-impls pform))])
@@ -513,16 +495,16 @@
513495
,@(for/list ([v (in-list (impl-info impl 'vars))]
514496
[vt (in-list (impl-info impl 'itype))])
515497
`(= ,(string->symbol (string-append "t" (symbol->string v)))
516-
(lower ,v ,(symbol->string (representation-name vt))))))
498+
(do-lower ,v ,(symbol->string (representation-name vt))))))
517499
((let t0 ,(symbol->string (representation-name (impl-info impl 'otype)))
518500
)
519501
(let et0 (,(string->symbol (string-append (symbol->string (serialize-impl impl)) "Ty"))
520502
,@(for/list ([v (in-list (impl-info impl 'vars))])
521503
(string->symbol (string-append "t" (symbol->string v)))))
522504
)
523-
(union (lower e t0) et0))
505+
(union (do-lower e t0) et0))
524506
:ruleset
525-
lowering)))
507+
lower)))
526508

527509
(define (impl-lifting-rules pform)
528510
(for/list ([impl (in-list (platform-impls pform))])
@@ -534,12 +516,12 @@
534516
,@(impl-info impl 'vars)))
535517
,@(for/list ([v (in-list (impl-info impl 'vars))]
536518
[vt (in-list (impl-info impl 'itype))])
537-
`(= ,(string->symbol (string-append "s" (symbol->string v))) (lift ,v))))
519+
`(= ,(string->symbol (string-append "s" (symbol->string v))) (do-lift ,v))))
538520
((let se ,(expr->egglog-spec-serialized spec-expr "s")
539521
)
540-
(union (lift e) se))
522+
(union (do-lift e) se))
541523
:ruleset
542-
lifting)))
524+
lift)))
543525

544526
(define (expr->egglog-spec-serialized expr s)
545527
(let loop ([expr expr])
@@ -678,7 +660,7 @@
678660
,@(for/list ([arg (in-list args)])
679661
(remap arg (spec? (batchref batch n)))))]
680662

681-
[(hole ty spec) `(lower ,(remap spec #t) ,(symbol->string ty))]))
663+
[(hole ty spec) `(do-lower ,(remap spec #t) ,(symbol->string ty))]))
682664

683665
(if node*
684666
(vector-set! mappings n (insert-node! node* n root?))
@@ -697,9 +679,9 @@
697679
(let ety (,(typed-var-id (representation-name repr))
698680
,(symbol->string var))
699681
)
700-
(union (lower e ty) ety))
682+
(union (do-lower e ty) ety))
701683
:ruleset
702-
lowering))
684+
lower))
703685

704686
(egglog-program-add! curr-var-lowering-rule curr-program))
705687

@@ -712,9 +694,9 @@
712694
((let se (Var
713695
,(symbol->string var))
714696
)
715-
(union (lift e) se))
697+
(union (do-lift e) se))
716698
:ruleset
717-
lifting))
699+
lift))
718700

719701
(egglog-program-add! curr-var-lifting-rule curr-program))
720702

@@ -780,8 +762,8 @@
780762

781763
(define curr-datatype
782764
(match actual-binding
783-
[(cons 'lower _) 'MTy]
784-
[(cons 'lift _) 'M]
765+
[(cons 'do-lower _) 'MTy]
766+
[(cons 'do-lift _) 'M]
785767

786768
;; TODO : fix this way of getting spec or impl
787769
[_ (if root? 'MTy 'M)]))
@@ -810,23 +792,17 @@
810792

811793
(values (reverse all-bindings) curr-bindings))
812794

813-
(define (egglog-unsound-detected-subprocess tag
814-
params
815-
egglog-process
816-
egglog-output
817-
egglog-in
818-
err
819-
dump-file)
795+
(define (egglog-unsound-detected-subprocess tag egglog-process egglog-output egglog-in err dump-file)
820796

821-
(define node-limit (dict-ref params 'node (*node-limit*)))
822-
(define iter-limit (dict-ref params 'iteration (*default-egglog-iter-limit*)))
797+
(define node-limit (*node-limit*))
798+
(define iter-limit (*default-egglog-iter-limit*))
823799

824800
;; Algorithm:
825801
;; 1. Run (PUSH) to the save the above state of the egraph
826802
;; 2. Repeat rules based on their ruleset tag once
827-
;; 3. Run the unsound-rule function ruleset once
828-
;; 4. Extract the (unsound) function that returns a bool
829-
;; 5. If (unsound) function returns "true", we have unsoundless, so go to Step 10 for ROLLBACK
803+
;; 3. Run the bad-merge-rule ruleset once
804+
;; 4. Extract the (bad-merge?) function that returns a bool
805+
;; 5. If (bad-merge?) function returns "true", we have a bad merge, so go to Step 10 for ROLLBACK
830806
;; 6. Run (print-size) to get nodes of the form "node_name : num_nodes" for all nodes in egraph
831807
;; 7. If the total number of nodes is more than node-limit, do NOT ROLLBACK and go to Step 11
832808
;; 8. Repeat rules based on the const-fold tag once and repeat Steps 3-7
@@ -852,8 +828,8 @@
852828
(list '(push)
853829
`(run-schedule (repeat 1 ,tag))
854830
'(print-size)
855-
'(run unsound-rule 1)
856-
'(extract (unsound))))
831+
'(run bad-merge-rule 1)
832+
'(extract (bad-merge?))))
857833

858834
;; Get egglog output
859835
(define-values (math-unsound? math-node-limit? math-total-nodes)
@@ -912,8 +888,8 @@
912888
(list '(push)
913889
`(run-schedule (repeat 1 const-fold))
914890
'(print-size)
915-
'(run unsound-rule 1)
916-
'(extract (unsound))))
891+
'(run bad-merge-rule 1)
892+
'(extract (bad-merge?))))
917893

918894
(define-values (const-unsound? const-node-limit? const-total-nodes)
919895
(get-egglog-output const-schedule

0 commit comments

Comments
 (0)