Skip to content

[Trainer] Eval loss depends on batch size (with solution) #39241

@ba144220

Description

@ba144220

System Info

  • transformers version: 4.54.0.dev0
  • Platform: Linux-5.15.0-1047-oracle-x86_64-with-glibc2.35
  • Python version: 3.12.11
  • Huggingface_hub version: 0.33.1
  • Safetensors version: 0.5.3
  • Accelerate version: 1.8.1
  • Accelerate config: not found
  • DeepSpeed version: 0.17.1
  • PyTorch version (accelerator?): 2.7.1+cu126 (CUDA)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: Yes
  • GPU type: NVIDIA A100-SXM4-80GB

Who can help?

@SunMarc @ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

This is actually a known issue for several years, see: https://discuss.huggingface.co/t/batch-size-during-training-vs-batch-size-during-evaluation/20827 and https://discuss.huggingface.co/t/evaluation-loss-depends-on-batch-size/112046

I’ve been evaluating a few causal LMs (e.g. Qwen/Qwen2.5-3B) on 512 samples from the togethercomputer/RedPajama-Data-1T-Sample pre-train dataset, and I noticed that eval loss consistently decreases as I increase the batch size:

Batch size Eval loss
1 2.414
2 2.340
4 2.299
8 2.298
16 2.296

I saw the same trend across other models as well.

This is the code I’m using:

import argparse
import os
import torch
from dotenv import load_dotenv

load_dotenv()

from transformers import AutoModelForCausalLM, AutoTokenizer
from trl.trl import SFTConfig, SFTTrainer
from datasets import load_dataset

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-3B")
    parser.add_argument("--dataset_name", type=str, default="togethercomputer/RedPajama-Data-1T-Sample")
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--max_seq_length", type=int, default=2048)
    parser.add_argument("--max_eval_samples", type=int, default=512)
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    model = AutoModelForCausalLM.from_pretrained(args.model_name, trust_remote_code=True, device_map="auto", token=os.getenv("HF_TOKEN"), torch_dtype=torch.bfloat16)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True, token=os.getenv("HF_TOKEN"))

    # Load dataset
    dataset = load_dataset(args.dataset_name, split="train")
    dataset = dataset.shuffle(args.seed).select(range(args.max_eval_samples))

    sft_config = SFTConfig(
        output_dir="./results",
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        dataset_text_field="text",
        max_seq_length=args.max_seq_length,
    )
    trainer = SFTTrainer(
        model=model,
        args=sft_config,
        train_dataset=dataset,
        eval_dataset=dataset,
        processing_class=tokenizer,
    )
    eval_result = trainer.evaluate()

    print(eval_result)

if __name__ == "__main__":
    main()

Solution

Digging in, I found that fixed_cross_entropy (in transformers/src/transformers/loss/loss_utils.py) does a token-level sum then divides by the total non-padding token count (micro-averaging). To fix the issue, I implemented a sample-wise average (macro-averaging):

# Hugging Face: token-sum / total_tokens
loss = F.cross_entropy(..., reduction="sum") / num_items_in_batch

# My version: per-sequence average then mean across sequences
loss = F.cross_entropy(..., reduction="none")
loss = loss.view(B, -1).sum(dim=1) / token_counts_per_seq
loss = loss.mean()

With macro-averaging, eval loss is identical across batch sizes and input orderings, enabling a few nice benefits:

  1. We can choose optimal batch size to speed up evaluation, especially when comparing models of different sizes.
  2. Sorting samples by length before batching reduces padding, reducing evaluation time by over 50%.

So I'm wondering:

  1. Is the Trainer’s default (micro-averaging) behavior on purpose—to tie loss scale strictly to total token count?
  2. Does this have any documented effect on training stability or convergence when you vary batch size?
  3. Are there recommended best practices for loss normalization in large-batch LLM training (e.g. should I always override this to macro-average)?

I’d love to hear from anyone who’s dug into this or has empirical experience with different loss-averaging schemes in the 🤗Trainer. Thanks in advance!

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions