Skip to content

Commit 340885e

Browse files
authored
Merge pull request #101 from a-r-j/amorehead-decoder-disable
Add ability for `BenchmarkModel` to have its decoder disabled
2 parents 61294d4 + e536f72 commit 340885e

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
* Add support for handling training/validation OOMs gracefully [#81](https://github.com/a-r-j/ProteinWorkshop/pull/81)
1616
* Add support for handling backward OOMs gracefully [#83](https://github.com/a-r-j/ProteinWorkshop/pull/83)
1717
* Update GCPNet paper link [#85](https://github.com/a-r-j/ProteinWorkshop/pull/85)
18+
* Add ability for `BenchmarkModel` to have its decoder disabled [#101](https://github.com/a-r-j/ProteinWorkshop/pull/101)
1819

1920
### Framework
2021

proteinworkshop/models/base.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -412,9 +412,13 @@ def __init__(self, cfg: DictConfig) -> None:
412412
self.encoder: nn.Module = hydra.utils.instantiate(cfg.encoder)
413413
logger.info(self.encoder)
414414

415-
logger.info("Instantiating decoders...")
416-
self.decoder: nn.ModuleDict = self._build_output_decoders()
417-
logger.info(self.decoder)
415+
if hasattr(cfg.decoder, "disable") and cfg.decoder.disable:
416+
logger.info("Disabling decoder as requested")
417+
self.decoder = None
418+
else:
419+
logger.info("Instantiating decoders...")
420+
self.decoder: nn.ModuleDict = self._build_output_decoders()
421+
logger.info(self.decoder)
418422

419423
logger.info("Instantiating losses...")
420424
self.losses = self.configure_losses(cfg.task.losses)

0 commit comments

Comments
 (0)