Skip to content

Commit 686df2d

Browse files
authored
Add INT32 support to SUB (#3037)
- Add INT32 support in sub - Add Tflite tests in sub_test.cc bug=fixes #2720
1 parent 0bf2956 commit 686df2d

File tree

3 files changed

+115
-30
lines changed

3 files changed

+115
-30
lines changed

tensorflow/lite/micro/kernels/sub.cc

Lines changed: 66 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -36,39 +36,76 @@ void* SubInit(TfLiteContext* context, const char* buffer, size_t length) {
3636
return context->AllocatePersistentBuffer(context, sizeof(OpDataSub));
3737
}
3838

39-
void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params,
40-
const OpDataSub* data, const TfLiteEvalTensor* input1,
41-
const TfLiteEvalTensor* input2, TfLiteEvalTensor* output) {
42-
float output_activation_min, output_activation_max;
43-
CalculateActivationRange(params->activation, &output_activation_min,
44-
&output_activation_max);
45-
tflite::ArithmeticParams op_params;
46-
SetActivationParams(output_activation_min, output_activation_max, &op_params);
47-
if (data->requires_broadcast) {
48-
tflite::reference_ops::BroadcastSubSlow(
49-
op_params, tflite::micro::GetTensorShape(input1),
50-
tflite::micro::GetTensorData<float>(input1),
51-
tflite::micro::GetTensorShape(input2),
52-
tflite::micro::GetTensorData<float>(input2),
53-
tflite::micro::GetTensorShape(output),
54-
tflite::micro::GetTensorData<float>(output));
55-
} else {
56-
tflite::reference_ops::SubWithActivation(
57-
op_params, tflite::micro::GetTensorShape(input1),
58-
tflite::micro::GetTensorData<float>(input1),
59-
tflite::micro::GetTensorShape(input2),
60-
tflite::micro::GetTensorData<float>(input2),
61-
tflite::micro::GetTensorShape(output),
62-
tflite::micro::GetTensorData<float>(output));
39+
TfLiteStatus EvalSub(TfLiteContext* context, TfLiteNode* node,
40+
TfLiteSubParams* params, const OpDataSub* data,
41+
const TfLiteEvalTensor* input1,
42+
const TfLiteEvalTensor* input2, TfLiteEvalTensor* output) {
43+
switch (output->type) {
44+
case kTfLiteFloat32: {
45+
float output_activation_min, output_activation_max;
46+
CalculateActivationRange(params->activation, &output_activation_min,
47+
&output_activation_max);
48+
tflite::ArithmeticParams op_params;
49+
SetActivationParams(output_activation_min, output_activation_max,
50+
&op_params);
51+
if (data->requires_broadcast) {
52+
tflite::reference_ops::BroadcastSubSlow(
53+
op_params, tflite::micro::GetTensorShape(input1),
54+
tflite::micro::GetTensorData<float>(input1),
55+
tflite::micro::GetTensorShape(input2),
56+
tflite::micro::GetTensorData<float>(input2),
57+
tflite::micro::GetTensorShape(output),
58+
tflite::micro::GetTensorData<float>(output));
59+
} else {
60+
tflite::reference_ops::SubWithActivation(
61+
op_params, tflite::micro::GetTensorShape(input1),
62+
tflite::micro::GetTensorData<float>(input1),
63+
tflite::micro::GetTensorShape(input2),
64+
tflite::micro::GetTensorData<float>(input2),
65+
tflite::micro::GetTensorShape(output),
66+
tflite::micro::GetTensorData<float>(output));
67+
}
68+
} break;
69+
case kTfLiteInt32: {
70+
int32_t output_activation_min, output_activation_max;
71+
CalculateActivationRange(params->activation, &output_activation_min,
72+
&output_activation_max);
73+
tflite::ArithmeticParams op_params;
74+
SetActivationParams(output_activation_min, output_activation_max,
75+
&op_params);
76+
if (data->requires_broadcast) {
77+
tflite::reference_ops::BroadcastSubSlow(
78+
op_params, tflite::micro::GetTensorShape(input1),
79+
tflite::micro::GetTensorData<int32_t>(input1),
80+
tflite::micro::GetTensorShape(input2),
81+
tflite::micro::GetTensorData<int32_t>(input2),
82+
tflite::micro::GetTensorShape(output),
83+
tflite::micro::GetTensorData<int32_t>(output));
84+
} else {
85+
tflite::reference_ops::SubWithActivation(
86+
op_params, tflite::micro::GetTensorShape(input1),
87+
tflite::micro::GetTensorData<int32_t>(input1),
88+
tflite::micro::GetTensorShape(input2),
89+
tflite::micro::GetTensorData<int32_t>(input2),
90+
tflite::micro::GetTensorShape(output),
91+
tflite::micro::GetTensorData<int32_t>(output));
92+
}
93+
} break;
94+
default:
95+
MicroPrintf("Type %s (%d) not supported.",
96+
TfLiteTypeGetName(output->type), output->type);
97+
return kTfLiteError;
6398
}
99+
100+
return kTfLiteOk;
64101
}
65102

66103
TfLiteStatus EvalSubQuantized(TfLiteContext* context, TfLiteNode* node,
67104
TfLiteSubParams* params, const OpDataSub* data,
68105
const TfLiteEvalTensor* input1,
69106
const TfLiteEvalTensor* input2,
70107
TfLiteEvalTensor* output) {
71-
tflite::ArithmeticParams op_params;
108+
tflite::ArithmeticParams op_params = {};
72109
op_params.left_shift = data->left_shift;
73110
op_params.input1_offset = data->input1_offset;
74111
op_params.input1_multiplier = data->input1_multiplier;
@@ -147,8 +184,9 @@ TfLiteStatus SubEval(TfLiteContext* context, TfLiteNode* node) {
147184
TFLITE_DCHECK(node->user_data != nullptr);
148185
const OpDataSub& data = *(static_cast<const OpDataSub*>(node->user_data));
149186

150-
if (output->type == kTfLiteFloat32) {
151-
EvalSub(context, node, params, &data, input1, input2, output);
187+
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
188+
TF_LITE_ENSURE_OK(
189+
context, EvalSub(context, node, params, &data, input1, input2, output));
152190
} else if (output->type == kTfLiteInt8 || output->type == kTfLiteInt16) {
153191
TF_LITE_ENSURE_OK(context, EvalSubQuantized(context, node, params, &data,
154192
input1, input2, output));

tensorflow/lite/micro/kernels/sub_common.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -98,6 +98,16 @@ TfLiteStatus SubPrepare(TfLiteContext* context, TfLiteNode* node) {
9898
TF_LITE_ENSURE_STATUS(
9999
CalculateOpDataSub(context, params, input1, input2, output, data));
100100

101+
if (output->type == kTfLiteInt32) {
102+
// Only support INT32 unquantized SUB for now.
103+
TF_LITE_ENSURE_EQ(context, input1->quantization.type,
104+
kTfLiteNoQuantization);
105+
TF_LITE_ENSURE_EQ(context, input2->quantization.type,
106+
kTfLiteNoQuantization);
107+
TF_LITE_ENSURE_EQ(context, output->quantization.type,
108+
kTfLiteNoQuantization);
109+
}
110+
101111
micro_context->DeallocateTempTfLiteTensor(input1);
102112
micro_context->DeallocateTempTfLiteTensor(input2);
103113
micro_context->DeallocateTempTfLiteTensor(output);

tensorflow/lite/micro/kernels/sub_test.cc

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -105,6 +105,29 @@ void TestSubFloat(int* input1_dims_data, const float* input1_data,
105105
ElementCount(*output_dims), activation);
106106
}
107107

108+
#if !defined(XTENSA)
109+
void TestSubInt32(int* input1_dims_data, const int32_t* input1_data,
110+
int* input2_dims_data, const int32_t* input2_data,
111+
int* output_dims_data, const int32_t* expected_output,
112+
TfLiteFusedActivation activation, int32_t* output_data) {
113+
TfLiteIntArray* input1_dims = IntArrayFromInts(input1_dims_data);
114+
TfLiteIntArray* input2_dims = IntArrayFromInts(input2_dims_data);
115+
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
116+
117+
constexpr int inputs_size = 2;
118+
constexpr int outputs_size = 1;
119+
constexpr int tensors_size = inputs_size + outputs_size;
120+
TfLiteTensor tensors[tensors_size] = {
121+
CreateTensor(input1_data, input1_dims),
122+
CreateTensor(input2_data, input2_dims),
123+
CreateTensor(output_data, output_dims),
124+
};
125+
126+
ValidateSubGoldens(tensors, tensors_size, expected_output, output_data,
127+
ElementCount(*output_dims), activation);
128+
}
129+
#endif
130+
108131
template <typename T>
109132
void TestSubQuantized(int* input1_dims_data, const float* input1_data,
110133
T* input1_quantized, float input1_scale,
@@ -219,6 +242,20 @@ TF_LITE_MICRO_TEST(FloatSubWithScalarBroadcast) {
219242
}
220243
}
221244

245+
#if !defined(XTENSA)
246+
TF_LITE_MICRO_TEST(Int32SubNoActivation) {
247+
int inout_shape[] = {4, 1, 2, 2, 1};
248+
const int32_t input1_values[] = {-2, 2147483646, -1, 1146622854};
249+
const int32_t input2_values[] = {3, 1, -2147483647, -726978367};
250+
const int32_t golden_values[] = {-5, 2147483645, 2147483646, 1873601221};
251+
const int kOutputDimsCount = 4;
252+
int32_t output_data[kOutputDimsCount];
253+
tflite::testing::TestSubInt32(inout_shape, input1_values, inout_shape,
254+
input2_values, inout_shape, golden_values,
255+
kTfLiteActNone, output_data);
256+
}
257+
#endif
258+
222259
TF_LITE_MICRO_TEST(QuantizedSubNoActivationInt8) {
223260
const float scales[] = {0.25, 0.5, 1.0};
224261
const int zero_points[] = {-10, 4, 13};

0 commit comments

Comments
 (0)