Closed
Description
Description of the bug:
converting LayerNorm module with one dynamic shape leads to TypeError
import torch
import ai_edge_torch
batch, sentence_length, embedding_dim = 20, 5, 10
embedding = torch.randn(batch, sentence_length, embedding_dim)
layer_norm = torch.nn.LayerNorm(embedding_dim)
batch = torch.export.Dim("batch")
dynamic_shapes = {"input": {0: batch}}
# model = ai_edge_torch.convert(layer_norm.eval(), (embedding,))
model = ai_edge_torch.convert(layer_norm.eval(), (embedding,), dynamic_shapes=dynamic_shapes)
model.export("/tmp/layer_norm.tflite")
Actual vs expected behavior:
expected, valid export
actual: TypeError
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[3], line 9
6 dynamic_shapes = {"input": {0: batch}}
8 # model = ai_edge_torch.convert(layer_norm.eval(), (embedding,))
----> 9 model = ai_edge_torch.convert(layer_norm.eval(), (embedding,), dynamic_shapes=dynamic_shapes)
10 model.export("[/tmp/layer_norm.tflite](http://localhost:8888/tmp/layer_norm.tflite)")
File [~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/_convert/converter.py:315](http://localhost:8888/lab/tree/~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/_convert/converter.py#line=314), in convert(module, sample_args, sample_kwargs, strict_export, quant_config, dynamic_shapes, _ai_edge_converter_flags, _saved_model_dir)
312 if _ai_edge_converter_flags is None:
313 _ai_edge_converter_flags = {}
--> 315 return Converter().convert(
316 module,
317 sample_args,
318 sample_kwargs,
319 strict_export=strict_export,
320 quant_config=quant_config,
321 dynamic_shapes=dynamic_shapes,
322 _ai_edge_converter_flags=_ai_edge_converter_flags,
323 _saved_model_dir=_saved_model_dir,
324 )
File [~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/_convert/converter.py:203](http://localhost:8888/lab/tree/~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/_convert/converter.py#line=202), in Converter.convert(self, module, sample_args, sample_kwargs, strict_export, quant_config, dynamic_shapes, _ai_edge_converter_flags, _saved_model_dir)
198 else: # module is provided but not args
199 raise ValueError(
200 "sample_args or sample_kwargs must be provided if a module is"
201 " specified."
202 )
--> 203 converted_model = conversion.convert_signatures(
204 self._signatures,
205 strict_export=strict_export,
206 quant_config=quant_config,
207 _tfl_converter_flags=_ai_edge_converter_flags,
208 _saved_model_dir=_saved_model_dir,
209 )
210 if self._compilation_configs:
211 return conversion.aot_compile(self._compilation_configs, converted_model)
File [~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/_convert/conversion.py:151](http://localhost:8888/lab/tree/~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/_convert/conversion.py#line=150), in convert_signatures(signatures, strict_export, quant_config, _tfl_converter_flags, _saved_model_dir)
149 # Apply default fx passes
150 exported_programs = list(map(_run_convert_passes, exported_programs))
--> 151 tflite_model = lowertools.exported_programs_to_tflite(
152 exported_programs,
153 signatures,
154 quant_config=quant_config,
155 _tfl_converter_flags=_tfl_converter_flags,
156 _saved_model_dir=_saved_model_dir,
157 )
159 return model.TfLiteModel(tflite_model)
File [~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/lowertools/_shim.py:72](http://localhost:8888/lab/tree/~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/lowertools/_shim.py#line=71), in exported_programs_to_tflite(exported_programs, signatures, quant_config, _tfl_converter_flags, _saved_model_dir)
68 if _tfl_converter_flags is None:
69 _tfl_converter_flags = {}
71 bundles: list[utils.MlirBundle] = [
---> 72 utils.exported_program_to_mlir(exported, sig.flat_args)
73 for exported, sig in zip(exported_programs, signatures)
74 ]
76 merged_bundle: utils.MergedBundle = utils.merge_mlir_bundles(
77 bundles, signatures, exported_programs
78 )
80 return utils.merged_bundle_to_tfl_model(
81 merged_bundle,
82 signatures,
(...)
85 _saved_model_dir=_saved_model_dir,
86 )
File [~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/lowertools/odml_torch_utils.py:236](http://localhost:8888/lab/tree/~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/lowertools/odml_torch_utils.py#line=235), in exported_program_to_mlir(exported_program, sample_args)
231 def exported_program_to_mlir(
232 exported_program: torch.export.ExportedProgram,
233 sample_args: tuple[torch.Tensor],
234 ) -> export.MlirLowered:
235 """Converts a ExportedProgram to a MlirLowered."""
--> 236 return odml_torch.export.exported_program_to_mlir(exported_program)
File [~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/odml_torch/export.py:399](http://localhost:8888/lab/tree/~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/odml_torch/export.py#line=398), in exported_program_to_mlir(exported_program, ir_context, _pre_lower_pass)
393 temp_func = func.FuncOp(
394 "temp",
395 ir.FunctionType.get(ir_flat_inputs, []),
396 ip=ir.InsertionPoint.at_block_begin(module.body),
397 )
398 with ir.InsertionPoint(temp_func.add_entry_block()):
--> 399 interpreter.run(*temp_func.arguments, enable_io_processing=False)
400 num_mutations = len(exported_program.graph_signature.buffers_to_mutate)
401 outputs = interpreter.outputs[num_mutations:]
File [~/dev/dynamo/venv/lib/python3.12/site-packages/torch/fx/interpreter.py:171](http://localhost:8888/lab/tree/~/dev/dynamo/venv/lib/python3.12/site-packages/torch/fx/interpreter.py#line=170), in Interpreter.run(self, initial_env, enable_io_processing, *args)
168 continue
170 try:
--> 171 self.env[node] = self.run_node(node)
172 except Exception as e:
173 if self.extra_traceback:
File [~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/odml_torch/export.py:127](http://localhost:8888/lab/tree/~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/odml_torch/export.py#line=126), in LoweringInterpreter.run_node(self, node)
125 with loc:
126 self.lctx = self.lctx.replace(ir_location=loc, node=node)
--> 127 res = super().run_node(node)
128 self.lctx = self.lctx.replace(ir_location=None, node=None)
129 return res
File [~/dev/dynamo/venv/lib/python3.12/site-packages/torch/fx/interpreter.py:240](http://localhost:8888/lab/tree/~/dev/dynamo/venv/lib/python3.12/site-packages/torch/fx/interpreter.py#line=239), in Interpreter.run_node(self, n)
238 assert isinstance(args, tuple)
239 assert isinstance(kwargs, dict)
--> 240 return getattr(self, n.op)(n.target, args, kwargs)
File [~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/odml_torch/export.py:150](http://localhost:8888/lab/tree/~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/odml_torch/export.py#line=149), in LoweringInterpreter.call_function(self, target, args, kwargs)
148 if lowering is None:
149 raise RuntimeError(f"Lowering not found: {target}")
--> 150 return lowering(self.lctx, *args, **kwargs)
File [~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/odml_torch/lowerings/_layer_norm.py:45](http://localhost:8888/lab/tree/~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/odml_torch/lowerings/_layer_norm.py#line=44), in _aten_native_layer_norm(lctx, data, normalized_shape, weight, bias, eps)
39 unnormalized_count = math.prod(data_type.shape) // math.prod(normalized_shape)
40 dest_shape = [
41 1,
42 unnormalized_count,
43 math.prod(normalized_shape),
44 ]
---> 45 dest_type = ir.RankedTensorType.get(dest_shape, data_type.element_type)
47 reshaped_data = stablehlo.reshape(dest_type, data)
49 one = utils.splat(1, data_type.element_type, [unnormalized_count])
TypeError: get(): incompatible function arguments. The following argument types are supported:
1. get(shape: collections.abc.Sequence[int], element_type: jaxlib.mlir._mlir_libs._mlir.ir.Type, encoding: jaxlib.mlir._mlir_libs._mlir.ir.Attribute | None = None, loc: mlir.ir.Location | None = None) -> jaxlib.mlir._mlir_libs._mlir.ir.RankedTensorType
Invoked with types: list, jaxlib.mlir._mlir_libs._mlir.ir.F32Type
While executing %native_layer_norm : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%input, [10], %p_weight, %p_bias, 1e-05), kwargs = {})
GraphModule: class GraphModule(torch.nn.Module):
def forward(self, p_weight: "f32[10][1]", p_bias: "f32[10][1]", input: "f32[s0, 5, 10][50, 10, 1]"):
input_1 = input
# File: /Users/nloriant/dev/dynamo/venv/lib/python3.12/site-packages/torch/_dynamo/external_utils.py:70 in inner, code: return fn(*args, **kwargs)
native_layer_norm = torch.ops.aten.native_layer_norm.default(input_1, [10], p_weight, p_bias, 1e-05); input_1 = p_weight = p_bias = None
getitem: "f32[s0, 5, 10][50, 10, 1]" = native_layer_norm[0]; native_layer_norm = None
return (getitem,)
Original traceback:
File "[/Users/nloriant/dev/dynamo/venv/lib/python3.12/site-packages/torch/_dynamo/external_utils.py", line 70](http://localhost:8888/lab/tree/venv/lib/python3.12/site-packages/torch/_dynamo/external_utils.py#line=69), in inner
return fn(*args, **kwargs)
Any other information you'd like to share?
python 3.12
torch 2.7.1