Skip to content

Transfer Learning Shootout for PyTorch's model zoo (torchvision)

License

Notifications You must be signed in to change notification settings

silakanveli/pytorch-retraining

 
 

Repository files navigation

pytorch-retraining

Transfer Learning shootout for PyTorch's model zoo (torchvision).

  • Load any pretrained model with custom final layer (num_classes) from PyTorch's model zoo in one line
model_pretrained, diff = load_model_merged('inception_v3', num_classes)
  • Retrain minimal (as inferred on load) or a custom amount of layers on multiple GPUs. Optionally with Cyclical Learning Rate (Smith 2017).
final_param_names = [d[0] for d in diff]
stats = train_eval(model_pretrained, trainloader, testloader, final_params_names)
  • Chart training_time, evaluation_time (fps), top-1 accuracy for varying levels of retraining depth (shallow, deep and from scratch)
chart
Transfer learning on example dataset Bee vs Ants with 2xK80 GPUs

Results on more elaborate Dataset

num_classes = 23, slightly unbalanced, high variance in rotation and motion blur artifacts with 1xGTX1080Ti

chart_17
Constant LR with momentum
chart_17_clr
Cyclical Learning Rate

About

Transfer Learning Shootout for PyTorch's model zoo (torchvision)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Jupyter Notebook 77.5%
  • Python 22.5%