1+ import sys
2+ import os
3+
4+ sys .path .insert (0 , os .path .abspath (os .path .join (os .path .dirname (__file__ ), '..' )))
5+ ROOT_PATH = os .path .dirname (os .path .dirname (os .path .abspath (__file__ )))
6+
7+ import threading
8+ import time
9+ import json
10+ import numpy as np
11+ import cv2
12+ import random
13+ import torch
14+ from torch .utils .data import DataLoader
15+ from torch .nn import DataParallel
16+ from tqdm import tqdm
17+ import torch .nn .functional as F
18+ from torch .utils .tensorboard import SummaryWriter
19+ import matplotlib .pyplot as plt
20+ from multiprocessing import Process , Manager , Lock
21+
22+ from lib .options import parse_config
23+ from lib .mesh_util import *
24+ from lib .sample_util import *
25+ from lib .train_util import *
26+ from lib .data import *
27+ from lib .model import *
28+
29+ # get options
30+ opt = parse_config ()
31+
32+ def train (opt ):
33+ np .random .seed (int (time .time ()))
34+ random .seed (int (time .time ()))
35+ torch .manual_seed (int (time .time ()))
36+ # set cuda
37+ log = SummaryWriter (opt .log_path )
38+ total_iteration = 0
39+ cuda = torch .device ('cuda:%s' % opt .gpu_ids [0 ])
40+ netG = DMCNet (opt , projection_mode = 'perspective' ).to (cuda )
41+ netN = NormalNet ().to (cuda )
42+ print ('Using Network: ' , netG .name , netN .name )
43+ gpu_ids = [int (i ) for i in opt .gpu_ids .split (',' )]
44+ netG = DataParallel (netG , device_ids = gpu_ids )
45+ netN = DataParallel (netN , device_ids = gpu_ids )
46+
47+ optimizerG = torch .optim .Adam (netG .parameters (), lr = opt .learning_rate )
48+ lr = opt .learning_rate
49+
50+ def set_train ():
51+ netG .train ()
52+
53+ if opt .load_netG_checkpoint_path is not None :
54+ print ('loading for net G ...' , opt .load_netG_checkpoint_path )
55+ netG .load_state_dict (torch .load (opt .load_netG_checkpoint_path , map_location = cuda ), strict = False )
56+
57+ if opt .load_netN_checkpoint_path is not None :
58+ print ('loading for net N ...' , opt .load_netN_checkpoint_path )
59+ netN .load_state_dict (torch .load (opt .load_netN_checkpoint_path , map_location = cuda ), strict = False )
60+
61+ print ("loaded finished!" )
62+
63+ yaw_list = sorted (np .random .choice (range (360 ), 30 ))
64+ print (yaw_list )
65+
66+ train_dataset = DMCDataset (opt , cache_data = Manager ().dict (), cache_data_lock = Lock (), phase = 'train' , yaw_list = yaw_list )
67+ test_dataset = DMCDataset (opt , cache_data = Manager ().dict (), cache_data_lock = Lock (), phase = 'test' , yaw_list = yaw_list )
68+
69+ projection_mode = train_dataset .projection_mode
70+ print ('projection_mode:' , projection_mode )
71+ # create data loader
72+ train_data_loader = DataLoader (train_dataset ,
73+ batch_size = opt .batch_size , shuffle = not opt .serial_batches ,
74+ num_workers = opt .num_threads , pin_memory = opt .pin_memory )
75+ print ('train data size: ' , len (train_data_loader ))
76+
77+ # NOTE: batch size should be 1 and use all the points for evaluation
78+ test_data_loader = DataLoader (test_dataset ,
79+ batch_size = 1 , shuffle = False ,
80+ num_workers = opt .num_threads , pin_memory = opt .pin_memory )
81+ print ('test data size: ' , len (test_data_loader ))
82+
83+ os .makedirs (opt .checkpoints_path , exist_ok = True )
84+ os .makedirs (opt .results_path , exist_ok = True )
85+ os .makedirs ('%s/%s' % (opt .checkpoints_path , opt .name ), exist_ok = True )
86+ os .makedirs ('%s/%s' % (opt .results_path , opt .name ), exist_ok = True )
87+
88+ opt_log = os .path .join (opt .results_path , opt .name , 'opt.txt' )
89+ with open (opt_log , 'w' ) as outfile :
90+ outfile .write (json .dumps (vars (opt ), indent = 2 ))
91+
92+ # training
93+ start_epoch = 0
94+ print ("start training......" )
95+
96+ for epoch in range (start_epoch , opt .num_epoch ):
97+ epoch_start_time = time .time ()
98+ set_train ()
99+ iter_data_time = time .time ()
100+ np .random .seed (int (time .time ()))
101+ random .seed (int (time .time ()))
102+ torch .manual_seed (int (time .time ()))
103+ train_bar = tqdm (enumerate (train_data_loader ))
104+ for train_idx , train_data in train_bar :
105+ total_iteration += 1
106+ iter_start_time = time .time ()
107+ # retrieve the data
108+ for key in train_data :
109+ if torch .is_tensor (train_data [key ]):
110+ train_data [key ] = train_data [key ].to (device = cuda )
111+
112+ # predict normal
113+ with torch .no_grad ():
114+ net_normal = netN .forward (train_data ['image' ])
115+ net_normal = net_normal * train_data ['mask' ]
116+
117+ train_data ['normal' ] = net_normal .detach ()
118+ res , error = netG .forward (train_data )
119+ optimizerG .zero_grad ()
120+ if len (gpu_ids ) > 1 :
121+ error = error .sum ()
122+ error .backward ()
123+ optimizerG .step ()
124+
125+ iter_net_time = time .time ()
126+ eta = ((iter_net_time - epoch_start_time ) / (train_idx + 1 )) * len (train_data_loader ) - (
127+ iter_net_time - epoch_start_time )
128+
129+ log .add_scalar ('loss' , error .item () / len (gpu_ids ), total_iteration )
130+ if train_idx % opt .freq_plot == 0 :
131+ descrip = 'Name: {0} | Epoch: {1} | {2}/{3} | Err: {4:.06f} | LR: {5:.06f} | Sigma: {6:.02f} | dataT: {7:.05f} | netT: {8:.05f} | ETA: {9:02d}:{10:02d}' .format (
132+ opt .name , epoch , train_idx , len (train_data_loader ), error .item () / len (gpu_ids ), lr , opt .sigma ,
133+ iter_start_time - iter_data_time ,
134+ iter_net_time - iter_start_time , int (eta // 60 ),
135+ int (eta - 60 * (eta // 60 )))
136+ train_bar .set_description (descrip )
137+
138+ if train_idx % opt .freq_save == 0 :
139+ torch .save (netG .state_dict (), '%s/%s/netG_latest' % (opt .checkpoints_path , opt .name ))
140+ torch .save (netG .state_dict (), '%s/%s/netG_epoch_%d' % (opt .checkpoints_path , opt .name , epoch ))
141+ torch .save (optimizerG .state_dict (), '%s/%s/optim_latest' % (opt .checkpoints_path , opt .name ))
142+ torch .save (optimizerG .state_dict (), '%s/%s/optim_epoch_%d' % (opt .checkpoints_path , opt .name , epoch ))
143+
144+ if train_idx % opt .freq_save_ply == 0 :
145+ save_path = '%s/%s/pred.ply' % (opt .results_path , opt .name )
146+ r = res [0 ].cpu ()
147+ points = train_data ['samples' ][0 ].transpose (0 , 1 ).cpu ()
148+ save_samples_truncted_prob (save_path , points .detach ().numpy (), r .detach ().numpy ())
149+
150+ iter_data_time = time .time ()
151+
152+ # update learning rate
153+ lr = adjust_learning_rate (optimizerG , epoch , lr , opt .schedule , opt .gamma )
154+ train_dataset .clear_cache ()
155+
156+ yaw_list = sorted (np .random .choice (range (360 ), 30 ))
157+ train_dataset .yaw_list = yaw_list
158+ log .close ()
159+
160+
161+ if __name__ == '__main__' :
162+ train (opt )
0 commit comments