Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NO _loss_tracker on train_on_batch because compile model multiple times. Possible Bug. #20474

Closed
TheMGGdev opened this issue Nov 8, 2024 · 5 comments · Fixed by #20602
Closed
Assignees
Labels

Comments

@TheMGGdev
Copy link

TheMGGdev commented Nov 8, 2024

Same code of a GAN works perfectly in Keras 3.3 but doesn´t work in keras 3.6. There is an error on train_on_batch I believe is because an bug introduce in a change in Keras 3.6
The code is this:

import numpy as np
import matplotlib.pyplot as plt
import random

from keras.datasets import mnist
from keras.utils import plot_model
from keras.models import Sequential, Model
from keras.layers import (Input, Conv2D, Dense, Activation, 
                         Flatten, Reshape, Dropout,
                         UpSampling2D, MaxPooling2D,
                         BatchNormalization, LeakyReLU, Conv2DTranspose,
                         GlobalMaxPooling2D)
from keras.losses import BinaryCrossentropy
from keras.optimizers import Adam
from keras.metrics import Mean, Accuracy
from keras.backend import backend
from keras.random import SeedGenerator, uniform, normal
from keras import ops

!pip list | grep keras

(X_train, Y_train), (X_test, Y_test) = mnist.load_data()
X_train = X_train.astype('float32')/127.5 -1
X_train = np.expand_dims(X_train, axis = 3)

def create_generator():
    generator = Sequential(
        [
            Input(shape = (100,)),
            Dense(7 * 7 * 128),
            LeakyReLU(0.2),
            Reshape((7, 7, 128)),
            Conv2DTranspose(128, 4, 2, "same"),
            LeakyReLU(0.2),
            Conv2DTranspose(256, 4, 2, "same"),
            LeakyReLU(0.2),
            Conv2D(1, 7, padding = "same", activation = "tanh"),
        ],
        name = "generator",
    )
    return generator

def create_discriminator():
    discriminator = Sequential(
        [
            Input(shape = (28, 28, 1)),
            Conv2D(64, 3, 2, "same"),
            LeakyReLU(0.2),
            Conv2D(128, 3, 2, "same"),
            LeakyReLU(0.2),
            Conv2D(256, 3, 2, "same"),
            LeakyReLU(0.2),
            Flatten(),
            Dropout(0.2),
            Dense(1, activation = "sigmoid"),
        ],
        name = "discriminator",
    )
    return discriminator

generator = create_generator()
discriminator = create_discriminator()

discriminator.compile(loss = 'binary_crossentropy', optimizer = Adam(), metrics = ['accuracy'])
discriminator.trainable = False


###Print for debugging/show the error 
print('---Debugging/show the error after compiled discriminator and before compiled combined---')
print('discriminator.compiled ->', discriminator.compiled)
print('discriminator.optimizer ->', discriminator.optimizer)
print('discriminator.train_function ->', discriminator.train_function)
print('discriminator.train_step ->', discriminator.train_step)
print('discriminator.metrics ->', discriminator.metrics)
print('discriminator._loss_tracker ->', discriminator._loss_tracker)
print('discriminator._jit_compile ->', discriminator._jit_compile)
###

z = Input(shape=(100,))
img = generator(z)
validity = discriminator(img)

combined = Model(z, validity)
combined.compile(loss = 'binary_crossentropy', optimizer = Adam())

###Print for debugging/show the error
print('---Debugging/show the error after compiled discriminator and combined---')
print('discriminator.compiled ->', discriminator.compiled)
print('discriminator.optimizer ->', discriminator.optimizer)
print('discriminator.train_function ->', discriminator.train_function)
print('discriminator.train_step ->', discriminator.train_step)
print('discriminator.metrics ->', discriminator.metrics)
print('discriminator._loss_tracker ->', discriminator._loss_tracker)
print('discriminator._jit_compile ->', discriminator._jit_compile)
###

def train(X_train, generator, discriminator, combined, epochs, batch_size = 32, sample_interval = 100):
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    history = {
        'd_loss' : [],
        'd_acc' : [],
        'g_loss' : []
    }

    for epoch in range(epochs):
        print("----EPOCH " + str(epoch) + '-----')
        for batch in range(int(len(X_train)/batch_size)):
            #  Train the Discriminator
            noise = np.random.normal(0, 1, (batch_size, 100))
            gen_imgs = generator.predict(noise, verbose = 0)
            imgs = X_train[batch*batch_size : (batch+1)*batch_size]

            #Print for debugging/show the error
            print('---Debugging/show the error---')
            print('discriminator.compiled ->', discriminator.compiled)
            print('discriminator.optimizer ->', discriminator.optimizer)
            print('discriminator._loss_tracker ->', discriminator._loss_tracker)
            print('discriminator._jit_compile ->', discriminator._jit_compile)


            d_loss_real = discriminator.train_on_batch(imgs, valid)
            d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # Train the Generator
            noise = np.random.normal(0, 1, (batch_size, 100))
            g_loss = combined.train_on_batch(noise, valid)

            # Save losses
            history['d_loss'].append(d_loss[0])
            history['d_acc'].append(d_loss[1])
            history['g_loss'].append(g_loss[0])

train(X_train, generator, discriminator, combined, epochs = 2, batch_size = 256, sample_interval = 100)

The error is this:

    Cell In[13], line 128, in train(X_train, generator, discriminator, combined, epochs, batch_size, sample_interval)
        124 print('discriminator._loss_tracker ->', discriminator._loss_tracker)
        125 print('discriminator._jit_compile ->', discriminator._jit_compile)
    --> 128 d_loss_real = discriminator.train_on_batch(imgs, valid)
        129 d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        130 d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
    
    File ~/.local/lib/python3.10/site-packages/keras/src/backend/torch/trainer.py:468, in TorchTrainer.train_on_batch(self, x, y, sample_weight, class_weight, return_dict)
        465 self._symbolic_build(data_batch=data)
        466 self.make_train_function()
    --> 468 logs = self.train_function([data])
        469 logs = tree.map_structure(lambda x: np.array(x), logs)
        470 if return_dict:
    
    File ~/.local/lib/python3.10/site-packages/keras/src/backend/torch/trainer.py:117, in TorchTrainer.make_train_function.<locals>.one_step_on_data(data)
        115 """Runs a single training step on a batch of data."""
        116 data = data[0]
    --> 117 return self.train_step(data)
    
    File ~/.local/lib/python3.10/site-packages/keras/src/backend/torch/trainer.py:55, in TorchTrainer.train_step(self, data)
         50 self.zero_grad()
         52 loss = self._compute_loss(
         53     x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=True
         54 )
    ---> 55 self._loss_tracker.update_state(
         56     loss, sample_weight=tree.flatten(x)[0].shape[0]
         57 )
         58 if self.optimizer is not None:
         59     loss = self.optimizer.scale_loss(loss)
    
    AttributeError: 'NoneType' object has no attribute 'update_state'.

The error says that self._loss_tracker.update_state is None when is should be metrics_module.Mean(name="loss") as is been compiled.
The print that I write in the code shows that after compiled the discriminator and before the combined:

---Debugging/show the error after compiled discriminator and before compiled combined---
discriminator.compiled -> True
discriminator.optimizer -> <keras.src.optimizers.adam.Adam object at 0x77ecf6bb0a00>
discriminator.train_function -> None
discriminator.train_step -> <bound method TensorFlowTrainer.train_step of <Sequential name=discriminator, built=True>>
discriminator.metrics -> [<Mean name=loss>, <CompileMetrics name=compile_metrics>]
discriminator._loss_tracker -> <Mean name=loss>
discriminator._jit_compile -> True

