Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

self.manual_backward() makes all gradients gone #20685

Open
samsara-ku opened this issue Mar 31, 2025 · 1 comment
Open

self.manual_backward() makes all gradients gone #20685

samsara-ku opened this issue Mar 31, 2025 · 1 comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.5.x

Comments

@samsara-ku
Copy link

samsara-ku commented Mar 31, 2025

Bug description

If you try to train GAN with lightning module using multi-GPU, face some erros like this:

[rank0]: RuntimeError: It looks like your LightningModule has parameters that were not used in producing the loss returned by training_step. If this is intentional, you must enable the detection of unused parameters in DDP, either by setting the string value `strategy='ddp_find_unused_parameters_true'` or by setting the flag in the strategy with `strategy=DDPStrategy(find_unused_parameters=True)`.

For inspection, I tried to train GAN training codes with this kind of snippets in the main.py:

import os

os.environ["TORCH_CPP_LOG_LEVEL"] = "INFO"
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"

and then you can face this kind of errors:

[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.conv.weight_orig did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.conv.bias did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.3.norm.bias did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.3.norm.weight did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.3.conv.weight_orig did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.3.conv.bias did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.2.norm.bias did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.2.norm.weight did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.2.conv.weight_orig did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.2.conv.bias did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.1.norm.bias did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.1.norm.weight did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.1.conv.weight_orig did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.1.conv.bias did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.0.conv.weight_orig did not get gradient in backwards pass.
[rank1]:[I reducer.cpp:1949] [Rank 1] Parameter: discriminator.discs.1.down_blocks.0.conv.bias did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.conv.weight_orig did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.conv.bias did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.3.norm.bias did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.3.norm.weight did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.3.conv.weight_orig did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.3.conv.bias did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.2.norm.bias did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.2.norm.weight did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.2.conv.weight_orig did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.2.conv.bias did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.1.norm.bias did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.1.norm.weight did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.1.conv.weight_orig did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.1.conv.bias did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.0.conv.weight_orig did not get gradient in backwards pass.
[rank0]:[I reducer.cpp:1949] [Rank 0] Parameter: discriminator.discs.1.down_blocks.0.conv.bias did not get gradient in backwards pass.

I think this problem comes from the some wrong codes in the lightning module; if you successfully run one of the manual_backward() codes, then your the other manual_backward() codes cannot run right way since the first call might remove all the gradients of other module.

In this, someone suggests a way to call only one manual_backward() codes, but I think it would be a little different from normal training strategy of GAN:

self.manual_backward(d_loss + g_loss) --> I think this would be problem, but I cannot find no other way to solve unused parameter issue

self.toggle_optimizer(optimizer_d)
optimizer_d.step()
optimizer_d.zero_grad()
self.untoggle_optimizer(optimizer_d)
self.toggle_optimizer(optimizer_g)
optimizer_g.step()
optimizer_g.zero_grad()
self.untoggle_optimizer(optimizer_g)

Is there anyone to solve this problem?

What version are you seeing the problem on?

v2.5

How to reproduce the bug

class G_base(pl.LightningModule):
    def __init__(self, cfg):
        super().__init__()

        self.cfg = cfg
        self.gen_scales = cfg.generator.scales
        self.dis_scales = cfg.discriminator.scales

        self.generator = Generator(cfg)
        self.discriminator = Discriminator(**cfg.discriminator)

        self.automatic_optimization = False

    def training_step(self, batch):
        src_img, drv_img = batch["src"], batch["drv"]

        opt_G, opt_D = self.optimizers()

        output = self.generator(src_img, drv_img)

        ### Gen ###
        self.toggle_optimizer(opt_G)
        gan_g_loss = 0

        pyramid_generated_gen = {"prediction_1": output["gen_img"]}

        disc_map_generated_gen = self.discriminator(pyramid_generated_gen)

        for scale in self.dis_scales:
            key = "prediction_map_%s" % scale
            value = (1 - disc_map_generated_gen[key]) ** 2
            gan_g_loss += value.mean()

        opt_G.zero_grad()
        self.manual_backward(gan_g_loss)
        opt_G.step()
        self.untoggle_optimizer(opt_G)

        self.log_dict({"gan_g_loss": gan_g_loss}, prog_bar=True)

        ### Dis ###
        self.toggle_optimizer(opt_D)
        gan_d_loss = 0

        pyramid_real = {"prediction_1": drv_img}
        pyramid_generated = {"prediction_1": output["gen_img"].detach()}

        disc_map_real = self.discriminator(pyramid_real)
        disc_map_generated = self.discriminator(pyramid_generated)

        for scale in self.dis_scales:
            key = "prediction_map_%s" % scale
            value = (1 - disc_map_real[key]) ** 2 + disc_map_generated[key] ** 2
            gan_d_loss += value.mean()

        opt_D.zero_grad()
        self.manual_backward(gan_d_loss)
        opt_D.step()
        self.untoggle_optimizer(opt_D)

        self.log_dict({"gan_d_loss": gan_d_loss}, prog_bar=True)

Error messages and logs

# Error messages and logs here please

Environment

Current environment
aiohappyeyeballs         2.6.1
aiohttp                  3.11.14
aiosignal                1.3.2
annotated-types          0.7.0
antlr4-python3-runtime   4.9.3
anykeystore              0.2
apex                     0.9.10.dev0
attrs                    25.3.0
certifi                  2025.1.31
charset-normalizer       3.4.1
click                    8.1.8
cryptacular              1.6.2
decorator                4.4.2
defusedxml               0.7.1
docker-pycreds           0.4.0
facenet-pytorch          2.6.0
filelock                 3.18.0
flow-vis                 0.1
frozenlist               1.5.0
fsspec                   2025.3.0
gitdb                    4.0.12
GitPython                3.1.44
greenlet                 3.0.3
h5py                     3.11.0
hupper                   1.12.1
idna                     3.10
Jinja2                   3.1.6
lightning-utilities      0.14.1
MarkupSafe               3.0.2
mpmath                   1.3.0
multidict                6.2.0
munkres                  1.1.4
munkres                  1.1.4
natsort                  8.4.0
natsort                  8.4.0
networkx                 3.4.2
numpy                    1.26.4
nvidia-cublas-cu12       12.1.3.1
nvidia-cuda-cupti-cu12   12.1.105
nvidia-cuda-nvrtc-cu12   12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12        8.9.2.26
nvidia-cufft-cu12        11.0.2.54
nvidia-curand-cu12       10.3.2.106
nvidia-cusolver-cu12     11.4.5.107
nvidia-cusparse-cu12     12.1.0.106
nvidia-cusparselt-cu12   0.6.2
nvidia-nccl-cu12         2.19.3
nvidia-nvjitlink-cu12    12.4.127
nvidia-nvtx-cu12         12.1.105
oauthlib                 3.2.2
omegaconf                2.3.0
opencv-python            4.11.0.86
packaging                24.2
PasteDeploy              3.1.0
pbkdf2                   1.3
pillow                   10.2.0
pip                      25.0
plaster                  1.1.2
plaster-pastedeploy      1.0.1
platformdirs             4.3.6
proglog                  0.1.10
propcache                0.3.0
protobuf                 5.29.3
psutil                   7.0.0
pyav                     11.4.1
pycocotools              2.0.8
pycocotools              2.0.8
pydantic                 2.10.6
pydantic_core            2.27.2
pyramid                  2.0.2
pyramid-mailer           0.15.1
python3-openid           3.2.0
pytorch-lightning        2.5.0.post0
PyYAML                   6.0.2
repoze.sendmail          4.4.1
requests                 2.32.3
requests-oauthlib        2.0.0
sentry-sdk               2.23.1
setproctitle             1.3.5
setuptools               75.8.0
six                      1.17.0
slack_sdk                3.35.0
smmap                    5.0.2
SQLAlchemy               2.0.30
sympy                    1.13.1
tensorboardX             2.6.2.2
torch                    2.2.2
torchaudio               2.2.2
torchmetrics             1.6.3
torchvision              0.17.2
tqdm                     4.67.1
transaction              4.0
translationstring        1.4
triton                   2.2.0
typing_extensions        4.12.2
urllib3                  2.3.0
velruse                  1.1.1
venusian                 3.1.0
wandb                    0.19.8
WebOb                    1.8.7
wheel                    0.45.1
WTForms                  3.1.2
wtforms-recaptcha        0.3.2
yarl                     1.18.3
zope.deprecation         5.0
zope.interface           6.4.post2
zope.sqlalchemy          3.1

More info

No response

@samsara-ku samsara-ku added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Mar 31, 2025
@samsara-ku
Copy link
Author

samsara-ku commented Mar 31, 2025

Interestingly, if i swap the procedure of training GAN (e.g. from "disc first, then gen" to "gen first, then disc") I would get same error there is no gradient in the later module.

@samsara-ku samsara-ku changed the title RuntimeError: It looks like your LightningModule has parameters that were not used in producing the loss returned by training_step. self.manual_backward() makes all gradients gone Apr 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.5.x
Projects
None yet
Development

No branches or pull requests

1 participant