Skip to content

Commit ca0f770

Browse files
a-r-jArian Jamasb
and
Arian Jamasb
authored
Minor hotfixes for data infrastructure #23 (#24)
* add conditional download flow logic #23 * correct references to processed files * fix reference to processed dir; update changelog * expose overwrite arg to users * fix dummy datamodule * add missing enumerate * fix explicit download loop skip * Add missing return * fix in memory dataloading bug * change how foldcomp downloads occur * inc graphein dependecy version; bump version * bump docs version * update changelog --------- Co-authored-by: Arian Jamasb <[email protected]>
1 parent 07404a3 commit ca0f770

38 files changed

+163
-24
lines changed

CHANGELOG.md

+8
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
### 2.0.2 (Unreleased)
2+
3+
* Fixes raw data download triggered by absence of PDB when using pre-processed datasets ([#24](https://github.com/a-r-j/ProteinWorkshop/pull/24))
4+
* Fixes bug where batches created from `in_memory=True` data were not correctly formatted ([#24](https://github.com/a-r-j/ProteinWorkshop/pull/24))
5+
* Consistently exposes the `overwrite` argument for datamodules to users ([#24](https://github.com/a-r-j/ProteinWorkshop/pull/24))
6+
* Fixes bug where downloading FoldComp datasets into directories with the same name as the dataset throws an error ([#24](https://github.com/a-r-j/ProteinWorkshop/pull/24))
7+
* Increments `graphein` dependency to `1.7.3` ([#24](https://github.com/a-r-j/ProteinWorkshop/pull/24))
8+
19
### 2.0.1 (29/08/2023)
210

311
* Fixes incorrect lookup of `DATA_PATH` env var ([#19](https://github.com/a-r-j/ProteinWorkshop/pull/19))

docs/source/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
1010
project = "Protein Workshop"
1111
author = "Arian R. Jamasb"
12-
release = "0.2.1"
12+
release = "0.2.2"
1313
copyright = f"{datetime.datetime.now().year}, {author}"
1414

1515
# -- General configuration ---------------------------------------------------

poetry.lock

+4-4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

proteinworkshop/config/dataset/antibody_developability.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ datamodule:
99
format: "mmtf" # Format of the structure files
1010
obsolete_strategy: "drop" # What to do with obsolete PDB entries
1111
transforms: ${transforms} # Transforms to apply to dataset examples
12+
overwrite: False
1213
num_classes: 2 # Number of classes

proteinworkshop/config/dataset/cath.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ datamodule:
88
batch_size: 32 # Batch size for dataloader
99
dataset_fraction: 1.0 # Fraction of the dataset to use
1010
transforms: ${transforms} # Transforms to apply to dataset examples
11+
overwrite: False # Whether to overwrite the dataset if it already exists
1112
num_classes: 23 # Number of classes

proteinworkshop/config/dataset/ccpdb_ligands.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ datamodule:
1414
val_fraction: 0.1 # Fraction of the dataset to use for validation
1515
test_fraction: 0.1 # Fraction of the dataset to use for testing
1616
transforms: ${transforms}
17+
overwrite: False # Whether to overwrite the dataset if it already exists
1718

1819
num_classes: 7 # Number of classes

proteinworkshop/config/dataset/ccpdb_metal.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ datamodule:
1414
val_fraction: 0.1 # Fraction of the dataset to use for validation
1515
test_fraction: 0.1 # Fraction of the dataset to use for testing
1616
transforms: ${transforms}
17+
overwrite: False # Whether to overwrite the dataset if it already exists
1718

1819
num_classes: 7 # Number of classes

proteinworkshop/config/dataset/ccpdb_nucleic.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ datamodule:
1414
val_fraction: 0.1 # Fraction of the dataset to use for validation
1515
test_fraction: 0.1 # Fraction of the dataset to use for testing
1616
transforms: ${transforms}
17+
overwrite: False # Whether to overwrite the dataset if it already exists
1718

1819
num_classes: 2 # Number of classes

proteinworkshop/config/dataset/ccpdb_nucleotides.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ datamodule:
1414
val_fraction: 0.1 # Fraction of the dataset to use for validation
1515
test_fraction: 0.1 # Fraction of the dataset to use for testing
1616
transforms: ${transforms}
17+
overwrite: False # Whether to overwrite the dataset if it already exists
1718

1819
num_classes: 8 # Number of classes

proteinworkshop/config/dataset/deep_sea_proteins.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ datamodule:
99
obsolete_strategy: "drop"
1010
format: "mmtf" # Format of the raw PDB/MMTF files
1111
transforms: ${transforms}
12+
overwrite: False
1213
num_classes: 2

proteinworkshop/config/dataset/dummy.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@ datamodule:
88
obsolete_strategy: "drop"
99
format: "mmtf.gz" # Format of the raw PDB/MMTF files
1010
transforms: ${transforms}
11+
overwrite: True
1112

1213
num_classes: 2 # Number of classes in the dataset

proteinworkshop/config/dataset/ec_reaction.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ datamodule:
99
dataset_fraction: 1.0 # Fraction of the dataset to use
1010
shuffle_labels: False # Whether to shuffle labels for permutation testing
1111
transforms: ${transforms}
12+
overwrite: False
1213
num_classes: 384

proteinworkshop/config/dataset/fold_family.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ datamodule:
88
dataset_fraction: 1.0 # Fraction of dataset to use
99
shuffle_labels: False # Whether to shuffle labels for permutation testing
1010
transforms: ${transforms} # Transforms to apply to dataset examples
11+
overwrite: False # Whether to overwrite existing dataset files
1112
num_classes: 1195 # Number of classes

proteinworkshop/config/dataset/fold_fold.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ datamodule:
88
dataset_fraction: 1.0 # Fraction of dataset to use
99
shuffle_labels: False # Whether to shuffle labels for permutation testing
1010
transforms: ${transforms} # Transforms to apply to dataset examples
11+
overwrite: False # Whether to overwrite existing dataset files
1112
num_classes: 1195 # Number of classes

proteinworkshop/config/dataset/fold_superfamily.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ datamodule:
88
dataset_fraction: 1.0 # Fraction of dataset to use
99
shuffle_labels: False # Whether to shuffle labels for permutation testing
1010
transforms: ${transforms} # Transforms to apply to dataset examples
11+
overwrite: False # Whether to overwrite existing dataset files
1112
num_classes: 1195 # Number of classes

proteinworkshop/config/dataset/go-bp.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ datamodule:
1010
num_workers: 8 # Number of workers for dataloader
1111
split: "BP" # Split of the dataset to use (`BP`, `MF`, `CC`)
1212
transforms: ${transforms} # Transforms to apply to dataset examples
13+
overwrite: False # Whether to overwrite existing dataset files
1314
num_classes: 2 # Number of classes

proteinworkshop/config/dataset/go-cc.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ datamodule:
1010
num_workers: 8 # Number of workers for dataloader
1111
split: "CC" # Split of the dataset to use (`BP`, `MF`, `CC`)
1212
transforms: ${transforms} # Transforms to apply to dataset examples
13+
overwrite: False # Whether to overwrite existing dataset files
1314
num_classes: 2 # Number of classes

proteinworkshop/config/dataset/go-mf.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ datamodule:
1010
num_workers: 8 # Number of workers for dataloader
1111
split: "MF" # Split of the dataset to use (`BP`, `MF`, `CC`)
1212
transforms: ${transforms} # Transforms to apply to dataset examples
13+
overwrite: False # Whether to overwrite existing dataset files
1314
num_classes: 2 # Number of classes

proteinworkshop/config/dataset/masif_site.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ datamodule:
99
dataset_fraction: 1.0 # Fraction of the dataset to use
1010
shuffle_labels: False # Whether to shuffle labels for permutation testing
1111
transforms: ${transforms} # Transforms to apply to dataset examples
12+
overwrite: False # Whether to overwrite existing dataset files
1213
num_classes: 2 # Number of classes

proteinworkshop/config/dataset/metal_3d.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ datamodule:
99
num_workers: 8 # Number of workers for dataloader
1010
transforms: ${transforms}
1111
obsolete_strategy: "drop" # Or replace
12+
overwrite: False
1213
num_classes: 2

proteinworkshop/config/dataset/pdb.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ datamodule:
55
num_workers: 4 # Number of workers for dataloader
66
pin_memory: True # Pin memory for dataloader
77
transforms: ${transforms} # Transforms to apply to dataset examples
8+
overwrite: False # Whether to overwrite existing dataset files
89

910
pdb_dataset:
1011
_target_: "proteinworkshop.datasets.pdb_dataset.PDBData"

proteinworkshop/config/dataset/ptm.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ datamodule:
77
pin_memory: True # Pin memory for dataloader
88
num_workers: 16 # Number of workers for dataloader
99
transforms: ${transforms} # Transforms to apply to dataset examples
10+
overwrite: False # Whether to overwrite existing dataset files
1011
num_classes: 13 # Number of classes

proteinworkshop/datasets/antibody_developability.py

+6
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(
2727
format: Literal["mmtf", "pdb"] = "mmtf",
2828
obsolete_strategy: str = "drop",
2929
transforms: Optional[List[Callable]] = None,
30+
overwrite: bool = False,
3031
) -> None:
3132
"""
3233
Data module for antibody developability dataset from Chen et al.
@@ -49,6 +50,9 @@ def __init__(
4950
:type obsolete_strategy: str
5051
:param transforms: List of transforms to apply to dataset.
5152
:type transforms: Optional[List[Callable]]
53+
:param overwrite: Whether or not to overwrite existing processed data.
54+
Defaults t o ``False``.
55+
:type overwrite: bool
5256
"""
5357
super().__init__()
5458
self.root = pathlib.Path(path)
@@ -64,6 +68,7 @@ def __init__(
6468

6569
self.format = format
6670
self.obsolete_strategy = obsolete_strategy
71+
self.overwrite = overwrite
6772

6873
if transforms is not None:
6974
self.transform = self.compose_transforms(
@@ -136,6 +141,7 @@ def _get_dataset(self, split: str) -> ProteinDataset:
136141
format=self.format,
137142
transform=self.transform,
138143
in_memory=self.in_memory,
144+
overwrite=self.overwrite,
139145
)
140146

141147
def train_dataset(self) -> ProteinDataset:

proteinworkshop/datasets/astral.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pathlib
33
import random
44
import tarfile
5-
from typing import Callable, Dict, Iterable, List, Optional, Literal
5+
from typing import Callable, Dict, Iterable, List, Literal, Optional
66

77
import omegaconf
88
import pandas as pd
@@ -140,7 +140,9 @@ def parse_class_map(self) -> Dict[str, str]:
140140
def setup(self, stage: Optional[str] = None):
141141
self.download()
142142

143-
def parse_dataset(self, split: Literal["train", "val", "test"]) -> List[str]:
143+
def parse_dataset(
144+
self, split: Literal["train", "val", "test"]
145+
) -> List[str]:
144146
"""Parses ASTRAL dataset. Returns a list of IDs for each split.
145147
146148
:param split: Split to parse.

proteinworkshop/datasets/base.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,18 @@ def __init__(
296296
self.store_het = store_het
297297
self.out_names = out_names
298298

299+
# Determine whether to download raw structures
300+
if not self.overwrite and all(
301+
os.path.exists(Path(self.root) / "processed" / p)
302+
for p in self.processed_file_names
303+
):
304+
logger.info(
305+
f"All structures already processed and overwrite=False. Skipping download."
306+
)
307+
self._skip_download = True
308+
else:
309+
self._skip_download = False
310+
299311
super().__init__(root, transform, pre_transform, pre_filter, log)
300312
self.structures = pdb_codes if pdb_codes is not None else pdb_paths
301313
if self.in_memory:
@@ -319,6 +331,11 @@ def download(self):
319331
320332
Downloaded files are stored in ``self.raw_dir``.
321333
"""
334+
if self._skip_download:
335+
logger.info(
336+
"All structures already processed and overwrite=False. Skipping download."
337+
)
338+
return
322339
if self.pdb_codes is not None:
323340
to_download = (
324341
self.pdb_codes
@@ -366,6 +383,8 @@ def raw_file_names(self) -> List[str]:
366383
:return: List of raw file names.
367384
:rtype: List[str]
368385
"""
386+
if self._skip_download:
387+
return []
369388
if self.pdb_paths is None:
370389
return [f"{pdb}.{format}" for pdb in self.pdb_codes]
371390
else:
@@ -419,7 +438,7 @@ def process(self):
419438
pdb_codes = self.pdb_codes
420439

421440
raw_dir = Path(self.raw_dir)
422-
for i, pdb in tqdm(pdb_codes):
441+
for i, pdb in enumerate(tqdm(pdb_codes)):
423442
try:
424443
path = raw_dir / f"{pdb}.{self.format}"
425444
if path.exists():
@@ -473,7 +492,7 @@ def get(self, idx: int) -> Data:
473492
:return: PyTorch Geometric Data object.
474493
"""
475494
if self.in_memory:
476-
return self.data[idx]
495+
return self._batch_format(self.data[idx])
477496

478497
if self.out_names is not None:
479498
fname = f"{self.out_names[idx]}.pt"

proteinworkshop/datasets/cath.py

+7
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(
2525
num_workers: int = 16,
2626
dataset_fraction: float = 1.0,
2727
transforms: Optional[Iterable[Callable]] = None,
28+
overwrite: bool = False,
2829
) -> None:
2930
"""Data module for CATH dataset.
3031
@@ -46,6 +47,9 @@ def __init__(
4647
:type dataset_fraction: float
4748
:param transforms: List of transforms to apply to dataset.
4849
:type transforms: Optional[List[Callable]]
50+
:param overwrite: Whether to overwrite existing data.
51+
Defaults to ``False``.
52+
:type overwrite: bool
4953
"""
5054
super().__init__()
5155

@@ -166,6 +170,7 @@ def train_dataset(self) -> ProteinDataset:
166170
transform=self.transform,
167171
format=self.format,
168172
in_memory=self.in_memory,
173+
overwrite=self.overwrite,
169174
)
170175

171176
def val_dataset(self) -> ProteinDataset:
@@ -188,6 +193,7 @@ def val_dataset(self) -> ProteinDataset:
188193
transform=self.transform,
189194
format=self.format,
190195
in_memory=self.in_memory,
196+
overwrite=self.overwrite,
191197
)
192198

193199
def test_dataset(self) -> ProteinDataset:
@@ -209,6 +215,7 @@ def test_dataset(self) -> ProteinDataset:
209215
transform=self.transform,
210216
format=self.format,
211217
in_memory=self.in_memory,
218+
overwrite=self.overwrite,
212219
)
213220

214221
def train_dataloader(self) -> ProteinDataLoader:

0 commit comments

Comments
 (0)