Skip to content

InputCellState is incorrectly populated for TFLite model to MLIR conversion using flatbuffer_translate #397

Open
@gaikwadrahul8

Description

@gaikwadrahul8

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
    Linux Debian 11
  • TensorFlow installed from (source or binary):
    Compiled from source
  • TensorFlow version (or github SHA if from source):
    744dad26ef526690319042030f776e6f7e62dbc8

Standalone code to reproduce the issue
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.

import numpy as np

# Define and create the model
model = tf.keras.models.Sequential([
    tf.keras.layers.Input(shape=(3, 5), name='input'),
    tf.keras.layers.LSTM(10, time_major=False, return_sequences=True)
])

model.compile(optimizer='adam',
              loss='mean_squared_error',
              metrics=['accuracy'])

run_model = tf.function(lambda x: model(x))
concrete_func = run_model.get_concrete_function(
    tf.TensorSpec([batchSize, sequenceLength, numFeatures], model.inputs[0].dtype))

print("hidden_states: ", model.layers[0].states[0])
print("cell_states: ", model.layers[0].states[1])

# model directory.
MODEL_DIR = "/tmp/lstmNet"
model.save(MODEL_DIR, save_format="tf", signatures=concrete_func)

converter = tf.lite.TFLiteConverter.from_saved_model(MODEL_DIR)
tflite_model = converter.convert()
  
# Save the TF Lite model.
with tf.io.gfile.GFile('/tmp/lstmNet.tflite', 'wb') as f:
  f.write(tflite_model)

Any other info / logs

Using flatbuffer_translate to convert the generated TFLite model to MLIR produces:

$ > flatbuffer_translate --tflite-flatbuffer-to-mlir /tmp/lstmNet.tflite

