Skip to content

Commit ef64591

Browse files
authored
Fix quant specific op registration for some ops (#2770)
BUG=Quantization specific registration for BatchMatmul, SVDF and LSTM were not working correctly.
1 parent 740cef3 commit ef64591

File tree

6 files changed

+140
-91
lines changed

6 files changed

+140
-91
lines changed

tensorflow/lite/micro/kernels/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ tflm_kernel_cc_library(
222222
"arg_min_max.cc",
223223
"assign_variable.cc",
224224
"batch_matmul.cc",
225+
"batch_matmul_common.cc",
225226
"batch_to_space_nd.cc",
226227
"broadcast_args.cc",
227228
"broadcast_to.cc",
@@ -347,6 +348,7 @@ tflm_kernel_cc_library(
347348
"sub.h",
348349
"svdf.h",
349350
"transpose_conv.h",
351+
"unidirectional_sequence_lstm.h",
350352
] + select({
351353
xtensa_fusion_f1_config(): glob(["xtensa/**/*.h"]),
352354
xtensa_hifi_3_config(): glob(["xtensa/**/*.h"]),

tensorflow/lite/micro/kernels/batch_matmul.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ limitations under the License.
2424
#include "tensorflow/lite/kernels/internal/reference/transpose.h"
2525
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
2626
#include "tensorflow/lite/kernels/internal/types.h"
27+
#include "tensorflow/lite/kernels/kernel_util.h"
2728
#include "tensorflow/lite/micro/kernels/batch_matmul.h"
29+
#include "tensorflow/lite/micro/kernels/kernel_util.h"
2830
#include "tensorflow/lite/micro/micro_log.h"
2931

3032
namespace tflite {

tensorflow/lite/micro/kernels/batch_matmul.h

Lines changed: 9 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,12 @@ limitations under the License.
1616
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_BATCH_MATMUL_H_
1717
#define TENSORFLOW_LITE_MICRO_KERNELS_BATCH_MATMUL_H_
1818

19-
#include <cstdint>
20-
2119
#include "tensorflow/lite/c/builtin_op_data.h"
22-
#include "tensorflow/lite/kernels/internal/reference/transpose.h"
2320
#include "tensorflow/lite/kernels/internal/types.h"
24-
#include "tensorflow/lite/kernels/kernel_util.h"
25-
#include "tensorflow/lite/micro/kernels/kernel_util.h"
2621
#include "tensorflow/lite/micro/micro_common.h"
27-
#include "tensorflow/lite/micro/micro_log.h"
2822

2923
namespace tflite {
3024

31-
extern constexpr int kBatchMatmulInputLhsTensor = 0;
32-
extern constexpr int kBatchMatmulInputRhsTensor = 1;
33-
extern constexpr int kBatchMatmulOutputTensor = 0;
34-
3525
struct QuantizationOpDataBatchMatmul {
3626
// The scaling factor from input to output (aka the 'real multiplier') can
3727
// be represented as a fixed point multiplier plus a left shift.
@@ -59,98 +49,29 @@ struct OpDataBatchMatmul {
5949
bool rhs_is_constant_tensor;
6050
};
6151

52+
extern const int kBatchMatmulInputLhsTensor;
53+
extern const int kBatchMatmulInputRhsTensor;
54+
extern const int kBatchMatmulOutputTensor;
55+
6256
TfLiteStatus ReshapeOutputTensor(TfLiteContext* context, TfLiteNode* node,
6357
const RuntimeShape& extended_lhs_shape,
6458
const RuntimeShape& extended_rhs_shape,
6559
bool adj_x, bool adj_y, int output_rank,
66-
TfLiteTensor* output) {
67-
int64_t orig_size = NumElements(output);
68-
69-
// make sure the new output dims rank does not exceed the original rank
70-
TF_LITE_ENSURE(context, output_rank <= NumDimensions(output));
71-
72-
// make sure output tensor dims are not in the FlatBuffer
73-
TfLiteEvalTensor* output_eval =
74-
tflite::micro::GetEvalOutput(context, node, kBatchMatmulOutputTensor);
75-
TF_LITE_ENSURE_OK(context, tflite::micro::CreateWritableTensorDimsWithCopy(
76-
context, output, output_eval));
77-
78-
// Fill in any broadcast dimensions.
79-
for (int i = 0; i < output_rank - 2; ++i) {
80-
const int lhs_dim = extended_lhs_shape.Dims(i);
81-
const int rhs_dim = extended_rhs_shape.Dims(i);
82-
int broadcast_dim = lhs_dim;
83-
if ((lhs_dim != rhs_dim) && (lhs_dim == 1)) {
84-
broadcast_dim = rhs_dim;
85-
}
86-
output->dims->data[i] = broadcast_dim;
87-
}
88-
// Fill in the matmul dimensions.
89-
int lhs_rows_index = adj_x ? output_rank - 1 : output_rank - 2;
90-
int rhs_cols_index = adj_y ? output_rank - 2 : output_rank - 1;
91-
92-
output->dims->data[output_rank - 2] = extended_lhs_shape.Dims(lhs_rows_index);
93-
output->dims->data[output_rank - 1] = extended_rhs_shape.Dims(rhs_cols_index);
94-
output->dims->size = output_rank;
95-
96-
// Check that output tensor has not been resized
97-
// since TFLM doesn't support tensor resizing.
98-
TF_LITE_ENSURE_EQ(context, orig_size, NumElements(output));
99-
100-
return kTfLiteOk;
101-
}
60+
TfLiteTensor* output);
10261

10362
template <typename T>
10463
void TransposeRowsColumnsImpl(const TfLiteEvalTensor& tensor_in,
105-
TfLiteEvalTensor* tensor_out) {
106-
const T* input = tflite::micro::GetTensorData<T>(&tensor_in);
107-
T* output = tflite::micro::GetTensorData<T>(tensor_out);
108-
RuntimeShape transposed_shape(tflite::micro::GetTensorShape(&tensor_in));
109-
RuntimeShape shape(transposed_shape);
110-
TransposeParams params;
111-
const int rank = shape.DimensionsCount();
112-
params.perm_count = rank;
113-
for (int i = 0; i < rank - 2; ++i) {
114-
params.perm[i] = i;
115-
}
116-
// Transpose the last two dimensions.
117-
params.perm[rank - 2] = rank - 1;
118-
params.perm[rank - 1] = rank - 2;
119-
transposed_shape.SetDim(rank - 1, shape.Dims(rank - 2));
120-
transposed_shape.SetDim(rank - 2, shape.Dims(rank - 1));
121-
reference_ops::Transpose(params, shape, input, transposed_shape, output);
122-
}
64+
TfLiteEvalTensor* tensor_out);
12365

12466
TfLiteStatus TransposeRowsColumns(const TfLiteEvalTensor& tensor_in,
125-
TfLiteEvalTensor* tensor_out) {
126-
if (tensor_in.type == kTfLiteFloat32) {
127-
TransposeRowsColumnsImpl<float>(tensor_in, tensor_out);
128-
return kTfLiteOk;
129-
} else if (tensor_in.type == kTfLiteInt8) {
130-
TransposeRowsColumnsImpl<int8_t>(tensor_in, tensor_out);
131-
return kTfLiteOk;
132-
} else if (tensor_in.type == kTfLiteInt16) {
133-
TransposeRowsColumnsImpl<int16_t>(tensor_in, tensor_out);
134-
return kTfLiteOk;
135-
} else {
136-
MicroPrintf(
137-
"BATCH_MATMUL can only transpose tensors with FLOAT32, INT8, INT16 "
138-
"type.");
139-
}
140-
return kTfLiteError;
141-
}
67+
TfLiteEvalTensor* tensor_out);
14268

143-
RuntimeShape SwapRowColumnDims(const RuntimeShape& shape) {
144-
RuntimeShape swapped_shape(shape);
145-
const int32_t dims = shape.DimensionsCount();
146-
swapped_shape.SetDim(dims - 2, shape.Dims(dims - 1));
147-
swapped_shape.SetDim(dims - 1, shape.Dims(dims - 2));
148-
return swapped_shape;
149-
}
69+
RuntimeShape SwapRowColumnDims(const RuntimeShape& shape);
15070

15171
TFLMRegistration Register_BATCH_MATMUL();
15272

15373
#if defined(CMSIS_NN)
74+
15475
// Returns a TFLMRegistration struct for kernel variant that only supports
15576
// int8 matrix multiplication and uses the latency optimized
15677
// implementations.
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include <cstdint>
17+
18+
#include "tensorflow/lite/kernels/internal/reference/transpose.h"
19+
#include "tensorflow/lite/kernels/kernel_util.h"
20+
#include "tensorflow/lite/micro/kernels/batch_matmul.h"
21+
#include "tensorflow/lite/micro/kernels/kernel_util.h"
22+
#include "tensorflow/lite/micro/micro_log.h"
23+
24+
namespace tflite {
25+
26+
const int kBatchMatmulInputLhsTensor = 0;
27+
const int kBatchMatmulInputRhsTensor = 1;
28+
const int kBatchMatmulOutputTensor = 0;
29+
30+
TfLiteStatus ReshapeOutputTensor(TfLiteContext* context, TfLiteNode* node,
31+
const RuntimeShape& extended_lhs_shape,
32+
const RuntimeShape& extended_rhs_shape,
33+
bool adj_x, bool adj_y, int output_rank,
34+
TfLiteTensor* output) {
35+
int64_t orig_size = NumElements(output);
36+
37+
// make sure the new output dims rank does not exceed the original rank
38+
TF_LITE_ENSURE(context, output_rank <= NumDimensions(output));
39+
40+
// make sure output tensor dims are not in the FlatBuffer
41+
TfLiteEvalTensor* output_eval =
42+
tflite::micro::GetEvalOutput(context, node, kBatchMatmulOutputTensor);
43+
TF_LITE_ENSURE_OK(context, tflite::micro::CreateWritableTensorDimsWithCopy(
44+
context, output, output_eval));
45+
46+
// Fill in any broadcast dimensions.
47+
for (int i = 0; i < output_rank - 2; ++i) {
48+
const int lhs_dim = extended_lhs_shape.Dims(i);
49+
const int rhs_dim = extended_rhs_shape.Dims(i);
50+
int broadcast_dim = lhs_dim;
51+
if ((lhs_dim != rhs_dim) && (lhs_dim == 1)) {
52+
broadcast_dim = rhs_dim;
53+
}
54+
output->dims->data[i] = broadcast_dim;
55+
}
56+
// Fill in the matmul dimensions.
57+
int lhs_rows_index = adj_x ? output_rank - 1 : output_rank - 2;
58+
int rhs_cols_index = adj_y ? output_rank - 2 : output_rank - 1;
59+
60+
output->dims->data[output_rank - 2] = extended_lhs_shape.Dims(lhs_rows_index);
61+
output->dims->data[output_rank - 1] = extended_rhs_shape.Dims(rhs_cols_index);
62+
output->dims->size = output_rank;
63+
64+
// Check that output tensor has not been resized
65+
// since TFLM doesn't support tensor resizing.
66+
TF_LITE_ENSURE_EQ(context, orig_size, NumElements(output));
67+
68+
return kTfLiteOk;
69+
}
70+
71+
template <typename T>
72+
void TransposeRowsColumnsImpl(const TfLiteEvalTensor& tensor_in,
73+
TfLiteEvalTensor* tensor_out) {
74+
const T* input = tflite::micro::GetTensorData<T>(&tensor_in);
75+
T* output = tflite::micro::GetTensorData<T>(tensor_out);
76+
RuntimeShape transposed_shape(tflite::micro::GetTensorShape(&tensor_in));
77+
RuntimeShape shape(transposed_shape);
78+
TransposeParams params;
79+
const int rank = shape.DimensionsCount();
80+
params.perm_count = rank;
81+
for (int i = 0; i < rank - 2; ++i) {
82+
params.perm[i] = i;
83+
}
84+
// Transpose the last two dimensions.
85+
params.perm[rank - 2] = rank - 1;
86+
params.perm[rank - 1] = rank - 2;
87+
transposed_shape.SetDim(rank - 1, shape.Dims(rank - 2));
88+
transposed_shape.SetDim(rank - 2, shape.Dims(rank - 1));
89+
reference_ops::Transpose(params, shape, input, transposed_shape, output);
90+
}
91+
92+
TfLiteStatus TransposeRowsColumns(const TfLiteEvalTensor& tensor_in,
93+
TfLiteEvalTensor* tensor_out) {
94+
if (tensor_in.type == kTfLiteFloat32) {
95+
TransposeRowsColumnsImpl<float>(tensor_in, tensor_out);
96+
return kTfLiteOk;
97+
} else if (tensor_in.type == kTfLiteInt8) {
98+
TransposeRowsColumnsImpl<int8_t>(tensor_in, tensor_out);
99+
return kTfLiteOk;
100+
} else if (tensor_in.type == kTfLiteInt16) {
101+
TransposeRowsColumnsImpl<int16_t>(tensor_in, tensor_out);
102+
return kTfLiteOk;
103+
} else {
104+
MicroPrintf(
105+
"BATCH_MATMUL can only transpose tensors with FLOAT32, INT8, INT16 "
106+
"type.");
107+
}
108+
return kTfLiteError;
109+
}
110+
111+
RuntimeShape SwapRowColumnDims(const RuntimeShape& shape) {
112+
RuntimeShape swapped_shape(shape);
113+
const int32_t dims = shape.DimensionsCount();
114+
swapped_shape.SetDim(dims - 2, shape.Dims(dims - 1));
115+
swapped_shape.SetDim(dims - 1, shape.Dims(dims - 2));
116+
return swapped_shape;
117+
}
118+
119+
} // namespace tflite

tensorflow/lite/micro/micro_mutable_op_resolver.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License.
2424
#include "tensorflow/lite/kernels/op_macros.h"
2525
#include "tensorflow/lite/micro/compatibility.h"
2626
#include "tensorflow/lite/micro/kernels/add.h"
27+
#include "tensorflow/lite/micro/kernels/batch_matmul.h"
2728
#include "tensorflow/lite/micro/kernels/conv.h"
2829
#include "tensorflow/lite/micro/kernels/depthwise_conv.h"
2930
#include "tensorflow/lite/micro/kernels/ethosu.h"
@@ -34,7 +35,9 @@ limitations under the License.
3435
#include "tensorflow/lite/micro/kernels/pooling.h"
3536
#include "tensorflow/lite/micro/kernels/reduce.h"
3637
#include "tensorflow/lite/micro/kernels/softmax.h"
38+
#include "tensorflow/lite/micro/kernels/svdf.h"
3739
#include "tensorflow/lite/micro/kernels/transpose_conv.h"
40+
#include "tensorflow/lite/micro/kernels/unidirectional_sequence_lstm.h"
3841
#include "tensorflow/lite/micro/micro_log.h"
3942
#include "tensorflow/lite/micro/micro_op_resolver.h"
4043
#include "tensorflow/lite/schema/schema_generated.h"
@@ -146,9 +149,10 @@ class MicroMutableOpResolver : public MicroOpResolver {
146149
return AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, registration, ParsePool);
147150
}
148151

149-
TfLiteStatus AddBatchMatMul() {
150-
return AddBuiltin(BuiltinOperator_BATCH_MATMUL,
151-
tflite::Register_BATCH_MATMUL(), ParseBatchMatMul);
152+
TfLiteStatus AddBatchMatMul(
153+
const TFLMRegistration& registration = Register_BATCH_MATMUL()) {
154+
return AddBuiltin(BuiltinOperator_BATCH_MATMUL, registration,
155+
ParseBatchMatMul);
152156
}
153157

154158
TfLiteStatus AddBatchToSpaceNd() {

tensorflow/lite/micro/tools/make/Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/add_n.cc \
365365
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/arg_min_max.cc \
366366
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/assign_variable.cc \
367367
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/batch_matmul.cc \
368+
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/batch_matmul_common.cc \
368369
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/batch_to_space_nd.cc \
369370
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/broadcast_args.cc \
370371
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/broadcast_to.cc \

0 commit comments

Comments
 (0)