-
-
Notifications
You must be signed in to change notification settings - Fork 333
/
Copy pathvae_plot.jl
45 lines (40 loc) · 1.39 KB
/
vae_plot.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
include("vae_mnist.jl")
using Plots
function plot_result()
checkpoint = JLD2.load("output/checkpoint.jld2")
encoder_state = checkpoint["encoder"]
decoder_state = checkpoint["decoder"]
args = Args(; checkpoint["args"]...)
encoder = Encoder(args.input_dim, args.latent_dim, args.hidden_dim)
decoder = Decoder(args.input_dim, args.latent_dim, args.hidden_dim)
Flux.loadmodel!(encoder, encoder_state)
Flux.loadmodel!(decoder, decoder_state)
loader = get_data(args.batch_size)
# clustering in the latent space
# visualize first two dims
plt = scatter(palette=:rainbow)
for (i, (x, y)) in enumerate(loader)
i < 20 || break
μ, logσ = encoder(x)
@assert size(μ, 1) == 2 # Latent_dim has to be 2 for direct visualization, otherwise use PCA or t-SNE
scatter!(μ[1, :], μ[2, :],
markerstrokewidth=0, markeralpha=0.8,
aspect_ratio=1,
markercolor=y, label="")
end
savefig(plt, "output/clustering.png")
z = range(-2.0, stop=2.0, length=11)
len = Base.length(z)
z1 = repeat(z, len)
z2 = sort(z1)
x = zeros(Float32, args.latent_dim, len^2)
x[1, :] = z1
x[2, :] = z2
samples = decoder(x)
samples = sigmoid.(samples)
image = convert_to_image(samples, len)
save("output/manifold.png", image)
end
if abspath(PROGRAM_FILE) == @__FILE__
plot_result()
end