Skip to content

WIP: AD stuff #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 13 commits into from
4 changes: 3 additions & 1 deletion REQUIRE
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
julia 0.6
julia 1.0
MacroTools
LinearAlgebra
20 changes: 16 additions & 4 deletions src/DiffRules.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
__precompile__()

module DiffRules

include("api.jl")
include("rules.jl")
using MacroTools

const SymOrExpr = Union{Symbol, Expr}
const AVM = AbstractVecOrMat
const AM = AbstractMatrix

# Simple derivative expressions.
include("diffrules/api.jl")
include("diffrules/rules.jl")

# Forwards-mode stuff.
include("forward/api.jl")

# Reverse-mode stuff.
include("reverse/api.jl")
include("reverse/generic.jl")

end # module
33 changes: 20 additions & 13 deletions src/api.jl → src/diffrules/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,15 @@ Examples:
@define_diffrule Base.polygamma(m, x) = :NaN, :(polygamma(\$m + 1, \$x))

"""
macro define_diffrule(def)
@assert isa(def, Expr) && def.head == :(=) "Diff rule expression does not have a left and right side"
lhs = def.args[1]
rhs = def.args[2]
@assert isa(lhs, Expr) && lhs.head == :call "LHS is not a function call"
qualified_f = lhs.args[1]
@assert isa(qualified_f, Expr) && qualified_f.head == :(.) "Function is not qualified by module"
M = qualified_f.args[1]
f = _get_quoted_symbol(qualified_f.args[2])
args = lhs.args[2:end]
rule = Expr(:->, Expr(:tuple, args...), rhs)
key = Expr(:tuple, Expr(:quote, M), Expr(:quote, f), length(args))
macro define_diffrule(def::Expr)

def_ = splitdef(def)
module_, name_ = _split_qualified_name(def_[:name])

key = Expr(:tuple, Expr(:quote, module_), Expr(:quote, name_), length(def_[:args]))
expr = Expr(:->, Expr(:tuple, def_[:args]...), def_[:body])
return esc(quote
$DiffRules.DEFINED_DIFFRULES[$key] = $rule
$DiffRules.DEFINED_DIFFRULES[$key] = $expr
$key
end)
end
Expand Down Expand Up @@ -123,3 +118,15 @@ function _get_quoted_symbol(ex::QuoteNode)
@assert isa(ex.value, Symbol) "Function not a single symbol"
ex.value
end

"""
_split_qualified_name(name::Union{Symbol, Expr})

Split a function name qualified by a module into the module name and a name.
"""
function _split_qualified_name(name::Expr)
@assert name.head == Symbol(".") _sqf_error
return name.args[1], _get_quoted_symbol(name.args[2])
end
_split_qualified_name(name::Symbol) = error(_sqf_error)
const _sqf_error = "Function is not qualified by module"
File renamed without changes.
71 changes: 71 additions & 0 deletions src/forward/api.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
const ForwardRuleKey = Tuple{SymOrExpr, Symbol, Any}

# Key indicates the function in terms of (module_name, function_name). Each entry contains
# a vector of implementations for different type signatures. First element of entry is
# method signature, second entry is a (Tuple of) expressions giving the forwards-mode
# sensitivities.
const DEFINED_FORWARD_RULES = Dict{ForwardRuleKey, Any}()

"""
@forward_rule def

Define a new forward-mode sensitivity for `M.f`. The first `N` arguments are the arguments
to `M.f`, while the second `N` are the corresponding forwards-mode sensitivities w.r.t.
the respective argument.

NOTE: We don't currently have a mechanism for not including sensitivities w.r.t. a
particular argument, which seems kind of important.

Examples:

