Skip to content
This repository has been archived by the owner on Dec 18, 2021. It is now read-only.

Zygote: transposing a block yields DimensionMismatch (or missing adjoint) #170

Open
vincentelfving opened this issue Nov 14, 2021 · 7 comments · Fixed by #171
Open

Zygote: transposing a block yields DimensionMismatch (or missing adjoint) #170

vincentelfving opened this issue Nov 14, 2021 · 7 comments · Fixed by #171

Comments

@vincentelfving
Copy link

vincentelfving commented Nov 14, 2021

For certain kinds of WF overlap gradients, one would need to conjugate-transpose a unitary that has appeared before, so one needs both the regular and daggered version to appear in the expression (cost or loss function). However, I get errors when using the daggered version.

A minimal reproducing example with even just 1 block, which for now is the non-daggered version, U:

using Zygote
using Yao
using YaoBlocks

N=2
psi_0 = zero_state(N)
U0 = chain(N, put(1=>Rx(0.0)), put(2=>Ry(0.0)))
C = sum([chain(N, put(k=>Z)) for k=1:N])

function loss(theta)
    U = dispatch(U0, theta)
    psi0 = copy(psi_0)
    psi1 = apply(psi0, U)
    psi2 = apply(psi1, C)
    result = real(sum(conj(state(psi1)) .* state(psi2)))
    return result
end

theta = [1.1,2.2]
println(expect'(C, copy(psi_0) => dispatch(U0, theta))[2])
grad = Zygote.gradient(theta->loss(theta), theta)[1]
println(grad)

In this case, the above loss function computes effectively an expectation value equivalent to expect(C, psi_0 => U). Therefore, expect' and zygote.gradient yield the same result [-0.8912073600614354, -0.8084964038195902], as expected.

However, if we instead select the conjugate transpose, daggered version, of U, the expect' version correctly returns the gradient as

julia> expect'(C, copy(psi_0) => dispatch(U0, theta)')[2]
2-element Vector{Float64}:
 0.8084964038195901
 0.8912073600614354

But when I attempt the same in Zygote by setting psi1 = apply(psi0, U') we get an error message:

ERROR: LoadError: DimensionMismatch("variable with size(x) == (2,) cannot have a gradient with size(dx) == (1, 2)")
Stacktrace:
 [1] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}})(dx::Matrix{Float64})
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/7ZiwT/src/projection.jl:226
 [2] ProjectTo
   @ ~/.julia/packages/ChainRulesCore/7ZiwT/src/projection.jl:247 [inlined]
 [3] _project
   @ ~/.julia/packages/Zygote/AlLTp/src/compiler/chainrules.jl:182 [inlined]
 [4] map(f::typeof(Zygote._project), t::Tuple{Vector{Float64}}, s::Tuple{LinearAlgebra.Adjoint{Float64, Vector{Float64}}})
   @ Base ./tuple.jl:232
 [5] gradient(f::Function, args::Vector{Float64})
   @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface.jl:77
 [6] top-level scope
   @ ~/zygote_bug_reproducer.jl:21
 [7] include(fname::String)
   @ Base.MainInclude ./client.jl:444
 [8] top-level scope
   @ REPL[6]:1
in expression starting at /zygote_bug_reproducer.jl:21

I have also tried copying U first by setting psi1 = apply(psi0, copy(U)') or psi1 = apply(psi0, copy(U')), (because the 'lazy' dagger operation might cause trouble?) but then I think Zygote gets lost tracking parameters across the copied block, because I get another kind of error even if I don't perform the dagger operation but simply psi1 = apply(psi0, copy(U)):

ERROR: LoadError: Need an adjoint for constructor ChainBlock{2}. Gradient is of type LinearAlgebra.Adjoint{Float64, Vector{Float64}}

any thoughts on how to enable daggered blocks with the Zygote patch? Am I taking the wrong approach, or is it simply defining the correct adjoint/rule like @GiggleLiu did for the other blocks like Add? is it an issue with double-daggered definitions? Thanks as always!

@GiggleLiu
Copy link
Member

Thanks for the issue, this is a problem causes by returning circuit gradients as vector. I made a patch for it, can you check if it solves your issue? #171

@vincentelfving
Copy link
Author

@GiggleLiu perfect, I have tested a few cases and indeed this patch works for me! #171

Do you recommend I put the Zygote.accum method for AbstractBlock in my own modules or is it generally applicable and will also be part of chainrules_patch.jl? (now I see it is in the test file)

@GiggleLiu
Copy link
Member

@GiggleLiu perfect, I have tested a few cases and indeed this patch works for me! #171

Do you recommend I put the Zygote.accum method for AbstractBlock in my own modules or is it generally applicable and will also be part of chainrules_patch.jl? (now I see it is in the test file)

We will not add this patch to YaoBlocks, because Zygote is very slow in loading and sometimes has version issues. E.g. now tests break on nightly due to using zygote in tests. In the future, we might switch a more correct implementation of constructing Tangent type for the circuit.

@vincentelfving
Copy link
Author

@GiggleLiu ok understood! One remaining issue with the current rrules is the following tiny modification I made to the code:

using Zygote
using Yao
using YaoBlocks

function Zygote.accum(a::AbstractBlock, b::AbstractBlock)
    dispatch(a, parameters(a) + parameters(b))
