Skip to content

Commit 8ba8206

Browse files
committed
remove ctc
1 parent 8613abe commit 8ba8206

File tree

3 files changed

+31
-47
lines changed

3 files changed

+31
-47
lines changed

main.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@
1818
@hydra.main(config_path=args.cp, config_name=args.cn)
1919
def main(cfg: DictConfig):
2020
text_process = TextProcess(**cfg.text_process)
21-
if cfg.decoder.type == "beamsearch":
22-
ctc_decoder = CTCDecoder(text_process=text_process, **cfg.ctcdecoder)
23-
else:
24-
ctc_decoder = None
2521

2622
trainset = VivosDataset(**cfg.dataset, subset="train")
2723
testset = VivosDataset(**cfg.dataset, subset="test")
@@ -32,7 +28,6 @@ def main(cfg: DictConfig):
3228
model = DeepSpeechModule(
3329
n_class=n_class,
3430
text_process=text_process,
35-
ctc_decoder=ctc_decoder,
3631
cfg_optim=cfg.optimizer,
3732
**cfg.model
3833
)

model.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ def __init__(
1616
n_class: int,
1717
lr: float,
1818
text_process: TextProcess,
19-
ctc_decoder: CTCDecoder,
2019
cfg_optim: dict,
2120
):
2221
super().__init__()
@@ -25,7 +24,6 @@ def __init__(
2524
)
2625
self.lr = lr
2726
self.text_process = text_process
28-
self.ctc_decoder = ctc_decoder
2927
self.cal_wer = torchmetrics.WordErrorRate()
3028
self.cfg_optim = cfg_optim
3129
self.criterion = nn.CTCLoss(zero_infinity=True)
@@ -63,12 +61,8 @@ def validation_step(self, batch, batch_idx):
6361
outputs.permute(1, 0, 2), targets, input_lengths, target_lengths
6462
)
6563

66-
if self.ctc_decoder:
67-
# unsqueeze for batchsize 1
68-
predicts = [self.ctc_decoder(sent.unsqueeze(0)) for sent in outputs]
69-
else:
70-
decode = outputs.argmax(dim=-1)
71-
predicts = [self.text_process.decode(sent) for sent in decode]
64+
decode = outputs.argmax(dim=-1)
65+
predicts = [self.text_process.decode(sent) for sent in decode]
7266

7367
targets = [self.text_process.int2text(sent) for sent in targets]
7468

@@ -92,12 +86,8 @@ def test_step(self, batch, batch_idx):
9286
outputs.permute(1, 0, 2), targets, input_lengths, target_lengths
9387
)
9488

95-
if self.ctc_decoder:
96-
# unsqueeze for batchsize 1
97-
predicts = [self.ctc_decoder(sent.unsqueeze(0)) for sent in outputs]
98-
else:
99-
decode = outputs.argmax(dim=-1)
100-
predicts = [self.text_process.decode(sent) for sent in decode]
89+
decode = outputs.argmax(dim=-1)
90+
predicts = [self.text_process.decode(sent) for sent in decode]
10191

10292
targets = [self.text_process.int2text(sent) for sent in targets]
10393

utils.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import torch
2-
import ctcdecode
32

43

54
class TextProcess:
@@ -45,32 +44,32 @@ def int2text(self, s: torch.Tensor) -> str:
4544
return "".join([self.list_vocab[i] for i in s if i > 2])
4645

4746

48-
class CTCDecoder:
49-
def __init__(
50-
self,
51-
alpha: float = 0.5,
52-
beta: float = 0.96,
53-
beam_size: int = 100,
54-
kenlm_path: str = None,
55-
text_process: TextProcess = None,
56-
):
57-
self.text_process = text_process
58-
labels = text_process.list_vocab
59-
blank_id = labels.index("<p>")
47+
# class CTCDecoder:
48+
# def __init__(
49+
# self,
50+
# alpha: float = 0.5,
51+
# beta: float = 0.96,
52+
# beam_size: int = 100,
53+
# kenlm_path: str = None,
54+
# text_process: TextProcess = None,
55+
# ):
56+
# self.text_process = text_process
57+
# labels = text_process.list_vocab
58+
# blank_id = labels.index("<p>")
6059

61-
print("loading beam search with lm...")
62-
self.decoder = ctcdecode.CTCBeamDecoder(
63-
labels,
64-
alpha=alpha,
65-
beta=beta,
66-
beam_width=beam_size,
67-
blank_id=blank_id,
68-
model_path=kenlm_path,
69-
)
70-
print("finished loading beam search")
60+
# print("loading beam search with lm...")
61+
# self.decoder = ctcdecode.CTCBeamDecoder(
62+
# labels,
63+
# alpha=alpha,
64+
# beta=beta,
65+
# beam_width=beam_size,
66+
# blank_id=blank_id,
67+
# model_path=kenlm_path,
68+
# )
69+
# print("finished loading beam search")
7170

72-
def __call__(self, output: torch.Tensor) -> str:
73-
beam_result, beam_scores, timesteps, out_seq_len = self.decoder.decode(output)
74-
tokens = beam_result[0][0]
75-
seq_len = out_seq_len[0][0]
76-
return self.text_process.int2text(tokens[:seq_len])
71+
# def __call__(self, output: torch.Tensor) -> str:
72+
# beam_result, beam_scores, timesteps, out_seq_len = self.decoder.decode(output)
73+
# tokens = beam_result[0][0]
74+
# seq_len = out_seq_len[0][0]
75+
# return self.text_process.int2text(tokens[:seq_len])

0 commit comments

Comments
 (0)