From 0c9e1e559ba79104c5280dff3482bb6a6ed69001 Mon Sep 17 00:00:00 2001 From: Leo Date: Tue, 16 Nov 2021 10:21:06 -0500 Subject: [PATCH] [WIP] Fix 170 - an improved patch (#173) * fix issue 170 * fix chainrules patch * fix circuit gradient type * merge master * fix test for latest zygote --- src/autodiff/chainrules_patch.jl | 100 +++++++++++++++++++++++++++--- test/autodiff/chainrules_patch.jl | 14 ++++- 2 files changed, 102 insertions(+), 12 deletions(-) diff --git a/src/autodiff/chainrules_patch.jl b/src/autodiff/chainrules_patch.jl index 87c9785..d636785 100644 --- a/src/autodiff/chainrules_patch.jl +++ b/src/autodiff/chainrules_patch.jl @@ -1,25 +1,90 @@ -import ChainRulesCore: rrule, @non_differentiable, NoTangent, Tangent +import ChainRulesCore: rrule, @non_differentiable, NoTangent, Tangent, backing, AbstractTangent, ZeroTangent + +function create_circuit_tangent(circuit, params) + gc = dispatch(circuit, params) + res = recursive_create_tangent(gc) + return res +end + +# fallback +function recursive_create_tangent(c::AbstractBlock) + if nparameters(c) == 0 + return NoTangent() + else + error("`Tangent` for type $(typeof(c)) is not defined!") + end +end +# primitive blocks +unsafe_primitive_tangent(::Any) = NoTangent() +unsafe_primitive_tangent(x::Number) = x +for GT in [:RotationGate, :ShiftGate, :TimeEvolution, :PhaseGate] + @eval function recursive_create_tangent(c::$GT) + lst = map(fieldnames(typeof(c))) do fn + fn => unsafe_primitive_tangent(getfield(c, fn)) + end + nt = NamedTuple(lst) + Tangent{typeof(c), typeof(nt)}(nt) + end +end +# composite blocks +unsafe_composite_tangent(::Any) = NoTangent() +unsafe_composite_tangent(c::AbstractVector{<:AbstractBlock}) = recursive_create_tangent.(c) +unsafe_composite_tangent(c::AbstractBlock) = recursive_create_tangent(c) +for GT in [:ChainBlock, :Add, :KronBlock, :RepeatedBlock, :PutBlock, :Subroutine, :CachedBlock, :Daggered, :Scale] + @eval function recursive_create_tangent(c::$GT) + lst = map(fieldnames(typeof(c))) do fn + fn => unsafe_composite_tangent(getfield(c, fn)) + end + nt = NamedTuple(lst) + Tangent{typeof(c), typeof(nt)}(nt) + end +end + +extract_circuit_gradients!(c::Number, output) = push!(output, c) +extract_circuit_gradients!(::Nothing, output) = output +extract_circuit_gradients!(::NoTangent, output) = output +extract_circuit_gradients!(::ZeroTangent, output) = output +function extract_circuit_gradients!(c::AbstractVector, output) + for ci in c + extract_circuit_gradients!(ci, output) + end + return output +end +function extract_circuit_gradients!(c::Tangent, output) + for fn in getfield(c, :backing) + extract_circuit_gradients!(fn, output) + end + return output +end +function extract_circuit_gradients!(c::NamedTuple, output) + for fn in c + extract_circuit_gradients!(fn, output) + end + return output +end function rrule(::typeof(apply), reg::ArrayReg, block::AbstractBlock) out = apply(reg, block) out, function (outδ) (in, inδ), paramsδ = apply_back((copy(out), outδ), block) - return (NoTangent(), inδ, dispatch(block, paramsδ)) + return (NoTangent(), inδ, create_circuit_tangent(block, paramsδ)) end end - function rrule(::typeof(apply), reg::ArrayReg, block::Add) out = apply(reg, block) out, function (outδ) (in, inδ), paramsδ = apply_back((copy(out), outδ), block; in = reg) - return (NoTangent(), inδ, dispatch(block, paramsδ)) + return (NoTangent(), inδ, create_circuit_tangent(block, paramsδ)) end end + function rrule(::typeof(dispatch), block::AbstractBlock, params) out = dispatch(block, params) - out, function (outδ) - (NoTangent(), NoTangent(), parameters(outδ)) + out, function (outδ::AbstractTangent) + g = extract_circuit_gradients!(outδ, empty(params)) + res = (NoTangent(), NoTangent(), g) + return res end end @@ -41,7 +106,7 @@ function rrule(::typeof(expect), op::AbstractBlock, reg_and_circuit::Pair{<:Arra for b in 1:B viewbatch(greg, b).state .*= 2 * outδ[b] end - return (NoTangent(), NoTangent(), Tangent{typeof(reg_and_circuit)}(; first=greg, second=dispatch(reg_and_circuit.second, gcircuit))) + return (NoTangent(), NoTangent(), Tangent{typeof(reg_and_circuit)}(; first=greg, second=create_circuit_tangent(reg_and_circuit.second, gcircuit))) end end @@ -49,7 +114,7 @@ function rrule(::Type{T}, block::AbstractBlock) where T<:Matrix out = T(block) out, function (outδ) paramsδ = mat_back(block, outδ) - return (NoTangent(), dispatch(block, paramsδ)) + return (NoTangent(), create_circuit_tangent(block, paramsδ)) end end @@ -57,7 +122,7 @@ function rrule(::typeof(mat), ::Type{T}, block::AbstractBlock) where T out = mat(T, block) out, function (outδ) paramsδ = mat_back(block, outδ) - return (NoTangent(), NoTangent(), dispatch(block, paramsδ)) + return (NoTangent(), NoTangent(), create_circuit_tangent(block, paramsδ)) end end @@ -73,6 +138,23 @@ function rrule(::typeof(copy), reg::ArrayReg) where {B} copy(reg), adjy -> (NoTangent(), adjy) end +for (BT, BLOCKS) in [(:Add, :(outδ.list)) (:ChainBlock, :(outδ.blocks))] + for ST in [:AbstractVector, :Tuple] + @eval function rrule(::Type{BT}, source::$ST) where {N, BT<:$BT} + out = BT(source) + out, function (outδ) + return (NoTangent(), $ST($BLOCKS)) + end + end + end + @eval function rrule(::Type{BT}, args::AbstractBlock...) where {N, BT <: $BT} + out = BT(args...) + out, function (outδ) + return (NoTangent(), $BLOCKS...) + end + end +end + _totype(::Type{T}, x::AbstractArray{T}) where {T} = x _totype(::Type{T}, x::AbstractArray{T2}) where {T,T2} = convert.(T, x) rrule(::typeof(state), reg::ArrayReg{B,T}) where {B,T} = diff --git a/test/autodiff/chainrules_patch.jl b/test/autodiff/chainrules_patch.jl index 39f07a6..83354d0 100644 --- a/test/autodiff/chainrules_patch.jl +++ b/test/autodiff/chainrules_patch.jl @@ -1,9 +1,17 @@ import Zygote, ForwardDiff using Random, Test using YaoBlocks, YaoArrayRegister +using ChainRulesCore: Tangent -function Zygote.accum(a::AbstractBlock, b::AbstractBlock) - dispatch(a, parameters(a) + parameters(b)) +@testset "recursive_create_tangent" begin + c = chain(put(5, 2 => chain(Rx(1.4), Rx(0.5))), cnot(5, 3, 1), put(5, 3 => Rx(-0.5))) + tc = YaoBlocks.AD.recursive_create_tangent(c) + @test tc isa Tangent +end + +@testset "construtors" begin + @test Zygote.gradient(x->x.list[1].blocks[1].theta, sum([chain(1, Rz(0.3))]))[1] == (list = NamedTuple{(:blocks,), Tuple{Vector{NamedTuple{(:block, :theta), Tuple{Nothing, Float64}}}}}[(blocks = [(block = nothing, theta = 1.0)],)],) + @test_broken Zygote.gradient(x->getfield(getfield(x,:content), :theta), Daggered(Rx(0.5)))[1] == (content = (block = nothing, theta = 1.0),) end @testset "rules" begin @@ -35,7 +43,7 @@ end @test Zygote.gradient(x -> real(sum(abs2, statevec(x'))), r)[1].state ≈ g1 # zygote does not work if `sin` is not here, # because it gives an adjoint of different type as the output matrix type. - @test parameters(Zygote.gradient(x -> real(sum(sin, Matrix(x))), c)[1]) ≈ + @test AD.extract_circuit_gradients!(Zygote.gradient(x -> real(sum(sin, Matrix(x))), c)[1].blocks, Float64[]) ≈ ForwardDiff.gradient(x -> real(sum(sin, Matrix(dispatch(c, x)))), parameters(c)) end