Skip to content

changed time_cutoff option #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* Improves support for datamodules with multiple test sets. Generalises this to support GO and FOLD. Also adds multiple seq ID.-based splits for GO. [#72](https://github.com/a-r-j/ProteinWorkshop/pull/72)
* Add redownload checks for already downloaded datasets and harmonise pdb download interface [#86](https://github.com/a-r-j/ProteinWorkshop/pull/86)
* Remove remaining errors from PDB dataset change
* Add option to create pdb datasets with sequence-based splits [#88](https://github.com/a-r-j/ProteinWorkshop/pull/88)
* Add option to create pdb datasets with sequence-based splits [#88](https://github.com/a-r-j/ProteinWorkshop/pull/88) as well as time-based splits [#89](https://github.com/a-r-j/ProteinWorkshop/pull/89)

### Models

Expand Down
11 changes: 7 additions & 4 deletions proteinworkshop/config/dataset/pdb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ datamodule:

pdb_dataset:
_target_: "proteinworkshop.datasets.pdb_dataset.PDBData"
fraction: 1.0 # Fraction of dataset to use
fraction: 0.01 # Fraction of dataset to use
molecule_type: "protein" # Type of molecule for which to select
experiment_types: ["diffraction", "NMR", "EM", "other"] # All experiment types
max_length: 1000 # Exclude polypeptides greater than length 1000
max_length: 150 # Exclude polypeptides greater than length 1000
min_length: 10 # Exclude peptides of length 10
oligomeric_min: 1 # Include only monomeric proteins
oligomeric_max: 5 # Include up to 5-meric proteins
Expand All @@ -24,6 +24,9 @@ datamodule:
remove_non_standard_residues: True # Include only proteins containing standard amino acid residues
remove_pdb_unavailable: True # Include only proteins that are available to download
train_val_test: [0.8, 0.1, 0.1] # Cross-validation ratios to use for train, val, and test splits
split_type: "sequence_similarity" # Split sequences by sequence similarity clustering, other option is "random"
split_sequence_similiarity: 0.3 # Clustering at 30% sequence similarity (argument is ignored if split_type="random")
split_type: "sequence_similarity" # Split sequences by sequence similarity clustering, other options are "random" and "time_cutoff"
split_sequence_similiarity: 0.3 # Clustering at 30% sequence similarity (argument is ignored if split_type!="sequence_similarity")
overwrite_sequence_clusters: False # Previous clusterings at same sequence similarity are reused and not overwritten
split_time_frames: null # Time-cutoffs for train, val and test set (argument is ignored if split_type!="time_cutoff") - e.g., ["2020-01-01", "2021-01-01", "2023-03-01"]


26 changes: 20 additions & 6 deletions proteinworkshop/datasets/pdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import hydra
import omegaconf
import numpy as np
import os
import pandas as pd
import pathlib
Expand Down Expand Up @@ -30,9 +31,11 @@ def __init__(
remove_non_standard_residues: bool,
remove_pdb_unavailable: bool,
train_val_test: List[float],
split_type: Literal["sequence_similarity", "random"],
split_sequence_similiarity: int,
overwrite_sequence_clusters: bool
split_type: Literal["sequence_similarity", "time_cutoff", "random"] = "random",
split_sequence_similiarity: Optional[int] = None,
overwrite_sequence_clusters: Optional[bool] = False,
split_time_frames: Optional[List[str]] = None,

):
self.fraction = fraction
self.molecule_type = molecule_type
Expand All @@ -52,6 +55,11 @@ def __init__(
self.split_type = split_type
self.split_sequence_similarity = split_sequence_similiarity
self.overwrite_sequence_clusters = overwrite_sequence_clusters
if self.split_type == "time_cutoff":
try:
self.split_time_frames = [np.datetime64(date) for date in split_time_frames]
except:
raise TypeError(f"{split_time_frames} does not contain valid dates for np.datetime64 format")
self.splits = ["train", "val", "test"]

def create_dataset(self):
Expand Down Expand Up @@ -128,9 +136,15 @@ def create_dataset(self):
elif self.split_type == "sequence_similarity":
log.info(f"Splitting dataset via sequence-similarity split into {self.train_val_test}...")
log.info(f"Using {self.split_sequence_similarity} sequence similarity for split")
pdb_manager.cluster(min_seq_id=self.split_sequence_similarity, update=True)
splits = pdb_manager.split_clusters(
pdb_manager.df, update=True, overwrite = self.overwrite_sequence_clusters)
pdb_manager.cluster(min_seq_id=self.split_sequence_similarity, update=True,
overwrite = self.overwrite_sequence_clusters)
splits = pdb_manager.split_clusters(pdb_manager.df, update=True)

elif self.split_type == "time_cutoff":
log.info(f"Splitting dataset via time_cutoff split into {self.train_val_test}...")
log.info(f"Using {self.split_time_frames} dates for split")
pdb_manager.split_time_frames = self.split_time_frames
splits = pdb_manager.split_by_deposition_date(df=pdb_manager.df, update=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!


log.info(splits["train"])
return splits
Expand Down
Loading