forked from sczhou/ProPainter
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
105 lines (82 loc) · 3.45 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import json
import argparse
import subprocess
from shutil import copyfile
import torch.distributed as dist
import torch
import torch.multiprocessing as mp
import core
import core.trainer
import core.trainer_flow_w_edge
# import warnings
# warnings.filterwarnings("ignore")
from core.dist import (
get_world_size,
get_local_rank,
get_global_rank,
get_master_ip,
)
parser = argparse.ArgumentParser(description='E2FGVI')
parser.add_argument('-c',
'--config',
default='configs/train_e2fgvi.json',
type=str)
parser.add_argument('-p', '--port', default='23490', type=str)
args = parser.parse_args()
def main_worker(rank, config):
if 'local_rank' not in config:
config['local_rank'] = config['global_rank'] = rank
if config['distributed']:
torch.cuda.set_device(int(config['local_rank']))
torch.distributed.init_process_group(backend='nccl',
init_method=config['init_method'],
world_size=config['world_size'],
rank=config['global_rank'],
group_name='mtorch')
print('using GPU {}-{} for training'.format(int(config['global_rank']),
int(config['local_rank'])))
config['save_dir'] = os.path.join(
config['save_dir'],
'{}_{}'.format(config['model']['net'],
os.path.basename(args.config).split('.')[0]))
config['save_metric_dir'] = os.path.join(
'./scores',
'{}_{}'.format(config['model']['net'],
os.path.basename(args.config).split('.')[0]))
if torch.cuda.is_available():
config['device'] = torch.device("cuda:{}".format(config['local_rank']))
else:
config['device'] = 'cpu'
if (not config['distributed']) or config['global_rank'] == 0:
os.makedirs(config['save_dir'], exist_ok=True)
config_path = os.path.join(config['save_dir'],
args.config.split('/')[-1])
if not os.path.isfile(config_path):
copyfile(args.config, config_path)
print('[**] create folder {}'.format(config['save_dir']))
trainer_version = config['trainer']['version']
trainer = core.__dict__[trainer_version].__dict__['Trainer'](config)
# Trainer(config)
trainer.train()
if __name__ == "__main__":
torch.backends.cudnn.benchmark = True
mp.set_sharing_strategy('file_system')
# loading configs
config = json.load(open(args.config))
# setting distributed configurations
# config['world_size'] = get_world_size()
config['world_size'] = torch.cuda.device_count()
config['init_method'] = f"tcp://{get_master_ip()}:{args.port}"
config['distributed'] = True if config['world_size'] > 1 else False
print('world_size:', config['world_size'])
# setup distributed parallel training environments
# if get_master_ip() == "127.0.0.X":
# # manually launch distributed processes
# mp.spawn(main_worker, nprocs=config['world_size'], args=(config, ))
# else:
# # multiple processes have been launched by openmpi
# config['local_rank'] = get_local_rank()
# config['global_rank'] = get_global_rank()
# main_worker(-1, config)
mp.spawn(main_worker, nprocs=torch.cuda.device_count(), args=(config, ))