|
| 1 | +#lang racket |
| 2 | + |
| 3 | +(require "platform.rkt" |
| 4 | + "syntax.rkt" |
| 5 | + "types.rkt" |
| 6 | + "generators.rkt" |
| 7 | + "../utils/errors.rkt" |
| 8 | + "../config.rkt") |
| 9 | + |
| 10 | +(provide define-representation |
| 11 | + define-operation |
| 12 | + define-operations |
| 13 | + fpcore-context |
| 14 | + if-impl |
| 15 | + if-cost |
| 16 | + (rename-out [platform-module-begin #%module-begin]) |
| 17 | + (except-out (all-from-out racket) #%module-begin) |
| 18 | + (all-from-out "platform.rkt") |
| 19 | + (all-from-out "generators.rkt") |
| 20 | + (all-from-out "types.rkt")) |
| 21 | + |
| 22 | +(define platform-being-defined (make-parameter #f)) |
| 23 | + |
| 24 | +;; Specification checking and operator implementation creation moved |
| 25 | +;; from syntax.rkt |
| 26 | +(define (check-spec! name ctx spec) |
| 27 | + (define (bad! fmt . args) |
| 28 | + (error name "~a in `~a`" (apply format fmt args) spec)) |
| 29 | + |
| 30 | + (define (type-error! expr actual-ty expect-ty) |
| 31 | + (bad! "expression `~a` has type `~a`, expected `~a`" expr actual-ty expect-ty)) |
| 32 | + |
| 33 | + (match-define (context vars repr var-reprs) ctx) |
| 34 | + (define itypes (map representation-type var-reprs)) |
| 35 | + (define otype (representation-type repr)) |
| 36 | + |
| 37 | + (unless (= (length itypes) (length vars)) |
| 38 | + (bad! "arity mismatch; expected ~a, got ~a" (length itypes) (length vars))) |
| 39 | + |
| 40 | + (define env (map cons vars itypes)) |
| 41 | + (define actual-ty |
| 42 | + (let type-of ([expr spec]) |
| 43 | + (match expr |
| 44 | + [(? number?) 'real] |
| 45 | + [(? symbol?) |
| 46 | + (cond |
| 47 | + [(assq expr env) |
| 48 | + => |
| 49 | + cdr] |
| 50 | + [else (bad! "unbound variable `~a`" expr)])] |
| 51 | + [`(if ,cond ,ift ,iff) |
| 52 | + (define cond-ty (type-of cond)) |
| 53 | + (unless (equal? cond-ty 'bool) |
| 54 | + (type-error! cond cond-ty 'bool)) |
| 55 | + (define ift-ty (type-of ift)) |
| 56 | + (define iff-ty (type-of iff)) |
| 57 | + (unless (equal? ift-ty iff-ty) |
| 58 | + (type-error! iff iff-ty ift-ty)) |
| 59 | + ift-ty] |
| 60 | + [`(,op ,args ...) |
| 61 | + (unless (operator-exists? op) |
| 62 | + (bad! "at `~a`, `~a` not an operator" expr op)) |
| 63 | + (define itypes (operator-info op 'itype)) |
| 64 | + (unless (= (length itypes) (length args)) |
| 65 | + (bad! "arity mismatch at `~a`: expected `~a`, got `~a`" |
| 66 | + expr |
| 67 | + (length itypes) |
| 68 | + (length args))) |
| 69 | + (for ([arg (in-list args)] |
| 70 | + [itype (in-list itypes)]) |
| 71 | + (define arg-ty (type-of arg)) |
| 72 | + (unless (equal? itype arg-ty) |
| 73 | + (type-error! arg arg-ty itype))) |
| 74 | + (operator-info op 'otype)] |
| 75 | + [_ (bad! "expected an expression, got `~a`" expr)]))) |
| 76 | + |
| 77 | + (unless (equal? actual-ty otype) |
| 78 | + (type-error! spec actual-ty otype))) |
| 79 | + |
| 80 | +(define fpcore-context (make-parameter '_)) |
| 81 | + |
| 82 | +(define (fpcore-parameterize spec) |
| 83 | + (let loop ([ctx (fpcore-context)]) |
| 84 | + (match ctx |
| 85 | + ['_ spec] |
| 86 | + [(list arg ...) (map loop arg)] |
| 87 | + [_ ctx]))) |
| 88 | + |
| 89 | +(define/contract (create-operator-impl! name |
| 90 | + ctx |
| 91 | + spec |
| 92 | + #:impl [fl-proc #f] |
| 93 | + #:fpcore [fpcore #f] |
| 94 | + #:cost [cost #f]) |
| 95 | + (->* (symbol? context? any/c) |
| 96 | + (#:impl (or/c procedure? generator? #f) #:fpcore any/c #:cost (or/c #f real? procedure?)) |
| 97 | + operator-impl?) |
| 98 | + ;; check specification |
| 99 | + (check-spec! name ctx spec) |
| 100 | + ;; synthesize operator (if the spec contains exactly one operator) |
| 101 | + (define op |
| 102 | + (match spec |
| 103 | + [(list op (or (? number?) (? symbol?)) ...) op] |
| 104 | + [_ #f])) |
| 105 | + ;; check FPCore translation |
| 106 | + (match (fpcore-parameterize (or fpcore spec)) |
| 107 | + [`(! ,props ... (,op ,args ...)) |
| 108 | + (unless (even? (length props)) |
| 109 | + (error 'create-operator-impl! "~a: umatched property in ~a" name fpcore)) |
| 110 | + (unless (symbol? op) |
| 111 | + (error 'create-operator-impl! "~a: expected symbol `~a`" name op)) |
| 112 | + (for ([arg (in-list args)] |
| 113 | + #:unless (or (symbol? arg) (number? arg))) |
| 114 | + (error 'create-operator-impl! "~a: expected terminal `~a`" name arg))] |
| 115 | + [`(,op ,args ...) |
| 116 | + (unless (symbol? op) |
| 117 | + (error 'create-operator-impl! "~a: expected symbol `~a`" name op)) |
| 118 | + (for ([arg (in-list args)] |
| 119 | + #:unless (or (symbol? arg) (number? arg))) |
| 120 | + (error 'create-operator-impl! "~a: expected terminal `~a`" name arg))] |
| 121 | + [(? symbol?) (void)] |
| 122 | + [_ (error 'create-operator-impl! "Invalid fpcore for ~a: ~a" name fpcore)]) |
| 123 | + ;; check or synthesize floating-point operation |
| 124 | + (define fl-proc* |
| 125 | + (match fl-proc |
| 126 | + [(? generator?) ((generator-gen fl-proc) spec ctx)] |
| 127 | + [(? procedure?) fl-proc] |
| 128 | + [#f (error 'create-operator-impl! "fl-proc is not provided for `~a` implementation" name)])) |
| 129 | + (unless (procedure-arity-includes? fl-proc* (length (context-vars ctx)) #t) |
| 130 | + (error 'arity-check |
| 131 | + "Procedure `~a` accepts ~a arguments, but ~a is provided" |
| 132 | + name |
| 133 | + (procedure-arity fl-proc*) |
| 134 | + (length (context-vars ctx)))) |
| 135 | + (define-values (cost* aggregate*) |
| 136 | + (cond |
| 137 | + [(number? cost) (values cost +)] |
| 138 | + [(procedure? cost) (values 0 cost)] |
| 139 | + [else (values cost +)])) |
| 140 | + (operator-impl name ctx spec (fpcore-parameterize (or fpcore spec)) fl-proc* cost* aggregate*)) |
| 141 | + |
| 142 | +(define-syntax (make-operator-impl stx) |
| 143 | + (define (oops! why [sub-stx #f]) |
| 144 | + (raise-syntax-error 'make-operator-impl why stx sub-stx)) |
| 145 | + (syntax-case stx (:) |
| 146 | + [(_ (id [var : repr] ...) rtype fields ...) |
| 147 | + (let ([id #'id] |
| 148 | + [vars (syntax->list #'(var ...))] |
| 149 | + [fields #'(fields ...)]) |
| 150 | + (unless (identifier? id) |
| 151 | + (oops! "expected identifier" id)) |
| 152 | + (for ([var (in-list vars)] |
| 153 | + #:unless (identifier? var)) |
| 154 | + (oops! "expected identifier" var)) |
| 155 | + (define spec #f) |
| 156 | + (define core #f) |
| 157 | + (define fl-expr #f) |
| 158 | + (define op-cost #f) |
| 159 | + |
| 160 | + (let loop ([fields fields]) |
| 161 | + (syntax-case fields () |
| 162 | + [() |
| 163 | + (unless spec |
| 164 | + (oops! "missing `#:spec` keyword")) |
| 165 | + (with-syntax ([id id] |
| 166 | + [spec spec] |
| 167 | + [core core] |
| 168 | + [fl-expr fl-expr] |
| 169 | + [op-cost op-cost]) |
| 170 | + #'(create-operator-impl! 'id |
| 171 | + (context '(var ...) rtype (list repr ...)) |
| 172 | + 'spec |
| 173 | + #:impl fl-expr |
| 174 | + #:fpcore 'core |
| 175 | + #:cost op-cost))] |
| 176 | + [(#:spec expr rest ...) |
| 177 | + (cond |
| 178 | + [spec (oops! "multiple #:spec clauses" stx)] |
| 179 | + [else |
| 180 | + (set! spec #'expr) |
| 181 | + (loop #'(rest ...))])] |
| 182 | + [(#:spec) (oops! "expected value after keyword `#:spec`" stx)] |
| 183 | + [(#:fpcore expr rest ...) |
| 184 | + (cond |
| 185 | + [core (oops! "multiple #:fpcore clauses" stx)] |
| 186 | + [else |
| 187 | + (set! core #'expr) |
| 188 | + (loop #'(rest ...))])] |
| 189 | + [(#:fpcore) (oops! "expected value after keyword `#:fpcore`" stx)] |
| 190 | + [(#:impl expr rest ...) |
| 191 | + (cond |
| 192 | + [fl-expr (oops! "multiple #:fl clauses" stx)] |
| 193 | + [else |
| 194 | + (set! fl-expr #'expr) |
| 195 | + (loop #'(rest ...))])] |
| 196 | + [(#:impl) (oops! "expected value after keyword `#:fl`" stx)] |
| 197 | + [(#:cost cost rest ...) |
| 198 | + (cond |
| 199 | + [op-cost (oops! "multiple #:cost clauses" stx)] |
| 200 | + [else |
| 201 | + (set! op-cost #'cost) |
| 202 | + (loop #'(rest ...))])] |
| 203 | + [(#:cost) (oops! "expected value after keyword `#:cost`" stx)] |
| 204 | + |
| 205 | + ; bad |
| 206 | + [_ (oops! "bad syntax" fields)])))] |
| 207 | + [_ (oops! "bad syntax")])) |
| 208 | + |
| 209 | +;; Platform registration functions moved from platform.rkt |
| 210 | +(define (platform-register-representation! platform #:repr repr #:cost cost) |
| 211 | + (define reprs (platform-representations platform)) |
| 212 | + (define repr-costs (platform-representation-costs platform)) |
| 213 | + ; Duplicate check |
| 214 | + (when (hash-has-key? reprs (representation-name repr)) |
| 215 | + (raise-herbie-error "Duplicate representation ~a in platform ~a" |
| 216 | + (representation-name repr) |
| 217 | + (*platform-name*))) |
| 218 | + ; Update tables |
| 219 | + (hash-set! reprs (representation-name repr) repr) |
| 220 | + (hash-set! repr-costs (representation-name repr) cost)) |
| 221 | + |
| 222 | +(define (platform-register-implementation! platform impl) |
| 223 | + (unless impl |
| 224 | + (raise-herbie-error "Platform ~a missing implementation" (*platform-name*))) |
| 225 | + ; Reprs check |
| 226 | + (define reprs (platform-representations platform)) |
| 227 | + (define otype (context-repr (operator-impl-ctx impl))) |
| 228 | + (define itype (context-var-reprs (operator-impl-ctx impl))) |
| 229 | + (define impl-reprs (map representation-name (remove-duplicates (cons otype itype)))) |
| 230 | + (unless (andmap (curry hash-has-key? reprs) impl-reprs) |
| 231 | + (raise-herbie-error "Platform ~a missing representation of ~a implementation" |
| 232 | + (*platform-name*) |
| 233 | + (operator-impl-name impl))) |
| 234 | + ; Cost check |
| 235 | + (define impl-cost (operator-impl-cost impl)) |
| 236 | + (unless impl-cost |
| 237 | + (raise-herbie-error "Missing cost for ~a" (operator-impl-name impl))) |
| 238 | + ; Duplicate check |
| 239 | + (define impls (platform-implementations platform)) |
| 240 | + (when (hash-has-key? impls (operator-impl-name impl)) |
| 241 | + (raise-herbie-error "Impl ~a is already registered in platform ~a" |
| 242 | + (operator-impl-name impl) |
| 243 | + (*platform-name*))) |
| 244 | + ; Update table |
| 245 | + (hash-set! impls (operator-impl-name impl) impl)) |
| 246 | + |
| 247 | +(define-syntax (platform-register-implementations! stx) |
| 248 | + (syntax-case stx () |
| 249 | + [(_ platform ([name ([var : repr] ...) otype spec fl fpcore cost] ...)) |
| 250 | + #'(begin |
| 251 | + (platform-register-implementation! platform |
| 252 | + (make-operator-impl (name [var : repr] ...) |
| 253 | + otype |
| 254 | + #:spec spec |
| 255 | + #:impl fl |
| 256 | + #:fpcore fpcore |
| 257 | + #:cost cost)) ...)])) |
| 258 | + |
| 259 | +(define-syntax-rule (define-representation repr #:cost cost) |
| 260 | + (platform-register-representation! (platform-being-defined) #:repr repr #:cost cost)) |
| 261 | + |
| 262 | +(define-syntax-rule (define-operation (name [arg irepr] ...) orepr flags ...) |
| 263 | + (let ([impl (make-operator-impl (name [arg : irepr] ...) orepr flags ...)]) |
| 264 | + (platform-register-implementation! (platform-being-defined) impl))) |
| 265 | + |
| 266 | +(define-syntax (define-operations stx) |
| 267 | + (syntax-case stx () |
| 268 | + [(_ ([arg irepr] ...) orepr #:fpcore fc [name flags ...] ...) |
| 269 | + #'(parameterize ([fpcore-context 'fc]) |
| 270 | + (begin |
| 271 | + (define-operation (name [arg irepr] ...) orepr flags ...) ...))] |
| 272 | + [(_ ([arg irepr] ...) orepr [name flags ...] ...) |
| 273 | + #'(begin |
| 274 | + (define-operation (name [arg irepr] ...) orepr flags ...) ...)])) |
| 275 | + |
| 276 | +(define-syntax (platform-module-begin stx) |
| 277 | + (with-syntax ([local-platform (datum->syntax stx 'platform)]) |
| 278 | + (syntax-case stx () |
| 279 | + [(_ content ...) |
| 280 | + #'(#%module-begin (define local-platform (make-empty-platform)) |
| 281 | + (define old-platform-being-defined (platform-being-defined)) |
| 282 | + (platform-being-defined local-platform) |
| 283 | + content ... |
| 284 | + (platform-being-defined old-platform-being-defined) |
| 285 | + (validate-platform! local-platform) |
| 286 | + (provide local-platform) |
| 287 | + (module+ main |
| 288 | + (display-platform local-platform)) |
| 289 | + (module test racket/base |
| 290 | + ))]))) |
| 291 | + |
| 292 | +(define (if-impl c t f) |
| 293 | + (if c t f)) |
| 294 | + |
| 295 | +(define (if-cost base) |
| 296 | + (lambda (c t f) (+ base c (max t f)))) |
0 commit comments