@@ -47,20 +47,26 @@ def __init__(self, module: nn.Module, module_output_size: int):
47
47
self .module = module
48
48
self .module_output_size = module_output_size
49
49
50
- def forward (self , values : FloatTensor , sequence_ids : LongTensor , * args , ** kwargs ) -> Tensor :
50
+ def forward (
51
+ self , values : FloatTensor , sequence_ids : LongTensor , * args , ** kwargs
52
+ ) -> FloatTensor :
51
53
results = torch .zeros (
52
54
values .size (0 ), values .size (1 ), self .module_output_size , device = values .device
53
55
)
54
56
for seq_idx in torch .unique (sequence_ids ):
55
57
# get values for the current sequence (from multiple batch entries)
56
58
mask = sequence_ids == seq_idx
59
+ # shape: (num_selected, sequence_length, input_size)
57
60
selected_values = values [mask ]
58
61
# flatten the batch dimension
59
62
concatenated_sequence = selected_values .view (- 1 , selected_values .size (- 1 ))
60
- processed_sequence = self .module (concatenated_sequence .unsqueeze (0 ), * args , ** kwargs )
61
- # restore the batch dimension
63
+ # (num_selected * sequence_length, input_size) -> (num_selected * sequence_length, output_size)
64
+ processed_sequence = self .module (
65
+ concatenated_sequence .unsqueeze (0 ), * args , ** kwargs
66
+ ).squeeze (0 )
67
+ # restore the batch dimension: (num_selected, sequence_length, output_size)
62
68
reconstructed_sequence = processed_sequence .view (
63
- selected_values .size (), processed_sequence .size (- 1 )
69
+ selected_values .size (0 ), selected_values . size ( 1 ), processed_sequence .size (- 1 )
64
70
)
65
71
# store the processed sequence back to the results tensor at the correct batch indices
66
72
results [mask ] = reconstructed_sequence
0 commit comments