1818 * @file Header definitions to include for esp_nn reference functions
1919 */
2020
21- #include "esp_nn_defs.h"
21+ #include <stdint.h>
22+
2223/************************** Basic math functions ****************************/
2324
2425/**
@@ -80,15 +81,28 @@ void esp_nn_mul_elementwise_s8_ansi(const int8_t *input1_data,
8081 * optimization notes: Though input_offset is int32 type,
8182 * offset values are contained in 8 bits [-128, 127]
8283 */
83- void esp_nn_depthwise_conv_s8_ansi (const data_dims_t * input_dims ,
84- const int8_t * input_data ,
85- const data_dims_t * filter_dims ,
84+ void esp_nn_depthwise_conv_s8_ansi (const int8_t * input_data ,
85+ const uint16_t input_wd ,
86+ const uint16_t input_ht ,
87+ const uint16_t channels ,
88+ const int32_t input_offset ,
89+ const uint16_t pad_wd ,
90+ const uint16_t pad_ht ,
91+ const uint16_t stride_wd ,
92+ const uint16_t stride_ht ,
93+ const uint16_t ch_mult ,
8694 const int8_t * filter_data ,
95+ const uint16_t filter_wd ,
96+ const uint16_t filter_ht ,
8797 const int32_t * bias ,
88- const data_dims_t * output_dims ,
8998 int8_t * out_data ,
90- const dw_conv_params_t * conv_params ,
91- const quant_data_t * quant_data );
99+ const uint16_t out_wd ,
100+ const uint16_t out_ht ,
101+ const int32_t out_offset ,
102+ const int32_t * out_shift ,
103+ const int32_t * out_mult ,
104+ const int32_t activation_min ,
105+ const int32_t activation_max );
92106
93107/**
94108 * @brief 2d-convolution channelwise
@@ -98,26 +112,43 @@ void esp_nn_depthwise_conv_s8_ansi(const data_dims_t *input_dims,
98112 * inputs type: int8_t, output: int8_t
99113 * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
100114 */
101- void esp_nn_conv_s8_ansi (const data_dims_t * input_dims ,
102- const int8_t * input_data ,
103- const data_dims_t * filter_dims ,
115+ void esp_nn_conv_s8_ansi (const int8_t * input_data ,
116+ const uint16_t input_wd ,
117+ const uint16_t input_ht ,
118+ const uint16_t in_channels ,
119+ const int32_t input_offset ,
120+ const uint16_t pad_wd ,
121+ const uint16_t pad_ht ,
122+ const uint16_t stride_wd ,
123+ const uint16_t stride_ht ,
104124 const int8_t * filter_data ,
125+ const uint16_t filter_wd ,
126+ const uint16_t filter_ht ,
105127 const int32_t * bias ,
106- const data_dims_t * output_dims ,
107128 int8_t * out_data ,
108- const conv_params_t * conv_params ,
109- const quant_data_t * quant_data );
110-
111- int esp_nn_get_conv_scratch_size_ansi (const data_dims_t * input_dims ,
112- const data_dims_t * filter_dims ,
113- const data_dims_t * output_dims ,
114- const conv_params_t * conv_params );
129+ const uint16_t out_wd ,
130+ const uint16_t out_ht ,
131+ const uint16_t out_channels ,
132+ const int32_t out_offset ,
133+ const int32_t * out_shift ,
134+ const int32_t * out_mult ,
135+ const int32_t activation_min ,
136+ const int32_t activation_max );
137+
138+ int esp_nn_get_conv_scratch_size_ansi (const uint16_t input_wd ,
139+ const uint16_t input_ht ,
140+ const uint16_t in_ch ,
141+ const uint16_t out_ch ,
142+ const uint16_t filter_wd ,
143+ const uint16_t filter_ht );
115144void esp_nn_set_conv_scratch_buf_ansi (const void * buf );
116145
117- int esp_nn_get_depthwise_conv_scratch_size_ansi (const data_dims_t * input_dims ,
118- const data_dims_t * filter_dims ,
119- const data_dims_t * output_dims ,
120- const dw_conv_params_t * conv_params );
146+ int esp_nn_get_depthwise_conv_scratch_size_ansi (const uint16_t input_wd ,
147+ const uint16_t input_ht ,
148+ const uint16_t channels ,
149+ const uint16_t ch_mult ,
150+ const uint16_t filter_wd ,
151+ const uint16_t filter_ht );
121152void esp_nn_set_depthwise_conv_scratch_buf_ansi (const void * buf );
122153
123154/************************** Activation functions *****************************/
@@ -221,6 +252,9 @@ int32_t esp_nn_get_softmax_scratch_size_opt(const int32_t width, const int32_t h
221252 */
222253void esp_nn_set_softmax_scratch_buf_ansi (void * buffer );
223254
255+ /* ANSI C function to be hooked up when optimised version needed */
256+ void esp_nn_set_softmax_scratch_buf_opt (void * buffer );
257+
224258/**
225259 * @brief reference softmax function
226260 *
@@ -234,66 +268,6 @@ void esp_nn_softmax_s8_ansi(const int8_t *input_data,
234268 const int32_t diff_min ,
235269 int8_t * output_data );
236270
237-
238- //////////////////////////// Generic optimisations /////////////////////////////
239-
240- /************************** Convolution functions *****************************/
241-
242- /**
243- * @brief 2d-convolution channelwise optimized version
244- *
245- * @note operation: result += (input + offset) * filter
246- *
247- * inputs type: int8_t, output: int8_t
248- * input offsets: although int32_t, they are contained in 8 bits [-128, 127]
249- */
250- void esp_nn_conv_s8_opt (const data_dims_t * input_dims ,
251- const int8_t * input_data ,
252- const data_dims_t * filter_dims ,
253- const int8_t * filter_data ,
254- const int32_t * bias ,
255- const data_dims_t * output_dims ,
256- int8_t * out_data ,
257- const conv_params_t * conv_params ,
258- const quant_data_t * quant_data );
259-
260- /**
261- * @brief depthwise convolution per channel optimized version
262- *
263- * @note inputs type: int8_t, output: int8_t
264- * Version used in tflite is per channel.
265- * This version follows the same footsprints.
266- * Meaning, it has per out_channel shift and multiplier for
267- * requantization
268- *
269- * optimization notes: Though input_offset is int32 type,
270- * offset values are contained in 8 bits [-128, 127]
271- */
272- void esp_nn_depthwise_conv_s8_opt (const data_dims_t * input_dims ,
273- const int8_t * input_data ,
274- const data_dims_t * filter_dims ,
275- const int8_t * filter_data ,
276- const int32_t * bias ,
277- const data_dims_t * output_dims ,
278- int8_t * out_data ,
279- const dw_conv_params_t * conv_params ,
280- const quant_data_t * quant_data );
281-
282- int esp_nn_get_conv_scratch_size_opt (const data_dims_t * input_dims ,
283- const data_dims_t * filter_dims ,
284- const data_dims_t * output_dims ,
285- const conv_params_t * conv_params );
286- void esp_nn_set_conv_scratch_buf_opt (const void * buf );
287-
288- int esp_nn_get_depthwise_conv_scratch_size_opt (const data_dims_t * input_dims ,
289- const data_dims_t * filter_dims ,
290- const data_dims_t * output_dims ,
291- const dw_conv_params_t * conv_params );
292- void esp_nn_set_depthwise_conv_scratch_buf_opt (const void * buf );
293-
294- /* ANSI C function to be hooked up when optimised version needed */
295- void esp_nn_set_softmax_scratch_buf_opt (void * buffer );
296-
297271/**
298272 * @brief optimised version of softmax function
299273 *
0 commit comments