Skip to content

Commit a914614

Browse files
committed
fix reshaping
1 parent e871ace commit a914614

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

src/pie_modules/models/components/seq2seq_encoder.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -47,20 +47,26 @@ def __init__(self, module: nn.Module, module_output_size: int):
4747
self.module = module
4848
self.module_output_size = module_output_size
4949

50-
def forward(self, values: FloatTensor, sequence_ids: LongTensor, *args, **kwargs) -> Tensor:
50+
def forward(
51+
self, values: FloatTensor, sequence_ids: LongTensor, *args, **kwargs
52+
) -> FloatTensor:
5153
results = torch.zeros(
5254
values.size(0), values.size(1), self.module_output_size, device=values.device
5355
)
5456
for seq_idx in torch.unique(sequence_ids):
5557
# get values for the current sequence (from multiple batch entries)
5658
mask = sequence_ids == seq_idx
59+
# shape: (num_selected, sequence_length, input_size)
5760
selected_values = values[mask]
5861
# flatten the batch dimension
5962
concatenated_sequence = selected_values.view(-1, selected_values.size(-1))
60-
processed_sequence = self.module(concatenated_sequence.unsqueeze(0), *args, **kwargs)
61-
# restore the batch dimension
63+
# (num_selected * sequence_length, input_size) -> (num_selected * sequence_length, output_size)
64+
processed_sequence = self.module(
65+
concatenated_sequence.unsqueeze(0), *args, **kwargs
66+
).squeeze(0)
67+
# restore the batch dimension: (num_selected, sequence_length, output_size)
6268
reconstructed_sequence = processed_sequence.view(
63-
selected_values.size(), processed_sequence.size(-1)
69+
selected_values.size(0), selected_values.size(1), processed_sequence.size(-1)
6470
)
6571
# store the processed sequence back to the results tensor at the correct batch indices
6672
results[mask] = reconstructed_sequence

0 commit comments

Comments
 (0)