Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch.compile Graph break introduced due to new loss function api #34615

Open
2 of 4 tasks
ChanderG opened this issue Nov 5, 2024 · 0 comments · May be fixed by #34616
Open
2 of 4 tasks

Torch.compile Graph break introduced due to new loss function api #34615

ChanderG opened this issue Nov 5, 2024 · 0 comments · May be fixed by #34616
Labels

Comments

@ChanderG
Copy link

ChanderG commented Nov 5, 2024

System Info

  • transformers version: 4.47.0.dev0
  • Platform: Linux-5.14.0-284.73.1.el9_2.x86_64-x86_64-with-glibc2.31
  • Python version: 3.10.15
  • Huggingface_hub version: 0.26.2
  • Safetensors version: 0.4.5
  • Accelerate version: 1.0.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: Yes
  • GPU type: NVIDIA A100-SXM4-80GB

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

PR #34191 introduces a new Loss API. Post this PR, Dynamo was broken, which was identified and fixed in this issue: #34402. Post this (on master), Dynamo runs without errors.

However, in this process, a new Graph Break has been introduced due to this line:

loss_type = re.findall(loss_groups, self.__class__.__name__)

This is due to the new regex check.

Since the dispatch function actually checks for an attr on the config, the fix for this is quite simple - set the loss_type at model init time itself.

Expected behavior

No additional graph breaks.

@ChanderG ChanderG added the bug label Nov 5, 2024
ChanderG added a commit to ChanderG/transformers that referenced this issue Nov 5, 2024
ensures no additional graph break introduced when torch.compile'ed

fixes huggingface#34615

Signed-off-by: ChanderG <[email protected]>
@ChanderG ChanderG linked a pull request Nov 5, 2024 that will close this issue
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant