-
Notifications
You must be signed in to change notification settings - Fork 38
/
Copy pathapi.jl
132 lines (94 loc) · 4.01 KB
/
api.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
const DEFINED_DIFFRULES = Dict{Tuple{Union{Expr,Symbol},Symbol,Int},Any}()
"""
@define_diffrule M.f(x) = :(df_dx(\$x))
@define_diffrule M.f(x, y) = :(df_dx(\$x, \$y)), :(df_dy(\$x, \$y))
⋮
Define a new differentiation rule for the function `M.f` and the given arguments, which should
be treated as bindings to Julia expressions. Return the defined rule's key.
The LHS should be a function call with a non-splatted argument list, and the RHS should be
the derivative expression, or in the `n`-ary case, an `n`-tuple of expressions where the
`i`th expression is the derivative of `f` w.r.t the `i`th argument. Arguments should be
interpolated wherever they are used on the RHS.
Note that differentiation rules are purely symbolic, so no type annotations should be used.
Examples:
@define_diffrule Base.cos(x) = :(-sin(\$x))
@define_diffrule Base.:/(x, y) = :(inv(\$y)), :(-\$x / (\$y^2))
@define_diffrule Base.polygamma(m, x) = :NaN, :(polygamma(\$m + 1, \$x))
"""
macro define_diffrule(def::Expr)
def_ = splitdef(def)
module_, name_ = _split_qualified_name(def_[:name])
key = Expr(:tuple, Expr(:quote, module_), Expr(:quote, name_), length(def_[:args]))
expr = Expr(:->, Expr(:tuple, def_[:args]...), def_[:body])
return esc(quote
$DiffRules.DEFINED_DIFFRULES[$key] = $expr
$key
end)
end
"""
diffrule(M::Union{Expr,Symbol}, f::Symbol, args...)
Return the derivative expression for `M.f` at the given argument(s), with the argument(s)
interpolated into the returned expression.
In the `n`-ary case, an `n`-tuple of expressions will be returned where the `i`th expression
is the derivative of `f` w.r.t the `i`th argument.
Examples:
julia> DiffResults.diffrule(:Base, :sin, 1)
:(cos(1))
julia> DiffResults.diffrule(:Base, :sin, :x)
:(cos(x))
julia> DiffResults.diffrule(:Base, :sin, :(x * y^2))
:(cos(x * y ^ 2))
julia> DiffResults.diffrule(:Base, :^, :(x + 2), :c)
(:(c * (x + 2) ^ (c - 1)), :((x + 2) ^ c * log(x + 2)))
"""
diffrule(M::Union{Expr,Symbol}, f::Symbol, args...) = DEFINED_DIFFRULES[M,f,length(args)](args...)
"""
hasdiffrule(M::Union{Expr,Symbol}, f::Symbol, arity::Int)
Return `true` if a differentiation rule is defined for `M.f` and `arity`, or return `false`
otherwise.
Here, `arity` refers to the number of arguments accepted by `f`.
Examples:
julia> DiffResults.hasdiffrule(:Base, :sin, 1)
true
julia> DiffResults.hasdiffrule(:Base, :sin, 2)
false
julia> DiffResults.hasdiffrule(:Base, :-, 1)
true
julia> DiffResults.hasdiffrule(:Base, :-, 2)
true
julia> DiffResults.hasdiffrule(:Base, :-, 3)
false
"""
hasdiffrule(M::Union{Expr,Symbol}, f::Symbol, arity::Int) = haskey(DEFINED_DIFFRULES, (M, f, arity))
"""
diffrules()
Return a list of keys that can be used to access all defined differentiation rules.
Each key is of the form `(M::Symbol, f::Symbol, arity::Int)`.
Here, `arity` refers to the number of arguments accepted by `f`.
Examples:
julia> first(DiffRules.diffrules())
(:Base, :asind, 1)
"""
diffrules() = keys(DEFINED_DIFFRULES)
# For v0.6 and v0.7 compatibility, need to support having the diff rule function enter as a
# `Expr(:quote...)` and a `QuoteNode`. When v0.6 support is dropped, the function will
# always enter in a `QuoteNode` (#23885).
function _get_quoted_symbol(ex::Expr)
@assert ex.head == :quote
@assert length(ex.args) == 1 && isa(ex.args[1], Symbol) "Function not a single symbol"
ex.args[1]
end
function _get_quoted_symbol(ex::QuoteNode)
@assert isa(ex.value, Symbol) "Function not a single symbol"
ex.value
end
"""
_split_qualified_name(name::Union{Symbol, Expr})
Split a function name qualified by a module into the module name and a name.
"""
function _split_qualified_name(name::Expr)
@assert name.head == Symbol(".") _sqf_error
return name.args[1], _get_quoted_symbol(name.args[2])
end
_split_qualified_name(name::Symbol) = error(_sqf_error)
const _sqf_error = "Function is not qualified by module"