However after compiled the combined:
discriminator.compiled -> True
discriminator.optimizer -> <keras.src.optimizers.adam.Adam object at 0x77ecf6bb0a00>
discriminator.train_function -> None
discriminator.train_step -> <bound method TensorFlowTrainer.train_step of >
discriminator.metrics -> []
discriminator._loss_tracker -> None
discriminator._jit_compile -> True

So the problems realives that compiling the combined erases the metrics (loss_tracks,...) of the discriminator what it shouldn't and keeps the discriminator as compiled when it shouldn´t becouse it undo the compiled. I belive the bug relieves in a change introduce in Keras 3.6 that change the compile of keras/src/trainers/trainer.py:
aux
The function self._clear_previous_trainer_metrics not only clears the metrics of combined but also of discriminator what that makes that discriminator not having proper metrics.
aux2.
My pull request to this possible error is: #20473
I try the code with the thee backeends and happens always
I hope it help ! :)

@TheMGGdev TheMGGdev changed the title NO _loss_tracker on train_on_batch becouse compile model multiple times. Possible Bug. NO _loss_tracker on train_on_batch because compile model multiple times. Possible Bug. Nov 9, 2024
@mohammad-rababah
Copy link

any update on this bug ?

@mehtamansi29
Copy link
Collaborator

Hi @TheMGGdev and @mohammad-rababah -

Here getting the error because you are doing discriminator.compile(loss = 'binary_crossentropy', optimizer = Adam(), metrics = ['accuracy'])and then freeze the weights during training by discriminator.trainable = False .

So when trying to compile combine model combined.compile(loss = 'binary_crossentropy', optimizer = Adam()), discriminator weights are frozen due to traininable=False.

That's why discriminator.train_function will become None on update_state.

You can compile discriminator after the combine model compile. It will resolve your error.

z = Input(shape=(100,))
img = generator(z)
validity = discriminator(img)
combined = Model(z, validity)
combined.compile(loss = 'binary_crossentropy', optimizer = Adam())
discriminator.compile(loss = 'binary_crossentropy', optimizer = Adam(), metrics = ['accuracy'])
discriminator.trainable = False

Attached gist for your reference.

@TheMGGdev
Copy link
Author

TheMGGdev commented Nov 27, 2024

The error has nothing to do with that. There are trainings in which a model that is a combination of several models, you don't want to train one of them, as in this example with the GANS. Here you have generator, discriminator and combined (which is generator + discriminator). When you create the combined model, which is the one you are going to use to train the generator, you want the discriminator not to train, so you put discriminator.trainable = False, Explained more simply:

  • Discriminator training: discriminator.train_on_batch we generate with the discriminator and the discriminator learns.
  • Generator training: combined.train_on_batch we generate with the combined (discriminator + generator) the combined learns but as we want only the generator to learn we set before we create the combined that the discriminator doesn´t learn.

The code works perfectly in other versions and it comes from a example of this repo https://github.com/eriklindernoren/Keras-GAN/blob/master/cgan/cgan.py. The real problem is that the new versions of Keras when you do the combined, it clears all the trainer metrics and therefore when you do the combined compile it deletes the discriminator metrics. As explained in the following pull request #20473 . Put the

discriminator.compile(loss = ‘binary_crossentropy’, optimizer = Adam(), metrics = [‘accuracy’]) '
discriminator.trainable = False 

after the combined compile instead of before does not solve the problem but creates another one, because now in the generator training the generator and the discriminator will be trained. Hope it helps :)

@mehtamansi29 mehtamansi29 added the keras-team-review-pending Pending review by a Keras team member. label Nov 27, 2024
@mattdangerw mattdangerw removed the keras-team-review-pending Pending review by a Keras team member. label Nov 27, 2024
@mattdangerw mattdangerw self-assigned this Nov 27, 2024
@mattdangerw
Copy link
Member

@TheMGGdev looks like we just need a unit test to go along with your change #20473

Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants