Skip to content

Commit 5208bca

Browse files
committed
first commit
1 parent f3972ab commit 5208bca

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+827864
-0
lines changed

README.md

+19
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,25 @@ This repository contains the official pytorch implementation of ”*DeepMultiCap
77

88
![Teaser Image](assets/teaser.jpg)
99

10+
### News
11+
* **[2021/9/26]** We added more scans to [MultiHuman dataset](https://github.com/y-zheng18/MultiHuman-Dataset). You can use MultiHuman to train/fine-tune our model or your own models!
12+
* **[2021/9/18]** [MultiHuman dataset](https://github.com/y-zheng18/MultiHuman-Dataset) for evaluation purpose is available!
13+
14+
## Requirements
15+
- [PyTorch](https://pytorch.org/)
16+
- torchvision
17+
- trimesh
18+
- numpy
19+
- matplotlib
20+
- PIL
21+
- skimage
22+
- tqdm
23+
- cv2
24+
- json
25+
- taichi
26+
- taichi_three
27+
- taichi_glsl
28+
1029

1130
## Citation
1231
```

apps/train_dmc.py

+162
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import sys
2+
import os
3+
4+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
5+
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
6+
7+
import threading
8+
import time
9+
import json
10+
import numpy as np
11+
import cv2
12+
import random
13+
import torch
14+
from torch.utils.data import DataLoader
15+
from torch.nn import DataParallel
16+
from tqdm import tqdm
17+
import torch.nn.functional as F
18+
from torch.utils.tensorboard import SummaryWriter
19+
import matplotlib.pyplot as plt
20+
from multiprocessing import Process, Manager, Lock
21+
22+
from lib.options import parse_config
23+
from lib.mesh_util import *
24+
from lib.sample_util import *
25+
from lib.train_util import *
26+
from lib.data import *
27+
from lib.model import *
28+
29+
# get options
30+
opt = parse_config()
31+
32+
def train(opt):
33+
np.random.seed(int(time.time()))
34+
random.seed(int(time.time()))
35+
torch.manual_seed(int(time.time()))
36+
# set cuda
37+
log = SummaryWriter(opt.log_path)
38+
total_iteration = 0
39+
cuda = torch.device('cuda:%s' % opt.gpu_ids[0])
40+
netG = DMCNet(opt, projection_mode='perspective').to(cuda)
41+
netN = NormalNet().to(cuda)
42+
print('Using Network: ', netG.name, netN.name)
43+
gpu_ids = [int(i) for i in opt.gpu_ids.split(',')]
44+
netG = DataParallel(netG, device_ids=gpu_ids)
45+
netN = DataParallel(netN, device_ids=gpu_ids)
46+
47+
optimizerG = torch.optim.Adam(netG.parameters(), lr=opt.learning_rate)
48+
lr = opt.learning_rate
49+
50+
def set_train():
51+
netG.train()
52+
53+
if opt.load_netG_checkpoint_path is not None:
54+
print('loading for net G ...', opt.load_netG_checkpoint_path)
55+
netG.load_state_dict(torch.load(opt.load_netG_checkpoint_path, map_location=cuda), strict=False)
56+
57+
if opt.load_netN_checkpoint_path is not None:
58+
print('loading for net N ...', opt.load_netN_checkpoint_path)
59+
netN.load_state_dict(torch.load(opt.load_netN_checkpoint_path, map_location=cuda), strict=False)
60+
61+
print("loaded finished!")
62+
63+
yaw_list = sorted(np.random.choice(range(360), 30))
64+
print(yaw_list)
65+
66+
train_dataset = DMCDataset(opt, cache_data=Manager().dict(), cache_data_lock=Lock(), phase='train', yaw_list=yaw_list)
67+
test_dataset = DMCDataset(opt, cache_data=Manager().dict(), cache_data_lock=Lock(), phase='test', yaw_list=yaw_list)
68+
69+
projection_mode = train_dataset.projection_mode
70+
print('projection_mode:', projection_mode)
71+
# create data loader
72+
train_data_loader = DataLoader(train_dataset,
73+
batch_size=opt.batch_size, shuffle=not opt.serial_batches,
74+
num_workers=opt.num_threads, pin_memory=opt.pin_memory)
75+
print('train data size: ', len(train_data_loader))
76+
77+
# NOTE: batch size should be 1 and use all the points for evaluation
78+
test_data_loader = DataLoader(test_dataset,
79+
batch_size=1, shuffle=False,
80+
num_workers=opt.num_threads, pin_memory=opt.pin_memory)
81+
print('test data size: ', len(test_data_loader))
82+
83+
os.makedirs(opt.checkpoints_path, exist_ok=True)
84+
os.makedirs(opt.results_path, exist_ok=True)
85+
os.makedirs('%s/%s' % (opt.checkpoints_path, opt.name), exist_ok=True)
86+
os.makedirs('%s/%s' % (opt.results_path, opt.name), exist_ok=True)
87+
88+
opt_log = os.path.join(opt.results_path, opt.name, 'opt.txt')
89+
with open(opt_log, 'w') as outfile:
90+
outfile.write(json.dumps(vars(opt), indent=2))
91+
92+
# training
93+
start_epoch = 0
94+
print("start training......")
95+
96+
for epoch in range(start_epoch, opt.num_epoch):
97+
epoch_start_time = time.time()
98+
set_train()
99+
iter_data_time = time.time()
100+
np.random.seed(int(time.time()))
101+
random.seed(int(time.time()))
102+
torch.manual_seed(int(time.time()))
103+
train_bar = tqdm(enumerate(train_data_loader))
104+
for train_idx, train_data in train_bar:
105+
total_iteration += 1
106+
iter_start_time = time.time()
107+
# retrieve the data
108+
for key in train_data:
109+
if torch.is_tensor(train_data[key]):
110+
train_data[key] = train_data[key].to(device=cuda)
111+
112+
# predict normal
113+
with torch.no_grad():
114+
net_normal = netN.forward(train_data['image'])
115+
net_normal = net_normal * train_data['mask']
116+
117+
train_data['normal'] = net_normal.detach()
118+
res, error = netG.forward(train_data)
119+
optimizerG.zero_grad()
120+
if len(gpu_ids) > 1:
121+
error = error.sum()
122+
error.backward()
123+
optimizerG.step()
124+
125+
iter_net_time = time.time()
126+
eta = ((iter_net_time - epoch_start_time) / (train_idx + 1)) * len(train_data_loader) - (
127+
iter_net_time - epoch_start_time)
128+
129+
log.add_scalar('loss', error.item() / len(gpu_ids), total_iteration)
130+
if train_idx % opt.freq_plot == 0:
131+
descrip = 'Name: {0} | Epoch: {1} | {2}/{3} | Err: {4:.06f} | LR: {5:.06f} | Sigma: {6:.02f} | dataT: {7:.05f} | netT: {8:.05f} | ETA: {9:02d}:{10:02d}'.format(
132+
opt.name, epoch, train_idx, len(train_data_loader), error.item() / len(gpu_ids), lr, opt.sigma,
133+
iter_start_time - iter_data_time,
134+
iter_net_time - iter_start_time, int(eta // 60),
135+
int(eta - 60 * (eta // 60)))
136+
train_bar.set_description(descrip)
137+
138+
if train_idx % opt.freq_save == 0:
139+
torch.save(netG.state_dict(), '%s/%s/netG_latest' % (opt.checkpoints_path, opt.name))
140+
torch.save(netG.state_dict(), '%s/%s/netG_epoch_%d' % (opt.checkpoints_path, opt.name, epoch))
141+
torch.save(optimizerG.state_dict(), '%s/%s/optim_latest' % (opt.checkpoints_path, opt.name))
142+
torch.save(optimizerG.state_dict(), '%s/%s/optim_epoch_%d' % (opt.checkpoints_path, opt.name, epoch))
143+
144+
if train_idx % opt.freq_save_ply == 0:
145+
save_path = '%s/%s/pred.ply' % (opt.results_path, opt.name)
146+
r = res[0].cpu()
147+
points = train_data['samples'][0].transpose(0, 1).cpu()
148+
save_samples_truncted_prob(save_path, points.detach().numpy(), r.detach().numpy())
149+
150+
iter_data_time = time.time()
151+
152+
# update learning rate
153+
lr = adjust_learning_rate(optimizerG, epoch, lr, opt.schedule, opt.gamma)
154+
train_dataset.clear_cache()
155+
156+
yaw_list = sorted(np.random.choice(range(360), 30))
157+
train_dataset.yaw_list = yaw_list
158+
log.close()
159+
160+
161+
if __name__ == '__main__':
162+
train(opt)

apps/train_normal_net.py

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import sys
2+
import os
3+
4+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
5+
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
6+
7+
import threading
8+
import time
9+
import json
10+
import numpy as np
11+
import cv2
12+
import random
13+
import torch
14+
from torch.utils.data import DataLoader
15+
from torch.nn import DataParallel
16+
from tqdm import tqdm
17+
from torch.utils.tensorboard import SummaryWriter
18+
import torch.nn.functional as F
19+
import matplotlib.pyplot as plt
20+
21+
from lib.mesh_util import *
22+
from lib.sample_util import *
23+
from lib.train_util import *
24+
from lib.data import *
25+
from lib.model import *
26+
from lib.geometry import index
27+
from lib.loss_util import VGGPerceptualLoss
28+
from lib.options import parse_config
29+
30+
# get options
31+
opt = parse_config()
32+
log = SummaryWriter(opt.log_path)
33+
34+
def train(opt):
35+
36+
gpu_ids = [int(i) for i in opt.gpu_ids.split(',')]
37+
cuda = torch.device("cuda:%d" % gpu_ids[0])
38+
netN = NormalNet()
39+
dataset = NormalDataset(opt)
40+
train_data_loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=not opt.serial_batches,
41+
num_workers=opt.num_threads, pin_memory=opt.pin_memory)
42+
43+
netN.to(cuda)
44+
netN = DataParallel(netN, device_ids=gpu_ids)
45+
46+
if opt.load_netN_checkpoint_path is not None:
47+
netN.load_state_dict(torch.load(opt.load_netN_checkpoint_path), strict=False)
48+
49+
lr = opt.learning_rate
50+
optimizerN = torch.optim.Adam(netN.parameters(), lr=lr)
51+
52+
os.makedirs(opt.checkpoints_path, exist_ok=True)
53+
os.makedirs(opt.results_path, exist_ok=True)
54+
os.makedirs('%s/%s' % (opt.checkpoints_path, opt.name), exist_ok=True)
55+
os.makedirs('%s/%s' % (opt.results_path, opt.name), exist_ok=True)
56+
57+
EPOCH = 100
58+
total_iteration = 0
59+
60+
perceptual_loss = VGGPerceptualLoss().to(cuda)
61+
for epoch in range(EPOCH):
62+
epoch_start_time = time.time()
63+
iter_data_time = time.time()
64+
train_bar = tqdm(enumerate(train_data_loader))
65+
for train_idx, train_data in train_bar:
66+
67+
total_iteration += 1
68+
iter_start_time = time.time()
69+
70+
# retrieve the data
71+
image_tensor = train_data['img'].to(device=cuda)
72+
normal_tensor = train_data['normal'].to(device=cuda)
73+
mask_tensor = train_data['mask'].to(device=cuda)
74+
75+
res = netN.forward(image_tensor)
76+
res = res * mask_tensor
77+
error = F.l1_loss(normal_tensor, res)
78+
perceptual_error = perceptual_loss(normal_tensor.squeeze(1), res.squeeze(1))
79+
80+
error = 5*error + perceptual_error
81+
82+
optimizerN.zero_grad()
83+
error.backward()
84+
optimizerN.step()
85+
86+
iter_net_time = time.time()
87+
eta = ((iter_net_time - epoch_start_time) / (train_idx + 1)) * len(train_data_loader) - (
88+
iter_net_time - epoch_start_time)
89+
90+
log.add_scalar('loss', error.item(), total_iteration)
91+
if train_idx % opt.freq_plot == 0:
92+
descrip = 'Name: {0} | Epoch: {1} | {2}/{3} | Err: {4:.06f} | LR: {5:.06f} | Sigma: {6:.02f} | dataT: {7:.05f} | netT: {8:.05f} | ETA: {9:02d}:{10:02d}'.format(
93+
opt.name, epoch, train_idx, len(train_data_loader), error.item(), lr, opt.sigma,
94+
iter_start_time - iter_data_time,
95+
iter_net_time - iter_start_time, int(eta // 60),
96+
int(eta - 60 * (eta // 60)))
97+
train_bar.set_description(descrip)
98+
99+
if train_idx % opt.freq_save == 0 and train_idx != 0:
100+
torch.save(netN.state_dict(), '%s/%s/netN_latest' % (opt.checkpoints_path, opt.name))
101+
torch.save(netN.state_dict(), '%s/%s/netN_epoch_%d' % (opt.checkpoints_path, opt.name, epoch))
102+
103+
if train_idx % opt.freq_normal_show == 0:
104+
show_img = (image_tensor[0][0].cpu().detach().permute(1, 2, 0) + 0.5).numpy()
105+
net_normal = (res[0][0].cpu().detach().permute(1, 2, 0) + 0.5).numpy()
106+
gt_normal = (normal_tensor[0][0].cpu().detach().permute(1, 2, 0) + 0.5).numpy()
107+
plt.subplot(131)
108+
plt.imshow(show_img)
109+
plt.subplot(132)
110+
plt.imshow(net_normal)
111+
plt.subplot(133)
112+
plt.imshow(gt_normal)
113+
plt.savefig('%s/%s/epoch%03d_%05d.jpg' % (opt.results_path, opt.name, epoch, train_idx))
114+
plt.close('all')
115+
116+
iter_data_time = time.time()
117+
118+
if __name__ == "__main__":
119+
train(opt)

0 commit comments

Comments
 (0)