forked from facebookresearch/moco
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain_minimoco.py
128 lines (103 loc) · 5.38 KB
/
main_minimoco.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from datetime import datetime
#from functools import partial
#from PIL import Image
from torch.utils.data import DataLoader
#from torchvision import transforms
from torchvision.datasets import CIFAR10
#from torchvision.models import resnet
from tqdm import tqdm
import argparse
import json
#import math
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from moco.builder_mini import ModelMoCo
from moco.loader_mini import CIFAR10Pair, train_transform, test_transform
from moco.trainer_mini import train, test
ROOT_DIR = Path.home()
EXPE_DIR = os.path.join(ROOT_DIR, 'experiments-moco/')
parser = argparse.ArgumentParser(description='Train MoCo on CIFAR-10')
parser.add_argument('-a', '--arch', default='resnet18')
# lr: 0.06 for batch 512 (or 0.03 for batch 256)
parser.add_argument('--lr', '--learning-rate', default=0.06, type=float, metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run')
parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int, help='learning rate schedule (when to drop lr by 10x); does not take effect if --cos is on')
parser.add_argument('--cos', action='store_true', help='use cosine lr schedule')
parser.add_argument('--batch-size', default=512, type=int, metavar='N', help='mini-batch size')
parser.add_argument('--wd', default=5e-4, type=float, metavar='W', help='weight decay')
# moco specific configs:
parser.add_argument('--moco-dim', default=128, type=int, help='feature dimension')
parser.add_argument('--moco-k', default=4096, type=int, help='queue size; number of negative keys')
parser.add_argument('--moco-m', default=0.99, type=float, help='moco momentum of updating key encoder')
parser.add_argument('--moco-t', default=0.1, type=float, help='softmax temperature')
parser.add_argument('--bn-splits', default=8, type=int, help='simulate multi-gpu behavior of BatchNorm in one gpu; 1 is SyncBatchNorm in multi-gpu')
parser.add_argument('--symmetric', action='store_true', help='use a symmetric loss function that backprops to both crops')
# knn monitor
parser.add_argument('--knn-k', default=200, type=int, help='k in kNN monitor')
parser.add_argument('--knn-t', default=0.1, type=float, help='softmax temperature in kNN monitor; could be different with moco-t')
# utils
parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('--results-dir', default='', type=str, metavar='PATH', help='path to cache (default: none)')
def main():
args = parser.parse_args() # running in command line
# set command line arguments here when running in ipynb
args.epochs = 200
args.cos = True
args.schedule = [] # cos in use
args.symmetric = False
args.results_dir = os.path.join(EXPE_DIR,'cache-' + datetime.now().strftime("%Y-%m-%d-%H-%M-%S-moco"))
print(args)
# create model
model = ModelMoCo(
dim=args.moco_dim,
K=args.moco_k,
m=args.moco_m,
T=args.moco_t,
arch=args.arch,
bn_splits=args.bn_splits,
symmetric=args.symmetric,
).cuda()
#print(model.encoder_q)
data_path = os.path.join(ROOT_DIR,'data/')
# data prepare
train_data = CIFAR10Pair(root=data_path, train=True, transform=train_transform, download=True)
train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=16, pin_memory=True, drop_last=True)
memory_data = CIFAR10(root=data_path, train=True, transform=test_transform, download=True)
memory_loader = DataLoader(memory_data, batch_size=args.batch_size, shuffle=False, num_workers=16, pin_memory=True)
test_data = CIFAR10(root=data_path, train=False, transform=test_transform, download=True)
test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=16, pin_memory=True)
# define optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wd, momentum=0.9)
# load model if resume
epoch_start = 1
if args.resume is not '':
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch_start = checkpoint['epoch'] + 1
print('Loaded from: {}'.format(args.resume))
# logging
results = {'train_loss': [], 'test_acc@1': []}
if not os.path.exists(args.results_dir):
os.mkdir(args.results_dir)
# dump args
with open(args.results_dir + '/args.json', 'w') as fid:
json.dump(args.__dict__, fid, indent=2)
# training loop
for epoch in range(epoch_start, args.epochs + 1):
train_loss = train(model, train_loader, optimizer, epoch, args)
results['train_loss'].append(train_loss)
test_acc_1 = test(model.encoder_q, memory_loader, test_loader, epoch, args)
results['test_acc@1'].append(test_acc_1)
# save statistics
data_frame = pd.DataFrame(data=results, index=range(epoch_start, epoch + 1))
data_frame.to_csv(args.results_dir + '/log.csv', index_label='epoch')
# save model
torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict(),}, args.results_dir + '/model_last.pth')
print("That's all Folks!")
if __name__ == '__main__':
main()