Skip to content

Commit a891666

Browse files
authored
Merge pull request #89 from a-r-j/time_splits
changed time_cutoff option
2 parents d5fbab7 + 7a3875d commit a891666

File tree

3 files changed

+28
-11
lines changed

3 files changed

+28
-11
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
* Improves support for datamodules with multiple test sets. Generalises this to support GO and FOLD. Also adds multiple seq ID.-based splits for GO. [#72](https://github.com/a-r-j/ProteinWorkshop/pull/72)
66
* Add redownload checks for already downloaded datasets and harmonise pdb download interface [#86](https://github.com/a-r-j/ProteinWorkshop/pull/86)
77
* Remove remaining errors from PDB dataset change
8-
* Add option to create pdb datasets with sequence-based splits [#88](https://github.com/a-r-j/ProteinWorkshop/pull/88)
8+
* Add option to create pdb datasets with sequence-based splits [#88](https://github.com/a-r-j/ProteinWorkshop/pull/88) as well as time-based splits [#89](https://github.com/a-r-j/ProteinWorkshop/pull/89)
99

1010
### Models
1111

proteinworkshop/config/dataset/pdb.yaml

+7-4
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ datamodule:
1010

1111
pdb_dataset:
1212
_target_: "proteinworkshop.datasets.pdb_dataset.PDBData"
13-
fraction: 1.0 # Fraction of dataset to use
13+
fraction: 0.01 # Fraction of dataset to use
1414
molecule_type: "protein" # Type of molecule for which to select
1515
experiment_types: ["diffraction", "NMR", "EM", "other"] # All experiment types
16-
max_length: 1000 # Exclude polypeptides greater than length 1000
16+
max_length: 150 # Exclude polypeptides greater than length 1000
1717
min_length: 10 # Exclude peptides of length 10
1818
oligomeric_min: 1 # Include only monomeric proteins
1919
oligomeric_max: 5 # Include up to 5-meric proteins
@@ -24,6 +24,9 @@ datamodule:
2424
remove_non_standard_residues: True # Include only proteins containing standard amino acid residues
2525
remove_pdb_unavailable: True # Include only proteins that are available to download
2626
train_val_test: [0.8, 0.1, 0.1] # Cross-validation ratios to use for train, val, and test splits
27-
split_type: "sequence_similarity" # Split sequences by sequence similarity clustering, other option is "random"
28-
split_sequence_similiarity: 0.3 # Clustering at 30% sequence similarity (argument is ignored if split_type="random")
27+
split_type: "sequence_similarity" # Split sequences by sequence similarity clustering, other options are "random" and "time_cutoff"
28+
split_sequence_similiarity: 0.3 # Clustering at 30% sequence similarity (argument is ignored if split_type!="sequence_similarity")
2929
overwrite_sequence_clusters: False # Previous clusterings at same sequence similarity are reused and not overwritten
30+
split_time_frames: null # Time-cutoffs for train, val and test set (argument is ignored if split_type!="time_cutoff") - e.g., ["2020-01-01", "2021-01-01", "2023-03-01"]
31+
32+

proteinworkshop/datasets/pdb_dataset.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import hydra
44
import omegaconf
5+
import numpy as np
56
import os
67
import pandas as pd
78
import pathlib
@@ -30,9 +31,11 @@ def __init__(
3031
remove_non_standard_residues: bool,
3132
remove_pdb_unavailable: bool,
3233
train_val_test: List[float],
33-
split_type: Literal["sequence_similarity", "random"],
34-
split_sequence_similiarity: int,
35-
overwrite_sequence_clusters: bool
34+
split_type: Literal["sequence_similarity", "time_cutoff", "random"] = "random",
35+
split_sequence_similiarity: Optional[int] = None,
36+
overwrite_sequence_clusters: Optional[bool] = False,
37+
split_time_frames: Optional[List[str]] = None,
38+
3639
):
3740
self.fraction = fraction
3841
self.molecule_type = molecule_type
@@ -52,6 +55,11 @@ def __init__(
5255
self.split_type = split_type
5356
self.split_sequence_similarity = split_sequence_similiarity
5457
self.overwrite_sequence_clusters = overwrite_sequence_clusters
58+
if self.split_type == "time_cutoff":
59+
try:
60+
self.split_time_frames = [np.datetime64(date) for date in split_time_frames]
61+
except:
62+
raise TypeError(f"{split_time_frames} does not contain valid dates for np.datetime64 format")
5563
self.splits = ["train", "val", "test"]
5664

5765
def create_dataset(self):
@@ -128,9 +136,15 @@ def create_dataset(self):
128136
elif self.split_type == "sequence_similarity":
129137
log.info(f"Splitting dataset via sequence-similarity split into {self.train_val_test}...")
130138
log.info(f"Using {self.split_sequence_similarity} sequence similarity for split")
131-
pdb_manager.cluster(min_seq_id=self.split_sequence_similarity, update=True)
132-
splits = pdb_manager.split_clusters(
133-
pdb_manager.df, update=True, overwrite = self.overwrite_sequence_clusters)
139+
pdb_manager.cluster(min_seq_id=self.split_sequence_similarity, update=True,
140+
overwrite = self.overwrite_sequence_clusters)
141+
splits = pdb_manager.split_clusters(pdb_manager.df, update=True)
142+
143+
elif self.split_type == "time_cutoff":
144+
log.info(f"Splitting dataset via time_cutoff split into {self.train_val_test}...")
145+
log.info(f"Using {self.split_time_frames} dates for split")
146+
pdb_manager.split_time_frames = self.split_time_frames
147+
splits = pdb_manager.split_by_deposition_date(df=pdb_manager.df, update=True)
134148

135149
log.info(splits["train"])
136150
return splits

0 commit comments

Comments
 (0)