Skip to content

Commit

Permalink
Improve API documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
Olga Lyashevska committed Oct 31, 2023
1 parent 5cd4169 commit 24a6a1e
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 23 deletions.
23 changes: 16 additions & 7 deletions bird_cloud_gnn/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ def __call__(self, epoch_values):


class EarlyStopperCallback:
"""Callback to check early stopping."""
"""
Callback to check early stopping.
This callback is used to check if the training should be stopped early based on the validation loss.
"""

def __init__(self, **kwargs):
"""Input arguments are passed to EarlyStopper."""
Expand All @@ -62,14 +66,19 @@ def __call__(self, epoch_values):


class CombinedCallback:
"""Helper to combine multiple callbacks."""
"""Helper to combine multiple callbacks.
This class allows multiple callbacks to be combined into a single callback. The callbacks are called in the given
sequence and if one of them returns True, the subsequent callbacks are not called.
Args:
callbacks (iterable): List of callbacks to be combined.
Returns:
bool: True if any of the callbacks return True, False otherwise.
"""

def __init__(self, callbacks):
"""
Args:
callbacks (iterable): List of callbacks. These are called in the given sequence and
if one of them returns True, the subsequents are not called.
"""
self.callbacks = callbacks

def __call__(self, epoch_values):
Expand Down
16 changes: 15 additions & 1 deletion bird_cloud_gnn/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,18 @@


def get_dataloaders(dataset, train_idx, test_idx, batch_size):
"""
Returns train and test dataloaders for a given dataset, train indices, test indices, and batch size.
Args:
dataset (torch_geometric.datasets): The dataset to use for creating dataloaders.
train_idx (list): The indices to use for training.
test_idx (list): The indices to use for testing.
batch_size (int): The batch size to use for the dataloaders.
Returns:
tuple: A tuple containing the train and test dataloaders.
"""
train_sampler = SubsetRandomSampler(train_idx)
test_sampler = SubsetRandomSampler(test_idx)

Expand Down Expand Up @@ -48,8 +60,10 @@ def kfold_evaluate(
learning_rate (float, optional): Learning rate. Defaults to 0.01.
num_epochs (int, optional): Training epochs. Defaults to 20.
batch_size (int, optional): Batch size used in the data loaders. Defaults to 512.
"""
Returns:
None
"""
labels = np.array(dataset.labels)
# Initialize a stratified k-fold splitter
kfold = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
Expand Down
42 changes: 31 additions & 11 deletions bird_cloud_gnn/early_stopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,42 @@


class EarlyStopper:
"""Early stopper check."""
"""Early stopper check.
def __init__(self, patience=3, min_abs_delta=1e-2, min_rel_delta=0.0):
"""EarlyStopper. Use to stop if the validation loss starts increasing.
The validation loss is increasing if
This class is used to stop the training process if the validation loss starts increasing. The validation loss is
considered to be increasing if it is greater than the minimum validation loss found so far plus an absolute and/or
relative tolerance. The class keeps track of the minimum validation loss found so far and the number of consecutive
iterations where the validation loss has increased. If the number of consecutive iterations where the validation
loss has increased exceeds a certain threshold, the training process is stopped.
Attributes:
patience (int): How many consecutive iterations to wait before stopping.
min_abs_delta (float): Absolute tolerance to the increase.
min_rel_delta (float): Relative tolerance to the increase.
counter (int): Number of consecutive iterations where the validation loss has increased.
min_validation_loss (float): Minimum validation loss found so far.
L > Lmin + abs_delta + rel_delta * |Lmin|,
Methods:
__init__(self, patience=3, min_abs_delta=1e-2, min_rel_delta=0.0): Initializes the EarlyStopper object.
early_stop(self, validation_loss): Checks whether it is time to stop, and updates the internal state of the
EarlyStopper object.
where `L` is the current validation loss, `Lmin` is the minimum validation loss found so
far, and `abs_delta` and `rel_delta` are absolute and relative tolerances to the increase,
respectively.
Example usage:
early_stopper = EarlyStopper(patience=5, min_abs_delta=0.1, min_rel_delta=0.01)
for epoch in range(num_epochs):
train_loss = train(model, train_loader)
val_loss = validate(model, val_loader)
if early_stopper.early_stop(val_loss):
print(f"Validation loss has been increasing for {early_stopper.patience} consecutive epochs. "
f"Training stopped.")
break
"""

def __init__(self, patience=3, min_abs_delta=1e-2, min_rel_delta=0.0):
"""Initializes the EarlyStopper object.
Args:
patience (int, optional): How many consecutive iterations to wait before stopping.
Defaults to 3.
patience (int, optional): How many consecutive iterations to wait before stopping. Defaults to 3.
min_abs_delta (float, optional): Absolute tolerance to the increase. Defaults to 1e-2.
min_rel_delta (float, optional): Relative tolerance to the increase. Defaults to 0.0.
"""
Expand All @@ -30,7 +50,7 @@ def __init__(self, patience=3, min_abs_delta=1e-2, min_rel_delta=0.0):
self.min_validation_loss = np.inf

def early_stop(self, validation_loss):
"""Check whether it is time to stop, and update the internal of EarlyStopper.
"""Checks whether it is time to stop, and updates the internal state of the EarlyStopper object.
Args:
validation_loss (float): Current validation loss
Expand Down
2 changes: 1 addition & 1 deletion bird_cloud_gnn/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def generate_data(
used internally for predicting the target class. Defaults to 300.0.
Returns:
pandas.DataFrames: Generated data. It was also saved to `filename` if that argument was
pandas.DataFrames: Generated data. It is also saved to `filename` if that argument is
passed.
"""

Expand Down
22 changes: 19 additions & 3 deletions bird_cloud_gnn/gnn_model.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
"""Module for creating GCN class"""

import os

import dgl
import numpy as np
from dgl.dataloading import GraphDataLoader
from dgl.nn.pytorch.conv import GraphConv
from torch import nn
from torch import optim
from torch import nn, optim
from torch.nn.modules import Module
from tqdm import tqdm


os.environ["DGLBACKEND"] = "pytorch"


Expand All @@ -19,6 +18,23 @@ class GCN(nn.Module):
A n-layer GCN is constructed from input features and list of layers
Each layer computes new node representations by aggregating neighbour information.
Args:
in_feats (int): the number of input features
layers_data (list): is a list of tuples of size of hidden layer and activation function
Attributes:
in_feats (int): the number of input features
layers (nn.ModuleList): list of layers
name (str): name of the model
num_classes (int): the last size should correspond to the number of classes were predicting
Methods:
oneline_description(): Description of the model to uniquely identify it in logs
forward(g, in_feats): Computes the output of the model.
fit(train_dataloader, learning_rate=0.01, num_epochs=20): Train the model.
evaluate(test_dataloader): Evaluate model.
fit_and_evaluate(train_dataloader, test_dataloader, callback=None, learning_rate=0.01, num_epochs=20, sch_explr_gamma=0.99, sch_multisteplr_milestones=None, sch_multisteplr_gamma=0.1): Fit the model while evaluating every iteraction.
"""

def __init__(self, in_feats: int, layers_data: list):
Expand Down

0 comments on commit 24a6a1e

Please sign in to comment.