Skip to content

Commit 413bec8

Browse files
authored
Merge pull request #989 from herbie-fp/dump-egraph
Add code to dump egraphs in `egraph-serialize` format
2 parents 8171fb9 + 029b8ef commit 413bec8

File tree

2 files changed

+60
-3
lines changed

2 files changed

+60
-3
lines changed

src/config.rkt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
numerics
2020
special
2121
bools
22-
branches)]))
22+
branches)]
23+
[dump . (egg)]))
2324

2425
(define default-flags
2526
#hash([precision . ()]
@@ -36,7 +37,8 @@
3637
numerics
3738
special
3839
bools
39-
branches)]))
40+
branches)]
41+
[dump . ()]))
4042

4143
(define (check-flag-deprecated! category flag)
4244
(match* (category flag)

src/core/egg-herbie.rkt

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
u32vector-set!
88
u32vector-ref
99
list->u32vector
10-
u32vector->list))
10+
u32vector->list)
11+
json) ; for dumping
1112

1213
(require "programs.rkt"
1314
"rules.rkt"
@@ -879,6 +880,36 @@
879880
; construct the `regraph` instance
880881
(regraph eclasses types leaf? constants specs parents canon egg->herbie))
881882

883+
(define (regraph-nodes->json regraph)
884+
(define cost (platform-node-cost-proc (*active-platform*)))
885+
(for/hash ([n (in-naturals)]
886+
[eclass (in-vector (regraph-eclasses regraph))]
887+
#:when true
888+
[k (in-naturals)]
889+
[enode eclass])
890+
(define type (vector-ref (regraph-types regraph) n))
891+
(define cost
892+
(if (representation? type)
893+
(match enode
894+
[(? number?) (platform-repr-cost (*active-platform*) type)]
895+
[(? symbol?) (platform-repr-cost (*active-platform*) type)]
896+
[(list '$approx x y) 0]
897+
[(list 'if c x y)
898+
(match (platform-impl-cost (*active-platform*) 'if)
899+
[`(max ,n) n] ; Not quite right
900+
[`(sum ,n) n])]
901+
[(list op args ...) (platform-impl-cost (*active-platform*) op)])
902+
1))
903+
(values (string->symbol (format "~a.~a" n k))
904+
(hash 'op
905+
(~a (if (list? enode) (car enode) enode))
906+
'children
907+
(if (list? enode) (map ~a (cdr enode)) '())
908+
'eclass
909+
(~a n)
910+
'cost
911+
cost))))
912+
882913
;; Egraph node has children.
883914
;; Nullary operators have no children!
884915
(define (node-has-children? node)
@@ -1322,6 +1353,26 @@
13221353
; make the runner
13231354
(egg-runner batch roots reprs schedule ctx))
13241355

1356+
(define (regraph-dump regraph root-ids reprs)
1357+
(define dump-dir "dump-egg")
1358+
(unless (directory-exists? dump-dir)
1359+
(make-directory dump-dir))
1360+
(define name
1361+
(for/first ([i (in-naturals)]
1362+
#:unless (file-exists? (build-path dump-dir (format "~a.json" i))))
1363+
(build-path dump-dir (format "~a.json" i))))
1364+
(define nodes (regraph-nodes->json regraph))
1365+
(define canon (regraph-canon regraph))
1366+
(define roots
1367+
(filter values
1368+
(for/list ([id (in-list root-ids)]
1369+
[type (in-list reprs)])
1370+
(hash-ref canon (cons id type) #f))))
1371+
(call-with-output-file
1372+
name
1373+
#:exists 'replace
1374+
(lambda (p) (write-json (hash 'nodes nodes 'root_eclasses (map ~a roots) 'class_data (hash)) p))))
1375+
13251376
;; Runs egg using an egg runner.
13261377
;;
13271378
;; Argument `cmd` specifies what to get from the e-graph:
@@ -1342,13 +1393,17 @@
13421393
(define regraph (make-regraph egg-graph))
13431394
(define extract-id (extractor regraph))
13441395
(define reprs (egg-runner-reprs runner))
1396+
(when (flag-set? 'dump 'egg)
1397+
(regraph-dump regraph root-ids reprs))
13451398
(for/list ([id (in-list root-ids)]
13461399
[repr (in-list reprs)])
13471400
(regraph-extract-best regraph extract-id id repr))]
13481401
[`(multi . ,extractor) ; multi expression extraction
13491402
(define regraph (make-regraph egg-graph))
13501403
(define extract-id (extractor regraph))
13511404
(define reprs (egg-runner-reprs runner))
1405+
(when (flag-set? 'dump 'egg)
1406+
(regraph-dump regraph root-ids reprs))
13521407

13531408
; List of roots inside the batch
13541409
(for/list ([id (in-list root-ids)]

0 commit comments

Comments
 (0)