Skip to content

Commit e1cffcf

Browse files
jkosataoameye
andauthored
plotting branches, optimized sorting, deps removed (#88)
* align_pair ducktyped to simplify 1D sorting * get_distance_matrix optimized for 2D * Julia rolled back to 1.8.2 * fix TD plot bug * Plots recipes (#87) * simplified Base MD * branches kw for transform_solutions * transform_solutions refactored * branches keyword for plotting * indexing fix, version bump * bugfixes for arrays * realify debugged * plotting tests to `sorting_optimization` branch (#93) * fixed Spaghetti_plot and 2D_cut bug * adding plotting tests Co-authored-by: Orjan Ameye <[email protected]>
1 parent 649916d commit e1cffcf

File tree

14 files changed

+146
-126
lines changed

14 files changed

+146
-126
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "HarmonicBalance"
22
uuid = "e13b9ff6-59c3-11ec-14b1-f3d2cc6c135e"
33
authors = ["Jan Kosata <[email protected]>", "Javier del Pino <[email protected]>"]
4-
version = "0.6.3"
4+
version = "0.6.4"
55

66
[deps]
77
BijectiveHilbert = "91e7fc40-53cd-4118-bd19-d7fcd1de2a54"
@@ -37,7 +37,7 @@ Peaks = "0.4.1"
3737
Plots = "1.36.4"
3838
ProgressMeter = "1.7.2"
3939
Symbolics = "4.13.0"
40-
julia = "1.8.3"
40+
julia = "1.8.2"
4141

4242
[extras]
4343
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

src/HarmonicBalance.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ module HarmonicBalance
2626

2727
export is_real
2828
is_real(x) = abs(imag(x)) / abs(real(x)) < IM_TOL::Float64 || abs(x) < 1e-70
29-
is_real(x::Array) = any(is_real.(x))
29+
is_real(x::Array) = is_real.(x)
3030

3131
# Symbolics does not natively support complex exponentials of variables
3232
import Base: exp

src/HarmonicEquation.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ end
134134

135135

136136
get_variables(p::Problem) = get_variables(p.eom)
137+
get_variables(res::Result) = get_variables(res.problem)
137138

138139

139140
"Get the parameters (not time nor variables) of a HarmonicEquation"

src/HarmonicVariable.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import Symbolics: get_variables; export get_variables
2-
import Base.isequal; export isequal
32

43
# pretty-printing
54
display(var::HarmonicVariable) = display(var.name)
@@ -46,7 +45,7 @@ get_variables(vars::Vector{Num}) = unique(flatten([Num.(get_variables(x)) for x
4645

4746
get_variables(var::HarmonicVariable) = Num.(get_variables(var.symbol))
4847

49-
isequal(v1::HarmonicVariable, v2::HarmonicVariable) = isequal(v1.symbol, v2.symbol)
48+
Base.isequal(v1::HarmonicVariable, v2::HarmonicVariable) = isequal(v1.symbol, v2.symbol)
5049

5150

5251

src/modules/LinearResponse.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ module LinearResponse
99
using ..HC_wrapper
1010
using DocStringExtensions
1111

12-
import Base: *, show; export *, show
12+
import Base: show; export show
1313

1414
include("LinearResponse/types.jl")
1515
include("LinearResponse/utils.jl")

src/modules/LinearResponse/Lorentzian_spectrum.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
Here the methods to find a
33
"""
44
# multiply a peak by a number.
5-
function *(number::Float64, peak::Lorentzian) # multiplication operation
5+
function Base.:*(number::Float64, peak::Lorentzian) # multiplication operation
66
Lorentzian(peak.ω0, peak.Γ, peak.A*number)
77
end
88

99

10-
*(number::Float64, s::JacobianSpectrum) = JacobianSpectrum([number * peak for peak in s.peaks])
10+
Base.:*(number::Float64, s::JacobianSpectrum) = JacobianSpectrum([number * peak for peak in s.peaks])
1111

1212

1313
function show(io::IO, s::JacobianSpectrum)

src/modules/TimeEvolution/ODEProblem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ end
6868

6969

7070
transform_solutions(soln::OrdinaryDiffEq.ODESolution, f::String, harm_eq::HarmonicEquation) = transform_solutions(soln.u, f, harm_eq)
71-
transform_solutions(s::OrdinaryDiffEq.ODESolution, funcs::Vector{String}, harm_eq::HarmonicEquation) = [transform_solutions(s, f, he) for f in funcs]
71+
transform_solutions(s::OrdinaryDiffEq.ODESolution, funcs::Vector{String}, harm_eq::HarmonicEquation) = [transform_solutions(s, f, harm_eq) for f in funcs]
7272

7373

7474

src/modules/TimeEvolution/types.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
import Base: keys, getindex, +
2-
export +, getindex
3-
41
"""
52
63
Represents a sweep of one or more parameters of a `HarmonicEquation`.
@@ -36,12 +33,12 @@ end
3633

3734

3835
# overload so that ParameterSweep can be accessed like a Dict
39-
keys(s::ParameterSweep) = keys(s.functions)
40-
getindex(s::ParameterSweep, i) = getindex(s.functions, i)
36+
Base.keys(s::ParameterSweep) = keys(s.functions)
37+
Base.getindex(s::ParameterSweep, i) = getindex(s.functions, i)
4138

4239

4340
# overload +
44-
function +(s1::ParameterSweep, s2::ParameterSweep)
41+
function Base.:+(s1::ParameterSweep, s2::ParameterSweep)
4542
common_params = intersect(keys(s1), keys(s2))
4643
!isempty(common_params) && error("cannot combine sweeps of the same parameter")
4744
return ParameterSweep(merge(s1.functions, s2.functions))

src/plotting_Plots.jl

Lines changed: 51 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,15 @@ end
7474
"""
7575
$(TYPEDSIGNATURES)
7676
77-
Return an array of bools to mark solutions in `res` which fall into `classes` but not `not_classes`
77+
Return an array of bools to mark solutions in `res` which fall into `classes` but not `not_classes`.
78+
Only `branches` are considered.
7879
"""
79-
function _get_mask(res, classes, not_classes=[])
80-
classes == "all" && return fill(trues(length(res.solutions[1])), size(res.solutions))
81-
bools = vcat([res.classes[c] for c in _vectorise(classes)], [map(.!, res.classes[c]) for c in _vectorise(not_classes)])
82-
map(.*, bools...)
80+
function _get_mask(res, classes, not_classes=[]; branches=1:branch_count(res))
81+
classes == "all" && return fill(trues(length(branches)), size(res.solutions))
82+
bools = vcat([res.classes[c] for c in _str_to_vec(classes)], [map(.!, res.classes[c]) for c in _str_to_vec(not_classes)])
83+
#m = map( x -> [getindex(x, b) for b in [branches...]], map(.*, bools...))
84+
85+
m = map( x -> x[[branches...]], map(.*, bools...))
8386
end
8487

8588

@@ -96,85 +99,62 @@ function _apply_mask(solns::Array{Vector{ComplexF64}}, booleans)
9699
end
97100

98101

99-
# convert x to Float64, raising a warning if complex
100-
function _realify(x, warn=true)
101-
!is_real(x) && !isnan(x) && warn ? (@warn "Values with non-negligible complex parts have been projected on the real axis!", x) : nothing
102-
real(x)
103-
end
102+
""" Project the array `a` into the real axis, warning if its contents are complex. """
103+
function _realify(a::Array{T} where T <: Number; warning="")
104104

105-
function _realify(v::Vector, warn=true)
106-
v_real = Vector{Vector{Float64}}(undef, length(v[1]))
107-
for branch in eachindex(v_real)
108-
values = getindex.(v, branch)
109-
values_real = Vector{Float64}(undef, length(values))
110-
for (j,x) in pairs(values)
111-
if !is_real(x) && !isnan(x) && warn
112-
(@warn "Values with non-negligible complex parts have been projected on the real axis!", x)
113-
warn = false
114-
end
115-
values_real[j] = real(x)
105+
warned = false
106+
a_real = similar(a, Float64)
107+
for i in eachindex(a)
108+
if !isnan(a[i]) && !warned && !is_real(a[i])
109+
@warn "Values with non-negligible complex parts have
110+
been projected on the real axis! " * warning
111+
warned = true
116112
end
117-
v_real[branch] = values_real
118-
end
119-
return v_real
120-
end
121-
122-
function _realify(a::Matrix, warn=true)
123-
a_real = Array{Vector{Vector{Float64}}}(undef, size(a)...)
124-
for idx in CartesianIndices(a)
125-
v = a[idx]
126-
v_real = Vector{Vector{Float64}}(undef, length(v[1]))
127-
for branch in eachindex(v_real)
128-
values = getindex.(v, branch)
129-
values_real = Vector{Float64}(undef, length(values))
130-
for (j,x) in pairs(values)
131-
if !is_real(x) && !isnan(x) && warn
132-
(@warn "Values with non-negligible complex parts have been projected on the real axis!", x)
133-
warn = false
134-
end
135-
values_real[j] = real(x)
136-
end
137-
v_real[branch] = values_real
138-
end
139-
a_real[idx] = v_real
113+
a_real[i] = real(a[i])
140114
end
141115
return a_real
142116
end
143117

144-
_vectorise(s::Vector) = s
145-
_vectorise(s) = [s]
118+
119+
_str_to_vec(s::Vector) = s
120+
_str_to_vec(s) = [s]
146121

147122

148123
# return true if p already has a label for branch index idx
149124
_is_labeled(p::Plots.Plot, idx::Int64) = in(string(idx), [sub[:label] for sub in p.series_list])
150125

151126

152-
function plot1D(res::Result; x::String="default", y::String, class="default", not_class=[], add=false, kwargs...)
127+
function plot1D(res::Result; x::String="default", y::String, class="default", not_class=[], branches=1:branch_count(res), add=false, kwargs...)
153128

154129
if class == "default"
130+
args = [:x => x, :y => y, :branches => branches]
155131
if not_class == [] # plot stable full, unstable dashed
156-
p = plot1D(res; x=x, y=y, class=["physical", "stable"], add=add, kwargs...)
157-
plot1D(res; x=x, y=y, class="physical", not_class="stable", add=true, style=:dash, kwargs...)
132+
p = plot1D(res; args..., class=["physical", "stable"], add=add, kwargs...)
133+
plot1D(res; args..., class="physical", not_class="stable", add=true, style=:dash, kwargs...)
158134
return p
159135
else
160-
p = plot1D(res; x=x, y=y, not_class=not_class, class="physical", add=add, kwargs...)
136+
p = plot1D(res; args..., not_class=not_class, class="physical", add=add, kwargs...)
161137
return p
162138
end
163139
end
164140

165141
dim(res) != 1 && error("1D plots of not-1D datasets are usually a bad idea.")
166142
x = x == "default" ? string(first(keys(res.swept_parameters))) : x
167-
X, Y = transform_solutions(res, [x,y]) # first transform, then filter
168-
Y = _apply_mask(Y, _get_mask(res, class, not_class))
169-
branches = _realify(Y)
143+
X = transform_solutions(res, x, branches=branches)
144+
Y = transform_solutions(res, y, branches=branches)
145+
Y = _apply_mask(Y, _get_mask(res, class, not_class, branches=branches))
146+
147+
# reformat and project onto real, warning if needed
148+
branch_data = [_realify( getindex.(Y, i), warning= "branch " * string(k) ) for (i,k) in enumerate(branches)]
170149

171150
# start a new plot if needed
172151
p = add ? Plots.plot!() : Plots.plot()
173152

174153
# colouring is matched to branch index - matched across plots
175-
for k in findall(x -> !all(isnan.(x)), branches[1:end]) # skip NaN branches but keep indices
176-
l = _is_labeled(p, k) ? nothing : k
177-
Plots.plot!(_realify.(getindex.(X, k)), branches[k]; color=k, label=l, xlabel=latexify(x), ylabel=latexify(y), kwargs...)
154+
for k in findall(x -> !all(isnan.(x)), branch_data) # skip NaN branch_data
155+
global_index = branches[k]
156+
lab = _is_labeled(p, global_index) ? nothing : global_index
157+
Plots.plot!(_realify(getindex.(X, k)), branch_data[k]; color=k, label=lab, xlabel=latexify(x), ylabel=latexify(y), kwargs...)
178158
end
179159

180160
return p
@@ -185,8 +165,7 @@ plot1D(res::Result, y::String; kwargs...) = plot1D(res; y=y, kwargs...)
185165

186166
function plot2D(res::Result; z::String, branch::Int64, class="physical", not_class=[], add=false, kwargs...)
187167
X, Y = values(res.swept_parameters)
188-
Z = getindex.(_apply_mask(transform_solutions(res, z), _get_mask(res, class, not_class)), branch) # first transform, then filter
189-
168+
Z = getindex.(_apply_mask(transform_solutions(res, z, branches=branch), _get_mask(res, class, not_class, branches=branch)), 1) # there is only one branch
190169
p = add ? Plots.plot!() : Plots.plot() # start a new plot if needed
191170

192171
ylab, xlab = latexify.(string.(keys(res.swept_parameters)))
@@ -219,15 +198,17 @@ function plot2D_cut(res::Result; y::String, cut::Pair, class="default", not_clas
219198

220199
X = res.swept_parameters[x]
221200
Y =_apply_mask(transform_solutions(res, y), _get_mask(res, class, not_class)) # first transform, then filter
222-
branches = _realify(x_index==1 ? Y[:, cut_par_index] : Y[cut_par_index, :])
201+
branches = x_index==1 ? Y[:, cut_par_index] : Y[cut_par_index, :]
202+
203+
branch_data = [_realify( getindex.(branches, i), warning= "branch " * string(k) ) for (i,k) in enumerate(1:branch_count(res))]
223204

224205
# start a new plot if needed
225206
p = add ? Plots.plot!() : Plots.plot()
226207

227208
# colouring is matched to branch index - matched across plots
228-
for k in findall(branch -> !all(isnan.(branch)), branches[1:end]) # skip NaN branches but keep indices
209+
for k in findall(branch -> !all(isnan.(branch)), branch_data) # skip NaN branches but keep indices
229210
l = _is_labeled(p, k) ? nothing : k
230-
Plots.plot!(X, branches[k]; color=k, label=l, xlabel=latexify(string(x)), ylabel=latexify(y), kwargs...)
211+
Plots.plot!(X, _realify(getindex.(branches, k)); color=k, label=l, xlabel=latexify(string(x)), ylabel=latexify(y), kwargs...)
231212
end
232213

233214
return p
@@ -302,6 +283,9 @@ Class selection done by passing `String` or `Vector{String}` as kwarg:
302283
Other kwargs are passed onto Plots.gr()
303284
"""
304285
function plot_spaghetti(res::Result; x::String, y::String, z::String, class="default", not_class=[], add=false, kwargs...)::Plots.Plot
286+
if dim(res) == 2
287+
error("Data dimension ", dim(res), " not supported")
288+
end
305289

306290
if class == "default"
307291
if not_class == [] # plot stable full, unstable dashed
@@ -325,16 +309,18 @@ function plot_spaghetti(res::Result; x::String, y::String, z::String, class="def
325309
isnothing(z_index) && error("The variable $z was not swept over.")
326310

327311
Z = res.swept_parameters.vals[z_index]
328-
X = _apply_mask(transform_solutions(res, x), _get_mask(res, class, not_class)) |> _realify
329-
Y = _apply_mask(transform_solutions(res, y), _get_mask(res, class, not_class)) |> _realify
312+
X = _apply_mask(transform_solutions(res, x), _get_mask(res, class, not_class))
313+
Y = _apply_mask(transform_solutions(res, y), _get_mask(res, class, not_class))
330314

331315
# start a new plot if needed
332316
p = add ? Plots.plot!() : Plots.plot()
333317

318+
branch_data = [_realify( getindex.(X, i), warning= "branch " * string(k) ) for (i,k) in enumerate(1:branch_count(res))]
319+
334320
# colouring is matched to branch index - matched across plots
335-
for k in findall(x -> !all(isnan.(x)), X[1:end]) # skip NaN branches but keep indices
321+
for k in findall(x -> !all(isnan.(x)), branch_data) # skip NaN branches but keep indices
336322
l = _is_labeled(p, k) ? nothing : k
337-
Plots.plot!(X[k], Y[k], Z; _set_Plots_default...,
323+
Plots.plot!(_realify(getindex.(X, k)), _realify(getindex.(Y, k)), Z; _set_Plots_default...,
338324
color=k, label=l, xlabel=latexify(x), ylabel=latexify(y), zlabel=latexify(z), xlim=:symmetric, ylim=:symmetric, kwargs...)
339325
end
340326
return p

src/sorting.jl

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -66,39 +66,38 @@ end
6666

6767
"Match each solution from to_sort to a closest partner from refs"
6868
function get_distance_matrix(refs::Vector{Vector{SteadyState}}, to_sort::Vector{SteadyState})
69-
distances = [get_distance_matrix(ref, to_sort) for ref in refs]
69+
distances = map( ref -> get_distance_matrix(ref, to_sort), refs)
7070
lowest_distances = similar(distances[1])
71-
for (i, el) in enumerate(lowest_distances)
72-
lowest_distances[i] = min([d[i] for d in distances]...)
71+
for idx in CartesianIndices(lowest_distances)
72+
lowest_distances[idx] = minimum( x[idx] for x in distances )
7373
end
7474
lowest_distances
7575
end
7676

7777

7878

7979
"""
80-
Match a to_sort vector of solutions to a set of reference vector of solutions.
80+
Match a to_sort vector of solutions to a set of reference vectors of solutions.
8181
Returns a list of Tuples of the form (1, i1), (2, i2), ... such that
8282
reference[1] and to_sort[i1] belong to the same branch
8383
"""
84-
function align_pair(reference::Vector{Vector{SteadyState}}, to_sort::Vector{SteadyState})
85-
84+
function align_pair(reference, to_sort::Vector{SteadyState})
85+
8686
distances = get_distance_matrix(reference, to_sort)
8787
n = length(to_sort)
88-
cartesians = [(j,i) for i in 1:n for j in 1:n]
89-
sorted_cartesians = cartesians[sortperm(vec(distances))]
88+
sorted_cartesians = CartesianIndices(distances)[sortperm(vec(distances))]
9089

9190
matched = falses(n)
9291
matched_ref = falses(n)
9392

94-
sorted = Array{Tuple{Int64, Int64}, 1}(undef, n)
93+
sorted = Vector{CartesianIndex}(undef, n)
9594

96-
for i in 1:length(cartesians)
97-
j,k = sorted_cartesians[i]
95+
for idx in sorted_cartesians
96+
j,k = idx[1], idx[2]
9897
if !matched[k] && !matched_ref[j]
9998
matched[k] = true
10099
matched_ref[j] = true
101-
sorted[j] = (j,k)
100+
sorted[j] = idx
102101
end
103102
end
104103

@@ -107,9 +106,6 @@ function align_pair(reference::Vector{Vector{SteadyState}}, to_sort::Vector{Stea
107106
end
108107

109108

110-
align_pair(ref::Vector{SteadyState}, to_sort::Vector{SteadyState}) = align_pair([ref], to_sort)
111-
112-
113109
"""
114110
Go through a vector of solution and sort each according to Euclidean norm.
115111
"""

0 commit comments

Comments
 (0)