-
Notifications
You must be signed in to change notification settings - Fork 129
Description
Description of the bug:
Like several other users found, exporting with dynamic shapes doesn't seem to work. I'm filing a separate issue as the errors I get aren't quite the same.
I've tested this on 0.7.0 as well as on the nightly
pip list | grep ai-edge
ai-edge-litert 2.0.3
ai-edge-litert-nightly 2.1.0rc2.dev20251205
ai-edge-quantizer 0.4.0
ai-edge-quantizer-nightly 0.5.0.dev20251207
ai-edge-torch 0.8.0
I'm trying to export an ALBERT model (impl here: https://github.com/huggingface/transformers/tree/main/src/transformers/models/albert)
I've tested that dynamic sequence length and batch size works fine both with an exported torch snapshot and on my ONNX export of the same model.
If I try to export this model with only sequence length dynamic:
✗ Export to out/albert-zh-ai-edge-only_seq_len_dynamic/albert-zh.tflite failed: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function lower_wrapper at /home/landow/gamedev/chinese_app/data/word_seg/autopinyin/venvagain/lib/python3.11/site-packages/ai_edge_torch/odml_torch/jax_bridge/_wrap.py:78 for jit. This concrete value was not available in Python because it depends on the value of the argument args[1].
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError
While executing %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_albert_embeddings_token_type_ids, 1, 0, %sym_size_int_83), kwargs = {})
Original traceback:
File "/home/landow/gamedev/chinese_app/data/word_seg/autopinyin/venvagain/lib/python3.11/site-packages/transformers/models/albert/modeling_albert.py", line 1124, in forward
outputs = self.albert(
File "/home/landow/gamedev/chinese_app/data/word_seg/autopinyin/venvagain/lib/python3.11/site-packages/transformers/models/albert/modeling_albert.py", line 684, in forward
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)
or if I add batch size
✗ Export to out/albert-zh-ai-edge-only_seq_len_dynamic/albert-zh.tflite failed: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function lower_wrapper at /home/landow/gamedev/chinese_app/data/word_seg/autopinyin/venvagain/lib/python3.11/site-packages/ai_edge_torch/odml_torch/jax_bridge/_wrap.py:78 for jit. This concrete value was not available in Python because it depends on the value of the argument args[1].
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError
While executing %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_albert_embeddings_token_type_ids, 1, 0, %sym_size_int_83), kwargs = {})
Original traceback:
File "/home/landow/gamedev/chinese_app/data/word_seg/autopinyin/venvagain/lib/python3.11/site-packages/transformers/models/albert/modeling_albert.py", line 1124, in forward
outputs = self.albert(
File "/home/landow/gamedev/chinese_app/data/word_seg/autopinyin/venvagain/lib/python3.11/site-packages/transformers/models/albert/modeling_albert.py", line 684, in forward
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)
Reproduction:
from pathlib import Path
import torch
from transformers import AutoModelForTokenClassification
OUT_DIR = "./out"
BASE_MODEL = "uer/albert-base-chinese-cluecorpussmall"
batch = torch.export.Dim("batch", min=1, max=16)
seq_len = torch.export.Dim("sequence_length", min=1, max=512)
only_seq_len_dynamic = {
"input_ids": {
0: 1,
1: seq_len,
},
"attention_mask": {
0: 1,
1: seq_len,
},
}
both_dynamic = {
"input_ids": {
0: batch,
1: seq_len,
},
"attention_mask": {
0: batch,
1: seq_len,
},
}
def export_tflite_ai_edge():
"""Convert models directly to TFLite using ai-edge-torch"""
import ai_edge_torch
model = AutoModelForTokenClassification.from_pretrained(BASE_MODEL)
model = model.cpu()
model.eval()
sample_ids = torch.randint(0, 21128, (1, 128))
sample_mask = torch.ones((1, 128), dtype=torch.int64)
test_cases = [
("both_dynamic", both_dynamic),
("only_seq_len_dynamic", only_seq_len_dynamic),
("static", None),
]
for name, tc in test_cases:
out_dir = Path(f"{OUT_DIR}/albert-zh-ai-edge-{name}")
out_dir.mkdir(parents=True, exist_ok=True)
try:
edge_model = ai_edge_torch.convert(
model,
sample_kwargs={
"input_ids": sample_ids,
"attention_mask": sample_mask,
},
dynamic_shapes=tc,
)
edge_model.export(f"{out_dir}/albert-zh.tflite")
except Exception as e:
print(f"✗ Export to {out_dir}/albert-zh.tflite failed: {e}")
continue
print(f"✓ Exported to {out_dir}/albert-zh.tflite")
if __name__ == "__main__":
export_tflite_ai_edge()Actual vs expected behavior:
Simply setting dynamic_shapes should just work - similar to other conversion packages
model = AutoModelForTokenClassification.from_pretrained("uer/albert-base-chinese-cluecorpussmall")
model = model.cpu()
model.eval()
edge_model = ai_edge_torch.convert(
model,
sample_kwargs={
"input_ids": sample_ids,
"attention_mask": sample_mask,
},
dynamic_shapes= {
"input_ids": {
0: batch,
1: seq_len,
},
"attention_mask": {
0: batch,
1: seq_len,
},
},
)for exmple, the following works in onnx
dummy_input = torch.randint(0, 21128, (1, 128))
# Export model to ONNX with legacy exporter for reliable dynamic axes
torch.onnx.export(
model,
dummy_input,
f"{OUT_DIR}/word_segmenter.onnx",
input_names=["input_ids"],
output_names=["logits"],
dynamic_axes={
"input_ids": {0: "batch_size", 1: "sequence_length"},
"logits": {0: "batch_size", 1: "sequence_length"},
},
opset_version=17,
dynamo=False,
)one interesting differences is that I have nowhere to pass the output shapes as "logits" , if included in the dynamic shapes dict, causes
When `dynamic_shapes` is specified as a dict, its top-level keys must be the arg names ['input_ids', 'attention_mask'] of `inputs`, but here they are ['input_ids', 'attention_mask', 'logits'].
Any other information you'd like to share?
No response