Skip to content
This repository was archived by the owner on Dec 18, 2021. It is now read-only.

Commit 2534320

Browse files
authored
non-inplace version apply, dispatch and setiparams (#162)
* non-inplace version apply, dispatch and setiparams * clean up methods
1 parent 726bd58 commit 2534320

14 files changed

+95
-24
lines changed

src/YaoBlocks.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,13 @@ export AbstractBlock,
4848
ResetTo,
4949
TagBlock,
5050
apply!,
51+
apply,
5152
apply_back!,
5253
chcontent,
5354
chsubblocks,
5455
content,
5556
dispatch!,
57+
dispatch,
5658
expect,
5759
getiparams,
5860
iparams_eltype,
@@ -70,6 +72,7 @@ export AbstractBlock,
7072
print_block,
7173
render_params,
7274
setiparams!,
75+
setiparams,
7376
subblocks,
7477
ishermitian,
7578
nparameters

src/abstract_block.jl

Lines changed: 61 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using YaoBase, YaoArrayRegister, SimpleTraits
55
66
Apply a block (of quantum circuit) to a quantum register.
77
"""
8-
function apply!(r::AbstractRegister, @nospecialize(b::AbstractBlock))
8+
function apply!(r::AbstractRegister, b::AbstractBlock)
99
_check_size(r, b)
1010
_apply!(r, b)
1111
end
@@ -131,34 +131,41 @@ Returns the intrinsic parameters of node `block`, default is an empty tuple.
131131
getiparams(x::AbstractBlock) = ()
132132

133133
"""
134-
setiparams!(block, itr)
135-
setiparams!(block, params...)
134+
setiparams!([f], block, itr)
135+
setiparams!([f], block, params...)
136136
137137
Set the parameters of `block`.
138+
When `f` is provided, set parameters of `block` to the value in `collection` mapped by `f`.
139+
`iter` can be an iterator or a symbol, the symbol can be `:zero`, `:random`.
138140
"""
139-
setiparams!(x::AbstractBlock, args...) =
140-
niparams(x) == length(args) == 0 ? x : throw(NotImplementedError(:setiparams!, (x, args...)))
141-
142-
setiparams!(x::AbstractBlock, it::Union{Tuple,AbstractArray,Base.Generator}) = setiparams!(x, it...)
143-
setiparams!(x::AbstractBlock, a::Number, xs::Number...) =
144-
error("setparams!(x, θ...) is not implemented")
145-
setiparams!(x::AbstractBlock, it::Symbol) = setiparams!(x, render_params(x, it))
141+
function setiparams! end
146142

147143
"""
148-
setiparams(f, block, collection)
144+
setiparams([f], block, itr)
145+
setiparams([f], block, params...)
149146
150-
Set parameters of `block` to the value in `collection` mapped by `f`.
147+
Set the parameters of `block`, the non-inplace version.
148+
When `f` is provided, set parameters of `block` to the value in `collection` mapped by `f`.
149+
`iter` can be an iterator or a symbol, the symbol can be `:zero`, `:random`.
151150
"""
152-
setiparams!(f::Function, x::AbstractBlock, it) =
153-
setiparams!(x, map(x -> f(x...), zip(getiparams(x), it)))
154-
setiparams!(f::Nothing, x::AbstractBlock, it) = setiparams!(x, it)
151+
function setiparams end
155152

156-
"""
157-
setiparams(f, block, symbol)
153+
for F in [:setiparams!, :setiparams]
154+
@eval begin
155+
$F(x::AbstractBlock, args...) =
156+
niparams(x) == length(args) == 0 ? x : throw(NotImplementedError($(QuoteNode(F)), (x, args...)))
158157

159-
Set the parameters to a given symbol, which can be :zero, :random.
160-
"""
161-
setiparams!(f::Function, x::AbstractBlock, it::Symbol) = setiparams!(f, x, render_params(x, it))
158+
$F(x::AbstractBlock, it::Union{Tuple,AbstractArray,Base.Generator}) = $F(x, it...)
159+
$F(x::AbstractBlock, a::Number, xs::Number...) =
160+
error("setparams!(x, θ...) is not implemented")
161+
$F(x::AbstractBlock, it::Symbol) = $F(x, render_params(x, it))
162+
163+
$F(f::Function, x::AbstractBlock, it) =
164+
$F(x, map(x -> f(x...), zip(getiparams(x), it)))
165+
$F(f::Nothing, x::AbstractBlock, it) = $F(x, it)
166+
$F(f::Function, x::AbstractBlock, it::Symbol) = $F(f, x, render_params(x, it))
167+
end
168+
end
162169

163170
"""
164171
parameters(block)
@@ -363,3 +370,37 @@ function parameters_range!(out::Vector{Tuple{T,T}}, block::AbstractBlock) where
363370
parameters_range!(out, subblock)
364371
end
365372
end
373+
374+
# non-inplace versions
375+
"""
376+
apply(register, block)
377+
378+
The non-inplace version of applying a block (of quantum circuit) to a quantum register.
379+
Check `apply!` for the faster inplace version.
380+
"""
381+
apply(r::AbstractRegister, b) = apply!(copy(r), b)
382+
383+
function generic_dispatch!(f::Union{Function,Nothing}, x::AbstractBlock, it::Dispatcher)
384+
x = setiparams(f, x, consume!(it, niparams(x)))
385+
chsubblocks(x, map(subblocks(x)) do each
386+
generic_dispatch!(f, each, it)
387+
end)
388+
end
389+
390+
"""
391+
dispatch(x::AbstractBlock, collection)
392+
393+
Dispatch parameters in collection to block tree `x`, the generic non-inplace version.
394+
395+
!!! note
396+
397+
it will try to dispatch the parameters in collection first.
398+
"""
399+
function dispatch(f::Union{Function,Nothing}, x::AbstractBlock, it)
400+
dp = Dispatcher(it)
401+
res = generic_dispatch!(f, x, dp)
402+
@assert (it isa Symbol || length(it) == dp.loc) "expect $(dp.loc) parameters, got $(length(it))"
403+
return res
404+
end
405+
406+
dispatch(x::AbstractBlock, it) = dispatch(nothing, x, it)

src/primitive/phase_gate.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ mat(::Type{T}, gate::PhaseGate) where {T} = exp(T(im * gate.theta)) * IMatrix{2,
3737
niparams(::Type{<:PhaseGate}) = 1
3838
getiparams(x::PhaseGate) = x.theta
3939
setiparams!(r::PhaseGate, param::Number) = (r.theta = param; r)
40+
setiparams(r::PhaseGate, param::Number) = PhaseGate(param)
4041

4142
# fallback to matrix method if it is not real
4243
YaoBase.isunitary(r::PhaseGate{<:Real}) = true

src/primitive/rotation_gate.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ niparams(::Type{<:RotationGate}) = 1
108108
getiparams(x::RotationGate) = x.theta
109109
# no need to specify the type of param, Julia will try to do the conversion
110110
setiparams!(r::RotationGate, param::Number) where {N,T} = (r.theta = param; r)
111+
setiparams(r::RotationGate, param::Number) where {N,T} = RotationGate(r.block, param)
111112

112113
# fallback to matrix methods if it is not real
113114
YaoBase.isunitary(r::RotationGate{N,<:Real}) where {N} = true

src/primitive/shift_gate.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ cache_key(gate::ShiftGate) = gate.theta
4343
niparams(::Type{<:ShiftGate}) = 1
4444
getiparams(x::ShiftGate) = x.theta
4545
setiparams!(r::ShiftGate, param::Number) = (r.theta = param; r)
46+
setiparams(r::ShiftGate, param::Number) = ShiftGate(param)
4647

4748

4849
Base.adjoint(blk::ShiftGate) = ShiftGate(-blk.theta)

src/primitive/time_evolution.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ cache_key(te::TimeEvolution) = (te.dt, cache_key(te.H))
8181
niparams(::Type{<:TimeEvolution}) = 1
8282
getiparams(x::TimeEvolution) = x.dt
8383
setiparams!(r::TimeEvolution, param::Number) = (r.dt = param; r)
84+
setiparams(r::TimeEvolution, param::Number) = TimeEvolution(r.H, param; tol=r.tol, check_hermicity=false)
8485

8586
function Base.:(==)(lhs::TimeEvolution, rhs::TimeEvolution)
8687
return lhs.H == rhs.H && lhs.dt == rhs.dt

src/routines.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ set specific col of a CSC matrix
3434
end
3535

3636
"""
37-
num_nonzero(nbits, nctrls, U)
37+
num_nonzero(nbits, nctrls, U, [N])
3838
3939
Return number of nonzero entries of the matrix form of control-U gate. `nbits`
4040
is the number of qubits, and `nctrls` is the number of control qubits.
@@ -43,8 +43,8 @@ is the number of qubits, and `nctrls` is the number of control qubits.
4343
return N + (1 << (nbits - nctrls - log2dim1(U))) * (length(U) - size(U, 2))
4444
end
4545

46-
@inline function num_nonzero(nbits::Int, U, N::Int = 1 << nbits)
47-
return N + (1 << (nbits - log2dim1(U))) * (length(U) - size(U, 2))
46+
@inline function num_nonzero(nbits::Int, nctrls::Int, U::SDSparseMatrixCSC, N::Int = 1 << nbits)
47+
return N + (1 << (nbits - nctrls - log2dim1(U))) * (nnz(U) - size(U, 2))
4848
end
4949

5050
"""
@@ -121,7 +121,7 @@ function u1mat(nbits::Int, U1::SDMatrix, ibit::Int)
121121
a, c, b, d = U1
122122
step = 1 << (ibit - 1)
123123
step_2 = 1 << ibit
124-
NNZ = num_nonzero(nbits, U1, N)
124+
NNZ = num_nonzero(nbits, 0, U1, N)
125125

126126
colptr = Vector{Int}(1:2:2*N+1)
127127
rowval = Vector{Int}(undef, NNZ)

test/abstract_blocks.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using YaoBase
55
@test getiparams(phase(0.1)) == 0.1
66
@test getiparams(2 * phase(0.1)) == ()
77
@test_throws NotImplementedError setiparams!(rot(X, 0.5), :nothing)
8+
@test_throws NotImplementedError setiparams(rot(X, 0.5), :nothing)
89
end
910

1011
@testset "block to matrix conversion" begin
@@ -18,6 +19,9 @@ end
1819
@testset "apply lambda" begin
1920
r = rand_state(3)
2021
@test apply!(copy(r), put(1 => X)) apply!(copy(r), put(3, 1 => X))
22+
r2 = copy(r)
23+
@test apply(r, put(1 => X)) apply!(copy(r), put(3, 1 => X))
24+
@test r2.state == r.state
2125
f(x::Float32) = x
2226
@test_throws ErrorException apply!(r, f)
2327
end

test/composite/swap_gate.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,13 @@ end
2828
dispatch!(pb, π)
2929
@test copy(reg) |> pb -im * (copy(reg) |> swap(6, 2, 4))
3030
@test copy(reg) |> pb |> isnormalized
31+
pb = dispatch(pb, π)
32+
@test copy(reg) |> pb -im * (copy(reg) |> swap(6, 2, 4))
33+
@test copy(reg) |> pb |> isnormalized
3134

35+
pb = pswap(6, 2, 4, 0.0)
3236
dispatch!(pb, :random)
3337
@test copy(reg) |> pb invoke(apply!, Tuple{ArrayReg,PutBlock}, copy(reg), pb)
38+
pb = dispatch(pb, :random)
39+
@test copy(reg) |> pb invoke(apply!, Tuple{ArrayReg,PutBlock}, copy(reg), pb)
3440
end

test/primitive/phase_gate.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ end
3030
@test nparameters(phase(0.1)) == 1
3131
@test adjoint(phase(0.1)) == phase(-0.1)
3232
@test dispatch!(phase(0.1), 0.3) == phase(0.3)
33+
@test dispatch(phase(0.1), 3) == phase(3) && eltype(getiparams(dispatch(phase(0.1), 3))) == Int
3334

3435
@testset "test $op" for op in [+, -, *, /]
3536
@test dispatch!(op, phase(0.1), π) == phase(op(0.1, π))

0 commit comments

Comments
 (0)