Skip to content

Commit 708bf37

Browse files
committed
Fix training of PARSeq model with pretrained weights
1 parent f6f072e commit 708bf37

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ The training script can train any supported model. You can override any configur
129129

130130
### Finetune using pretrained weights
131131
```bash
132-
./train.py pretrained=parseq-tiny # Not all experiments have pretrained weights
132+
./train.py +experiment=parseq-tiny pretrained=parseq-tiny # Not all experiments have pretrained weights
133133
```
134134

135135
### Train a model variant/preconfigured experiment

train.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ def main(config: DictConfig):
7777
model: BaseSystem = hydra.utils.instantiate(config.model)
7878
# If specified, use pretrained weights to initialize the model
7979
if config.pretrained is not None:
80-
model.load_state_dict(get_pretrained_weights(config.pretrained))
80+
m = model.model if config.model._target_.endswith('PARSeq') else model
81+
m.load_state_dict(get_pretrained_weights(config.pretrained))
8182
print(summarize(model, max_depth=2))
8283

8384
datamodule: SceneTextDataModule = hydra.utils.instantiate(config.data)

0 commit comments

Comments
 (0)