Open
Description
🚀 Feature
The program below fails due to the use of cache=thunder.core.options.CACHE_OPTIONS.SYMBOLIC_VALUES
.
Traceback (most recent call last):
File "/home/tfogal/dev/ak-bench/bench_targets/llm_peft/sample.py", line 146, in <module>
cmodel(**d)
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 569, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/transformers/models/phi3/modeling_phi3.py", line 1193, in forward
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 738, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 822, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 400, in __call__
raise e
File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 387, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<eval_with_key>.14", line 5, in forward
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tfogal/dev/thunder/thunder/core/module.py", line 80, in forward
res = self._forward_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tfogal/dev/thunder/thunder/__init__.py", line 742, in wrapped
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/tfogal/dev/thunder/thunder/__init__.py", line 777, in fn_
cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tfogal/dev/thunder/thunder/__init__.py", line 724, in wrapped
cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tfogal/dev/thunder/thunder/core/langctxs.py", line 136, in _fn
result = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/tfogal/dev/thunder/thunder/__init__.py", line 236, in cache_info_wrapper
res = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/tfogal/dev/thunder/thunder/__init__.py", line 630, in get_computation_and_inputs
computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tfogal/dev/thunder/thunder/executors/torch_autograd.py", line 156, in split_forward_backward
fw_trace, bw_trace = forward_and_backward_from_trace(primal_trace, torch_autograd=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tfogal/dev/thunder/thunder/core/transforms.py", line 3053, in forward_and_backward_from_trace
forward_trace, result, env = augmented_forward_pass_trace(trace, *trace.args, **trace.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tfogal/dev/thunder/thunder/core/transforms.py", line 2669, in augmented_forward_pass_trace
trace, result, env = interpret_trace_to_trace(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tfogal/dev/thunder/thunder/core/trace_interpreter.py", line 169, in interpret_trace_to_trace
result = prim_func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tfogal/dev/thunder/thunder/core/transforms.py", line 2582, in _vjp_impl
out_primal, out_residuals = vjp_impl(*primals, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tfogal/dev/thunder/thunder/core/transforms.py", line 2357, in decomposed_fn_aug_fwd_rule
saved_for_backward = deconstruct_forward_env_for_backward(trace, env)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tfogal/dev/thunder/thunder/core/transforms.py", line 2323, in deconstruct_forward_env_for_backward
saved_for_backward = tuple(env[sequencify(symbol.output)[0].name].residuals for symbol in bound_symbols)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tfogal/dev/thunder/thunder/core/transforms.py", line 2323, in <genexpr>
saved_for_backward = tuple(env[sequencify(symbol.output)[0].name].residuals for symbol in bound_symbols)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'int' object has no attribute 'name'
Motivation
With NeMo, we are starting to test fine-tuning with varying sequence lengths, and thus the tensor sizes are changing every step.
Pitch
We do not give an error :-)
Alternatives
The alternative is probably to pad the tensor up to a power of two and compile that.
Additional context
import math
import datasets
import torch
import thunder
from thunder.dynamo import thunderfx
import transformers
import nemo
import nvtx
nvtx.push_range("startup") # force nvtx initialization
nvtx.pop_range()
m_id = "microsoft/Phi-3-mini-128k-instruct"
nvtx.push_range("loading")
cfg = transformers.AutoConfig.from_pretrained(
m_id,
torch_dtype=torch.bfloat16,
num_hidden_layers=2, # scale down for testing
)
cfg.hidden_size = cfg.num_attention_heads
with torch.device("cuda"):
model = transformers.AutoModelForCausalLM.from_config(cfg).to(torch.bfloat16)
tokenizer = transformers.AutoTokenizer.from_pretrained(
m_id,
torch_dtype='auto',
trust_remote_code=False,
symbolic = thunder.core.options.CACHE_OPTIONS.SYMBOLIC_VALUES
nvtx.pop_range() # loading
cmodel = thunderfx(model, cache=symbolic)
def argument_details(args: list[torch.Tensor]):
for a in args:
if isinstance(a, torch.Tensor):
print(f"arg {a.shape=}")
else:
print(f"arg {a=}")
@nvtx.annotate()
def make_squad_hf_dataset(tokenizer):
EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
def formatting_prompts_func(examples):
alpaca_prompt = """Below is an instruction that describes a task,
### Instruction:
{}
### Input:
{}
### Response:
{}"""
print("-- FORMATTING PROMPTS!")
instruction = examples["context"]
input = examples["question"]
output = examples["answers"]['text']
if isinstance(output, list):
output = output[0]
text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
ans = tokenizer(text)
ans['labels'] = list(ans['input_ids'][1:])
ans['input_ids'] = ans['input_ids'][:-1]
ans['attention_mask'] = ans['attention_mask'][:-1]
print("answer:", ans)
return ans
datamodule = datasets.load_dataset("rajpurkar/squad", split="train[:100]")
return datamodule.map(
formatting_prompts_func,
batched=False,
batch_size=2,
remove_columns=["id", "title", "context", "question", 'answers'],
)
@staticmethod
def collate_fn(batch, pad_token_id=0):
def batchify(tensor):
if tensor.ndim == 1:
return tensor.unsqueeze_(0)
return tensor
def extract_key_from_dicts(batch, key):
return list(map(lambda x: x[key], batch))
def pad_within_micro(batch, pad_token_id):
max_len = max(map(len, batch))
return [item + [pad_token_id] * (max_len - len(item)) for item in batch]
return {
key: batchify(
torch.LongTensor(
pad_within_micro(
extract_key_from_dicts(batch, key),
pad_token_id,
)
)
)
for key in batch[0].keys()
}
#ds = datasets.load_dataset("rajpurkar/squad")
dm = make_squad_hf_dataset(tokenizer)
loader = torch.utils.data.DataLoader(dm, collate_fn=collate_fn, num_workers=1, pin_memory=True)
print(dm)
counter = 0
for d in loader:
#print(d)
nvtx.push_range("moving to GPU")
d = {k: v.cuda() for k,v in d.items()}
nvtx.pop_range()
for k in d:
print(f"{k=}:", d[k].shape)
nvtx.push_range(f"run model {d['input_ids'].shape}")
cmodel(**d)
nvtx.pop_range()
#print(thunder.last_traces(cmodel)[-1])
counter = counter + 1
if counter >= 15:
break
cc @tfogal