Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[question]: can I use thunder for pytorch lightning modules to make it lightning fast? #1491

Open
2catycm opened this issue Nov 30, 2024 · 3 comments
Labels
documentation Improvements or additions to documentation

Comments

@2catycm
Copy link

2catycm commented Nov 30, 2024

📚 Documentation

It seems the documenation didn't cover this, so I am not sure if it should work on pytorch lightning.

The mnist example in the documentation of thunder is using raw pytorch.

cc @Borda @apaz-cli

@2catycm 2catycm added the documentation Improvements or additions to documentation label Nov 30, 2024
@2catycm
Copy link
Author

2catycm commented Dec 2, 2024

Well, it seems we cannot directly use thunder to lightning module. The lightning trainer will raise exception telling me that it requires lightning module type, not thunder type.

@t-vi
Copy link
Collaborator

t-vi commented Dec 2, 2024

Hey, great question!

We are working on this from both ends:

  • after doing lots of enablement, we are looking to make Thunder more end-user friendly,
  • we are systematically looking into how to best work with Thunder (and other compilers) in PyTorch Lightning.

Stay tuned, we want to have something soon, at the latest by end of January.
(cc @lantiga )

@lantiga
Copy link
Collaborator

lantiga commented Dec 2, 2024

hey @2catycm, good question

tldr, instead of compiling a module as a whole, for now you can defer compilation to the point in time when the model is instantiated (which happens in child processes).

That is, you can call self.model = thunder.jit(model) in configure_model, like

class LanguageModel(L.LightningModule):
    def __init__(self, vocab_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.model = None

    def configure_model(self):
        if self.model is not None:
            return

        with torch.device("meta"):
            model = Transformer(
                vocab_size=self.vocab_size,
                nlayers=16,
                nhid=4096,
                ninp=1024,
                nhead=32,
            )

        self.model = thunder.jit(model)

    def training_step(self, batch):
        input, target = batch
        output = self.model(input, target)
        loss = F.nll_loss(output, target.view(-1))
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)

In the future we will want to also capture the optimizer as part of compilation, but the above should be more than enough for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

3 participants