end

N=2
psi_0 = zero_state(N)
U0 = chain(N, put(1=>Rx(0.0)), put(2=>Ry(0.0)))

function loss(theta)
    C = sum([chain(N, put(k=>Z)) for k=1:N])
    U = dispatch(U0, theta)
    psi0 = copy(psi_0)
    psi1 = apply(psi0, U)
    psi2 = apply(psi1, C)
    result = real(sum(conj(state(psi1)) .* state(psi2)))
    return result
end

theta = [1.1,2.2]
println(expect'(C, copy(psi_0) => dispatch(U0, theta))[2])
grad = Zygote.gradient(theta->loss(theta), theta)[1]
println(grad)

as compared to the opening of this issue, I only added the Zygote.accum, and moved the C, which is an Add block of chain, into the loss function code. There is even no dependence on parameters, but I get the following error:

ERROR: LoadError: Need an adjoint for constructor ChainBlock{2}. Gradient is of type Add{2}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.Jnew{ChainBlock{2}, Nothing, false})(Δ::Add{2})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/lib/lib.jl:324
  [3] (::Zygote.var"#1768#back#224"{Zygote.Jnew{ChainBlock{2}, Nothing, false}})(Δ::Add{2})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [4] Pullback
    @ ~/.julia/packages/YaoBlocks/O1EqK/src/composite/chain.jl:13 [inlined]
  [5] (::typeof(∂(ChainBlock{2})))(Δ::Add{2})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
  [6] Pullback
    @ ~/.julia/packages/YaoBlocks/O1EqK/src/composite/chain.jl:13 [inlined]
  [7] (::typeof(∂(ChainBlock)))(Δ::Add{2})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
  [8] Pullback
    @ ~/.julia/packages/YaoBlocks/O1EqK/src/composite/chain.jl:17 [inlined]
  [9] (::typeof(∂(ChainBlock)))(Δ::Add{2})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [10] Pullback
    @ ~/.julia/packages/YaoBlocks/O1EqK/src/composite/chain.jl:48 [inlined]
 [11] (::typeof(∂(chain)))(Δ::Add{2})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [12] Pullback
    @ ~/.julia/packages/YaoBlocks/O1EqK/src/composite/chain.jl:45 [inlined]
 [13] (::typeof(∂(chain)))(Δ::Add{2})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [14] Pullback
    @ ./none:0 [inlined]
 [15] (::typeof(∂(#201)))(Δ::Add{2})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [16] #557
    @ ~/.julia/packages/Zygote/AlLTp/src/lib/array.jl:202 [inlined]
 [17] #4
    @ ./generator.jl:36 [inlined]
 [18] iterate
    @ ./generator.jl:47 [inlined]
 [19] collect(itr::Base.Generator{Base.Iterators.Zip{Tuple{Vector{Tuple{ChainBlock{2}, typeof(∂(#201))}}, FillArrays.Fill{Add{2}, 1, Tuple{Base.OneTo{Int64}}}}}, Base.var"#4#5"{Zygote.var"#557#562"}})
    @ Base ./array.jl:678
 [20] map
    @ ./abstractarray.jl:2383 [inlined]
 [21] map_back
    @ ~/.julia/packages/Zygote/AlLTp/src/lib/array.jl:202 [inlined]
 [22] (::Zygote.var"#back#591"{Zygote.var"#map_back#561"{var"#201#202", 1, Tuple{UnitRange{Int64}}, Tuple{Tuple{Base.OneTo{Int64}}}, Vector{Tuple{ChainBlock{2}, typeof(∂(#201))}}}})(ȳ::FillArrays.Fill{Add{2}, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/lib/array.jl:247
 [23] Pullback
    @ ~/reproducing_chainblock_bug.jl:14 [inlined]
 [24] (::typeof(∂(loss)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [25] Pullback
    @ ~/reproducing_chainblock_bug.jl:25 [inlined]
 [26] (::Zygote.var"#55#56"{typeof(∂(#203))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface.jl:41
 [27] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface.jl:76
 [28] top-level scope
    @ ~/reproducing_chainblock_bug.jl:25
 [29] include(fname::String)
    @ Base.MainInclude ./client.jl:444
 [30] top-level scope
    @ REPL[7]:1
in expression starting at /reproducing_chainblock_bug.jl:25

Not only is this a current issue, but also in general I would like to differentiate such type of blocks, for example if a parameter (from the perspective of Zygote, not a Yao dispatched param per se) appears in there. Please let me know if you want me to open a new issue as a copy of this comment.

@GiggleLiu
Copy link
Member

GiggleLiu commented Nov 16, 2021

Hi, I just used the correct tangent type for gradients. It seems to solve your problem, and you do not need the patch anymore. However, there are still some issues unsettled. I am not good at debugging Zygote, if some one can help, that would be great. I posted a WIP PR in #173

@vincentelfving
Copy link
Author

yes this works for my problem and the accum patch is no longer needed thanks!

@Roger-luo
Copy link
Member

I haven't got time to look into the Zygote issue (just traveling Boston), e.g the one @GiggleLiu posted

julia> Zygote.gradient(x->x.content.theta, Daggered(Rx(0.5)))
(nothing,)

let me take a second look at this a bit later to fully resolve this issue. But glad to see the main case is fixed

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants