Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 24 additions & 39 deletions src/core/egg-herbie.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -1217,10 +1217,11 @@

;; Runs rules over the egraph with the given egg parameters.
;; Invariant: the returned egraph is never unsound
(define (egraph-run-rules egg-graph0 egg-rules params)
(define node-limit (dict-ref params 'node #f))
(define iter-limit (dict-ref params 'iteration #f))
(define scheduler (dict-ref params 'scheduler 'backoff))
(define (egraph-run-rules egg-graph0
egg-rules
#:node-limit [node-limit #f]
#:iter-limit [iter-limit #f]
#:scheduler [scheduler 'backoff])
(define ffi-rules (map cdr egg-rules))

;; run the rules
Expand All @@ -1247,14 +1248,18 @@

; run the schedule
(define egg-graph*
(for/fold ([egg-graph egg-graph]) ([(rules params) (in-dict schedule)])
; run rules in the egraph
(define egg-rules
(expand-rules (match rules
[`lift (platform-lifting-rules)]
[`lower (platform-lowering-rules)]
[else rules])))
(define-values (egg-graph* iteration-data) (egraph-run-rules egg-graph egg-rules params))
(for/fold ([egg-graph egg-graph]) ([step (in-list schedule)])
(define-values (egg-graph* iteration-data)
(match step
['lift
(define rules (expand-rules (platform-lifting-rules)))
(egraph-run-rules egg-graph rules #:iter-limit 1 #:scheduler 'simple)]
['lower
(define rules (expand-rules (platform-lowering-rules)))
(egraph-run-rules egg-graph rules #:iter-limit 1 #:scheduler 'simple)]
['rewrite
(define rules (expand-rules (*rules*)))
(egraph-run-rules egg-graph rules #:node-limit (*node-limit*))]))

; get cost statistics
(for ([iter (in-list iteration-data)]
Expand Down Expand Up @@ -1290,37 +1295,17 @@

;; Constructs an egg runner.
;;
;; The schedule is a list of pairs specifying
;; - a list of rules
;; - scheduling parameters:
;; - node limit: `(node . <number>)`
;; - iteration limit: `(iteration . <number>)`
;; - scheduler: `(scheduler . <name>)` [default: backoff]
;; - `simple`: run all rules without banning
;; - `backoff`: ban rules if the fire too much
;; The schedule is a list of step symbols:
;; - `lift`: run lifting rules for 1 iteration with simple scheduler
;; - `rewrite`: run rewrite rules up to node limit with backoff scheduler
;; - `lower`: run lowering rules for 1 iteration with simple scheduler
(define (make-egraph batch brfs reprs schedule ctx)
(define (oops! fmt . args)
(apply error 'verify-schedule! fmt args))
; verify the schedule
(for ([instr (in-list schedule)])
(match instr
[(cons rules params)
;; `run` instruction

(unless (or (equal? `lift rules)
(equal? `lower rules)
(and (list? rules) (andmap rule? rules)))
(oops! "expected list of rules: `~a`" rules))

(for ([param (in-list params)])
(match param
[(cons 'node (? nonnegative-integer?)) (void)]
[(cons 'iteration (? nonnegative-integer?)) (void)]
[(cons 'scheduler mode)
(unless (set-member? '(simple backoff) mode)
(oops! "in instruction `~a`, unknown scheduler `~a`" instr mode))]
[_ (oops! "in instruction `~a`, unknown parameter `~a`" instr param)]))]
[_ (oops! "expected `(<rules> . <params>)`, got `~a`" instr)]))
(for ([step (in-list schedule)])
(unless (memq step '(lift lower rewrite))
(oops! "unknown schedule step `~a`" step)))

(define-values (root-ids egg-graph) (egraph-run-schedule batch brfs schedule ctx))

Expand Down
7 changes: 2 additions & 5 deletions src/core/patch.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
approxs*)

(define (run-lowering altns global-batch)
(define schedule `((lower . ((iteration . 1) (scheduler . simple)))))
(define schedule '(lower))

; run egg
(define brfs (map alt-expr altns))
Expand Down Expand Up @@ -161,10 +161,7 @@
(define rules (*rules*))

; egg schedule (3-phases for mathematical rewrites and implementation selection)
(define schedule
(list `(lift . ((iteration . 1) (scheduler . simple)))
`(,rules . ((node . ,(*node-limit*))))
`(lower . ((iteration . 1) (scheduler . simple)))))
(define schedule '(lift rewrite lower))

(define brfs (map alt-expr altns))
(define reprs (map (batch-reprs global-batch (*context*)) brfs))
Expand Down
7 changes: 1 addition & 6 deletions src/core/preprocess.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,7 @@

;; make egg runner
(define-values (batch brfs) (progs->batch (cons spec (map cdr identities))))
(define runner
(make-egraph batch
brfs
(make-list (length brfs) (context-repr ctx))
`((,(*rules*) . ((node . ,(*node-limit*)))))
ctx))
(define runner (make-egraph batch brfs (make-list (length brfs) (context-repr ctx)) '(rewrite) ctx))

;; collect equalities
(for/list ([(ident spec*) (in-dict identities)]
Expand Down