Skip to content

ASR,ST and CS recipies #1307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
45 changes: 33 additions & 12 deletions lhotse/dataset/speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
input_strategy: BatchIO = PrecomputedFeatures(),
lid: bool = False,
):
"""
k2 ASR IterableDataset constructor.
Expand All @@ -78,13 +79,15 @@ def __init__(
Examples: normalization, SpecAugment, etc.
:param input_strategy: Converts cuts into a collated batch of audio/features.
By default, reads pre-computed features from disk.
:param lid: adding lid information to the batch.
"""
super().__init__()
# Initialize the fields
self.return_cuts = return_cuts
self.cut_transforms = ifnone(cut_transforms, [])
self.input_transforms = ifnone(input_transforms, [])
self.input_strategy = input_strategy
self.lid = lid

# This attribute is a workaround to constantly growing HDF5 memory
# throughout the epoch. It regularly closes open file handles to
Expand Down Expand Up @@ -132,19 +135,37 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]
segments = torch.stack(list(supervision_intervals.values()), dim=1)
for tnfm in self.input_transforms:
inputs = tnfm(inputs, supervision_segments=segments)

batch = {
"inputs": inputs,
"supervisions": default_collate(
[
{
"text": supervision.text,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can just add ”language”: supervision.language, in the line below and always return it to get rid of the extra option and code duplication.

}
for sequence_idx, cut in enumerate(cuts)
if self.lid == True:
batch = {
"inputs": inputs,
"lids": [
supervision.language
for _, cut in enumerate(cuts)
for supervision in cut.supervisions
]
),
}
],
"supervisions": default_collate(
[
{
"text": supervision.text,
}
for sequence_idx, cut in enumerate(cuts)
for supervision in cut.supervisions
]
),
}
else:
batch = {
"inputs": inputs,
"supervisions": default_collate(
[
{
"text": supervision.text,
}
for sequence_idx, cut in enumerate(cuts)
for supervision in cut.supervisions
]
),
}
# Update the 'supervisions' field with sequence_idx and start/num frames/samples
batch["supervisions"].update(supervision_intervals)
if self.return_cuts:
Expand Down