module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
  func.func @predict(%arg0: tensor<1x5x3xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1x5x10xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_x:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
    %0 = "tfl.pseudo_const"() {value = dense<[[-0.130185053, -0.0151278675, 0.0130760074], [-0.258772284, 0.299689293, -0.195314541], [0.252850413, -0.259092718, -0.0803229808], [-0.220947981, 0.155216038, 0.108377606], [0.00254765153, 0.111942321, -0.219952658], [0.206842721, -0.193888605, 0.1106188], [0.0955285131, 0.157347143, 0.221373796], [-0.276973069, -0.0735740363, -2.882380e-01], [0.012721926, 0.0903562903, -0.161965311], [-0.119528085, -0.037569046, -0.362928301]]> : tensor<10x3xf32>} : () -> tensor<10x3xf32>
    %1 = "tfl.pseudo_const"() {value = dense<[[-0.0352886915, 0.21145165, -0.165998831], [0.155003309, 0.144935846, 0.217351139], [-0.351629466, 0.341497242, -0.217549637], [0.0939139425, -0.0606328547, 0.197987914], [0.339334488, -0.0430043638, -0.193897158], [0.188981593, -0.00256928802, 0.357774317], [-0.053791374, -0.159659907, -0.334026635], [-0.313022763, 0.120892107, 0.365564883], [-0.0173099339, 0.0726312696, -0.256803274], [-0.0634435713, 0.320655167, -0.342872471]]> : tensor<10x3xf32>} : () -> tensor<10x3xf32>
    %2 = "tfl.pseudo_const"() {value = dense<[[-0.0151669681, 0.246311307, -0.0844985544], [0.00762689113, 0.150569379, -0.275011361], [0.0549676716, 0.0834532678, 0.159000754], [0.0447338223, -0.339231104, 0.134988308], [0.160350919, -0.0878992974, 0.0488999486], [0.323455155, 0.345792234, 0.061250925], [0.0837553441, -0.272862256, 0.0991969704], [0.0828025043, -0.364639938, -0.144624218], [0.343984544, -0.183882296, -0.358834654], [-0.133859664, -0.0814070403, 0.36716789]]> : tensor<10x3xf32>} : () -> tensor<10x3xf32>
    %3 = "tfl.pseudo_const"() {value = dense<[[0.284724414, 0.148204803, -0.349271417], [-0.302493066, 0.158759058, -0.236835971], [0.334069431, -0.00801679491, 0.00450757146], [0.0284658968, 0.0569611788, -0.167743862], [-0.12706618, 7.704550e-02, -0.0319027305], [0.0807663202, -0.0853067636, -0.171359152], [-0.143240675, 0.320195258, -0.107624263], [0.134614825, 0.137890339, 0.220042884], [-0.0685030818, 0.0266759694, -0.279772133], [-0.123277575, -0.130606934, 0.155195773]]> : tensor<10x3xf32>} : () -> tensor<10x3xf32>
    %4 = "tfl.pseudo_const"() {value = dense<[[0.17501545, 0.386878431, -0.0484701507, -0.0765462443, -0.181067079, -0.0750025958, -0.185812324, 0.0773477107, -0.00309249177, -0.0284322314], [0.134114936, 0.176706672, 0.189387321, 0.149077475, 0.142265841, 0.263391763, 0.0858101398, -0.334216356, -0.186852977, 0.164630443], [-0.00760664651, -0.217253342, 0.201065645, 0.248932779, -0.109697223, 0.0161245521, 0.059841346, -0.260818273, -0.0882004871, -0.274861604], [0.348256409, 5.338460e-02, 0.0468219183, -0.0478721932, -0.165056169, -0.0841045827, 0.321953684, -0.164883882, -0.110893697, 0.123676449], [0.128447369, 0.0639217123, 0.224082395, -0.0548633076, 0.0286352951, 0.224867627, -0.0468161702, 0.208391294, -0.0053243367, -0.0550255962], [-8.189090e-02, 0.0486251637, 0.0460817255, 0.107632667, 0.118330166, -0.0682023465, -0.109725371, 0.125324547, -0.0441931672, 0.148130983], [0.0746245757, 0.124206074, 0.176272869, 0.0834054648, 0.173528254, -0.187932938, -0.215293854, -0.103552267, -0.141145512, 0.0601227432], [0.0378115289, 0.0337749943, 0.0378107131, 0.160787702, -0.216121212, 0.229908451, -0.0723084137, 0.226159409, 0.0131477024, 0.0372102521], [0.0926874801, -0.06026401, -0.0561813228, -0.148479104, 0.287242681, 0.0332023241, -0.220059201, 0.0408726893, 0.191863343, 0.0540938973], [0.133774817, -0.278124928, -0.113644354, -0.0739784315, -0.316670179, -0.11459551, -0.264918804, -0.00448995735, -0.0878850519, -0.133028492]]> : tensor<10x10xf32>} : () -> tensor<10x10xf32>
    %5 = "tfl.pseudo_const"() {value = dense<[[0.264959276, -0.259054482, -0.110992216, -0.0414756909, 0.0988652482, 0.33331418, -0.348624617, -0.13201724, 0.00749636581, 0.0932318419], [0.210364223, -0.0775818601, -0.0835916772, 0.21802707, 0.0432840511, -0.0722324625, -0.140951559, 0.150197655, 0.0137763629, 0.139982209], [0.0820472389, 0.230024397, -0.156220302, 0.391181529, -0.0579637811, 0.0591350347, 0.0986427441, 0.237236544, 0.187094688, 0.0165050365], [-0.25562951, 0.0437992811, -0.146220565, 0.107471354, 0.00600573421, 0.0497420132, -0.134210557, 0.0306625273, -0.00543002971, -0.0933061838], [-0.0681353137, 0.187344134, -0.0387082808, -0.0682450234, -0.184748515, 0.295936018, 0.0660243928, -0.00609975681, -0.0220854413, -0.0215770602], [0.123715289, -0.149823353, 0.200297117, -0.138535887, 0.0972650945, -0.100543253, 0.0232232288, -0.0455124453, 0.188970223, 0.0623518042], [0.051465977, -0.16720742, -0.192865178, -0.255042017, -0.216406658, 0.00651523052, 0.182573959, -0.127715543, 0.197372794, -0.222561017], [0.126112625, 0.23686114, -0.0976715758, -0.0628687739, 0.00957589969, -0.211140588, -0.284029543, -0.145181119, -0.167763934, -0.224568278], [-0.0698507354, 0.0349835455, 0.0666678548, 0.0400003493, 0.246126533, -0.107588813, 0.103801601, -0.0709155499, -0.131777883, -0.106253646], [-0.207862586, 0.016677089, -0.237795576, -0.033603251, 0.227860093, -0.00418378413, 0.0745725333, -0.1573136, 0.183234155, 0.0573116578]]> : tensor<10x10xf32>} : () -> tensor<10x10xf32>
    %6 = "tfl.pseudo_const"() {value = dense<[[0.120267294, -0.0315921716, -0.207551152, -0.0469613187, -0.0472462848, 0.136440426, 0.0656398907, 0.0643941164, -0.214579433, 0.25941658], [-0.132091984, 0.142318785, 0.158400849, 0.0916790738, -0.222117975, -0.10034211, -0.134736136, 0.0465843715, 0.0484883301, 0.262922645], [0.0632713437, -0.15643549, -0.410313338, -0.0796892642, 0.0193372741, -0.192711219, 0.084401071, -0.0132020218, 0.0214749519, 0.217391357], [-0.023331102, -0.234836608, 0.283015937, -0.27972725, -0.0461517051, -0.202469319, -0.140018716, 0.266479582, -0.199459061, 0.0869432539], [-0.173759624, -0.146702722, 0.00280229794, 0.128232807, -0.13735646, -0.156450719, -0.197670206, -0.402188092, 0.182303295, 0.127702326], [0.1447341, -0.234402686, 0.017554732, 0.1726809, 0.14695932, 0.111371085, 0.109428726, 0.268820375, -0.0866426378, -0.0679584667], [-0.0736832693, -0.108465984, 0.223148763, 0.00207955297, 0.115413204, -0.0324923247, 0.151184335, -0.0644928366, 0.0295013655, -0.0610723197], [0.104881756, 0.166022196, -0.142411351, 0.142822742, 0.12192411, -1.49629079E-4, -0.335838586, -0.156993166, -0.134479508, -0.183696106], [0.131517142, 0.164839372, 0.0969391465, -0.153943628, 0.251961887, 0.0231811684, -0.0071518924, -0.0611652061, 0.0723072588, -0.00650136918], [0.184471652, 0.0503823385, -0.0937841534, 0.118845418, 0.154655725, -0.47709918, 0.186724663, 0.162837937, -0.161638841, -0.274689823]]> : tensor<10x10xf32>} : () -> tensor<10x10xf32>
    %7 = "tfl.pseudo_const"() {value = dense<[[0.0217915159, 0.206113726, -0.190372929, -0.422592759, 0.0248149242, -0.0202117879, 7.80003611E-5, 0.0636207908, -0.129249915, 0.195001438], [-0.0265121292, 0.157053322, -0.116082504, -0.0937363803, 0.155826345, 0.10708677, 0.110328041, -0.136170894, 0.143848643, -0.217342377], [-0.182416916, 0.0980726927, 6.768720e-02, -0.134475529, 0.00583284162, -0.0480525941, 0.0232472904, -0.146422684, -0.125577226, 0.231070623], [-0.404781759, -0.0607755706, 0.0381766148, 0.060485743, 0.0692789554, -0.0851613953, -0.151829481, 0.0906099975, -0.0977554768, 0.120315634], [0.147763401, 0.0345406979, 0.121470168, -0.0796541571, -0.0689968764, 0.0199192017, -0.187039837, -0.0605593659, 0.307765067, -0.132991672], [0.128111526, 0.0928409397, 0.315211028, -0.200447187, -0.0918775648, -0.0101282364, 0.0570981055, -0.0672892854, -0.0190094449, 0.0233988091], [-0.00752805918, -0.0974235087, 0.0578115322, -0.154167339, 0.30695793, 0.159499139, -0.102361277, 0.186416253, 0.11203637, -0.187620446], [0.168520495, -0.0998088344, -0.0158915576, 0.101420805, -0.122099787, 0.0111542856, -0.049965702, -0.115302391, -0.121384457, -0.00166527322], [-0.169698805, -0.10477908, -0.141094357, -0.10161002, 0.02446693, 0.249826044, 0.0071637705, -0.052272439, -0.567903876, -0.189442679], [0.225991249, -0.16400744, -0.0723658279, 0.153675273, 0.228954688, -0.0319412872, 0.0650407076, -0.126990899, -0.0396433137, 0.326944739]]> : tensor<10x10xf32>} : () -> tensor<10x10xf32>
    %8 = "tfl.no_value"() {value} : () -> none
    %9 = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<10xf32>} : () -> tensor<10xf32>
    %10 = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<10xf32>} : () -> tensor<10xf32>
    %11 = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
    %12 = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
    %13 = "tfl.unidirectional_sequence_lstm"(%arg0, %0, %1, %2, %3, %4, %5, %6, %7, %8, %8, %8, %9, %10, %9, %9, %8, %8, %11, %12, %8, %8, %8, %8) {asymmetric_quantize_inputs = false, cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<1x5x3xf32>, tensor<10x3xf32>, tensor<10x3xf32>, tensor<10x3xf32>, tensor<10x3xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<1x10xf32>, tensor<1x10xf32>, none, none, none, none) -> tensor<1x5x10xf32>
    return %13 : tensor<1x5x10xf32>
  }
}

Given that the model.layers[0].states[1] is none the InputCellState value for the tfl.unidirectional_sequence_lstm op which is %12 should have been all zeros, but it is all ones. Note that this value is interpreted as all zeros when using the tf_tfl_translate command :

$> tf_tfl_translate --savedmodel-signaturedefs-to-mlir /tmp/lstmNet/ --output-mlir

module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1286 : i32}, tf_saved_model.semantics, tfl._legalize_tfl_variables = true} {
  func.func @serving_default(%arg0: tensor<1x5x3xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1x5x10xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_x:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
    %cst = arith.constant dense<0.000000e+00> : tensor<1x10xf32>
    %cst_0 = arith.constant dense<[[0.17501545, 0.386878431, -0.0484701507, -0.0765462443, -0.181067079, -0.0750025958, -0.185812324, 0.0773477107, -0.00309249177, -0.0284322314], [0.134114936, 0.176706672, 0.189387321, 0.149077475, 0.142265841, 0.263391763, 0.0858101398, -0.334216356, -0.186852977, 0.164630443], [-0.00760664651, -0.217253342, 0.201065645, 0.248932779, -0.109697223, 0.0161245521, 0.059841346, -0.260818273, -0.0882004871, -0.274861604], [0.348256409, 5.338460e-02, 0.0468219183, -0.0478721932, -0.165056169, -0.0841045827, 0.321953684, -0.164883882, -0.110893697, 0.123676449], [0.128447369, 0.0639217123, 0.224082395, -0.0548633076, 0.0286352951, 0.224867627, -0.0468161702, 0.208391294, -0.0053243367, -0.0550255962], [-8.189090e-02, 0.0486251637, 0.0460817255, 0.107632667, 0.118330166, -0.0682023465, -0.109725371, 0.125324547, -0.0441931672, 0.148130983], [0.0746245757, 0.124206074, 0.176272869, 0.0834054648, 0.173528254, -0.187932938, -0.215293854, -0.103552267, -0.141145512, 0.0601227432], [0.0378115289, 0.0337749943, 0.0378107131, 0.160787702, -0.216121212, 0.229908451, -0.0723084137, 0.226159409, 0.0131477024, 0.0372102521], [0.0926874801, -0.06026401, -0.0561813228, -0.148479104, 0.287242681, 0.0332023241, -0.220059201, 0.0408726893, 0.191863343, 0.0540938973], [0.133774817, -0.278124928, -0.113644354, -0.0739784315, -0.316670179, -0.11459551, -0.264918804, -0.00448995735, -0.0878850519, -0.133028492]]> : tensor<10x10xf32>
    %cst_1 = arith.constant dense<[[0.264959276, -0.259054482, -0.110992216, -0.0414756909, 0.0988652482, 0.33331418, -0.348624617, -0.13201724, 0.00749636581, 0.0932318419], [0.210364223, -0.0775818601, -0.0835916772, 0.21802707, 0.0432840511, -0.0722324625, -0.140951559, 0.150197655, 0.0137763629, 0.139982209], [0.0820472389, 0.230024397, -0.156220302, 0.391181529, -0.0579637811, 0.0591350347, 0.0986427441, 0.237236544, 0.187094688, 0.0165050365], [-0.25562951, 0.0437992811, -0.146220565, 0.107471354, 0.00600573421, 0.0497420132, -0.134210557, 0.0306625273, -0.00543002971, -0.0933061838], [-0.0681353137, 0.187344134, -0.0387082808, -0.0682450234, -0.184748515, 0.295936018, 0.0660243928, -0.00609975681, -0.0220854413, -0.0215770602], [0.123715289, -0.149823353, 0.200297117, -0.138535887, 0.0972650945, -0.100543253, 0.0232232288, -0.0455124453, 0.188970223, 0.0623518042], [0.051465977, -0.16720742, -0.192865178, -0.255042017, -0.216406658, 0.00651523052, 0.182573959, -0.127715543, 0.197372794, -0.222561017], [0.126112625, 0.23686114, -0.0976715758, -0.0628687739, 0.00957589969, -0.211140588, -0.284029543, -0.145181119, -0.167763934, -0.224568278], [-0.0698507354, 0.0349835455, 0.0666678548, 0.0400003493, 0.246126533, -0.107588813, 0.103801601, -0.0709155499, -0.131777883, -0.106253646], [-0.207862586, 0.016677089, -0.237795576, -0.033603251, 0.227860093, -0.00418378413, 0.0745725333, -0.1573136, 0.183234155, 0.0573116578]]> : tensor<10x10xf32>
    %cst_2 = arith.constant dense<[[0.120267294, -0.0315921716, -0.207551152, -0.0469613187, -0.0472462848, 0.136440426, 0.0656398907, 0.0643941164, -0.214579433, 0.25941658], [-0.132091984, 0.142318785, 0.158400849, 0.0916790738, -0.222117975, -0.10034211, -0.134736136, 0.0465843715, 0.0484883301, 0.262922645], [0.0632713437, -0.15643549, -0.410313338, -0.0796892642, 0.0193372741, -0.192711219, 0.084401071, -0.0132020218, 0.0214749519, 0.217391357], [-0.023331102, -0.234836608, 0.283015937, -0.27972725, -0.0461517051, -0.202469319, -0.140018716, 0.266479582, -0.199459061, 0.0869432539], [-0.173759624, -0.146702722, 0.00280229794, 0.128232807, -0.13735646, -0.156450719, -0.197670206, -0.402188092, 0.182303295, 0.127702326], [0.1447341, -0.234402686, 0.017554732, 0.1726809, 0.14695932, 0.111371085, 0.109428726, 0.268820375, -0.0866426378, -0.0679584667], [-0.0736832693, -0.108465984, 0.223148763, 0.00207955297, 0.115413204, -0.0324923247, 0.151184335, -0.0644928366, 0.0295013655, -0.0610723197], [0.104881756, 0.166022196, -0.142411351, 0.142822742, 0.12192411, -1.49629079E-4, -0.335838586, -0.156993166, -0.134479508, -0.183696106], [0.131517142, 0.164839372, 0.0969391465, -0.153943628, 0.251961887, 0.0231811684, -0.0071518924, -0.0611652061, 0.0723072588, -0.00650136918], [0.184471652, 0.0503823385, -0.0937841534, 0.118845418, 0.154655725, -0.47709918, 0.186724663, 0.162837937, -0.161638841, -0.274689823]]> : tensor<10x10xf32>
    %cst_3 = arith.constant dense<[[0.0217915159, 0.206113726, -0.190372929, -0.422592759, 0.0248149242, -0.0202117879, 7.80003611E-5, 0.0636207908, -0.129249915, 0.195001438], [-0.0265121292, 0.157053322, -0.116082504, -0.0937363803, 0.155826345, 0.10708677, 0.110328041, -0.136170894, 0.143848643, -0.217342377], [-0.182416916, 0.0980726927, 6.768720e-02, -0.134475529, 0.00583284162, -0.0480525941, 0.0232472904, -0.146422684, -0.125577226, 0.231070623], [-0.404781759, -0.0607755706, 0.0381766148, 0.060485743, 0.0692789554, -0.0851613953, -0.151829481, 0.0906099975, -0.0977554768, 0.120315634], [0.147763401, 0.0345406979, 0.121470168, -0.0796541571, -0.0689968764, 0.0199192017, -0.187039837, -0.0605593659, 0.307765067, -0.132991672], [0.128111526, 0.0928409397, 0.315211028, -0.200447187, -0.0918775648, -0.0101282364, 0.0570981055, -0.0672892854, -0.0190094449, 0.0233988091], [-0.00752805918, -0.0974235087, 0.0578115322, -0.154167339, 0.30695793, 0.159499139, -0.102361277, 0.186416253, 0.11203637, -0.187620446], [0.168520495, -0.0998088344, -0.0158915576, 0.101420805, -0.122099787, 0.0111542856, -0.049965702, -0.115302391, -0.121384457, -0.00166527322], [-0.169698805, -0.10477908, -0.141094357, -0.10161002, 0.02446693, 0.249826044, 0.0071637705, -0.052272439, -0.567903876, -0.189442679], [0.225991249, -0.16400744, -0.0723658279, 0.153675273, 0.228954688, -0.0319412872, 0.0650407076, -0.126990899, -0.0396433137, 0.326944739]]> : tensor<10x10xf32>
    %cst_4 = arith.constant dense<0.000000e+00> : tensor<10xf32>
    %cst_5 = arith.constant dense<1.000000e+00> : tensor<10xf32>
    %cst_6 = arith.constant dense<[[-0.130185053, -0.0151278675, 0.0130760074], [-0.258772284, 0.299689293, -0.195314541], [0.252850413, -0.259092718, -0.0803229808], [-0.220947981, 0.155216038, 0.108377606], [0.00254765153, 0.111942321, -0.219952658], [0.206842721, -0.193888605, 0.1106188], [0.0955285131, 0.157347143, 0.221373796], [-0.276973069, -0.0735740363, -2.882380e-01], [0.012721926, 0.0903562903, -0.161965311], [-0.119528085, -0.037569046, -0.362928301]]> : tensor<10x3xf32>
    %cst_7 = arith.constant dense<[[-0.0352886915, 0.21145165, -0.165998831], [0.155003309, 0.144935846, 0.217351139], [-0.351629466, 0.341497242, -0.217549637], [0.0939139425, -0.0606328547, 0.197987914], [0.339334488, -0.0430043638, -0.193897158], [0.188981593, -0.00256928802, 0.357774317], [-0.053791374, -0.159659907, -0.334026635], [-0.313022763, 0.120892107, 0.365564883], [-0.0173099339, 0.0726312696, -0.256803274], [-0.0634435713, 0.320655167, -0.342872471]]> : tensor<10x3xf32>
    %cst_8 = arith.constant dense<[[-0.0151669681, 0.246311307, -0.0844985544], [0.00762689113, 0.150569379, -0.275011361], [0.0549676716, 0.0834532678, 0.159000754], [0.0447338223, -0.339231104, 0.134988308], [0.160350919, -0.0878992974, 0.0488999486], [0.323455155, 0.345792234, 0.061250925], [0.0837553441, -0.272862256, 0.0991969704], [0.0828025043, -0.364639938, -0.144624218], [0.343984544, -0.183882296, -0.358834654], [-0.133859664, -0.0814070403, 0.36716789]]> : tensor<10x3xf32>
    %cst_9 = arith.constant dense<[[0.284724414, 0.148204803, -0.349271417], [-0.302493066, 0.158759058, -0.236835971], [0.334069431, -0.00801679491, 0.00450757146], [0.0284658968, 0.0569611788, -0.167743862], [-0.12706618, 7.704550e-02, -0.0319027305], [0.0807663202, -0.0853067636, -0.171359152], [-0.143240675, 0.320195258, -0.107624263], [0.134614825, 0.137890339, 0.220042884], [-0.0685030818, 0.0266759694, -0.279772133], [-0.123277575, -0.130606934, 0.155195773]]> : tensor<10x3xf32>
    %0 = "tfl.no_value"() {value} : () -> none
    %cst_10 = arith.constant dense<0.000000e+00> : tensor<1x10xf32>
    %1 = "tfl.unidirectional_sequence_lstm"(%arg0, %cst_6, %cst_7, %cst_8, %cst_9, %cst_0, %cst_1, %cst_2, %cst_3, %0, %0, %0, %cst_4, %cst_5, %cst_4, %cst_4, %0, %0, %cst, %cst_10, %0, %0, %0, %0) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<1x5x3xf32>, tensor<10x3xf32>, tensor<10x3xf32>, tensor<10x3xf32>, tensor<10x3xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<1x10xf32>, tensor<1x10xf32>, none, none, none, none) -> tensor<1x5x10xf32>
    return %1 : tensor<1x5x10xf32>
  }
}

Here %cst_10 which is the InputCellState value is all zeros. So the error is coming from the TFLite file to MLIR conversion when invoked via flatbuffer_translate command.

Include any logs or source code that would be helpful to diagnose the problem.
If including tracebacks, please include the full traceback. Large logs and files
should be attached.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions