-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[wenet] nn context biasing #1982
base: main
Are you sure you want to change the base?
Conversation
…tion problem due to context mismatch.
wenet/transformer/context_module.py
Outdated
_, last_state = self.sen_rnn(pack_seq) | ||
laste_h = last_state[0] | ||
laste_c = last_state[1] | ||
state = torch.cat([laste_h[-1, :, :], laste_h[0, :, :], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi,这里的实现是最后一层BLSTM的reverse last_h_state和第一层的forward last_h_state?
torch.nn.LSTM
**h_n**: tensor of shape :math:
(D * \text{num_layers}, H_{out}) for unbatched input or :math:
(D * \text{num_layers}, N, H_{out})containing the final hidden state for each element in the sequence. When ``bidirectional=True``,
h_n will contain a concatenation of the final forward and reverse hidden states, respectively.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是我写错了,0应该改成-2,感谢指正
…hunk during bias module training
for utt_label in batch_label: | ||
st_index_list = [] | ||
for i in range(len(utt_label)): | ||
if '▁' not in symbol_table: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我想请问下,这里如果我的建模单元是中文汉字+英文bpe,这里是不是不太适用,需要改下?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,我自己训练的时候都是纯中文或者纯英文,英文在热词采样的时候对下划线特殊处理了下保证不会采样出半个词的情况,如果同时有中文和英文这部分最好是改下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
谢谢~
可以提供一些模型训练时候的conf.yaml参数设置吗?谢谢 |
上面的模型链接中有我用的yaml文件,可以直接下载 |
我想请问下,我在aishell170小时上训练了deep biasing的模型,但是在解码的时候如果设置deep biasing,会出现很多的漏字现象,这个会是什么原因呀? |
漏字的现象很严重吗,使用的热词列表大小多大?我这边也有做过aishell1的实验,结果比较正常,没有观察到漏字的现象 |
很严重,就是一段一段的漏,原始设置的热词表大小是187,modelscope上开源的热词测试集,然后是设置了context_filtering参数进行过滤,如果过滤后热词表只有【0】的话,基本上就整句话漏了,如果是有热词的情况,也会出现成片漏掉的情况,设置的deep_score=1,filter_threshold=-4。目前训练迭代了17个epoch,loss_bias在10左右 |
那确实很奇怪,总体loss的情况正常吗,正常情况下收敛到差不多的时候,bias loss应该是和ctc loss差不多,总体的loss应该会比没有训练热词模块之前更低一些,在aishell上大概是3.4左右。你用的热词相关的yaml配置是否都和我上面给出的一致 |
还有就是我在做aishell1实验的时候发现对于aishell1这种句子大部分都很短的数据集,热词采样的代码需要去掉那个判断采样热词不能交叉的逻辑,不然很容易一句话只能采样出一个热词,这样训出来热词增强的效果会差一些,不过这个问题并不会导致漏字的情况。 |
目前训练出来整体的loss还算是正常,从3.1下降到了2.5,bias loss会比ctc loss高一些。我现在的热词配置就是您给的这个哈 |
会不会是你修改的热词采样部分的代码有点问题,我这边确实没遇到过你描述的状况,也想不出是什么原因,漏字而且还和传入的热词数量有关,理论上来说热词列表只剩个0应该对于正常解码的影响是最小的 |
您好,我尝试复现您在librispeech的结果,但是在训练热词增强模型时,出现cv loss值不下降的情况(保持在160多),并且train loss也是下降到四五十就不太下降了。 另外,我发现每次训练几个batch时,都会花五六分钟去训练下一个batch,正常情况我的显卡每训练一个batch的时间是30s左右,下面是一小段训练日志。。。 我没修改任何代码,训练conf文件也是您提供那个train_bias, 能大概分析下出现问题的原因吗? 谢谢! 2023-10-17 18:48:14,596 DEBUG TRAIN Batch 0/300 loss 77.121582 loss_att 68.588936 loss_ctc 90.912209 loss_bias 61.188702 lr 0.00001204 rank 3 |
你是不是直接从头开始训练了,为了减少对原本asr性能的影响,我写的是从一个预训练好的asr模型开始训,除了热词模块之外的参数都给冻结了。从头开始训应该也能够收敛,但是至少得把冻结的参数先解冻。 |
没有,也是用的之前在librispeech上预训练好的asr模型,做了参数冻结 |
我这边试着调整了下导出模块代码,我把forward_context_emb放到forward里面了,然后直接传热词列表,目前看导出是成功的:
|
导出肯定是没问题的,你试试onnx.run ,即导出代码里的 ort_outs = ort_session.run(None, ort_inputs) ,验证torch 和onnx一致性的那部分,然后你调整不同的context_list 的size,不一定是(200,10) ,当然如果你没用两阶段过滤算法,然后所有的音频都用同样的热词,那是没问题的,因为这个时候context的size是固定的。如果你的context 的size一直在变的话,你再试试看。 |
torch导出不是支持dynamic_axes吗,我这里是给了动态size的,后面的测试目前看是过了的 |
大佬,你尝试过改runtime代码吗,很难对原始代码不侵入,需要调整ForwardEncoderFunc, 加一个热词输入处理以及热词模块前向,主要ForwardEncoderFunc这个函数把ctc_prob计算也加进去了,其实每个模块抽出来会比较好改 |
你可以看下我仓库里的nn_bias 分支,我已经把代码合进来了,我做了如下改动:
|
好的,感谢大佬分享哈,我正在尝试改runtime,改完了也分享一下 |
@dahu1 大佬这个模型导出确实是有维度问题的,不过torch导出是支持dynamic_axes,这个确实是有效的;我目前定位到cppn模型内部使用了torch.nn.utils.rnn.pack_padded_sequence,把数据处理成变长的了,这个代码删了,动态维度导出是正常的,但是会对效果有明显影响,我理解可能是引入了padding的数据做embedding计算了: wenet/wenet/transformer/context_module.py Line 48 in 762e199
|
找到一个讨论的帖子,按帖子里这么做确实可以正常导出,onnxruntime可以跑通,就是会报warning;python代码最后导出的检查部分精度差距会比较大:17%,我看了下cer也会损失绝对0.1%左右;我在尝试用mask的形式去处理数据 |
@dahu1 大佬,我这尝试支持了下cppn的onnxruntime和模型导出,可以做个参考:https://github.com/fclearner/wenet/tree/nn_bias |
这个问题把lstm改成单向,然后不使用torch.nn.utils.rnn.pack_padded_sequence就行了,维度要做下调整,我看了下funasr就是这么做的,直接根据热词列表长度去lstm状态 |
您好,您的邮件我已经收到,会尽快给您回复!祝您生活愉快,工作顺利!
|
好的,感谢大佬分享,我学习一下。确实单个推理会比较慢,如果能做到转onnx进行batch推理,会快很多。另外你说的把lstm改成单向的,那是不是性能会有损啊? |
我已经把新改好的代码提交了,
我这边刚做完测试,lstm单向相较双向确实热词召回率变差了六七个点,目前来看解决变长的onnx导出问题有两种策略:1、参考funasr,使用单向的lstm,状态索引用热词列表长度;2、使用torch.jit.script先转成静态图,但是运行的时候会报warning,需要调整onnx的日志屏蔽级别;我刚更新了第一种的代码,稍后我把第二种也commit一下;如果有大佬有更好的方法也可以回复下 |
更改为单向lstm后,是不是context 模型需要重新训练?毕竟训练和推理要保持一致。 |
是的,需要重新训练,主要是因为lstm的输出需要接一个context_encoder映射成embedding_size,lstm的单向状态仍然是可以通过热词列表长度索引的,我已经更新了双向的模型导出代码了,你先试试吧,因为我是直接copy到github的,没有做测试,可能会有问题:https://github.com/fclearner/wenet/blob/nn_bias/wenet/bin/export_onnx_cpu.py |
记录下问题, 问题二: 问题三: |
请问您跑这个huggingface上的bwer计算脚本跑通了吗,我跑出来一直不对,wer非常高 |
您好,您的邮件我已经收到,会尽快给您回复!祝您生活愉快,工作顺利!
|
你好请问您漏字问题解决了吗,我也碰到了同样的问题 |
您好,您的邮件我已经收到,会尽快给您回复!祝您生活愉快,工作顺利!
|
是不是没有用全量数据微调,我之前试过小数据量会引入通用效果损失 |
我也是在aishell170小时数据集上训练了30轮进行测试的 |
初始模型也是aishell训练的嘛,把deep biasing的权重调小看看 |
您好,我使用wenetspeech预训练的模型进行微调,但是wenetspeech数据集比较大我没有使用全量数据,只下载了大概40G的数据,目前训练了60轮,但是感觉热词增强的效果一般,较训练前没有什么提升,是训练的不够吗还是其他原因,下面截取了一小部分训练日志,可以帮忙看一下损失值都正常吗,cv_loss也基本在12左右波动,目前还在继续训练观察。 |
全量数据微调,轮次不需要太大,越靠前通用效果损失越小,但是热词效果越差,我一般都是取前七轮做个avg的,你可以拿靠前的轮次测测看热词的效果,测的时候尝试对比下不同热词权重的差异 |
好的,谢谢回复,我去试试 |
请问有没有遇到过导出整个识别+偏置模型的onnx时,onnx图中LSTM层输入的热词表大小会固定不变的问题 |
建议热词模块拆出来导,这样还能节省推理资源,lstm那块需要设置动态维度,而且建议用单向lstm,通过热词列表长度索引状态 |
好的,谢谢您的建议 :) |
The Deep biasing method comes from: https://arxiv.org/abs/2305.12493
The pre-trained ASR model is fine-tuned to achieve biasing. During the training process, the original parameters of the ASR model are frozen, and only the parameters related to deep biasing are trained. use_dynamic_chunk cannot be enabled during fine-tuning (the biasing effect will decrease), but the biasing effects of streaming and non-streaming inference are basically the same.
RESULT:
Model link: https://huggingface.co/kxhuang/Wenet_Librispeech_deep_biasing/tree/main
(I used the BLSTM forward state incorrectly when training this model, so to test this model you need to change the -2 to 0 in the forward function of the BLSTM class in wenet/transformer/context_module.py)
Using the Wenet Librispeech pre-trained AED model, after fine-tuning for 30 epochs, the final model was obtained with an average of 3 epochs. The following are the test results of the Librispeech test other.
The context list for the test set is sourced from: https://github.com/facebookresearch/fbai-speech/tree/main/is21_deep_bias
Non-streaming inference:
+ deep biasing
+ deep biasing
Streaming inference (chunk 16):
+ deep biasing