forked from valeoai/SP4ASC
-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain.py
111 lines (99 loc) · 3.07 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import torch
import tarfile
import argparse
from torch.utils.data import DataLoader
from sp4asc.datasets.dcase import DCaseDataset
from sp4asc.models import get_net
from sp4asc.models.cnns import LogMelSpectrogram
from sp4asc.training import TrainingManager
def make_tarfile(output_filename, source_dir):
with tarfile.open(output_filename, "w:gz") as tar:
for file in os.listdir(source_dir):
if file == "data":
pass
elif file.split(".")[-1] == "egg-info":
pass
else:
tar.add(os.path.join(source_dir, file))
if __name__ == "__main__":
# --- Args
parser = argparse.ArgumentParser(description="Training with mixup")
parser.add_argument(
"--config",
type=str,
default="configs/example.py",
help="Path to config file describing training parameters",
)
args = parser.parse_args()
# ---
print("Training script: ", os.path.realpath(__file__))
# --- Config
name_config = args.config.replace(".py", "").replace(os.path.sep, ".")
config = __import__(name_config, fromlist=["config"]).config
print("Config parameters:")
print(config)
# --- Log dir
path2log = config["out_dir"] + name_config
os.makedirs(path2log, exist_ok=True)
make_tarfile(path2log + "/src.tgz", os.path.dirname(os.path.realpath(__file__)))
# ---
current_dir = os.path.dirname(os.path.abspath(__file__))
train_dataset = DCaseDataset(
current_dir + "/data/TAU-urban-acoustic-scenes-2020-mobile-development/",
split="train",
)
test_dataset = DCaseDataset(
current_dir + "/data/TAU-urban-acoustic-scenes-2020-mobile-development/",
split="val",
)
loader_train = DataLoader(
train_dataset,
batch_size=config["batchsize"],
shuffle=True,
pin_memory=True,
num_workers=config["num_workers"],
drop_last=True,
)
loader_test = DataLoader(
test_dataset,
batch_size=config["batchsize"],
shuffle=False,
pin_memory=True,
num_workers=config["num_workers"],
drop_last=False,
)
# --- Get network
spectrogram = LogMelSpectrogram()
net = get_net[config["net"]](
config["dropout"],
config["specAugment"],
)
print("\n\nNet at training time")
print(net)
print("Nb. of parameters at training time: ", net.get_nb_parameters() / 1e3, "k")
# ---
optim = torch.optim.AdamW(
[
{"params": net.parameters()},
],
lr=config["lr"],
weight_decay=config["weight_decay"],
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optim,
config["max_epoch"],
eta_min=config["eta_min"],
)
# --- Training
mng = TrainingManager(
net,
spectrogram,
loader_train,
loader_test,
optim,
scheduler,
config,
path2log,
)
mng.train()