From 7a2188e5847ec474490c38040961b6dc0aefc50f Mon Sep 17 00:00:00 2001 From: rafael orozco Date: Wed, 4 Oct 2023 18:24:21 -0400 Subject: [PATCH] grad wrt sum net with summarized --- src/networks/summarized_net.jl | 4 ++-- .../test_conditional_glow_network.jl | 23 +++++++++---------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/networks/summarized_net.jl b/src/networks/summarized_net.jl index 8b616ccd..ac0fd547 100644 --- a/src/networks/summarized_net.jl +++ b/src/networks/summarized_net.jl @@ -49,8 +49,8 @@ function inverse(X::AbstractArray{T, N}, Y::AbstractArray{T, N}, S::SummarizedNe end # Backward pass and compute gradients -function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, Y::AbstractArray{T, N}, S::SummarizedNet; Y_save=nothing) where {T, N} - ΔX, X, ΔY = S.cond_net.backward(ΔX,X,similar(Y),Y) +function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, S::SummarizedNet; Y_save=nothing) where {T, N} + ΔX, X, ΔY = S.cond_net.backward(ΔX,X,ΔY,Y) ΔY = S.sum_net.backward(ΔY, Y_save) return ΔX, X, ΔY end diff --git a/test/test_networks/test_conditional_glow_network.jl b/test/test_networks/test_conditional_glow_network.jl index aca9bc66..a7bd4627 100644 --- a/test/test_networks/test_conditional_glow_network.jl +++ b/test/test_networks/test_conditional_glow_network.jl @@ -95,8 +95,6 @@ end - - function loss(G, X, Cond) Y, ZC, logdet = G.forward(X, Cond) f = -log_likelihood(Y) - logdet @@ -211,10 +209,11 @@ end # Gradient test function loss_sum(G, X, Cond) Y, ZC, logdet = G.forward(X, Cond) - f = -log_likelihood(Y) - logdet + f = -log_likelihood(Y) -log_likelihood(ZC) - logdet ΔY = -∇log_likelihood(Y) - ΔX, X_ = G.backward(ΔY, Y, ZC; Y_save=Cond) - return f, ΔX, G.cond_net.CL[1,1].RB.W1.grad, G.cond_net.CL[1,1].C.v1.grad + ΔZC = -∇log_likelihood(ZC) + ΔX, X_, ΔC = G.backward(ΔY, Y, ΔZC, ZC; Y_save=Cond) + return f, ΔX, ΔC, G.cond_net.CL[1,1].RB.W1.grad, G.cond_net.CL[1,1].C.v1.grad end # Gradient test w.r.t. input @@ -223,9 +222,9 @@ Cond = rand(Float32, N..., n_cond, batchsize); X0 = rand(Float32, N..., n_in, batchsize); Cond0 = rand(Float32, N..., n_cond, batchsize); -dX = X - X0 +dCond = Cond - Cond0 -f0, ΔX = loss_sum(G, X0, Cond0)[1:2] +f0, ΔX, ΔC = loss_sum(G, X0, Cond0)[1:3] h = 0.1f0 maxiter = 4 err1 = zeros(Float32, maxiter) @@ -233,9 +232,9 @@ err2 = zeros(Float32, maxiter) print("\nGradient test glow: input\n") for j=1:maxiter - f = loss_sum(G, X0 + h*dX, Cond0)[1] + f = loss_sum(G, X0, Cond0 + h*dCond)[1] err1[j] = abs(f - f0) - err2[j] = abs(f - f0 - h*dot(dX, ΔX)) + err2[j] = abs(f - f0 - h*dot(dCond, ΔC)) print(err1[j], "; ", err2[j], "\n") global h = h/2f0 end @@ -254,7 +253,7 @@ Gini = deepcopy(G0) dW = G.cond_net.CL[1,1].RB.W1.data - G0.cond_net.CL[1,1].RB.W1.data dv = G.cond_net.CL[1,1].C.v1.data - G0.cond_net.CL[1,1].C.v1.data -f0, ΔX, ΔW, Δv = loss_sum(G0, X, Cond) +f0, ΔX, ΔC, ΔW, Δv = loss_sum(G0, X, Cond) h = 0.1f0 maxiter = 4 err3 = zeros(Float32, maxiter) @@ -388,7 +387,7 @@ X_ = G.inverse(Y,ZCond); # saving the cond is important in split scales because @test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5) # Test gradients are set and cleared -G.backward(Y, Y,ZCond; Y_save=Cond) +G.backward(Y, Y,ZCond,ZCond; Y_save=Cond) P = get_params(G) gsum = 0 @@ -444,7 +443,7 @@ Gini = deepcopy(G0) dW = G.cond_net.CL[1,1].RB.W1.data - G0.cond_net.CL[1,1].RB.W1.data dv = G.cond_net.CL[1,1].C.v1.data - G0.cond_net.CL[1,1].C.v1.data -f0, ΔX, ΔW, Δv = loss_sum(G0, X, Cond); +f0, ΔX,ΔC, ΔW, Δv = loss_sum(G0, X, Cond); h = 0.1f0 maxiter = 4 err3 = zeros(Float32, maxiter)