Skip to content

Commit 455c599

Browse files
authored
Merge pull request #1127 from herbie-fp/simplify-patch
Add `(hole repr spec)` terms to simplify Taylor expansion
2 parents 3d631b2 + 9395a1c commit 455c599

File tree

7 files changed

+39
-55
lines changed

7 files changed

+39
-55
lines changed

src/core/batch.rkt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
(define (expr-recurse expr f)
2727
(match expr
2828
[(approx spec impl) (approx (f spec) (f impl))]
29+
[(hole precision spec) (hole precision (f spec))]
2930
[(list op args ...) (cons op (map f args))]
3031
[_ expr]))
3132

src/core/derivations.rkt

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,35 +35,13 @@
3535
(define (build! altn)
3636
(match altn
3737
; recursive rewrite using egg (impl -> impl)
38-
[(alt expr `(rr ,loc ,(? egg-runner? runner) #f) `(,prev) _)
38+
[(alt expr `(,(or 'rr 'simplify) ,loc ,(? egg-runner? runner) #f) `(,prev) _)
3939
(define start-expr (location-get loc (alt-expr prev)))
4040
(define end-expr (location-get loc expr))
4141
(define rewrite (cons start-expr end-expr))
4242
(hash-set! alt->query&rws (altn->key altn) (cons runner rewrite))
4343
(hash-update! query->rws runner (lambda (rws) (set-add rws rewrite)) '())]
4444

45-
; simplify using egg
46-
; usually: impl -> impl
47-
; taylor: spec -> approx (_, impl)
48-
[(alt expr `(simplify ,loc ,(? egg-runner? runner) #f) `(,prev) _)
49-
(define rewrite
50-
(match (alt-event prev)
51-
[(list 'taylor _ ...)
52-
; simplify after taylor: spec -> approx (_, impl)
53-
(define start-expr (location-get loc (alt-expr prev)))
54-
(define end-expr (location-get loc expr))
55-
(unless (approx? end-expr)
56-
(error 'make-proof-tables "expected approx node, got ~a" end-expr))
57-
(cons start-expr (approx-impl end-expr))]
58-
[_
59-
; simplify after other: impl -> impl
60-
(define start-expr (location-get loc (alt-expr prev)))
61-
(define end-expr (location-get loc expr))
62-
(cons start-expr end-expr)]))
63-
64-
(hash-set! alt->query&rws (altn->key altn) (cons runner rewrite))
65-
(hash-update! query->rws runner (lambda (rws) (set-add rws rewrite)) '())]
66-
6745
; everything else
6846
[_ (void)])
6947

src/core/egg-herbie.rkt

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,22 +132,23 @@
132132
(for ([node (in-vector (batch-nodes insert-batch))]
133133
[root? (in-vector root-mask)]
134134
[n (in-naturals)])
135-
(define node*
135+
(define idx
136136
(match node
137-
[(literal v _) v]
138-
[(? number?) node]
139-
[(? symbol?) (normalize-var node)]
137+
[(literal v _) (insert-node! v root?)]
138+
[(? number?) (insert-node! node root?)]
139+
[(? symbol?) (insert-node! (normalize-var node) root?)]
140+
[(hole prec spec) (remap spec)] ; "hole" terms currently disappear
140141
[(approx spec impl)
141142
(hash-ref! id->spec
142143
(remap spec)
143144
(lambda ()
144145
(define spec* (normalize-spec (batch-ref insert-batch spec)))
145146
(define type (representation-type (repr-of-node insert-batch impl ctx)))
146147
(cons spec* type))) ; preserved spec and type for extraction
147-
(list '$approx (remap spec) (remap impl))]
148-
[(list op (app remap args) ...) (cons op args)]))
148+
(insert-node! (list '$approx (remap spec) (remap impl)) root?)]
149+
[(list op (app remap args) ...) (insert-node! (cons op args) root?)]))
149150

150-
(vector-set! mappings n (insert-node! node* root?)))
151+
(vector-set! mappings n idx))
151152

152153
(for/list ([root (in-vector (batch-roots insert-batch))])
153154
(remap root)))
@@ -276,6 +277,7 @@
276277
(hash-set! egg->herbie-dict replacement (cons expr (context-lookup ctx expr)))
277278
replacement))]
278279
[(approx spec impl) (list '$approx (loop spec) (loop impl))]
280+
[(hole precision spec) (loop spec)]
279281
[(list op args ...) (cons op (map loop args))])))
280282

281283
(define (flatten-let expr)

src/core/patch.rkt

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,8 @@
5555
[outputs (in-list simplification-options)])
5656
(match-define (cons _ simplified) outputs)
5757
(define prev (car (alt-prevs altn)))
58-
(for ([batchreff (in-list simplified)])
59-
(define spec (prog->spec (debatchref (alt-expr prev))))
60-
(define idx ; Munge
61-
(mutable-batch-push! global-batch-mutable
62-
(approx (mutable-batch-munge! global-batch-mutable spec)
63-
(batchref-idx batchreff))))
64-
(sow (alt (batchref global-batch idx) `(simplify ,runner #f) (list altn) '()))))
58+
(for ([bref (in-list simplified)])
59+
(sow (alt bref `(simplify ,runner #f) (list altn) '()))))
6560
(batch-copy-mutable-nodes! global-batch global-batch-mutable))) ; Update global-batch
6661

6762
(timeline-push! 'count (length approxs) (length simplified))
@@ -80,25 +75,26 @@
8075
#;(log ,log-x ,exp-x))))
8176

8277
(define (taylor-alts starting-exprs altns global-batch)
83-
(define exprs
84-
(for/list ([expr (in-list starting-exprs)])
85-
(prog->spec expr)))
86-
(define free-vars (map free-variables exprs))
78+
(define specs (map prog->spec starting-exprs))
79+
(define free-vars (map free-variables specs))
8780
(define vars (context-vars (*context*)))
8881

8982
(reap [sow]
9083
(define global-batch-mutable (batch->mutable-batch global-batch)) ; Create a mutable batch
9184
(for* ([var (in-list vars)]
9285
[transform-type transforms-to-try])
9386
(match-define (list name f finv) transform-type)
94-
(define timeline-stop! (timeline-start! 'series (~a exprs) (~a var) (~a name)))
95-
(define genexprs (approximate exprs var #:transform (cons f finv)))
87+
(define timeline-stop! (timeline-start! 'series (~a specs) (~a var) (~a name)))
88+
(define genexprs (approximate specs var #:transform (cons f finv)))
9689
(for ([genexpr (in-list genexprs)]
90+
[spec (in-list specs)]
91+
[expr (in-list starting-exprs)]
9792
[altn (in-list altns)]
9893
[fv (in-list free-vars)]
99-
#:when (member var fv)) ; check whether var exists in expr at all
94+
#:when (set-member? fv var)) ; check whether var exists in expr at all
10095
(for ([i (in-range (*taylor-order-limit*))])
101-
(define gen (genexpr))
96+
(define repr (repr-of expr (*context*)))
97+
(define gen (approx spec (hole (representation-name repr) (genexpr))))
10298
(define idx (mutable-batch-munge! global-batch-mutable gen)) ; Munge gen
10399
(sow (alt (batchref global-batch idx) `(taylor ,name ,var) (list altn) '()))))
104100
(timeline-stop!))

src/core/programs.rkt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
[(literal val precision) (get-representation precision)]
4141
[(? variable?) (context-lookup ctx node)]
4242
[(approx _ impl) (repr-of-node batch impl ctx)]
43+
[(hole precision spec) (get-representation precision)]
4344
[(list 'if cond ift iff) (repr-of-node batch ift ctx)]
4445
[(list op args ...) (impl-info op 'otype)]))
4546

src/reports/history.rkt

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,16 +73,17 @@
7373
(values (format-accuracy err repr-bits #:unit "%")
7474
(format "~a on training set" (format-accuracy err2 repr-bits #:unit "%"))))
7575

76-
(define (remove-literals expr)
77-
(match expr
78-
[(? symbol?) expr]
79-
[(? number?) expr]
80-
[(? literal?) (literal-value expr)]
81-
[(approx spec impl) (approx (remove-literals spec) (remove-literals impl))]
82-
[(list op args ...) (cons op (map remove-literals args))]))
83-
8476
(define (expr->fpcore expr ctx #:ident [ident #f])
85-
(list 'FPCore (context-vars ctx) (remove-literals expr)))
77+
(list 'FPCore
78+
(context-vars ctx)
79+
(let loop ([expr expr])
80+
(match expr
81+
[(? symbol?) expr]
82+
[(? number?) expr]
83+
[(? literal?) (literal-value expr)]
84+
[(approx spec impl) (loop impl)]
85+
[(hole precision spec) (loop spec)]
86+
[(list op args ...) (cons op (map loop args))]))))
8687

8788
(define (mixed->fpcore expr ctx)
8889
(define expr*
@@ -92,6 +93,7 @@
9293
[(? number?) expr]
9394
[(? literal?) (literal-value expr)]
9495
[(approx _ impl) (loop impl)]
96+
[(hole precision spec) (loop spec)]
9597
[`(if ,cond ,ift ,iff)
9698
`(if ,(loop cond)
9799
,(loop ift)

src/syntax/syntax.rkt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
(provide (struct-out literal)
1212
(struct-out approx)
13+
(struct-out hole)
1314
variable?
1415
constant-operator?
1516
operator-exists?
@@ -493,9 +494,12 @@
493494
(struct literal (value precision) #:prefab)
494495

495496
;; An approximation of a specification by
496-
;; an arbitrary floating-point expression.
497+
;; a floating-point expression.
497498
(struct approx (spec impl) #:prefab)
498499

500+
;; An unknown floating-point expression that implements a given spec
501+
(struct hole (precision spec) #:prefab)
502+
499503
;; name -> (vars repr body) ;; name -> (vars prec body)
500504
(define *functions* (make-parameter (make-hasheq)))
501505

0 commit comments

Comments
 (0)