Skip to content

Commit caeb09b

Browse files
committed
Simplify scheduling with pre-set schedule steps
1 parent 4dceae1 commit caeb09b

File tree

3 files changed

+27
-50
lines changed

3 files changed

+27
-50
lines changed

src/core/egg-herbie.rkt

Lines changed: 24 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,10 +1217,11 @@
12171217

12181218
;; Runs rules over the egraph with the given egg parameters.
12191219
;; Invariant: the returned egraph is never unsound
1220-
(define (egraph-run-rules egg-graph0 egg-rules params)
1221-
(define node-limit (dict-ref params 'node #f))
1222-
(define iter-limit (dict-ref params 'iteration #f))
1223-
(define scheduler (dict-ref params 'scheduler 'backoff))
1220+
(define (egraph-run-rules egg-graph0
1221+
egg-rules
1222+
#:node-limit [node-limit #f]
1223+
#:iter-limit [iter-limit #f]
1224+
#:scheduler [scheduler 'backoff])
12241225
(define ffi-rules (map cdr egg-rules))
12251226

12261227
;; run the rules
@@ -1247,14 +1248,18 @@
12471248

12481249
; run the schedule
12491250
(define egg-graph*
1250-
(for/fold ([egg-graph egg-graph]) ([(rules params) (in-dict schedule)])
1251-
; run rules in the egraph
1252-
(define egg-rules
1253-
(expand-rules (match rules
1254-
[`lift (platform-lifting-rules)]
1255-
[`lower (platform-lowering-rules)]
1256-
[else rules])))
1257-
(define-values (egg-graph* iteration-data) (egraph-run-rules egg-graph egg-rules params))
1251+
(for/fold ([egg-graph egg-graph]) ([step (in-list schedule)])
1252+
(define-values (egg-graph* iteration-data)
1253+
(match step
1254+
['lift
1255+
(define rules (expand-rules (platform-lifting-rules)))
1256+
(egraph-run-rules egg-graph rules #:iter-limit 1 #:scheduler 'simple)]
1257+
['lower
1258+
(define rules (expand-rules (platform-lowering-rules)))
1259+
(egraph-run-rules egg-graph rules #:iter-limit 1 #:scheduler 'simple)]
1260+
['rewrite
1261+
(define rules (expand-rules (*rules*)))
1262+
(egraph-run-rules egg-graph rules #:node-limit (*node-limit*))]))
12581263

12591264
; get cost statistics
12601265
(for ([iter (in-list iteration-data)]
@@ -1290,37 +1295,17 @@
12901295

12911296
;; Constructs an egg runner.
12921297
;;
1293-
;; The schedule is a list of pairs specifying
1294-
;; - a list of rules
1295-
;; - scheduling parameters:
1296-
;; - node limit: `(node . <number>)`
1297-
;; - iteration limit: `(iteration . <number>)`
1298-
;; - scheduler: `(scheduler . <name>)` [default: backoff]
1299-
;; - `simple`: run all rules without banning
1300-
;; - `backoff`: ban rules if the fire too much
1298+
;; The schedule is a list of step symbols:
1299+
;; - `lift`: run lifting rules for 1 iteration with simple scheduler
1300+
;; - `rewrite`: run rewrite rules up to node limit with backoff scheduler
1301+
;; - `lower`: run lowering rules for 1 iteration with simple scheduler
13011302
(define (make-egraph batch brfs reprs schedule ctx)
13021303
(define (oops! fmt . args)
13031304
(apply error 'verify-schedule! fmt args))
13041305
; verify the schedule
1305-
(for ([instr (in-list schedule)])
1306-
(match instr
1307-
[(cons rules params)
1308-
;; `run` instruction
1309-
1310-
(unless (or (equal? `lift rules)
1311-
(equal? `lower rules)
1312-
(and (list? rules) (andmap rule? rules)))
1313-
(oops! "expected list of rules: `~a`" rules))
1314-
1315-
(for ([param (in-list params)])
1316-
(match param
1317-
[(cons 'node (? nonnegative-integer?)) (void)]
1318-
[(cons 'iteration (? nonnegative-integer?)) (void)]
1319-
[(cons 'scheduler mode)
1320-
(unless (set-member? '(simple backoff) mode)
1321-
(oops! "in instruction `~a`, unknown scheduler `~a`" instr mode))]
1322-
[_ (oops! "in instruction `~a`, unknown parameter `~a`" instr param)]))]
1323-
[_ (oops! "expected `(<rules> . <params>)`, got `~a`" instr)]))
1306+
(for ([step (in-list schedule)])
1307+
(unless (memq step '(lift lower rewrite))
1308+
(oops! "unknown schedule step `~a`" step)))
13241309

13251310
(define-values (root-ids egg-graph) (egraph-run-schedule batch brfs schedule ctx))
13261311

src/core/patch.rkt

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@
9090
approxs*)
9191

9292
(define (run-lowering altns global-batch)
93-
(define schedule `((lower . ((iteration . 1) (scheduler . simple)))))
93+
(define schedule '(lower))
9494

9595
; run egg
9696
(define brfs (map alt-expr altns))
@@ -161,10 +161,7 @@
161161
(define rules (*rules*))
162162

163163
; egg schedule (3-phases for mathematical rewrites and implementation selection)
164-
(define schedule
165-
(list `(lift . ((iteration . 1) (scheduler . simple)))
166-
`(,rules . ((node . ,(*node-limit*))))
167-
`(lower . ((iteration . 1) (scheduler . simple)))))
164+
(define schedule '(lift rewrite lower))
168165

169166
(define brfs (map alt-expr altns))
170167
(define reprs (map (batch-reprs global-batch (*context*)) brfs))

src/core/preprocess.rkt

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,7 @@
6969

7070
;; make egg runner
7171
(define-values (batch brfs) (progs->batch (cons spec (map cdr identities))))
72-
(define runner
73-
(make-egraph batch
74-
brfs
75-
(make-list (length brfs) (context-repr ctx))
76-
`((,(*rules*) . ((node . ,(*node-limit*)))))
77-
ctx))
72+
(define runner (make-egraph batch brfs (make-list (length brfs) (context-repr ctx)) '(rewrite) ctx))
7873

7974
;; collect equalities
8075
(for/list ([(ident spec*) (in-dict identities)]

0 commit comments

Comments
 (0)