Skip to content

Commit

Permalink
Merge pull request #112 from slimgroup/fix_odd_cond
Browse files Browse the repository at this point in the history
Fix odd chan for cond network
  • Loading branch information
rafaelorozco authored Aug 2, 2024
2 parents e165c4f + f72cd6f commit 7055a9f
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/conditional_layers/conditional_layer_glow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,12 @@ function ConditionalLayerGlow(n_in::Int64, n_cond::Int64, n_hidden::Int64;freeze

# 1x1 Convolution and residual block for invertible layers
C = Conv1x1(n_in; freeze=freeze_conv)
RB = ResidualBlock(Int(n_in/2)+n_cond, n_hidden; n_out=n_in, activation=rb_activation, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, fan=true, ndims=ndims)

split_num = Int(round(n_in/2))
in_split = n_in-split_num
out_chan = 2*split_num

RB = ResidualBlock(in_split+n_cond, n_hidden; n_out=out_chan, activation=rb_activation, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, fan=true, ndims=ndims)

return ConditionalLayerGlow(C, RB, logdet, activation)
end
Expand Down Expand Up @@ -143,7 +148,7 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, C::AbstractA

# Backpropagate RB
ΔX2_ΔC = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), (tensor_cat(X2, C)))
ΔX2, ΔC = tensor_split(ΔX2_ΔC; split_index=Int(size(ΔY)[N-1]/2))
ΔX2, ΔC = tensor_split(ΔX2_ΔC; split_index=size(ΔY2)[N-1])
ΔX2 += ΔY2

# Backpropagate 1x1 conv
Expand Down
43 changes: 43 additions & 0 deletions test/test_networks/test_conditional_glow_network.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,49 @@ device = InvertibleNetworks.CUDA.functional() ? gpu : cpu
# Random seed
Random.seed!(3);

# Define network
nx = 32; ny = 32; nz = 32
n_in = 3
n_cond = 3
n_hidden = 4
batchsize = 2
L = 2
K = 2
split_scales = false
N = (nx,ny)

########################################### Test with split_scales = false N = (nx,ny) #########################
# Invertibility

# Network and input
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N)) |> device
X = rand(Float32, N..., n_in, batchsize) |> device
Cond = rand(Float32, N..., n_cond, batchsize) |> device

Y, Cond = G.forward(X,Cond)
X_ = G.inverse(Y,Cond) # saving the cond is important in split scales because of reshapes

@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5)

# Test gradients are set and cleared
G.backward(Y, Y, Cond)

P = get_params(G)
gsum = 0
for p in P
~isnothing(p.grad) && (global gsum += 1)
end
@test isequal(gsum, L*K*10+2)

clear_grad!(G)
gsum = 0
for p in P
~isnothing(p.grad) && (global gsum += 1)
end
@test isequal(gsum, 0)


Random.seed!(3);
# Define network
nx = 32; ny = 32; nz = 32
n_in = 2
Expand Down

0 comments on commit 7055a9f

Please sign in to comment.