1
+ import sys
2
+ import os
3
+
4
+ sys .path .insert (0 , os .path .abspath (os .path .join (os .path .dirname (__file__ ), '..' )))
5
+ ROOT_PATH = os .path .dirname (os .path .dirname (os .path .abspath (__file__ )))
6
+
7
+ import threading
8
+ import time
9
+ import json
10
+ import numpy as np
11
+ import cv2
12
+ import random
13
+ import torch
14
+ from torch .utils .data import DataLoader
15
+ from torch .nn import DataParallel
16
+ from tqdm import tqdm
17
+ import torch .nn .functional as F
18
+ from torch .utils .tensorboard import SummaryWriter
19
+ import matplotlib .pyplot as plt
20
+ from multiprocessing import Process , Manager , Lock
21
+
22
+ from lib .options import parse_config
23
+ from lib .mesh_util import *
24
+ from lib .sample_util import *
25
+ from lib .train_util import *
26
+ from lib .data import *
27
+ from lib .model import *
28
+
29
+ # get options
30
+ opt = parse_config ()
31
+
32
+ def train (opt ):
33
+ np .random .seed (int (time .time ()))
34
+ random .seed (int (time .time ()))
35
+ torch .manual_seed (int (time .time ()))
36
+ # set cuda
37
+ log = SummaryWriter (opt .log_path )
38
+ total_iteration = 0
39
+ cuda = torch .device ('cuda:%s' % opt .gpu_ids [0 ])
40
+ netG = DMCNet (opt , projection_mode = 'perspective' ).to (cuda )
41
+ netN = NormalNet ().to (cuda )
42
+ print ('Using Network: ' , netG .name , netN .name )
43
+ gpu_ids = [int (i ) for i in opt .gpu_ids .split (',' )]
44
+ netG = DataParallel (netG , device_ids = gpu_ids )
45
+ netN = DataParallel (netN , device_ids = gpu_ids )
46
+
47
+ optimizerG = torch .optim .Adam (netG .parameters (), lr = opt .learning_rate )
48
+ lr = opt .learning_rate
49
+
50
+ def set_train ():
51
+ netG .train ()
52
+
53
+ if opt .load_netG_checkpoint_path is not None :
54
+ print ('loading for net G ...' , opt .load_netG_checkpoint_path )
55
+ netG .load_state_dict (torch .load (opt .load_netG_checkpoint_path , map_location = cuda ), strict = False )
56
+
57
+ if opt .load_netN_checkpoint_path is not None :
58
+ print ('loading for net N ...' , opt .load_netN_checkpoint_path )
59
+ netN .load_state_dict (torch .load (opt .load_netN_checkpoint_path , map_location = cuda ), strict = False )
60
+
61
+ print ("loaded finished!" )
62
+
63
+ yaw_list = sorted (np .random .choice (range (360 ), 30 ))
64
+ print (yaw_list )
65
+
66
+ train_dataset = DMCDataset (opt , cache_data = Manager ().dict (), cache_data_lock = Lock (), phase = 'train' , yaw_list = yaw_list )
67
+ test_dataset = DMCDataset (opt , cache_data = Manager ().dict (), cache_data_lock = Lock (), phase = 'test' , yaw_list = yaw_list )
68
+
69
+ projection_mode = train_dataset .projection_mode
70
+ print ('projection_mode:' , projection_mode )
71
+ # create data loader
72
+ train_data_loader = DataLoader (train_dataset ,
73
+ batch_size = opt .batch_size , shuffle = not opt .serial_batches ,
74
+ num_workers = opt .num_threads , pin_memory = opt .pin_memory )
75
+ print ('train data size: ' , len (train_data_loader ))
76
+
77
+ # NOTE: batch size should be 1 and use all the points for evaluation
78
+ test_data_loader = DataLoader (test_dataset ,
79
+ batch_size = 1 , shuffle = False ,
80
+ num_workers = opt .num_threads , pin_memory = opt .pin_memory )
81
+ print ('test data size: ' , len (test_data_loader ))
82
+
83
+ os .makedirs (opt .checkpoints_path , exist_ok = True )
84
+ os .makedirs (opt .results_path , exist_ok = True )
85
+ os .makedirs ('%s/%s' % (opt .checkpoints_path , opt .name ), exist_ok = True )
86
+ os .makedirs ('%s/%s' % (opt .results_path , opt .name ), exist_ok = True )
87
+
88
+ opt_log = os .path .join (opt .results_path , opt .name , 'opt.txt' )
89
+ with open (opt_log , 'w' ) as outfile :
90
+ outfile .write (json .dumps (vars (opt ), indent = 2 ))
91
+
92
+ # training
93
+ start_epoch = 0
94
+ print ("start training......" )
95
+
96
+ for epoch in range (start_epoch , opt .num_epoch ):
97
+ epoch_start_time = time .time ()
98
+ set_train ()
99
+ iter_data_time = time .time ()
100
+ np .random .seed (int (time .time ()))
101
+ random .seed (int (time .time ()))
102
+ torch .manual_seed (int (time .time ()))
103
+ train_bar = tqdm (enumerate (train_data_loader ))
104
+ for train_idx , train_data in train_bar :
105
+ total_iteration += 1
106
+ iter_start_time = time .time ()
107
+ # retrieve the data
108
+ for key in train_data :
109
+ if torch .is_tensor (train_data [key ]):
110
+ train_data [key ] = train_data [key ].to (device = cuda )
111
+
112
+ # predict normal
113
+ with torch .no_grad ():
114
+ net_normal = netN .forward (train_data ['image' ])
115
+ net_normal = net_normal * train_data ['mask' ]
116
+
117
+ train_data ['normal' ] = net_normal .detach ()
118
+ res , error = netG .forward (train_data )
119
+ optimizerG .zero_grad ()
120
+ if len (gpu_ids ) > 1 :
121
+ error = error .sum ()
122
+ error .backward ()
123
+ optimizerG .step ()
124
+
125
+ iter_net_time = time .time ()
126
+ eta = ((iter_net_time - epoch_start_time ) / (train_idx + 1 )) * len (train_data_loader ) - (
127
+ iter_net_time - epoch_start_time )
128
+
129
+ log .add_scalar ('loss' , error .item () / len (gpu_ids ), total_iteration )
130
+ if train_idx % opt .freq_plot == 0 :
131
+ descrip = 'Name: {0} | Epoch: {1} | {2}/{3} | Err: {4:.06f} | LR: {5:.06f} | Sigma: {6:.02f} | dataT: {7:.05f} | netT: {8:.05f} | ETA: {9:02d}:{10:02d}' .format (
132
+ opt .name , epoch , train_idx , len (train_data_loader ), error .item () / len (gpu_ids ), lr , opt .sigma ,
133
+ iter_start_time - iter_data_time ,
134
+ iter_net_time - iter_start_time , int (eta // 60 ),
135
+ int (eta - 60 * (eta // 60 )))
136
+ train_bar .set_description (descrip )
137
+
138
+ if train_idx % opt .freq_save == 0 :
139
+ torch .save (netG .state_dict (), '%s/%s/netG_latest' % (opt .checkpoints_path , opt .name ))
140
+ torch .save (netG .state_dict (), '%s/%s/netG_epoch_%d' % (opt .checkpoints_path , opt .name , epoch ))
141
+ torch .save (optimizerG .state_dict (), '%s/%s/optim_latest' % (opt .checkpoints_path , opt .name ))
142
+ torch .save (optimizerG .state_dict (), '%s/%s/optim_epoch_%d' % (opt .checkpoints_path , opt .name , epoch ))
143
+
144
+ if train_idx % opt .freq_save_ply == 0 :
145
+ save_path = '%s/%s/pred.ply' % (opt .results_path , opt .name )
146
+ r = res [0 ].cpu ()
147
+ points = train_data ['samples' ][0 ].transpose (0 , 1 ).cpu ()
148
+ save_samples_truncted_prob (save_path , points .detach ().numpy (), r .detach ().numpy ())
149
+
150
+ iter_data_time = time .time ()
151
+
152
+ # update learning rate
153
+ lr = adjust_learning_rate (optimizerG , epoch , lr , opt .schedule , opt .gamma )
154
+ train_dataset .clear_cache ()
155
+
156
+ yaw_list = sorted (np .random .choice (range (360 ), 30 ))
157
+ train_dataset .yaw_list = yaw_list
158
+ log .close ()
159
+
160
+
161
+ if __name__ == '__main__' :
162
+ train (opt )
0 commit comments