Skip to content
This repository has been archived by the owner on Dec 18, 2021. It is now read-only.

Commit

Permalink
non-inplace version apply, dispatch and setiparams (#162)
Browse files Browse the repository at this point in the history
* non-inplace version apply, dispatch and setiparams

* clean up methods
  • Loading branch information
GiggleLiu authored Aug 14, 2021
1 parent 726bd58 commit 2534320
Show file tree
Hide file tree
Showing 14 changed files with 95 additions and 24 deletions.
3 changes: 3 additions & 0 deletions src/YaoBlocks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,13 @@ export AbstractBlock,
ResetTo,
TagBlock,
apply!,
apply,
apply_back!,
chcontent,
chsubblocks,
content,
dispatch!,
dispatch,
expect,
getiparams,
iparams_eltype,
Expand All @@ -70,6 +72,7 @@ export AbstractBlock,
print_block,
render_params,
setiparams!,
setiparams,
subblocks,
ishermitian,
nparameters
Expand Down
81 changes: 61 additions & 20 deletions src/abstract_block.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using YaoBase, YaoArrayRegister, SimpleTraits
Apply a block (of quantum circuit) to a quantum register.
"""
function apply!(r::AbstractRegister, @nospecialize(b::AbstractBlock))
function apply!(r::AbstractRegister, b::AbstractBlock)
_check_size(r, b)
_apply!(r, b)
end
Expand Down Expand Up @@ -131,34 +131,41 @@ Returns the intrinsic parameters of node `block`, default is an empty tuple.
getiparams(x::AbstractBlock) = ()

"""
setiparams!(block, itr)
setiparams!(block, params...)
setiparams!([f], block, itr)
setiparams!([f], block, params...)
Set the parameters of `block`.
When `f` is provided, set parameters of `block` to the value in `collection` mapped by `f`.
`iter` can be an iterator or a symbol, the symbol can be `:zero`, `:random`.
"""
setiparams!(x::AbstractBlock, args...) =
niparams(x) == length(args) == 0 ? x : throw(NotImplementedError(:setiparams!, (x, args...)))

setiparams!(x::AbstractBlock, it::Union{Tuple,AbstractArray,Base.Generator}) = setiparams!(x, it...)
setiparams!(x::AbstractBlock, a::Number, xs::Number...) =
error("setparams!(x, θ...) is not implemented")
setiparams!(x::AbstractBlock, it::Symbol) = setiparams!(x, render_params(x, it))
function setiparams! end

"""
setiparams(f, block, collection)
setiparams([f], block, itr)
setiparams([f], block, params...)
Set parameters of `block` to the value in `collection` mapped by `f`.
Set the parameters of `block`, the non-inplace version.
When `f` is provided, set parameters of `block` to the value in `collection` mapped by `f`.
`iter` can be an iterator or a symbol, the symbol can be `:zero`, `:random`.
"""
setiparams!(f::Function, x::AbstractBlock, it) =
setiparams!(x, map(x -> f(x...), zip(getiparams(x), it)))
setiparams!(f::Nothing, x::AbstractBlock, it) = setiparams!(x, it)
function setiparams end

"""
setiparams(f, block, symbol)
for F in [:setiparams!, :setiparams]
@eval begin
$F(x::AbstractBlock, args...) =
niparams(x) == length(args) == 0 ? x : throw(NotImplementedError($(QuoteNode(F)), (x, args...)))

Set the parameters to a given symbol, which can be :zero, :random.
"""
setiparams!(f::Function, x::AbstractBlock, it::Symbol) = setiparams!(f, x, render_params(x, it))
$F(x::AbstractBlock, it::Union{Tuple,AbstractArray,Base.Generator}) = $F(x, it...)
$F(x::AbstractBlock, a::Number, xs::Number...) =
error("setparams!(x, θ...) is not implemented")
$F(x::AbstractBlock, it::Symbol) = $F(x, render_params(x, it))

$F(f::Function, x::AbstractBlock, it) =
$F(x, map(x -> f(x...), zip(getiparams(x), it)))
$F(f::Nothing, x::AbstractBlock, it) = $F(x, it)
$F(f::Function, x::AbstractBlock, it::Symbol) = $F(f, x, render_params(x, it))
end
end

"""
parameters(block)
Expand Down Expand Up @@ -363,3 +370,37 @@ function parameters_range!(out::Vector{Tuple{T,T}}, block::AbstractBlock) where
parameters_range!(out, subblock)
end
end

# non-inplace versions
"""
apply(register, block)
The non-inplace version of applying a block (of quantum circuit) to a quantum register.
Check `apply!` for the faster inplace version.
"""
apply(r::AbstractRegister, b) = apply!(copy(r), b)

function generic_dispatch!(f::Union{Function,Nothing}, x::AbstractBlock, it::Dispatcher)
x = setiparams(f, x, consume!(it, niparams(x)))
chsubblocks(x, map(subblocks(x)) do each
generic_dispatch!(f, each, it)
end)
end

"""
dispatch(x::AbstractBlock, collection)
Dispatch parameters in collection to block tree `x`, the generic non-inplace version.
!!! note
it will try to dispatch the parameters in collection first.
"""
function dispatch(f::Union{Function,Nothing}, x::AbstractBlock, it)
dp = Dispatcher(it)
res = generic_dispatch!(f, x, dp)
@assert (it isa Symbol || length(it) == dp.loc) "expect $(dp.loc) parameters, got $(length(it))"
return res
end

dispatch(x::AbstractBlock, it) = dispatch(nothing, x, it)
1 change: 1 addition & 0 deletions src/primitive/phase_gate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ mat(::Type{T}, gate::PhaseGate) where {T} = exp(T(im * gate.theta)) * IMatrix{2,
niparams(::Type{<:PhaseGate}) = 1
getiparams(x::PhaseGate) = x.theta
setiparams!(r::PhaseGate, param::Number) = (r.theta = param; r)
setiparams(r::PhaseGate, param::Number) = PhaseGate(param)

# fallback to matrix method if it is not real
YaoBase.isunitary(r::PhaseGate{<:Real}) = true
Expand Down
1 change: 1 addition & 0 deletions src/primitive/rotation_gate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ niparams(::Type{<:RotationGate}) = 1
getiparams(x::RotationGate) = x.theta
# no need to specify the type of param, Julia will try to do the conversion
setiparams!(r::RotationGate, param::Number) where {N,T} = (r.theta = param; r)
setiparams(r::RotationGate, param::Number) where {N,T} = RotationGate(r.block, param)

# fallback to matrix methods if it is not real
YaoBase.isunitary(r::RotationGate{N,<:Real}) where {N} = true
Expand Down
1 change: 1 addition & 0 deletions src/primitive/shift_gate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ cache_key(gate::ShiftGate) = gate.theta
niparams(::Type{<:ShiftGate}) = 1
getiparams(x::ShiftGate) = x.theta
setiparams!(r::ShiftGate, param::Number) = (r.theta = param; r)
setiparams(r::ShiftGate, param::Number) = ShiftGate(param)


Base.adjoint(blk::ShiftGate) = ShiftGate(-blk.theta)
Expand Down
1 change: 1 addition & 0 deletions src/primitive/time_evolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ cache_key(te::TimeEvolution) = (te.dt, cache_key(te.H))
niparams(::Type{<:TimeEvolution}) = 1
getiparams(x::TimeEvolution) = x.dt
setiparams!(r::TimeEvolution, param::Number) = (r.dt = param; r)
setiparams(r::TimeEvolution, param::Number) = TimeEvolution(r.H, param; tol=r.tol, check_hermicity=false)

function Base.:(==)(lhs::TimeEvolution, rhs::TimeEvolution)
return lhs.H == rhs.H && lhs.dt == rhs.dt
Expand Down
8 changes: 4 additions & 4 deletions src/routines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ set specific col of a CSC matrix
end

"""
num_nonzero(nbits, nctrls, U)
num_nonzero(nbits, nctrls, U, [N])
Return number of nonzero entries of the matrix form of control-U gate. `nbits`
is the number of qubits, and `nctrls` is the number of control qubits.
Expand All @@ -43,8 +43,8 @@ is the number of qubits, and `nctrls` is the number of control qubits.
return N + (1 << (nbits - nctrls - log2dim1(U))) * (length(U) - size(U, 2))
end

@inline function num_nonzero(nbits::Int, U, N::Int = 1 << nbits)
return N + (1 << (nbits - log2dim1(U))) * (length(U) - size(U, 2))
@inline function num_nonzero(nbits::Int, nctrls::Int, U::SDSparseMatrixCSC, N::Int = 1 << nbits)
return N + (1 << (nbits - nctrls - log2dim1(U))) * (nnz(U) - size(U, 2))
end

"""
Expand Down Expand Up @@ -121,7 +121,7 @@ function u1mat(nbits::Int, U1::SDMatrix, ibit::Int)
a, c, b, d = U1
step = 1 << (ibit - 1)
step_2 = 1 << ibit
NNZ = num_nonzero(nbits, U1, N)
NNZ = num_nonzero(nbits, 0, U1, N)

colptr = Vector{Int}(1:2:2*N+1)
rowval = Vector{Int}(undef, NNZ)
Expand Down
4 changes: 4 additions & 0 deletions test/abstract_blocks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using YaoBase
@test getiparams(phase(0.1)) == 0.1
@test getiparams(2 * phase(0.1)) == ()
@test_throws NotImplementedError setiparams!(rot(X, 0.5), :nothing)
@test_throws NotImplementedError setiparams(rot(X, 0.5), :nothing)
end

@testset "block to matrix conversion" begin
Expand All @@ -18,6 +19,9 @@ end
@testset "apply lambda" begin
r = rand_state(3)
@test apply!(copy(r), put(1 => X)) apply!(copy(r), put(3, 1 => X))
r2 = copy(r)
@test apply(r, put(1 => X)) apply!(copy(r), put(3, 1 => X))
@test r2.state == r.state
f(x::Float32) = x
@test_throws ErrorException apply!(r, f)
end
Expand Down
6 changes: 6 additions & 0 deletions test/composite/swap_gate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@ end
dispatch!(pb, π)
@test copy(reg) |> pb -im * (copy(reg) |> swap(6, 2, 4))
@test copy(reg) |> pb |> isnormalized
pb = dispatch(pb, π)
@test copy(reg) |> pb -im * (copy(reg) |> swap(6, 2, 4))
@test copy(reg) |> pb |> isnormalized

pb = pswap(6, 2, 4, 0.0)
dispatch!(pb, :random)
@test copy(reg) |> pb invoke(apply!, Tuple{ArrayReg,PutBlock}, copy(reg), pb)
pb = dispatch(pb, :random)
@test copy(reg) |> pb invoke(apply!, Tuple{ArrayReg,PutBlock}, copy(reg), pb)
end
1 change: 1 addition & 0 deletions test/primitive/phase_gate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ end
@test nparameters(phase(0.1)) == 1
@test adjoint(phase(0.1)) == phase(-0.1)
@test dispatch!(phase(0.1), 0.3) == phase(0.3)
@test dispatch(phase(0.1), 3) == phase(3) && eltype(getiparams(dispatch(phase(0.1), 3))) == Int

@testset "test $op" for op in [+, -, *, /]
@test dispatch!(op, phase(0.1), π) == phase(op(0.1, π))
Expand Down
4 changes: 4 additions & 0 deletions test/primitive/rotation_gate.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Test, YaoBlocks, YaoArrayRegister
using YaoBlocks.ConstGate: CNOT

@testset "test constructor" for T in [Float16, Float32, Float64]
# NOTE: type should follow the axis
Expand Down Expand Up @@ -30,6 +31,9 @@ end

@testset "test dispatch" begin
@test dispatch!(Rx(0.1), 0.3) == Rx(0.3)
x = Rx(0.1)
@test dispatch(x, 3f0) == Rx(3f0) && eltype(getiparams(dispatch(x, 3f0))) == Float32
@test x == Rx(0.1)
@test nparameters(Rx(0.1)) == 1

@testset "test $op" for op in [+, -, *, /]
Expand Down
1 change: 1 addition & 0 deletions test/primitive/shift_gate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,5 @@ end
@test nparameters(shift(0.1)) == 1
@test parameters(shift(0.1)) == [0.1]
@test parameters(dispatch!(shift(0.1), 0.2)) == [0.2]
@test parameters(dispatch(shift(0.1), 2)) == [2]
end
3 changes: 3 additions & 0 deletions test/primitive/time_evolution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ const hm = heisenberg(4)
@test setiparams!(cte, 0.5).dt == 0.5
@test setiparams!(cte, :random).dt != 0.5
@test setiparams!(cte, :zero).dt == 0.0
@test setiparams(cte, 0.5).dt == 0.5
@test setiparams(cte, :random).dt != 0.5
@test setiparams(cte, :zero).dt == 0.0
end

@testset "test imaginary time evolution" begin
Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ end
g = dispatch!(chain(Rx(0.1), Rx(0.2)), [0.3, 0])
@test getiparams(g[1]) == 0.3
@test getiparams(g[2]) == 0.0

g = dispatch(chain(Rx(0.1), Rx(0.2)), [0f3, 0f0])
@test getiparams(g[1]) === 0f3
@test getiparams(g[2]) === 0f0
end

@testset "abstract blocks" begin
Expand Down

0 comments on commit 2534320

Please sign in to comment.