@forward_rule Base.cos(x::Real, ẋ::Real) = :(-\$ẋ * sin(\$x))
@forward_rule Main.foo(x, y, ẋ, ẏ) = :(\$x + \$ẋ - \$y * \$ẏ)
"""
macro forward_rule(def::Expr)
return esc(_forward_rule(def))
end

function _forward_rule(def::Expr)

# Split up function definition and assert no whereparams or kwargs.
def_ = splitdef(def)
@assert def_[:whereparams] == () "where parameters not currently supported"
@assert isempty(def_[:kwargs]) "There exists a keyword argument"

# Split up the arguments and assert no is slurps or default values.
args = splitarg.(def_[:args])
@assert all(arg->arg[3] === false, args) "At least one argument is slurping"
@assert all(arg->arg[4] === nothing, args) "At least one argument has a default value"

# Construct forward rule.
M, f = QuoteNode.(_split_qualified_name(def_[:name]))
signature = QuoteNode(:(Tuple{$(getindex.(args, 2)...)}))
expr = Expr(:->, Expr(:tuple, def_[:args]...), def_[:body])

return :(DiffRules.add_forward_rule!(($M, $f, $signature), $expr))
end

function add_forward_rule!(key::ForwardRuleKey, body::Any)
DEFINED_FORWARD_RULES[key] = body
end

arity(key::ForwardRuleKey) = length(getfield(key[3], :3))

# Create forward rules from all of the existing diff rules.
for ((M, f, nargs), rules) in DEFINED_DIFFRULES
if nargs == 1
add_forward_rule!(
(M, f, Tuple{Vararg{Real, 2}}),
(x::Symbol, ẋ::Symbol)->:($ẋ * $(rules(x))),
)
elseif nargs == 2
∂f∂x, ∂f∂y = rules(:x, :y)
(∂f∂x == :NaN || ∂f∂y == :NaN) && continue
add_forward_rule!(
(M, f, Tuple{Vararg{Real, 4}}),
(x::Symbol, y::Symbol, ẋ::Symbol, ẏ::Symbol)->:($ẋ * $∂f∂x + $ẏ * $∂f∂y),
)
else
error("Arrghh")
end
end
29 changes: 29 additions & 0 deletions src/reverse/DiffLinearAlgebra.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
__precompile__(true)

module DiffLinearAlgebra

using LinearAlgebra

# Some aliases used repeatedly throughout the package.
const AV, AM, AVM, AA = AbstractVector, AbstractMatrix, AbstractVecOrMat, AbstractArray
const SV, SM, SVM, SA = StridedVector, StridedMatrix, StridedVecOrMat, StridedArray
const AS, ASVM = Union{Real, AA}, Union{Real, AVM}
const Arg1, Arg2, Arg3 = Type{Val{1}}, Type{Val{2}}, Type{Val{3}}
const Arg4, Arg5, Arg6 = Type{Val{4}}, Type{Val{5}}, Type{Val{6}}
const BF = Union{Float32, Float64}
const DLA = DiffLinearAlgebra

export ∇, DLA, import_expr

# Define container and meta-data for defined operations.
include("util.jl")

# Define operations, and log them for external use via `ops`.
include("generic.jl")
include("blas.jl")
include("diagonal.jl")
include("triangular.jl")
include("uniformscaling.jl")
include("factorization/cholesky.jl")

end # module
94 changes: 94 additions & 0 deletions src/reverse/api.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# FORMAT of rule function: (y, ȳ, x₁, x₂, ...) where `y` is output, `ȳ` is sensitivity of
# output, `x₁, x₂, ...` are inputs. If sensitivity computation doesn't require one of the
# arguments, it simply won't be used in the resulting expression.

# A (<module_name>, <function_name>, <argument_signature>, <argument_numbers>)-Tuple.
const ReverseRuleKey = Tuple{SymOrExpr, Symbol, Expr, Tuple{Vararg{Int}}}

# All of the defined reverse rules. Keys are of the form:
# (Module, Function, type-tuple, argument numbers)
const DEFINED_REVERSE_RULES = Dict{ReverseRuleKey, Any}()

"""
@reverse_rule z z̄ M.f(wrt(x::Real), y) = :(...)

