Skip to content

Accelerate + Dynamo broken in 4.46.0 due to model loss functions refactor #34402

@AbrahamSanders

Description

@AbrahamSanders

System Info

  • transformers version: 4.46.0
  • Platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.35
  • Python version: 3.9.16
  • Huggingface_hub version: 0.24.0
  • Safetensors version: 0.4.5
  • Accelerate version: 1.0.1
  • Accelerate config: - compute_environment: LOCAL_MACHINE
    - distributed_type: NO
    - mixed_precision: bf16
    - use_cpu: False
    - debug: False
    - num_processes: 1
    - machine_rank: 0
    - num_machines: 1
    - gpu_ids: 0
    - rdzv_backend: static
    - same_network: True
    - main_training_function: main
    - enable_cpu_affinity: True
    - downcast_bf16: no
    - tpu_use_cluster: False
    - tpu_use_sudo: False
    - tpu_env: []
    - dynamo_config: {'dynamo_backend': 'INDUCTOR'}
  • PyTorch version (GPU?): 2.5.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: Yes
  • GPU type: NVIDIA RTX A6000

Who can help?

@muellerzr @ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

#34191 introduced custom loss functions to the model classes. This appears to have broken training with accelerate + torch dynamo.

To reproduce, use run_clm.py with the following accelerate config:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: 'NO'
downcast_bf16: 'no'
dynamo_config:
  dynamo_backend: INDUCTOR
enable_cpu_affinity: true
gpu_ids: '0'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
accelerate launch run_clm.py \
    --log_level info \
    --model_name_or_path=meta-llama/Llama-3.2-1B \
    --dataset_name=Salesforce/wikitext \
    --dataset_config_name=wikitext-2-raw-v1 \
    --block_size=1024 \
    --per_device_train_batch_size=4 \
    --do_train \
    --bf16 \
    --output_dir=Llama-3.2-1B-wikitext-2-raw-v1 \
    --overwrite_output_dir \
    --seed=42 \
    --logging_steps=10 \
    --lr_scheduler_type=cosine \
    --num_train_epochs=3 \
    --learning_rate=5e-05 \
    --warmup_ratio=0.03 \
    --dataloader_drop_last

This produces an error from dynamo relating to the new model_cls.loss_function attribute added in #34191:

loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs)

Important part of the traceback:

  File "/anaconda3/envs/dev/lib/python3.9/site-packages/torch/_dynamo/variables/functions.py", line 152, in __init__
    assert isinstance(
AssertionError: expected FunctionType found _lru_cache_wrapper <functools._lru_cache_wrapper object at 0x7f1091109a40>

from user code:
   File "/anaconda3/envs/dev/lib/python3.9/site-packages/accelerate/utils/operations.py", line 820, in forward
    return model_forward(*args, **kwargs)
  File "/anaconda3/envs/dev/lib/python3.9/site-packages/accelerate/utils/operations.py", line 808, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/anaconda3/envs/dev/lib/python3.9/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
    return func(*args, **kwargs)
  File "/anaconda3/envs/dev/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 1214, in forward
    loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Now, if you update the accelerate config to not use dynamo, it runs just fine:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: 'NO'
downcast_bf16: 'no'
enable_cpu_affinity: true
gpu_ids: '0'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Expected behavior

Accelerate should not throw the error when using torch dynamo.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions