|
1 | 1 | from hls4ml.backends.backend import get_backend
|
2 | 2 | from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
|
3 |
| -from hls4ml.model.layers import GRU, LSTM |
| 3 | +from hls4ml.model.layers import GRU, LSTM, TimeDistributed |
4 | 4 |
|
5 | 5 | # recurrent multiplication template
|
6 | 6 |
|
@@ -237,3 +237,69 @@ def format(self, node):
|
237 | 237 | template = recr_function_template
|
238 | 238 |
|
239 | 239 | return template.format(**params)
|
| 240 | + |
| 241 | + |
| 242 | +time_distributed_config_template = """struct config{index} : nnet::time_distributed_config {{ |
| 243 | + static const unsigned dim = {dim}; |
| 244 | +
|
| 245 | + static const unsigned n_time_steps = {n_time_steps}; |
| 246 | + static const unsigned in_height = {in_height}; |
| 247 | + static const unsigned in_width = {in_width}; |
| 248 | + static const unsigned n_chan = {n_chan}; |
| 249 | +}};\n""" |
| 250 | + |
| 251 | +time_distributed_loop_start_template = """for (int ts = 0; ts < config{index}::n_time_steps; ts++) {{ |
| 252 | + {loop_mode} |
| 253 | + nnet::read_time_step_{dim}d<{input_t}, {config}>(ts, {input}, {output});""" |
| 254 | + |
| 255 | +time_distributed_loop_end_template = """ nnet::write_time_step_{dim}d<{output_t}, {config}>(ts, {input}, {output}); |
| 256 | + }}""" |
| 257 | + |
| 258 | +time_distributed_include_list = ['nnet_utils/nnet_time_distributed.h'] |
| 259 | + |
| 260 | + |
| 261 | +class TimeDistributedConfigTemplate(LayerConfigTemplate): |
| 262 | + def __init__(self): |
| 263 | + super().__init__(TimeDistributed) |
| 264 | + self.template = time_distributed_config_template |
| 265 | + |
| 266 | + def format(self, node): |
| 267 | + params = self._default_config_params(node) |
| 268 | + |
| 269 | + input_shape = node.get_input_variable().shape |
| 270 | + params['dim'] = len(input_shape) |
| 271 | + if node.name.endswith('_end'): |
| 272 | + params['dim'] += 1 # The input variable will be from the wrapped layer, without time dimension |
| 273 | + params['in_height'] = input_shape[-3] if params['dim'] == 4 else 1 |
| 274 | + params['in_width'] = input_shape[-2] if params['dim'] >= 3 else 1 |
| 275 | + params['n_chan'] = input_shape[-1] |
| 276 | + |
| 277 | + return self.template.format(**params) |
| 278 | + |
| 279 | + |
| 280 | +class TimeDistributedFunctionTemplate(FunctionCallTemplate): |
| 281 | + def __init__(self): |
| 282 | + super().__init__((TimeDistributed), include_header=time_distributed_include_list) |
| 283 | + self.template_start = time_distributed_loop_start_template |
| 284 | + self.template_end = time_distributed_loop_end_template |
| 285 | + |
| 286 | + def format(self, node): |
| 287 | + params = self._default_function_params(node) |
| 288 | + |
| 289 | + input_shape = node.get_input_variable().shape |
| 290 | + params['dim'] = len(input_shape) |
| 291 | + if node.name.endswith('_end'): |
| 292 | + params['dim'] += 1 # The input variable will be from the wrapped layer, without time dimension |
| 293 | + |
| 294 | + loop_mode = node.get_attr('time_step_loop_parallelism') |
| 295 | + if loop_mode == 'unroll': |
| 296 | + params['loop_mode'] = '#pragma HLS UNROLL' |
| 297 | + elif loop_mode == 'pipeline': |
| 298 | + params['loop_mode'] = '#pragma HLS PIPELINE' |
| 299 | + else: |
| 300 | + params['loop_mode'] = '' |
| 301 | + |
| 302 | + if node.attributes['wrapped_layer'].name == node.name + '_end': |
| 303 | + return self.template_start.format(**params) |
| 304 | + else: |
| 305 | + return self.template_end.format(**params) |
0 commit comments