Skip to content

Unbreak optimum-executorch #38646

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

guangy10
Copy link
Contributor

@guangy10 guangy10 commented Jun 6, 2025

What does this PR do?

Revert minimal changes made from #37866 that breaks export to ExecuTorch in huggingface/optimum-executorch when developing from latest transformers trunk

TODO: Will update with tests shortly

Before submitting

  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case. I surfaced the issue in Slack
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker @Cyrilvallez @ydshieh

@guangy10
Copy link
Contributor Author

guangy10 commented Jun 6, 2025

cc @kimishpatel to unblock the work in optimum-et

@ydshieh
Copy link
Collaborator

ydshieh commented Jun 6, 2025

for onnx job, rebase on main will work

@ydshieh
Copy link
Collaborator

ydshieh commented Jun 6, 2025

This might be better to wait @Cyrilvallez 's back.

Some explanation would be nice for them.

@guangy10
Copy link
Contributor Author

This might be better to wait @Cyrilvallez 's back.
Some explanation would be nice for them.

@ydshieh The blamed PR messed up the recipe being used to export the model. For example, models with static cache will be exported using the recipe for hybrid cache due to the changes. This PR is making the minimal changes to just reverted the code that locates the recipe based on cache type explicitly. Can we prioritize to get this PR reviewed? We will need this fix to unblock some work in the downstream in optimum-executorch.

@Cyrilvallez
Copy link
Member

Hey @guangy10! Sorry for the delay, I was on vacations! With a quick glance, checking layer_types attribute should be correct no? Which model does not export with the correct cache?

@guangy10
Copy link
Contributor Author

guangy10 commented Jun 10, 2025

Hey @Cyrilvallez, adding layer_types in some models making it impossible to go to else branch in the following block

if not hasattr(model.config, "layer_types"):
# If `layer_types` is not specified explicitly in the config, there is only 1 type of layers, so
# export will use `StaticCache` by default.
logging.info("Using `StaticCache` for export as `layer_types` is not specified in the config.")
self.model = TorchExportableModuleWithStaticCache(model)
else:
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)

I think it's because you added the layer_types for qwen3 here:

if self.layer_types is None:
self.layer_types = [
"sliding_attention"
if self.sliding_window is not None and i >= self.max_window_layers
else "full_attention"
for i in range(self.num_hidden_layers)
]

So in the downstream when I call export to executorch in optimum-executorch, you can see it's going off. Here is the call stack:

Traceback (most recent call last):
  File "/opt/anaconda3/envs/huggingface/bin/optimum-cli", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/opt/anaconda3/envs/huggingface/lib/python3.11/site-packages/optimum/commands/optimum_cli.py", line 208, in main
    service.run()
  File "/opt/anaconda3/envs/huggingface/lib/python3.11/site-packages/optimum/commands/export/executorch.py", line 104, in run
    main_export(
  File "/opt/anaconda3/envs/huggingface/lib/python3.11/site-packages/optimum/exporters/executorch/__main__.py", line 138, in main_export
    return export_to_executorch(
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/huggingface/lib/python3.11/site-packages/optimum/exporters/executorch/convert.py", line 77, in export_to_executorch
    executorch_progs = recipe_func(model, **kwargs)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/huggingface/lib/python3.11/site-packages/optimum/exporters/executorch/recipes/xnnpack.py", line 98, in export_to_executorch_with_xnnpack
    exported_progs = model.export()
                     ^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/huggingface/lib/python3.11/site-packages/optimum/exporters/executorch/integrations.py", line 58, in export
    exportable_module = TorchExportableModuleForDecoderOnlyLM(self.model, max_batch_size, max_cache_len)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/guangyang/transformers/src/transformers/integrations/executorch.py", line 65, in __init__
    self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/guangyang/transformers/src/transformers/integrations/executorch.py", line 407, in __init__
    raise AssertionError("Model must use 'hybrid' cache implementation")
AssertionError: Model must use 'hybrid' cache implementation
FAILED

Pretty much all models that uses static cache will fail during export due to this issue.

@guangy10 guangy10 force-pushed the unbreak_optimum_et branch from 8f11b60 to 7046b2a Compare June 10, 2025 18:37
@guangy10
Copy link
Contributor Author

@Cyrilvallez I updated the PR with enhanced tests. That is, without reverting the changes in executorch.py, running test_export for these models will fail in the CI. If CI are green, can we get this fix merged to unblock the downstream optimum-executorch work that @kimishpatel and I are working on?

@guangy10 guangy10 force-pushed the unbreak_optimum_et branch from 7046b2a to aefca28 Compare June 10, 2025 18:53
@guangy10
Copy link
Contributor Author

Failure in test_onnx is irrelevant to this PR

@Cyrilvallez
Copy link
Member

Cyrilvallez commented Jun 10, 2025

Humm, but models with layer_types should always be Hybrid, and not the others. What you are experiencing is the fact that we removed the default cache_implementation="hybrid" in the config (to default back to DynamicCache), not the fact that we export with the wrong cache.
So we should just remove this check IMO

@guangy10
Copy link
Contributor Author

guangy10 commented Jun 10, 2025

Humm, but models with layer_types should always be Hybrid, and not the others.

Is Qwen3 hybrid? Some model could work with both hybrid and static, I think the check or the added layer_types will force it always go export with hybrid cache.

What you are experiencing is the fact that we removed the default cache_implementation="hybrid" in the config (to default back to DynamicCache), not the fact that we export with the wrong cache. So we should just remove this check IMO

Which check are you suggesting to remove? Maybe it's more clear if you can comment in the code inline?

@Cyrilvallez
Copy link
Member

Cyrilvallez commented Jun 10, 2025

Well I'm simply talking about the check in your stacktrace here https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/executorch.py#L403-L407

And yes, Qwen3 is hybrid in general, though it does not always have sliding layers (in which case Hybrid and Static caches are equivalent)

@guangy10 guangy10 force-pushed the unbreak_optimum_et branch from aefca28 to c5d9f34 Compare June 10, 2025 22:39
@guangy10
Copy link
Contributor Author

Well I'm simply talking about the check in your stacktrace here https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/executorch.py#L403-L407

And yes, Qwen3 is hybrid in general, though it does not always have sliding layers (in which case Hybrid and Static caches are equivalent)

@Cyrilvallez Looks like there are additional work needed in order to treat HybridCache and StaticCache (hybrid w/o sliding window) in a unified way. If I just removed the mentioned checking as you suggested, the export test still fails due to missing sliding window config.

        from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM

>       exportable_module = TorchExportableModuleForDecoderOnlyLM(model)

tests/models/qwen3/test_modeling_qwen3.py:288:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src/transformers/integrations/executorch.py:65: in __init__
    self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
src/transformers/integrations/executorch.py:404: in __init__
    self.cache = HybridCache(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <transformers.cache_utils.HybridCache object at 0x31ae42310>
config = Qwen3Config {
  "architectures": [
    "Qwen3ForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "eos_token_id": 151643,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_types": [
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention"
  ],
  "max_position_embeddings": 32768,
  "max_window_layers": 28,
  "model_type": "qwen3",
  "num_attention_heads": 16,
  "num_hidden_layers": 28,
  "num_key_value_heads": 8,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 1000000,
  "sliding_window": null,
  "tie_word_embeddings": true,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.53.0.dev0",
  "use_cache": true,
  "use_sliding_window": false,
  "vocab_size": 151936
}

max_batch_size = 1, max_cache_len = 4096, device = device(type='cpu'), dtype = torch.bfloat16, layer_device_map = None

    def __init__(
        self,
        config: PretrainedConfig,
        max_batch_size: int,
        max_cache_len: Optional[int] = None,
        device: Union[torch.device, str, None] = None,
        dtype: torch.dtype = torch.float32,
        layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
    ) -> None:
        super().__init__()
        if not hasattr(config, "sliding_window") or config.sliding_window is None:
>           raise ValueError(
                "Setting `cache_implementation` to 'hybrid' requires the model config supporting "
                "sliding window attention, please check if there is a `sliding_window` field in the model "
                "config and it's not set to None."
            )
E           ValueError: Setting `cache_implementation` to 'hybrid' requires the model config supporting sliding window attention, please check if there is a `sliding_window` field in the model config and it's not set to None.

src/transformers/cache_utils.py:1610: ValueError
================================================================================================================ warnings summary ================================================================================================================
../../../opt/anaconda3/envs/huggingface/lib/python3.11/site-packages/_pytest/config/__init__.py:1441
  /opt/anaconda3/envs/huggingface/lib/python3.11/site-packages/_pytest/config/__init__.py:1441: PytestConfigWarning: Unknown config option: asyncio_default_fixture_loop_scope

    self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================================================================================================ short test summary info =============================================================================================================
FAILED tests/models/qwen3/test_modeling_qwen3.py::Qwen3IntegrationTest::test_export_static_cache - ValueError: Setting `cache_implementation` to 'hybrid' requires the model config supporting sliding window attention, please check if there is a `sliding_window` field in the model config and it's not set to None.

I guess my main motivation for this PR is to restore the behavior to unbreak the downstream work in optimum. We have two export recipes, one uses StaticCache and the other uses HybridCache. The checking upon layer_types in the export recipe here https://github.com/guangy10/transformers/blob/1094dd34f73dae1d9a91a6632635934516612490/src/transformers/integrations/executorch.py#L59
will be evaluated false always hence forcing Qwen3 (and similar models) to use Hybrid cache as "static cache" w/o sliding window, which is not working as shown above. If that's easy to fix, happy to get it corrected in this PR 😄

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, there is a check in the Cache directly as well

@guangy10 guangy10 force-pushed the unbreak_optimum_et branch from 1501526 to c894b9e Compare June 11, 2025 17:31
Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, the change works for me! However, I'm just concerned about having several different APIs (the function and the class) to do seamingly the same thing. IMO we should choose either and standardize, especially since a lot is redundant - could be a future PR though WDYT?

@guangy10 guangy10 force-pushed the unbreak_optimum_et branch from c894b9e to 76fd034 Compare June 11, 2025 18:23
@guangy10
Copy link
Contributor Author

Fixed linter. @Cyrilvallez let me know if it's good to go.

@guangy10 guangy10 requested a review from Cyrilvallez June 11, 2025 18:24
@guangy10 guangy10 force-pushed the unbreak_optimum_et branch from 76fd034 to bb2cff8 Compare June 12, 2025 00:32
Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, last small detail, we don't need to add the view op! Let's remove it then I'll merge 🤗 Sorry for being annoying on this one 😬

@guangy10 guangy10 requested a review from Cyrilvallez June 12, 2025 17:22
@guangy10
Copy link
Contributor Author

Alright, last small detail, we don't need to add the view op! Let's remove it then I'll merge 🤗 Sorry for being annoying on this one 😬

Updated the PR. Should be good to go.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants