Skip to content

Commit addef4e

Browse files
brian-dellabettakylesayrs
authored andcommitted
bugfix AWQ with Llama models and python 3.9 (#1384)
SUMMARY: LlamaAttention.forward has an optional `attention_mask` field that has no default (see [here](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L246)). So `attention_mask=None` must be passed in, otherwise AWQ will error out. The previous check only worked for Python 3.10 and 3.11. This fixes it to be a more general recommended solution that works with Python 3.9 ```python from transformers.models.llama.modeling_llama import LlamaAttention import inspect import typing params = inspect.signature(LlamaAttention.forward).parameters #old check old_check = (params["attention_mask"].annotation._name == "Optional") #new check new_check = (params["attention_mask"].default is inspect.Parameter.empty) print(f"OLD {old_check}, NEW {new_check}") # Python 3.9: OLD False, NEW True # Python 3.11: OLD True, NEW True ``` TEST PLAN: This will resolve the failing e2e test at https://github.com/neuralmagic/llm-compressor-testing/actions/runs/14654995202/job/41128588916#step:15:33208 --------- Signed-off-by: Brian Dellabetta <[email protected]>
1 parent d3c0f0a commit addef4e

File tree

1 file changed

+2
-2
lines changed
  • src/llmcompressor/modifiers/awq

1 file changed

+2
-2
lines changed

src/llmcompressor/modifiers/awq/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -617,12 +617,12 @@ def _sanitize_kwargs(inputs_kwargs, module):
617617
# In case forward pass has optional dependencies that don't default to None.
618618
# This is the case for `LlamaAttention.forward` which has input
619619
# `attention_mask: Optional[torch.Tensor],` (with no `= None` default)
620-
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L269
620+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L246
621621
for k, v in params.items():
622622
if (
623623
k not in sanitized_kwargs
624624
and k != "use_cache"
625-
and getattr(v.annotation, "_name", "") == "Optional"
625+
and v.default is inspect.Parameter.empty
626626
):
627627
sanitized_kwargs[k] = None
628628

0 commit comments

Comments
 (0)