@@ -16,22 +16,12 @@ limitations under the License.
16
16
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_BATCH_MATMUL_H_
17
17
#define TENSORFLOW_LITE_MICRO_KERNELS_BATCH_MATMUL_H_
18
18
19
- #include < cstdint>
20
-
21
19
#include " tensorflow/lite/c/builtin_op_data.h"
22
- #include " tensorflow/lite/kernels/internal/reference/transpose.h"
23
20
#include " tensorflow/lite/kernels/internal/types.h"
24
- #include " tensorflow/lite/kernels/kernel_util.h"
25
- #include " tensorflow/lite/micro/kernels/kernel_util.h"
26
21
#include " tensorflow/lite/micro/micro_common.h"
27
- #include " tensorflow/lite/micro/micro_log.h"
28
22
29
23
namespace tflite {
30
24
31
- extern constexpr int kBatchMatmulInputLhsTensor = 0 ;
32
- extern constexpr int kBatchMatmulInputRhsTensor = 1 ;
33
- extern constexpr int kBatchMatmulOutputTensor = 0 ;
34
-
35
25
struct QuantizationOpDataBatchMatmul {
36
26
// The scaling factor from input to output (aka the 'real multiplier') can
37
27
// be represented as a fixed point multiplier plus a left shift.
@@ -59,98 +49,29 @@ struct OpDataBatchMatmul {
59
49
bool rhs_is_constant_tensor;
60
50
};
61
51
52
+ extern const int kBatchMatmulInputLhsTensor ;
53
+ extern const int kBatchMatmulInputRhsTensor ;
54
+ extern const int kBatchMatmulOutputTensor ;
55
+
62
56
TfLiteStatus ReshapeOutputTensor (TfLiteContext* context, TfLiteNode* node,
63
57
const RuntimeShape& extended_lhs_shape,
64
58
const RuntimeShape& extended_rhs_shape,
65
59
bool adj_x, bool adj_y, int output_rank,
66
- TfLiteTensor* output) {
67
- int64_t orig_size = NumElements (output);
68
-
69
- // make sure the new output dims rank does not exceed the original rank
70
- TF_LITE_ENSURE (context, output_rank <= NumDimensions (output));
71
-
72
- // make sure output tensor dims are not in the FlatBuffer
73
- TfLiteEvalTensor* output_eval =
74
- tflite::micro::GetEvalOutput (context, node, kBatchMatmulOutputTensor );
75
- TF_LITE_ENSURE_OK (context, tflite::micro::CreateWritableTensorDimsWithCopy (
76
- context, output, output_eval));
77
-
78
- // Fill in any broadcast dimensions.
79
- for (int i = 0 ; i < output_rank - 2 ; ++i) {
80
- const int lhs_dim = extended_lhs_shape.Dims (i);
81
- const int rhs_dim = extended_rhs_shape.Dims (i);
82
- int broadcast_dim = lhs_dim;
83
- if ((lhs_dim != rhs_dim) && (lhs_dim == 1 )) {
84
- broadcast_dim = rhs_dim;
85
- }
86
- output->dims ->data [i] = broadcast_dim;
87
- }
88
- // Fill in the matmul dimensions.
89
- int lhs_rows_index = adj_x ? output_rank - 1 : output_rank - 2 ;
90
- int rhs_cols_index = adj_y ? output_rank - 2 : output_rank - 1 ;
91
-
92
- output->dims ->data [output_rank - 2 ] = extended_lhs_shape.Dims (lhs_rows_index);
93
- output->dims ->data [output_rank - 1 ] = extended_rhs_shape.Dims (rhs_cols_index);
94
- output->dims ->size = output_rank;
95
-
96
- // Check that output tensor has not been resized
97
- // since TFLM doesn't support tensor resizing.
98
- TF_LITE_ENSURE_EQ (context, orig_size, NumElements (output));
99
-
100
- return kTfLiteOk ;
101
- }
60
+ TfLiteTensor* output);
102
61
103
62
template <typename T>
104
63
void TransposeRowsColumnsImpl (const TfLiteEvalTensor& tensor_in,
105
- TfLiteEvalTensor* tensor_out) {
106
- const T* input = tflite::micro::GetTensorData<T>(&tensor_in);
107
- T* output = tflite::micro::GetTensorData<T>(tensor_out);
108
- RuntimeShape transposed_shape (tflite::micro::GetTensorShape (&tensor_in));
109
- RuntimeShape shape (transposed_shape);
110
- TransposeParams params;
111
- const int rank = shape.DimensionsCount ();
112
- params.perm_count = rank;
113
- for (int i = 0 ; i < rank - 2 ; ++i) {
114
- params.perm [i] = i;
115
- }
116
- // Transpose the last two dimensions.
117
- params.perm [rank - 2 ] = rank - 1 ;
118
- params.perm [rank - 1 ] = rank - 2 ;
119
- transposed_shape.SetDim (rank - 1 , shape.Dims (rank - 2 ));
120
- transposed_shape.SetDim (rank - 2 , shape.Dims (rank - 1 ));
121
- reference_ops::Transpose (params, shape, input, transposed_shape, output);
122
- }
64
+ TfLiteEvalTensor* tensor_out);
123
65
124
66
TfLiteStatus TransposeRowsColumns (const TfLiteEvalTensor& tensor_in,
125
- TfLiteEvalTensor* tensor_out) {
126
- if (tensor_in.type == kTfLiteFloat32 ) {
127
- TransposeRowsColumnsImpl<float >(tensor_in, tensor_out);
128
- return kTfLiteOk ;
129
- } else if (tensor_in.type == kTfLiteInt8 ) {
130
- TransposeRowsColumnsImpl<int8_t >(tensor_in, tensor_out);
131
- return kTfLiteOk ;
132
- } else if (tensor_in.type == kTfLiteInt16 ) {
133
- TransposeRowsColumnsImpl<int16_t >(tensor_in, tensor_out);
134
- return kTfLiteOk ;
135
- } else {
136
- MicroPrintf (
137
- " BATCH_MATMUL can only transpose tensors with FLOAT32, INT8, INT16 "
138
- " type." );
139
- }
140
- return kTfLiteError ;
141
- }
67
+ TfLiteEvalTensor* tensor_out);
142
68
143
- RuntimeShape SwapRowColumnDims (const RuntimeShape& shape) {
144
- RuntimeShape swapped_shape (shape);
145
- const int32_t dims = shape.DimensionsCount ();
146
- swapped_shape.SetDim (dims - 2 , shape.Dims (dims - 1 ));
147
- swapped_shape.SetDim (dims - 1 , shape.Dims (dims - 2 ));
148
- return swapped_shape;
149
- }
69
+ RuntimeShape SwapRowColumnDims (const RuntimeShape& shape);
150
70
151
71
TFLMRegistration Register_BATCH_MATMUL ();
152
72
153
73
#if defined(CMSIS_NN)
74
+
154
75
// Returns a TFLMRegistration struct for kernel variant that only supports
155
76
// int8 matrix multiplication and uses the latency optimized
156
77
// implementations.
0 commit comments