Skip to content

Commit

Permalink
grad wrt sum net with summarized
Browse files Browse the repository at this point in the history
  • Loading branch information
rafael orozco committed Oct 4, 2023
1 parent db31c2c commit 7a2188e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/networks/summarized_net.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ function inverse(X::AbstractArray{T, N}, Y::AbstractArray{T, N}, S::SummarizedNe
end

# Backward pass and compute gradients
function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, Y::AbstractArray{T, N}, S::SummarizedNet; Y_save=nothing) where {T, N}
ΔX, X, ΔY = S.cond_net.backward(ΔX,X,similar(Y),Y)
function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, S::SummarizedNet; Y_save=nothing) where {T, N}
ΔX, X, ΔY = S.cond_net.backward(ΔX,X,ΔY,Y)
ΔY = S.sum_net.backward(ΔY, Y_save)
return ΔX, X, ΔY
end
23 changes: 11 additions & 12 deletions test/test_networks/test_conditional_glow_network.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,6 @@ end





function loss(G, X, Cond)
Y, ZC, logdet = G.forward(X, Cond)
f = -log_likelihood(Y) - logdet
Expand Down Expand Up @@ -211,10 +209,11 @@ end
# Gradient test
function loss_sum(G, X, Cond)
Y, ZC, logdet = G.forward(X, Cond)
f = -log_likelihood(Y) - logdet
f = -log_likelihood(Y) -log_likelihood(ZC) - logdet
ΔY = -∇log_likelihood(Y)
ΔX, X_ = G.backward(ΔY, Y, ZC; Y_save=Cond)
return f, ΔX, G.cond_net.CL[1,1].RB.W1.grad, G.cond_net.CL[1,1].C.v1.grad
ΔZC = -∇log_likelihood(ZC)
ΔX, X_, ΔC = G.backward(ΔY, Y, ΔZC, ZC; Y_save=Cond)
return f, ΔX, ΔC, G.cond_net.CL[1,1].RB.W1.grad, G.cond_net.CL[1,1].C.v1.grad
end

# Gradient test w.r.t. input
Expand All @@ -223,19 +222,19 @@ Cond = rand(Float32, N..., n_cond, batchsize);
X0 = rand(Float32, N..., n_in, batchsize);
Cond0 = rand(Float32, N..., n_cond, batchsize);

dX = X - X0
dCond = Cond - Cond0

f0, ΔX = loss_sum(G, X0, Cond0)[1:2]
f0, ΔX, ΔC = loss_sum(G, X0, Cond0)[1:3]
h = 0.1f0
maxiter = 4
err1 = zeros(Float32, maxiter)
err2 = zeros(Float32, maxiter)

print("\nGradient test glow: input\n")
for j=1:maxiter
f = loss_sum(G, X0 + h*dX, Cond0)[1]
f = loss_sum(G, X0, Cond0 + h*dCond)[1]
err1[j] = abs(f - f0)
err2[j] = abs(f - f0 - h*dot(dX, ΔX))
err2[j] = abs(f - f0 - h*dot(dCond, ΔC))
print(err1[j], "; ", err2[j], "\n")
global h = h/2f0
end
Expand All @@ -254,7 +253,7 @@ Gini = deepcopy(G0)
dW = G.cond_net.CL[1,1].RB.W1.data - G0.cond_net.CL[1,1].RB.W1.data
dv = G.cond_net.CL[1,1].C.v1.data - G0.cond_net.CL[1,1].C.v1.data

f0, ΔX, ΔW, Δv = loss_sum(G0, X, Cond)
f0, ΔX, ΔC, ΔW, Δv = loss_sum(G0, X, Cond)
h = 0.1f0
maxiter = 4
err3 = zeros(Float32, maxiter)
Expand Down Expand Up @@ -388,7 +387,7 @@ X_ = G.inverse(Y,ZCond); # saving the cond is important in split scales because
@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5)

# Test gradients are set and cleared
G.backward(Y, Y,ZCond; Y_save=Cond)
G.backward(Y, Y,ZCond,ZCond; Y_save=Cond)

P = get_params(G)
gsum = 0
Expand Down Expand Up @@ -444,7 +443,7 @@ Gini = deepcopy(G0)
dW = G.cond_net.CL[1,1].RB.W1.data - G0.cond_net.CL[1,1].RB.W1.data
dv = G.cond_net.CL[1,1].C.v1.data - G0.cond_net.CL[1,1].C.v1.data

f0, ΔX, ΔW, Δv = loss_sum(G0, X, Cond);
f0, ΔX,ΔC, ΔW, Δv = loss_sum(G0, X, Cond);
h = 0.1f0
maxiter = 4
err3 = zeros(Float32, maxiter)
Expand Down

0 comments on commit 7a2188e

Please sign in to comment.