Skip to content
This repository has been archived by the owner on Dec 18, 2021. It is now read-only.

Commit

Permalink
[WIP] Fix 170 - an improved patch (#173)
Browse files Browse the repository at this point in the history
* fix issue 170

* fix chainrules patch

* fix circuit gradient type

* merge master

* fix test for latest zygote
  • Loading branch information
GiggleLiu authored Nov 16, 2021
1 parent f3fede2 commit 0c9e1e5
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 12 deletions.
100 changes: 91 additions & 9 deletions src/autodiff/chainrules_patch.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,90 @@
import ChainRulesCore: rrule, @non_differentiable, NoTangent, Tangent
import ChainRulesCore: rrule, @non_differentiable, NoTangent, Tangent, backing, AbstractTangent, ZeroTangent

function create_circuit_tangent(circuit, params)
gc = dispatch(circuit, params)
res = recursive_create_tangent(gc)
return res
end

# fallback
function recursive_create_tangent(c::AbstractBlock)
if nparameters(c) == 0
return NoTangent()
else
error("`Tangent` for type $(typeof(c)) is not defined!")
end
end
# primitive blocks
unsafe_primitive_tangent(::Any) = NoTangent()
unsafe_primitive_tangent(x::Number) = x
for GT in [:RotationGate, :ShiftGate, :TimeEvolution, :PhaseGate]
@eval function recursive_create_tangent(c::$GT)
lst = map(fieldnames(typeof(c))) do fn
fn => unsafe_primitive_tangent(getfield(c, fn))
end
nt = NamedTuple(lst)
Tangent{typeof(c), typeof(nt)}(nt)
end
end
# composite blocks
unsafe_composite_tangent(::Any) = NoTangent()
unsafe_composite_tangent(c::AbstractVector{<:AbstractBlock}) = recursive_create_tangent.(c)
unsafe_composite_tangent(c::AbstractBlock) = recursive_create_tangent(c)
for GT in [:ChainBlock, :Add, :KronBlock, :RepeatedBlock, :PutBlock, :Subroutine, :CachedBlock, :Daggered, :Scale]
@eval function recursive_create_tangent(c::$GT)
lst = map(fieldnames(typeof(c))) do fn
fn => unsafe_composite_tangent(getfield(c, fn))
end
nt = NamedTuple(lst)
Tangent{typeof(c), typeof(nt)}(nt)
end
end

extract_circuit_gradients!(c::Number, output) = push!(output, c)
extract_circuit_gradients!(::Nothing, output) = output
extract_circuit_gradients!(::NoTangent, output) = output
extract_circuit_gradients!(::ZeroTangent, output) = output
function extract_circuit_gradients!(c::AbstractVector, output)
for ci in c
extract_circuit_gradients!(ci, output)
end
return output
end
function extract_circuit_gradients!(c::Tangent, output)
for fn in getfield(c, :backing)
extract_circuit_gradients!(fn, output)
end
return output
end
function extract_circuit_gradients!(c::NamedTuple, output)
for fn in c
extract_circuit_gradients!(fn, output)
end
return output
end

function rrule(::typeof(apply), reg::ArrayReg, block::AbstractBlock)
out = apply(reg, block)
out, function (outδ)
(in, inδ), paramsδ = apply_back((copy(out), outδ), block)
return (NoTangent(), inδ, dispatch(block, paramsδ))
return (NoTangent(), inδ, create_circuit_tangent(block, paramsδ))
end
end

function rrule(::typeof(apply), reg::ArrayReg, block::Add)
out = apply(reg, block)
out, function (outδ)
(in, inδ), paramsδ = apply_back((copy(out), outδ), block; in = reg)
return (NoTangent(), inδ, dispatch(block, paramsδ))
return (NoTangent(), inδ, create_circuit_tangent(block, paramsδ))
end
end


function rrule(::typeof(dispatch), block::AbstractBlock, params)
out = dispatch(block, params)
out, function (outδ)
(NoTangent(), NoTangent(), parameters(outδ))
out, function (outδ::AbstractTangent)
g = extract_circuit_gradients!(outδ, empty(params))
res = (NoTangent(), NoTangent(), g)
return res
end
end

Expand All @@ -41,23 +106,23 @@ function rrule(::typeof(expect), op::AbstractBlock, reg_and_circuit::Pair{<:Arra
for b in 1:B
viewbatch(greg, b).state .*= 2 * outδ[b]
end
return (NoTangent(), NoTangent(), Tangent{typeof(reg_and_circuit)}(; first=greg, second=dispatch(reg_and_circuit.second, gcircuit)))
return (NoTangent(), NoTangent(), Tangent{typeof(reg_and_circuit)}(; first=greg, second=create_circuit_tangent(reg_and_circuit.second, gcircuit)))
end
end

function rrule(::Type{T}, block::AbstractBlock) where T<:Matrix
out = T(block)
out, function (outδ)
paramsδ = mat_back(block, outδ)
return (NoTangent(), dispatch(block, paramsδ))
return (NoTangent(), create_circuit_tangent(block, paramsδ))
end
end

function rrule(::typeof(mat), ::Type{T}, block::AbstractBlock) where T
out = mat(T, block)
out, function (outδ)
paramsδ = mat_back(block, outδ)
return (NoTangent(), NoTangent(), dispatch(block, paramsδ))
return (NoTangent(), NoTangent(), create_circuit_tangent(block, paramsδ))
end
end

Expand All @@ -73,6 +138,23 @@ function rrule(::typeof(copy), reg::ArrayReg) where {B}
copy(reg), adjy -> (NoTangent(), adjy)
end

for (BT, BLOCKS) in [(:Add, :(outδ.list)) (:ChainBlock, :(outδ.blocks))]
for ST in [:AbstractVector, :Tuple]
@eval function rrule(::Type{BT}, source::$ST) where {N, BT<:$BT}
out = BT(source)
out, function (outδ)
return (NoTangent(), $ST($BLOCKS))
end
end
end
@eval function rrule(::Type{BT}, args::AbstractBlock...) where {N, BT <: $BT}
out = BT(args...)
out, function (outδ)
return (NoTangent(), $BLOCKS...)
end
end
end

_totype(::Type{T}, x::AbstractArray{T}) where {T} = x
_totype(::Type{T}, x::AbstractArray{T2}) where {T,T2} = convert.(T, x)
rrule(::typeof(state), reg::ArrayReg{B,T}) where {B,T} =
Expand Down
14 changes: 11 additions & 3 deletions test/autodiff/chainrules_patch.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import Zygote, ForwardDiff
using Random, Test
using YaoBlocks, YaoArrayRegister
using ChainRulesCore: Tangent

function Zygote.accum(a::AbstractBlock, b::AbstractBlock)
dispatch(a, parameters(a) + parameters(b))
@testset "recursive_create_tangent" begin
c = chain(put(5, 2 => chain(Rx(1.4), Rx(0.5))), cnot(5, 3, 1), put(5, 3 => Rx(-0.5)))
tc = YaoBlocks.AD.recursive_create_tangent(c)
@test tc isa Tangent
end

@testset "construtors" begin
@test Zygote.gradient(x->x.list[1].blocks[1].theta, sum([chain(1, Rz(0.3))]))[1] == (list = NamedTuple{(:blocks,), Tuple{Vector{NamedTuple{(:block, :theta), Tuple{Nothing, Float64}}}}}[(blocks = [(block = nothing, theta = 1.0)],)],)
@test_broken Zygote.gradient(x->getfield(getfield(x,:content), :theta), Daggered(Rx(0.5)))[1] == (content = (block = nothing, theta = 1.0),)
end

@testset "rules" begin
Expand Down Expand Up @@ -35,7 +43,7 @@ end
@test Zygote.gradient(x -> real(sum(abs2, statevec(x'))), r)[1].state g1
# zygote does not work if `sin` is not here,
# because it gives an adjoint of different type as the output matrix type.
@test parameters(Zygote.gradient(x -> real(sum(sin, Matrix(x))), c)[1])
@test AD.extract_circuit_gradients!(Zygote.gradient(x -> real(sum(sin, Matrix(x))), c)[1].blocks, Float64[])
ForwardDiff.gradient(x -> real(sum(sin, Matrix(dispatch(c, x)))), parameters(c))
end

Expand Down

0 comments on commit 0c9e1e5

Please sign in to comment.