diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index def74d431..2913d2cd8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -77,13 +77,7 @@ jobs: JULIA_PKG_PRECOMPILE_AUTO: 0 run: julia --color=yes ci/matplotlib.jl - - name: Test downstream packages - if: startsWith(matrix.os, 'ubuntu') - run: | - xvfb-run julia --color=yes ci/downstream.jl GraphRecipes - xvfb-run julia --color=yes ci/downstream.jl StatsPlots - - - name: Test RecipesBase, RecipesPipeline, PlotsBase, Plots + - name: Test all Plots packages timeout-minutes: 60 run: | cmd=(julia --color=yes) @@ -93,6 +87,7 @@ jobs: echo ${cmd[@]} ${cmd[@]} -e ' using Pkg + foreach(name -> Pkg.test(name; coverage=true), ("GraphRecipes", "StatsPlots")) foreach(name -> Pkg.test(name; coverage=true), ("RecipesBase", "RecipesPipeline", "PlotsBase", "Plots")) ' diff --git a/GraphRecipes/Project.toml b/GraphRecipes/Project.toml new file mode 100644 index 000000000..4c3ddc4ad --- /dev/null +++ b/GraphRecipes/Project.toml @@ -0,0 +1,45 @@ +name = "GraphRecipes" +uuid = "bd48cda9-67a9-57be-86fa-5b3c104eda73" +version = "1.0" + +[deps] +AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +GeometryTypes = "4d00f742-c7ba-57c2-abde-4428a4b178cb" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" +InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +NetworkLayout = "46757867-2c16-5918-afeb-47bfcb05e46a" +PlotUtils = "995b91a9-d308-5afd-9ec6-746e21dbc043" +RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[compat] +AbstractTrees = "0.4" +GeometryTypes = "0.8" +Graphs = "1.7" +Interpolations = "0.13 - 0.15" +NaNMath = "1" +NetworkLayout = "0.4" +PlotUtils = "0.6.2, 1" +RecipesBase = "1" +Statistics = "1" +julia = "1.10" + +[extras] +Gtk = "4c0ca9eb-093a-5379-98c5-f87ac0bbbf44" +ImageMagick = "6218d12a-5da1-5696-b52f-db25d2ecc6d1" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +VisualRegressionTests = "34922c18-7c2a-561c-bac1-01e79b2c4c92" + +[targets] +test = ["Gtk", "ImageMagick", "LinearAlgebra", "Logging", "Markdown", "Plots", "Random", "SparseArrays", "StableRNGs", "Test", "VisualRegressionTests"] diff --git a/GraphRecipes/src/GraphRecipes.jl b/GraphRecipes/src/GraphRecipes.jl new file mode 100644 index 000000000..9147c81cc --- /dev/null +++ b/GraphRecipes/src/GraphRecipes.jl @@ -0,0 +1,24 @@ +module GraphRecipes + +using Graphs +using PlotUtils # ColorGradient +using RecipesBase + +using InteractiveUtils # subtypes +using LinearAlgebra +using SparseArrays +using Statistics +using NaNMath +using GeometryTypes +using Interpolations + +import NetworkLayout +import Graphs: rng_from_rng_or_seed + +include("utils.jl") +include("graph_layouts.jl") +include("graphs.jl") +include("misc.jl") +include("trees.jl") + +end diff --git a/GraphRecipes/src/graph_layouts.jl b/GraphRecipes/src/graph_layouts.jl new file mode 100644 index 000000000..d535cbbc0 --- /dev/null +++ b/GraphRecipes/src/graph_layouts.jl @@ -0,0 +1,496 @@ + +# ----------------------------------------------------- +infer_size_from(args...) = maximum(maximum.(args)) + +# see: http://www.research.att.com/export/sites/att_labs/groups/infovis/res/legacy_papers/DBLP-journals-camwa-Koren05.pdf +# also: http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.3.2055&rep=rep1&type=pdf + +function spectral_graph( + adjmat::AbstractMatrix; + node_weights::AbstractVector = ones(size(adjmat, 1)), + kw..., +) + positions = + NetworkLayout.spectral(adjmat; nodeweights = convert(Vector{Float64}, node_weights)) + + ([p[1] for p in positions], [p[2] for p in positions], [p[3] for p in positions]) +end + +function spectral_graph( + source::AbstractVector{Int}, + destiny::AbstractVector{Int}, + weights::AbstractVector; + kw..., +) + spectral_graph(get_adjacency_matrix(source, destiny, weights); kw...) +end + +function spring_graph( + adjmat::AbstractMatrix; + dim = 2, + rng = nothing, + x = rand(rng_from_rng_or_seed(rng, nothing), size(adjmat)[1]), + y = rand(rng_from_rng_or_seed(rng, nothing), size(adjmat)[1]), + z = rand(rng_from_rng_or_seed(rng, nothing), size(adjmat)[1]), + maxiter = 100, + initialtemp = 2.0, + C = 2.0, + kw..., +) + @assert dim == 2 || dim == 3 + T = Float64 + adjmat = make_symmetric(adjmat) + startpostions = if dim == 2 + [Point(T(x[i]), T(y[i])) for i in 1:length(x)] + elseif dim == 3 + [Point(T(x[i]), T(y[i]), T(z[i])) for i in 1:length(x)] + end + + positions = NetworkLayout.spring( + adjmat; + dim, + Ptype = T, + iterations = maxiter, + initialtemp = initialtemp, + C = C, + initialpos = startpostions, + ) + if dim == 2 + ([p[1] for p in positions], [p[2] for p in positions], nothing) + else + ([p[1] for p in positions], [p[2] for p in positions], [p[3] for p in positions]) + end +end + +function spring_graph( + source::AbstractVector{Int}, + destiny::AbstractVector{Int}, + weights::AbstractVector; + kw..., +) + spring_graph(get_adjacency_matrix(source, destiny, weights); kw...) +end + +function sfdp_graph( + adjmat::AbstractMatrix; + dim = 2, + rng = nothing, + x = rand(rng_from_rng_or_seed(rng, nothing), size(adjmat)[1]), + y = rand(rng_from_rng_or_seed(rng, nothing), size(adjmat)[1]), + z = rand(rng_from_rng_or_seed(rng, nothing), size(adjmat)[1]), + maxiter = 100, + tol = 1e-10, + C = 1.0, + K = 1.0, + kw..., +) + @assert dim == 2 || dim == 3 + adjmat = make_symmetric(adjmat) + T = Float64 + startpostions = if dim == 2 + [Point(T(x[i]), T(y[i])) for i in 1:length(x)] + elseif dim == 3 + [Point(T(x[i]), T(y[i]), T(z[i])) for i in 1:length(x)] + end + + positions = NetworkLayout.sfdp( + adjmat; + dim, + Ptype = T, + iterations = maxiter, + tol = tol, + C = C, + K = K, + initialpos = startpostions, + ) + if dim == 2 + ([p[1] for p in positions], [p[2] for p in positions], nothing) + else + ([p[1] for p in positions], [p[2] for p in positions], [p[3] for p in positions]) + end +end + +function sfdp_graph( + source::AbstractVector{Int}, + destiny::AbstractVector{Int}, + weights::AbstractVector; + kw..., +) + sfpd_graph(get_adjacency_matrix(source, destiny, weights); kw...) +end + +circular_graph(args...; kwargs...) = shell_graph(args...; kwargs...) + +function shell_graph( + adjmat::AbstractMatrix; + dim = 2, + rng = nothing, + x = rand(rng_from_rng_or_seed(rng, nothing), size(adjmat)[1]), + y = rand(rng_from_rng_or_seed(rng, nothing), size(adjmat)[1]), + z = rand(rng_from_rng_or_seed(rng, nothing), size(adjmat)[1]), + nlist = Vector{Int}[], + kw..., +) + @assert dim == 2 + positions = NetworkLayout.shell(adjmat; nlist = nlist) + + ([p[1] for p in positions], [p[2] for p in positions], nothing) +end + +function shell_graph( + source::AbstractVector{Int}, + destiny::AbstractVector{Int}, + weights::AbstractVector; + kw..., +) + shell_graph(get_adjacency_matrix(source, destiny, weights); kw...) +end + +# ----------------------------------------------------- + +# Axis-by-Axis Stress Minimization -- Yehuda Koren and David Harel +# See: http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.437.3177&rep=rep1&type=pdf + +# # NOTES: +# # - dᵢⱼ = the "graph-theoretical distance between nodes i and j" +# # = Aᵢⱼ +# # - kᵢⱼ = dᵢⱼ⁻² +# # - b̃ᵢ = ∑ᵢ≠ⱼ ((x̃ⱼ ≤ x̃ᵢ ? 1 : -1) / dᵢⱼ) +# # - need to solve for x each iteration: Lx = b̃ + +# # Solve for one axis at a time while holding the others constant. +# # dims is 2 (2D) or 3 (3D). free_dims is a vector of the dimensions to update (for example if you fix y and solve for x) +# function by_axis_stress_graph(adjmat::AbstractMatrix, node_weights::AbstractVector = ones(size(adjmat,1)); +# dims = 2, free_dims = 1:dims, +# rng = nothing, +# x = rand(rng_from_rng_or_seed(rng, nothing), length(node_weights)), +# y = rand(rng_from_rng_or_seed(rng, nothing), length(node_weights)), +# z = rand(rng_from_rng_or_seed(rng, nothing), length(node_weights))) +# adjmat = make_symmetric(adjmat) +# L, D = compute_laplacian(adjmat, node_weights) + +# n = length(node_weights) +# maxiter = 100 # TODO: something else + +# @assert dims == 2 + +# @show adjmat L + +# for _ in 1:maxiter +# x̃ = x +# b̃ = Float64[sum(Float64[(i==j || adjmat[i,j] == 0) ? 0.0 : ((x̃[j] <= x̃[i] ? 1.0 : -1.0) / adjmat[i,j]) for j=1:n]) for i=1:n] +# @show x̃ b̃ +# x = L \ b̃ + +# xdiff = x - x̃ +# @show norm(xdiff) +# if norm(xdiff) < 1e-4 +# info("converged. norm(xdiff) = $(norm(xdiff))") +# break +# end +# end +# @show x y +# x, y, z +# end + +norm_ij(X, i, j) = sqrt(sum(Float64[(v[i] - v[j])^2 for v in X])) +stress(X, dist, w, i, j) = w[i, j] * (norm_ij(X, i, j) - dist[i, j])^2 +function stress(X, dist, w) + tot = 0.0 + for i in 1:size(X, 1), j in 1:(i - 1) + tot += stress(X, dist, w, i, j) + end + tot +end + +# follows section 2.3 from http://link.springer.com/chapter/10.1007%2F978-3-540-31843-9_25#page-1 +# Localized optimization, updates: x +function by_axis_local_stress_graph( + adjmat::AbstractMatrix; + node_weights::AbstractVector = ones(size(adjmat, 1)), + dim = 2, + free_dims = 1:dim, + rng = nothing, + x = rand(rng_from_rng_or_seed(rng, nothing), length(node_weights)), + y = rand(rng_from_rng_or_seed(rng, nothing), length(node_weights)), + z = rand(rng_from_rng_or_seed(rng, nothing), length(node_weights)), + maxiter = 1000, + kw..., +) + adjmat = make_symmetric(adjmat) + n = length(node_weights) + + # graph-theoretical distance between node i and j (i.e. shortest path distance) + # TODO: calculate a real distance + dist = estimate_distance(adjmat) + # @show dist + + # also known as kᵢⱼ in "axis-by-axis stress minimization". the -2 could also be 0 or -1? + w = dist .^ -2 + + # in each iteration, we update one dimension/node at a time, reducing the total stress with each update + X = dim == 2 ? (x, y) : (x, y, z) + laststress = stress(X, dist, w) + for k in 1:maxiter + for p in free_dims + for i in 1:n + numer, denom = 0.0, 0.0 + for j in 1:n + i == j && continue + numer += + w[i, j] * + (X[p][j] + dist[i, j] * (X[p][i] - X[p][j]) / norm_ij(X, i, j)) + denom += w[i, j] + end + if denom != 0 + X[p][i] = numer / denom + end + end + end + + # check for convergence of the total stress + thisstress = stress(X, dist, w) + if abs(thisstress - laststress) / abs(laststress) < 1e-6 + # info("converged. numiter=$k last=$laststress this=$thisstress") + break + end + laststress = thisstress + end + + dim == 2 ? (X..., nothing) : X +end + +function by_axis_local_stress_graph( + source::AbstractVector{Int}, + destiny::AbstractVector{Int}, + weights::AbstractVector; + kw..., +) + by_axis_local_stress_graph(get_adjacency_matrix(source, destiny, weights); kw...) +end + +# ----------------------------------------------------- + +function buchheim_graph( + adjlist::AbstractVector; + node_weights::AbstractVector = ones(length(adjlist)), + root::Symbol = :top, # flow of tree: left, right, top, bottom + layers_scalar = 1.0, + layers = nothing, + dim = 2, + kw..., +) + # @show adjlist typeof(adjlist) + positions = + NetworkLayout.buchheim(adjlist; nodesize = convert(Vector{Float64}, node_weights)) + Float64[p[1] for p in positions], Float64[p[2] for p in positions], nothing +end + +# ----------------------------------------------------- + +tree_graph(adjmat::AbstractMatrix; kw...) = + tree_graph(get_source_destiny_weight(adjmat)...; kw...) + +function tree_graph( + source::AbstractVector{Int}, + destiny::AbstractVector{Int}, + weights::AbstractVector; + node_weights::AbstractVector = ones(infer_size_from(source, destiny)), + root::Symbol = :top, # flow of tree: left, right, top, bottom + layers_scalar = 1.0, + layers = nothing, + positions = nothing, + dim = 2, + rng = nothing, + add_noise = true, + kw..., +) + extrakw = Dict{Symbol,Any}(kw) + # @show root layers positions dim add_noise extrakw + n = length(node_weights) + + # TODO: compute layers, which get bigger as you go away from the root + if layers == nothing + # layers = rand(rng_from_rng_or_seed(rng, nothing), 1:4, n) + layers = compute_tree_layers2(source, destiny, n) + end + + # reverse direction? + if root in (:top, :right) + layers = -layers + end + + # add noise + if add_noise + layers = layers + 0.6rand(rng_from_rng_or_seed(rng, nothing), size(layers)...) + end + + # TODO: normalize layers somehow so it's in line with distances + layers .*= layers_scalar + if dim == 2 + if root in (:top, :bottom) + extrakw[:y] = layers + extrakw[:free_dims] = if isnothing(positions) + [1] + else + extrakw[:x] = positions + Int[] + end + elseif root in (:left, :right) + extrakw[:x] = layers + # extrakw[:free_dims] = [2] + extrakw[:free_dims] = if isnothing(positions) + [2] + else + extrakw[:y] = positions + Int[] + end + else + error("unknown root: $root") + end + else + error("3d not supported") + end + + # now that we've fixed one dimension, let the stress algo solve for the other(s) + by_axis_local_stress_graph( + get_adjacency_matrix(source, destiny, weights); + node_weights = node_weights, + rng = rng, + dim = dim, + extrakw..., + ) +end + +function adjlist_and_degrees(source, destiny, n) + # build a list of children (adjacency list) + alist = Vector{Int}[Int[] for i in 1:n] + indeg, outdeg = zeros(Int, n), zeros(Int, n) + for (si, di) in zip(source, destiny) + push!(alist[si], di) + indeg[di] += 1 + outdeg[si] += 1 + end + alist, indeg, outdeg +end + +function compute_tree_layers(source, destiny, n) + alist, indeg, outdeg = adjlist_and_degrees(source, destiny, n) + + # choose root to be the node with lots going out, but few coming in + netdeg = outdeg - 50indeg + idxs = sortperm(netdeg, rev = true) + # rootidx = findmax(netdeg) + # @show outdeg indeg netdeg idxs alist + placed = Int[] + + layers = zeros(n) + for i in 1:n + idx = shift!(idxs) + + # first, place this after its parents + for j in placed + if idx in alist[j] + layers[idx] = max(layers[idx], layers[j] + 1) + end + end + + # next, shift its children lower + for j in idxs + if j in alist[idx] + layers[j] = max(layers[j], layers[idx] + 1) + end + end + + push!(placed, idx) + end + layers +end + +# an alternative algo to pick tree layers... generate a list of roots, +# and for each root, make a pass through the tree (without recurrency) +# and push the children below their parents +function compute_tree_layers2(source, destiny, n) + alist, indeg, outdeg = adjlist_and_degrees(source, destiny, n) + roots = filter(i -> indeg[i] == 0, 1:n) + if isempty(roots) + roots = [1] + end + + layers = zeros(Int, n) + for i in roots + shift_children!(layers, alist, Int[], i) + end + + # now that we've shifted children out, move parents closer to their closest children + while true + shifted = false + for parent in 1:n + if !(isempty(alist[parent])) + minidx = minimum(layers[child] for child in alist[parent]) + if layers[parent] < minidx - 1 + shifted = true + layers[parent] = minidx - 1 + end + end + end + shifted || break + end + + layers +end + +function shift_children!(layers, alist, placed, parent) + for idx in alist[parent] + if !(idx in placed) && layers[idx] <= layers[parent] + layers[idx] = layers[parent] + 1 + end + end + for idx in alist[parent] + if idx != parent && !(idx in placed) + push!(placed, idx) + shift_children!(layers, alist, placed, idx) + end + end +end + +# ----------------------------------------------------- + +# TODO: maybe also implement Catmull-Rom Splines? http://www.mvps.org/directx/articles/catmull/ + +# ----------------------------------------------------- + +function arc_diagram( + source::AbstractVector{Int}, + destiny::AbstractVector{Int}, + weights::AbstractVector; + kw..., +) + N = infer_size_from(source, destiny) + X = collect(1:N) + O = zero(X) + X, O, O +end + +# ----------------------------------------------------- + +function chord_diagram( + source::AbstractVector{Int}, + destiny::AbstractVector{Int}, + weights::AbstractVector; + kw..., +) + N = infer_size_from(source, destiny) + nodes = collect(1:N) + δ = 2pi / N + + x = Array{Float64}(undef, N) + y = Array{Float64}(undef, N) + for i in 1:N + v = (i - 1) * δ + x[i] = sin(v) + y[i] = cos(v) + end + + x, y, zero(x) +end diff --git a/GraphRecipes/src/graphs.jl b/GraphRecipes/src/graphs.jl new file mode 100644 index 000000000..15a6127ef --- /dev/null +++ b/GraphRecipes/src/graphs.jl @@ -0,0 +1,1160 @@ +const _graph_funcs = Dict{Symbol,Any}( + :spectral => spectral_graph, + :sfdp => sfdp_graph, + :circular => circular_graph, + :shell => shell_graph, + :spring => spring_graph, + :stress => by_axis_local_stress_graph, + :tree => tree_graph, + :buchheim => buchheim_graph, + :arcdiagram => arc_diagram, + :chorddiagram => chord_diagram, +) + +const _graph_inputs = Dict{Symbol,Any}( + :spectral => :adjmat, + :sfdp => :adjmat, + :circular => :adjmat, + :shell => :adjmat, + :stress => :adjmat, + :spring => :adjmat, + :tree => :sourcedestiny, + :buchheim => :adjlist, + :arcdiagram => :sourcedestiny, + :chorddiagram => :sourcedestiny, +) + +function prepare_graph_inputs(method::Symbol, inputs...; display_n = nothing) + input_type = get(_graph_inputs, method, :sourcedestiny) + if input_type === :adjmat + mat = if display_n === nothing + get_adjacency_matrix(inputs...) + else + get_adjacency_matrix(inputs..., display_n) + end + (mat,) + elseif input_type === :sourcedestiny + get_source_destiny_weight(inputs...) + elseif input_type === :adjlist + (get_adjacency_list(inputs...),) + end +end + +# ----------------------------------------------------- + +function get_source_destiny_weight(mat::AbstractArray{T,2}) where {T} + nrow, ncol = size(mat) # rows are sources and columns are destinies + @assert nrow == ncol + + nosymmetric = !issymmetric(mat) # plots only triu for symmetric matrices + nosparse = !issparse(mat) # doesn't plot zeros from a sparse matrix + + L = length(mat) + + source = Array{Int}(undef, L) + destiny = Array{Int}(undef, L) + weights = Array{T}(undef, L) + + idx = 0 + for i in 1:nrow, j in 1:ncol + value = mat[i, j] + if !isnan(value) && (nosparse || value != zero(T)) # TODO: deal with Nullable + if i < j + idx += 1 + source[idx] = i + destiny[idx] = j + weights[idx] = value + elseif nosymmetric && (i > j) + idx += 1 + source[idx] = i + destiny[idx] = j + weights[idx] = value + end + end + end + resize!(source, idx), resize!(destiny, idx), resize!(weights, idx) +end + +function get_source_destiny_weight(source::AbstractVector, destiny::AbstractVector) + if length(source) != length(destiny) + throw(ArgumentError("Source and destiny must have the same length.")) + end + source, destiny, ones(length(source)) +end + +function get_source_destiny_weight( + source::AbstractVector, + destiny::AbstractVector, + weights::AbstractVector, +) + if !(length(source) == length(destiny) == length(weights)) + throw(ArgumentError("Source, destiny and weights must have the same length.")) + end + source, destiny, weights +end + +function get_source_destiny_weight( + adjlist::AbstractVector{V}, +) where {V<:AbstractVector{T}} where {T<:Any} + source = Int[] + destiny = Int[] + for (i, l) in enumerate(adjlist) + for j in l + push!(source, i) + push!(destiny, j) + end + end + get_source_destiny_weight(source, destiny) +end + +# ----------------------------------------------------- + +get_adjacency_matrix(mat::AbstractMatrix) = mat + +get_adjacency_matrix( + source::AbstractVector{Int}, + destiny::AbstractVector{Int}, + weights::AbstractVector, + n = infer_size_from(source, destiny), +) = Matrix(sparse(source, destiny, weights, n, n)) + +get_adjacency_matrix( + adjlist::AbstractVector{V}, +) where {V<:AbstractVector{T}} where {T<:Any} = + get_adjacency_matrix(get_source_destiny_weight(adjlist)...) + +# ----------------------------------------------------- + +get_adjacency_list(mat::AbstractMatrix) = get_adjacency_list(get_source_destiny_weight(mat)) + +function get_adjacency_list( + source::AbstractVector{Int}, + destiny::AbstractVector{Int}, + weights::AbstractVector, +) + n = infer_size_from(source, destiny) + adjlist = [Int[] for i in 1:n] + for (s, d) in zip(source, destiny) + push!(adjlist[s], d) + end + adjlist +end + +get_adjacency_list(adjlist::AbstractVector{<:AbstractVector{Int}}) = adjlist + +# ----------------------------------------------------- + +function make_symmetric(A::AbstractMatrix) + A = copy(A) + for i in 1:size(A, 1), j in (i + 1):size(A, 2) + A[i, j] = A[j, i] = A[i, j] + A[j, i] + end + A +end + +function compute_laplacian(adjmat::AbstractMatrix, node_weights::AbstractVector) + n, m = size(adjmat) + @assert n == m == length(node_weights) + + # scale the edge values by the product of node_weights, so that "heavier" nodes also form + # stronger connections + adjmat = adjmat .* sqrt(node_weights * node_weights') + + # D is a diagonal matrix with the degrees (total weights for that node) on the diagonal + deg = vec(sum(adjmat; dims = 1)) - diag(adjmat) + D = diagm(0 => deg) + + # Laplacian (L = D - adjmat) + L = eltype(adjmat)[i == j ? deg[i] : -adjmat[i, j] for i in 1:n, j in 1:n] + + L, D +end + +import Graphs + +# TODO: so much wasteful conversion... do better +function estimate_distance(adjmat::AbstractMatrix) + source, destiny, weights = get_source_destiny_weight(sparse(adjmat)) + + g = Graphs.Graph(adjmat) + dists = convert( + Matrix{Float64}, + hcat(map(i -> Graphs.dijkstra_shortest_paths(g, i).dists, Graphs.vertices(g))...), + ) + tot = 0.0 + cnt = 0 + for (i, d) in enumerate(dists) + if d < 1e10 + tot += d + cnt += 1 + end + end + avg = cnt > 0 ? tot / cnt : 1.0 + for (i, d) in enumerate(dists) + if d > 1e10 + dists[i] = 3avg + end + end + dists +end + +function get_source_destiny_weight(g::Graphs.AbstractGraph) + source = Vector{Int}() + destiny = Vector{Int}() + sizehint!(source, Graphs.nv(g)) + sizehint!(destiny, Graphs.nv(g)) + for e in Graphs.edges(g) + push!(source, Graphs.src(e)) + push!(destiny, Graphs.dst(e)) + end + get_source_destiny_weight(source, destiny) +end + +get_adjacency_matrix(g::Graphs.AbstractGraph) = adjacency_matrix(g) + +get_adjacency_matrix( + source::AbstractVector{Int}, + destiny::AbstractVector{Int}, + n = infer_size_from(source, destiny), +) = get_adjacency_matrix(source, destiny, ones(length(source)), n) + +get_adjacency_list(g::Graphs.AbstractGraph) = g.fadjlist + +function format_nodeproperty(prop, n_edges, edge_boxes = 0; fill_value = nothing) + prop isa Array ? + permutedims(vcat(fill(fill_value, edge_boxes + n_edges), vec(prop), fill_value)) : prop +end +# ----------------------------------------------------- + +# a graphplot takes in either an (N x N) adjacency matrix +# note: you may want to pass node weights to markersize or marker_z +# A graph has N nodes where adj_mat[i,j] is the strength of edge i --> j. (adj_mat[i,j]==0 implies no edge) + +# NOTE: this is for undirected graphs... adjmat should be symmetric and non-negative + +const graph_aliases = Dict( + :curvature_scalar => [:curvaturescalar, :curvature], + :node_weights => [:nodeweights], + :nodeshape => [:node_shape, :markershape], + :nodesize => [:node_size, :markersize], + :nodecolor => [:marker_color, :markercolor], + :node_z => [:marker_z], + :nodestrokealpha => [:markerstrokealpha], + :nodealpha => [:markeralpha], + :nodestrokewidth => [:markerstrokewidth], + :nodestrokealpha => [:markerstrokealpha], + :nodestrokecolor => [:markerstrokecolor], + :nodestrokestyle => [:markerstrokestyle], + :shorten => [:shorten_edge], + :axis_buffer => [:axisbuffer], + :edgewidth => [:edge_width, :ew], + :edgelabel => [:edge_label, :el], + :edgelabel_offset => [:edgelabeloffset, :elo], + :self_edge_size => [:selfedgesize, :ses], + :edge_label_box => [:edgelabelbox, :edgelabel_box, :elb], +) + +""" + graphplot(g; kwargs...) + +Visualize the graph `g`, where `g` represents a graph via a matrix or a +`Graphs.graph`. +## Keyword arguments +``` +dim = 2 +free_dims = nothing +T = Float64 +curves = true +curvature_scalar = 0.05 +root = :top +node_weights = nothing +names = [] +fontsize = 7 +nodeshape = :hexagon +nodesize = 0.1 +node_z = nothing +nodecolor = 1 +nodestrokealpha = 1 +nodealpha = 1 +nodestrokewidth = 1 +nodestrokecolor = :black +nodestrokestyle = :solid +nodestroke_z = nothing +rng = nothing +x = nothing +y = nothing +z = nothing +method = :stress +func = get(_graph_funcs, method, by_axis_local_stress_graph) +shorten = 0.0 +axis_buffer = 0.2 +layout_kw = Dict{Symbol,Any}() +edgewidth = (s,d,w)->1 +edgelabel = nothing +edgelabel_offset = 0.0 +self_edge_size = 0.1 +edge_label_box = true +edge_z = nothing +edgecolor = :black +edgestyle = :solid +trim = false +``` + +See the [documentation]( http://docs.juliaplots.org/latest/graphrecipes/introduction/ ) for +more details. +""" +@userplot GraphPlot + +@recipe function f( + g::GraphPlot; + dim = 2, + free_dims = nothing, + T = Float64, + curves = true, + curvature_scalar = 0.05, + root = :top, + node_weights = nothing, + names = [], + fontsize = 7, + nodeshape = :hexagon, + nodesize = 0.1, + node_z = nothing, + nodecolor = 1, + nodestrokealpha = 1, + nodealpha = 1, + nodestrokewidth = 1, + nodestrokecolor = :black, + nodestrokestyle = :solid, + nodestroke_z = nothing, + rng = nothing, + x = nothing, + y = nothing, + z = nothing, + method = :stress, + func = get(_graph_funcs, method, by_axis_local_stress_graph), + shorten = 0.0, + axis_buffer = 0.2, + layout_kw = Dict{Symbol,Any}(), + edgewidth = (s, d, w) -> 1, + edgelabel = nothing, + edgelabel_offset = 0.0, + self_edge_size = 0.1, + edge_label_box = true, + edge_z = nothing, + edgecolor = :black, + edgestyle = :solid, + trim = false, +) + # Process the args so that they are a Graphs.Graph. + if length(g.args) <= 1 && + !(eltype(g.args[1]) <: AbstractArray) && + !(g.args[1] isa Graphs.AbstractGraph) && + method != :chorddiagram && + method != :arcdiagram + if !LinearAlgebra.issymmetric(g.args[1]) || + any(diag(g.args[1]) .!= zeros(length(diag(g.args[1])))) + g.args = (Graphs.DiGraph(g.args[1]),) + elseif LinearAlgebra.issymmetric(g.args[1]) + g.args = (Graphs.Graph(g.args[1]),) + end + end + + # To process aliases that are unique to graphplot, find aliases that are in + # plotattributes and replace the attributes with their aliases. Then delete the alias + # names from the plotattributes dictionary. + @process_aliases plotattributes graph_aliases + for arg in keys(graph_aliases) + remove_aliases!(arg, plotattributes, graph_aliases) + end + # The above process will remove all marker properties from the plotattributes + # dictionary. To ensure consistency between markers and nodes, we replace all marker + # properties with the corresponding node property. + marker_node_collection = zip( + [ + :markershape, + :markersize, + :markercolor, + :marker_z, + :markerstrokealpha, + :markeralpha, + :markerstrokewidth, + :markerstrokealpha, + :markerstrokecolor, + :markerstrokestyle, + ], + [ + nodeshape, + nodesize, + nodecolor, + node_z, + nodestrokealpha, + nodealpha, + nodestrokewidth, + nodestrokealpha, + nodestrokecolor, + nodestrokestyle, + ], + ) + for (markerproperty, nodeproperty) in marker_node_collection + # Make sure that the node properties are row vectors. + nodeproperty isa Array && (nodeproperty = permutedims(vec(nodeproperty))) + plotattributes[markerproperty] = nodeproperty + end + + # If we pass a value of plotattributes[:markershape] that the backend does not + # recognize, then the backend will throw an error. The error is thrown despite the + # fact that we override the default behavior. Custom nodehapes are incompatible + # with the backend's markershapes and thus replaced. + if nodeshape isa Function || + nodeshape isa Array && any([s isa Function for s in nodeshape]) + plotattributes[:markershape] = :circle + end + + @assert dim in (2, 3) + is3d = dim == 3 + adj_mat = get_adjacency_matrix(g.args...) + nr, nc = size(adj_mat) # number of nodes == number of rows + @assert nr == nc + isdirected = + (g.args[1] isa DiGraph || !issymmetric(adj_mat)) && + !in(method, (:tree, :buchheim)) && + !(get(plotattributes, :arrow, true) == false) + if isdirected && (g.args[1] isa Matrix) + g = GraphPlot((adjacency_matrix(DiGraph(g.args[1])),)) + end + + source, destiny, weights = get_source_destiny_weight(g.args...) + if !(eltype(source) <: Integer) + names = unique(sort(vcat(source, destiny))) + source = Int[findfirst(names, si) for si in source] + destiny = Int[findfirst(names, di) for di in destiny] + end + n = infer_size_from(source, destiny) + display_n = trim ? n : nr # number of displayed nodes + n_edges = length(source) + + isnothing(node_weights) && (node_weights = ones(display_n)) + + xyz = is3d ? (x, y, z) : (x, y) + numnothing = count(isnothing, xyz) + + # do we want to compute coordinates? + if numnothing > 0 + isnothing(free_dims) && (free_dims = findall(isnothing, xyz)) # compute free_dims + dat = prepare_graph_inputs(method, source, destiny, weights; display_n = display_n) + x, y, z = func( + dat...; + node_weights = node_weights, + dim = dim, + free_dims = free_dims, + root = root, + rng = rng, + layout_kw..., + ) + end + + # reorient the points after root + if root in (:left, :right) + x, y = y, -x + end + if root == :left + x, y = -x, y + end + if root == :bottom + x, y = x, -y + end + + # Since we do nodehapes manually, they only work with aspect_ratio=1. + # TODO: rescale the nodeshapes based on the ranges of x,y,z. + aspect_ratio --> 1 + if length(axis_buffer) == 1 + axis_buffer = fill(axis_buffer, dim) + end + + # center and rescale to the widest of all dimensions + if method == :arcdiagram + xl, yl = arcdiagram_limits(x, source, destiny) + xlims --> xl + ylims --> yl + aspect_ratio --> :equal + elseif all(axis_buffer .< 0) # equal axes + ahw = 1.2 * 0.5 * maximum(v -> maximum(v) - minimum(v), xyz) + xcenter = mean(extrema(x)) + #xlims --> (xcenter-ahw, xcenter+ahw) + ycenter = mean(extrema(y)) + #ylims --> (ycenter-ahw, ycenter+ahw) + if is3d + zcenter = mean(extrema(z)) + #zlims --> (zcenter-ahw, zcenter+ahw) + end + else + xlims = ignorenan_extrema(x) + if method != :chorddiagram && numnothing > 0 + x .-= mean(x) + x /= (xlims[2] - xlims[1]) + y .-= mean(y) + ylims = ignorenan_extrema(y) + y /= (ylims[2] - ylims[1]) + end + xlims --> extrema_plus_buffer(x, axis_buffer[1]) + ylims --> extrema_plus_buffer(y, axis_buffer[2]) + if is3d + if method != :chorddiagram && numnothing > 0 + zlims = ignorenan_extrema(z) + z .-= mean(z) + z /= (zlims[2] - zlims[1]) + end + zlims --> extrema_plus_buffer(z, axis_buffer[3]) + end + end + xyz = is3d ? (x, y, z) : (x, y) + # Get the coordinates for the edges of the nodes. + node_vec_vec_xy = [] + nodewidth = 0.0 + nodewidth_array = Vector{Float64}(undef, length(x)) + if !(nodeshape isa Array) + nodeshape = repeat([nodeshape], length(x)) + end + if !is3d + for i in eachindex(x) + node_number = + i % length(nodeshape) == 0 ? length(nodeshape) : i % length(nodeshape) + node_weight = + isnothing(node_weights) ? 1 : + (10 + 100node_weights[i] / sum(node_weights)) / 50 + xextent, yextent = if isempty(names) + [ + x[i] .+ [-0.5nodesize * node_weight, 0.5nodesize * node_weight], + y[i] .+ [-0.5nodesize * node_weight, 0.5nodesize * node_weight], + ] + else + annotation_extent( + plotattributes, + ( + x[i], + y[i], + names[ifelse( + i % length(names) == 0, + length(names), + i % length(names), + )], + fontsize * nodesize * node_weight, + ), + ) + end + nodewidth = xextent[2] - xextent[1] + nodewidth_array[i] = nodewidth + if nodeshape[node_number] == :circle + push!( + node_vec_vec_xy, + partialcircle(0, 2π, [x[i], y[i]], 80, nodewidth / 2), + ) + elseif (nodeshape[node_number] == :rect) || + (nodeshape[node_number] == :rectangle) + push!( + node_vec_vec_xy, + [ + (xextent[1], yextent[1]), + (xextent[2], yextent[1]), + (xextent[2], yextent[2]), + (xextent[1], yextent[2]), + (xextent[1], yextent[1]), + ], + ) + elseif nodeshape[node_number] == :hexagon + push!(node_vec_vec_xy, partialcircle(0, 2π, [x[i], y[i]], 7, nodewidth / 2)) + elseif nodeshape[node_number] == :ellipse + nodeheight = (yextent[2] - yextent[1]) + push!( + node_vec_vec_xy, + partialellipse(0, 2π, [x[i], y[i]], 80, nodewidth / 2, nodeheight / 2), + ) + elseif applicable(nodeshape[node_number], x[i], y[i], 0.0, 0.0) + nodeheight = (yextent[2] - yextent[1]) + push!( + node_vec_vec_xy, + nodeshape[node_number](x[i], y[i], nodewidth, nodeheight), + ) + elseif applicable(nodeshape[node_number], x[i], y[i], 0.0) + push!(node_vec_vec_xy, nodeshape[node_number](x[i], y[i], nodewidth)) + else + error( + "Unknown nodeshape: $(nodeshape[node_number]). Choose from :circle, ellipse, :hexagon, :rect or :rectangle or or a custom shape. Custom shapes can be passed as a function customshape such that customshape(x, y, nodeheight, nodewidth) -> nodeperimeter/ customshape(x, y, nodescale) -> nodeperimeter. nodeperimeter must be an array of 2-tuples, where each tuple is a corner of your custom shape, centered at (x, y) and with height nodeheight, width nodewidth or only a nodescale for symmetrically scaling shapes.", + ) + end + end + else + @assert is3d # TODO Make 3d work. + end + # The node_perimter_info list contains the information needed to construct the + # information in node_vec_vec_xy. For example, if (nodeshape[i]==:circle && !is3d), + # then all of the information in node_vec_vec_xy[i] can be summarised with three + # numbers describing the center and the radius of the circle. + node_perimeter_info = [] + for i in eachindex(node_vec_vec_xy) + if nodeshape[i] == :circle + push!( + node_perimeter_info, + GeometryTypes.Circle( + Point((convert(T, x[i]), convert(T, y[i]))), + nodewidth_array[i] / 2, + ), + ) + else + push!(node_perimeter_info, node_vec_vec_xy[i]) + end + end + + # generate a list of colors, one per segment + segment_colors = get(plotattributes, :linecolor, nothing) + edge_label_array = Vector{Tuple}() + edge_label_box_vertices_array = Vector{Array}() + if !isa(edgelabel, Dict) && !isnothing(edgelabel) + tmp = Dict() + if length(size(edgelabel)) < 2 + matrix_size = round(Int, sqrt(length(edgelabel))) + edgelabel = reshape(edgelabel, matrix_size, matrix_size) + end + for i in 1:size(edgelabel)[1], j in 1:size(edgelabel)[2] + if islabel(edgelabel[i, j]) + tmp[(i, j)] = edgelabel[i, j] + end + end + edgelabel = tmp + end + # If the edgelabel dictionary is full of length two tuples, then make all of the + # tuples length three with last element 1. (i.e. a multigraph that has no extra + # edges). + if edgelabel isa Dict + edgelabel = convert(Dict{Any,Any}, edgelabel) + for key in keys(edgelabel) + if length(key) == 2 + edgelabel[(key..., 1)] = edgelabel[key] + end + end + end + edge_has_been_seen = Dict() + for edge in zip(source, destiny) + edge_has_been_seen[edge] = 0 + end + if length(curvature_scalar) == 1 + curvature_scalar = fill(curvature_scalar, size(adj_mat, 1), size(adj_mat, 1)) + end + + edges_list = (T[], T[], T[], T[]) + # TODO do a proper job of calculating nsegments. + nsegments = if curves && (method in (:tree, :buchheim)) + 4 + elseif method == :chorddiagram + 3 + elseif method == :arcdiagram + 30 + elseif curves + 50 + else + 2 + end + + for (edge_num, (si, di, wi)) in enumerate(zip(source, destiny, weights)) + edge_has_been_seen[(si, di)] += 1 + xseg = Float64[] + yseg = Float64[] + zseg = Float64[] + l_wg = Float64[] + + # add a line segment + xsi, ysi, xdi, ydi = shorten_segment(x[si], y[si], x[di], y[di], shorten) + θ = (edge_has_been_seen[(si, di)] - 1) * pi / 8 + if isdirected && si != di && !is3d + xpt, ypt = if method != :chorddiagram + control_point( + xsi, + xdi, + ysi, + ydi, + edge_has_been_seen[(si, di)] * curvature_scalar[si, di] * sign(si - di), + ) + else + (0.0, 0.0) + end + # For directed graphs, shorten the line segment so that the edge ends at + # the perimeter of the destiny node. + if isdirected + _, _, xdi, ydi = + nearest_intersection(xpt, ypt, x[di], y[di], node_perimeter_info[di]) + end + end + if curves + if method in (:tree, :buchheim) + # for trees, shorten should be on one axis only + # dist = sqrt((x[di]-x[si])^2 + (y[di]-y[si])^2) * shorten + dist = shorten * (root in (:left, :bottom) ? 1 : -1) + ishoriz = root in (:left, :right) + xsi, xdi = (ishoriz ? (x[si] + dist, x[di] - dist) : (x[si], x[di])) + ysi, ydi = (ishoriz ? (y[si], y[di]) : (y[si] + dist, y[di] - dist)) + xpts, ypts = directed_curve( + xsi, + xdi, + ysi, + ydi, + xview = get(plotattributes, :xlims, (0, 1)), + yview = get(plotattributes, :ylims, (0, 1)), + root = root, + rng = rng, + ) + append!(xseg, xpts) + append!(yseg, ypts) + append!(l_wg, [wi for i in 1:length(xpts)]) + elseif method == :arcdiagram + r = (xdi - xsi) / 2 + x₀ = (xdi + xsi) / 2 + θ = range(0, stop = π, length = 30) + xpts = x₀ .+ r .* cos.(θ) + ypts = r .* sin.(θ) .+ ysi # ysi == ydi + for x in xpts + push!(xseg, x) + push!(l_wg, wi) + end + # push!(xseg, NaN) + for y in ypts + push!(yseg, y) + end + # push!(yseg, NaN) + else + xpt, ypt = if method != :chorddiagram + control_point( + xsi, + x[di], + ysi, + y[di], + edge_has_been_seen[(si, di)] * + curvature_scalar[si, di] * + sign(si - di), + ) + else + (0.0, 0.0) + end + xpts = [xsi, xpt, xdi] + ypts = [ysi, ypt, ydi] + t = range(0, stop = 1, length = 3) + A = hcat(xpts, ypts) + itp = scale(interpolate(A, BSpline(Cubic(Natural(OnGrid())))), t, 1:2) + tfine = range(0, stop = 1, length = nsegments) + xpts, ypts = [itp(t, 1) for t in tfine], [itp(t, 2) for t in tfine] + if !isnothing(edgelabel) && + haskey(edgelabel, (si, di, edge_has_been_seen[(si, di)])) + q = control_point( + xsi, + x[di], + ysi, + y[di], + ( + edgelabel_offset + + edge_has_been_seen[(si, di)] * curvature_scalar[si, di] + ) * sign(si - di), + ) + + if !any(isnan.(q)) + push!( + edge_label_array, + ( + q..., + string(edgelabel[(si, di, edge_has_been_seen[(si, di)])]), + fontsize, + ), + ) + edge_label_box_vertices = (annotation_extent( + plotattributes, + ( + q[1], + q[2], + edgelabel[(si, di, edge_has_been_seen[(si, di)])], + 0.05fontsize, + ), + )) + push!(edge_label_box_vertices_array, edge_label_box_vertices) + end + end + if method != :chorddiagram && !is3d + append!(xseg, xpts) + append!(yseg, ypts) + push!(l_wg, wi) + else + push!(xseg, xsi, xpt, xdi) + push!(yseg, ysi, ypt, ydi) + is3d && push!(zseg, z[si], z[si], z[di]) + push!(l_wg, wi) + end + end + else + push!(xseg, xsi, xdi) + push!(yseg, ysi, ydi) + is3d && push!(zseg, z[si], z[di]) + if !isnothing(edgelabel) && + haskey(edgelabel, (si, di, edge_has_been_seen[(si, di)])) + q = [(xsi + xdi) / 2, (ysi + ydi) / 2] + + if !any(isnan.(q)) + push!( + edge_label_array, + ( + q..., + string(edgelabel[(si, di, edge_has_been_seen[(si, di)])]), + fontsize, + ), + ) + edge_label_box_vertices = (annotation_extent( + plotattributes, + ( + q[1], + q[2], + edgelabel[(si, di, edge_has_been_seen[(si, di)])], + 0.05fontsize, + ), + )) + push!(edge_label_box_vertices_array, edge_label_box_vertices) + end + end + end + if si == di && !is3d + inds = 1:n .!= si + self_edge_angle = pi / 8 + (edge_has_been_seen[(si, di)] - 1) * pi / 8 + θ1 = unoccupied_angle(xsi, ysi, x[inds], y[inds]) - self_edge_angle / 2 + θ2 = θ1 + self_edge_angle + nodewidth = nodewidth_array[si] + if nodeshape == :circle + xpts = [ + xsi + nodewidth * cos(θ1) / 2, + NaN, + NaN, + NaN, + xsi + nodewidth * cos(θ2) / 2, + ] + xpts[2] = + mean([xpts[1], xpts[end]]) + + 0.5 * (0.5 + edge_has_been_seen[(si, di)]) * self_edge_size * cos(θ1) + xpts[3] = + mean([xpts[1], xpts[end]]) + + edge_has_been_seen[(si, di)] * self_edge_size * cos((θ1 + θ2) / 2) + xpts[4] = + mean([xpts[1], xpts[end]]) + + 0.5 * (0.5 + edge_has_been_seen[(si, di)]) * self_edge_size * cos(θ2) + ypts = [ + ysi + nodewidth * sin(θ1) / 2, + NaN, + NaN, + NaN, + ysi + nodewidth * sin(θ2) / 2, + ] + ypts[2] = + mean([ypts[1], ypts[end]]) + + 0.5 * (0.5 + edge_has_been_seen[(si, di)]) * self_edge_size * sin(θ1) + ypts[3] = + mean([ypts[1], ypts[end]]) + + edge_has_been_seen[(si, di)] * self_edge_size * sin((θ1 + θ2) / 2) + ypts[4] = + mean([ypts[1], ypts[end]]) + + 0.5 * (0.5 + edge_has_been_seen[(si, di)]) * self_edge_size * sin(θ2) + t = range(0, stop = 1, length = 5) + A = hcat(xpts, ypts) + itp = scale(interpolate(A, BSpline(Cubic(Natural(OnGrid())))), t, 1:2) + tfine = range(0, stop = 1, length = nsegments) + xpts, ypts = [itp(t, 1) for t in tfine], [itp(t, 2) for t in tfine] + else + _, _, start_point1, start_point2 = nearest_intersection( + xsi, + ysi, + xsi + 2nodewidth * cos(θ1), + ysi + 2nodewidth * sin(θ1), + node_vec_vec_xy[si], + ) + _, _, end_point1, end_point2 = nearest_intersection( + xsi + + edge_has_been_seen[(si, di)] * (nodewidth + self_edge_size) * cos(θ2), + ysi + + edge_has_been_seen[(si, di)] * (nodewidth + self_edge_size) * sin(θ2), + xsi, + ysi, + node_vec_vec_xy[si], + ) + xpts = [start_point1, NaN, NaN, NaN, end_point1] + xpts[2] = + mean([xpts[1], xpts[end]]) + + 0.5 * (0.5 + edge_has_been_seen[(si, di)]) * self_edge_size * cos(θ1) + xpts[3] = + mean([xpts[1], xpts[end]]) + + edge_has_been_seen[(si, di)] * self_edge_size * cos((θ1 + θ2) / 2) + xpts[4] = + mean([xpts[1], xpts[end]]) + + 0.5 * (0.5 + edge_has_been_seen[(si, di)]) * self_edge_size * cos(θ2) + ypts = [start_point2, NaN, NaN, NaN, end_point2] + ypts[2] = + mean([ypts[1], ypts[end]]) + + 0.5 * (0.5 + edge_has_been_seen[(si, di)]) * self_edge_size * sin(θ1) + ypts[3] = + mean([ypts[1], ypts[end]]) + + edge_has_been_seen[(si, di)] * self_edge_size * sin((θ1 + θ2) / 2) + ypts[4] = + mean([ypts[1], ypts[end]]) + + 0.5 * (0.5 + edge_has_been_seen[(si, di)]) * self_edge_size * sin(θ2) + t = range(0, stop = 1, length = 5) + A = hcat(xpts, ypts) + itp = scale(interpolate(A, BSpline(Cubic(Natural(OnGrid())))), t, 1:2) + tfine = range(0, stop = 1, length = nsegments) + xpts, ypts = [itp(t, 1) for t in tfine], [itp(t, 2) for t in tfine] + end + append!(xseg, xpts) + append!(yseg, ypts) + mid_ind = div(length(xpts), 2) + q = [ + xpts[mid_ind] + edgelabel_offset * cos((θ1 + θ2) / 2), + ypts[mid_ind] + edgelabel_offset * sin((θ1 + θ2) / 2), + ] + if !isnothing(edgelabel) && + haskey(edgelabel, (si, di, edge_has_been_seen[(si, di)])) + push!( + edge_label_array, + ( + q..., + string(edgelabel[(si, di, edge_has_been_seen[(si, di)])]), + fontsize, + ), + ) + edge_label_box_vertices = annotation_extent( + plotattributes, + (q..., edgelabel[(si, di, edge_has_been_seen[(si, di)])], 0.05fontsize), + ) + if !any(isnan.(q)) + push!(edge_label_box_vertices_array, edge_label_box_vertices) + end + end + end + append!(edges_list[1], xseg[.!isnan.(xseg)]) + append!(edges_list[2], yseg[.!isnan.(yseg)]) + is3d && append!(edges_list[3], zseg[.!isnan.(zseg)]) + append!(edges_list[4], l_wg[.!isnan.(l_wg)]) + end + + if is3d + edges_list = ( + reshape(edges_list[1], 3, round(Int, length(edges_list[1]) / 3)), + reshape(edges_list[2], 3, round(Int, length(edges_list[2]) / 3)), + reshape(edges_list[3], 3, round(Int, length(edges_list[3]) / 3)), + ) + else + edges_list = ( + reshape( + edges_list[1], + nsegments, + round(Int, length(edges_list[1]) / nsegments), + ), + reshape( + edges_list[2], + nsegments, + round(Int, length(edges_list[2]) / nsegments), + ), + ) + edges_list = ( + [edges_list[1][:, j] for j in 1:size(edges_list[1], 2)], + [edges_list[2][:, j] for j in 1:size(edges_list[2], 2)], + ) + end + + @series begin + @debug num_edges_nodes := (length(edges_list[1]), length(node_vec_vec_xy)) # for debugging / tests + + seriestype := if method in (:tree, :buchheim, :chorddiagram) + :curves + else + if is3d + # TODO make curves work + if curves + :curves + end + else + :path + end + end + + colorbar_entry := true + + edge_z = process_edge_attribute(edge_z, source, destiny, weights) + edgewidth = process_edge_attribute(edgewidth, source, destiny, weights) + edgecolor = process_edge_attribute(edgecolor, source, destiny, weights) + edgestyle = process_edge_attribute(edgestyle, source, destiny, weights) + + !isnothing(edge_z) && (line_z := edge_z) + linewidthattr = get(plotattributes, :linewidth, 1) + linewidth := linewidthattr * edgewidth + fillalpha := 1 + linecolor := edgecolor + linestyle := get(plotattributes, :linestyle, edgestyle) + markershape := :none + markersize := 0 + markeralpha := 0 + markercolor := :black + marker_z := nothing + isdirected && (arrow --> :simple, :head, 0.3, 0.3) + primary := false + + is3d ? edges_list[1:3] : edges_list[1:2] + end + # The boxes around edge labels are defined as another list of series that sits on top + # of the series for the edges. + edge_has_been_seen = Dict() + for edge in zip(source, destiny) + edge_has_been_seen[edge] = 0 + end + index = 0 + if edge_label_box && !isnothing(edgelabel) + for (edge_num, (si, di, wi)) in enumerate(zip(source, destiny, weights)) + edge_has_been_seen[(si, di)] += 1 + if haskey(edgelabel, (si, di, edge_has_been_seen[(si, di)])) + index += 1 + @series begin + seriestype := :shape + + colorbar_entry --> false + fillcolor --> get(plotattributes, :background_color, :white) + linewidth --> 0 + linealpha --> 0 + edge_label_box_vertices = edge_label_box_vertices_array[index] + ( + [ + edge_label_box_vertices[1][1], + edge_label_box_vertices[1][2], + edge_label_box_vertices[1][2], + edge_label_box_vertices[1][1], + edge_label_box_vertices[1][1], + ], + [ + edge_label_box_vertices[2][1], + edge_label_box_vertices[2][1], + edge_label_box_vertices[2][2], + edge_label_box_vertices[2][2], + edge_label_box_vertices[2][1], + ], + ) + end + end + end + end + + framestyle := :none + axis := nothing + legend --> false + + # Make sure that the node properties are row vectors. + nodeshape = format_nodeproperty(nodeshape, n_edges, index) + nodesize = format_nodeproperty(nodesize, n_edges, index) + nodecolor = format_nodeproperty(nodecolor, n_edges, index) + node_z = format_nodeproperty(node_z, n_edges, index) + nodestrokealpha = format_nodeproperty(nodestrokealpha, n_edges, index) + nodealpha = format_nodeproperty(nodealpha, n_edges, index) + nodestrokewidth = format_nodeproperty(nodestrokewidth, n_edges, index) + nodestrokealpha = format_nodeproperty(nodestrokealpha, n_edges, index) + nodestrokecolor = format_nodeproperty(nodestrokecolor, n_edges, index) + nodestrokestyle = + format_nodeproperty(nodestrokestyle, n_edges, index, fill_value = :solid) + + if method == :chorddiagram + seriestype := :scatter + markersize := 0 + markeralpha := 0 + aspect_ratio --> :equal + if length(names) == length(x) + annotations := [(x[i], y[i], names[i]) for i in eachindex(x)] + end + @series begin + seriestype := :shape + N = length(x) + angles = Vector{Float64}(undef, N) + for i in 1:N + if y[i] > 0 + angles[i] = acos(x[i]) + else + angles[i] = 2pi - acos(x[i]) + end + end + δ = 0.4 * (angles[2] - angles[1]) + vec_vec_xy = [arcshape(Θ - δ, Θ + δ) for Θ in angles] # Shape + [[xy[1] for xy in vec_xy] for vec_xy in vec_vec_xy], + [[xy[2] for xy in vec_xy] for vec_xy in vec_vec_xy] + end + else + if is3d + seriestype := :scatter3d + linewidth := 0 + linealpha := 0 + markercolor := nodecolor + series_annotations --> map(string, names) + markersize --> (10 .+ (100 .* node_weights) ./ sum(node_weights)) + else + @series begin + seriestype := :shape + + colorbar_entry := true + fill_z --> node_z + fillalpha := nodealpha + fillcolor := nodecolor + markersize := 0 + markeralpha := 0 + linewidth := nodestrokewidth + linealpha := nodestrokealpha + linecolor := nodestrokecolor + linestyle := nodestrokestyle + line_z := nodestroke_z + + nodeperimeters = (Any[], Any[]) + for vec_xy in node_vec_vec_xy + push!(nodeperimeters[1], [xy[1] for xy in vec_xy]) + push!(nodeperimeters[2], [xy[2] for xy in vec_xy]) + end + + nodeperimeters + + # if is3d + # seriestype := :volume + # ([[xyz[1] for xyz in vec_xyz] for vec_xyz in node_vec_vec_xyz], + # [[xyz[2] for xyz in vec_xyz] for vec_xyz in node_vec_vec_xyz], + # [[xyz[3] for xyz in vec_xyz] for vec_xyz in node_vec_vec_xyz]) + # end + end + + if isempty(names) + seriestype := :scatter + + colorbar_entry --> false + markersize := 0 + markeralpha := 0 + markerstrokesize := 0 + !isnothing(edgelabel) && (annotations --> edge_label_array) + else + seriestype := :scatter + + colorbar_entry --> false + markersize := 0 + markeralpha := 0 + markerstrokesize := 0 + annotations --> [ + edge_label_array + [ + ( + x[i], + y[i], + names[ifelse( + i % length(names) == 0, + length(names), + i % length(names), + )], + fontsize, + ) for i in eachindex(x) + ] + ] + end + end + end + xyz +end + +@recipe f(g::AbstractGraph) = GraphPlot(get_source_destiny_weight(get_adjacency_list(g))) diff --git a/GraphRecipes/src/misc.jl b/GraphRecipes/src/misc.jl new file mode 100644 index 000000000..6934c7a86 --- /dev/null +++ b/GraphRecipes/src/misc.jl @@ -0,0 +1,115 @@ + +# ------------------------------------------------------------------- +# AST trees + +function add_ast(adjlist, names, depthdict, depthlists, nodetypes, ex::Expr, parent_idx) + idx = length(names) + 1 + iscall = ex.head == :call + push!(names, iscall ? string(ex.args[1]) : string(ex.head)) + push!(nodetypes, iscall ? :call : :expr) + l = Int[] + push!(adjlist, l) + + depth = parent_idx == 0 ? 1 : depthdict[parent_idx] + 1 + depthdict[idx] = depth + while length(depthlists) < depth + push!(depthlists, Int[]) + end + push!(depthlists[depth], idx) + + for arg in (iscall ? ex.args[2:end] : ex.args) + if isa(arg, LineNumberNode) + continue + end + push!(l, add_ast(adjlist, names, depthdict, depthlists, nodetypes, arg, idx)) + end + idx +end + +function add_ast(adjlist, names, depthdict, depthlists, nodetypes, x, parent_idx) + push!(names, string(x)) + push!(nodetypes, :leaf) + push!(adjlist, Int[]) + idx = length(names) + + depth = parent_idx == 0 ? 1 : depthdict[parent_idx] + 1 + depthdict[idx] = depth + while length(depthlists) < depth + push!(depthlists, Int[]) + end + push!(depthlists[depth], idx) + + idx +end + +@recipe function f(ex::Expr) + names = String[] + adjlist = Vector{Int}[] + depthdict = Dict{Int,Int}() + depthlists = Vector{Int}[] + nodetypes = Symbol[] + add_ast(adjlist, names, depthdict, depthlists, nodetypes, ex, 0) + names := names + # method := :tree + method := :buchheim + root --> :top + + # markercolor --> Symbol[(nt == :call ? :pink : nt == :leaf ? :white : :lightgreen) for nt in nodetypes] + + # # compute the y-values from the depthdict dict + # n = length(depthlists)-1 + # layers = Float64[(depthdict[i]-1)/n for i=1:length(names)] + # # add_noise --> false + # + # positions = zeros(length(names)) + # for (depth, lst) in enumerate(depthlists) + # n = length(lst) + # pos = n > 1 ? linspace(0, 1, n) : [0.5] + # for (i, idx) in enumerate(lst) + # positions[idx] = pos[i] + # end + # end + # + # layout_kw := Dict{Symbol,Any}(:layers => layers, :add_noise => false, :positions => positions) + + GraphPlot(get_source_destiny_weight(adjlist)) +end + +# ------------------------------------------------------------------- +# Type trees + +function add_subs!(nodes, source, destiny, ::Type{T}, supidx) where {T} + for sub in subtypes(T) + push!(nodes, sub) + subidx = length(nodes) + push!(source, supidx) + push!(destiny, subidx) + add_subs!(nodes, source, destiny, sub, subidx) + end +end + +# recursively build a graph of subtypes of T +@recipe function f( + ::Type{T}; + namefunc = node -> isa(node, UnionAll) ? split(string(node), '.')[end] : node.name.name, +) where {T} + # get the supertypes + sups = Any[T] + sup = T + while sup != Any + sup = supertype(sup) + pushfirst!(sups, sup) + end + + # add the subtypes + n = length(sups) + nodes = copy(sups) + source, destiny = collect(1:(n - 1)), collect(2:n) + add_subs!(nodes, source, destiny, T, n) + + # set up the graphplot + names := map(namefunc, nodes) + method --> :buchheim + root --> :top + GraphPlot((source, destiny)) +end diff --git a/GraphRecipes/src/trees.jl b/GraphRecipes/src/trees.jl new file mode 100644 index 000000000..b18f0f9d1 --- /dev/null +++ b/GraphRecipes/src/trees.jl @@ -0,0 +1,60 @@ +import AbstractTrees +using AbstractTrees: children + +export TreePlot + +""" + TreePlot(root) + +Wrap a tree-like object for plotting. Uses `AbstractTrees.children()` to recursively add children to the plot and `AbstractTrees.printnode()` to generate the labels. + +# Example + +```julia +using AbstractTrees, GraphRecipes +AbstractTrees.children(d::Dict) = [p for p in d] +AbstractTrees.children(p::Pair) = AbstractTrees.children(p[2]) +function AbstractTrees.printnode(io::IO, p::Pair) + str = isempty(AbstractTrees.children(p[2])) ? string(p[1], ": ", p[2]) : string(p[1], ": ") + print(io, str) +end + +d = Dict(:a => 2,:d => Dict(:b => 4,:c => "Hello"),:e => 5.0) + +plot(TreePlot(d)) +```` +""" +struct TreePlot{T} + root::T +end + +function add_children!(nodes, source, destiny, node, parent_idx) + for child in children(node) + push!(nodes, child) + child_idx = length(nodes) + push!(source, parent_idx) + push!(destiny, child_idx) + add_children!(nodes, source, destiny, child, child_idx) + end +end + +function string_from_node(node) + io = IOBuffer() + AbstractTrees.printnode(io, node) + String(take!(io)) +end + +# recursively build a graph of children of `tree_wrapper.root` +@recipe function f(tree_wrapper::TreePlot; namefunc = string_from_node) + root = tree_wrapper.root + # recursively add children + nodes = Any[root] + source, destiny = Int[], Int[] + add_children!(nodes, source, destiny, root, 1) + + # set up the graphplot + names --> map(namefunc, nodes) + method --> :buchheim + root --> :top + GraphPlot((source, destiny)) +end diff --git a/GraphRecipes/src/utils.jl b/GraphRecipes/src/utils.jl new file mode 100644 index 000000000..37d0351db --- /dev/null +++ b/GraphRecipes/src/utils.jl @@ -0,0 +1,378 @@ +""" +This function builds a BezierCurve which leaves point p vertically upwards and +arrives point q vertically upwards. It may create a loop if necessary. +It assumes the view is [0,1]. That can be modified using the `xview` and +`yview` keyword arguments (default: `0:1`). +""" +function directed_curve( + x1, + x2, + y1, + y2; + xview = 0:1, + yview = 0:1, + root::Symbol = :bottom, + rng = nothing, +) + if root in (:left, :right) + # flip x/y to simplify + x1, x2, y1, y2, xview, yview = y1, y2, x1, x2, yview, xview + end + x = Float64[x1, x1] + y = Float64[y1] + + minx, maxx = extrema(xview) + miny, maxy = extrema(yview) + dist = sqrt((x2 - x1)^2 + (y2 - y1)^2) + flip = root in (:top, :right) + need_loop = (flip && y1 <= y2) || (!flip && y1 >= y2) + + # these points give the initial/final "rise" + # note: this is a function of distance between points and axis scale + y_offset = if need_loop + 0.3dist + else + min(0.3dist, 0.5 * abs(y2 - y1)) + end + y_offset = max(0.02 * (maxy - miny), y_offset) + + if flip + # got the other direction + y_offset *= -1 + end + push!(y, y1 + y_offset) + + # try to figure out when to loop around vs just connecting straight + if need_loop + if abs(x2 - x1) > 0.1 * (maxx - minx) + # go between + sgn = x2 > x1 ? 1 : -1 + x_offset = 0.5 * abs(x2 - x1) + append!(x, [x1 + sgn * x_offset, x2 - sgn * x_offset]) + else + # add curve points which will create a loop + x_offset = + 0.3 * + (maxx - minx) * + (rand(rng_from_rng_or_seed(rng, nothing), Bool) ? 1 : -1) + append!(x, [x1 + x_offset, x2 + x_offset]) + end + append!(y, [y1 + y_offset, y2 - y_offset]) + end + + append!(x, [x2, x2]) + append!(y, [y2 - y_offset, y2]) + if root in (:left, :right) + # flip x/y to simplify + x, y = y, x + end + x, y +end + +function shorten_segment(x1, y1, x2, y2, shorten) + xshort = shorten * (x2 - x1) + yshort = shorten * (y2 - y1) + x1 + xshort, y1 + yshort, x2 - xshort, y2 - yshort +end + +# """ +# shorten_segment_absolute(x1, y1, x2, y2, shorten) +# +# Remove an amount `shorten` from the end of the line [x1,y1] -> [x2,y2]. +# """ +# function shorten_segment_absolute(x1, y1, x2, y2, shorten) +# if x1 == x2 && y1 == y2 +# return x1, y1, x2, y2 +# end +# t = shorten/sqrt(x1*(x1-2x2) + x2^2 + y1*(y1-2y2) + y2^2) +# x1, y1, (1.0-t)*x2 + t*x1, (1.0-t)*y2 + t*y1 +# end + +""" + nearest_intersection(xs, ys, xd, yd, vec_xy_d) + +Find where the line defined by [xs,ys] -> [xd,yd] intersects with the closed shape who's +vertices are stored in `vec_xy_d`. Return the intersection that is closest to the point +[xs,ys] (the source node). +""" +function nearest_intersection(xs, ys, xd, yd, vec_xy_d) + if xs == xd && ys == yd + return xs, ys, xd, yd + end + t = Vector{Float64}(undef, 2) + xvec = Vector{Float64}(undef, 2) + yvec = Vector{Float64}(undef, 2) + xy_d_edge = Vector{Float64}(undef, 2) + ret = Vector{Float64}(undef, 2) + A = Array{Float64}(undef, 2, 2) + nearest = Inf + for i in 1:(length(vec_xy_d) - 1) + xvec .= [vec_xy_d[i][1], vec_xy_d[i + 1][1]] + yvec .= [vec_xy_d[i][2], vec_xy_d[i + 1][2]] + A .= [-xs+xd -xvec[1]+xvec[2]; -ys+yd -yvec[1]+yvec[2]] + t .= (A + eps() * I) \ [xs - xvec[1]; ys - yvec[1]] + xy_d_edge .= + [(1 - t[2]) * xvec[1] + t[2] * xvec[2], (1 - t[2]) * yvec[1] + t[2] * yvec[2]] + if 0 <= t[2] <= 1 + tmp = abs2(xy_d_edge[1] - xs) + abs2(xy_d_edge[2] - ys) + if tmp < nearest + ret .= xy_d_edge + nearest = tmp + end + end + end + xs, ys, ret[1], ret[2] +end + +function nearest_intersection(xs, ys, xd, yd, vec_xy_d::GeometryTypes.Circle) + if xs == xd && ys == yd + return xs, ys, xd, yd + end + + α = atan(ys - yd, xs - xd) + xd = xd + vec_xy_d.r * cos(α) + yd = yd + vec_xy_d.r * sin(α) + + xs, ys, xd, yd +end + +function nearest_intersection(xs, ys, zs, xd, yd, zd, vec_xyz_d) + # TODO make 3d work. +end + +""" +Randomly pick a point to be the center control point of a bezier curve, +which is both equidistant between the endpoints and normally distributed +around the midpoint. +""" +function random_control_point( + xi, + xj, + yi, + yj, + curvature_scalar; + rng = rng_from_rng_or_seed(rng, nothing), +) + xmid = 0.5 * (xi + xj) + ymid = 0.5 * (yi + yj) + + # get the angle of y relative to x + theta = atan((yj - yi) / (xj - xi)) + 0.5pi + + # calc random shift relative to dist between x and y + dist = sqrt((xj - xi)^2 + (yj - yi)^2) + dist_from_mid = curvature_scalar * (rand(rng) - 0.5) * dist + + # now we have polar coords, we can compute the position, adding to the midpoint + (xmid + dist_from_mid * cos(theta), ymid + dist_from_mid * sin(theta)) +end + +function control_point(xi, xj, yi, yj, dist_from_mid) + xmid = 0.5 * (xi + xj) + ymid = 0.5 * (yi + yj) + + # get the angle of y relative to x + theta = atan((yj - yi) / (xj - xi)) + 0.5pi + + # dist = sqrt((xj-xi)^2 + (yj-yi)^2) + # dist_from_mid = curvature_scalar * 0.5dist + + # now we have polar coords, we can compute the position, adding to the midpoint + (xmid + dist_from_mid * cos(theta), ymid + dist_from_mid * sin(theta)) +end + +function annotation_extent(p, annotation; width_scalar = 0.06, height_scalar = 0.096) + str = string(annotation[3]) + position = annotation[1:2] + plot_size = get(p, :size, (600, 400)) + fontsize = annotation[4] + xextent_length = width_scalar * (600 / plot_size[1]) * fontsize * length(str)^0.8 + xextent = [position[1] - xextent_length, position[1] + xextent_length] + yextent_length = height_scalar * (400 / plot_size[2]) * fontsize + yextent = [position[2] - yextent_length, position[2] + yextent_length] + + [xextent, yextent] +end + +clockwise_difference(angle1, angle2) = pi - abs(abs(angle1 - angle2) - pi) + +function clockwise_mean(angles) + if clockwise_difference(angles[2], angles[1]) > angles[2] - angles[1] + return mean(angles) + pi + else + return mean(angles) + end +end + +""" + unoccupied_angle(x1, y1, x, y) + +Starting from the point [x1,y1], find the angle theta such that a line leaving at an angle +theta will have maximum distance from the points [x[i],y[i]] +""" +function unoccupied_angle(x1, y1, x, y) + @assert length(x) == length(y) + + if length(x) == 1 + return atan(y[1] - y1, x[1] - x1) + pi + end + + max_range = zeros(2) + # Calculate all angles between the point [x1,y1] and all points [x[i],y[i]], make sure + # that all of the angles are between 0 and 2pi + angles = [atan(y[i] - y1, x[i] - x1) for i in 1:length(x)] + for i in 1:length(angles) + if angles[i] < 0 + angles[i] += 2pi + end + end + # Sort all of the angles and calculate which two angles subtend the largest gap. + sort!(angles) + max_range .= [angles[end], angles[1]] + for i in 2:length(x) + if ( + clockwise_difference(angles[i], angles[i - 1]) > + clockwise_difference(max_range[2], max_range[1]) + ) + max_range .= [angles[i - 1], angles[i]] + end + end + # Return the angle that is in the middle of the two angles subtending the largest + # empty angle. + clockwise_mean(max_range) +end + +function process_edge_attribute(attr, source, destiny, weights) + if isnothing(attr) || (attr isa Symbol) + return attr + elseif attr isa Graphs.AbstractGraph + mat = incidence_matrix(attr) + attr = [mat[si, di] for (si, di) in zip(source, destiny)][:] |> permutedims + elseif attr isa Function + attr = + [ + attr(si, di, wi) for + (i, (si, di, wi)) in enumerate(zip(source, destiny, weights)) + ][:] |> permutedims + elseif attr isa Dict + attr = [attr[(si, di)] for (si, di) in zip(source, destiny)][:] |> permutedims + elseif all(size(attr) .!= 1) + attr = [attr[si, di] for (si, di) in zip(source, destiny)][:] |> permutedims + end + attr +end +# Function from Plots/src/components.jl +"get an array of tuples of points on a circle with radius `r`" +function partialcircle(start_θ, end_θ, n = 20, r = 1) + Tuple{Float64,Float64}[ + (r * cos(u), r * sin(u)) for u in range(start_θ, stop = end_θ, length = n) + ] +end + +function partialcircle(start_θ, end_θ, circle_center::Array{T,1}, n = 20, r = 1) where {T} + Tuple{Float64,Float64}[ + (r * cos(u) + circle_center[1], r * sin(u) + circle_center[2]) for + u in range(start_θ, stop = end_θ, length = n) + ] +end + +function partialellipse(start_θ, end_θ, n = 20, major_axis = 2, minor_axis = 1) + Tuple{Float64,Float64}[ + (major_axis * cos(u), minor_axis * sin(u)) for + u in range(start_θ, stop = end_θ, length = n) + ] +end + +function partialellipse( + start_θ, + end_θ, + ellipse_center::Array{T,1}, + n = 20, + major_axis = 2, + minor_axis = 1, +) where {T} + Tuple{Float64,Float64}[ + (major_axis * cos(u) + ellipse_center[1], minor_axis * sin(u) + ellipse_center[2]) + for u in range(start_θ, stop = end_θ, length = n) + ] +end + +# for chord diagrams: +function arcshape(θ1, θ2) + vcat(partialcircle(θ1, θ2, 15, 1.05), reverse(partialcircle(θ1, θ2, 15, 0.95))) +end + +# x and y limits for arc diagram () +function arcdiagram_limits(x, source, destiny) + @assert length(x) >= 2 + margin = abs(0.1 * (x[2] - x[1])) + xmin, xmax = extrema(x) + r = abs(0.5 * (xmax - xmin)) + mean_upside = mean(source .< destiny) + ylims = if mean_upside == 1.0 + (-margin, r + margin) + elseif mean_upside == 0.0 + (-r - margin, margin) + else + (-r - margin, r + margin) + end + (xmin - margin, xmax + margin), ylims +end + +function islabel(item) + ismissing(item) && return false + ((item isa AbstractFloat) && isnan(item)) && return false + !in(item, (nothing, false, "")) +end + +function replacement_kwarg(sym, name, plotattributes, graph_aliases) + replacement = name + for alias in graph_aliases[sym] + if haskey(plotattributes, alias) + replacement = plotattributes[alias] + end + end + replacement +end + +macro process_aliases(plotattributes, graph_aliases) + ex = Expr(:block) + attributes = getfield(__module__, graph_aliases) |> keys + ex.args = [ + Expr( + :(=), + esc(sym), + :($(esc(replacement_kwarg))( + $(QuoteNode(sym)), + $(esc(sym)), + $(esc(plotattributes)), + $(esc(graph_aliases)), + )), + ) for sym in attributes + ] + ex +end + +remove_aliases!(sym, plotattributes, graph_aliases) = + for alias in graph_aliases[sym] + if haskey(plotattributes, alias) + delete!(plotattributes, alias) + end + end + +# From Plots/src/utils.jl +isnothing(x::Nothing) = true +isnothing(x) = false + +# From Plots/src/Plots.jl +ignorenan_extrema(x) = Base.extrema(x) +# From Plots/src/utils.jl +ignorenan_extrema(x::AbstractArray{F}) where {F<:AbstractFloat} = NaNMath.extrema(x) +# From Plots/src/components.jl +function extrema_plus_buffer(v, buffmult = 0.2) + vmin, vmax = extrema(v) + vdiff = vmax - vmin + zero_buffer = vdiff == 0 ? 1.0 : 0.0 + buffer = (vdiff + zero_buffer) * buffmult + vmin - buffer, vmax + buffer +end diff --git a/GraphRecipes/test/functions.jl b/GraphRecipes/test/functions.jl new file mode 100644 index 000000000..140c6a686 --- /dev/null +++ b/GraphRecipes/test/functions.jl @@ -0,0 +1,284 @@ +using Plots +using StableRNGs +using GraphRecipes +using GraphRecipes.Colors +using GraphRecipes.AbstractTrees +function random_labelled_graph() + n = 15 + rng = StableRNG(1) + A = Float64[rand(rng) < 0.5 ? 0 : rand(rng) for i in 1:n, j in 1:n] + for i in 1:n + A[i, 1:(i - 1)] = A[1:(i - 1), i] + A[i, i] = 0 + end + x = rand(rng, n) + y = rand(rng, n) + z = rand(rng, n) + p = graphplot( + A, + nodesize = 0.2, + node_weights = 1:n, + nodecolor = range(colorant"yellow", stop = colorant"red", length = n), + names = 1:n, + fontsize = 10, + linecolor = :darkgrey, + layout_kw = Dict(:x => x, :y => y), + rng = rng, + ) + p, n, A, x, y, z +end + +function random_3d_graph() + n, A, x, y, z = random_labelled_graph()[2:end] + graphplot( + A, + node_weights = 1:n, + markercolor = :darkgray, + dim = 3, + markersize = 5, + markershape = :circle, + linecolor = :darkgrey, + linealpha = 0.5, + layout_kw = Dict(:x => x, :y => y, :z => z), + rng = StableRNG(1), + ) +end + +function light_graphs() + g = wheel_graph(10) + graphplot(g, curves = false, rng = StableRNG(1)) +end + +function directed() + g = [ + 0 1 1 + 0 0 1 + 0 1 0 + ] + graphplot(g, names = 1:3, curvature_scalar = 0.1, rng = StableRNG(1)) +end + +function edgelabel() + n = 8 + g = wheel_digraph(n) + edgelabel_dict = Dict() + for i in 1:n + for j in 1:n + edgelabel_dict[(i, j)] = string("edge ", i, " to ", j) + end + end + + graphplot( + g, + names = 1:n, + edgelabel = edgelabel_dict, + curves = false, + nodeshape = :rect, + rng = StableRNG(1), + ) +end + +function selfedges() + g = [ + 1 1 1 + 0 0 1 + 0 0 1 + ] + graphplot(DiGraph(g), self_edge_size = 0.2, rng = StableRNG(1)) +end + +multigraphs() = graphplot( + [[1, 1, 2, 2], [1, 1, 1], [1]], + names = "node_" .* string.(1:3), + nodeshape = :circle, + self_edge_size = 0.25, + rng = StableRNG(1), +) + +function arc_chord_diagrams() + rng = StableRNG(1) + adjmat = Symmetric(sparse(rand(rng, 0:1, 8, 8))) + plot( + graphplot( + adjmat, + method = :chorddiagram, + names = [text(string(i), 8) for i in 1:8], + linecolor = :black, + fillcolor = :lightgray, + rng = rng, + ), + graphplot( + adjmat, + method = :arcdiagram, + markersize = 0.5, + markershape = :circle, + linecolor = :black, + markercolor = :black, + rng = rng, + ), + ) +end + +function marker_properties() + N = 8 + seed = 42 + rng = StableRNG(seed) + g = barabasi_albert(N, 1; rng = rng) + weights = [length(neighbors(g, i)) for i in 1:nv(g)] + graphplot( + g, + curvature_scalar = 0, + node_weights = weights, + nodesize = 0.25, + linecolor = :gray, + linewidth = 2.5, + nodeshape = :circle, + node_z = rand(rng, N), + markercolor = :viridis, + nodestrokewidth = 1.5, + markerstrokestyle = :solid, + markerstrokealpha = 1.0, + markerstrokecolor = :lightgray, + colorbar = true, + rng = rng, + ) +end + +function ast_example() + code = :(function mysum(list) + out = 0 + for value in list + out += value + end + out + end) + plot( + code, + fontsize = 10, + shorten = 0.01, + axis_buffer = 0.15, + nodeshape = :rect, + size = (1000, 1000), + rng = StableRNG(1), + ) +end + +julia_type_tree() = plot( + AbstractFloat, + method = :tree, + fontsize = 10, + nodeshape = :ellipse, + size = (1000, 1000), + rng = StableRNG(1), +) + +AbstractTrees.children(d::Dict) = [p for p in d] +AbstractTrees.children(p::Pair) = AbstractTrees.children(p[2]) +function AbstractTrees.printnode(io::IO, p::Pair) + str = + isempty(AbstractTrees.children(p[2])) ? string(p[1], ": ", p[2]) : + string(p[1], ": ") + print(io, str) +end + +function julia_dict_tree() + d = Dict(:a => 2, :d => Dict(:b => 4, :c => "Hello"), :e => 5.0) + plot( + TreePlot(d), + method = :tree, + fontsize = 10, + nodeshape = :ellipse, + size = (1000, 1000), + rng = StableRNG(1), + ) +end + +diamond_nodeshape(x_i, y_i, s) = [ + (x_i + 0.5s * dx, y_i + 0.5s * dy) for (dx, dy) in [(1, 1), (-1, 1), (-1, -1), (1, -1)] +] + +function diamond_nodeshape_wh(x_i, y_i, h, w) + out = Tuple{Float64,Float64}[(-0.5, 0), (0, -0.5), (0.5, 0), (0, 0.5)] + map(out) do t + x = t[1] * h + y = t[2] * w + (x + x_i, y + y_i) + end +end + +function custom_nodeshapes_single() + rng = StableRNG(1) + g = rand(rng, 5, 5) + g[g .> 0.5] .= 0 + for i in 1:5 + g[i, i] = 0 + end + graphplot(g, nodeshape = diamond_nodeshape, rng = rng) +end + +function custom_nodeshapes_various() + rng = StableRNG(1) + g = rand(rng, 5, 5) + g[g .> 0.5] .= 0 + for i in 1:5 + g[i, i] = 0 + end + graphplot( + g, + nodeshape = [ + :circle, + diamond_nodeshape, + diamond_nodeshape_wh, + :hexagon, + diamond_nodeshape_wh, + ], + rng = rng, + ) +end + +function funky_edge_and_marker_args() + n = 5 + g = SimpleDiGraph(n) + + add_edge!(g, 1, 2) + add_edge!(g, 2, 3) + add_edge!(g, 3, 4) + add_edge!(g, 4, 4) + add_edge!(g, 4, 5) + + curviness_matrix = zeros(n, n) + edgewidth_matrix = zeros(n, n) + edgestyle_dict = Dict() + for e in edges(g) + curviness_matrix[e.src, e.dst] = 0.5sin(e.src) + edgewidth_matrix[e.src, e.dst] = 0.8e.dst + edgestyle_dict[(e.src, e.dst)] = e.src < 2.0 ? :solid : e.src > 3.0 ? :dash : :dot + end + edge_z_function = (s, d, w) -> curviness_matrix[s, d] + + graphplot( + g, + names = [" I ", " am ", " a ", "funky", "graph"], + x = [1, 2, 3, 4, 5], + y = [5, 4, 3, 2, 1], + nodesize = 0.3, + size = (1000, 1000), + axis_buffer = 0.6, + fontsize = 16, + self_edge_size = 1.3, + curvature_scalar = curviness_matrix, + edgestyle = edgestyle_dict, + edgewidth = edgewidth_matrix, + edge_z = edge_z_function, + nodeshape = :circle, + node_z = [1, 2, 3, 4, 5], + nodestroke_z = [5, 4, 3, 2, 1], + edgecolor = :viridis, + markercolor = :viridis, + nodestrokestyle = [:dash, :solid, :dot], + nodestrokewidth = 6, + linewidth = 2, + colorbar = true, + rng = StableRNG(1), + ) +end diff --git a/GraphRecipes/test/parse_readme.jl b/GraphRecipes/test/parse_readme.jl new file mode 100644 index 000000000..bcac9e179 --- /dev/null +++ b/GraphRecipes/test/parse_readme.jl @@ -0,0 +1,21 @@ +using Markdown +using GraphRecipes +using Plots + +cd(@__DIR__) + +readme = read("../README.md", String) |> Markdown.parse +content = readme.content + +code_blocks = [] +for paragraph in content + if paragraph isa Markdown.Code + push!(code_blocks, paragraph.code) + end +end + +# Parse the code examples on the README into expressions. Ignore the first one, which is +# the installation instructions. +readme_exprs = [Meta.parse("begin $(code_blocks[i]) end") for i in 2:length(code_blocks)] + +julia_logo_pun() = eval(readme_exprs[1]) diff --git a/GraphRecipes/test/pkg_deps.jl b/GraphRecipes/test/pkg_deps.jl new file mode 100644 index 000000000..007b89df4 --- /dev/null +++ b/GraphRecipes/test/pkg_deps.jl @@ -0,0 +1,116 @@ + +module PkgDeps + +using GraphRecipes + +# const _pkgs = Pkg.available() +# const _idxmap = Dict(p=>i for (i,p) in enumerate(_pkgs)) +# const _alist = [Int[] for i=1:length(_pkgs)] + +# for pkg in _pkgs +# i = _idxmap[pkg] +# for dep in Pkg.dependents(pkg) +# if !haskey(_idxmap, dep) +# push!(_pkgs, dep) +# push!(_alist, []) +# _idxmap[dep] = length(_pkgs) +# end +# j = _idxmap[dep] +# push!(_alist[j], i) +# end +# end + +@userplot DepsGraph +@recipe function f(g::DepsGraph) + source, destiny, names = g.args + arrow --> arrow() + markersize --> 50 + markeralpha --> 0.2 + linealpha --> 0.4 + linewidth --> 2 + names --> names + func --> :tree + root --> :left + GraphRecipes.GraphPlot((source, destiny)) +end + +# const args = (source, destiny, pkgs) + +const all_pkgs = Pkg.available() +@show all_pkgs +const deplists = Dict(pkg => Pkg.dependents(pkg) for pkg in all_pkgs) +@show deplists + +const childlists = Dict(pkg => Set{String}() for pkg in all_pkgs) +for pkg in all_pkgs + for dep in deplists[pkg] + if haskey(childlists, dep) + push!(childlists[dep], pkg) + else + warn("Package $dep wasn't in Pkg.available()") + deplists[dep] = [] + childlists[dep] = Set([pkg]) + end + end +end +@show childlists + +function add_deps(pkg, deps = Set([pkg])) + for dep in deplists[pkg] + if !(dep in deps) + push!(deps, dep) + add_deps(dep, deps) + end + end + deps +end + +function add_children(pkg, children = Set([pkg])) + for child in childlists[pkg] + if !(child in children) + push!(children, child) + add_children(child, children) + end + end + children +end + +function plotdeps(pkg) + pkgs = unique(union(add_deps(pkg), add_children(pkg))) + idxmap = Dict(p => i for (i, p) in enumerate(pkgs)) + + source, destiny = Int[], Int[] + for pkg in pkgs + i = idxmap[pkg] + for dep in deplists[pkg] + # if !haskey(_idxmap, dep) + # push!(pkgs, dep) + # push!(_alist, []) + # _idxmap[dep] = length(pkgs) + # end + if !haskey(idxmap, dep) + warn("missing: ", dep) + continue + end + j = idxmap[dep] + push!(source, j) + push!(destiny, i) + # push!(_alist[j], i) + end + end + depsgraph(source, destiny, pkgs, root = :bottom) +end + +# # pkgs = Set([pkg]) +# idx = _idxmap[pkg] +# source, destiny = Int[], Int[] +# for j in _alist[i] +# push!(pkgs, _pkgs[j]) +# push!(source, j) +# push!(destiny, i) +# end + +# to use: +# depsgraph(PkgDeps.args...) + +end # module diff --git a/GraphRecipes/test/runtests.jl b/GraphRecipes/test/runtests.jl new file mode 100644 index 000000000..295cabf36 --- /dev/null +++ b/GraphRecipes/test/runtests.jl @@ -0,0 +1,189 @@ +using VisualRegressionTests +using AbstractTrees +using LinearAlgebra +using Logging +using GraphRecipes +using SparseArrays +using ImageMagick +using StableRNGs +using Graphs +using Plots +using Test +using Gtk # for popup + +import Plots: PlotsBase + +isci() = get(ENV, "CI", "false") == "true" +itol(tol = nothing) = something(tol, isci() ? 1e-3 : 1e-5) + +include("functions.jl") +include("parse_readme.jl") + +default(show = false, reuse = true) + +cd(joinpath(@__DIR__, "..", "assets")) do + @testset "FIGURES" begin + @plottest random_labelled_graph() "random_labelled_graph.png" popup = !isci() tol = + itol() + + @plottest random_3d_graph() "random_3d_graph.png" popup = !isci() tol = itol() + + @plottest light_graphs() "light_graphs.png" popup = !isci() tol = itol() + + @plottest directed() "directed.png" popup = !isci() tol = itol() + + @plottest marker_properties() "marker_properties.png" popup = !isci() tol = itol() + + @plottest edgelabel() "edgelabel.png" popup = !isci() tol = itol() + + @plottest selfedges() "selfedges.png" popup = !isci() tol = itol() + + @plottest multigraphs() "multigraphs.png" popup = !isci() tol = itol() + + @plottest arc_chord_diagrams() "arc_chord_diagrams.png" popup = !isci() tol = itol() + + @plottest ast_example() "ast_example.png" popup = !isci() tol = itol() + + if !(v"1.6" < VERSION < v"1.7") # having Static.jl in the Manifest adds another type + @plottest julia_type_tree() "julia_type_tree.png" popup = !isci() tol = itol() + end + @plottest julia_dict_tree() "julia_dict_tree.png" popup = !isci() tol = itol() + + @plottest funky_edge_and_marker_args() "funky_edge_and_marker_args.png" popup = + !isci() tol = itol() + + @plottest custom_nodeshapes_single() "custom_nodeshapes_single.png" popup = !isci() tol = + itol() + + @plottest custom_nodeshapes_various() "custom_nodeshapes_various.png" popup = + !isci() tol = itol() + end + + @testset "README" begin + @plottest julia_logo_pun() "readme_julia_logo_pun.png" popup = !isci() tol = itol() + end +end + +@testset "issues" begin + @testset "143" begin + g = SimpleGraph(7) + + add_edge!(g, 2, 3) + add_edge!(g, 3, 4) + @test g.ne == 2 + al = GraphRecipes.get_adjacency_list(g) + @test isempty(al[1]) + @test al[2] == [3] + @test al[3] == [2, 4] + @test al[4] == [3] + @test isempty(al[5]) + @test isempty(al[6]) + @test isempty(al[7]) + s, d, w = GraphRecipes.get_source_destiny_weight(al) + @test s == [2, 3, 3, 4] + @test d == [3, 2, 4, 3] + @test all(w .≈ 1) + + with_logger(ConsoleLogger(stderr, Logging.Debug)) do + pl = graphplot(g) + @test first(pl.series_list)[:extra_kwargs][:num_edges_nodes] == (2, 7) + + add_edge!(g, 6, 7) + @test g.ne == 3 + pl = graphplot(g) + @test first(pl.series_list)[:extra_kwargs][:num_edges_nodes] == (3, 7) + + # old behavior (see issue), can be recovered using `trim=true` + g = SimpleGraph(7) + add_edge!(g, 2, 3) + add_edge!(g, 3, 4) + pl = graphplot(g; trim = true) + @test first(pl.series_list)[:extra_kwargs][:num_edges_nodes] == (2, 4) + end + end + + @testset "180" begin + rng = StableRNG(1) + mat = Symmetric(sparse(rand(rng, 0:1, 8, 8))) + graphplot(mat, method = :arcdiagram, rng = rng) + end +end + +@testset "utils.jl" begin + rng = StableRNG(1) + @test GraphRecipes.directed_curve(0.0, 1.0, 0.0, 1.0, rng = rng) == + GraphRecipes.directed_curve(0, 1, 0, 1, rng = rng) + + @test GraphRecipes.isnothing(nothing) == PlotsBase.isnothing(nothing) + @test GraphRecipes.isnothing(missing) == PlotsBase.isnothing(missing) + @test GraphRecipes.isnothing(NaN) == PlotsBase.isnothing(NaN) + @test GraphRecipes.isnothing(0) == PlotsBase.isnothing(0) + @test GraphRecipes.isnothing(1) == PlotsBase.isnothing(1) + @test GraphRecipes.isnothing(0.0) == PlotsBase.isnothing(0.0) + @test GraphRecipes.isnothing(1.0) == PlotsBase.isnothing(1.0) + + for (s, e) in [(rand(rng), rand(rng)) for i in 1:100] + @test GraphRecipes.partialcircle(s, e) == PlotsBase.partialcircle(s, e) + end + + @testset "nearest_intersection" begin + @test GraphRecipes.nearest_intersection(0, 0, 3, 3, [(1, 0), (0, 1)]) == + (0, 0, 0.5, 0.5) + @test GraphRecipes.nearest_intersection(1, 2, 1, 2, []) == (1, 2, 1, 2) + end + + @testset "unoccupied_angle" begin + @test GraphRecipes.unoccupied_angle(1, 1, [1, 1, 1, 1], [2, 0, 3, -1]) == 2pi + end + + @testset "islabel" begin + @test GraphRecipes.islabel("hi") + @test GraphRecipes.islabel(1) + @test !GraphRecipes.islabel(missing) + @test !GraphRecipes.islabel(NaN) + @test !GraphRecipes.islabel(false) + @test !GraphRecipes.islabel("") + end + + @testset "control_point" begin + @test GraphRecipes.control_point(0, 0, 6, 0, 4) == (4, 3) + end + + # TODO: Actually test that the aliases produce the same plots, rather than just + # checking that they don't error. Also, test all of the different aliases. + @testset "Aliases" begin + A = [1 0 1 0; 0 0 1 1; 1 1 1 1; 0 0 1 1] + graphplot(A, markercolor = :red, markershape = :rect, markersize = 0.5, rng = rng) + graphplot(A, nodeweights = 1:4, rng = rng) + graphplot(A, curvaturescalar = 0, rng = rng) + graphplot(A, el = Dict((1, 2) => ""), elb = true, rng = rng) + graphplot(A, ew = (s, d, w) -> 3, rng = rng) + graphplot(A, ses = 0.5, rng = rng) + end +end + +# ----------------------------------------- +# marginalhist + +# using Distributions +# n = 1000 +# x = rand(RNG, Gamma(2), n) +# y = -0.5x + randn(RNG, n) +# marginalhist(x, y) + +# ----------------------------------------- +# portfolio composition map + +# # fake data +# tickers = ["IBM", "Google", "Apple", "Intel"] +# N = 10 +# D = length(tickers) +# weights = rand(RNG, N, D) +# weights ./= sum(weights, 2) +# returns = sort!((1:N) + D*randn(RNG, N)) + +# # plot it +# portfoliocomposition(weights, returns, labels = tickers') + +# ----------------------------------------- +# diff --git a/StatsPlots/Project.toml b/StatsPlots/Project.toml new file mode 100644 index 000000000..c0c4a6d62 --- /dev/null +++ b/StatsPlots/Project.toml @@ -0,0 +1,51 @@ +name = "StatsPlots" +uuid = "f3b207a7-027a-5e70-b257-86293d7955fd" +version = "1.0" + +[deps] +AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" +Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5" +DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" +KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411" +NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +Observables = "510215fc-4207-5dde-b226-833fc4488ee2" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +RecipesPipeline = "01d81517-befc-4cb6-b9ec-a95719d0359c" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +TableOperations = "ab02a1b2-a7df-11e8-156e-fb1833f50b87" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +Widgets = "cc8bc4a8-27d6-5769-a93b-9d913e69aa62" + +[compat] +AbstractFFTs = "1.1" +Clustering = "0.13 - 0.15" +DataStructures = "0.17 - 0.18" +Distributions = "0.21 - 0.25" +Interpolations = "0.12 - 0.15" +KernelDensity = "0.5 - 0.6" +MultivariateStats = "0.9 - 0.10" +NaNMath = "1" +Observables = "0.3 - 0.5" +Plots = "2" +RecipesBase = "1" +RecipesPipeline = "1" +Reexport = "0.2, 1" +StatsBase = "0.32 - 0.34" +TableOperations = "1" +Tables = "1" +Widgets = "0.5 - 0.6" +julia = "1.10" + +[extras] +NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["Test", "NaNMath", "StableRNGs"] diff --git a/StatsPlots/src/StatsPlots.jl b/StatsPlots/src/StatsPlots.jl new file mode 100644 index 000000000..bfd7daef9 --- /dev/null +++ b/StatsPlots/src/StatsPlots.jl @@ -0,0 +1,55 @@ +module StatsPlots + +using Reexport +import RecipesBase: recipetype +import Tables +import TableOperations +using RecipesPipeline +@reexport using Plots +import Plots: _cycle +using StatsBase +using Distributions +using LinearAlgebra: eigen, diagm +using Widgets, Observables +import Observables: AbstractObservable, @map, observe +import Widgets: @nodeps +import DataStructures: OrderedDict +import Clustering: Hclust, nnodes +using Interpolations +using MultivariateStats: MultivariateStats +using AbstractFFTs: fft, ifft +import KernelDensity +using NaNMath +@recipe f(k::KernelDensity.UnivariateKDE) = k.x, k.density +@recipe f(k::KernelDensity.BivariateKDE) = k.x, k.y, permutedims(k.density) + +@shorthands cdensity + +export @df, dataviewer + +isvertical(plotattributes) = + let val = get(plotattributes, :orientation, missing) + val === missing || val in (:vertical, :v) + end + +include("df.jl") +include("interact.jl") +include("corrplot.jl") +include("cornerplot.jl") +include("distributions.jl") +include("boxplot.jl") +include("dotplot.jl") +include("violin.jl") +include("ecdf.jl") +include("hist.jl") +include("marginalhist.jl") +include("marginalscatter.jl") +include("marginalkde.jl") +include("bar.jl") +include("dendrogram.jl") +include("andrews.jl") +include("ordinations.jl") +include("covellipse.jl") +include("errorline.jl") + +end # module diff --git a/StatsPlots/src/andrews.jl b/StatsPlots/src/andrews.jl new file mode 100644 index 000000000..fc1204ed5 --- /dev/null +++ b/StatsPlots/src/andrews.jl @@ -0,0 +1,63 @@ +@userplot AndrewsPlot + +""" + andrewsplot(args...; kw...) +Shows each row of an array (or table) as a line. The `x` argument specifies a +grouping variable. This is a way to visualize structure in high-dimensional data. +https://en.wikipedia.org/wiki/Andrews_plot +#Examples +```julia +using RDatasets, StatsPlots +iris = dataset("datasets", "iris") +@df iris andrewsplot(:Species, cols(1:4)) +``` +""" +andrewsplot + +@recipe function f(h::AndrewsPlot) + if length(h.args) == 2 # specify x if not given + x, y = h.args + else + y = h.args[1] + x = ones(size(y, 1)) + end + + seriestype := :andrews + + # series in a user recipe will have different colors + for g in unique(x) + @series begin + label := "$g" + range(-π, stop = π, length = 200), Surface(y[g .== x, :]) #surface needed, or the array will be split into columns + end + end + nothing +end + +# the series recipe +@recipe function f(::Type{Val{:andrews}}, x, y, z) + y = y.surf + rows, cols = size(y) + seriestype := :path + + # these series are the lines, will keep the same colors + for j = 1:rows + @series begin + primary := false + ys = zeros(length(x)) + terms = + [isodd(i) ? cos((i ÷ 2) .* ti) : sin((i ÷ 2) .* ti) for i = 2:cols, ti in x] + for ti in eachindex(x) + ys[ti] = y[j, 1] / sqrt(2) + sum(y[j, i] .* terms[i - 1, ti] for i = 2:cols) + end + + x := x + y := ys + () + end + end + + x := [] + y := [] + () +end diff --git a/StatsPlots/src/bar.jl b/StatsPlots/src/bar.jl new file mode 100644 index 000000000..dc47a3634 --- /dev/null +++ b/StatsPlots/src/bar.jl @@ -0,0 +1,97 @@ +@userplot GroupedBar + +recipetype(::Val{:groupedbar}, args...) = GroupedBar(args) + +PlotsBase.group_as_matrix(g::GroupedBar) = true + +grouped_xy(x::AbstractVector, y::AbstractArray) = x, y +grouped_xy(y::AbstractArray) = 1:size(y, 1), y + +@recipe function f(g::GroupedBar; spacing = 0) + x, y = grouped_xy(g.args...) + + nr, nc = size(y) + isstack = pop!(plotattributes, :bar_position, :dodge) === :stack + isylog = pop!(plotattributes, :yscale, :identity) ∈ (:log10, :log) + the_ylims = pop!(plotattributes, :ylims, (-Inf, Inf)) + + # extract xnums and set default bar width. + # might need to set xticks as well + xnums = if eltype(x) <: Number + xdiff = length(x) > 1 ? mean(diff(x)) : 1 + bar_width --> 0.8 * xdiff + x + else + bar_width --> 0.8 + ux = unique(x) + xnums = (1:length(ux)) .- 0.5 + xticks --> (xnums, ux) + xnums + end + @assert length(xnums) == nr + + # compute the x centers. for dodge, make a matrix for each column + x = if isstack + x + else + bws = plotattributes[:bar_width] / nc + bar_width := bws * clamp(1 - spacing, 0, 1) + xmat = zeros(nr, nc) + for r = 1:nr + bw = _cycle(bws, r) + farleft = xnums[r] - 0.5 * (bw * nc) + for c = 1:nc + xmat[r, c] = farleft + 0.5bw + (c - 1) * bw + end + end + xmat + end + + fill_bottom = if isylog + if isfinite(the_ylims[1]) + min(minimum(y) / 100, the_ylims[1]) + else + minimum(y) / 100 + end + else + 0 + end + # compute fillrange + y, fr = + isstack ? groupedbar_fillrange(y) : + (y, get(plotattributes, :fillrange, [fill_bottom])) + if isylog + replace!(fr, 0 => fill_bottom) + end + fillrange := fr + + seriestype := :bar + x, y +end + +function groupedbar_fillrange(y) + nr, nc = size(y) + # bar series fills from y[nr, nc] to fr[nr, nc], y .>= fr + fr = zeros(nr, nc) + y = copy(y) + y[.!isfinite.(y)] .= 0 + for r = 1:nr + y_neg = 0 + # upper & lower bounds for positive bar + y_pos = sum([e for e in y[r, :] if e > 0]) + # division subtract towards 0 + for c = 1:nc + el = y[r, c] + if el >= 0 + y[r, c] = y_pos + y_pos -= el + fr[r, c] = y_pos + else + fr[r, c] = y_neg + y_neg += el + y[r, c] = y_neg + end + end + end + y, fr +end diff --git a/StatsPlots/src/boxplot.jl b/StatsPlots/src/boxplot.jl new file mode 100644 index 000000000..1587bbe9f --- /dev/null +++ b/StatsPlots/src/boxplot.jl @@ -0,0 +1,259 @@ + +# --------------------------------------------------------------------------- +# Box Plot + +notch_width(q2, q4, N) = 1.58 * (q4 - q2) / sqrt(N) + +@recipe function f( + ::Type{Val{:boxplot}}, + x, + y, + z; + notch = false, + whisker_range = 1.5, + outliers = true, + whisker_width = :half, + sort_labels_by = identity, + xshift = 0.0, +) + # if only y is provided, then x will be UnitRange 1:size(y,2) + if typeof(x) <: AbstractRange + x = if step(x) == first(x) == 1 + plotattributes[:series_plotindex] + else + [getindex(x, plotattributes[:series_plotindex])] + end + end + xsegs, ysegs = Plots.PlotsBase.Segments(), Plots.PlotsBase.Segments() + texts = String[] + glabels = sort(collect(unique(x))) + warning = false + outliers_x, outliers_y = zeros(0), zeros(0) + bw = plotattributes[:bar_width] + isnothing(bw) && (bw = 0.8) + @assert whisker_width === :match || whisker_width == :half || whisker_width >= 0 "whisker_width must be :match, :half, or a positive number" + ww = whisker_width === :match ? bw : whisker_width == :half ? bw / 2 : whisker_width + for (i, glabel) in enumerate(sort(glabels; by = sort_labels_by)) + # filter y + values = y[filter(i -> _cycle(x, i) == glabel, 1:length(y))] + + # compute quantiles + q1, q2, q3, q4, q5 = quantile(values, range(0, stop = 1, length = 5)) + + # notch + n = notch_width(q2, q4, length(values)) + + # warn on inverted notches? + if notch && !warning && ((q2 > (q3 - n)) || (q4 < (q3 + n))) + @warn("Boxplot's notch went outside hinges. Set notch to false.") + warning = true # Show the warning only one time + end + + # make the shape + center = PlotsBase.discrete_value!(plotattributes, :x, glabel)[1] + xshift + hw = 0.5_cycle(bw, i) # Box width + HW = 0.5_cycle(ww, i) # Whisker width + l, m, r = center - hw, center, center + hw + lw, rw = center - HW, center + HW + + # internal nodes for notches + L, R = center - 0.5 * hw, center + 0.5 * hw + + # outliers + if Float64(whisker_range) != 0.0 # if the range is 0.0, the whiskers will extend to the data + limit = whisker_range * (q4 - q2) + inside = Float64[] + for value in values + if (value < (q2 - limit)) || (value > (q4 + limit)) + if outliers + push!(outliers_y, value) + push!(outliers_x, center) + end + else + push!(inside, value) + end + end + # change q1 and q5 to show outliers + # using maximum and minimum values inside the limits + q1, q5 = PlotsBase.ignorenan_extrema(inside) + q1, q5 = (min(q1, q2), max(q4, q5)) # whiskers cannot be inside the box + end + # Box + push!(xsegs, m, lw, rw, m, m) # lower T + push!(ysegs, q1, q1, q1, q1, q2) # lower T + push!( + texts, + "Lower fence: $q1", + "Lower fence: $q1", + "Lower fence: $q1", + "Lower fence: $q1", + "Q1: $q2", + "", + ) + + if notch + push!(xsegs, r, r, R, L, l, l, r, r) # lower box + push!(xsegs, r, r, l, l, L, R, r, r) # upper box + + push!(ysegs, q2, q3 - n, q3, q3, q3 - n, q2, q2, q3 - n) # lower box + push!( + texts, + "Q1: $q2", + "Median: $q3 ± $n", + "Median: $q3 ± $n", + "Median: $q3 ± $n", + "Median: $q3 ± $n", + "Q1: $q2", + "Q1: $q2", + "Median: $q3 ± $n", + "", + ) + + push!(ysegs, q3 + n, q4, q4, q3 + n, q3, q3, q3 + n, q4) # upper box + push!( + texts, + "Median: $q3 ± $n", + "Q3: $q4", + "Q3: $q4", + "Median: $q3 ± $n", + "Median: $q3 ± $n", + "Median: $q3 ± $n", + "Median: $q3 ± $n", + "Q3: $q4", + "", + ) + else + push!(xsegs, r, r, l, l, r, r) # lower box + push!(xsegs, r, l, l, r, r, m) # upper box + push!(ysegs, q2, q3, q3, q2, q2, q3) # lower box + push!( + texts, + "Q1: $q2", + "Median: $q3", + "Median: $q3", + "Q1: $q2", + "Q1: $q2", + "Median: $q3", + "", + ) + push!(ysegs, q4, q4, q3, q3, q4, q4) # upper box + push!( + texts, + "Q3: $q4", + "Q3: $q4", + "Median: $q3", + "Median: $q3", + "Q3: $q4", + "Q3: $q4", + "", + ) + end + + push!(xsegs, m, lw, rw, m, m) # upper T + push!(ysegs, q5, q5, q5, q5, q4) # upper T + push!( + texts, + "Upper fence: $q5", + "Upper fence: $q5", + "Upper fence: $q5", + "Upper fence: $q5", + "Q3: $q4", + "", + ) + end + + if !isvertical(plotattributes) + # We should draw the plot horizontally! + xsegs, ysegs = ysegs, xsegs + outliers_x, outliers_y = outliers_y, outliers_x + + # Now reset the orientation, so that the axes limits are set correctly. + orientation := default(:orientation) + end + + @series begin + # To prevent linecolor equal to fillcolor (It makes the median visible) + if plotattributes[:linecolor] == plotattributes[:fillcolor] + plotattributes[:linecolor] = plotattributes[:markerstrokecolor] + end + primary := true + seriestype := :shape + x := xsegs.pts + y := ysegs.pts + () + end + + # Outliers + if outliers && !isempty(outliers) + @series begin + primary := false + seriestype := :scatter + if get!(plotattributes, :markershape, :circle) === :none + plotattributes[:markershape] = :circle + end + + fillrange := nothing + x := outliers_x + y := outliers_y + () + end + end + + # Hover + primary := false + seriestype := :path + marker := false + if PlotsBase.is_attr_supported(PlotsBase.backend(), :hover) + hover := texts + end + linewidth := 0 + x := xsegs.pts + y := ysegs.pts + () +end + +PlotsBase.@deps boxplot shape scatter + +# ------------------------------------------------------------------------------ +# Grouped Boxplot + +@userplot GroupedBoxplot + +recipetype(::Val{:groupedboxplot}, args...) = GroupedBoxplot(args) + +@recipe function f(g::GroupedBoxplot; spacing = 0.1) + x, y = grouped_xy(g.args...) + + # extract xnums and set default bar width. + # might need to set xticks as well + ux = unique(x) + x = if eltype(x) <: Number + bar_width --> (0.8 * mean(diff(sort(ux)))) + float.(x) + else + bar_width --> 0.8 + xnums = [findfirst(isequal(xi), ux) for xi in x] .- 0.5 + xticks --> (eachindex(ux) .- 0.5, ux) + xnums + end + + # shift x values for each group + group = get(plotattributes, :group, nothing) + if group != nothing + gb = RecipesPipeline._extract_group_attributes(group) + labels, idxs = getfield(gb, 1), getfield(gb, 2) + n = length(labels) + bws = plotattributes[:bar_width] / n + bar_width := bws * clamp(1 - spacing, 0, 1) + for i = 1:n + groupinds = idxs[i] + Δx = _cycle(bws, i) * (i - (n + 1) / 2) + x[groupinds] .+= Δx + end + end + + seriestype := :boxplot + x, y +end + +PlotsBase.@deps groupedboxplot boxplot diff --git a/StatsPlots/src/cornerplot.jl b/StatsPlots/src/cornerplot.jl new file mode 100644 index 000000000..74838ed3e --- /dev/null +++ b/StatsPlots/src/cornerplot.jl @@ -0,0 +1,120 @@ +@userplot CornerPlot + +recipetype(::Val{:cornerplot}, args...) = CornerPlot(args) + +@recipe function f(cp::CornerPlot; compact = false, maxvariables = 30, histpct = 0.1) + mat = cp.args[1] + C = cor(mat) + @assert typeof(mat) <: AbstractMatrix + N = size(mat, 2) + if N > maxvariables + error( + "Requested to plot $N variables in $(N^2) subplots! Likely, the first input needs transposing, otherwise increase maxvariables.", + ) + end + + # k is the number of rows/columns to hide + k = compact ? 1 : 0 + + # n is the total number of rows/columns. hists always shown + n = N + 1 - k + + labs = pop!(plotattributes, :label, ["x$i" for i = 1:N]) + if labs != [""] && length(labs) != N + error("Number of labels not identical to number of datasets") + end + + # build a grid layout, where the histogram sizes are a fixed percentage, and we + scatterpcts = ones(n - 1) * (1 - histpct) / (n - 1) + g = grid( + n, + n, + widths = vcat(scatterpcts, histpct), + heights = vcat(histpct, scatterpcts), + ) + spidx = 1 + indices = zeros(Int, n, n) + for i = 1:n, j = 1:n + isblank = (i == 1 && j == n) || (compact && i > 1 && j < n && j >= i) + g[i, j].attr[:blank] = isblank + if !isblank + indices[i, j] = spidx + spidx += 1 + end + end + layout := g + + # some defaults + legend := false + foreground_color_border := nothing + margin --> 1mm + titlefont --> font(11) + fillcolor --> PlotsBase.fg_color(plotattributes) + linecolor --> PlotsBase.fg_color(plotattributes) + grid --> true + ticks := nothing + xformatter := x -> "" + yformatter := y -> "" + link := :both + grad = cgrad(get(plotattributes, :markercolor, :RdYlBu)) + + # figure out good defaults for scatter plot dots: + pltarea = 1 / (2n) + nsamples = size(mat, 1) + markersize --> clamp(pltarea * 800 / sqrt(nsamples), 1, 10) + markeralpha --> clamp(pltarea * 100 / nsamples^0.42, 0.005, 0.4) + + # histograms in the right column + for i = 1:N + compact && i == 1 && continue + @series begin + orientation := :h + seriestype := :histogram + subplot := indices[i + 1 - k, n] + grid := false + view(mat, :, i) + end + end + + # histograms in the top row + for j = 1:N + compact && j == N && continue + @series begin + seriestype := :histogram + subplot := indices[1, j] + grid := false + view(mat, :, j) + end + end + + # scatters + for i = 1:N + vi = view(mat, :, i) + for j = 1:N + # only the lower triangle + if compact && i <= j + continue + end + + vj = view(mat, :, j) + @series begin + ticks := :auto + if i == N + xformatter := :auto + xguide := _cycle(labs, j) + end + if j == 1 + yformatter := :auto + yguide := _cycle(labs, i) + end + seriestype := :scatter + subplot := indices[i + 1 - k, j] + markercolor := grad[0.5 + 0.5C[i, j]] + smooth --> true + markerstrokewidth --> 0 + vj, vi + end + end + # end + end +end diff --git a/StatsPlots/src/corrplot.jl b/StatsPlots/src/corrplot.jl new file mode 100644 index 000000000..f503e5232 --- /dev/null +++ b/StatsPlots/src/corrplot.jl @@ -0,0 +1,121 @@ +""" + corrplot + +This plot type shows the correlation among input variables. +A correlation plot may be produced by a matrix. + + +A correlation matrix can also be created from the columns of a `DataFrame` +using the [`@df`](@ref) macro like so: + +```julia +@df iris corrplot([:SepalLength :SepalWidth :PetalLength :PetalWidth]) +``` + +The marker color in scatter plots reveals the degree of correlation. +Pass the desired colorgradient to `markercolor`. + +With the default gradient positive correlations are blue, neutral are yellow +and negative are red. In the 2d-histograms, the color gradient shows the frequency +of points in that bin (as usual, controlled by `seriescolor`). +""" +@userplot CorrPlot + +recipetype(::Val{:corrplot}, args...) = CorrPlot(args) + +""" + to_corrplot_matrix(mat) + +Transforms the input into a correlation plot matrix. +Meant to be overloaded by other types! +""" +to_corrplot_matrix(x) = x + +function update_ticks_guides(d::KW, labs, i, j, n) + # d[:title] = (i==1 ? _cycle(labs,j) : "") + # d[:xticks] = (i==n) + d[:xguide] = (i == n ? _cycle(labs, j) : "") + # d[:yticks] = (j==1) + d[:yguide] = (j == 1 ? _cycle(labs, i) : "") +end + +@recipe function f(cp::CorrPlot) + mat = to_corrplot_matrix(cp.args[1]) + n = size(mat, 2) + C = cor(mat) + labs = pop!(plotattributes, :label, [""]) + + link := :x # need custom linking for y + layout := (n, n) + legend := false + foreground_color_border := nothing + margin := 1mm + titlefont := font(11) + fillcolor --> PlotsBase.fg_color(plotattributes) + linecolor --> PlotsBase.fg_color(plotattributes) + markeralpha := 0.4 + grad = cgrad(get(plotattributes, :markercolor, :RdYlBu)) + indices = reshape(1:(n^2), n, n)' + title = get(plotattributes, :title, "") + title_location = get(plotattributes, :title_location, :center) + title := "" + + # histograms on the diagonal + for i = 1:n + @series begin + if title != "" && title_location === :left && i == 1 + title := title + end + seriestype := :histogram + subplot := indices[i, i] + grid := false + xformatter --> ((i == n) ? :auto : (x -> "")) + yformatter --> ((i == 1) ? :auto : (y -> "")) + update_ticks_guides(plotattributes, labs, i, i, n) + view(mat, :, i) + end + end + + # scatters + for i = 1:n + ylink := setdiff(vec(indices[i, :]), indices[i, i]) + vi = view(mat, :, i) + for j = 1:n + j == i && continue + vj = view(mat, :, j) + subplot := indices[i, j] + update_ticks_guides(plotattributes, labs, i, j, n) + if i > j + #below diag... scatter + @series begin + seriestype := :scatter + markercolor := grad[0.5 + 0.5C[i, j]] + smooth := true + markerstrokewidth --> 0 + xformatter --> ((i == n) ? :auto : (x -> "")) + yformatter --> ((j == 1) ? :auto : (y -> "")) + vj, vi + end + else + #above diag... hist2d + @series begin + seriestype := get(plotattributes, :seriestype, :histogram2d) + if title != "" && + i == 1 && + ( + (title_location === :center && j == div(n, 2) + 1) || + (title_location === :right && j == n) + ) + if iseven(n) + title_location := :left + end + title := title + end + xformatter --> ((i == n) ? :auto : (x -> "")) + yformatter --> ((j == 1) ? :auto : (y -> "")) + vj, vi + end + end + end + end +end diff --git a/StatsPlots/src/covellipse.jl b/StatsPlots/src/covellipse.jl new file mode 100644 index 000000000..f131701fb --- /dev/null +++ b/StatsPlots/src/covellipse.jl @@ -0,0 +1,40 @@ +""" + covellipse(μ, Σ; showaxes=false, n_std=1, n_ellipse_vertices=100) + +Plot a confidence ellipse of the 2×2 covariance matrix `Σ`, centered at `μ`. +The ellipse is the contour line of a Gaussian density function with mean `μ` +and variance `Σ` at `n_std` standard deviations. +If `showaxes` is true, the two axes of the ellipse are also plotted. +""" +@userplot CovEllipse + +@recipe function f(c::CovEllipse; showaxes = false, n_std = 1, n_ellipse_vertices = 100) + μ, S = _covellipse_args(c.args; n_std = n_std) + + θ = range(0, 2π; length = n_ellipse_vertices) + A = S * [cos.(θ)'; sin.(θ)'] + + @series begin + seriesalpha --> 0.3 + Shape(μ[1] .+ A[1, :], μ[2] .+ A[2, :]) + end + showaxes && @series begin + label := false + linecolor --> "gray" + ([μ[1] + S[1, 1], μ[1], μ[1] + S[1, 2]], [μ[2] + S[2, 1], μ[2], μ[2] + S[2, 2]]) + end +end + +function _covellipse_args( + (μ, Σ)::Tuple{AbstractVector{<:Real},AbstractMatrix{<:Real}}; + n_std::Real, +) + size(μ) == (2,) && size(Σ) == (2, 2) || + error("covellipse requires mean of length 2 and covariance of size 2×2.") + λ, U = eigen(Σ) + μ, n_std * U * diagm(.√λ) +end +_covellipse_args(args; n_std) = error( + "Wrong inputs for covellipse: $(typeof.(args)). " * + "Expected real-valued vector μ, real-valued matrix Σ.", +) diff --git a/StatsPlots/src/dendrogram.jl b/StatsPlots/src/dendrogram.jl new file mode 100644 index 000000000..f73133628 --- /dev/null +++ b/StatsPlots/src/dendrogram.jl @@ -0,0 +1,54 @@ +function treepositions(hc::Hclust, useheight::Bool, orientation = :vertical) + order = StatsBase.indexmap(hc.order) + nodepos = Dict(-i => (float(order[i]), 0.0) for i in hc.order) + + xs = Array{Float64}(undef, 4, size(hc.merges, 1)) + ys = Array{Float64}(undef, 4, size(hc.merges, 1)) + + for i = 1:size(hc.merges, 1) + x1, y1 = nodepos[hc.merges[i, 1]] + x2, y2 = nodepos[hc.merges[i, 2]] + + xpos = (x1 + x2) / 2 + ypos = useheight ? hc.heights[i] : (max(y1, y2) + 1) + + nodepos[i] = (xpos, ypos) + xs[:, i] .= [x1, x1, x2, x2] + ys[:, i] .= [y1, ypos, ypos, y2] + end + if orientation === :horizontal + return ys, xs + else + return xs, ys + end +end + +@recipe function f(hc::Hclust; useheight = true, orientation = :vertical) + typeof(useheight) <: Bool || error("'useheight' argument must be true or false") + + legend --> false + linecolor --> :black + + if orientation === :horizontal + yforeground_color_axis --> :white + ygrid --> false + ylims --> (0.5, length(hc.order) + 0.5) + yticks --> (1:nnodes(hc), string.(1:nnodes(hc))[hc.order]) + if useheight + hs = sum(hc.heights) + xlims --> (0, hs + hs * 0.01) + else + xlims --> (0, Inf) + end + xshowaxis --> useheight + else + xforeground_color_axis --> :white + xgrid --> false + xlims --> (0.5, length(hc.order) + 0.5) + xticks --> (1:nnodes(hc), string.(1:nnodes(hc))[hc.order]) + ylims --> (0, Inf) + yshowaxis --> useheight + end + + treepositions(hc, useheight, orientation) +end diff --git a/StatsPlots/src/df.jl b/StatsPlots/src/df.jl new file mode 100644 index 000000000..8201200c7 --- /dev/null +++ b/StatsPlots/src/df.jl @@ -0,0 +1,226 @@ +""" + `@df d x` + +Convert every symbol in the expression `x` with the respective column in `d` if it exists. + +If you want to avoid replacing the symbol, escape it with `^`. + +`NA` values are replaced with `NaN` for columns of `Float64` and `""` or `Symbol()` +for strings and symbols respectively. + +`x` can be either a plot command or a block of plot commands. +""" +macro df(d, x) + esc(Expr(:call, df_helper(x), d)) +end + +""" + `@df x` + +Curried version of `@df d x`. Outputs an anonymous function `d -> @df d x`. +""" +macro df(x) + esc(df_helper(x)) +end + +function df_helper(x) + i = gensym() + Expr(:(->), i, df_helper(i, x)) +end + +function df_helper(d, x) + if isa(x, Expr) && x.head === :block # meaning that there were multiple plot commands + commands = [ + df_helper(d, xx) for xx in x.args if + !(isa(xx, Expr) && xx.head === :line || isa(xx, LineNumberNode)) + ] # apply the helper recursively + return Expr(:block, commands...) + + elseif isa(x, Expr) && x.head === :call # each function call is operated on alone + syms = Any[] + vars = Symbol[] + plot_call = parse_table_call!(d, x, syms, vars) + names = gensym() + compute_vars = Expr( + :(=), + Expr(:tuple, Expr(:tuple, vars...), names), + Expr(:call, :($(@__MODULE__).extract_columns_and_names), d, syms...), + ) + argnames = _argnames(names, x) + if (length(plot_call.args) >= 2) && + isa(plot_call.args[2], Expr) && + (plot_call.args[2].head === :parameters) + label_plot_call = Expr( + :call, + :($(@__MODULE__).add_label), + plot_call.args[2], + argnames, + plot_call.args[1], + plot_call.args[3:end]..., + ) + else + label_plot_call = + Expr(:call, :($(@__MODULE__).add_label), argnames, plot_call.args...) + end + return Expr(:block, compute_vars, label_plot_call) + + else + error("Second argument ($x) can only be a block or function call") + end +end + +parse_table_call!(d, x, syms, vars) = x + +function parse_table_call!(d, x::QuoteNode, syms, vars) + new_var = gensym(x.value) + push!(syms, x) + push!(vars, new_var) + return new_var +end + +function parse_table_call!(d, x::Expr, syms, vars) + if x.head === :. && length(x.args) == 2 + isa(x.args[2], QuoteNode) && return x + elseif x.head === :call + x.args[1] === :^ && length(x.args) == 2 && return x.args[2] + if x.args[1] === :cols + if length(x.args) == 1 + push!(x.args, :($(@__MODULE__).column_names($d))) + return parse_table_call!(d, x, syms, vars) + end + range = x.args[2] + new_vars = gensym("range") + push!(syms, range) + push!(vars, new_vars) + return new_vars + end + elseif x.head === :braces # From Query: use curly brackets to simplify writing named tuples + new_ex = Expr(:tuple, x.args...) + + for (j, field_in_NT) in enumerate(new_ex.args) + if isa(field_in_NT, Expr) && field_in_NT.head === :(=) + new_ex.args[j] = Expr(:(=), field_in_NT.args...) + elseif field_in_NT isa QuoteNode + new_ex.args[j] = Expr(:(=), field_in_NT.value, field_in_NT) + elseif isa(field_in_NT, Expr) + new_ex.args[j] = Expr( + :(=), + Symbol(filter(t -> t != ':', string(field_in_NT))), + field_in_NT, + ) + elseif isa(field_in_NT, Symbol) + new_ex.args[j] = Expr(:(=), field_in_NT, field_in_NT) + end + end + return parse_table_call!(d, new_ex, syms, vars) + end + return Expr(x.head, (parse_table_call!(d, arg, syms, vars) for arg in x.args)...) +end + +function column_names(t) + s = Tables.schema(t) + s === nothing ? propertynames(first(Tables.rows(t))) : s.names +end + +not_kw(x) = true +not_kw(x::Expr) = !(x.head in [:kw, :parameters]) + +function insert_kw!(x::Expr, s::Symbol, v) + index = isa(x.args[2], Expr) && x.args[2].head === :parameters ? 3 : 2 + x.args = vcat(x.args[1:(index - 1)], Expr(:kw, s, v), x.args[index:end]) +end + +function _argnames(names, x::Expr) + Expr(:vect, [_arg2string(names, s) for s in x.args[2:end] if not_kw(s)]...) +end + +_arg2string(names, x) = stringify(x) +function _arg2string(names, x::Expr) + if x.head === :call && x.args[1] == :cols + return :($(@__MODULE__).compute_name($names, $(x.args[2]))) + elseif x.head === :call && x.args[1] == :hcat + return hcat(stringify.(x.args[2:end])...) + elseif x.head === :hcat + return hcat(stringify.(x.args)...) + else + return stringify(x) + end +end + +stringify(x) = filter(t -> t != ':', string(x)) + +compute_name(names, i::Int) = names[i] +compute_name(names, i::Symbol) = i +compute_name(names, i) = reshape([compute_name(names, ii) for ii in i], 1, :) + +""" + add_label(argnames, f, args...; kwargs...) + +This function ensures that labels are passed to the plotting command, if it accepts them. + +If `f` does not accept keyword arguments, and `kwargs` is empty, it will only +forward `args...`. + +If the user has provided keyword arguments, but `f` does not accept them, +then it will error. +""" +function add_label(argnames, f, args...; kwargs...) + i = findlast(t -> isa(t, Expr) || isa(t, AbstractArray), argnames) + try + if (i === nothing) + return f(args...; kwargs...) + else + return f(label = stringify.(argnames[i]), args...; kwargs...) + end + catch e + if e isa MethodError || + (e isa ErrorException && occursin("does not accept keyword arguments", e.msg)) + # check if the user has supplied kwargs, then we need to rethrow the error + isempty(kwargs) || rethrow(e) + # transmit only args to `f` + return f(args...) + else + rethrow(e) + end + end +end + +get_col(s::Int, col_nt, names) = col_nt[names[s]] +get_col(s::Symbol, col_nt, names) = get(col_nt, s, s) +get_col(syms, col_nt, names) = hcat((get_col(s, col_nt, names) for s in syms)...) + +# get the appropriate name when passed an Integer +add_sym!(cols, i::Integer, names) = push!(cols, names[i]) +# check for errors in Symbols +add_sym!(cols, s::Symbol, names) = s in names ? push!(cols, s) : cols +# recursively extract column names +function add_sym!(cols, s, names) + for si in s + add_sym!(cols, si, names) + end + cols +end + +""" + extract_columns_and_names(df, syms...) + +Extracts columns and their names (if the column number is an integer) +into a slightly complex `Tuple`. + +The structure goes as `((columndata...), names)`. This is unpacked by the [`@df`](@ref) macro into `gensym`'ed variables, which are passed to the plotting function. + +!!! note + If you want to extend the [`@df`](@ref) macro + to work with your custom type, this is the + function you should overload! +""" +function extract_columns_and_names(df, syms...) + Tables.istable(df) || error("Only tables are supported") + names = column_names(df) + + # extract selected column names + selected_cols = add_sym!(Symbol[], syms, names) + + cols = Tables.columntable(TableOperations.select(df, unique(selected_cols)...)) + return Tuple(get_col(s, cols, names) for s in syms), names +end diff --git a/StatsPlots/src/distributions.jl b/StatsPlots/src/distributions.jl new file mode 100644 index 000000000..e56da9b9f --- /dev/null +++ b/StatsPlots/src/distributions.jl @@ -0,0 +1,105 @@ + +# pick a nice default x range given a distribution +function default_range(dist::Distribution, alpha = 0.0001) + minval = isfinite(minimum(dist)) ? minimum(dist) : quantile(dist, alpha) + maxval = isfinite(maximum(dist)) ? maximum(dist) : quantile(dist, 1 - alpha) + minval, maxval +end + +function default_range(m::Distributions.UnivariateMixture, alpha = 0.0001) + mapreduce(_minmax, 1:Distributions.ncomponents(m)) do k + default_range(Distributions.component(m, k), alpha) + end +end + +_minmax((xmin, xmax), (ymin, ymax)) = (min(xmin, ymin), max(xmax, ymax)) + +yz_args(dist) = default_range(dist) +function yz_args(dist::DiscreteUnivariateDistribution) + minval, maxval = extrema(dist) + if isfinite(minval) && isfinite(maxval) # bounded + sup = support(dist) + return sup isa AbstractVector ? (sup,) : ([sup...],) + else # unbounded + return (UnitRange(promote(default_range(dist)...)...),) + end +end + +# this "user recipe" adds a default x vector based on the distribution's μ and σ +@recipe function f(dist::Distribution) + if dist isa DiscreteUnivariateDistribution + seriestype --> :sticks + end + (dist, yz_args(dist)...) +end + +@recipe function f(m::Distributions.UnivariateMixture; components = true) + if m isa DiscreteUnivariateDistribution + seriestype --> :sticks + end + if components + for k = 1:Distributions.ncomponents(m) + c = Distributions.component(m, k) + @series begin + (c, yz_args(c)...) + end + end + else + (m, yz_args(m)...) + end +end + +@recipe function f(distvec::AbstractArray{<:Distribution}, yz...) + for di in distvec + @series begin + seriesargs = isempty(yz) ? yz_args(di) : yz + if di isa DiscreteUnivariateDistribution + seriestype --> :sticks + end + (di, seriesargs...) + end + end +end + +# this "type recipe" replaces any instance of a distribution with a function mapping xi to yi +@recipe f(::Type{T}, dist::T; func = pdf) where {T<:Distribution} = xi -> func(dist, xi) + +#----------------------------------------------------------------------------- +# qqplots + +@recipe function f(h::QQPair; qqline = :identity) + if qqline in (:fit, :quantile, :identity, :R) + xs = [extrema(h.qx)...] + if qqline === :identity + ys = xs + elseif qqline === :fit + itc, slp = hcat(fill!(similar(h.qx), 1), h.qx) \ h.qy + ys = slp .* xs .+ itc + else # if qqline === :quantile || qqline == :R + quantx, quanty = quantile(h.qx, [0.25, 0.75]), quantile(h.qy, [0.25, 0.75]) + slp = diff(quanty) ./ diff(quantx) + ys = quanty .+ slp .* (xs .- quantx) + end + + @series begin + primary := false + seriestype := :path + xs, ys + end + end + + seriestype --> :scatter + legend --> false + h.qx, h.qy +end + +loc(D::Type{T}, x) where {T<:Distribution} = fit(D, x), x +loc(D, x) = D, x + +@userplot QQPlot +recipetype(::Val{:qqplot}, args...) = QQPlot(args) +@recipe f(h::QQPlot) = qqbuild(loc(h.args[1], h.args[2])...) + +@userplot QQNorm +recipetype(::Val{:qqnorm}, args...) = QQNorm(args) +@recipe f(h::QQNorm) = QQPlot((Normal, h.args[1])) diff --git a/StatsPlots/src/dotplot.jl b/StatsPlots/src/dotplot.jl new file mode 100644 index 000000000..e0852d737 --- /dev/null +++ b/StatsPlots/src/dotplot.jl @@ -0,0 +1,116 @@ + +# --------------------------------------------------------------------------- +# Dot Plot (strip plot, beeswarm) + +@recipe function f(::Type{Val{:dotplot}}, x, y, z; mode = :density, side = :both) + # if only y is provided, then x will be UnitRange 1:size(y, 2) + if typeof(x) <: AbstractRange + if step(x) == first(x) == 1 + x = plotattributes[:series_plotindex] + else + x = [getindex(x, plotattributes[:series_plotindex])] + end + end + + grouplabels = sort(collect(unique(x))) + barwidth = plotattributes[:bar_width] + barwidth == nothing && (barwidth = 0.8) + + getoffsets(halfwidth, y) = + mode === :uniform ? (rand(length(y)) .* 2 .- 1) .* halfwidth : + mode === :density ? violinoffsets(halfwidth, y) : zeros(length(y)) + + points_x, points_y = zeros(0), zeros(0) + + for (i, grouplabel) in enumerate(grouplabels) + # filter y + groupy = y[filter(i -> _cycle(x, i) == grouplabel, 1:length(y))] + + center = PlotsBase.discrete_value!(plotattributes, :x, grouplabel)[1] + halfwidth = 0.5_cycle(barwidth, i) + + offsets = getoffsets(halfwidth, groupy) + + if side === :left + offsets = -abs.(offsets) + elseif side === :right + offsets = abs.(offsets) + end + + append!(points_y, groupy) + append!(points_x, center .+ offsets) + end + + seriestype := :scatter + x := points_x + y := points_y + () +end + +PlotsBase.@deps dotplot scatter +PlotsBase.@shorthands dotplot + +function violinoffsets(maxwidth, y) + normalizewidths(maxwidth, widths) = + maxwidth * widths / PlotsBase.ignorenan_maximum(widths) + + function getlocalwidths(widths, centers, y) + upperbounds = + [violincenters[violincenters .> yval] for yval ∈ y] .|> findmin .|> first + lowercenters = findmax.([violincenters[violincenters .≤ yval] for yval ∈ y]) + lowerbounds, lowerindexes = first.(lowercenters), last.(lowercenters) + δs = (y .- lowerbounds) ./ (upperbounds .- lowerbounds) + + itp = interpolate(widths, BSpline(Quadratic(Reflect(OnCell())))) + localwidths = itp.(lowerindexes .+ δs) + end + + violinwidths, violincenters = violin_coords(y) + violinwidths = normalizewidths(maxwidth, violinwidths) + localwidths = getlocalwidths(violinwidths, violincenters, y) + offsets = (rand(length(y)) .* 2 .- 1) .* localwidths +end + +# ------------------------------------------------------------------------------ +# Grouped dotplot + +@userplot GroupedDotplot + +recipetype(::Val{:groupeddotplot}, args...) = GroupedDotplot(args) + +@recipe function f(g::GroupedDotplot; spacing = 0.1) + x, y = grouped_xy(g.args...) + + # extract xnums and set default bar width. + # might need to set xticks as well + ux = unique(x) + x = if eltype(x) <: Number + bar_width --> (0.8 * mean(diff(sort(ux)))) + float.(x) + else + bar_width --> 0.8 + xnums = [findfirst(isequal(xi), ux) for xi in x] .- 0.5 + xticks --> (eachindex(ux) .- 0.5, ux) + xnums + end + + # shift x values for each group + group = get(plotattributes, :group, nothing) + if group != nothing + gb = RecipesPipeline._extract_group_attributes(group) + labels, idxs = getfield(gb, 1), getfield(gb, 2) + n = length(labels) + bws = plotattributes[:bar_width] / n + bar_width := bws * clamp(1 - spacing, 0, 1) + for i = 1:n + groupinds = idxs[i] + Δx = _cycle(bws, i) * (i - (n + 1) / 2) + x[groupinds] .+= Δx + end + end + + seriestype := :dotplot + x, y +end + +PlotsBase.@deps groupeddotplot dotplot diff --git a/StatsPlots/src/ecdf.jl b/StatsPlots/src/ecdf.jl new file mode 100644 index 000000000..8edfd12a0 --- /dev/null +++ b/StatsPlots/src/ecdf.jl @@ -0,0 +1,26 @@ + +# --------------------------------------------------------------------------- +# empirical CDF + +@recipe function f(ecdf::StatsBase.ECDF) + seriestype := :steppost + legend --> :topleft + x = [ecdf.sorted_values[1]; ecdf.sorted_values] + if :weights in propertynames(ecdf) && !isempty(ecdf.weights) + # support StatsBase versions >v0.32.0 + y = [0; cumsum(ecdf.weights) ./ sum(ecdf.weights)] + else + y = range(0, 1; length = length(x)) + end + x, y +end + +@userplot ECDFPlot +recipetype(::Val{:ecdfplot}, args...) = ECDFPlot(args) +@recipe function f(p::ECDFPlot) + x = p.args[1] + if !isa(x, StatsBase.ECDF) + x = StatsBase.ecdf(x) + end + x +end diff --git a/StatsPlots/src/errorline.jl b/StatsPlots/src/errorline.jl new file mode 100644 index 000000000..229d6209d --- /dev/null +++ b/StatsPlots/src/errorline.jl @@ -0,0 +1,272 @@ +@userplot ErrorLine +""" +# StatsPlots.errorline(x, y, arg): + Function for parsing inputs to easily make a [`ribbons`] (https://ggplot2.tidyverse.org/reference/geom_ribbon.html), + stick errorbar (https://www.mathworks.com/help/matlab/ref/errorbar.html), or plume + (https://stackoverflow.com/questions/65510619/how-to-prepare-my-data-for-plume-plots) plot while allowing + for easily controlling error type and NaN handling. + +# Inputs: default values are indicated with *s + + x (vector, unit range) - the values along the x-axis for each y-point + + y (matrix [x, repeat, group]) - values along y-axis wrt x. The first dimension must be of equal length to that of x. + The second dimension is treated as the repeated observations and error is computed along this dimension. If the + matrix has a 3rd dimension this is treated as a new group. + + error_style (`Symbol` - *:ribbon*, :stick, :plume) - determines whether to use a ribbon style or stick style error + representation. + + centertype (symbol - *:mean* or :median) - which approach to use to represent the central value of y at each x-value. + + errortype (symbol - *:std*, :sem, :percentile) - which error metric to use to show the distribution of y at each x-value. + + percentiles (Vector{Int64} *[25, 75]*) - if using errortype === :percentile then which percentiles to use as bounds. + + groupcolor (Symbol, RGB, Vector of Symbol or RGB) - Declares the color for each group. If no value is passed then will use + the default colorscheme. If one value is given then it will use that color for all groups. If multiple colors are + given then it will use a different color for each group. + + secondarycolor (`Symbol`, `RGB`, `:matched` - *:Gray60*) - When using stick mode this will allow for the setting of the stick color. + If `:matched` is given then the color of the sticks with match that of the main line. + + secondarylinealpha (float *.1*) - alpha value of plume lines. + + numsecondarylines (int *100*) - number of plume lines to plot behind central line. + + stickwidth (Float64 *.01*) - How much of the x-axis the horizontal aspect of the error stick should take up. + +# Example +```julia +x = 1:10 +y = fill(NaN, 10, 100, 3) +for i = axes(y,3) + y[:,:,i] = collect(1:2:20) .+ rand(10,100).*5 .* collect(1:2:20) .+ rand()*100 +end + +y = reshape(1:100, 10, 10); +errorline(1:10, y) +``` +""" +errorline + +function compute_error( + y::AbstractMatrix, + centertype::Symbol, + errortype::Symbol, + percentiles::AbstractVector, +) + y_central = fill(NaN, size(y, 1)) + + # NaNMath doesn't accept Ints so convert to AbstractFloat if necessary + if eltype(y) <: Integer + y = float(y) + end + # First compute the center + y_central = if centertype === :mean + mapslices(NaNMath.mean, y, dims = 2) + elseif centertype === :median + mapslices(NaNMath.median, y, dims = 2) + else + error("Invalid center type. Valid symbols include :mean or :median") + end + + # Takes 2d matrix [x,y] and computes the desired error type for each row (value of x) + if errortype === :std || errortype === :sem + y_error = mapslices(NaNMath.std, y, dims = 2) + if errortype == :sem + y_error = y_error ./ sqrt(size(y, 2)) + end + + elseif errortype === :percentile + y_lower = fill(NaN, size(y, 1)) + y_upper = fill(NaN, size(y, 1)) + if any(isnan.(y)) # NaNMath does not have a percentile function so have to go via StatsBase + for i in axes(y, 1) + yi = y[i, .!isnan.(y[i, :])] + y_lower[i] = percentile(yi, percentiles[1]) + y_upper[i] = percentile(yi, percentiles[2]) + end + else + y_lower = mapslices(Y -> percentile(Y, percentiles[1]), y, dims = 2) + y_upper = mapslices(Y -> percentile(Y, percentiles[2]), y, dims = 2) + end + + y_error = (y_central .- y_lower, y_upper .- y_central) # Difference from center value + else + error("Invalid error type. Valid symbols include :std, :sem, :percentile") + end + + return y_central, y_error +end + +@recipe function f( + e::ErrorLine; + errorstyle = :ribbon, + centertype = :mean, + errortype = :std, + percentiles = [25, 75], + groupcolor = nothing, + secondarycolor = nothing, + stickwidth = 0.01, + secondarylinealpha = 0.1, + numsecondarylines = 100, + secondarylinewidth = 1, +) + if length(e.args) == 1 # If only one input is given assume it is y-values in the form [x,obs] + y = e.args[1] + x = 1:size(y, 1) + else # Otherwise assume that the first two inputs are x and y + x = e.args[1] + y = e.args[2] + + # Check y orientation + ndims(y) > 3 && error("ndims(y) > 3") + + if !any(size(y) .== length(x)) + error("Size of x and y do not match") + elseif ndims(y) == 2 && size(y, 1) != length(x) && size(y, 2) == length(x) # Check if y needs to be transposed or transmuted + y = transpose(y) + elseif ndims(y) == 3 && size(y, 1) != length(x) + error( + "When passing a 3 dimensional matrix as y, the axes must be [x, repeat, group]", + ) + end + end + + # Determine if a color palette is being used so it can be passed to secondary lines + if :color_palette ∉ keys(plotattributes) + color_palette = :default + else + color_palette = plotattributes[:color_palette] + end + + # Parse different color type + if groupcolor isa Symbol || groupcolor isa RGB{Float64} || groupcolor isa RGBA{Float64} + groupcolor = [groupcolor] + end + + # Check groupcolor format + if (groupcolor !== nothing && ndims(y) > 2) && length(groupcolor) == 1 + groupcolor = repeat(groupcolor, size(y, 3)) # Use the same color for all groups + elseif (groupcolor !== nothing && ndims(y) > 2) && length(groupcolor) < size(y, 3) + error("$(length(groupcolor)) colors given for a matrix with $(size(y,3)) groups") + elseif groupcolor === nothing + gsi_counter = 0 + for i = 1:length(plotattributes[:plot_object].series_list) + if plotattributes[:plot_object].series_list[i].plotattributes[:primary] + gsi_counter += 1 + end + end + # Start at next index and allow wrapping of indices + gsi_counter += 1 + idx = (gsi_counter:(gsi_counter + size(y, 3))) .% length(palette(color_palette)) + idx[findall(x -> x == 0, idx)] .= length(palette(color_palette)) + groupcolor = palette(color_palette)[idx] + end + + if errorstyle === :plume && numsecondarylines > size(y, 2) # Override numsecondarylines + numsecondarylines = size(y, 2) + end + + for g in axes(y, 3) # Iterate through 3rd dimension + # Compute center and distribution for each value of x + y_central, y_error = compute_error(y[:, :, g], centertype, errortype, percentiles) + + if errorstyle === :ribbon + seriestype := :path + @series begin + x := x + y := y_central + ribbon := y_error + fillalpha --> 0.1 + linecolor := groupcolor[g] + fillcolor := groupcolor[g] + () # Suppress implicit return + end + + elseif errorstyle === :stick + x_offset = diff(extrema(x) |> collect)[1] * stickwidth + seriestype := :path + for (i, xi) in enumerate(x) + # Error sticks + @series begin + primary := false + x := + [xi - x_offset, xi + x_offset, xi, xi, xi + x_offset, xi - x_offset] + if errortype === :percentile + y := [ + repeat([y_central[i] - y_error[1][i]], 3) + repeat([y_central[i] + y_error[2][i]], 3) + ] + else + y := [ + repeat([y_central[i] - y_error[i]], 3) + repeat([y_central[i] + y_error[i]], 3) + ] + end + # Set the stick color + if secondarycolor === nothing + linecolor := :gray60 + elseif secondarycolor === :matched + linecolor := groupcolor[g] + else + linecolor := secondarycolor + end + linewidth := secondarylinewidth + () # Suppress implicit return + end + end + + # Base line + seriestype := :line + @series begin + primary := true + x := x + y := y_central + linecolor := groupcolor[g] + () + end + + elseif errorstyle === :plume + num_obs = size(y, 2) + if num_obs > numsecondarylines + sub_sample_idx = sample(1:num_obs, numsecondarylines, replace = false) + y_sub_sample = y[:, sub_sample_idx, g] + else + y_sub_sample = y[:, :, g] + end + seriestype := :path + for i = 1:numsecondarylines + # Background paths + @series begin + primary := false + x := x + y := y_sub_sample[:, i] + # Set the stick color + if secondarycolor === nothing || secondarycolor === :matched + linecolor := groupcolor[g] + else + linecolor := secondarycolor + end + linealpha := secondarylinealpha + linewidth := secondarylinewidth + () # Suppress implicit return + end + end + + # Base line + seriestype := :line + @series begin + primary := true + x := x + y := y_central + linecolor := groupcolor[g] + linewidth --> 3 # Make it stand out against the plume better + () + end + else + error("Invalid error style. Valid symbols include :ribbon, :stick, or :plume.") + end + end +end diff --git a/StatsPlots/src/hist.jl b/StatsPlots/src/hist.jl new file mode 100644 index 000000000..a916be8bb --- /dev/null +++ b/StatsPlots/src/hist.jl @@ -0,0 +1,252 @@ + +# --------------------------------------------------------------------------- +# density + +@recipe function f( + ::Type{Val{:density}}, + x, + y, + z; + trim = false, + bandwidth = KernelDensity.default_bandwidth(y), +) + newx, newy = + violin_coords(y, trim = trim, wts = plotattributes[:weights], bandwidth = bandwidth) + if isvertical(plotattributes) + newx, newy = newy, newx + end + x := newx + y := newy + seriestype := :path + () +end +PlotsBase.@deps density path + +# --------------------------------------------------------------------------- +# cumulative density + +@recipe function f( + ::Type{Val{:cdensity}}, + x, + y, + z; + trim = false, + npoints = 200, + bandwidth = KernelDensity.default_bandwidth(y), +) + newx, newy = + violin_coords(y, trim = trim, wts = plotattributes[:weights], bandwidth = bandwidth) + + if isvertical(plotattributes) + newx, newy = newy, newx + end + + newy = cumsum(float(yi) for yi in newy) + newy ./= newy[end] + + x := newx + y := newy + seriestype := :path + () +end +PlotsBase.@deps cdensity path + +ea_binnumber(y, bin::AbstractVector) = + error("You cannot specify edge locations for equal area histogram") +ea_binnumber(y, bin::Real) = + (floor(bin) == bin || error("Only integer or symbol values accepted by bins"); Int(bin)) +ea_binnumber(y, bin::Int) = bin +ea_binnumber(y, bin::Symbol) = PlotsBase._auto_binning_nbins((y,), 1, mode = bin) + +@recipe function f(::Type{Val{:ea_histogram}}, x, y, z) + bin = ea_binnumber(y, plotattributes[:bins]) + bins := quantile(y, range(0, stop = 1, length = bin + 1)) + normalize := :density + seriestype := :barhist + () +end +PlotsBase.@deps histogram barhist + +push!(PlotsBase.Commons._histogram_like, :ea_histogram) + +@shorthands ea_histogram + +@recipe function f(::Type{Val{:testhist}}, x, y, z) + markercolor --> :red + seriestype := :scatter + () +end +@shorthands testhist + +# --------------------------------------------------------------------------- +# grouped histogram + +@userplot GroupedHist + +PlotsBase.group_as_matrix(g::GroupedHist) = true + +@recipe function f(p::GroupedHist) + _, v = grouped_xy(p.args...) + group = get(plotattributes, :group, nothing) + bins = get(plotattributes, :bins, :auto) + normed = get(plotattributes, :normalize, false) + weights = get(plotattributes, :weights, nothing) + + # compute edges from ungrouped data + h = PlotsBase._make_hist((vec(copy(v)),), bins; normed = normed, weights = weights) + nbins = length(h.weights) + edges = h.edges[1] + bar_width --> mean(map(i -> edges[i + 1] - edges[i], 1:nbins)) + x = map(i -> (edges[i] + edges[i + 1]) / 2, 1:nbins) + + if group === nothing + y = reshape(h.weights, nbins, 1) + else + gb = RecipesPipeline._extract_group_attributes(group) + labels, idxs = getfield(gb, 1), getfield(gb, 2) + ngroups = length(labels) + ntot = count(x -> !isnan(x), v) + + # compute weights (frequencies) by group using those edges + y = fill(NaN, nbins, ngroups) + for i = 1:ngroups + groupinds = idxs[i] + v_i = filter(x -> !isnan(x), v[:, i]) + w_i = weights == nothing ? nothing : weights[groupinds] + h_i = PlotsBase._make_hist((v_i,), h.edges; normed = false, weights = w_i) + if normed + y[:, i] .= h_i.weights .* (length(v_i) / ntot / sum(h_i.weights)) + else + y[:, i] .= h_i.weights + end + end + end + + GroupedBar((x, y)) +end + +# --------------------------------------------------------------------------- +# Compute binsizes using Wand (1997)'s criterion +# Ported from R code located here https://github.com/cran/KernSmooth/tree/master/R + +"Returns optimal histogram edge positions in accordance to Wand (1995)'s criterion'" +PlotsBase.wand_edges(x::AbstractVector, args...) = (binwidth = wand_bins(x, args...); +(minimum(x) - binwidth):binwidth:(maximum(x) + binwidth)) + +"Returns optimal histogram bin widths in accordance to Wand (1995)'s criterion'" +function wand_bins(x, scalest = :minim, gridsize = 401, range_x = extrema(x), trun = true) + n = length(x) + minx, maxx = range_x + gpoints = range(minx, stop = maxx, length = gridsize) + gcounts = linbin(x, gpoints, trun = trun) + + scalest = if scalest === :stdev + sqrt(var(x)) + elseif scalest === :iqr + (quantile(x, 3 // 4) - quantile(x, 1 // 4)) / 1.349 + elseif scalest === :minim + min((quantile(x, 3 // 4) - quantile(x, 1 // 4)) / 1.349, sqrt(var(x))) + else + error("scalest must be one of :stdev, :iqr or :minim (default)") + end + + scalest == 0 && error("scale estimate is zero for input data") + sx = (x .- mean(x)) ./ scalest + sa = (minx - mean(x)) / scalest + sb = (maxx - mean(x)) / scalest + + gpoints = range(sa, stop = sb, length = gridsize) + gcounts = linbin(sx, gpoints, trun = trun) + + hpi = begin + alpha = ((2 / (11 * n))^(1 / 13)) * sqrt(2) + psi10hat = bkfe(gcounts, 10, alpha, [sa, sb]) + alpha = (-105 * sqrt(2 / pi) / (psi10hat * n))^(1 // 11) + psi8hat = bkfe(gcounts, 8, alpha, [sa, sb]) + alpha = (15 * sqrt(2 / pi) / (psi8hat * n))^(1 / 9) + psi6hat = bkfe(gcounts, 6, alpha, [sa, sb]) + alpha = (-3 * sqrt(2 / pi) / (psi6hat * n))^(1 / 7) + psi4hat = bkfe(gcounts, 4, alpha, [sa, sb]) + alpha = (sqrt(2 / pi) / (psi4hat * n))^(1 / 5) + psi2hat = bkfe(gcounts, 2, alpha, [sa, sb]) + (6 / (-psi2hat * n))^(1 / 3) + end + + scalest * hpi +end + +function linbin(X, gpoints; trun = true) + n, M = length(X), length(gpoints) + + a, b = gpoints[1], gpoints[M] + gcnts = zeros(M) + delta = (b - a) / (M - 1) + + for i = 1:n + lxi = ((X[i] - a) / delta) + 1 + li = floor(Int, lxi) + rem = lxi - li + + if 1 <= li < M + gcnts[li] += 1 - rem + gcnts[li + 1] += rem + end + + if !trun + if lt < 1 + gcnts[1] += 1 + end + + if li >= M + gcnts[M] += 1 + end + end + end + gcnts +end + +"binned kernel function estimator" +function bkfe(gcounts, drv, bandwidth, range_x) + bandwidth <= 0 && error("'bandwidth' must be strictly positive") + + a, b = range_x + h = bandwidth + M = length(gcounts) + gpoints = range(a, stop = b, length = M) + + ## Set the sample size and bin width + + n = sum(gcounts) + delta = (b - a) / (M - 1) + + ## Obtain kernel weights + + tau = 4 + drv + L = min(Int(fld(tau * h, delta)), M) + + lvec = 0:L + arg = lvec .* delta / h + + kappam = pdf.(Normal(), arg) ./ h^(drv + 1) + hmold0, hmnew = ones(length(arg)), ones(length(arg)) + hmold1 = arg + + if drv >= 2 + for i in (2:drv) + hmnew = arg .* hmold1 .- (i - 1) .* hmold0 + hmold0 = hmold1 # Compute mth degree Hermite polynomial + hmold1 = hmnew # by recurrence. + end + end + kappam = hmnew .* kappam + + ## Now combine weights and counts to obtain estimate + ## we need P >= 2L+1L, M: L <= M. + P = nextpow(2, M + L + 1) + kappam = [kappam; zeros(P - 2 * L - 1); reverse(kappam[2:end])] + Gcounts = [gcounts; zeros(P - M)] + kappam = fft(kappam) + Gcounts = fft(Gcounts) + + sum(gcounts .* (real(ifft(kappam .* Gcounts)))[1:M]) / (n^2) +end diff --git a/StatsPlots/src/interact.jl b/StatsPlots/src/interact.jl new file mode 100644 index 000000000..6afedb77d --- /dev/null +++ b/StatsPlots/src/interact.jl @@ -0,0 +1,110 @@ +plot_function(plt::Function, grouped) = plt +plot_function(plt::Tuple, grouped) = grouped ? plt[2] : plt[1] + +combine_cols(dict, ns) = length(ns) > 1 ? hcat((dict[n] for n in ns)...) : dict[ns[1]] + +function dataviewer(t; throttle = 0.1, nbins = 30, nbins_range = 1:100) + (t isa AbstractObservable) || (t = Observable{Any}(t)) + + coltable = map(Tables.columntable, t) + + @show names = map(collect ∘ keys, coltable) + + dict = @map Dict((key, val) for (key, val) in pairs(&coltable)) + x = Widgets.dropdown(names, placeholder = "First axis", multiple = true) + y = Widgets.dropdown(names, placeholder = "Second axis", multiple = true) + y_toggle = Widgets.togglecontent(y, value = false, label = "Second axis") + plot_type = Widgets.dropdown( + OrderedDict( + "line" => PlotsBase.plot, + "scatter" => PlotsBase.scatter, + "bar" => (PlotsBase.bar, StatsPlots.groupedbar), + "boxplot" => (StatsPlots.boxplot, StatsPlots.groupedboxplot), + "corrplot" => StatsPlots.corrplot, + "cornerplot" => StatsPlots.cornerplot, + "density" => StatsPlots.density, + "cdensity" => StatsPlots.cdensity, + "histogram" => StatsPlots.histogram, + "marginalhist" => StatsPlots.marginalhist, + "violin" => (StatsPlots.violin, StatsPlots.groupedviolin), + ), + placeholder = "Plot type", + ) + + # Add bins if the plot allows it + display_nbins = + @map (&plot_type) in [corrplot, cornerplot, histogram, marginalhist] ? "block" : + "none" + nbins = (Widgets.slider( + nbins_range, + extra_obs = ["display" => display_nbins], + value = nbins, + label = "number of bins", + )) + nbins.scope.dom = Widgets.div( + nbins.scope.dom, + attributes = Dict("data-bind" => "style: {display: display}"), + ) + nbins_throttle = Observables.throttle(throttle, nbins) + + by = Widgets.dropdown(names, multiple = true, placeholder = "Group by") + by_toggle = Widgets.togglecontent(by, value = false, label = "Split data") + plt = Widgets.button("plot") + output = @map begin + if (&plt == 0) + plot() + else + args = Any[] + # add first and maybe second argument + push!(args, combine_cols(&dict, x[])) + has_y = y_toggle[] && !isempty(y[]) + has_y && push!(args, combine_cols(&dict, y[])) + + # compute automatic kwargs + kwargs = Dict() + + # grouping kwarg + has_by = by_toggle[] && !isempty(by[]) + by_tup = Tuple(getindex(&dict, b) for b in by[]) + has_by && (kwargs[:group] = NamedTuple{Tuple(by[])}(by_tup)) + + # label kwarg + if length(x[]) > 1 + kwargs[:label] = x[] + elseif y_toggle[] && length(y[]) > 1 + kwargs[:label] = y[] + end + + # x and y labels + densityplot1D = plot_type[] in [cdensity, density, histogram] + (length(x[]) == 1 && (densityplot1D || has_y)) && (kwargs[:xlabel] = x[][1]) + if has_y && length(y[]) == 1 + kwargs[:ylabel] = y[][1] + elseif !has_y && !densityplot1D && length(x[]) == 1 + kwargs[:ylabel] = x[][1] + end + + plot_func = plot_function(plot_type[], has_by) + plot_func(args...; nbins = &nbins_throttle, kwargs...) + end + end + wdg = Widget{:dataviewer}( + [ + "x" => x, + "y" => y, + "y_toggle" => y_toggle, + "by" => by, + "by_toggle" => by_toggle, + "plot_type" => plot_type, + "plot_button" => plt, + "nbins" => nbins, + ], + output = output, + ) + @layout! wdg Widgets.div( + Widgets.div(:x, :y_toggle, :plot_type, :by_toggle, :plot_button), + Widgets.div(style = Dict("width" => "3em")), + Widgets.div(Widgets.observe(_), :nbins), + style = Dict("display" => "flex", "direction" => "row"), + ) +end diff --git a/StatsPlots/src/marginalhist.jl b/StatsPlots/src/marginalhist.jl new file mode 100644 index 000000000..fe50662fa --- /dev/null +++ b/StatsPlots/src/marginalhist.jl @@ -0,0 +1,75 @@ +@shorthands marginalhist + +@recipe function f(::Type{Val{:marginalhist}}, plt::AbstractPlot; density = false) + x, y = plotattributes[:x], plotattributes[:y] + i = isfinite.(x) .& isfinite.(y) + x, y = x[i], y[i] + bns = get(plotattributes, :bins, :auto) + scale = get(plotattributes, :scale, :identity) + edges1, edges2 = PlotsBase._hist_edges((x, y), bns) + xlims, ylims = map( + x -> PlotsBase.Axes.scale_lims( + PlotsBase.ignorenan_extrema(x)..., + PlotsBase.Axes.default_widen_factor, + scale, + ), + (x, y), + ) + + # set up the subplots + legend --> false + link := :both + grid --> false + layout --> @layout [ + tophist _ + hist2d{0.9w,0.9h} righthist + ] + + # main histogram2d + @series begin + seriestype := :histogram2d + right_margin --> 0PlotsBase.mm + top_margin --> 0PlotsBase.mm + subplot := 2 + bins := (edges1, edges2) + xlims --> xlims + ylims --> ylims + end + + # these are common to both marginal histograms + ticks := nothing + xguide := "" + yguide := "" + foreground_color_border := nothing + fillcolor --> PlotsBase.fg_color(plotattributes) + linecolor --> PlotsBase.fg_color(plotattributes) + + if density + trim := true + seriestype := :density + else + seriestype := :histogram + end + + # upper histogram + @series begin + subplot := 1 + bottom_margin --> 0PlotsBase.mm + bins := edges1 + y := x + xlims --> xlims + end + + # right histogram + @series begin + orientation := :h + subplot := 3 + left_margin --> 0PlotsBase.mm + bins := edges2 + y := y + ylims --> ylims + end +end + +# # now you can plot like: +# marginalhist(rand(1000), rand(1000)) diff --git a/StatsPlots/src/marginalkde.jl b/StatsPlots/src/marginalkde.jl new file mode 100644 index 000000000..6adc62376 --- /dev/null +++ b/StatsPlots/src/marginalkde.jl @@ -0,0 +1,75 @@ +@userplot MarginalKDE + +@recipe function f(kc::MarginalKDE; levels = 10, clip = ((-3.0, 3.0), (-3.0, 3.0))) + x, y = kc.args + + x = vec(x) + y = vec(y) + + m_x = median(x) + m_y = median(y) + + dx_l = m_x - quantile(x, 0.16) + dx_h = quantile(x, 0.84) - m_x + + dy_l = m_y - quantile(y, 0.16) + dy_h = quantile(y, 0.84) - m_y + + xmin = m_x + clip[1][1] * dx_l + xmax = m_x + clip[1][2] * dx_h + + ymin = m_y + clip[2][1] * dy_l + ymax = m_y + clip[2][2] * dy_h + + k = KernelDensity.kde((x, y)) + kx = KernelDensity.kde(x) + ky = KernelDensity.kde(y) + + ps = pdf.(Ref(k), x, y) + + ls = [] + for p in range(1.0 / levels, stop = 1 - 1.0 / levels, length = levels - 1) + push!(ls, quantile(ps, p)) + end + + legend --> false + layout := @layout [ + topdensity _ + contour{0.9w,0.9h} rightdensity + ] + + @series begin + seriestype := :contour + levels := ls + fill := false + colorbar := false + subplot := 2 + xlims := (xmin, xmax) + ylims := (ymin, ymax) + + (collect(k.x), collect(k.y), k.density') + end + + ticks := nothing + xguide := "" + yguide := "" + + @series begin + seriestype := :density + subplot := 1 + xlims := (xmin, xmax) + ylims := (0, 1.1 * maximum(kx.density)) + + x + end + + @series begin + seriestype := :density + subplot := 3 + orientation := :h + xlims := (0, 1.1 * maximum(ky.density)) + ylims := (ymin, ymax) + + y + end +end diff --git a/StatsPlots/src/marginalscatter.jl b/StatsPlots/src/marginalscatter.jl new file mode 100644 index 000000000..5641fca53 --- /dev/null +++ b/StatsPlots/src/marginalscatter.jl @@ -0,0 +1,74 @@ +@shorthands marginalscatter + +@recipe function f(::Type{Val{:marginalscatter}}, plt::AbstractPlot; density = false) + x, y = plotattributes[:x], plotattributes[:y] + i = isfinite.(x) .& isfinite.(y) + x, y = x[i], y[i] + scale = get(plotattributes, :scale, :identity) + xlims, ylims = map( + x -> PlotsBase.Axes.scale_lims( + PlotsBase.ignorenan_extrema(x)..., + PlotsBase.Axes.default_widen_factor, + scale, + ), + (x, y), + ) + + # set up the subplots + legend --> false + link := :both + grid --> false + layout --> @layout [ + topscatter _ + scatter2d{0.9w,0.9h} rightscatter + ] + + # main scatter2d + @series begin + seriestype := :scatter + right_margin --> 0PlotsBase.mm + top_margin --> 0PlotsBase.mm + subplot := 2 + xlims --> xlims + ylims --> ylims + end + + # these are common to both marginal scatter + ticks := nothing + xguide := "" + yguide := "" + fillcolor --> PlotsBase.fg_color(plotattributes) + linecolor --> PlotsBase.fg_color(plotattributes) + + if density + trim := true + seriestype := :density + else + seriestype := :scatter + end + + # upper scatter + @series begin + subplot := 1 + bottom_margin --> 0PlotsBase.mm + showaxis := :x + x := x + y := ones(y |> size) + xlims --> xlims + ylims --> (0.95, 1.05) + end + + # right scatter + @series begin + orientation := :h + showaxis := :y + subplot := 3 + left_margin --> 0PlotsBase.mm + # bins := edges2 + y := y + x := ones(x |> size) + end +end + +# # now you can plot like: +# marginalscatter(rand(1000), rand(1000)) diff --git a/StatsPlots/src/ordinations.jl b/StatsPlots/src/ordinations.jl new file mode 100644 index 000000000..5615b6424 --- /dev/null +++ b/StatsPlots/src/ordinations.jl @@ -0,0 +1,24 @@ +@recipe function f(mds::MultivariateStats.MDS{<:Real}; mds_axes = (1, 2)) + length(mds_axes) in [2, 3] || throw(ArgumentError("Can only accept 2 or 3 mds axes")) + xax = mds_axes[1] + yax = mds_axes[2] + tfm = collect(MultivariateStats.predict(mds)') + + xlabel --> "MDS$xax" + ylabel --> "MDS$yax" + seriestype := :scatter + aspect_ratio --> 1 + + if length(mds_axes) == 3 + zax = mds_axes[3] + zlabel --> "MDS$zax" + tfm[:, xax], tfm[:, yax], tfm[:, zax] + else + tfm[:, xax], tfm[:, yax] + end +end + +#= This needs to wait on a different PCA API in MultivariateStats.jl +@recipe function f(pca::PCA{<:Real}; pca_axes=(1,2)) +end +=# diff --git a/StatsPlots/src/violin.jl b/StatsPlots/src/violin.jl new file mode 100644 index 000000000..1ee41277a --- /dev/null +++ b/StatsPlots/src/violin.jl @@ -0,0 +1,215 @@ + +# --------------------------------------------------------------------------- +# Violin Plot + +const _violin_warned = [false] + +function violin_coords( + y; + wts = nothing, + trim::Bool = false, + bandwidth = KernelDensity.default_bandwidth(y), +) + kd = + wts === nothing ? KernelDensity.kde(y, npoints = 200, bandwidth = bandwidth) : + KernelDensity.kde(y, weights = weights(wts), npoints = 200, bandwidth = bandwidth) + if trim + xmin, xmax = PlotsBase.ignorenan_extrema(y) + inside = Bool[xmin <= x <= xmax for x in kd.x] + return (kd.density[inside], kd.x[inside]) + end + kd.density, kd.x +end + +get_quantiles(quantiles::AbstractVector) = quantiles +get_quantiles(x::Real) = [x] +get_quantiles(b::Bool) = b ? [0.5] : Float64[] +get_quantiles(n::Int) = range(0, 1, length = n + 2)[2:(end - 1)] + +@recipe function f( + ::Type{Val{:violin}}, + x, + y, + z; + trim = true, + side = :both, + show_mean = false, + show_median = false, + quantiles = Float64[], + bandwidth = KernelDensity.default_bandwidth(y), +) + # if only y is provided, then x will be UnitRange 1:size(y,2) + if typeof(x) <: AbstractRange + x = if step(x) == first(x) == 1 + plotattributes[:series_plotindex] + else + [getindex(x, plotattributes[:series_plotindex])] + end + end + xsegs, ysegs = Plots.PlotsBase.Segments(), Plots.PlotsBase.Segments() + qxsegs, qysegs = Plots.PlotsBase.Segments(), Plots.PlotsBase.Segments() + mxsegs, mysegs = Plots.PlotsBase.Segments(), Plots.PlotsBase.Segments() + glabels = sort(collect(unique(x))) + bw = plotattributes[:bar_width] + bw == nothing && (bw = 0.8) + msc = plotattributes[:markerstrokecolor] + for (i, glabel) in enumerate(glabels) + fy = y[filter(i -> _cycle(x, i) == glabel, 1:length(y))] + widths, centers = violin_coords( + fy, + trim = trim, + wts = plotattributes[:weights], + bandwidth = bandwidth, + ) + isempty(widths) && continue + + # normalize + hw = 0.5_cycle(bw, i) + widths = hw * widths / PlotsBase.ignorenan_maximum(widths) + + # make the violin + xcenter = PlotsBase.discrete_value!(plotattributes, :x, glabel)[1] + xcoords = if (side === :right) + vcat(widths, zeros(length(widths))) .+ xcenter + elseif (side === :left) + vcat(zeros(length(widths)), -reverse(widths)) .+ xcenter + else + vcat(widths, -reverse(widths)) .+ xcenter + end + ycoords = vcat(centers, reverse(centers)) + + push!(xsegs, xcoords) + push!(ysegs, ycoords) + + if show_mean + mea = StatsBase.mean(fy) + mw = maximum(widths) + mx = xcenter .+ [-mw, mw] * 0.75 + my = [mea, mea] + if side === :right + mx[1] = xcenter + elseif side === :left + mx[2] = xcenter + end + + push!(mxsegs, mx) + push!(mysegs, my) + end + + if show_median + med = StatsBase.median(fy) + mw = maximum(widths) + mx = xcenter .+ [-mw, mw] / 2 + my = [med, med] + if side === :right + mx[1] = xcenter + elseif side === :left + mx[2] = xcenter + end + + push!(qxsegs, mx) + push!(qysegs, my) + end + + quantiles = get_quantiles(quantiles) + if !isempty(quantiles) + qy = quantile(fy, quantiles) + maxw = maximum(widths) + + for i in eachindex(qy) + qxi = xcenter .+ [-maxw, maxw] * (0.5 - abs(0.5 - quantiles[i])) + qyi = [qy[i], qy[i]] + if side === :right + qxi[1] = xcenter + elseif side === :left + qxi[2] = xcenter + end + + push!(qxsegs, qxi) + push!(qysegs, qyi) + end + + push!(qxsegs, [xcenter, xcenter]) + push!(qysegs, [extrema(qy)...]) + end + end + + @series begin + seriestype := :shape + x := xsegs.pts + y := ysegs.pts + () + end + + if !isempty(mxsegs.pts) + @series begin + primary := false + seriestype := :shape + linestyle := :dot + x := mxsegs.pts + y := mysegs.pts + () + end + end + + if !isempty(qxsegs.pts) + @series begin + primary := false + seriestype := :shape + x := qxsegs.pts + y := qysegs.pts + () + end + end + + seriestype := :shape + primary := false + x := [] + y := [] + () +end +PlotsBase.@deps violin shape + +# ------------------------------------------------------------------------------ +# Grouped Violin + +@userplot GroupedViolin + +recipetype(::Val{:groupedviolin}, args...) = GroupedViolin(args) + +@recipe function f(g::GroupedViolin; spacing = 0.1) + x, y = grouped_xy(g.args...) + + # extract xnums and set default bar width. + # might need to set xticks as well + ux = unique(x) + x = if eltype(x) <: Number + bar_width --> (0.8 * mean(diff(sort(ux)))) + float.(x) + else + bar_width --> 0.8 + xnums = [findfirst(isequal(xi), ux) for xi in x] .- 0.5 + xticks --> (eachindex(ux) .- 0.5, ux) + xnums + end + + # shift x values for each group + group = get(plotattributes, :group, nothing) + if group != nothing + gb = RecipesPipeline._extract_group_attributes(group) + labels, idxs = getfield(gb, 1), getfield(gb, 2) + n = length(labels) + bws = plotattributes[:bar_width] / n + bar_width := bws * clamp(1 - spacing, 0, 1) + for i = 1:n + groupinds = idxs[i] + Δx = _cycle(bws, i) * (i - (n + 1) / 2) + x[groupinds] .+= Δx + end + end + + seriestype := :violin + x, y +end + +PlotsBase.@deps groupedviolin violin diff --git a/StatsPlots/test/runtests.jl b/StatsPlots/test/runtests.jl new file mode 100644 index 000000000..831191285 --- /dev/null +++ b/StatsPlots/test/runtests.jl @@ -0,0 +1,494 @@ +using MultivariateStats +using Distributions +using StatsPlots +using StableRNGs +using Clustering +using NaNMath +using Plots +using Test + +import Plots: PlotsBase + +@testset "Grouped histogram" begin + rng = StableRNG(1337) + gpl = groupedhist( + rand(rng, 1000), + yscale = :log10, + ylims = (1e-2, 1e4), + bar_position = :stack, + ) + @test NaNMath.minimum(gpl[1][1][:y]) <= 1e-2 + @test NaNMath.minimum(gpl[1][1][:y]) > 0 + rng = StableRNG(1337) + gpl = groupedhist( + rand(rng, 1000), + yscale = :log10, + ylims = (1e-2, 1e4), + bar_position = :dodge, + ) + @test NaNMath.minimum(gpl[1][1][:y]) <= 1e-2 + @test NaNMath.minimum(gpl[1][1][:y]) > 0 + + data = [1, 1, 1, 1, 2, 1] + mask = (collect(1:6) .< 5) + gpl1 = groupedhist(data[mask], group = mask[mask], color = 1) + gpl2 = groupedhist(data[.!mask], group = mask[.!mask], color = 2) + gpl12 = groupedhist(data, group = mask, nbins = 5, bar_position = :stack) + @test NaNMath.maximum(gpl12[1][end][:y]) == NaNMath.maximum(gpl1[1][1][:y]) + data = [10 12; 1 1; 0.25 0.25] + gplr = groupedbar(data) + @test NaNMath.maximum(gplr[1][1][:y]) == 10 + @test NaNMath.maximum(gplr[1][end][:y]) == 12 + gplr = groupedbar(data, bar_position = :stack) + @test NaNMath.maximum(gplr[1][1][:y]) == 22 + @test NaNMath.maximum(gplr[1][end][:y]) == 12 +end # testset + +@testset "dendrogram" begin + # Example from https://en.wikipedia.org/wiki/Complete-linkage_clustering + wiki_example = [ + 0 17 21 31 23 + 17 0 30 34 21 + 21 30 0 28 39 + 31 34 28 0 43 + 23 21 39 43 0 + ] + clustering = hclust(wiki_example, linkage = :complete) + + xs, ys = StatsPlots.treepositions(clustering, true, :vertical) + + @test xs == [ + 2.0 1.0 4.0 1.75 + 2.0 1.0 4.0 1.75 + 3.0 2.5 5.0 4.5 + 3.0 2.5 5.0 4.5 + ] + + @test ys == [ + 0.0 0.0 0.0 23.0 + 17.0 23.0 28.0 43.0 + 17.0 23.0 28.0 43.0 + 0.0 17.0 0.0 28.0 + ] +end + +@testset "Histogram" begin + data = randn(1000) + @test 0.2 < StatsPlots.wand_bins(data) < 0.4 +end + +@testset "Distributions" begin + @testset "univariate" begin + @testset "discrete" begin + pbern = plot(Bernoulli(0.25)) + @test pbern[1][1][:x][1:2] == zeros(2) + @test pbern[1][1][:x][4:5] == ones(2) + @test pbern[1][1][:y][[1, 4]] == zeros(2) + @test pbern[1][1][:y][[2, 5]] == [0.75, 0.25] + + pdirac = plot(Dirac(0.25)) + @test pdirac[1][1][:x][1:2] == [0.25, 0.25] + @test pdirac[1][1][:y][1:2] == [0, 1] + + ppois_unbounded = plot(Poisson(1)) + @test ppois_unbounded[1][1][:x] isa AbstractVector + @test ppois_unbounded[1][1][:x][1:2] == zeros(2) + @test ppois_unbounded[1][1][:x][4:5] == ones(2) + @test ppois_unbounded[1][1][:y][[1, 4]] == zeros(2) + @test ppois_unbounded[1][1][:y][[2, 5]] == + pdf.(Poisson(1), ppois_unbounded[1][1][:x][[1, 4]]) + + pnonint = plot(Bernoulli(0.75) - 1 // 2) + @test pnonint[1][1][:x][1:2] == [-1 // 2, -1 // 2] + @test pnonint[1][1][:x][4:5] == [1 // 2, 1 // 2] + @test pnonint[1][1][:y][[1, 4]] == zeros(2) + @test pnonint[1][1][:y][[2, 5]] == [0.25, 0.75] + + pmix = plot( + MixtureModel([Bernoulli(0.75), Bernoulli(0.5)], [0.5, 0.5]); + components = false, + ) + @test pmix[1][1][:x][1:2] == zeros(2) + @test pmix[1][1][:x][4:5] == ones(2) + @test pmix[1][1][:y][[1, 4]] == zeros(2) + @test pmix[1][1][:y][[2, 5]] == [0.375, 0.625] + + dzip = MixtureModel([Dirac(0), Poisson(1)], [0.1, 0.9]) + pzip = plot(dzip; components = false) + @test pzip[1][1][:x] isa AbstractVector + @test pzip[1][1][:y][2:3:end] == pdf.(dzip, Int.(pzip[1][1][:x][1:3:end])) + end + end +end + +@testset "ordinations" begin + @testset "MDS" begin + X = randn(4, 100) + M = fit(MultivariateStats.MDS, X; maxoutdim = 3, distances = false) + Y = MultivariateStats.predict(M)' + + mds_plt = plot(M) + @test mds_plt[1][1][:x] == Y[:, 1] + @test mds_plt[1][1][:y] == Y[:, 2] + @test mds_plt[1][:xaxis][:guide] == "MDS1" + @test mds_plt[1][:yaxis][:guide] == "MDS2" + + mds_plt2 = plot(M; mds_axes = (3, 1, 2)) + @test mds_plt2[1][1][:x] == Y[:, 3] + @test mds_plt2[1][1][:y] == Y[:, 1] + @test mds_plt2[1][1][:z] == Y[:, 2] + @test mds_plt2[1][:xaxis][:guide] == "MDS3" + @test mds_plt2[1][:yaxis][:guide] == "MDS1" + @test mds_plt2[1][:zaxis][:guide] == "MDS2" + end +end + +@testset "errorline" begin + rng = StableRNG(1337) + x = 1:10 + # Test for floats + y = rand(rng, 10, 100) .* collect(1:2:20) + @test errorline(1:10, y)[1][1][:x] == x # x-input + @test all( + round.(errorline(1:10, y)[1][1][:y], digits = 3) .== + round.(mean(y, dims = 2), digits = 3), + ) # mean of y + @test all( + round.(errorline(1:10, y)[1][1][:ribbon], digits = 3) .== + round.(std(y, dims = 2), digits = 3), + ) # std of y + # Test for ints + y = reshape(1:100, 10, 10) + @test all(errorline(1:10, y)[1][1][:y] .== mean(y, dims = 2)) + @test all( + round.(errorline(1:10, y)[1][1][:ribbon], digits = 3) .== + round.(std(y, dims = 2), digits = 3), + ) + # Test colors + y = rand(rng, 10, 100, 3) .* collect(1:2:20) + c = palette(:default) + e = errorline(1:10, y) + @test colordiff(c[1], e[1][1][:linecolor]) == 0.0 + @test colordiff(c[2], e[1][2][:linecolor]) == 0.0 + @test colordiff(c[3], e[1][3][:linecolor]) == 0.0 +end + +@testset "marginalhist" begin + rng = StableRNG(1337) + pl = marginalhist(rand(rng, 100), rand(rng, 100)) + @test show(devnull, pl) isa Nothing +end + +@testset "marginalscatter" begin + rng = StableRNG(1337) + pl = marginalscatter(rand(rng, 100), rand(rng, 100)) + @test show(devnull, pl) isa Nothing +end + +@testset "violin" begin + rng = StableRNG(1337) + pl = violin(repeat([0.1, 0.2, 0.3], outer = 100), randn(300), side = :right) + @test show(devnull, pl) isa Nothing +end + +@testset "density" begin + rng = StableRNG(1337) + pl = density(rand(100_000), label = "density(rand())") + @test show(devnull, pl) isa Nothing +end + +@testset "boxplot" begin + # credits to stackoverflow.com/a/71467031 + boxed = [ + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 3, + 7, + 26, + 80, + 170, + 322, + 486, + 688, + 817, + 888, + 849, + 783, + 732, + 624, + 500, + 349, + 232, + 130, + 49, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 2, + 28, + 83, + 181, + 318, + 491, + 670, + 761, + 849, + 843, + 862, + 799, + 646, + 481, + 361, + 225, + 98, + 50, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 2, + 8, + 28, + 80, + 179, + 322, + 493, + 660, + 753, + 803, + 832, + 823, + 783, + 657, + 541, + 367, + 223, + 121, + 62, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 7, + 23, + 84, + 171, + 312, + 463, + 640, + 778, + 834, + 820, + 763, + 752, + 655, + 518, + 374, + 244, + 133, + 52, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 21, + 70, + 169, + 342, + 527, + 725, + 808, + 861, + 857, + 799, + 688, + 622, + 523, + 369, + 232, + 115, + 41, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 2, + 9, + 28, + 76, + 150, + 301, + 492, + 660, + 760, + 823, + 862, + 790, + 749, + 646, + 525, + 352, + 223, + 116, + 54, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 6, + 21, + 64, + 165, + 290, + 434, + 585, + 771, + 852, + 847, + 785, + 739, + 630, + 535, + 354, + 230, + 114, + 42, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 2, + 4, + 19, + 76, + 190, + 337, + 506, + 680, + 775, + 851, + 853, + 816, + 705, + 588, + 496, + 388, + 232, + 127, + 54, + ], + ] + + boxes = -0.002:0.0001:0.0012 + + xx = repeat(boxes, outer = length(boxed)) + yy = collect(Iterators.flatten(boxed)) + + xtick = collect(-0.002:0.0005:0.0012) + + pl = boxplot(xx * 20_000, yy, xticks = (xtick * 20_000, xtick)) + @test show(devnull, pl) isa Nothing +end