Skip to content

Torch.compile fail during inference with meta-llama/Meta-Llama-3.1-8B-Instruct #34604

@prasiyer

Description

@prasiyer

System Info

  • transformers version: 4.43.3
  • Platform: Linux-5.15.0-1074-azure-x86_64-with-glibc2.31
  • Python version: 3.11.9
  • Huggingface_hub version: 0.23.1
  • Safetensors version: 0.4.3
  • Accelerate version: 0.31.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.1 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: using device_map = "auto" in AutoModelForCausalLM.from_pretrained
  • Using GPU in script?: Yes
  • GPU type: NVIDIA A100 80GB PCIe

Who can help?

@gante , @ArthurZucker
While using torch.compile(), I get the following error. I have included the sample code in the "Steps to reproduce"

Error:
Traceback (most recent call last):
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/gradio/queueing.py", line 536, in process_events
    response = await route_utils.call_process_api(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/gradio/route_utils.py", line 276, in call_process_api
    output = await app.get_blocks().process_api(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/gradio/blocks.py", line 1923, in process_api
    result = await self.call_function(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/gradio/blocks.py", line 1506, in call_function
    prediction = await fn(*processed_input)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/gradio/utils.py", line 785, in async_wrapper
    response = await f(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/gradio/chat_interface.py", line 607, in _submit_fn
    response = await anyio.to_thread.run_sync(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 2134, in run_sync_in_worker_thread
    return await future
           ^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 851, in run
    result = context.run(func, *args)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/vp899/projects/Agent_System/Code/Agent_Launch_UI_v2_Experiments.py", line 253, in contract_analyst_chat
    outputs = model.generate(input_ids, max_new_tokens=500, eos_token_id=terminators, do_sample=True, temperature=0.6, top_p=0.9)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/transformers/generation/utils.py", line 1989, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/transformers/generation/utils.py", line 2932, in _sample
    outputs = self(**model_inputs, return_dict=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
    return _compile(
           ^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 703, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
    transformations(instructions, code_options)
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 482, in transform
    tracer = InstructionTranslator(
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2085, in __init__
    self._throw_if_in_functorch()
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2126, in _throw_if_in_functorch
    eager = torch._dynamo.lookup_backend("eager")
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/backends/registry.py", line 58, in lookup_backend
    _lazy_import()
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/backends/registry.py", line 91, in _lazy_import
    import_submodule(backends)
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1866, in import_submodule
    importlib.import_module(f"{mod.__name__}.{filename[:-3]}")
  File "/anaconda/envs/pi2_py311/lib/python3.11/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 940, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/backends/cudagraphs.py", line 10, in <module>
    from torch._inductor.cudagraph_trees import cudagraphify_impl
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 71, in <module>
    from torch._inductor.compile_fx import (
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 57, in <module>
    from .fx_passes.joint_graph import joint_graph_passes
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_inductor/fx_passes/joint_graph.py", line 12, in <module>
    from ..pattern_matcher import (
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_inductor/pattern_matcher.py", line 46, in <module>
    from .lowering import fallback_node_due_to_unsupported_type
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_inductor/lowering.py", line 6002, in <module>
    import_submodule(kernel)
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1866, in import_submodule
    importlib.import_module(f"{mod.__name__}.{filename[:-3]}")
  File "/anaconda/envs/pi2_py311/lib/python3.11/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_inductor/kernel/flex_attention.py", line 155, in <module>
    flex_attention_template = TritonTemplate(
                              ^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_inductor/select_algorithm.py", line 453, in __init__
    self.template = self._template_from_string(source)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/torch/_inductor/codegen/common.py", line 1720, in _template_from_string
    return env.from_string(source)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/jinja2/environment.py", line 1108, in from_string
    return cls.from_code(self, self.compile(source), gs, None)
                               ^^^^^^^^^^^^^^^^^^^^
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/jinja2/environment.py", line 768, in compile
    self.handle_exception(source=source_hint)
  File "/anaconda/envs/pi2_py311/lib/python3.11/site-packages/jinja2/environment.py", line 939, in handle_exception
    raise rewrite_traceback_stack(source=source)
  File "<unknown>", line 104, in template
torch._dynamo.exc.InternalTorchDynamoError: No filter named 'indent_except_first'.


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, token = llama31_hf_token)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto", token = llama31_hf_token, attn_implementation="flash_attention_2",)
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

...
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)    terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
outputs = model.generate(input_ids, max_new_tokens=500, eos_token_id=terminators, do_sample=True, temperature=0.6, top_p=0.9)

Expected behavior

Model should compile and model.generate should yield the answer

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions