Skip to content

Commit

Permalink
Better support for AbstractString (#397)
Browse files Browse the repository at this point in the history
* Support `AbstractString`

* Add tests

* Bump version
  • Loading branch information
devmotion authored Jan 30, 2023
1 parent 8c3f5ed commit 711a298
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 33 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "Chain types and utility functions for MCMC simulations."
version = "5.6.1"
version = "5.7.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
22 changes: 11 additions & 11 deletions src/chains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ function Chains(
end

"""
Chains(c::Chains, section::Union{Symbol,String})
Chains(c::Chains, section::Union{Symbol,AbstractString})
Chains(c::Chains, sections)
Return a new chain with only a specific `section` or multiple `sections` pulled out.
Expand All @@ -101,7 +101,7 @@ julia> names(chn2)
:a
```
"""
Chains(c::Chains, section::Union{Symbol,String}) = Chains(c, (section,))
Chains(c::Chains, section::Union{Symbol,AbstractString}) = Chains(c, (section,))
function Chains(chn::Chains, sections)
# Make sure the sections exist first.
all(haskey(chn.name_map, Symbol(x)) for x in sections) ||
Expand All @@ -121,7 +121,7 @@ Chains(chain::Chains, ::Nothing) = chain
# Groups of parameters

"""
namesingroup(chains::Chains, sym::Symbol; index_type::Symbol=:bracket)
namesingroup(chains::Chains, sym::Union{AbstractString,Symbol}; index_type::Symbol=:bracket)
Return the parameters with the same name `sym`, but have a different index. Bracket indexing format
in the form of `:sym[index]` is assumed by default. Use `index_type=:dot` for parameters with dot
Expand All @@ -147,7 +147,7 @@ julia> namesingroup(chn, :A; index_type=:dot)
Symbol("A.2")
```
"""
namesingroup(chains::Chains, sym::String; kwargs...) = namesingroup(chains, Symbol(sym); kwargs...)
namesingroup(chains::Chains, sym::AbstractString; kwargs...) = namesingroup(chains, Symbol(sym); kwargs...)
function namesingroup(chains::Chains, sym::Symbol; index_type::Symbol=:bracket)
if index_type !== :bracket && index_type !== :dot
error("index_type must be :bracket or :dot")
Expand All @@ -161,14 +161,14 @@ function namesingroup(chains::Chains, sym::Symbol; index_type::Symbol=:bracket)
end

"""
group(chains::Chains, name::Union{String,Symbol}; index_type::Symbol=:bracket)
group(chains::Chains, name::Union{AbstractString,Symbol}; index_type::Symbol=:bracket)
Return a subset of the chain containing parameters with the same `name`, but a different index.
Bracket indexing format in the form of `:name[index]` is assumed by default. Use `index_type=:dot` for parameters with dot
indexing, i.e. `:sym.index`.
"""
function group(chains::Chains, name::Union{String,Symbol}; kwargs...)
function group(chains::Chains, name::Union{AbstractString,Symbol}; kwargs...)
return chains[:, namesingroup(chains, name; kwargs...), :]
end

Expand All @@ -177,8 +177,8 @@ end
Base.getindex(c::Chains, i::Integer) = c[i, :, :]
Base.getindex(c::Chains, i::AbstractVector{<:Integer}) = c[i, :, :]

Base.getindex(c::Chains, v::String) = c[:, Symbol(v), :]
Base.getindex(c::Chains, v::AbstractVector{String}) = c[:, Symbol.(v), :]
Base.getindex(c::Chains, v::AbstractString) = c[:, Symbol(v), :]
Base.getindex(c::Chains, v::AbstractVector{<:AbstractString}) = c[:, Symbol.(v), :]

Base.getindex(c::Chains, v::Symbol) = c[:, v, :]
Base.getindex(c::Chains, v::AbstractVector{Symbol}) = c[:, v, :]
Expand All @@ -199,7 +199,7 @@ _toindex(i, j, k::Integer) = (i, string2symbol(j), k:k)
_toindex(i::Integer, j, k::Integer) = (i:i, string2symbol(j), k:k)

# return an array or a number if a single parameter is specified
const SingleIndex = Union{Symbol,String,Integer}
const SingleIndex = Union{Symbol,AbstractString,Integer}
_toindex(i, j::SingleIndex, k) = (i, string2symbol(j), k)
_toindex(i::Integer, j::SingleIndex, k) = (i, string2symbol(j), k)
_toindex(i, j::SingleIndex, k::Integer) = (i, string2symbol(j), k)
Expand Down Expand Up @@ -542,7 +542,7 @@ Return multiple `Chains` objects, each containing only a single section.
function get_sections(chains::Chains, sections = keys(chains.name_map))
return [Chains(chains, section) for section in sections]
end
get_sections(chains::Chains, section::Union{Symbol, String}) = Chains(chains, section)
get_sections(chains::Chains, section::Union{Symbol, AbstractString}) = Chains(chains, section)

"""
sections(c::Chains)
Expand Down Expand Up @@ -727,7 +727,7 @@ function _clean_sections(chains::Chains, sections)
haskey(chains.name_map, Symbol(section))
end
end
function _clean_sections(chains::Chains, section::Union{String,Symbol})
function _clean_sections(chains::Chains, section::Union{AbstractString,Symbol})
return haskey(chains.name_map, Symbol(section)) ? section : ()
end
_clean_sections(::Chains, ::Nothing) = nothing
Expand Down
10 changes: 5 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ end
Convert strings to symbols.
If `x isa String`, the corresponding `Symbol` is returned. Likewise, if
`x isa AbstractVector{String}`, the corresponding vector of `Symbol`s is returned. In all
other cases, input `x` is returned.
If `x isa AbstractString`, the corresponding `Symbol` is returned.
Likewise, if `x isa AbstractVector{<:AbstractString}`, the corresponding vector of `Symbol`s is returned.
In all other cases, input `x` is returned.
"""
string2symbol(x) = x
string2symbol(x::String) = Symbol(x)
string2symbol(x::AbstractVector{String}) = Symbol.(x)
string2symbol(x::AbstractString) = Symbol(x)
string2symbol(x::AbstractVector{<:AbstractString}) = Symbol.(x)

#################### Mathematical Operators ####################
function cummean(x::AbstractArray)
Expand Down
66 changes: 50 additions & 16 deletions test/diagnostic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,42 @@ end
end

@testset "indexing tests" begin
@test chn[:,1,:] isa AbstractMatrix
@test chn[200:300, "param_1", :] isa AbstractMatrix
@test chn[200:300, ["param_1", "param_3"], :] isa Chains
@test chn[200:300, "param_1", 1] isa AbstractVector
@test size(chn[:,1,:]) == (niter, nchains)
@test chn[:,1,1] == val[:,1,1]
@test chn[:,1,2] == val[:,1,2]
c = chn[:, 1, :]
@test c isa AbstractMatrix
@test size(c) == (niter, nchains)
@test c == val[:, 1, :]

for i in 1:2
c = chn[:, 1, i]
@test c isa AbstractVector
@test length(c) == niter
@test c == val[:, 1, i]
end

for p in (:param_1, "param_1", SubString("param_1", 1))
c = chn[200:300, p, :]
@test c isa AbstractMatrix
@test size(c) == (101, size(chn, 3))
@test c == val[200:300, 1, :]

c = chn[200:300, p, 1]
@test c isa AbstractVector
@test length(c) == 101
@test c == val[200:300, 1, 1]
end

for ps in (
[:param_1, :param_3],
["param_1", "param_3"],
[SubString("param_1", 1), "param_3"],
["param_1", SubString("param_3", 1)],
[SubString("param_1", 1), SubString("param_3", 1)],
)
c = chn[200:300, ps, :]
@test c isa Chains
@test size(c) == (101, 2, nchains)
@test c.value.data == val[200:300, [1, 3], :]
end
end

@testset "names and groups tests" begin
Expand All @@ -116,18 +145,23 @@ end
(@inferred replacenames(chn, Dict("param_2" => "param[2]",
"param_3" => "param[3]"))).value
@test names(chn2) == [:param_1, Symbol("param[2]"), Symbol("param[3]"), :param_4]
@test namesingroup(chn2, "param") == Symbol.(["param[2]", "param[3]"])
for p in (:param, "param", SubString("param", 1))
@test namesingroup(chn2, p) == Symbol.(["param[2]", "param[3]"])
end

chn3 = group(chn2, "param")
@test names(chn3) == Symbol.(["param[2]", "param[3]"])
@test chn3.value == chn[:, [:param_2, :param_3], :].value
for p in (:param, "param", SubString("param", 1))
chn3 = group(chn2, p)
@test names(chn3) == Symbol.(["param[2]", "param[3]"])
@test chn3.value == chn[:, [:param_2, :param_3], :].value
end

stan_chn = Chains(rand(100, 3, 1), ["a.1", "a[2]", "b"])
@test namesingroup(stan_chn, "a"; index_type=:dot) == [Symbol("a.1")]
@test namesingroup(stan_chn, :a; index_type=:dot) == [Symbol("a.1")]
@test names(group(stan_chn, :a; index_type=:dot)) == [Symbol("a.1")]
@test_throws Exception namesingroup(stan_chn, :a; index_type=:x)
@test_throws Exception group(stan_chn, :a; index_type=:x)
for p in (:a, "a", SubString("a", 1))
@test namesingroup(stan_chn, p; index_type=:dot) == [Symbol("a.1")]
@test names(group(stan_chn, p; index_type=:dot)) == [Symbol("a.1")]
@test_throws Exception namesingroup(stan_chn, p; index_type=:x)
@test_throws Exception group(stan_chn, p; index_type=:x)
end
end

@testset "function tests" begin
Expand Down

2 comments on commit 711a298

@cpfiffer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/76661

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v5.7.0 -m "<description of version>" 711a298833705df92821296d13de376abb07ea6a
git push origin v5.7.0

Please sign in to comment.