Skip to content

Commit 3a50b18

Browse files
author
varun10p
committed
Type-checking logic added
1 parent a4efa40 commit 3a50b18

File tree

5 files changed

+130
-70
lines changed

5 files changed

+130
-70
lines changed

src/core/mainloop.rkt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@
368368
(timeline-event! 'simplify)
369369

370370
; egg schedule (only mathematical rewrites)
371-
(define rules (*fp-safe-simplify-rules*))
371+
(define rules (append (*fp-safe-simplify-rules*) (real-rules (*simplify-rules*))))
372372
(define schedule `((,rules . ((node . ,(*node-limit*)) (const-fold? . #f)))))
373373

374374
; egg runner

src/platforms/bool.rkt

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,31 +34,16 @@
3434
#:spec (not x)
3535
#:fpcore (! (not x))
3636
#:fl not
37-
#:identities
38-
([not-true (not (TRUE)) (FALSE)] [not-false (not (FALSE)) (TRUE)]
39-
[not-not (not (not a)) a]
40-
[not-and (not (and a b)) (or (not a) (not b))]
41-
[not-or (not (or a b)) (and (not a) (not b))]
42-
[not-lt (not (< x y)) (>= x y)]
43-
[not-gt (not (> x y)) (<= x y)]
44-
[not-lte (not (<= x y)) (> x y)]
45-
[not-gte (not (>= x y)) (< x y)]))
37+
#:identities (#:exact (not a)))
4638

4739
(define-operator-impl (and [x : bool] [y : bool])
4840
bool
4941
#:spec (and x y)
5042
#:fl and-fn
51-
#:identities
52-
([and-true-l (and (TRUE) a) a] [and-true-r (and a (TRUE)) a]
53-
[and-false-l (and (FALSE) a) (FALSE)]
54-
[and-false-r (and a (FALSE)) (FALSE)]
55-
[and-same (and a a) a]))
43+
#:identities (#:exact (and a b)))
5644

5745
(define-operator-impl (or [x : bool] [y : bool])
5846
bool
5947
#:spec (or x y)
6048
#:fl or-fn
61-
#:identities ([or-true-l (or (TRUE) a) (TRUE)] [or-true-r (or a (TRUE)) (TRUE)]
62-
[or-false-l (or (FALSE) a) a]
63-
[or-false-r (or a (FALSE)) a]
64-
[or-same (or a a) a]))
49+
#:identities (#: exact (or a b)))

src/syntax/matcher.rkt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
#lang racket
44

5-
(provide pattern-match
5+
(provide merge-bindings
6+
pattern-match
67
pattern-substitute)
78

89
;; Unions two bindings. Returns #f if they disagree.

src/syntax/platform.rkt

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"../core/programs.rkt"
66
"../core/rules.rkt"
77
"matcher.rkt"
8+
"sugar.rkt"
89
"syntax.rkt"
910
"types.rkt")
1011

@@ -549,18 +550,86 @@
549550
(string-join (map (lambda (subst) (~a (cdr subst))) isubst) "-"))))
550551
(sow (rule name* input* output* itypes* repr)))))]))))
551552

553+
(define (expr-otype expr)
554+
(match expr
555+
[(? literal?) #f]
556+
[(? variable?) #f]
557+
[(list 'if cond ift iff) (expr-otype ift)]
558+
[(list op args ...) (impl-info op 'otype)]))
559+
560+
(define (type-verify expr otype)
561+
(match expr
562+
[(? literal?) '()]
563+
[(? variable?) '((cons expr otype))]
564+
[(list 'if cond ift iff)
565+
(define bool-repr (get-representation 'bool))
566+
(define combined
567+
(merge-bindings (type-verify cond bool-repr)
568+
(merge-bindings (type-verify ift otype) (type-verify iff otype))))
569+
(unless combined
570+
(error 'type-verify "Variable types do not match in ~a" expr))
571+
combined]
572+
[(list op args ...)
573+
(define op-otype (impl-info op 'otype))
574+
(when (not (equal? op-otype otype))
575+
(error 'type-verify "Operator ~a has type ~a, expected ~a" op op-otype otype))
576+
(define bindings '())
577+
(for ([arg (in-list args)]
578+
[itype (in-list (impl-info op 'itype))])
579+
(define combined (merge-bindings bindings (type-verify arg itype)))
580+
(unless combined
581+
(error 'type-verify "Variable types do not match in ~a" expr))
582+
(set! bindings combined))
583+
bindings]))
584+
585+
(define (expr->prog expr repr)
586+
(match expr
587+
[(? literal?) (literal (get-representation repr) expr)]
588+
[(? variable?) expr]
589+
[`(if ,cond ,ift ,iff)
590+
`(if ,(expr->prog cond repr) ,(expr->prog ift repr) ,(expr->prog iff repr))]
591+
[`(,impl ,args ...) `(impl ,@(map (λ (arg) (expr->prog arg (impl-info impl 'itype))) args))]))
592+
552593
(define (*fp-safe-simplify-rules*)
553594
(reap [sow]
554595
(for ([impl (in-list (platform-impls (*active-platform*)))])
555596
(define rules (impl-info impl 'identities))
556-
(for ([name (in-hash-keys rules)])
557-
(match-define (list input output vars) (hash-ref rules name))
558-
(define itype (car (impl-info impl 'itype)))
559-
(define r
560-
(rule name
561-
input
562-
output
563-
(for/hash ([v (in-list vars)])
564-
(values v itype))
565-
(impl-info impl 'otype)))
566-
(sow r)))))
597+
(for ([identity (in-list rules)])
598+
(match identity
599+
[(list 'exact name expr)
600+
(when (not (expr-otype expr))
601+
(error "Exact identity expr cannot infer type"))
602+
(define otype (expr-otype expr))
603+
(define var-types (type-verify expr otype))
604+
(define prog (expr->prog expr otype))
605+
(define r
606+
(rule name
607+
prog
608+
(prog->spec prog)
609+
(for/hash ([binding (in-list var-types)])
610+
(values (car binding) (cdr binding)))
611+
(impl-info impl 'otype)))
612+
(sow r)]
613+
[(list 'commutes name expr rev-expr)
614+
(define vars (impl-info impl 'vars))
615+
(define itype (car (impl-info impl 'itype)))
616+
(define r
617+
(rule name
618+
(expr->prog expr)
619+
(expr->prog rev-expr)
620+
(for/hash ([v (in-list vars)])
621+
(values v itype))
622+
(impl-info impl 'otype))) ; Commutes by definition the types are matching
623+
(sow r)]
624+
[(list 'directed name lhs rhs)
625+
(define lotype (expr-otype lhs))
626+
(define rotype (expr-otype rhs))
627+
(define var-types (merge-bindings (type-verify lhs lotype) (type-verify rhs rotype)))
628+
(define r
629+
(rule name
630+
(expr->prog lhs)
631+
(expr->prog rhs)
632+
(for/hash ([binding (in-list var-types)])
633+
(values (car binding) (cdr binding)))
634+
(impl-info impl 'otype)))
635+
(sow r)])))))

src/syntax/syntax.rkt

Lines changed: 44 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -389,52 +389,57 @@
389389
name)]))
390390

391391
; make hash table
392-
(define rules (make-hasheq))
393-
(define count 0)
392+
(define rules '())
393+
(define rule-names (make-hasheq))
394394
(define commutes? #f)
395395
(when identities
396-
(for ([ident (in-list identities)])
397-
(match ident
398-
[(list ident-name lhs-expr rhs-expr)
399-
(cond
400-
[(hash-has-key? rules ident-name)
401-
(raise-herbie-syntax-error "Duplicate identity ~a" ident-name)]
402-
[else
403-
(hash-set! rules
404-
(string->symbol (format "~a-~a" (symbol->string ident-name) name))
405-
(list lhs-expr
406-
rhs-expr
407-
(remove-duplicates (append (free-variables lhs-expr)
408-
(free-variables rhs-expr)))))])]
409-
[(list 'exact expr)
410-
(hash-set! rules
411-
(gensym (string->symbol (format "~a-exact-~a" name count)))
412-
(list expr expr (free-variables expr)))
413-
(set! count (+ count 1))]
414-
[(list 'commutes)
415-
(cond
416-
[commutes? (raise-herbie-syntax-error "Commutes identity already defined")]
417-
[(hash-has-key? rules (string->symbol (format "~a-commutes" name)))
418-
(raise-herbie-syntax-error "Commutes identity already manually defined")]
419-
[(not (equal? (length vars) 2))
420-
(raise-herbie-syntax-error "Cannot commute a non 2-ary operator")]
421-
[else
422-
(set! commutes? #t)
423-
(hash-set! rules
424-
(string->symbol (format "~a-commutes" name))
425-
(list `(,name ,@vars) `(,name ,@(reverse vars)) vars))])])))
396+
(set! rules
397+
(for/list ([ident (in-list identities)]
398+
[i (in-naturals)])
399+
(match ident
400+
[(list ident-name lhs-expr rhs-expr)
401+
(cond
402+
[(hash-has-key? rule-names ident-name)
403+
(raise-herbie-syntax-error "Duplicate identity ~a" ident-name)]
404+
[(not (well-formed? lhs-expr))
405+
(raise-herbie-syntax-error "Ill-formed identity expression ~a" lhs-expr)]
406+
[(not (well-formed? rhs-expr))
407+
(raise-herbie-syntax-error "Ill-formed identity expression ~a" rhs-expr)]
408+
[else
409+
(define rule-name (string->symbol (format "~a-~a" ident-name name)))
410+
(hash-set! rule-names rule-name #f)
411+
(list 'directed rule-name lhs-expr rhs-expr)])]
412+
[(list 'exact expr)
413+
(cond
414+
[(not (well-formed? expr))
415+
(raise-herbie-syntax-error "Ill-formed identity expression ~a" expr)]
416+
[else
417+
(define rule-name (gensym (string->symbol (format "~a-exact-~a" name i))))
418+
(hash-set! rule-names rule-name #f)
419+
(list 'exact rule-name expr)])]
420+
[(list 'commutes)
421+
(cond
422+
[commutes? (error "Commutes identity already defined")]
423+
[(hash-has-key? rule-names (string->symbol (format "~a-commutes" name)))
424+
(error "Commutes identity already manually defined")]
425+
[(not (equal? (length vars) 2))
426+
(raise-herbie-syntax-error "Cannot commute a non 2-ary operator")]
427+
[else
428+
(set! commutes? #t)
429+
(define rule-name (string->symbol (format "~a-commutes" name)))
430+
(hash-set! rule-names rule-name #f)
431+
(list 'commutes rule-name `(,name ,@vars) `(,name ,@(reverse vars)))])]))))
426432

427433
; update tables
428434
(define impl (operator-impl name ctx spec fpcore* fl-proc* rules))
429435
(hash-set! operator-impls name impl))
430436

431-
(define (free-variables prog)
432-
(match prog
433-
[(? literal?) '()]
434-
[(? number?) '()]
435-
[(? variable?) (list prog)]
436-
[(approx _ impl) (free-variables impl)]
437-
[(list _ args ...) (remove-duplicates (append-map free-variables args))]))
437+
(define (well-formed? expr)
438+
(match expr
439+
[(? number?) #t]
440+
[(? variable?) #t]
441+
[`(,impl ,args ...) (andmap well-formed? args)]
442+
[_ #f]))
438443

439444
(define-syntax (define-operator-impl stx)
440445
(define (oops! why [sub-stx #f])

0 commit comments

Comments
 (0)