-
Notifications
You must be signed in to change notification settings - Fork 4
/
quick_start1.py
98 lines (84 loc) · 4.35 KB
/
quick_start1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# A toy example to show how to train SiamParseNet in a fully-supervised way
import os
import torch
from SPN.SPNet import SPNet, resnet_feature_layers
from ProUtils.misc import vl2ch, get_1x_lr_params, get_10x_lr_params, loss_calc, dist_loss_calc
import torch.optim as optim
# setting GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# using fully-supervised SPN for training
model = SPNet()
model = model.cuda()
# finetuning from pretrained DeepLab is important for the training of SPN
# set use_finetune True once you set the path
# setting the path to pretrained DeepLab model
# which can be downloaded from
# https://drive.google.com/file/d/0BxhUwxvLPO7TVFJQU1dwbXhHdEk/view?resourcekey=0-7UxnHrm5eDCyvz2G35aKgA
use_finetune = False
if use_finetune:
pretrained_deeplab_path = ""
# load pretrained PropNet Model
saved_matnet_state_dict = torch.load(pretrained_deeplab_path)
for name, _ in model.CommonFeatureExtract.state_dict().items():
if "num_batches_tracked" in name.split('.'):
continue
ckpt_name = name.replace("model." + name.split('.')[1], "Scale." + resnet_feature_layers[int(name.split('.')[1])])
model.CommonFeatureExtract.state_dict()[name].copy_(saved_matnet_state_dict[ckpt_name])
for name, _ in model.MatFeatureExtract.state_dict().items():
if "num_batches_tracked" in name.split('.'):
continue
ckpt_name = name.replace("model.", "Scale.layer4.")
model.MatFeatureExtract.state_dict()[name].copy_(saved_matnet_state_dict[ckpt_name])
# load pretrained SegNet Model
saved_segnet_state_dict = torch.load(pretrained_deeplab_path)
for name, _ in model.SegFeatureExtract.state_dict().items():
if "num_batches_tracked" in name.split('.'):
continue
ckpt_name = name.replace("model", "Scale.layer4")
model.SegFeatureExtract.state_dict()[name].copy_(saved_segnet_state_dict[ckpt_name])
# freeze BN layers
for name, param in model.named_parameters():
if name.find('bn') != -1:
param.requires_grad = False
model.eval() # evaluation mode for BN layers
# setting learning rate
learning_rate = 2.5e-4
momentum = 0.9
weight_decay = 0.0005
optimizer = optim.SGD([{'params': get_1x_lr_params(model), 'lr': learning_rate},
{'params': get_10x_lr_params(model), 'lr': 10 * learning_rate}],
lr=learning_rate,
momentum=momentum,
weight_decay=weight_decay)
optimizer.zero_grad()
# setting x_m (src_img_batch) and x_n (tar_img_batch)
# size is (batch, channel, height, width)
src_img_batch = torch.rand((2, 3, 256, 256)).cuda()
tar_img_batch = torch.rand((2, 3, 256, 256)).cuda()
# setting y_m (src_lbl_batch) and y_n (tar_lbl_batch)
# size is (batch, height, width)
src_lbl_batch = torch.randint(high=5, size=(2, 256, 256)).cuda().to(dtype=torch.float32)
tar_lbl_batch = torch.randint(high=5, size=(2, 256, 256)).cuda().to(dtype=torch.float32)
# change to one-hot representation
# size is (batch, channel, height, width)
src_lbl_batch_resize = vl2ch(src_lbl_batch).cuda()
tar_lbl_batch_resize = vl2ch(tar_lbl_batch).cuda()
# One forward during training
# outputs are segmentation maps from segmentation branch (*_img_seg_lbl) and propagation branch (*_img_mat_lbl)
src_img_seg_lbl, tar_img_seg_lbl, src_img_mat_lbl, tar_img_mat_lbl = model(src_img_batch,
tar_img_batch,
src_lbl_batch_resize,
tar_lbl_batch_resize)
# we suggest to use class weight, here we set None for simplicity
# segmentation branch loss
src_seg_loss = loss_calc(src_img_seg_lbl, src_lbl_batch_resize, class_weight=None)
tar_seg_loss = loss_calc(tar_img_seg_lbl, tar_lbl_batch_resize, class_weight=None)
# propagation branch loss
src_mat_loss = loss_calc(src_img_mat_lbl, src_lbl_batch_resize, class_weight=None)
tar_mat_loss = loss_calc(tar_img_mat_lbl, tar_lbl_batch_resize, class_weight=None)
# consistency loss between two branches
src_loss = dist_loss_calc(src_img_seg_lbl, src_img_mat_lbl, class_weight=None)
tar_loss = dist_loss_calc(tar_img_seg_lbl, tar_img_mat_lbl, class_weight=None)
loss = src_seg_loss + tar_seg_loss + src_mat_loss + tar_mat_loss + src_loss + tar_loss
loss.backward()
optimizer.step()