Skip to content

TypeError when converting LayerNorm with dynamic shape #711

Closed
@gsigms

Description

@gsigms

Description of the bug:

converting LayerNorm module with one dynamic shape leads to TypeError

import torch
import ai_edge_torch

batch, sentence_length, embedding_dim = 20, 5, 10
embedding = torch.randn(batch, sentence_length, embedding_dim)
layer_norm = torch.nn.LayerNorm(embedding_dim)

batch = torch.export.Dim("batch")
dynamic_shapes = {"input": {0: batch}}

# model = ai_edge_torch.convert(layer_norm.eval(), (embedding,))
model = ai_edge_torch.convert(layer_norm.eval(), (embedding,), dynamic_shapes=dynamic_shapes)
model.export("/tmp/layer_norm.tflite")

Actual vs expected behavior:

expected, valid export
actual: TypeError

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[3], line 9
      6 dynamic_shapes = {"input": {0: batch}}
      8 # model = ai_edge_torch.convert(layer_norm.eval(), (embedding,))
----> 9 model = ai_edge_torch.convert(layer_norm.eval(), (embedding,), dynamic_shapes=dynamic_shapes)
     10 model.export("[/tmp/layer_norm.tflite](http://localhost:8888/tmp/layer_norm.tflite)")

File [~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/_convert/converter.py:315](http://localhost:8888/lab/tree/~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/_convert/converter.py#line=314), in convert(module, sample_args, sample_kwargs, strict_export, quant_config, dynamic_shapes, _ai_edge_converter_flags, _saved_model_dir)
    312 if _ai_edge_converter_flags is None:
    313   _ai_edge_converter_flags = {}
--> 315 return Converter().convert(
    316     module,
    317     sample_args,
    318     sample_kwargs,
    319     strict_export=strict_export,
    320     quant_config=quant_config,
    321     dynamic_shapes=dynamic_shapes,
    322     _ai_edge_converter_flags=_ai_edge_converter_flags,
    323     _saved_model_dir=_saved_model_dir,
    324 )

File [~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/_convert/converter.py:203](http://localhost:8888/lab/tree/~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/_convert/converter.py#line=202), in Converter.convert(self, module, sample_args, sample_kwargs, strict_export, quant_config, dynamic_shapes, _ai_edge_converter_flags, _saved_model_dir)
    198   else:  # module is provided but not args
    199     raise ValueError(
    200         "sample_args or sample_kwargs must be provided if a module is"
    201         " specified."
    202     )
--> 203 converted_model = conversion.convert_signatures(
    204     self._signatures,
    205     strict_export=strict_export,
    206     quant_config=quant_config,
    207     _tfl_converter_flags=_ai_edge_converter_flags,
    208     _saved_model_dir=_saved_model_dir,
    209 )
    210 if self._compilation_configs:
    211   return conversion.aot_compile(self._compilation_configs, converted_model)

File [~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/_convert/conversion.py:151](http://localhost:8888/lab/tree/~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/_convert/conversion.py#line=150), in convert_signatures(signatures, strict_export, quant_config, _tfl_converter_flags, _saved_model_dir)
    149 # Apply default fx passes
    150 exported_programs = list(map(_run_convert_passes, exported_programs))
--> 151 tflite_model = lowertools.exported_programs_to_tflite(
    152     exported_programs,
    153     signatures,
    154     quant_config=quant_config,
    155     _tfl_converter_flags=_tfl_converter_flags,
    156     _saved_model_dir=_saved_model_dir,
    157 )
    159 return model.TfLiteModel(tflite_model)

File [~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/lowertools/_shim.py:72](http://localhost:8888/lab/tree/~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/lowertools/_shim.py#line=71), in exported_programs_to_tflite(exported_programs, signatures, quant_config, _tfl_converter_flags, _saved_model_dir)
     68 if _tfl_converter_flags is None:
     69   _tfl_converter_flags = {}
     71 bundles: list[utils.MlirBundle] = [
---> 72     utils.exported_program_to_mlir(exported, sig.flat_args)
     73     for exported, sig in zip(exported_programs, signatures)
     74 ]
     76 merged_bundle: utils.MergedBundle = utils.merge_mlir_bundles(
     77     bundles, signatures, exported_programs
     78 )
     80 return utils.merged_bundle_to_tfl_model(
     81     merged_bundle,
     82     signatures,
   (...)
     85     _saved_model_dir=_saved_model_dir,
     86 )

File [~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/lowertools/odml_torch_utils.py:236](http://localhost:8888/lab/tree/~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/lowertools/odml_torch_utils.py#line=235), in exported_program_to_mlir(exported_program, sample_args)
    231 def exported_program_to_mlir(
    232     exported_program: torch.export.ExportedProgram,
    233     sample_args: tuple[torch.Tensor],
    234 ) -> export.MlirLowered:
    235   """Converts a ExportedProgram to a MlirLowered."""
--> 236   return odml_torch.export.exported_program_to_mlir(exported_program)

File [~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/odml_torch/export.py:399](http://localhost:8888/lab/tree/~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/odml_torch/export.py#line=398), in exported_program_to_mlir(exported_program, ir_context, _pre_lower_pass)
    393 temp_func = func.FuncOp(
    394     "temp",
    395     ir.FunctionType.get(ir_flat_inputs, []),
    396     ip=ir.InsertionPoint.at_block_begin(module.body),
    397 )
    398 with ir.InsertionPoint(temp_func.add_entry_block()):
--> 399   interpreter.run(*temp_func.arguments, enable_io_processing=False)
    400   num_mutations = len(exported_program.graph_signature.buffers_to_mutate)
    401   outputs = interpreter.outputs[num_mutations:]

File [~/dev/dynamo/venv/lib/python3.12/site-packages/torch/fx/interpreter.py:171](http://localhost:8888/lab/tree/~/dev/dynamo/venv/lib/python3.12/site-packages/torch/fx/interpreter.py#line=170), in Interpreter.run(self, initial_env, enable_io_processing, *args)
    168     continue
    170 try:
--> 171     self.env[node] = self.run_node(node)
    172 except Exception as e:
    173     if self.extra_traceback:

File [~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/odml_torch/export.py:127](http://localhost:8888/lab/tree/~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/odml_torch/export.py#line=126), in LoweringInterpreter.run_node(self, node)
    125 with loc:
    126   self.lctx = self.lctx.replace(ir_location=loc, node=node)
--> 127   res = super().run_node(node)
    128   self.lctx = self.lctx.replace(ir_location=None, node=None)
    129 return res

File [~/dev/dynamo/venv/lib/python3.12/site-packages/torch/fx/interpreter.py:240](http://localhost:8888/lab/tree/~/dev/dynamo/venv/lib/python3.12/site-packages/torch/fx/interpreter.py#line=239), in Interpreter.run_node(self, n)
    238 assert isinstance(args, tuple)
    239 assert isinstance(kwargs, dict)
--> 240 return getattr(self, n.op)(n.target, args, kwargs)

File [~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/odml_torch/export.py:150](http://localhost:8888/lab/tree/~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/odml_torch/export.py#line=149), in LoweringInterpreter.call_function(self, target, args, kwargs)
    148 if lowering is None:
    149   raise RuntimeError(f"Lowering not found: {target}")
--> 150 return lowering(self.lctx, *args, **kwargs)

File [~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/odml_torch/lowerings/_layer_norm.py:45](http://localhost:8888/lab/tree/~/dev/dynamo/venv/lib/python3.12/site-packages/ai_edge_torch/odml_torch/lowerings/_layer_norm.py#line=44), in _aten_native_layer_norm(lctx, data, normalized_shape, weight, bias, eps)
     39 unnormalized_count = math.prod(data_type.shape) // math.prod(normalized_shape)
     40 dest_shape = [
     41     1,
     42     unnormalized_count,
     43     math.prod(normalized_shape),
     44 ]
---> 45 dest_type = ir.RankedTensorType.get(dest_shape, data_type.element_type)
     47 reshaped_data = stablehlo.reshape(dest_type, data)
     49 one = utils.splat(1, data_type.element_type, [unnormalized_count])

TypeError: get(): incompatible function arguments. The following argument types are supported:
    1. get(shape: collections.abc.Sequence[int], element_type: jaxlib.mlir._mlir_libs._mlir.ir.Type, encoding: jaxlib.mlir._mlir_libs._mlir.ir.Attribute | None = None, loc: mlir.ir.Location | None = None) -> jaxlib.mlir._mlir_libs._mlir.ir.RankedTensorType

Invoked with types: list, jaxlib.mlir._mlir_libs._mlir.ir.F32Type

While executing %native_layer_norm : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%input, [10], %p_weight, %p_bias, 1e-05), kwargs = {})
GraphModule: class GraphModule(torch.nn.Module):
    def forward(self, p_weight: "f32[10][1]", p_bias: "f32[10][1]", input: "f32[s0, 5, 10][50, 10, 1]"):
        input_1 = input
        
         # File: /Users/nloriant/dev/dynamo/venv/lib/python3.12/site-packages/torch/_dynamo/external_utils.py:70 in inner, code: return fn(*args, **kwargs)
        native_layer_norm = torch.ops.aten.native_layer_norm.default(input_1, [10], p_weight, p_bias, 1e-05);  input_1 = p_weight = p_bias = None
        getitem: "f32[s0, 5, 10][50, 10, 1]" = native_layer_norm[0];  native_layer_norm = None
        return (getitem,)
        

Original traceback:
  File "[/Users/nloriant/dev/dynamo/venv/lib/python3.12/site-packages/torch/_dynamo/external_utils.py", line 70](http://localhost:8888/lab/tree/venv/lib/python3.12/site-packages/torch/_dynamo/external_utils.py#line=69), in inner
    return fn(*args, **kwargs)

Any other information you'd like to share?

python 3.12
torch 2.7.1

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions