Skip to content

Commit

Permalink
Merge pull request #36 from jmisilo/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
jmisilo authored Oct 30, 2022
2 parents 9d5dd11 + c62905a commit 6b728f4
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 12 deletions.
1 change: 0 additions & 1 deletion .gitattributes

This file was deleted.

2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
/data/
__pycache__/
.vscode/
/weights/arch/
/weights/
/wandb/
test_generation.ipynb
13 changes: 9 additions & 4 deletions src/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from model.loops import evaluate_dataset
from model.model import Net
from utils.config import Config
from utils.load_ckp import load_ckp
from utils.load_ckp import download_weights

config = Config()
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -48,9 +48,7 @@

ckp_path = os.path.join(config.weights_dir, args.checkpoint_name)

assert os.path.isfile(ckp_path), 'Checkpoint does not exist'
assert os.path.exists(args.img_path), 'Path to the test image folder does not exist'
assert os.path.exists(args.res_path), 'Path to the results folder does not exist'

# set seed
random.seed(config.seed)
Expand Down Expand Up @@ -81,7 +79,14 @@

_, _, test_dataset = random_split(dataset, [config.train_size, config.val_size, config.test_size])

load_ckp(ckp_path, model, device=device)
if not os.path.exists(config.weights_dir):
os.makedirs(config.weights_dir)

if not os.path.isfile(ckp_path):
download_weights(ckp_path)

checkpoint = torch.load(ckp_path, map_location=device)
model.load_state_dict(checkpoint)

save_path = os.path.join(args.res_path, args.checkpoint_name[:-3])

Expand Down
14 changes: 10 additions & 4 deletions src/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import torch

from model.model import Net
from utils.load_ckp import download_weights
from utils.config import Config
from utils.load_ckp import load_ckp

config = Config()
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -59,7 +59,6 @@
ckp_path = os.path.join(config.weights_dir, args.checkpoint_name)

assert os.path.isfile(args.img_path), 'Image does not exist'
assert os.path.isfile(ckp_path), 'Checkpoint does not exist'

if not os.path.exists(args.res_path):
os.makedirs(args.res_path)
Expand All @@ -75,8 +74,15 @@
max_len=config.max_len,
device=device
)

load_ckp(ckp_path, model, device=device)

if not os.path.exists(config.weights_dir):
os.makedirs(config.weights_dir)

if not os.path.isfile(ckp_path):
download_weights(ckp_path)

checkpoint = torch.load(ckp_path, map_location=device)
model.load_state_dict(checkpoint)

model.eval()

Expand Down
12 changes: 10 additions & 2 deletions src/utils/load_ckp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
'''
Utility function to load checkpoint from coresponding file.
Utility functions for loading weights.
'''

import gdown
import torch

def load_ckp(checkpoint_fpath, model, optimizer=None, scheduler=None, scaler=None, device='cpu'):
Expand All @@ -21,4 +22,11 @@ def load_ckp(checkpoint_fpath, model, optimizer=None, scheduler=None, scaler=Non
if scaler is not None:
scaler.load_state_dict(checkpoint['scaler_state_dict'])

return checkpoint['epoch']
return checkpoint['epoch']

def download_weights(checkpoint_fpath):
'''
Downloads weights from Google Drive.
'''

gdown.download('https://drive.google.com/uc?id=1lEufQVOETFEIhPdFDYaez31uroq_5Lby', checkpoint_fpath, quiet=False)

0 comments on commit 6b728f4

Please sign in to comment.