Skip to content

HF Qwen 2 with Thunder returns a slightly different loss function output #1407

Closed
@IvanYashchuk

Description

@IvanYashchuk

🐛 Bug

We need to determine whether Thunder has real accuracy problems computing HF's Qwen 2 model.

The test added in #1406 might fail because the loss computed by the Thunder-generated function is slightly different from HF's implementation. Here's the snippet to reproduce the problem:

import torch
from thunder.dynamo import ThunderCompiler
from transformers import Qwen2Config, Qwen2ForCausalLM
torch.manual_seed(0)

# https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json
configuration = Qwen2Config(
    # Qwen2.5-7B-Instruct uses Grouped-Query Attention, while the default
    # config uses Multi-Head Attention
    num_attention_heads=28,
    num_key_value_heads=4,
    # Scaled down for testing
    hidden_size=56,
    vocab_size=2,
    max_position_embeddings=32,
)
configuration.num_hidden_layers = 1
with torch.device("cuda"):
    model = Qwen2ForCausalLM(configuration).to(torch.bfloat16)

# thunder.jit doesn't work with Qwen2, so we use torch.compile
# https://github.com/Lightning-AI/lightning-thunder/issues/1405
backend = ThunderCompiler()
compiled_model = torch.compile(model, backend=backend, fullgraph=True)

input_ids = torch.randint(0, configuration.vocab_size, (1, configuration.max_position_embeddings), device="cuda")
# input_ids = torch.ones_like(input_ids) * 0
ref_output = model(input_ids=input_ids, labels=input_ids)
ref_loss = ref_output.loss

compiled_output = compiled_model(input_ids=input_ids, labels=input_ids)
compiled_loss = compiled_output.loss
torch.testing.assert_close(compiled_loss, ref_loss)
AssertionError: Scalars are not close!

Expected 0.7005462646484375 but got 0.7004587650299072.
Absolute difference: 8.749961853027344e-05 (up to 1e-05 allowed)
Relative difference: 0.00012490198427392128 (up to 1.3e-06 allowed)

Thunder may return a different result because upcasting and downcasting to bf16 are different. However, we need to know that Thunder's is indeed more accurate by comparing the distance to the fp64 result and the tolerances in the test may need to be tweaked.

cc @apaz-cli

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions