Skip to content

Torch.compile Graph break introduced due to new loss function api #34615

Closed
@ChanderG

Description

@ChanderG

System Info

  • transformers version: 4.47.0.dev0
  • Platform: Linux-5.14.0-284.73.1.el9_2.x86_64-x86_64-with-glibc2.31
  • Python version: 3.10.15
  • Huggingface_hub version: 0.26.2
  • Safetensors version: 0.4.5
  • Accelerate version: 1.0.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: Yes
  • GPU type: NVIDIA A100-SXM4-80GB

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

PR #34191 introduces a new Loss API. Post this PR, Dynamo was broken, which was identified and fixed in this issue: #34402. Post this (on master), Dynamo runs without errors.

However, in this process, a new Graph Break has been introduced due to this line:

loss_type = re.findall(loss_groups, self.__class__.__name__)

This is due to the new regex check.

Since the dispatch function actually checks for an attr on the config, the fix for this is quite simple - set the loss_type at model init time itself.

Expected behavior

No additional graph breaks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions