Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get dynamic shapes to work with Phi-3-mini-128k-instruct #1579

Open
tfogal opened this issue Dec 20, 2024 · 0 comments
Open

Get dynamic shapes to work with Phi-3-mini-128k-instruct #1579

tfogal opened this issue Dec 20, 2024 · 0 comments
Assignees
Labels
enhancement New feature or request nemo Issues needed to support NVIDIA NeMo models.

Comments

@tfogal
Copy link
Collaborator

tfogal commented Dec 20, 2024

🚀 Feature

The program below fails due to the use of cache=thunder.core.options.CACHE_OPTIONS.SYMBOLIC_VALUES.

Traceback (most recent call last):
  File "/home/tfogal/dev/ak-bench/bench_targets/llm_peft/sample.py", line 146, in <module>
    cmodel(**d)
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 569, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/models/phi3/modeling_phi3.py", line 1193, in forward
    @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 738, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 822, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 400, in __call__
    raise e
  File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 387, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<eval_with_key>.14", line 5, in forward
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/core/module.py", line 80, in forward
    res = self._forward_fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/__init__.py", line 742, in wrapped
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/__init__.py", line 777, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/__init__.py", line 724, in wrapped
    cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/core/langctxs.py", line 136, in _fn
    result = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/__init__.py", line 236, in cache_info_wrapper
    res = fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/__init__.py", line 630, in get_computation_and_inputs
    computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/executors/torch_autograd.py", line 156, in split_forward_backward
    fw_trace, bw_trace = forward_and_backward_from_trace(primal_trace, torch_autograd=True)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/core/transforms.py", line 3053, in forward_and_backward_from_trace
    forward_trace, result, env = augmented_forward_pass_trace(trace, *trace.args, **trace.kwargs)
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/core/transforms.py", line 2669, in augmented_forward_pass_trace
    trace, result, env = interpret_trace_to_trace(
                         ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/core/trace_interpreter.py", line 169, in interpret_trace_to_trace
    result = prim_func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/core/transforms.py", line 2582, in _vjp_impl
    out_primal, out_residuals = vjp_impl(*primals, **kwargs)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/core/transforms.py", line 2357, in decomposed_fn_aug_fwd_rule
    saved_for_backward = deconstruct_forward_env_for_backward(trace, env)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/core/transforms.py", line 2323, in deconstruct_forward_env_for_backward
    saved_for_backward = tuple(env[sequencify(symbol.output)[0].name].residuals for symbol in bound_symbols)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tfogal/dev/thunder/thunder/core/transforms.py", line 2323, in <genexpr>
    saved_for_backward = tuple(env[sequencify(symbol.output)[0].name].residuals for symbol in bound_symbols)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'int' object has no attribute 'name'

Motivation

With NeMo, we are starting to test fine-tuning with varying sequence lengths, and thus the tensor sizes are changing every step.

Pitch

We do not give an error :-)

Alternatives

The alternative is probably to pad the tensor up to a power of two and compile that.

Additional context

import math
import datasets
import torch
import thunder
from thunder.dynamo import thunderfx
import transformers
import nemo
import nvtx

nvtx.push_range("startup") # force nvtx initialization
nvtx.pop_range()

m_id = "microsoft/Phi-3-mini-128k-instruct"
nvtx.push_range("loading")
cfg = transformers.AutoConfig.from_pretrained(
  m_id,
  torch_dtype=torch.bfloat16,
  num_hidden_layers=2, # scale down for testing
)
cfg.hidden_size = cfg.num_attention_heads
with torch.device("cuda"):
  model = transformers.AutoModelForCausalLM.from_config(cfg).to(torch.bfloat16)

tokenizer = transformers.AutoTokenizer.from_pretrained(
  m_id,
  torch_dtype='auto',
  trust_remote_code=False,

symbolic = thunder.core.options.CACHE_OPTIONS.SYMBOLIC_VALUES
nvtx.pop_range() # loading
cmodel = thunderfx(model, cache=symbolic)

def argument_details(args: list[torch.Tensor]):
  for a in args:
    if isinstance(a, torch.Tensor):
      print(f"arg {a.shape=}")
    else:
      print(f"arg {a=}")

@nvtx.annotate()
def make_squad_hf_dataset(tokenizer):
    EOS_TOKEN = tokenizer.eos_token  # Must add EOS_TOKEN

    def formatting_prompts_func(examples):
        alpaca_prompt = """Below is an instruction that describes a task,
        ### Instruction:
        {}

        ### Input:
        {}

        ### Response:
        {}"""
        print("-- FORMATTING PROMPTS!")
        instruction = examples["context"]
        input = examples["question"]
        output = examples["answers"]['text']
        if isinstance(output, list):
            output = output[0]
        text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
        ans = tokenizer(text)
        ans['labels'] = list(ans['input_ids'][1:])
        ans['input_ids'] = ans['input_ids'][:-1]
        ans['attention_mask'] = ans['attention_mask'][:-1]
        print("answer:", ans)
        return ans

    datamodule = datasets.load_dataset("rajpurkar/squad", split="train[:100]")
    return datamodule.map(
        formatting_prompts_func,
        batched=False,
        batch_size=2,
        remove_columns=["id", "title", "context", "question", 'answers'],
    )

@staticmethod
def collate_fn(batch, pad_token_id=0):
    def batchify(tensor):
        if tensor.ndim == 1:
            return tensor.unsqueeze_(0)
        return tensor

    def extract_key_from_dicts(batch, key):
        return list(map(lambda x: x[key], batch))

    def pad_within_micro(batch, pad_token_id):
        max_len = max(map(len, batch))
        return [item + [pad_token_id] * (max_len - len(item)) for item in batch]


    return {
        key: batchify(
            torch.LongTensor(
                pad_within_micro(
                    extract_key_from_dicts(batch, key),
                    pad_token_id,
                )
            )
        )
        for key in batch[0].keys()
    }

#ds = datasets.load_dataset("rajpurkar/squad")
dm = make_squad_hf_dataset(tokenizer)
loader = torch.utils.data.DataLoader(dm, collate_fn=collate_fn, num_workers=1, pin_memory=True)

print(dm)
counter = 0
for d in loader:
  #print(d)
  nvtx.push_range("moving to GPU")
  d = {k: v.cuda() for k,v in d.items()}
  nvtx.pop_range()
  for k in d:
    print(f"{k=}:", d[k].shape)
  nvtx.push_range(f"run model {d['input_ids'].shape}")
  cmodel(**d)
  nvtx.pop_range()
  #print(thunder.last_traces(cmodel)[-1])

  counter = counter + 1
  if counter >= 15:
    break

cc @tfogal

@tfogal tfogal added enhancement New feature or request nemo Issues needed to support NVIDIA NeMo models. labels Dec 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request nemo Issues needed to support NVIDIA NeMo models.
Projects
None yet
Development

No branches or pull requests

2 participants