|
1 |
| -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
| 1 | +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. |
2 | 2 |
|
3 | 3 | Licensed under the Apache License, Version 2.0 (the "License");
|
4 | 4 | you may not use this file except in compliance with the License.
|
@@ -36,39 +36,76 @@ void* SubInit(TfLiteContext* context, const char* buffer, size_t length) {
|
36 | 36 | return context->AllocatePersistentBuffer(context, sizeof(OpDataSub));
|
37 | 37 | }
|
38 | 38 |
|
39 |
| -void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params, |
40 |
| - const OpDataSub* data, const TfLiteEvalTensor* input1, |
41 |
| - const TfLiteEvalTensor* input2, TfLiteEvalTensor* output) { |
42 |
| - float output_activation_min, output_activation_max; |
43 |
| - CalculateActivationRange(params->activation, &output_activation_min, |
44 |
| - &output_activation_max); |
45 |
| - tflite::ArithmeticParams op_params; |
46 |
| - SetActivationParams(output_activation_min, output_activation_max, &op_params); |
47 |
| - if (data->requires_broadcast) { |
48 |
| - tflite::reference_ops::BroadcastSubSlow( |
49 |
| - op_params, tflite::micro::GetTensorShape(input1), |
50 |
| - tflite::micro::GetTensorData<float>(input1), |
51 |
| - tflite::micro::GetTensorShape(input2), |
52 |
| - tflite::micro::GetTensorData<float>(input2), |
53 |
| - tflite::micro::GetTensorShape(output), |
54 |
| - tflite::micro::GetTensorData<float>(output)); |
55 |
| - } else { |
56 |
| - tflite::reference_ops::SubWithActivation( |
57 |
| - op_params, tflite::micro::GetTensorShape(input1), |
58 |
| - tflite::micro::GetTensorData<float>(input1), |
59 |
| - tflite::micro::GetTensorShape(input2), |
60 |
| - tflite::micro::GetTensorData<float>(input2), |
61 |
| - tflite::micro::GetTensorShape(output), |
62 |
| - tflite::micro::GetTensorData<float>(output)); |
| 39 | +TfLiteStatus EvalSub(TfLiteContext* context, TfLiteNode* node, |
| 40 | + TfLiteSubParams* params, const OpDataSub* data, |
| 41 | + const TfLiteEvalTensor* input1, |
| 42 | + const TfLiteEvalTensor* input2, TfLiteEvalTensor* output) { |
| 43 | + switch (output->type) { |
| 44 | + case kTfLiteFloat32: { |
| 45 | + float output_activation_min, output_activation_max; |
| 46 | + CalculateActivationRange(params->activation, &output_activation_min, |
| 47 | + &output_activation_max); |
| 48 | + tflite::ArithmeticParams op_params; |
| 49 | + SetActivationParams(output_activation_min, output_activation_max, |
| 50 | + &op_params); |
| 51 | + if (data->requires_broadcast) { |
| 52 | + tflite::reference_ops::BroadcastSubSlow( |
| 53 | + op_params, tflite::micro::GetTensorShape(input1), |
| 54 | + tflite::micro::GetTensorData<float>(input1), |
| 55 | + tflite::micro::GetTensorShape(input2), |
| 56 | + tflite::micro::GetTensorData<float>(input2), |
| 57 | + tflite::micro::GetTensorShape(output), |
| 58 | + tflite::micro::GetTensorData<float>(output)); |
| 59 | + } else { |
| 60 | + tflite::reference_ops::SubWithActivation( |
| 61 | + op_params, tflite::micro::GetTensorShape(input1), |
| 62 | + tflite::micro::GetTensorData<float>(input1), |
| 63 | + tflite::micro::GetTensorShape(input2), |
| 64 | + tflite::micro::GetTensorData<float>(input2), |
| 65 | + tflite::micro::GetTensorShape(output), |
| 66 | + tflite::micro::GetTensorData<float>(output)); |
| 67 | + } |
| 68 | + } break; |
| 69 | + case kTfLiteInt32: { |
| 70 | + int32_t output_activation_min, output_activation_max; |
| 71 | + CalculateActivationRange(params->activation, &output_activation_min, |
| 72 | + &output_activation_max); |
| 73 | + tflite::ArithmeticParams op_params; |
| 74 | + SetActivationParams(output_activation_min, output_activation_max, |
| 75 | + &op_params); |
| 76 | + if (data->requires_broadcast) { |
| 77 | + tflite::reference_ops::BroadcastSubSlow( |
| 78 | + op_params, tflite::micro::GetTensorShape(input1), |
| 79 | + tflite::micro::GetTensorData<int32_t>(input1), |
| 80 | + tflite::micro::GetTensorShape(input2), |
| 81 | + tflite::micro::GetTensorData<int32_t>(input2), |
| 82 | + tflite::micro::GetTensorShape(output), |
| 83 | + tflite::micro::GetTensorData<int32_t>(output)); |
| 84 | + } else { |
| 85 | + tflite::reference_ops::SubWithActivation( |
| 86 | + op_params, tflite::micro::GetTensorShape(input1), |
| 87 | + tflite::micro::GetTensorData<int32_t>(input1), |
| 88 | + tflite::micro::GetTensorShape(input2), |
| 89 | + tflite::micro::GetTensorData<int32_t>(input2), |
| 90 | + tflite::micro::GetTensorShape(output), |
| 91 | + tflite::micro::GetTensorData<int32_t>(output)); |
| 92 | + } |
| 93 | + } break; |
| 94 | + default: |
| 95 | + MicroPrintf("Type %s (%d) not supported.", |
| 96 | + TfLiteTypeGetName(output->type), output->type); |
| 97 | + return kTfLiteError; |
63 | 98 | }
|
| 99 | + |
| 100 | + return kTfLiteOk; |
64 | 101 | }
|
65 | 102 |
|
66 | 103 | TfLiteStatus EvalSubQuantized(TfLiteContext* context, TfLiteNode* node,
|
67 | 104 | TfLiteSubParams* params, const OpDataSub* data,
|
68 | 105 | const TfLiteEvalTensor* input1,
|
69 | 106 | const TfLiteEvalTensor* input2,
|
70 | 107 | TfLiteEvalTensor* output) {
|
71 |
| - tflite::ArithmeticParams op_params; |
| 108 | + tflite::ArithmeticParams op_params = {}; |
72 | 109 | op_params.left_shift = data->left_shift;
|
73 | 110 | op_params.input1_offset = data->input1_offset;
|
74 | 111 | op_params.input1_multiplier = data->input1_multiplier;
|
@@ -147,8 +184,9 @@ TfLiteStatus SubEval(TfLiteContext* context, TfLiteNode* node) {
|
147 | 184 | TFLITE_DCHECK(node->user_data != nullptr);
|
148 | 185 | const OpDataSub& data = *(static_cast<const OpDataSub*>(node->user_data));
|
149 | 186 |
|
150 |
| - if (output->type == kTfLiteFloat32) { |
151 |
| - EvalSub(context, node, params, &data, input1, input2, output); |
| 187 | + if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) { |
| 188 | + TF_LITE_ENSURE_OK( |
| 189 | + context, EvalSub(context, node, params, &data, input1, input2, output)); |
152 | 190 | } else if (output->type == kTfLiteInt8 || output->type == kTfLiteInt16) {
|
153 | 191 | TF_LITE_ENSURE_OK(context, EvalSubQuantized(context, node, params, &data,
|
154 | 192 | input1, input2, output));
|
|
0 commit comments