Skip to content

Commit

Permalink
[EDIT] add finetune & fix GA model wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
thanhtvt committed May 1, 2023
1 parent 21d072c commit 2790955
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 21 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,5 @@ docker run -it --name uetasr --gpus all -v <workspace_dir>:/workspace uetasr:v1.
2. [TensorFlowASR: Almost State-of-the-art Automatic Speech Recognition in Tensorflow 2](https://github.com/TensorSpeech/TensorFlowASR)
3. [ESPNet: End-to-End Speech Processing Toolkit](https://github.com/espnet/espnet)
4. [SpeechBrain: A PyTorch-based Speech Toolkit](https://github.com/speechbrain/speechbrain)
5. [Python module for evaluting ASR hypotheses](https://github.com/belambert/asr-evaluation)
5. [Python module for evaluting ASR hypotheses](https://github.com/belambert/asr-evaluation)
6. [Accumulated Gradients for TensorFlow 2](https://github.com/andreped/GradientAccumulator)
11 changes: 0 additions & 11 deletions egs/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,6 @@
help="Enable to evaluate loss in test data.",
)

parser.add_argument(
"--summary",
action="store_true",
help="Enable to print summary.",
)

parser.add_argument(
"--verbose",
action="store_true",
Expand All @@ -74,11 +68,6 @@
def test(config_file):
with open(config_file) as fin:
modules = load_hyperpyyaml(fin)

if args.summary:
model = modules['model']
model.summary()

test_loader = modules['test_loader']
trainer = modules['trainer']

Expand Down
1 change: 0 additions & 1 deletion egs/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def train(config_file):
with open(config_file) as fin:
modules = load_hyperpyyaml(fin)
model = modules['model']
model.summary()
train_loader = modules['train_loader']
dev_loader = modules['dev_loader']
cmvn_loader = None
Expand Down
1 change: 0 additions & 1 deletion egs/vlsp2022/conformer/v3/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,6 @@ callbacks:
## =================== TRAINER =================== ##
trainer: !new:uetasr.trainers.trainer.ASRTrainer
model: !ref <model>
learning_rate: !ref <lr>
beam_decoder: !ref <decoder>
optimizer: !ref <optimizer>
losses: [!ref <rnnt_loss>]
Expand Down
24 changes: 17 additions & 7 deletions uetasr/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ class ASRTrainer(BaseTrainer):
def __init__(
self,
model: tf.keras.Model,
learning_rate: Union[float, LearningRateSchedule],
beam_decoder: tf.keras.layers.Layer,
optimizer: tf.keras.optimizers.Optimizer,
log_append: bool = False,
Expand All @@ -27,6 +26,7 @@ def __init__(
loss_weights: List[float] = [],
metrics: List[tf.keras.metrics.Metric] = [],
num_epochs: int = 1,
finetune: bool = False,
jit_compile: bool = False,
steps_per_execution: int = 1,
callbacks: List[tf.keras.callbacks.Callback] = [],
Expand All @@ -38,19 +38,29 @@ def __init__(
if accum_steps > 1 and has_devices("GPU"):
if get_num_devices("GPU") > 1:
optimizer = GradientAccumulator(optimizer, accum_steps)
# elif get_num_devices("GPU") == 1: # GA model is not stable multi-gpus
# model = GradientAccumulateModel(accum_steps=accum_steps,
# mixed_precision=False,
# use_agc=True,
# inputs=model.input,
# outputs=model.output)
elif get_num_devices("GPU") == 1: # GA model is not stable multi-gpus
model.summary() # this is necessary to build model
model = GradientAccumulateModel(accum_steps=accum_steps,
mixed_precision=False,
use_agc=True,
inputs=model.input,
outputs=model.output)

self.optimizer = optimizer
self.model = model

if pretrained_model:
self.load_model(pretrained_model)

if finetune:
# freeze model except last layers
for layer in self.model.layers:
layer.trainable = False
if layer.name == "rnnt_jointer":
layer.trainable = True
break

self.model.summary()
self.model.compile(loss=losses,
loss_weights=loss_weights,
optimizer=optimizer,
Expand Down

0 comments on commit 2790955

Please sign in to comment.