diff --git a/REQUIRE b/REQUIRE index 137767a..fdb20e5 100644 --- a/REQUIRE +++ b/REQUIRE @@ -1 +1,3 @@ -julia 0.6 +julia 1.0 +MacroTools +LinearAlgebra diff --git a/src/DiffRules.jl b/src/DiffRules.jl index b70cae5..01ad5d0 100644 --- a/src/DiffRules.jl +++ b/src/DiffRules.jl @@ -1,8 +1,20 @@ -__precompile__() - module DiffRules -include("api.jl") -include("rules.jl") +using MacroTools + +const SymOrExpr = Union{Symbol, Expr} +const AVM = AbstractVecOrMat +const AM = AbstractMatrix + +# Simple derivative expressions. +include("diffrules/api.jl") +include("diffrules/rules.jl") + +# Forwards-mode stuff. +include("forward/api.jl") + +# Reverse-mode stuff. +include("reverse/api.jl") +include("reverse/generic.jl") end # module diff --git a/src/api.jl b/src/diffrules/api.jl similarity index 82% rename from src/api.jl rename to src/diffrules/api.jl index 24394af..3e449d5 100644 --- a/src/api.jl +++ b/src/diffrules/api.jl @@ -23,20 +23,15 @@ Examples: @define_diffrule Base.polygamma(m, x) = :NaN, :(polygamma(\$m + 1, \$x)) """ -macro define_diffrule(def) - @assert isa(def, Expr) && def.head == :(=) "Diff rule expression does not have a left and right side" - lhs = def.args[1] - rhs = def.args[2] - @assert isa(lhs, Expr) && lhs.head == :call "LHS is not a function call" - qualified_f = lhs.args[1] - @assert isa(qualified_f, Expr) && qualified_f.head == :(.) "Function is not qualified by module" - M = qualified_f.args[1] - f = _get_quoted_symbol(qualified_f.args[2]) - args = lhs.args[2:end] - rule = Expr(:->, Expr(:tuple, args...), rhs) - key = Expr(:tuple, Expr(:quote, M), Expr(:quote, f), length(args)) +macro define_diffrule(def::Expr) + + def_ = splitdef(def) + module_, name_ = _split_qualified_name(def_[:name]) + + key = Expr(:tuple, Expr(:quote, module_), Expr(:quote, name_), length(def_[:args])) + expr = Expr(:->, Expr(:tuple, def_[:args]...), def_[:body]) return esc(quote - $DiffRules.DEFINED_DIFFRULES[$key] = $rule + $DiffRules.DEFINED_DIFFRULES[$key] = $expr $key end) end @@ -123,3 +118,15 @@ function _get_quoted_symbol(ex::QuoteNode) @assert isa(ex.value, Symbol) "Function not a single symbol" ex.value end + +""" + _split_qualified_name(name::Union{Symbol, Expr}) + +Split a function name qualified by a module into the module name and a name. +""" +function _split_qualified_name(name::Expr) + @assert name.head == Symbol(".") _sqf_error + return name.args[1], _get_quoted_symbol(name.args[2]) +end +_split_qualified_name(name::Symbol) = error(_sqf_error) +const _sqf_error = "Function is not qualified by module" diff --git a/src/rules.jl b/src/diffrules/rules.jl similarity index 100% rename from src/rules.jl rename to src/diffrules/rules.jl diff --git a/src/forward/api.jl b/src/forward/api.jl new file mode 100644 index 0000000..2295214 --- /dev/null +++ b/src/forward/api.jl @@ -0,0 +1,71 @@ +const ForwardRuleKey = Tuple{SymOrExpr, Symbol, Any} + +# Key indicates the function in terms of (module_name, function_name). Each entry contains +# a vector of implementations for different type signatures. First element of entry is +# method signature, second entry is a (Tuple of) expressions giving the forwards-mode +# sensitivities. +const DEFINED_FORWARD_RULES = Dict{ForwardRuleKey, Any}() + +""" + @forward_rule def + +Define a new forward-mode sensitivity for `M.f`. The first `N` arguments are the arguments +to `M.f`, while the second `N` are the corresponding forwards-mode sensitivities w.r.t. +the respective argument. + +NOTE: We don't currently have a mechanism for not including sensitivities w.r.t. a +particular argument, which seems kind of important. + +Examples: + + @forward_rule Base.cos(x::Real, ẋ::Real) = :(-\$ẋ * sin(\$x)) + @forward_rule Main.foo(x, y, ẋ, ẏ) = :(\$x + \$ẋ - \$y * \$ẏ) +""" +macro forward_rule(def::Expr) + return esc(_forward_rule(def)) +end + +function _forward_rule(def::Expr) + + # Split up function definition and assert no whereparams or kwargs. + def_ = splitdef(def) + @assert def_[:whereparams] == () "where parameters not currently supported" + @assert isempty(def_[:kwargs]) "There exists a keyword argument" + + # Split up the arguments and assert no is slurps or default values. + args = splitarg.(def_[:args]) + @assert all(arg->arg[3] === false, args) "At least one argument is slurping" + @assert all(arg->arg[4] === nothing, args) "At least one argument has a default value" + + # Construct forward rule. + M, f = QuoteNode.(_split_qualified_name(def_[:name])) + signature = QuoteNode(:(Tuple{$(getindex.(args, 2)...)})) + expr = Expr(:->, Expr(:tuple, def_[:args]...), def_[:body]) + + return :(DiffRules.add_forward_rule!(($M, $f, $signature), $expr)) +end + +function add_forward_rule!(key::ForwardRuleKey, body::Any) + DEFINED_FORWARD_RULES[key] = body +end + +arity(key::ForwardRuleKey) = length(getfield(key[3], :3)) + +# Create forward rules from all of the existing diff rules. +for ((M, f, nargs), rules) in DEFINED_DIFFRULES + if nargs == 1 + add_forward_rule!( + (M, f, Tuple{Vararg{Real, 2}}), + (x::Symbol, ẋ::Symbol)->:($ẋ * $(rules(x))), + ) + elseif nargs == 2 + ∂f∂x, ∂f∂y = rules(:x, :y) + (∂f∂x == :NaN || ∂f∂y == :NaN) && continue + add_forward_rule!( + (M, f, Tuple{Vararg{Real, 4}}), + (x::Symbol, y::Symbol, ẋ::Symbol, ẏ::Symbol)->:($ẋ * $∂f∂x + $ẏ * $∂f∂y), + ) + else + error("Arrghh") + end +end diff --git a/src/reverse/DiffLinearAlgebra.jl b/src/reverse/DiffLinearAlgebra.jl new file mode 100644 index 0000000..868d01d --- /dev/null +++ b/src/reverse/DiffLinearAlgebra.jl @@ -0,0 +1,29 @@ +__precompile__(true) + +module DiffLinearAlgebra + + using LinearAlgebra + + # Some aliases used repeatedly throughout the package. + const AV, AM, AVM, AA = AbstractVector, AbstractMatrix, AbstractVecOrMat, AbstractArray + const SV, SM, SVM, SA = StridedVector, StridedMatrix, StridedVecOrMat, StridedArray + const AS, ASVM = Union{Real, AA}, Union{Real, AVM} + const Arg1, Arg2, Arg3 = Type{Val{1}}, Type{Val{2}}, Type{Val{3}} + const Arg4, Arg5, Arg6 = Type{Val{4}}, Type{Val{5}}, Type{Val{6}} + const BF = Union{Float32, Float64} + const DLA = DiffLinearAlgebra + + export ∇, DLA, import_expr + + # Define container and meta-data for defined operations. + include("util.jl") + + # Define operations, and log them for external use via `ops`. + include("generic.jl") + include("blas.jl") + include("diagonal.jl") + include("triangular.jl") + include("uniformscaling.jl") + include("factorization/cholesky.jl") + +end # module diff --git a/src/reverse/api.jl b/src/reverse/api.jl new file mode 100644 index 0000000..5ecff72 --- /dev/null +++ b/src/reverse/api.jl @@ -0,0 +1,94 @@ +# FORMAT of rule function: (y, ȳ, x₁, x₂, ...) where `y` is output, `ȳ` is sensitivity of +# output, `x₁, x₂, ...` are inputs. If sensitivity computation doesn't require one of the +# arguments, it simply won't be used in the resulting expression. + +# A (, , , )-Tuple. +const ReverseRuleKey = Tuple{SymOrExpr, Symbol, Expr, Tuple{Vararg{Int}}} + +# All of the defined reverse rules. Keys are of the form: +# (Module, Function, type-tuple, argument numbers) +const DEFINED_REVERSE_RULES = Dict{ReverseRuleKey, Any}() + +""" + @reverse_rule z z̄ M.f(wrt(x::Real), y) = :(...) + +Define a new reverse-mode sensitivity for `M.f` w.r.t. the first argument. `z` is the output +from the forward-pass, `z̄` is the reverse-mode sensitivity w.r.t. `z`. + +Examples: + + @reverse_rule z::Real z̄ Base.cos(x::Real) = :(\$z̄̇ * sin(\$x)) + @reverse_rule z z̄::Real Main.foo(x, y) = :(\$x + \$z - \$y * \$z̄) +""" +macro reverse_rule(z::SymOrExpr, z̄::SymOrExpr, def::Expr) + return esc(_reverse_rule(z, z̄, def)) +end + +function _reverse_rule(z::SymOrExpr, z̄::SymOrExpr, def::Expr) + def_ = splitdef(def) + @assert def_[:whereparams] == () "where parameters not currently supported" + @assert isempty(def_[:kwargs]) "There exists a keyword argument" + + M, f = QuoteNode.(_split_qualified_name(def_[:name])) + + wrts, args′ = process_args(def_[:args]) + args = vcat(splitarg(z), splitarg(z̄), splitarg.(args′)) + @assert all(arg->arg[3] === false, args) "At least one argument is slurping" + @assert all(arg->arg[4] === nothing, args) "At least one argument has a default value" + signature = QuoteNode(:(Tuple{$(getindex.(args, 2)...)})) + + body = Expr(:->, Expr(:tuple, getfield.(args, 1)...), def_[:body]) + return :(DiffRules.add_reverse_rule!(($M, $f, $signature, $wrts), $body)) +end + +function process_args(args::Array{Any}) + wrts, args′ = Vector{Int}(), Vector{Any}(undef, length(args)) + for (n, arg) in enumerate(args) + if arg isa Expr && arg.head == :call && arg.args[1] == :wrt + @assert length(arg.args) == 2 + push!(wrts, n) + args′[n] = arg.args[2] + else + args′[n] = arg + end + end + return (wrts...,), args′ +end + +add_reverse_rule!(key::Tuple, rule::Any) = add_reverse_rule!(ReverseRuleKey(key), rule) +function add_reverse_rule!(key::ReverseRuleKey, rule::Any) + DEFINED_REVERSE_RULES[key] = rule +end + +function arity(key::ReverseRuleKey) + typ = key[3] + @assert typ.head === :curly && typ.args[1] === :Tuple + return length(typ.args) - 1 +end + +function make_named_signature(names::AbstractVector, key::ReverseRuleKey) + return make_named_signature(names, key[3]) +end +function make_named_signature(names::AbstractVector, type_tuple::Expr) + @assert type_tuple.head === :curly && + type_tuple.args[1] === :Tuple && + length(type_tuple.args) - 1 === length(names) + return [Expr(Symbol("::"), name, type) for (name, type) in zip(names, type_tuple.args[2:end])] +end + +# Create reverse rules from all of the existing diff rules. +for ((M, f, nargs), rules) in DEFINED_DIFFRULES + if nargs == 1 + reverse_rule = (z::Symbol, z̄::Symbol, x::Symbol)->:($z̄ * $(rules(x))) + add_reverse_rule!((M, f, :(Tuple{Real, Real, Real}), (1,)), reverse_rule) + elseif nargs == 2 + ∂f∂x, ∂f∂y = rules(:x, :y) + (∂f∂x == :NaN || ∂f∂y == :NaN) && continue + rev_rule_1 = (z::Symbol, z̄::Symbol, x::Symbol, y::Symbol)->:(z̄ * $(rules(x, y)[1])) + rev_rule_2 = (z::Symbol, z̄::Symbol, x::Symbol, y::Symbol)->:(z̄ * $(rules(x, y)[2])) + add_reverse_rule!((M, f, :(Tuple{Real, Real, Real, Real}), (1,)), rev_rule_1) + add_reverse_rule!((M, f, :(Tuple{Real, Real, Real, Real}), (2,)), rev_rule_2) + else + error("Arrghh") + end +end diff --git a/src/reverse/blas.jl b/src/reverse/blas.jl new file mode 100644 index 0000000..209b189 --- /dev/null +++ b/src/reverse/blas.jl @@ -0,0 +1,515 @@ +import LinearAlgebra: dot +import LinearAlgebra.BLAS: asum, blascopy!, nrm2, scal, scal!, gemm, gemm!, gemv, gemv!, + syrk, symm, symm!, symv, symv!, trmm, trsm, trmv, trsv, trsv!, ger! + +################################## Level 1 ################################## + +# Unit-stride `dot`. +push!(ops, DiffOp(:(LinearAlgebra.dot), + :(Tuple{DLA.SA{<:DLA.BF}, DLA.SA{<:DLA.BF}}), + [true, true] +)) +∇(::typeof(LinearAlgebra.dot), ::Arg1, p, z::BF, z̄::BF, x::SA{<:BF}, y::SA{<:BF}) = z̄ .* y +∇(::typeof(LinearAlgebra.dot), ::Arg2, p, z::BF, z̄::BF, x::SA{<:BF}, y::SA{<:BF}) = z̄ .* x +function ∇(x̄, ::typeof(LinearAlgebra.dot), ::Arg1, p, z::BF, z̄::BF, x::SA{<:BF}, y::SA{<:BF}) + x̄ .= x̄ .+ z̄ .* y + return x̄ +end +function ∇(ȳ, ::typeof(LinearAlgebra.dot), ::Arg2, p, z::BF, z̄::BF, x::SA{<:BF}, y::SA{<:BF}) + ȳ .= ȳ .+ z̄ .* x + return ȳ +end + +# Arbitrary-stride `dot`. +push!(ops, DiffOp(:(LinearAlgebra.BLAS.dot), + :(Tuple{Int, DLA.SA{<:DLA.BF}, Int, DLA.SA{<:DLA.BF}, Int}), + [false, true, false, true, false], +)) +∇(::typeof(BLAS.dot), ::Arg2, p, z::BF, z̄::BF, n::Int, x::SA{<:BF}, ix::Int, y::SA{<:BF}, iy::Int) = + scal!(n, z̄, blascopy!(n, y, iy, fill!(similar(x), zero(eltype(x))), ix), ix) +∇(::typeof(BLAS.dot), ::Arg4, p, z::BF, z̄::BF, n::Int, x::SA{<:BF}, ix::Int, y::SA{<:BF}, iy::Int) = + scal!(n, z̄, blascopy!(n, x, ix, fill!(similar(y), zero(eltype(y))), iy), iy) +function ∇(x̄, ::typeof(BLAS.dot), ::Arg2, p, z::BF, z̄::BF, n::Int, x::SA{<:BF}, ix::Int, y::SA{<:BF}, iy::Int) + x̄ .= x̄ .+ scal!(n, z̄, blascopy!(n, y, iy, fill!(similar(x), zero(eltype(x))), ix), ix) + return x̄ +end +function ∇(ȳ, ::typeof(BLAS.dot), ::Arg4, p, z::BF, z̄::BF, n::Int, x::SA{<:BF}, ix::Int, y::SA{<:BF}, iy::Int) + ȳ .= ȳ .+ scal!(n, z̄, blascopy!(n, x, ix, fill!(similar(y), zero(eltype(y))), iy), iy) + return ȳ +end + +# Unit-stride `nrm2`. +push!(ops, DiffOp(:(LinearAlgebra.BLAS.nrm2), + :(Tuple{DLA.SA{<:DLA.BF}}), + [true] +)) +∇(::typeof(nrm2), ::Arg1, p, y::BF, ȳ::BF, x::SA{<:BF}) = x .* (ȳ / y) +function ∇(x̄::AA, ::typeof(nrm2), ::Arg1, p, y::BF, ȳ::BF, x::SA{<:BF}) + x̄ .= x̄ .+ x .* (ȳ / y) + return x̄ +end + +# Arbitrary-stride `nrm2`. +push!(ops, DiffOp(:(LinearAlgebra.BLAS.nrm2), + :(Tuple{Integer, DLA.SA{<:DLA.BF}, Integer}), + [false, true, false] +)) +∇(::typeof(nrm2), ::Arg2, p, y::BF, ȳ::BF, n::Integer, x::SA{<:BF}, inc::Integer) = + scal!(n, ȳ / y, blascopy!(n, x, inc, fill!(similar(x), zero(eltype(x))), inc), inc) +function ∇(x̄::SA{<:BF}, ::typeof(nrm2), ::Arg2, p, y::BF, ȳ::BF, n::Integer, x::SA{<:BF}, inc::Integer) + x̄ .= x̄ .+ scal!(n, ȳ / y, blascopy!(n, x, inc, fill!(similar(x), zero(eltype(x))), inc), inc) + return x̄ +end + +# Unit-stride `asum`. +push!(ops, DiffOp(:(LinearAlgebra.BLAS.asum), + :(Tuple{DLA.SA{<:DLA.BF}}), + [true] +)) +∇(::typeof(asum), ::Arg1, p, y::BF, ȳ::BF, x::SA{<:BF}) = ȳ .* sign.(x) +function ∇(x̄::AA, ::typeof(asum), ::Arg1, p, y::BF, ȳ::BF, x::SA{<:BF}) + x̄ .= x̄ .+ ȳ .* sign.(x) + return x̄ +end + +# Arbitrary-stride `asum`. +push!(ops, DiffOp(:(LinearAlgebra.BLAS.asum), + :(Tuple{Integer, DLA.SA{<:DLA.BF}, Integer}), + [false, true, false] +)) +∇(::typeof(asum), ::Arg2, p, y::BF, ȳ::BF, n::Integer, x, inc::Integer) = + scal!(n, ȳ, blascopy!(n, sign.(x), inc, fill!(similar(x), zero(eltype(x))), inc), inc) +function ∇(x̄::SA{<:BF}, ::typeof(asum), ::Arg2, p, y::BF, ȳ::BF, n::Integer, x::SA{<:BF}, inc::Integer) + x̄ .= x̄ .+ scal!(n, ȳ, blascopy!(n, sign.(x), inc, fill!(similar(x), zero(eltype(x))), inc), inc) + return x̄ +end +# Some weird stuff going on that I haven't figured out yet. This is a very old attempt. +# let f = :(scal{T <: AbstractArray, V <: AbstractFloat}) +# ā = :(blascopy!(n, z̄, inc, zeros(X), inc) .* X) +# X̄ = :(scal!(n, a, z̄, inc)) +# end + + +################################## Level 2 ################################## + +# `gemv` sensitivities implementation. +push!(ops, DiffOp(:(LinearAlgebra.BLAS.gemv), + :(Tuple{Char, T, DLA.SM{T}, DLA.SV{T}} where T<:DLA.BF), + [false, true, true, true], +)) +∇(::typeof(gemv), ::Arg2, _, y::SV{T}, ȳ::SV, tA::Char, α::T, A::SM{T}, x::SV{T}) where T<:BF = + dot(ȳ, y) / α +∇(::typeof(gemv), ::Arg3, _, y::SV{T}, ȳ::SV, tA::Char, α::T, A::SM{T}, x::SV{T}) where T<:BF = + uppercase(tA) == 'N' ? α * ȳ * x' : α * x * ȳ' +∇(Ā::SM{T}, ::typeof(gemv), ::Arg3, _, y::SV{T}, ȳ::SV{T}, tA::Char, α::T, A::SM{T}, x::SV{T}) where T<:BF = + uppercase(tA) == 'N' ? ger!(α, ȳ, x, Ā) : ger!(α, x, ȳ, Ā) +∇(::typeof(gemv), ::Arg4, _, y::SV{T}, ȳ::SV{T}, tA::Char, α::T, A::SM{T}, x::SV{T}) where T<:BF = + gemv(uppercase(tA) == 'N' ? 'T' : 'N', α, A, ȳ) +∇(x̄::SV{T}, ::typeof(gemv), ::Arg4, _, y::SV{T}, ȳ::SV{T}, tA::Char, α::T, A::SM{T}, x::SV{T}) where T<:BF = + gemv!(uppercase(tA) == 'N' ? 'T' : 'N', α, A, ȳ, one(T), x̄) + +# `gemv` sensitivities implementation with `α = 1`. +push!(ops, DiffOp(:(LinearAlgebra.BLAS.gemv), + :(Tuple{Char, DLA.SM{T}, DLA.SV{T}} where T<:DLA.BF), + [false, true, true], +)) +∇(::typeof(gemv), ::Arg2, p, y::SV{T}, ȳ::SV{T}, tA::Char, A::SM{T}, x::SV{T}) where T<:BF = + ∇(gemv, Val{3}, p, y, ȳ, tA, one(T), A, x) +∇(Ā::SM{T}, ::typeof(gemv), ::Arg2, p, y::SV{T}, ȳ::SV{T}, tA::Char, A::SM{T}, x::SV{T}) where T<:BF = + ∇(Ā, gemv, Val{3}, p, y, ȳ, tA, one(T), A, x) +∇(::typeof(gemv), ::Arg3, p, y::SV{T}, ȳ::SV{T}, tA::Char, A::SM{T}, x::SV{T}) where T<:BF = + ∇(gemv, Val{4}, p, y, ȳ, tA, one(T), A, x) +∇(x̄::SV{T}, ::typeof(gemv), ::Arg3, p, y::SV{T}, ȳ::SV{T}, tA::Char, A::SM{T}, x::SV{T}) where T<:BF = + ∇(x̄, gemv, Val{4}, p, y, ȳ, tA, one(T), A, x) + +# `symv` sensitivity implementations. +push!(ops, DiffOp(:(LinearAlgebra.BLAS.symv), + :(Tuple{Char, T, DLA.SM{T}, DLA.SV{T}} where T<:DLA.BF), + [false, true, true, true], +)) +∇(::typeof(symv), ::Arg2, p, y::SV{T}, ȳ::SV{T}, ul::Char, α::T, A::SM{T}, x::SV{T}) where T<:BF = + dot(ȳ, y) / α +function ∇(::typeof(symv), ::Arg3, p, y::SV{T}, ȳ::SV{T}, ul::Char, α::T, A::SM{T}, x::SV{T}) where T<:BF + Y, Ȳ, X = reshape(y, length(y), 1), reshape(ȳ, length(ȳ), 1), reshape(x, length(x), 1) + return ∇(symm, Val{4}, p, Y, Ȳ, 'L', ul, α, A, X) +end +function ∇(Ā::SM{T}, ::typeof(symv), ::Arg3, p, y::SV{T}, ȳ::SV{T}, ul::Char, α::T, A::SM{T}, x::SV{T}) where T<:BF + Y, Ȳ, X = reshape(y, length(y), 1), reshape(ȳ, length(ȳ), 1), reshape(x, length(x), 1) + return ∇(Ā, symm, Val{4}, p, Y, Ȳ, 'L', ul, α, A, X) +end +∇(::typeof(symv), ::Arg4, p, y::SV{T}, ȳ::SV{T}, ul::Char, α::T, A::SM{T}, x::SV{T}) where T<:BF = + symv(ul, α, A, ȳ) +∇(x̄::SV{T}, ::typeof(symv), ::Arg4, p, y::SV{T}, ȳ::SV{T}, ul::Char, α::T, A::SM{T}, x::SV{T}) where T<:BF = + symv!(ul, α, A, ȳ, one(T), x̄) + +# `symv` sensitivity implementations for `α=1`. +push!(ops, DiffOp(:(LinearAlgebra.BLAS.symv), + :(Tuple{Char, DLA.SM{T}, DLA.SV{T}} where T<:DLA.BF), + [false, true, true], +)) +∇(::typeof(symv), ::Arg2, p, y::SV{T}, ȳ::SV{T}, ul::Char, A::SM{T}, x::SV{T}) where T<:BF = + ∇(symv, Val{3}, p, y, ȳ, ul, one(T), A, x) +∇(Ā::SM{T}, ::typeof(symv), ::Arg2, p, y::SV{T}, ȳ::SV{T}, ul::Char, A::SM{T}, x::SV{T}) where T<:BF = + ∇(Ā, symv, Val{3}, p, y::SV{T}, ȳ::SV{T}, ul, one(T), A, x) +∇(::typeof(symv), ::Arg3, p, y::SV{T}, ȳ::SV{T}, ul::Char, A::SM{T}, x::SV{T}) where T<:BF = + ∇(symv, Val{4}, p, y, ȳ, ul, one(T), A, x) +∇(B̄::SV{T}, ::typeof(symv), ::Arg3, p, y::SV{T}, ȳ::SV{T}, ul::Char, A::SM{T}, x::SV{T}) where T<:BF = + ∇(B̄, symv, Val{4}, p, y, ȳ, ul, one(T), A, x) + +# `trmv` sensitivity implementations. +push!(ops, DiffOp(:(LinearAlgebra.BLAS.trmv), + :(Tuple{Char, Char, Char, DLA.SM{T}, DLA.SV{T}} where T<:DLA.BF), + [false, false, false, true, true], +)) +function ∇(::typeof(trmv), ::Arg4, p, y::SV{T}, ȳ::SV{T}, + ul::Char, ta::Char, dA::Char, + A::SM{T}, + b::SV{T}, +) where T<:BF + Ā = (uppercase(ul) == 'L' ? tril! : triu!)(uppercase(ta) == 'N' ? ȳ * b' : b * ȳ') + dA == 'U' && fill!(view(Ā, diagind(Ā)), zero(T)) + return Ā +end +∇(::typeof(trmv), ::Arg5, p, y::SV{T}, ȳ::SV{T}, + ul::Char, ta::Char, dA::Char, + A::SM{T}, + b::SV{T}, +) where T<:BF = trmv(ul, uppercase(ta) == 'N' ? 'T' : 'N', dA, A, ȳ) + +# `trsv` sensitivity implementations. +push!(ops, DiffOp(:(LinearAlgebra.BLAS.trsv), + :(Tuple{Char, Char, Char, DLA.SM{T}, DLA.SV{T}} where T<:DLA.BF), + [false, false, false, true, true], +)) +function ∇(::typeof(trsv), ::Arg4, p, y::SV{T}, ȳ::SV{T}, + ul::Char, ta::Char, dA::Char, + A::SM{T}, + x::SV{T}, +) where T<:BF + Y, Ȳ, X = reshape(y, length(y), 1), reshape(ȳ, length(ȳ), 1), reshape(x, length(x), 1) + Ā = ∇(trsm, Val{6}, p, Y, Ȳ, 'L', ul, ta, dA, one(T), A, X) + dA == 'U' && fill!(view(Ā, diagind(Ā)), zero(T)) + return Ā +end +∇(::typeof(trsv), ::Arg5, p, y::SV{T}, ȳ::SV{T}, + ul::Char, ta::Char, dA::Char, + A::SM{T}, + x::SV{T}, +) where T<:BF = trsv(ul, uppercase(ta) == 'N' ? 'T' : 'N', dA, A, ȳ) + + +################################## Level 3 ################################## + +# `gemm` sensitivities implementation. +push!(ops, DiffOp(:(LinearAlgebra.BLAS.gemm), + :(Tuple{Char, Char, T, DLA.SM{T}, DLA.SM{T}} where T<:DLA.BF), + [false, false, true, true, true], +)) +∇(::typeof(gemm), ::Arg3, p, Y::SM{T}, Ȳ::SM{T}, + tA::Char, + tB::Char, + α::T, + A::SM{T}, + B::SM{T}, +) where T<:BF = sum(Ȳ .* Y) / α +∇(::typeof(gemm), ::Arg4, p, Y::SM{T}, Ȳ::SM{T}, + tA::Char, + tB::Char, + α::T, + A::SM{T}, + B::SM{T}, +) where T<:BF = + uppercase(tA) == 'N' ? + uppercase(tB) == 'N' ? + gemm('N', 'T', α, Ȳ, B) : + gemm('N', 'N', α, Ȳ, B) : + uppercase(tB) == 'N' ? + gemm('N', 'T', α, B, Ȳ) : + gemm('T', 'T', α, B, Ȳ) +∇(Ā::SM{T}, ::typeof(gemm), ::Arg4, _, Y::SM{T}, Ȳ::SM{T}, + tA::Char, + tB::Char, + α::T, + A::SM{T}, + B::SM{T}, +) where T<:BF = + uppercase(tA) == 'N' ? + uppercase(tB) == 'N' ? + gemm!('N', 'T', α, Ȳ, B, one(T), Ā) : + gemm!('N', 'N', α, Ȳ, B, one(T), Ā) : + uppercase(tB) == 'N' ? + gemm!('N', 'T', α, B, Ȳ, one(T), Ā) : + gemm!('T', 'T', α, B, Ȳ, one(T), Ā) +∇(::typeof(gemm), ::Arg5, p, Y::SM{T}, Ȳ::SM{T}, + tA::Char, + tB::Char, + α::T, + A::SM{T}, + B::SM{T}, +) where T<:BF = + uppercase(tA) == 'N' ? + uppercase(tB) == 'N' ? + gemm('T', 'N', α, A, Ȳ) : + gemm('T', 'N', α, Ȳ, A) : + uppercase(tB) == 'N' ? + gemm('N', 'N', α, A, Ȳ) : + gemm('T', 'T', α, Ȳ, A) +∇(B̄::SM{T}, ::typeof(gemm), ::Arg5, _, Y::SM{T}, Ȳ::SM{T}, + tA::Char, + tB::Char, + α::T, + A::SM{T}, + B::SM{T}, +) where T<:BF = + uppercase(tA) == 'N' ? + uppercase(tB) == 'N' ? + gemm!('T', 'N', α, A, Ȳ, one(T), B̄) : + gemm!('T', 'N', α, Ȳ, A, one(T), B̄) : + uppercase(tB) == 'N' ? + gemm!('N', 'N', α, A, Ȳ, one(T), B̄) : + gemm!('T', 'T', α, Ȳ, A, one(T), B̄) + +# `gemm` sensitivities implementation for `α = 1`. +push!(ops, DiffOp(:(LinearAlgebra.BLAS.gemm), + :(Tuple{Char, Char, DLA.SM{T}, DLA.SM{T}} where T<:DLA.BF), + [false, false, true, true], +)) +∇(::typeof(gemm), ::Arg3, p, Y::SM{T}, Ȳ::SM{T}, + tA::Char, + tB::Char, + A::SM{T}, + B::SM{T} +) where T<:BF = ∇(gemm, Val{4}, p, Y, Ȳ, tA, tB, one(T), A, B) +∇(Ā::SM{T}, ::typeof(gemm), ::Arg3, p, Y::SM{T}, Ȳ::SM{T}, + tA::Char, + tB::Char, + A::SM{T}, + B::SM{T}, +) where T<:BF = ∇(Ā, gemm, Val{4}, p, Y, Ȳ, tA, tB, one(T), A, B) +∇(::typeof(gemm), ::Arg4, p, Y::SM{T}, Ȳ::SM{T}, + tA::Char, + tB::Char, + A::SM{T}, + B::SM{T}, +) where T<:BF = ∇(gemm, Val{5}, p, Y, Ȳ, tA, tB, one(T), A, B) +∇(B̄::SM{T}, ::typeof(gemm), ::Arg4, p, Y::SM{T}, Ȳ::SM{T}, + tA::Char, + tB::Char, + A::SM{T}, + B::SM{T}, +) where T<:BF = ∇(B̄, gemm, Val{5}, p, Y, Ȳ, tA, tB, one(T), A, B) + +# # `syrk` sensitivity implementations. +# @explicit_intercepts( +# syrk, +# Tuple{Char, Char, ∇Scalar, StridedVecOrMat{<:∇Scalar}}, +# [false, false, true, true], +# ) +# function ∇(::typeof(syrk), ::Type{Arg{3}}, p, Y, Ȳ, +# uplo::Char, +# trans::Char, +# α::∇Scalar, +# A::StridedVecOrMat{<:∇Scalar}, +# ) +# g! = uppercase(uplo) == 'L' ? tril! : triu! +# return sum(g!(Ȳ .* Y)) / α +# end +# function ∇(::typeof(syrk), ::Type{Arg{4}}, p, Y, Ȳ, +# uplo::Char, +# trans::Char, +# α::∇Scalar, +# A::StridedVecOrMat{<:∇Scalar}, +# ) +# triȲ = uppercase(uplo) == 'L' ? tril(Ȳ) : triu(Ȳ) +# out = gemm('N', trans, α, triȲ .+ triȲ', A) +# return uppercase(trans) == 'N' ? out : out' +# end +# function ∇(Ā::StridedVecOrMat{T}, ::typeof(syrk), ::Type{Arg{4}}, p, Y, Ȳ, +# uplo::Char, +# trans::Char, +# α::∇Scalar, +# A::StridedVecOrMat{T}, +# ) where T<:∇Scalar +# triȲ = uppercase(uplo) == 'L' ? tril(Ȳ) : triu(Ȳ) +# out = gemm('N', trans, α, triȲ .+ triȲ', A) +# return broadcast!((ā, δā)->ā+δā, Ā, Ā, uppercase(trans) == 'N' ? out : out') +# end + +# # `syrk` sensitivity implementations for `α=1`. +# @explicit_intercepts( +# syrk, +# Tuple{Char, Char, StridedVecOrMat{<:∇Scalar}}, +# [false, false, true], +# ) +# ∇(::typeof(syrk), ::Type{Arg{3}}, p, Y, Ȳ, +# uplo::Char, +# trans::Char, +# A::StridedVecOrMat{<:∇Scalar}, +# ) = ∇(syrk, Arg{4}, p, Y, Ȳ, uplo, trans, one(eltype(A)), A) +# ∇(Ā::StridedVecOrMat{T}, ::typeof(syrk), ::Type{Arg{4}}, p, Y, Ȳ, +# uplo::Char, +# trans::Char, +# A::StridedVecOrMat{T}, +# ) where T<:∇Scalar = ∇(Ā, syrk, Arg{4}, p, Y, Ȳ, uplo, char, one(eltype(A)), A) + +# `symm` sensitivity implementations. +push!(ops, DiffOp(:(LinearAlgebra.BLAS.symm), + :(Tuple{Char, Char, T, DLA.SM{T}, DLA.SM{T}} where T<:DLA.BF), + [false, false, true, true, true], +)) +∇(::typeof(symm), ::Arg3, p, Y::SM{T}, Ȳ::SM{T}, + side::Char, + ul::Char, + α::T, + A::SM{T}, + B::SM{T}, +) where T<:BF = sum(Ȳ .* Y) / α +function ∇(::typeof(symm), ::Arg4, p, Y::SM{T}, Ȳ::SM{T}, + side::Char, + ul::Char, + α::T, + A::SM{T}, + B::SM{T}, +) where T<:BF + tmp = uppercase(side) == 'L' ? Ȳ * B' : B'Ȳ + g! = uppercase(ul) == 'L' ? tril! : triu! + return α * g!(tmp + tmp' - Diagonal(tmp)) +end +function ∇(Ā::SM{T}, ::typeof(symm), ::Arg4, p, Y::SM{T}, Ȳ::SM{T}, + side::Char, + ul::Char, + α::T, + A::SM{T}, + B::SM{T}, +) where T<:BF + tmp = uppercase(side) == 'L' ? Ȳ * B' : B'Ȳ + g! = uppercase(ul) == 'L' ? tril! : triu! + return broadcast!((ā, δā)->ā + δā, Ā, Ā, α * g!(tmp + tmp' - Diagonal(tmp))) +end +∇(::typeof(symm), ::Arg5, p, Y::SM{T}, Ȳ::SM{T}, + side::Char, + ul::Char, + α::T, + A::SM{T}, + B::SM{T}, +) where T<:BF = symm(side, ul, α, A, Ȳ) +∇(B̄::SM{T}, ::typeof(symm), ::Arg5, p, Y::SM{T}, Ȳ::SM{T}, + side::Char, + ul::Char, + α::T, + A::SM{T}, + B::SM{T}, +) where T<:BF = symm!(side, ul, α, A, Ȳ, one(T), B̄) + +# `symm` sensitivity implementations for `α=1`. +push!(ops, DiffOp(:(LinearAlgebra.BLAS.symm), + :(Tuple{Char, Char, DLA.SM{T}, DLA.SM{T}} where T<:DLA.BF), + [false, false, true, true], +)) +∇(::typeof(symm), ::Arg3, p, Y::SM{T}, Ȳ::SM{T}, + side::Char, + ul::Char, + A::SM{T}, + B::SM{T}, +) where T<:BF = ∇(symm, Val{4}, p, Y, Ȳ, side, ul, one(T), A, B) +∇(Ā::SM{T}, ::typeof(symm), ::Arg3, p, Y::SM{T}, Ȳ::SM{T}, + side::Char, + ul::Char, + A::SM{T}, + B::SM{T}, +) where T<:BF = ∇(Ā, symm, Val{4}, p, Y, Ȳ, side, ul, one(T), A, B) +∇(::typeof(symm), ::Arg4, p, Y::SM{T}, Ȳ::SM{T}, + side::Char, + ul::Char, + A::SM{T}, + B::SM{T}, +) where T<:BF = ∇(symm, Val{5}, p, Y, Ȳ, side, ul, one(T), A, B) +∇(B̄::SM{T}, ::typeof(symm), ::Arg4, p, Y::SM{T}, Ȳ::SM{T}, + side::Char, + ul::Char, + A::SM{T}, + B::SM{T}, +) where T<:BF = ∇(B̄, symm, Val{5}, p, Y, Ȳ, side, ul, one(T), A, B) + +# `trmm` sensitivity implementations. +push!(ops, DiffOp(:(LinearAlgebra.BLAS.trmm), + :(Tuple{Char, Char, Char, Char, T, DLA.SM{T}, DLA.SM{T}} where T<:DLA.BF), + [false, false, false, false, true, true, true], +)) +∇(::typeof(trmm), ::Arg5, p, Y::SM{T}, Ȳ::SM{T}, + side::Char, ul::Char, ta::Char, dA::Char, + α::T, + A::SM{T}, + B::SM{T}, +) where T<:BF = sum(Ȳ .* Y) / α +function ∇(::typeof(trmm), ::Arg6, p, Y::SM{T}, Ȳ::SM{T}, + side::Char, ul::Char, ta::Char, dA::Char, + α::T, + A::SM{T}, + B::SM{T}, +) where T<:BF + Ā_full = uppercase(side) == 'L' ? + uppercase(ta) == 'N' ? + gemm('N', 'T', α, Ȳ, B) : + gemm('N', 'T', α, B, Ȳ) : + uppercase(ta) == 'N' ? + gemm('T', 'N', α, B, Ȳ) : + gemm('T', 'N', α, Ȳ, B) + dA == 'U' && fill!(view(Ā_full, diagind(Ā_full)), zero(T)) + return (uppercase(ul) == 'L' ? tril! : triu!)(Ā_full) +end +∇(::typeof(trmm), ::Type{Val{7}}, p, Y::SM{T}, Ȳ::SM{T}, + side::Char, ul::Char, ta::Char, dA::Char, + α::T, + A::SM{T}, + B::SM{T}, +) where T<:BF = + uppercase(side) == 'L' ? + uppercase(ta) == 'N' ? + trmm('L', ul, 'T', dA, α, A, Ȳ) : + trmm('L', ul, 'N', dA, α, A, Ȳ) : + uppercase(ta) == 'N' ? + trmm('R', ul, 'T', dA, α, A, Ȳ) : + trmm('R', ul, 'N', dA, α, A, Ȳ) + + +# `trsm` sensitivity implementations. +push!(ops, DiffOp(:(LinearAlgebra.BLAS.trsm), + :(Tuple{Char, Char, Char, Char, T, DLA.SM{T}, DLA.SM{T}} where T<:DLA.BF), + [false, false, false, false, true, true, true], +)) +∇(::typeof(trsm), ::Arg5, p, Y::SM{T}, Ȳ::SM{T}, + side::Char, ul::Char, ta::Char, dA::Char, + α::T, + A::SM{T}, + X::SM{T}, +) where T<:BF = sum(Ȳ .* Y) / α +function ∇(::typeof(trsm), ::Arg6, p, Y::SM{T}, Ȳ::SM{T}, + side::Char, ul::Char, ta::Char, dA::Char, + α::T, + A::SM{T}, + X::SM{T}, +) where T<:BF + Ā_full = uppercase(side) == 'L' ? + uppercase(ta) == 'N' ? + trsm('L', ul, 'T', dA, -one(T), A, Ȳ * Y') : + trsm('R', ul, 'T', dA, -one(T), A, Y * Ȳ') : + uppercase(ta) == 'N' ? + trsm('R', ul, 'T', dA, -one(T), A, Y'Ȳ) : + trsm('L', ul, 'T', dA, -one(T), A, Ȳ'Y) + dA == 'U' && fill!(view(Ā_full, diagind(Ā_full)), zero(T)) + return (uppercase(ul) == 'L' ? tril! : triu!)(Ā_full) +end +∇(::typeof(trsm), ::Type{Val{7}}, p, Y::SM{T}, Ȳ::SM{T}, + side::Char, ul::Char, ta::Char, dA::Char, + α::T, + A::SM{T}, + X::SM{T}, +) where T<:BF = + uppercase(side) == 'L' ? + uppercase(ta) == 'N' ? + trsm('L', ul, 'T', dA, α, A, Ȳ) : + trsm('L', ul, 'N', dA, α, A, Ȳ) : + uppercase(ta) == 'N' ? + trsm('R', ul, 'T', dA, α, A, Ȳ) : + trsm('R', ul, 'N', dA, α, A, Ȳ) diff --git a/src/reverse/diagonal.jl b/src/reverse/diagonal.jl new file mode 100644 index 0000000..479582c --- /dev/null +++ b/src/reverse/diagonal.jl @@ -0,0 +1,65 @@ +import LinearAlgebra: det, logdet, diagm, Diagonal, diag +export diag, diagm, Diagonal + +push!(ops, DiffOp(:(LinearAlgebra.diag), :(Tuple{DLA.AM}), [true])) +function ∇(::typeof(diag), ::Arg1, p, y::AV, ȳ::AV, X::AM) + X̄ = fill!(similar(X), zero(eltype(X))) + X̄[diagind(X̄)] = ȳ + return X̄ +end +function ∇(X̄::AM, ::typeof(diag), ::Arg1, p, y::AV, ȳ::AV, X::AM) + X̄_diag = view(X̄, diagind(X̄)) + X̄_diag .+= ȳ + return X̄ +end + +push!(ops, DiffOp(:(LinearAlgebra.diag), :(Tuple{DLA.AM, Integer}), [true, false])) +function ∇(::typeof(diag), ::Arg1, p, y::AV, ȳ::AV, X::AM, k::Integer) + X̄ = fill!(similar(X), zero(eltype(X))) + X̄[diagind(X̄, k)] = ȳ + return X̄ +end +function ∇(X̄::AM, ::typeof(diag), ::Arg1, p, y::AV, ȳ::AV, X::AM, k::Integer) + X̄_diag = view(X̄, diagind(X̄, k)) + X̄_diag .+= ȳ + return X̄ +end + +push!(ops, DiffOp(:(LinearAlgebra.diagm), :(Tuple{Pair{<:Integer, <:AV}}), [true])) +∇(::typeof(diagm), ::Arg1, p, Y::AM, Ȳ::AM, x::Pair{<:Integer, <:AV}) = + copyto!(similar(x.second), view(Ȳ, diagind(Ȳ, x.first))) +∇(x̄::AV, ::typeof(diagm), ::Arg1, p, Y::AM, Ȳ::AM, x::Pair{<:Integer, <:AV}) = + broadcast!(+, x̄, x̄, view(Ȳ, diagind(Ȳ, x.first))) + +push!(ops, DiffOp(:(LinearAlgebra.Diagonal), :(Tuple{DLA.AV}), [true])) +∇(::Type{Diagonal}, ::Arg1, p, Y::Diagonal{<:Real}, Ȳ::Diagonal{<:Real}, x::AV) = + copyto!(similar(x), Ȳ.diag) +∇(x̄::AV, ::Type{Diagonal}, ::Arg1, p, Y::Diagonal{<:Real}, Ȳ::Diagonal{<:Real}, x::AV) = + broadcast!(+, x̄, x̄, Ȳ.diag) + +push!(ops, DiffOp(:(LinearAlgebra.Diagonal), :(Tuple{DLA.AM}), [true])) +function ∇(::Type{Diagonal}, ::Arg1, p, Y::Diagonal{<:Real}, Ȳ::Diagonal{<:Real}, X::AM) + X̄ = zero(X) + copyto!(view(X̄, diagind(X)), Ȳ.diag) + return X̄ +end +function ∇(X̄::AM, ::Type{Diagonal}, ::Arg1, p, Y::Diagonal{<:Real}, Ȳ::Diagonal{<:Real}, X::AM) + X̄_diag = view(X̄, diagind(X̄)) + broadcast!(+, X̄_diag, X̄_diag, Ȳ.diag) + return X̄ +end + +push!(ops, DiffOp(:(LinearAlgebra.det), :(Tuple{Diagonal{<:Real}}), [true])) +∇(::typeof(det), ::Arg1, p, y::Real, ȳ::Real, X::Diagonal{<:Real}) = + Diagonal(ȳ .* y ./ X.diag) +function ∇(X̄::Diagonal{<:Real}, ::typeof(det), ::Arg1, p, y::Real, ȳ::Real, X::Diagonal{<:Real}) + broadcast!((x̄, x, y, ȳ)->x̄ + ȳ * y / x, X̄.diag, X̄.diag, X.diag, y, ȳ) + return X̄ +end + +push!(ops, DiffOp(:(LinearAlgebra.logdet), :(Tuple{Diagonal{<:Real}}), [true])) +∇(::typeof(logdet), ::Arg1, p, y::Real, ȳ::Real, X::Diagonal{<:Real}) = Diagonal(ȳ ./ X.diag) +function ∇(X̄::Diagonal{<:Real}, ::typeof(logdet), ::Arg1, p, y::Real, ȳ::Real, X::Diagonal{<:Real}) + broadcast!((x̄, x, ȳ)->x̄ + ȳ / x, X̄.diag, X̄.diag, X.diag, ȳ) + return X̄ +end diff --git a/src/reverse/factorization/cholesky.jl b/src/reverse/factorization/cholesky.jl new file mode 100644 index 0000000..1a40465 --- /dev/null +++ b/src/reverse/factorization/cholesky.jl @@ -0,0 +1,210 @@ +import LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger! +import LinearAlgebra: chol, copytri! + +#= +See [1] for implementation details: pages 5-9 in particular. The derivations presented in +[1] assume column-major layout, whereas Julia primarily uses row-major. We therefore +implement both the derivations in [1] and their transpose, which is more appropriate to +Julia. + +[2] suggests that implementing these operations at the highest level of abstraction is the +way forward. There is, therefore, some code below to see what happens when we do this. + +[1] - "Differentiation of the Cholesky decomposition", Murray 2016 +[2] - "Auto-Differentiating Linear Algebra", Seeger et. al 2017. +=# + +const UT = UpperTriangular +∇(::typeof(chol), ::Arg1, p, U::UT{T}, Ū::AM{T}, Σ::AM{T}) where T<:BF = + chol_blocked_rev(Matrix(Ū), Matrix(U), 25, true) + +# Experimental code implementing the algebraic sensitivities discussed in [2]. +# """ +# ∇(::typeof(chol), ::Arg1, p, U::UT{T}, Ū::AM{T}, Σ::AM{T}) where T<:BF + +# Transform Ū into Σ̄ in a non-allocating manner. +# """ +# function ∇(::typeof(chol), ::Arg1, p, U::UT{T}, Ū::AM{T}, Σ::AM{T}, ::Symbol) where T<:BF +# Σ̄ = A_mul_Bt!(Ū, U) +# Σ̄ = copytri!(Σ̄, 'U') +# Σ̄ = A_ldiv_B!(U, Σ̄) +# BLAS.trsm!('R', 'U', 'T', 'N', one(T), U.data, Σ̄) +# @inbounds for n in diagind(Σ̄) +# Σ̄[n] *= 0.5 +# end +# return Σ̄ +# end + +""" + level2partition(A::AbstractMatrix, j::Int, upper::Bool) + +Returns views to various bits of the lower triangle of `A` according to the +`level2partition` procedure defined in [1] if `upper` is `false`. If `upper` is `true` then +the transposed views are returned from the upper triangle of `A`. +""" +function level2partition(A::AM, j::Int, upper::Bool) + + # Check that A is square and j is a valid index. + M, N = size(A) + (0 >= j || j > M) && throw(ArgumentError("j is out of range.")) + M != N && throw(ArgumentError("A is not square.")) + + if upper + r = view(A, 1:j-1, j) + d = view(A, j, j) + B = view(A, 1:j-1, j+1:N) + c = view(A, j, j+1:N) + else + r = view(A, j, 1:j-1) + d = view(A, j, j) + B = view(A, j+1:N, 1:j-1) + c = view(A, j+1:N, j) + end + return r, d, B, c +end + +""" + level3partition(A::AbstractMatrix, j::Int, k::Int, upper::Bool) + +Returns views to various bits of the lower triangle of `A` according to the +`level3partition` procedure defined in [1] if `upper` is `false`. If `upper` is `true` then +the transposed views are returned from the upper triangle of `A`. +""" +function level3partition(A::AM, j::Int, k::Int, upper::Bool) + + # Check that A is square and j is a valid index. + M, N = size(A) + (0 >= j || j > M) && throw(ArgumentError("j is out of range.")) + M != N && throw(ArgumentError("A is not square.")) + + # Get views into bits of A. + if upper + R = view(A, 1:j-1, j:k) + D = view(A, j:k, j:k) + B = view(A, 1:j-1, k+1:N) + C = view(A, j:k, k+1:N) + else + R = view(A, j:k, 1:j-1) + D = view(A, j:k, j:k) + B = view(A, k+1:N, 1:j-1) + C = view(A, k+1:N, j:k) + end + return R, D, B, C +end + +""" + chol_unblocked_rev!( + Ā::AbstractMatrix{T}, + L::AbstractMatrix{T}, + upper::Bool + ) where T<:Real + +Compute the reverse-mode sensitivities of the Cholesky factorisation in an unblocked manner. +If `upper` is `false`, then the sensitivites computed from and stored in the lower triangle +of `Ā` and `L` respectively. If `upper` is `true` then they are computed and stored in the +upper triangles. If at input `upper` is `false` and `tril(Ā) = L̄`, at output +`tril(Ā) = tril(Σ̄)`, where `Σ = LLᵀ`. Analogously, if at input `upper` is `true` and +`triu(Ā) = triu(Ū)`, at output `triu(Ā) = triu(Σ̄)` where `Σ = UᵀU`. +""" +function chol_unblocked_rev!(Σ̄::AM{T}, L::AM{T}, upper::Bool) where T<:Real + + # Check that L is square, that Σ̄ is square and that they are the same size. + M, N = size(Σ̄) + M != N && throw(ArgumentError("Σ̄ is not square.")) + + # Compute the reverse-mode diff. + j = N + for ĵ in 1:N + r, d, B, c = level2partition(L, j, upper) + r̄, d̄, B̄, c̄ = level2partition(Σ̄, j, upper) + + # d̄ <- d̄ - c'c̄ / d. + d̄[1] -= dot(c, c̄) / d[1] + + # [d̄ c̄'] <- [d̄ c̄'] / d. + d̄ ./= d + c̄ ./= d + + # r̄ <- r̄ - [d̄ c̄'] [r' B']'. + r̄ = axpy!(-Σ̄[j, j], r, r̄) + r̄ = gemv!(upper ? 'N' : 'T', -one(T), B, c̄, one(T), r̄) + + # B̄ <- B̄ - c̄ r. + B̄ = upper ? ger!(-one(T), r, c̄, B̄) : ger!(-one(T), c̄, r, B̄) + d̄ ./= 2 + j -= 1 + end + return (upper ? triu! : tril!)(Σ̄) +end +chol_unblocked_rev(Σ̄::AM, L::AM, upper::Bool) = chol_unblocked_rev!(copy(Σ̄), L, upper) + +""" + chol_blocked_rev!( + Σ̄::AbstractMatrix{T}, + L::AbstractMatrix{T}, + Nb::Int, + upper::Bool + ) where T<:BF + +Compute the sensitivities of the Cholesky factorisation using a blocked, cache-friendly +procedure. `Σ̄` are the sensitivities of `L`, and will be transformed into the sensitivities +of `Σ`, where `Σ = LLᵀ`. `Nb` is the block-size to use. If the upper triangle has been used +to represent the factorization, that is `Σ = UᵀU` where `U := Lᵀ`, then this should be +indicated by passing `upper = true`. +""" +function chol_blocked_rev!(Σ̄::AM{T}, L::AM{T}, Nb::Int, upper::Bool) where T<:BF + + # Check that L is square, that Σ̄ is square and that they are the same size. + M, N = size(Σ̄) + M != N && throw(ArgumentError("Σ̄ is not square.")) + + tmp = Matrix{T}(undef, Nb, Nb) + + # Compute the reverse-mode diff. + k = N + if upper + for k̂ in 1:Nb:N + j = max(1, k - Nb + 1) + R, D, B, C = level3partition(L, j, k, true) + R̄, D̄, B̄, C̄ = level3partition(Σ̄, j, k, true) + + C̄ = trsm!('L', 'U', 'N', 'N', one(T), D, C̄) + gemm!('N', 'N', -one(T), R, C̄, one(T), B̄) + gemm!('N', 'T', -one(T), C, C̄, one(T), D̄) + chol_unblocked_rev!(D̄, D, true) + gemm!('N', 'T', -one(T), B, C̄, one(T), R̄) + if size(D̄, 1) == Nb + tmp = axpy!(one(T), D̄, transpose!(tmp, D̄)) + gemm!('N', 'N', -one(T), R, tmp, one(T), R̄) + else + gemm!('N', 'N', -one(T), R, D̄ + D̄', one(T), R̄) + end + + k -= Nb + end + return triu!(Σ̄) + else + for k̂ in 1:Nb:N + j = max(1, k - Nb + 1) + R, D, B, C = level3partition(L, j, k, false) + R̄, D̄, B̄, C̄ = level3partition(Σ̄, j, k, false) + + C̄ = trsm!('R', 'L', 'N', 'N', one(T), D, C̄) + gemm!('N', 'N', -one(T), C̄, R, one(T), B̄) + gemm!('T', 'N', -one(T), C̄, C, one(T), D̄) + chol_unblocked_rev!(D̄, D, false) + gemm!('T', 'N', -one(T), C̄, B, one(T), R̄) + if size(D̄, 1) == Nb + tmp = axpy!(one(T), D̄, transpose!(tmp, D̄)) + gemm!('N', 'N', -one(T), tmp, R, one(T), R̄) + else + gemm!('N', 'N', -one(T), D̄ + D̄', R, one(T), R̄) + end + + k -= Nb + end + return tril!(Σ̄) + end +end +chol_blocked_rev(Σ̄::AbstractMatrix, L::AbstractMatrix, Nb::Int, upper::Bool) = + chol_blocked_rev!(copy(Σ̄), L, Nb, upper) diff --git a/src/reverse/generic.jl b/src/reverse/generic.jl new file mode 100644 index 0000000..d218dbc --- /dev/null +++ b/src/reverse/generic.jl @@ -0,0 +1,199 @@ +import LinearAlgebra: -, tr, inv, det, logdet, transpose, adjoint, norm, kron + +# ############################# Unary sensitivities ############################# + +@reverse_rule( + Y::AbstractArray{<:Real}, Ȳ::AbstractArray{<:Real}, + Base.:-(wrt(X::AbstractArray{<:Real})) = :(-$Ȳ), +) + +@reverse_rule( + Y::Real, Ȳ::Real, + LinearAlgebra.tr(wrt(X::AbstractMatrix{<:Real})) = :(Diagonal(fill!(similar($X), $Ȳ))) +) + +@reverse_rule( + Y::AbstractMatrix{<:Real}, Ȳ::AbstractMatrix{<:Real}, + LinearAlgebra.inv(wrt(X::AbstractMatrix{<:Real})) = :(-$Y' * $Ȳ * $Y'), +) + +@reverse_rule( + Y::Real, Ȳ::Real, + LinearAlgebra.det(wrt(X::AbstractMatrix{<:Real})) = :($Y * $Ȳ * inv($X)'), +) + +@reverse_rule( + Y::Real, Ȳ::Real, + LinearAlgebra.logdet(wrt(X::AbstractMatrix{<:Real})) = :($Ȳ * inv($X)'), +) + +@reverse_rule( + Y::AbstractVecOrMat{<:Real}, Ȳ::AbstractVecOrMat{<:Real}, + LinearAlgebra.transpose(wrt(X::AbstractVecOrMat{<:Real})) = :($Ȳ'), +) + +@reverse_rule( + Y::AbstractVecOrMat{<:Real}, Ȳ::AbstractVecOrMat{<:Real}, + LinearAlgebra.adjoint(wrt(X::AbstractVecOrMat{<:Real})) = :($Ȳ'), +) + +@reverse_rule( + Y::Real, Ȳ::Real, + LinearAlgebra.norm(wrt(X::AbstractArray{<:Real})) = :($Ȳ ./ $Y .* abs2.($X) ./ $X), +) + +@reverse_rule( + Y::Real, Ȳ::Real, + LinearAlgebra.norm(wrt(X::Real)) = :($Ȳ * sign($X)), +) + + +############################# Binary sensitivities ############################# + +@reverse_rule( + Y::AbstractVecOrMat{<:Real}, Ȳ::AbstractVecOrMat{<:Real}, + LinearAlgebra.:*(wrt(A::AbstractVecOrMat{<:Real}), B::AbstractVecOrMat{<:Real}) = :($Ȳ * $B') +) +@reverse_rule( + Y::AbstractVecOrMat{<:Real}, Ȳ::AbstractVecOrMat{<:Real}, + LinearAlgebra.:*(A::AbstractVecOrMat{<:Real}, wrt(B::AbstractVecOrMat{<:Real})) = :($A' * $Ȳ) +) + +@reverse_rule( + Y::AbstractVecOrMat{<:Real}, Ȳ::AbstractVecOrMat{<:Real}, + LinearAlgebra.:/(wrt(A::AbstractVecOrMat{<:Real}), B::AbstractVecOrMat{<:Real}) = :($Ȳ / $B') +) +@reverse_rule( + Y::AbstractVecOrMat{<:Real}, Ȳ::AbstractVecOrMat{<:Real}, + LinearAlgebra.:/(A::AbstractVecOrMat{<:Real}, wrt(B::AbstractVecOrMat{<:Real})) = :(-($Y)' * ($Ȳ / $B')) +) + +@reverse_rule( + Y::AbstractVecOrMat{<:Real}, Ȳ::AbstractVecOrMat{<:Real}, + LinearAlgebra.:\(wrt(A::AbstractVecOrMat{<:Real}), B::AbstractVecOrMat{<:Real}) = :(-($A' \ $Ȳ) * $Y') +) +@reverse_rule( + Y::AbstractVecOrMat{<:Real}, Ȳ::AbstractVecOrMat{<:Real}, + LinearAlgebra.:\(A::AbstractVecOrMat{<:Real}, wrt(B::AbstractVecOrMat{<:Real})) = :($A' \ $Ȳ) +) + +@reverse_rule( + Y::Real, Ȳ::Real, + LinearAlgebra.norm(wrt(A::AbstractArray{<:Real}), B::Real) = + :($Ȳ .* $Y^(1 - $B) .* abs.($A).^$B ./ $A) +) +@reverse_rule( + Y::Real, Ȳ::Real, + LinearAlgebra.norm(A::AbstractArray{<:Real}, wrt(B::Real)) = + :($Ȳ * ($Y^(1 - $B) * sum(abs.($A).^$B .* log.(abs.($A))) - $Y * log($Y)) / $B) +) + +@reverse_rule( + Y::Real, Ȳ::Real, + LinearAlgebra.norm(wrt(A::Real), B::Real) = :($Ȳ * sign($A)) +) +@reverse_rule( + Y::Real, Ȳ::Real, + LinearAlgebra.norm(A::Real, wrt(B::Real)) = :(0) +) + +@reverse_rule( + Y::AbstractMatrix{<:Real}, Ȳ::AbstractMatrix{<:Real}, + LinearAlgebra.kron(wrt(A::AbstractMatrix{<:Real}), B::AbstractMatrix{<:Real}) = + :(_kron_rev_kernel_1($Y, $Ȳ, $A, $B)), +) +@reverse_rule( + Y::AbstractMatrix{<:Real}, Ȳ::AbstractMatrix{<:Real}, + LinearAlgebra.kron(A::AbstractMatrix{<:Real}, wrt(B::AbstractMatrix{<:Real})) = + :(_kron_rev_kernel_2($Y, $Ȳ, $A, $B)), +) + +function _kron_rev_kernel_1( + Y::AbstractMatrix{<:Real}, + Ȳ::AbstractMatrix{<:Real}, + A::AbstractMatrix{<:Real}, + B::AbstractMatrix{<:Real}, +) + Ā = similar(A) + (I, J), (K, L), m = size(A), size(B), length(Y) + @inbounds for j = reverse(1:J), l = reverse(1:L), i = reverse(1:I) + āij = Ā[i, j] + for k = reverse(1:K) + āij += Ȳ[m] * B[k, l] + m -= 1 + end + Ā[i, j] = āij + end + return Ā +end + +function _kron_rev_kernel_2( + Y::AbstractMatrix{<:Real}, + Ȳ::AbstractMatrix{<:Real}, + A::AbstractMatrix{<:Real}, + B::AbstractMatrix{<:Real}, +) + B̄ = similar(B) + (I, J), (K, L), m = size(A), size(B), length(Y) + @inbounds for j = reverse(1:J), l = reverse(1:L), i = reverse(1:I) + aij = A[i, j] + for k = reverse(1:K) + B̄[k, l] += aij * Ȳ[m] + m -= 1 + end + end + return B̄ +end + +# push!(ops, DiffOp(:(LinearAlgebra.:*), :(Tuple{DLA.AVM, DLA.AVM}), [true, true])) +# ∇(::typeof(*), ::Arg1, p, Y::ASVM, Ȳ::ASVM, A::AVM, B::AVM) = Ȳ * B' +# ∇(::typeof(*), ::Arg2, p, Y::ASVM, Ȳ::ASVM, A::AVM, B::AVM) = A' * Ȳ + +# push!(ops, DiffOp(:(LinearAlgebra.:/), :(Tuple{DLA.AVM, DLA.AVM}), [true, true])) +# ∇(::typeof(/), ::Arg1, p, Y::ASVM, Ȳ::ASVM, A::AVM, B::AVM) = Ȳ / B' +# ∇(::typeof(/), ::Arg2, p, Y::ASVM, Ȳ::ASVM, A::AVM, B::AVM) = -Y' * (Ȳ / B') + +# push!(ops, DiffOp(:(LinearAlgebra.:\), :(Tuple{DLA.AVM, DLA.AVM}), [true, true])) +# ∇(::typeof(\), ::Arg1, p, Y::ASVM, Ȳ::ASVM, A::AVM, B::AVM) = -(A' \ Ȳ) * Y' +# ∇(::typeof(\), ::Arg2, p, Y::ASVM, Ȳ::ASVM, A::AVM, B::AVM) = A' \ Ȳ + +# push!(ops, DiffOp(:(LinearAlgebra.norm), :(Tuple{DLA.AA, Real}), [true, true])) +# ∇(::typeof(norm), ::Arg1, p, Y::Real, Ȳ::Real, A::AA, B::Real) = +# Ȳ .* Y^(1 - B) .* abs.(A).^B ./ A +# ∇(::typeof(norm), ::Arg2, p, Y::Real, Ȳ::Real, A::AA, B::Real) = +# Ȳ * (Y^(1 - B) * sum(abs.(A).^B .* log.(abs.(A))) - Y * log(Y)) / B + +# push!(ops, DiffOp(:(LinearAlgebra.norm), :(Tuple{Real, Real}), [true, true])) +# ∇(::typeof(norm), ::Arg1, p, Y::Real, Ȳ::Real, A::Real, B::Real) = Ȳ * sign(A) +# ∇(::typeof(norm), ::Arg2, p, Y::Real, Ȳ::Real, A::Real, B::Real) = 0 + +# push!(ops, DiffOp(:(LinearAlgebra.kron), :(Tuple{AM, AM}), [true, true])) +# ∇(::typeof(kron), ::Type{Val{1}}, p, Y::AM, Ȳ::AM, A::AM, B::AM) = +# ∇(zero(A), kron, Val{1}, p, Y, Ȳ, A, B) +# ∇(::typeof(kron), ::Type{Val{2}}, p, Y::AM, Ȳ::AM, A::AM, B::AM) = +# ∇(zero(B), kron, Val{2}, p, Y, Ȳ, A, B) +# function ∇(Ā::AM, ::typeof(kron), ::Type{Val{1}}, p, Y::AM, Ȳ::AM, A::AM, B::AM) +# @assert size(Ā) == size(A) +# (I, J), (K, L), m = size(A), size(B), length(Y) +# @inbounds for j = reverse(1:J), l = reverse(1:L), i = reverse(1:I) +# āij = Ā[i, j] +# for k = reverse(1:K) +# āij += Ȳ[m] * B[k, l] +# m -= 1 +# end +# Ā[i, j] = āij +# end +# return Ā +# end +# function ∇(B̄::AM, ::typeof(kron), ::Type{Val{2}}, p, Y::AM, Ȳ::AM, A::AM, B::AM) +# @assert size(B̄) == size(B) +# (I, J), (K, L), m = size(A), size(B), length(Y) +# @inbounds for j = reverse(1:J), l = reverse(1:L), i = reverse(1:I) +# aij = A[i, j] +# for k = reverse(1:K) +# B̄[k, l] += aij * Ȳ[m] +# m -= 1 +# end +# end +# return B̄ +# end diff --git a/src/reverse/triangular.jl b/src/reverse/triangular.jl new file mode 100644 index 0000000..6d6dd3a --- /dev/null +++ b/src/reverse/triangular.jl @@ -0,0 +1,51 @@ +import LinearAlgebra: det, logdet, LowerTriangular, UpperTriangular +export det, logdet, LowerTriangular, UpperTriangular + +for (ctor, ctor_sym, T, T_sym) in zip([:LowerTriangular, :UpperTriangular], + [:(:(LinearAlgebra.LowerTriangular)), :(:(LinearAlgebra.UpperTriangular))], + [:(LowerTriangular{<:Real}), :(UpperTriangular{<:Real})], + [:(:(LowerTriangular{<:Real})), :(:(UpperTriangular{<:Real}))]) + + @eval begin + + push!(ops, DiffOp($ctor_sym, :(Tuple{DLA.AM}), [true])) + ∇(::Type{$ctor}, ::Arg1, p, Y::$T, Ȳ::$T, X::AM) = Matrix(Ȳ) + ∇(X̄::AM, ::Type{$ctor}, ::Arg1, p, Y::$T, Ȳ::$T, X::AM) = broadcast!(+, X̄, X̄, Ȳ) + + push!(ops, DiffOp(:(LinearAlgebra.det), Expr(:curly, :Tuple, $T_sym), [true])) + ∇(::typeof(det), ::Arg1, p, y::Real, ȳ::Real, X::$T) = + Diagonal(ȳ .* y ./ view(X, diagind(X))) + + # Optimisation for in-place updates. + function ∇(X̄::AM, ::typeof(det), ::Arg1, p, y::Real, ȳ::Real, X::$T) + X̄_diag = view(X̄, diagind(X̄)) + broadcast!((x̄, x, y, ȳ)->x̄ + ȳ * y / x, + X̄_diag, X̄_diag, view(X, diagind(X)), y, ȳ) + return X̄ + end + + # Optimisation for in-place updates to `Diagonal` sensitivity cache. + function ∇(X̄::Diagonal, ::typeof(det), ::Arg1, p, y::Real, ȳ::Real, X::$T) + X̄.diag .+= ȳ .* y ./ view(X, diagind(X)) + return X̄ + end + + push!(ops, DiffOp(:(LinearAlgebra.logdet), Expr(:curly, :Tuple, $T_sym), [true])) + ∇(::typeof(logdet), ::Arg1, p, y::Real, ȳ::Real, X::$T) = + Diagonal(ȳ ./ view(X, diagind(X))) + + # Optimisation for in-place updates. + function ∇(X̄::AM, ::typeof(logdet), ::Arg1, p, y::Real, ȳ::Real, X::$T) + X̄_diag = view(X̄, diagind(X̄)) + broadcast!((x̄, x, ȳ)->x̄ + ȳ / x, X̄_diag, X̄_diag, view(X, diagind(X)), ȳ) + return X̄ + end + + # Optimisation for in-place updates to `Diagonal` sensitivity cache. + function ∇(X̄::Diagonal, ::typeof(logdet), ::Arg1, p, y::Real, ȳ::Real, X::$T) + X̄.diag .+= ȳ ./ view(X, diagind(X)) + return X̄ + end + + end +end diff --git a/src/reverse/uniformscaling.jl b/src/reverse/uniformscaling.jl new file mode 100644 index 0000000..e69de29 diff --git a/src/reverse/util.jl b/src/reverse/util.jl new file mode 100644 index 0000000..28a8662 --- /dev/null +++ b/src/reverse/util.jl @@ -0,0 +1,38 @@ +""" + DiffOp + +The information associated with a particular differentiable linear algebra operation. `f` is +either a `Symbol` or `Expr` containing the function name, and `T` is an expression +corresponding to the tuple-type of the arguments of the function. `diff_flags` is a vector +containing flags indicating whether each argument of the function is differentiable or not. +""" +struct DiffOp + f::Union{Symbol, Expr} + T::Expr + diff_flags::Vector{Bool} +end +ops = Set{DiffOp}() + +""" + importable(ex::Expr) + +Construct an expression of the form `:(import Package.Subpackage.Foo)` from an expression of +the form `:(Package.Subpackage.Foo)`. +""" +function importable(ex::Expr) + ex.head === :. || error("Expression is not valid as an import: $ex") + result = importable(ex.args[1]) + push!(result, ex.args[2].value) + return result +end +importable(sym::Symbol) = Any[sym] + +""" + import_expr(dop::DiffOp) + +Generate an expression to import `dop.f` from the appropriate package. +""" +import_expr(dop::DiffOp) = + VERSION <= VersionNumber("0.6.2") ? + Expr(:import, importable(dop.f)...) : + Expr(:import, Expr(Symbol("."), importable(dop.f)...)) diff --git a/test/REQUIRE b/test/REQUIRE index 5728fa8..4107efe 100644 --- a/test/REQUIRE +++ b/test/REQUIRE @@ -1,2 +1,3 @@ SpecialFunctions NaNMath +FDM diff --git a/test/diffrules/rules.jl b/test/diffrules/rules.jl new file mode 100644 index 0000000..4a0a2c4 --- /dev/null +++ b/test/diffrules/rules.jl @@ -0,0 +1,45 @@ +@testset "rules" begin + +non_numeric_arg_functions = [(:Base, :rem2pi, 2)] + +for (M, f, arity) in DiffRules.diffrules() + (M, f, arity) ∈ non_numeric_arg_functions && continue + if arity == 1 + @test DiffRules.hasdiffrule(M, f, 1) + deriv = DiffRules.diffrule(M, f, :goo) + modifier = in(f, (:asec, :acsc, :asecd, :acscd, :acosh, :acoth)) ? 1 : 0 + @eval begin + goo = rand() + $modifier + @test isapprox($deriv, finitediff($M.$f, goo), rtol=0.05) + end + elseif arity == 2 + @test DiffRules.hasdiffrule(M, f, 2) + derivs = DiffRules.diffrule(M, f, :foo, :bar) + @eval begin + foo, bar = rand(1:10), rand() + dx, dy = $(derivs[1]), $(derivs[2]) + if !(isnan(dx)) + @test isapprox(dx, finitediff(z -> $M.$f(z, bar), float(foo)), rtol=0.05) + end + if !(isnan(dy)) + @test isapprox(dy, finitediff(z -> $M.$f(foo, z), bar), rtol=0.05) + end + end + end +end + +# Treat rem2pi separately because of its non-numeric second argument: +derivs = DiffRules.diffrule(:Base, :rem2pi, :x, :y) +for xtype in [:Float64, :BigFloat, :Int64] + for mode in [:RoundUp, :RoundDown, :RoundToZero, :RoundNearest] + @eval begin + x = $xtype(rand(1 : 10)) + y = $mode + dx, dy = $(derivs[1]), $(derivs[2]) + @test isapprox(dx, finitediff(z -> rem2pi(z, y), float(x)), rtol=0.05) + @test isnan(dy) + end + end +end + +end diff --git a/test/forward/api.jl b/test/forward/api.jl new file mode 100644 index 0000000..029f6ad --- /dev/null +++ b/test/forward/api.jl @@ -0,0 +1,46 @@ +using DiffRules: @forward_rule, DEFINED_FORWARD_RULES, arity, diffrule, _forward_rule + +@testset "api" begin + + # Check that various things that should fail, fail. + @test_throws AssertionError _forward_rule(:(M.f(x))) + @test_throws AssertionError _forward_rule(:(f(x::T) where T = :(5x))) + @test_throws AssertionError _forward_rule(:(f(x::Real; y) = :(5x))) + @test_throws AssertionError _forward_rule(:(f(x::Real...) = :(5x))) + @test_throws AssertionError _forward_rule(:(f(x::Real=5) = :(5x))) + @test_throws ErrorException _forward_rule(:(f(x::Real) = :(5x))) + + # Check that a basic rule works. + foo(x) = 5x + @forward_rule Main.foo(x, ẋ) = :(5($x)^2 + $ẋ) + @test DEFINED_FORWARD_RULES[(:Main, :foo, :(Tuple{Any, Any}))](:g, :h) == :(5g^2 + h) + delete!(DEFINED_FORWARD_RULES, (:Main, :foo, :(Tuple{Any, Any}))) + + non_numeric_arg_functions = [(:Base, :rem2pi, 4)] + + # Check that all forward rules agree with basic diff rules. + for (key, body) in DEFINED_FORWARD_RULES + M, f, signature = key + (M, f, arity(key)) ∈ non_numeric_arg_functions && continue + if arity(key) == 2 + modifier = f ∈ (:asec, :acsc, :asecd, :acscd, :acosh, :acoth) ? 1 : 0 + simple_rule_code = diffrule(M, f, :x) + forward_rule_code = body(:x, :ẋ) + @eval manual_rule = (x, ẋ)->ẋ * $simple_rule_code + @eval forward_rule = (x, ẋ)->$forward_rule_code + x, ẋ = rand() + modifier, randn() + @test manual_rule(x, ẋ) ≈ forward_rule(x, ẋ) + elseif arity(key) == 4 + ∂f∂x, ∂f∂y = diffrule(M, f, :x, :y) + forward_rule_code = body(:x, :y, :ẋ, :ẏ) + @eval manual_rule = (x, y, ẋ, ẏ)->ẋ * $∂f∂x + ẏ * $∂f∂y + @eval forward_rule = (x, y, ẋ, ẏ)->$forward_rule_code + x, y, ẋ, ẏ = rand(), rand(), rand(), rand() + manual, fwd = manual_rule(x, y, ẋ, ẏ), forward_rule(x, y, ẋ, ẏ) + @test isnan(manual) && isnan(fwd) || manual ≈ fwd + else + error("argh") + end + end + +end diff --git a/test/reverse/REQUIRE b/test/reverse/REQUIRE new file mode 100644 index 0000000..24b17bc --- /dev/null +++ b/test/reverse/REQUIRE @@ -0,0 +1 @@ +FDM v0.1.1 diff --git a/test/reverse/api.jl b/test/reverse/api.jl new file mode 100644 index 0000000..c8a7530 --- /dev/null +++ b/test/reverse/api.jl @@ -0,0 +1,62 @@ +using DiffRules: diffrules, DEFINED_REVERSE_RULES, arity, diffrule, @reverse_rule, + _reverse_rule, ReverseRuleKey + +@testset "api" begin + + # Check that various things that should fail, fail. + @test_throws AssertionError _reverse_rule(:z, :z̄, :(M.f(x))) + @test_throws AssertionError _reverse_rule(:z, :z̄, :(M.f(x::T) where T = :(5x))) + @test_throws AssertionError _reverse_rule(:z, :z̄, :(M.f(x::Real; y) = :(5x))) + @test_throws AssertionError _reverse_rule(:z, :z̄, :(M.f(x::Real...) = :(5x))) + @test_throws AssertionError _reverse_rule(:z, :z̄, :(M.f(x::Real=5) = :(5x))) + @test_throws ErrorException _reverse_rule(:z, :z̄, :(f(x::Real) = :(5x))) + + # Check that a basic rule works. + foo(x) = 5x + @reverse_rule y ȳ::Real Main.foo(wrt(x::Int)) = :($y * $ȳ * $x) + key = (:Main, :foo, :(Tuple{Any, Real, Int}), (1,)) + @test DEFINED_REVERSE_RULES[key](:g, :h, :y) == :(g * h * y) + delete!(DEFINED_REVERSE_RULES, key) + + + # non_numeric_arg_functions = [(:Base, :rem2pi, 4)] + + # # Check that all reverse rules agree with basic diff rules. + # for key in diffrules() + # M, f, arity = key + # key ∈ non_numeric_arg_functions && continue + # if arity == 1 + # rev_key = ReverseRuleKey(M, f, :(Tuple{Real, Real, Real}), (1,)) + # rev_rule = DEFINED_REVERSE_RULES[rev_key] + # modifier = f ∈ (:asec, :acsc, :asecd, :acscd, :acosh, :acoth) ? 1 : 0 + # @eval manual_rule = (z, z̄, g)->z̄ * $(diffrule(M, f, :g)) + # @eval reverse_rule = (z, z̄, g)->$(rev_rule(:z, :z̄, :g)) + # x, z, z̄ = rand() + modifier, randn(), randn() + # @test manual_rule(z, z̄, x) ≈ reverse_rule(z, z̄, x) + # elseif arity == 2 + + # ∂f∂x, ∂f∂y = diffrule(M, f, :g_x, :h_z) + + # # Grab the corresponding reverse rules. + # typ = :(Tuple{Real, Real, Real, Real}) + # key1, key2 = (M, f, typ, (1,)), (M, f, typ, (2,)) + # if key1 ∈ keys(DEFINED_REVERSE_RULES) + # x, y, z, z̄ = rand(), rand(), rand(), rand() + # rev_rule_1 = DEFINED_REVERSE_RULES[key1](:z, :z̄, :g_x, :h_z) + # @eval manual_rule_1 = (z, z̄, g_x, h_z)->z̄ * $∂f∂x + # @eval reverse_rule_1 = (z, z̄, g_x, h_z)->$rev_rule_1 + # @test manual_rule_1(z, z̄, x, y) ≈ reverse_rule_1(z, z̄, x, y) + # end + # if key2 ∈ keys(DEFINED_REVERSE_RULES) + # x, y, z, z̄ = rand(), rand(), rand(), rand() + # rev_rule_2 = DEFINED_REVERSE_RULES[(M, f, typ, (2,))](:z, :z̄, :g_x, :h_z) + # @eval manual_rule_2 = (z, z̄, g_x, h_z)->z̄ * $∂f∂y + # @eval reverse_rule_2 = (z, z̄, g_x, h_z)->$rev_rule_2 + # @test manual_rule_2(z, z̄, x, y) ≈ reverse_rule_2(z, z̄, x, y) + # end + # else + # @test 1 === 0 + # end + # end + +end diff --git a/test/reverse/blas.jl b/test/reverse/blas.jl new file mode 100644 index 0000000..db84345 --- /dev/null +++ b/test/reverse/blas.jl @@ -0,0 +1,260 @@ +@testset "BLAS" begin + +import LinearAlgebra.BLAS: nrm2, asum, gemm, gemv, symm, symv, trmm, trmv, trsm, trsv + + +################################## Level 1 ################################## +let P = 10, Q = 6, rng = MersenneTwister(123456), N = 10 + + # Utility random generators. + sc, vP, vQ = ()->randn(rng), ()->randn(rng, P), ()->randn(rng, Q) + + # Unit-stride dot. + @test check_errs(N, binary_ȲD(LinearAlgebra.dot, 1, vP)..., sc, vP, vP) + @test check_errs(N, binary_ȲD(LinearAlgebra.dot, 2, vP)..., sc, vP, vP) + @test check_errs(N, binary_ȲD_inplace(LinearAlgebra.dot, 1, vP, vP())..., sc, vP, vP) + @test check_errs(N, binary_ȲD_inplace(LinearAlgebra.dot, 2, vP, vP())..., sc, vP, vP) + + # Strided dot. + _x, _y = vP(), vQ() + _dot2, _dot4 = x->BLAS.dot(5, x, 2, _y, 1), y->BLAS.dot(5, _x, 2, y, 1) + _∇dot2 = (z, z̄, x)->∇(BLAS.dot, Val{2}, (), z, z̄, 5, x, 2, _y, 1) + _∇dot4 = (z, z̄, y)->∇(BLAS.dot, Val{4}, (), z, z̄, 5, _x, 2, y, 1) + @test check_errs(N, _dot2, _∇dot2, sc, vP, vP) + @test check_errs(N, _dot4, _∇dot4, sc, vQ, vQ) + + # In-place strided dot. + _δx, _δy = vP(), vQ() + _∇dot2 = (z, z̄, x)->∇(copy(_δx), BLAS.dot, Val{2}, (), z, z̄, 5, x, 2, _y, 1) - _δx + _∇dot4 = (z, z̄, y)->∇(copy(_δy), BLAS.dot, Val{4}, (), z, z̄, 5, _x, 2, y, 1) - _δy + @test check_errs(N, _dot2, _∇dot2, sc, vP, vP) + @test check_errs(N, _dot4, _∇dot4, sc, vQ, vQ) + + # Unit-stride nrm2. + @test check_errs(N, unary_ȲD(nrm2)..., sc, vP, vP) + @test check_errs(N, unary_ȲD_inplace(nrm2, vP())..., sc, vP, vP) + + # Arbitrary-stride nrm2. + _nrm2 = x->nrm2(5, x, 2) + _∇nrm2 = (y, ȳ, x)->∇(nrm2, Val{2}, (), y, ȳ, 5, x, 2) + _∇nrm2_in_place = (y, ȳ, x)->∇(copy(_δx), nrm2, Val{2}, (), y, ȳ, 5, x, 2) - _δx + @test check_errs(N, _nrm2, _∇nrm2, sc, vP, vP) + @test check_errs(N, _nrm2, _∇nrm2_in_place, sc, vP, vP) + + # Unit-stride `asum`. + @test check_errs(N, unary_ȲD(asum)..., sc, vP, vP) + @test check_errs(N, unary_ȲD_inplace(asum, vP())..., sc, vP, vP) + + # Arbitrary-stride `asum`. + _asum = x->asum(5, x, 2) + _∇asum = (y, ȳ, x)->∇(asum, Val{2}, (), y, ȳ, 5, x, 2) + _∇asum_in_place = (y, ȳ, x)->∇(copy(_δx), asum, Val{2}, (), y, ȳ, 5, x, 2) - _δx + @test check_errs(N, _asum, _∇asum, sc, vP, vP) + @test check_errs(N, _asum, _∇asum_in_place, sc, vP, vP) +end + + +################################## Level 2 ################################## +let P = 10, Q = 6, rng = MersenneTwister(123456), N = 10 + + # Utility random generators. + sc, vP, vQ = ()->randn(rng), ()->randn(rng, P), ()->randn(rng, Q) + mPQ, mQP = ()->randn(rng, P, Q), ()->randn(rng, Q, P) + mPP = ()->randn(rng, P, P) + + # gemv: + for tA in ['T', 'N'] + _α, _x = sc(), vQ() + _A = tA == 'N' ? mPQ() : mQP() + A = tA == 'N' ? mPQ : mQP + _δA, _δx = randn!(rng, similar(_A)), randn!(rng, similar(_x)) + + # α != 1 tests. + _gemv1 = α->gemv(tA, α, _A, _x) + _gemv2 = A->gemv(tA, _α, A, _x) + _gemv3 = x->gemv(tA, _α, _A, x) + _∇gemv1 = (y, ȳ, α)->∇(gemv, Val{2}, (), y, ȳ, tA, α, _A, _x) + _∇gemv2 = (y, ȳ, A)->∇(gemv, Val{3}, (), y, ȳ, tA, _α, A, _x) + _∇gemv2_inp = (y, ȳ, A)->∇(copy(_δA), gemv, Val{3}, (), y, ȳ, tA, _α, A, _x) - _δA + _∇gemv3 = (y, ȳ, x)->∇(gemv, Val{4}, (), y, ȳ, tA, _α, _A, x) + _∇gemv3_inp = (y, ȳ, x)->∇(copy(_δx), gemv, Val{4}, (), y, ȳ, tA, _α, _A, x) - _δx + + @test check_errs(N, _gemv1, _∇gemv1, vP, sc, sc) + @test check_errs(N, _gemv2, _∇gemv2, vP, A, A) + @test check_errs(N, _gemv2, _∇gemv2_inp, vP, A, A) + @test check_errs(N, _gemv3, _∇gemv3, vP, vQ, vQ) + @test check_errs(N, _gemv3, _∇gemv3_inp, vP, vQ, vQ) + + # α = 1 tests. + _gemv1, _gemv2 = A->gemv(tA, A, _x), x->gemv(tA, _A, x) + _∇gemv1 = (y, ȳ, A)->∇(gemv, Val{2}, (), y, ȳ, tA, A, _x) + _∇gemv1_inp = (y, ȳ, A)->∇(copy(_δA), gemv, Val{2}, (), y, ȳ, tA, A, _x) - _δA + _∇gemv2 = (y, ȳ, x)->∇(gemv, Val{3}, (), y, ȳ, tA, _A, x) + _∇gemv2_inp = (y, ȳ, x)->∇(copy(_δx), gemv, Val{3}, (), y, ȳ, tA, _A, x) - _δx + + @test check_errs(N, _gemv1, _∇gemv1, vP, A, A) + @test check_errs(N, _gemv1, _∇gemv1_inp, vP, A, A) + @test check_errs(N, _gemv2, _∇gemv2, vP, vQ, vQ) + @test check_errs(N, _gemv2, _∇gemv2_inp, vP, vQ, vQ) + end + + # symv: + for ul in ['L', 'U'] + α_gen, A_gen, x_gen = sc, mPP, vP + _α, _A, _x = α_gen(), A_gen(), x_gen() + _δA, _δx = randn!(rng, similar(_A)), randn!(rng, similar(_x)) + + # α != 1 tests. + _symv_α = α->symv(ul, α, _A, _x) + _symv_A = A->symv(ul, _α, A, _x) + _symv_x = x->symv(ul, _α, _A, x) + _∇symv_α = (y, ȳ, α)->∇(symv, Val{2}, (), y, ȳ, ul, α, _A, _x) + _∇symv_A = (y, ȳ, A)->∇(symv, Val{3}, (), y, ȳ, ul, _α, A, _x) + _∇symv_A_inp = (y, ȳ, A)->∇(copy(_δA), symv, Val{3}, (), y, ȳ, ul, _α, A, _x) - _δA + _∇symv_x = (y, ȳ, x)->∇(symv, Val{4}, (), y, ȳ, ul, _α, _A, x) + _∇symv_x_inp = (y, ȳ, x)->∇(copy(_δx), symv, Val{4}, (), y, ȳ, ul, _α, _A, x) - _δx + + @test check_errs(N, _symv_α, _∇symv_α, vP, sc, sc) + @test check_errs(N, _symv_A, _∇symv_A, vP, A_gen, A_gen) + @test check_errs(N, _symv_A, _∇symv_A_inp, vP, A_gen, A_gen) + @test check_errs(N, _symv_x, _∇symv_x, vP, x_gen, x_gen) + @test check_errs(N, _symv_x, _∇symv_x_inp, vP, x_gen, x_gen) + + # α = 1 tests. + _symv_A = A->symv(ul, A, _x) + _symv_x = x->symv(ul, _A, x) + _∇symv_A = (y, ȳ, A)->∇(symv, Val{2}, (), y, ȳ, ul, A, _x) + _∇symv_A_inp = (y, ȳ, A)->∇(copy(_δA), symv, Val{2}, (), y, ȳ, ul, A, _x) - _δA + _∇symv_x = (y, ȳ, x)->∇(symv, Val{3}, (), y, ȳ, ul, _A, x) + _∇symv_x_inp = (y, ȳ, x)->∇(copy(_δx), symv, Val{3}, (), y, ȳ, ul, _A, x) - _δx + + @test check_errs(N, _symv_A, _∇symv_A, vP, A_gen, A_gen) + @test check_errs(N, _symv_A, _∇symv_A_inp, vP, A_gen, A_gen) + @test check_errs(N, _symv_x, _∇symv_x, vP, x_gen, x_gen) + @test check_errs(N, _symv_x, _∇symv_x_inp, vP, x_gen, x_gen) + end + + for f in [trmv, trsv], ul in ['L', 'U'], tA in ['N', 'T'], dA in ['U', 'N'] + A_gen, x_gen = mPP, vP + _A, _x = A_gen(), x_gen() + + _f_A = A->f(ul, tA, dA, A, _x) + _f_x = x->f(ul, tA, dA, _A, x) + _∇f_A = (y, ȳ, A)->∇(f, Val{4}, (), y, ȳ, ul, tA, dA, A, _x) + _∇f_x = (y, ȳ, x)->∇(f, Val{5}, (), y, ȳ, ul, tA, dA, _A, x) + + @test check_errs(N, _f_A, _∇f_A, vP, A_gen, A_gen) + @test check_errs(N, _f_x, _∇f_x, vP, x_gen, x_gen) + end +end + + +################################## Level 3 ################################## +let P = 5, Q = 3, rng = MersenneTwister(123456), N = 10 + + # Utility random generators. + sc, mPP, mQQ = ()->randn(rng), ()->randn(rng, P, P), ()->randn(rng, Q, Q) + mPQ, mQP = ()->randn(rng, P, Q), ()->randn(rng, Q, P) + + # gemm: + for tA in ['T', 'N'], tB in ['T', 'N'] + + # Generate conformal test matrices. + A_gen = mPQ + B_gen = tA == 'N' ? + (tB == 'N' ? mQP : mPQ) : + (tB == 'T' ? mQP : mPQ) + C_gen = tA == 'N' ? mPP : mQQ + + # α != 1 tests. + _α, _A, _B = sc(), A_gen(), B_gen() + _δA, _δB = randn!(rng, similar(_A)), randn!(rng, similar(_B)) + _gemm_α = α->gemm(tA, tB, α, _A, _B) + _gemm_A = A->gemm(tA, tB, _α, A, _B) + _gemm_B = B->gemm(tA, tB, _α, _A, B) + _∇gemm_α = (Y, Ȳ, α)->∇(gemm, Val{3}, (), Y, Ȳ, tA, tB, α, _A, _B) + _∇gemm_A = (Y, Ȳ, A)->∇(gemm, Val{4}, (), Y, Ȳ, tA, tB, _α, A, _B) + _∇gemm_B = (Y, Ȳ, B)->∇(gemm, Val{5}, (), Y, Ȳ, tA, tB, _α, _A, B) + _∇gemm_A_inp = (Y, Ȳ, A)->∇(copy(_δA), gemm, Val{4}, (), Y, Ȳ, tA, tB, _α, A, _B) - _δA + _∇gemm_B_inp = (Y, Ȳ, B)->∇(copy(_δB), gemm, Val{5}, (), Y, Ȳ, tA, tB, _α, _A, B) - _δB + + @test check_errs(N, _gemm_α, _∇gemm_α, C_gen, sc, sc) + @test check_errs(N, _gemm_A, _∇gemm_A, C_gen, A_gen, A_gen) + @test check_errs(N, _gemm_A, _∇gemm_A_inp, C_gen, A_gen, A_gen) + @test check_errs(N, _gemm_B, _∇gemm_B, C_gen, B_gen, B_gen) + @test check_errs(N, _gemm_B, _∇gemm_B_inp, C_gen, B_gen, B_gen) + + # α = 1 tests. + _A, _B = A_gen(), B_gen() + _gemm_A = A->gemm(tA, tB, A, _B) + _gemm_B = B->gemm(tA, tB, _A, B) + _∇gemm_A = (Y, Ȳ, A)->∇(gemm, Val{3}, (), Y, Ȳ, tA, tB, A, _B) + _∇gemm_B = (Y, Ȳ, B)->∇(gemm, Val{4}, (), Y, Ȳ, tA, tB, _A, B) + _∇gemm_A_inp = (Y, Ȳ, A)->∇(copy(_δA), gemm, Val{3}, (), Y, Ȳ, tA, tB, A, _B) - _δA + _∇gemm_B_inp = (Y, Ȳ, B)->∇(copy(_δB), gemm, Val{4}, (), Y, Ȳ, tA, tB, _A, B) - _δB + + @test check_errs(N, _gemm_A, _∇gemm_A, C_gen, A_gen, A_gen) + @test check_errs(N, _gemm_A, _∇gemm_A_inp, C_gen, A_gen, A_gen) + @test check_errs(N, _gemm_B, _∇gemm_B, C_gen, B_gen, B_gen) + @test check_errs(N, _gemm_B, _∇gemm_B_inp, C_gen, B_gen, B_gen) + end + + # symm: + for side in ['L', 'R'], ul in ['L', 'U'] + + # Fixed qtts. + _α, _A, _B = sc(), mPP(), side == 'L' ? mPQ() : mQP() + _δA, _δB = randn!(rng, similar(_A)), randn!(rng, similar(_B)) + + # α != 1 tests. + _symm_α = α->symm(side, ul, α, _A, _B) + _symm_A = A->symm(side, ul, _α, A, _B) + _symm_B = B->symm(side, ul, _α, _A, B) + _∇symm_α = (Y, Ȳ, α)->∇(symm, Val{3}, (), Y, Ȳ, side, ul, α, _A, _B) + _∇symm_A = (Y, Ȳ, A)->∇(symm, Val{4}, (), Y, Ȳ, side, ul, _α, A, _B) + _∇symm_B = (Y, Ȳ, B)->∇(symm, Val{5}, (), Y, Ȳ, side, ul, _α, _A, B) + _∇symm_A_inp = (Y, Ȳ, A)->∇(copy(_δA), symm, Val{4}, (), Y, Ȳ, side, ul, _α, A, _B) - _δA + _∇symm_B_inp = (Y, Ȳ, B)->∇(copy(_δB), symm, Val{5}, (), Y, Ȳ, side, ul, _α, _A, B) - _δB + + B_gen = side == 'L' ? mPQ : mQP + @test check_errs(N, _symm_α, _∇symm_α, B_gen, sc, sc) + @test check_errs(N, _symm_A, _∇symm_A, B_gen, mPP, mPP) + @test check_errs(N, _symm_A, _∇symm_A_inp, B_gen, mPP, mPP) + @test check_errs(N, _symm_B, _∇symm_B, B_gen, B_gen, B_gen) + @test check_errs(N, _symm_B, _∇symm_B_inp, B_gen, B_gen, B_gen) + + # α = 1 tests. + _symm_A = A->symm(side, ul, A, _B) + _symm_B = B->symm(side, ul, _A, B) + _∇symm_A = (Y, Ȳ, A)->∇(symm, Val{3}, (), Y, Ȳ, side, ul, A, _B) + _∇symm_B = (Y, Ȳ, B)->∇(symm, Val{4}, (), Y, Ȳ, side, ul, _A, B) + _∇symm_A_inp = (Y, Ȳ, A)->∇(copy(_δA), symm, Val{3}, (), Y, Ȳ, side, ul, A, _B) - _δA + _∇symm_B_inp = (Y, Ȳ, B)->∇(copy(_δB), symm, Val{4}, (), Y, Ȳ, side, ul, _A, B) - _δB + + B_gen = side == 'L' ? mPQ : mQP + @test check_errs(N, _symm_A, _∇symm_A, B_gen, mPP, mPP) + @test check_errs(N, _symm_A, _∇symm_A_inp, B_gen, mPP, mPP) + @test check_errs(N, _symm_B, _∇symm_B, B_gen, B_gen, B_gen) + @test check_errs(N, _symm_B, _∇symm_B_inp, B_gen, B_gen, B_gen) + end + + # trmm / trsm: + for f in [trmm, trsm], side in ['L', 'R'], ul in ['L', 'U'], tA in ['N', 'T'], dA in ['U', 'N'] + A_gen = ()->mPP() + 1e-3I + B_gen = side == 'L' ? mPQ : mQP + C_gen = B_gen + + _α, _A, _B = sc(), A_gen(), B_gen() + _trmm_α = α->f(side, ul, tA, dA, α, _A, _B) + _trmm_A = A->f(side, ul, tA, dA, _α, A, _B) + _trmm_B = B->f(side, ul, tA, dA, _α, _A, B) + _∇trmm_α = (Y, Ȳ, α)->∇(f, Val{5}, (), Y, Ȳ, side, ul, tA, dA, α, _A, _B) + _∇trmm_A = (Y, Ȳ, A)->∇(f, Val{6}, (), Y, Ȳ, side, ul, tA, dA, _α, A, _B) + _∇trmm_B = (Y, Ȳ, B)->∇(f, Val{7}, (), Y, Ȳ, side, ul, tA, dA, _α, _A, B) + + @test check_errs(N, _trmm_α, _∇trmm_α, C_gen, sc, sc) + @test check_errs(N, _trmm_A, _∇trmm_A, C_gen, A_gen, A_gen) + @test check_errs(N, _trmm_B, _∇trmm_B, C_gen, B_gen, B_gen) + end +end +end diff --git a/test/reverse/diagonal.jl b/test/reverse/diagonal.jl new file mode 100644 index 0000000..d1ba0b3 --- /dev/null +++ b/test/reverse/diagonal.jl @@ -0,0 +1,65 @@ +@testset "Diagonal" begin + + # diag: + let P = 10, rng = MersenneTwister(123456), N = 10, k = 2 + + # Some rngs. + vP, mPP, vPk = ()->randn(rng, P), ()->randn(rng, P, P), ()->randn(rng, P - k) + + # Test on-central-diagonal `diag`. + X0 = randn!(rng, similar(mPP())) + @test check_errs(N, unary_ȲD(diag)..., vP, mPP, mPP) + @test check_errs(N, unary_ȲD_inplace(diag, X0)..., vP, mPP, mPP) + + # Test off-central-diagonal `diag`. + _diag = X->diag(X, k) + _∇diag = (y, ȳ, X)->∇(diag, Val{1}, (), y, ȳ, X, k) + _∇diag_inp = (y, ȳ, X)->∇(copy(X0), diag, Val{1}, (), y, ȳ, X, k) - X0 + @test check_errs(N, _diag, _∇diag, vPk, mPP, mPP) + @test check_errs(N, _diag, _∇diag_inp, vPk, mPP, mPP) + end + + # diagm: + let P = 10, rng = MersenneTwister(123456), N = 10, k = 3 + + # Some rngs. + sc, vP = ()->randn(rng), ()->randn(rng, P) + mPP, mPPk = ()->randn(rng, P, P), ()->randn(rng, P + k, P + k) + + # Test on-central-diagonal `diagm`. + x0 = randn!(rng, similar(vP())) + _diagm = x->diagm(0=>x) + _∇diagm = (Y, Ȳ, x)->∇(diagm, Val{1}, (), Y, Ȳ, 0=>x) + _∇diagm_inp = (Y, Ȳ, x)->∇(copy(x0), diagm, Val{1}, (), Y, Ȳ, 0=>x) - x0 + @test check_errs(N, _diagm, _∇diagm, mPP, vP, vP) + @test check_errs(N, _diagm, _∇diagm_inp, mPP, vP, vP) + + # Test off-central-diagonal `diagm`. + _diagm = x->diagm(k=>x) + _∇diagm = (Y, Ȳ, x)->∇(diagm, Val{1}, (), Y, Ȳ, k=>x) + _∇diagm_inp = (Y, Ȳ, x)->∇(copy(x0), diagm, Val{1}, (), Y, Ȳ, k=>x) - x0 + @test check_errs(N, _diagm, _∇diagm, mPPk, vP, vP) + @test check_errs(N, _diagm, _∇diagm_inp, mPPk, vP, vP) + end + + # Diagonal: + let P = 10, rng = MersenneTwister(123456), N = 10 + sc, vP = ()->abs(randn(rng)), ()->abs.(randn(rng, P)) + mPP, dPP = ()->abs.(randn(rng, P, P)), ()->Diagonal(abs.(randn(rng, P))) + D0 = Diagonal(randn!(rng, similar(vP()))) + + # Construction. + @test check_errs(N, unary_ȲD(Diagonal)..., dPP, vP, vP) + @test check_errs(N, unary_ȲD_inplace(Diagonal, vP())..., dPP, vP, vP) + @test check_errs(N, unary_ȲD(Diagonal)..., dPP, mPP, mPP) + @test check_errs(N, unary_ȲD_inplace(Diagonal, mPP())..., dPP, mPP, mPP) + + # Determinant. + @test check_errs(N, unary_ȲD(det)..., sc, dPP, dPP) + @test check_errs(N, unary_ȲD_inplace(det, D0)..., sc, dPP, dPP) + + # Log Determinant. + @test check_errs(N, unary_ȲD(logdet)..., sc, dPP, dPP) + @test check_errs(N, unary_ȲD_inplace(logdet, D0)..., sc, dPP, dPP) + end +end diff --git a/test/reverse/factorization/cholesky.jl b/test/reverse/factorization/cholesky.jl new file mode 100644 index 0000000..07894bb --- /dev/null +++ b/test/reverse/factorization/cholesky.jl @@ -0,0 +1,58 @@ +@testset "Cholesky" begin + import DiffLinearAlgebra: level2partition, level3partition, chol_unblocked_rev, + chol_blocked_rev + + let rng = MersenneTwister(123456), N = 5 + A = randn(rng, N, N) + r, d, B2, c = level2partition(A, 4, false) + R, D, B3, C = level3partition(A, 4, 4, false) + @test all(r .== R') + @test all(d .== D) + @test B2[1] == B3[1] + @test all(c .== C) + + # Check that level2partition with 'U' is consistent with 'L'. + rᵀ, dᵀ, B2ᵀ, cᵀ = level2partition(transpose(A), 4, true) + @test r == rᵀ + @test d == dᵀ + @test B2' == B2ᵀ + @test c == cᵀ + + # Check that level3partition with 'U' is consistent with 'L'. + R, D, B3, C = level3partition(A, 2, 4, false) + Rᵀ, Dᵀ, B3ᵀ, Cᵀ = level3partition(transpose(A), 2, 4, true) + @test transpose(R) == Rᵀ + @test transpose(D) == Dᵀ + @test transpose(B3) == B3ᵀ + @test transpose(C) == Cᵀ + end + + let rng = MersenneTwister(123456), N = 10 + A, Ā = Matrix.(LowerTriangular.(randn.(Ref(rng), [N, N], [N, N]))) + B, B̄ = copy.(transpose.([A, Ā])) + @test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 1, false) + @test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 3, false) + @test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 5, false) + @test chol_unblocked_rev(Ā, A, false) ≈ chol_blocked_rev(Ā, A, 10, false) + @test chol_unblocked_rev(Ā, A, false) ≈ chol_unblocked_rev(B̄, B, true)' + + @test chol_unblocked_rev(B̄, B, true) ≈ chol_blocked_rev(B̄, B, 1, true) + @test chol_unblocked_rev(B̄, B, true) ≈ chol_blocked_rev(B̄, B, 5, true) + @test chol_unblocked_rev(B̄, B, true) ≈ chol_blocked_rev(B̄, B, 10, true) + end + + # Check sensitivities for lower-triangular version. + let P = 15, rng = MersenneTwister(123456), N = 10 + + Σ_gen = ()->(A = randn(rng, P, P); A'A + 1e-3I) + S_gen, H_gen = ()->Symmetric(Σ_gen()), ()->Hermitian(Σ_gen()) + U_gen = ()->chol(S_gen()) + + _chol = Σ->chol(Symmetric(Σ)) + _∇chol = (Y, Ȳ, Σ)->∇(chol, Val{1}, (), Y, Ȳ, Symmetric(Σ)) + + @test check_errs(N, _chol, _∇chol, U_gen, Σ_gen, Σ_gen) + @test check_errs(N, _chol, _∇chol, U_gen, H_gen, H_gen) + @test check_errs(N, _chol, _∇chol, U_gen, S_gen, S_gen) + end +end diff --git a/test/reverse/generic.jl b/test/reverse/generic.jl new file mode 100644 index 0000000..566cfc4 --- /dev/null +++ b/test/reverse/generic.jl @@ -0,0 +1,97 @@ +using LinearAlgebra +using LinearAlgebra: -, tr, inv, det, logdet, transpose, adjoint, norm +using DiffRules: ReverseRuleKey, DEFINED_REVERSE_RULES, make_named_signature + +function unary_ȲD(key::ReverseRuleKey) + f = @eval $(key[1]).$(key[2]) + arg_names = vcat(gensym(), gensym(), [gensym() for _ in 1:arity(key) - 2]) + typed_args = Expr(:tuple, make_named_signature(arg_names, key)...) + body = DEFINED_REVERSE_RULES[key](arg_names...) + return f, eval(Expr(Symbol("->"), typed_args, body)) +end + + +@testset "generic" begin + +let + P, Q, rng, N = 4, 3, MersenneTwister(123456), 100 + + # Utility for generating square matrices, vectors, non-square matrices and scalars. + mPP, mQQ = ()->randn(rng, P, P), ()->randn(rng, Q, Q) + mPQ, mQP = ()->randn(rng, P, Q), ()->randn(rng, Q, P) + mPQQP = ()->randn(rng, P * Q, P * Q) + v, sc = ()->randn(rng, P), ()->randn(rng) + psd = ()->(A = randn(rng, P, P); transpose(A) * A + 1e-3I) + + # sig = :(Tuple{AbstractArray{<:Real}, AbstractArray{<:Real}, AbstractArray{<:Real}}) + # @show unary_ȲD((:Base, :-, sig, (1,))) + # f, df = unary_ȲD((:Base, :-, sig, (1,))) + # @show f(mPP()) + # Ȳ = mPP() + # @show Ȳ + # @show df(mPP(), Ȳ, mPP()) + + + sig = :(Tuple{AbstractArray{<:Real}, AbstractArray{<:Real}, AbstractArray{<:Real}}) + @test check_errs(N, unary_ȲD((:Base, :-, sig, (1,)))..., mPQ, mPQ, mPQ) + + sig = :(Tuple{Real, Real, AbstractMatrix{<:Real}}) + @test check_errs(N, unary_ȲD((:LinearAlgebra, :tr, sig, (1,)))..., sc, mPP, mPP) + + sig = :(Tuple{AbstractMatrix{<:Real}, AbstractMatrix{<:Real}, AbstractMatrix{<:Real}}) + @test check_errs(N, unary_ȲD((:LinearAlgebra, :inv, sig, (1,)))..., mPP, mPP, mPP) + + sig = :(Tuple{Real, Real, AbstractMatrix{<:Real}}) + @test check_errs(N, unary_ȲD((:LinearAlgebra, :det, sig, (1,)))..., sc, mPP, mPP) + + sig = :(Tuple{Real, Real, AbstractMatrix{<:Real}}) + @test check_errs(N, unary_ȲD((:LinearAlgebra, :logdet, sig, (1,)))..., sc, psd, psd) + + sig = :(Tuple{AbstractVecOrMat{<:Real}, AbstractVecOrMat{<:Real}, AbstractVecOrMat{<:Real}}) + @test check_errs(N, unary_ȲD((:LinearAlgebra, :transpose, sig, (1,)))..., mQP, mPQ, mPQ) + + sig = :(Tuple{AbstractVecOrMat{<:Real}, AbstractVecOrMat{<:Real}, AbstractVecOrMat{<:Real}}) + @test check_errs(N, unary_ȲD((:LinearAlgebra, :adjoint, sig, (1,)))..., mQP, mPQ, mPQ) + + sig = :(Tuple{Real, Real, AbstractArray{<:Real}}) + @test check_errs(N, unary_ȲD((:LinearAlgebra, :norm, sig, (1,)))..., sc, mPQ, mPQ) + + sig = :(Tuple{Real, Real, Real}) + @test check_errs(N, unary_ȲD((:LinearAlgebra, :norm, sig, (1,)))..., sc, sc, sc) + + # # Test all of the binary sensitivities. + # @test check_errs(N, binary_ȲD(*, 1, mQP)..., mPP, mPQ, mPQ) + # @test check_errs(N, binary_ȲD(*, 2, mPQ)..., mPP, mQP, mQP) + # @test check_errs(N, binary_ȲD(*, 1, mQP)..., mPP, ()->mQP()', ()->mQP()') + # @test check_errs(N, binary_ȲD(*, 2, ()->mQP()')..., mPP, mQP, mQP) + # @test check_errs(N, binary_ȲD(*, 1, ()->mPQ()')..., mPP, mPQ, mPQ) + # @test check_errs(N, binary_ȲD(*, 2, mPQ)..., mPP, ()->mPQ()', ()->mPQ()') + # @test check_errs(N, binary_ȲD(*, 1, ()->mPQ()')..., mPP, ()->mQP()', ()->mQP()') + # @test check_errs(N, binary_ȲD(*, 2, ()->mQP()')..., mPP, ()->mPQ()', ()->mPQ()') + # @test check_errs(N, binary_ȲD(/, 1, mQQ)..., mPQ, mPQ, mPQ) + # @test check_errs(N, binary_ȲD(/, 2, mPQ)..., mPQ, mQQ, mQQ) + # @test check_errs(N, binary_ȲD(/, 1, mQQ)..., mPQ, ()->mQP()', ()->mQP()') + # @test check_errs(N, binary_ȲD(/, 2, ()->mQP()')..., mPQ, mQQ, mQQ) + # @test check_errs(N, binary_ȲD(/, 1, ()->mQQ()')..., mPQ, mPQ, mPQ) + # @test check_errs(N, binary_ȲD(/, 2, mPQ)..., mPQ, ()->mQQ()', ()->mQQ()') + # @test check_errs(N, binary_ȲD(/, 1, ()->mQQ()')..., mPQ, ()->mQP()', ()->mQP()') + # @test check_errs(N, binary_ȲD(/, 2, ()->mQP()')..., mPQ, ()->mQQ()', ()->mQQ()') + # @test check_errs(N, binary_ȲD(\, 1, mQP)..., mQP, mQQ, mQQ) + # @test check_errs(N, binary_ȲD(\, 2, mQQ)..., mQP, mQP, mQP) + # @test check_errs(N, binary_ȲD(\, 1, mQP)..., mQP, ()->mQQ()', ()->mQQ()') + # @test check_errs(N, binary_ȲD(\, 2, ()->mQQ()')..., mQP, mQP, mQP) + # @test check_errs(N, binary_ȲD(\, 1, ()->mPQ()')..., mQP, mQQ, mQQ) + # @test check_errs(N, binary_ȲD(\, 2, mQQ)..., mQP, ()->mPQ()', ()->mPQ()') + # @test check_errs(N, binary_ȲD(\, 1, ()->mPQ()')..., mQP, ()->mQQ()', ()->mQQ()') + # @test check_errs(N, binary_ȲD(\, 2, ()->mQQ()')..., mQP, ()->mPQ()', ()->mPQ()') + # @test check_errs(N, binary_ȲD(vecnorm, 1, sc)..., sc, mPQ, mPQ) + # @test check_errs(N, binary_ȲD(vecnorm, 2, mPQ)..., sc, sc, sc) + # @test check_errs(N, binary_ȲD(vecnorm, 1, sc)..., sc, sc, sc) + # @test check_errs(N, binary_ȲD(vecnorm, 2, sc)..., sc, sc, sc) + # @test check_errs(N, binary_ȲD(kron, 1, mQP)..., mPQQP, mPQ, mPQ) + # @test check_errs(N, binary_ȲD(kron, 2, mPQ)..., mPQQP, mQP, mQP) + # @test check_errs(N, binary_ȲD_inplace(kron, 1, mQP, mPQ())..., mPQQP, mPQ, mPQ) + # @test check_errs(N, binary_ȲD_inplace(kron, 2, mPQ, mQP())..., mPQQP, mQP, mQP) +end + +end diff --git a/test/reverse/test_imports.jl b/test/reverse/test_imports.jl new file mode 100644 index 0000000..1bda615 --- /dev/null +++ b/test/reverse/test_imports.jl @@ -0,0 +1,7 @@ +@testset "testimports" begin + # Check that everything is imported without an error. This is quite a weak criterion, + # but I'm not sure what else is easily doable. + for op in DLA.ops + @eval $(import_expr(op)) + end +end diff --git a/test/reverse/test_util.jl b/test/reverse/test_util.jl new file mode 100644 index 0000000..a334a8a --- /dev/null +++ b/test/reverse/test_util.jl @@ -0,0 +1,39 @@ +const AS = Union{AbstractArray{<:Real}, Real} + +""" + check_errs(f, ∇f, Ȳ::AS, X::AS, V::AS, ε_abs::Real=1e-7, ε_rel::Real=1e-5)::Bool +""" +check_errs(f, ∇f, Ȳ::AS, X::AS, V::AS, ε_abs::Real=1e-7, ε_rel::Real=1e-5)::Bool = + assert_approx_equal( + central_fdm(2, 1)(ϵ->sum(Ȳ .* f(X + ϵ * V))), + sum(∇f(f(X), Ȳ, X) .* V), + ε_abs, ε_rel) + +""" + check_errs(N::Int, f, ∇f, Ȳ, X, V, ε_abs=1e-7, ε_rel=1e-5)::Bool + +Call check_errs `N` times with arguments generated by 0-ary functions `Ȳ`, `X` and `V`. +""" +check_errs(N::Int, f, ∇f, Ȳ, X, V, ε_abs=1e-7, ε_rel=1e-5)::Bool = + all(map(n->check_errs(f, ∇f, Ȳ(), X(), V(), ε_abs, ε_rel), 1:N)) + + +# Utility to create the closures required for unit testing. +unary_ȲD(f) = (f, (Y, Ȳ, X)->∇(f, Val{1}, (), Y, Ȳ, X)) +unary_ȲD_inplace(f, X̄0) = (f, (Y, Ȳ, X)->∇(copy(X̄0), f, Val{1}, (), Y, Ȳ, X) - X̄0) +function binary_ȲD(f, arg::Int, G) + G_ = G() + if arg == 1 + return X->f(X, G_), (Y, Ȳ, X)->∇(f, Val{1}, (), Y, Ȳ, X, G_) + else + return X->f(G_, X), (Y, Ȳ, X)->∇(f, Val{2}, (), Y, Ȳ, G_, X) + end +end +function binary_ȲD_inplace(f, arg::Int, G, X̄0) + G_ = G() + if arg == 1 + return X->f(X, G_), (Y, Ȳ, X)->∇(copy(X̄0), f, Val{1}, (), Y, Ȳ, X, G_) - X̄0 + else + return X->f(G_, X), (Y, Ȳ, X)->∇(copy(X̄0), f, Val{2}, (), Y, Ȳ, G_, X) - X̄0 + end +end diff --git a/test/reverse/triangular.jl b/test/reverse/triangular.jl new file mode 100644 index 0000000..1e97ae1 --- /dev/null +++ b/test/reverse/triangular.jl @@ -0,0 +1,27 @@ +@testset "Triangular" begin + + for f in [LowerTriangular, UpperTriangular] + let P = 10, rng = MersenneTwister(123456), N = 10 + + sc = ()->randn(rng) + mPP = ()->abs.(randn(rng, P, P)) + dPP = ()->Diagonal(abs.(randn(rng, P))) + tPP = ()->f(abs.(randn(rng, P, P))) + m0 = randn!(rng, similar(mPP())) + + # Construction. + @test check_errs(N, unary_ȲD(f)..., tPP, mPP, mPP) + @test check_errs(N, unary_ȲD_inplace(f, m0)..., tPP, mPP, mPP) + + # Determinant. + @test check_errs(N, unary_ȲD(det)..., sc, tPP, tPP) + @test check_errs(N, unary_ȲD_inplace(det, tPP())..., sc, tPP, tPP) + @test check_errs(N, unary_ȲD_inplace(det, dPP())..., sc, tPP, tPP) + + # Log Determinant. + @test check_errs(N, unary_ȲD(logdet)..., sc, tPP, tPP) + @test check_errs(N, unary_ȲD_inplace(logdet, tPP())..., sc, tPP, tPP) + @test check_errs(N, unary_ȲD_inplace(logdet, dPP())..., sc, tPP, tPP) + end + end +end diff --git a/test/reverse/uniformscaling.jl b/test/reverse/uniformscaling.jl new file mode 100644 index 0000000..e69de29 diff --git a/test/reverse/util.jl b/test/reverse/util.jl new file mode 100644 index 0000000..b8f932f --- /dev/null +++ b/test/reverse/util.jl @@ -0,0 +1,16 @@ +@testset "util" begin + import DiffLinearAlgebra: importable + # @test Expr(:import, Expr(Symbol("."), importable(:Foo)...)) == + # :(import Foo) + # @test Expr(:import, Expr(Symbol("."), importable(:(Package.Foo))...)) == + # :(import Package.Foo) + # @test Expr(:import, Expr(Symbol("."), importable(:(Package.Subpackage.Foo))...)) == + # :(import Package.Subpackage.Foo) + + @test import_expr(DLA.DiffOp(:Foo, :(Tuple{Foo}), [true])) == + :(import Foo) + @test import_expr(DLA.DiffOp(:(Package.Foo), :(Tuple{Foo}), [true])) == + :(import Package.Foo) + @test import_expr(DLA.DiffOp(:(Package.Subpackage.Foo), :(Tuple{Foo}), [true])) == + :(import Package.Subpackage.Foo) +end diff --git a/test/runtests.jl b/test/runtests.jl index 51455db..4c151a8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,59 +1,26 @@ -if VERSION < v"0.7-" - using Base.Test - srand(1) -else - using Test - import Random - Random.seed!(1) -end -import SpecialFunctions, NaNMath -using DiffRules +using Test, DiffRules, Random, FDM +import Random, SpecialFunctions, NaNMath +Random.seed!(1) function finitediff(f, x) ϵ = cbrt(eps(typeof(x))) * max(one(typeof(x)), abs(x)) return (f(x + ϵ) - f(x - ϵ)) / (ϵ + ϵ) end +@testset "DiffRules" begin -non_numeric_arg_functions = [(:Base, :rem2pi, 2)] + @testset "diffrules" begin + include("diffrules/rules.jl") + end -for (M, f, arity) in DiffRules.diffrules() - (M, f, arity) ∈ non_numeric_arg_functions && continue - if arity == 1 - @test DiffRules.hasdiffrule(M, f, 1) - deriv = DiffRules.diffrule(M, f, :goo) - modifier = in(f, (:asec, :acsc, :asecd, :acscd, :acosh, :acoth)) ? 1 : 0 - @eval begin - goo = rand() + $modifier - @test isapprox($deriv, finitediff($M.$f, goo), rtol=0.05) - end - elseif arity == 2 - @test DiffRules.hasdiffrule(M, f, 2) - derivs = DiffRules.diffrule(M, f, :foo, :bar) - @eval begin - foo, bar = rand(1:10), rand() - dx, dy = $(derivs[1]), $(derivs[2]) - if !(isnan(dx)) - @test isapprox(dx, finitediff(z -> $M.$f(z, bar), float(foo)), rtol=0.05) - end - if !(isnan(dy)) - @test isapprox(dy, finitediff(z -> $M.$f(foo, z), bar), rtol=0.05) - end - end + @testset "forward" begin + include("forward/api.jl") end -end -# Treat rem2pi separately because of its non-numeric second argument: -derivs = DiffRules.diffrule(:Base, :rem2pi, :x, :y) -for xtype in [:Float64, :BigFloat, :Int64] - for mode in [:RoundUp, :RoundDown, :RoundToZero, :RoundNearest] - @eval begin - x = $xtype(rand(1 : 10)) - y = $mode - dx, dy = $(derivs[1]), $(derivs[2]) - @test isapprox(dx, finitediff(z -> rem2pi(z, y), float(x)), rtol=0.05) - @test isnan(dy) - end + @testset "reverse" begin + include("reverse/api.jl") + include("reverse/test_util.jl") + include("reverse/generic.jl") end end