Define a new reverse-mode sensitivity for `M.f` w.r.t. the first argument. `z` is the output
from the forward-pass, `z̄` is the reverse-mode sensitivity w.r.t. `z`.

Examples:

@reverse_rule z::Real z̄ Base.cos(x::Real) = :(\$z̄̇ * sin(\$x))
@reverse_rule z z̄::Real Main.foo(x, y) = :(\$x + \$z - \$y * \$z̄)
"""
macro reverse_rule(z::SymOrExpr, z̄::SymOrExpr, def::Expr)
return esc(_reverse_rule(z, z̄, def))
end

function _reverse_rule(z::SymOrExpr, z̄::SymOrExpr, def::Expr)
def_ = splitdef(def)
@assert def_[:whereparams] == () "where parameters not currently supported"
@assert isempty(def_[:kwargs]) "There exists a keyword argument"

M, f = QuoteNode.(_split_qualified_name(def_[:name]))

wrts, args′ = process_args(def_[:args])
args = vcat(splitarg(z), splitarg(z̄), splitarg.(args′))
@assert all(arg->arg[3] === false, args) "At least one argument is slurping"
@assert all(arg->arg[4] === nothing, args) "At least one argument has a default value"
signature = QuoteNode(:(Tuple{$(getindex.(args, 2)...)}))

body = Expr(:->, Expr(:tuple, getfield.(args, 1)...), def_[:body])
return :(DiffRules.add_reverse_rule!(($M, $f, $signature, $wrts), $body))
end

function process_args(args::Array{Any})
wrts, args′ = Vector{Int}(), Vector{Any}(undef, length(args))
for (n, arg) in enumerate(args)
if arg isa Expr && arg.head == :call && arg.args[1] == :wrt
@assert length(arg.args) == 2
push!(wrts, n)
args′[n] = arg.args[2]
else
args′[n] = arg
end
end
return (wrts...,), args′
end

add_reverse_rule!(key::Tuple, rule::Any) = add_reverse_rule!(ReverseRuleKey(key), rule)
function add_reverse_rule!(key::ReverseRuleKey, rule::Any)
DEFINED_REVERSE_RULES[key] = rule
end

function arity(key::ReverseRuleKey)
typ = key[3]
@assert typ.head === :curly && typ.args[1] === :Tuple
return length(typ.args) - 1
end

function make_named_signature(names::AbstractVector, key::ReverseRuleKey)
return make_named_signature(names, key[3])
end
function make_named_signature(names::AbstractVector, type_tuple::Expr)
@assert type_tuple.head === :curly &&
type_tuple.args[1] === :Tuple &&
length(type_tuple.args) - 1 === length(names)
return [Expr(Symbol("::"), name, type) for (name, type) in zip(names, type_tuple.args[2:end])]
end

# Create reverse rules from all of the existing diff rules.
for ((M, f, nargs), rules) in DEFINED_DIFFRULES
if nargs == 1
reverse_rule = (z::Symbol, z̄::Symbol, x::Symbol)->:($z̄ * $(rules(x)))
add_reverse_rule!((M, f, :(Tuple{Real, Real, Real}), (1,)), reverse_rule)
elseif nargs == 2
∂f∂x, ∂f∂y = rules(:x, :y)
(∂f∂x == :NaN || ∂f∂y == :NaN) && continue
rev_rule_1 = (z::Symbol, z̄::Symbol, x::Symbol, y::Symbol)->:(z̄ * $(rules(x, y)[1]))
rev_rule_2 = (z::Symbol, z̄::Symbol, x::Symbol, y::Symbol)->:(z̄ * $(rules(x, y)[2]))
add_reverse_rule!((M, f, :(Tuple{Real, Real, Real, Real}), (1,)), rev_rule_1)
add_reverse_rule!((M, f, :(Tuple{Real, Real, Real, Real}), (2,)), rev_rule_2)
else
error("Arrghh")
end
end
Loading