Skip to content

Support converting SymTypes Node to input proxy #2171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 17, 2025
Merged

Support converting SymTypes Node to input proxy #2171

merged 9 commits into from
Jun 17, 2025

Conversation

kiya00
Copy link
Collaborator

@kiya00 kiya00 commented Jun 2, 2025

Before submitting
  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

While exploring the split reasons and fallback operators in issue #2169, it was found that the current implementation of get_proxy_inputs_from_node does not support SymInt types. It causes graph splits when running models like Undi95/dbrx-base and llama-moe/LLaMA-MoE-v1-3_5B-4_16.

SplitReason(reason_type=<SplitReasonType.EXCEPTION_PROXY_THUNDER_OP: 3>, info='Failed while creating proxy for node with name: expert_w1 and target: <built-in method select of type object at 0x7f10f0184760>, see exception field', exception="`make_input_proxy` received example_value which wasn't Tensor or Tuple")
from transformers import AutoConfig
import torch
from torch.testing import make_tensor
from functools import partial
import time
import argparse


model_id = "Undi95/dbrx-base"

config = AutoConfig.from_pretrained(
        model_id,
        # Scaled down for testing
        vocab_size=16,
        pad_token_id=15,
        max_position_embeddings=32,
        trust_remote_code=True,
        num_hidden_layers=1,)
batch_size = 1
seq_length = 4096
shape = (batch_size, seq_length)

from transformers import AutoModelForCausalLM

with torch.device("cuda"):
    model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16,trust_remote_code=True)
    model.requires_grad_(True)
    make = partial(
        make_tensor,
        low=0,
        high=config.vocab_size,
        device="cuda",
        dtype=torch.int64,
        requires_grad=False,
    )
    a = make(shape)
    inputs = {"input_ids": a, "labels": a}

    from thunder.dynamo import thunderfx
    cmodel = thunderfx(model)

cmodel(**inputs)
infos = cmodel._backend.subgraph_infos
print(f"{len(infos)} segmentations")
for info in infos:
    rs = info.split_reasons
    print(f"{len(rs)} split reasons")
    for i,r in enumerate(rs):
        print(f"    reason {i}: {r}")

@kiya00 kiya00 marked this pull request as ready for review June 2, 2025 12:17
@kiya00 kiya00 requested review from mruberry, lantiga and t-vi as code owners June 2, 2025 12:17
@mruberry
Copy link
Collaborator

mruberry commented Jun 2, 2025

@IvanYashchuk would you like to review this?

@kiya00
Copy link
Collaborator Author

kiya00 commented Jun 10, 2025

Hi @kshitij12345 @IvanYashchuk ,could you help to take a look

Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @kiya00 , I just have one question.

Also, we should add a test, maybe something like the following. I think we should also verify that thunder caching works correctly and calling cfn with different idx value also returns the correct result.

import torch

from thunder.dynamo import thunderfx

def fn(x, idx):
    return torch.select(x, 0, idx)

x = torch.randn(10, 10)
idx = 0
cfn = thunderfx(fn, dynamic=True)
cfn(x, idx)

assert cfn._backend.subgraph_infos[0].split_reasons == []

Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @kiya00

@kiya00
Copy link
Collaborator Author

kiya00 commented Jun 13, 2025

Hi @IvanYashchuk @t-vi , could you help take a look, I think it's ready to merge

@IvanYashchuk IvanYashchuk added the thunderfx for things that could be applicable to the dynamo+thunder frontend label Jun 13, 2025
Copy link
Collaborator

@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@kiya00
Copy link
Collaborator Author

kiya00 commented Jun 13, 2025

Hi @t-vi @Borda , it's ready to merge

@kiya00
Copy link
Collaborator Author

kiya00 commented Jun 17, 2025

Hi @t-vi @mruberry , it's ready to merge

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry mruberry enabled auto-merge (squash) June 17, 2025 14:14
@mruberry mruberry merged commit 01d312c into main Jun 17, 2025
49 checks passed
@mruberry mruberry deleted the fixsymtype branch June 17, 2025 15:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
thunderfx for things that could be applicable to the dynamo+thunder frontend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants