forked from yaohungt/Multimodal-Transformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
152 lines (128 loc) · 5.99 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
import argparse
from src.utils import *
from torch.utils.data import DataLoader
from src import train
parser = argparse.ArgumentParser(description='MOSEI Sentiment Analysis')
parser.add_argument('-f', default='', type=str)
# Fixed
parser.add_argument('--model', type=str, default='MulT',
help='name of the model to use (Transformer, etc.)')
# Tasks
parser.add_argument('--vonly', action='store_true',
help='use the crossmodal fusion into v (default: False)')
parser.add_argument('--aonly', action='store_true',
help='use the crossmodal fusion into a (default: False)')
parser.add_argument('--lonly', action='store_true',
help='use the crossmodal fusion into l (default: False)')
parser.add_argument('--aligned', action='store_true',
help='consider aligned experiment or not (default: False)')
parser.add_argument('--dataset', type=str, default='mosei_senti',
help='dataset to use (default: mosei_senti)')
parser.add_argument('--data_path', type=str, default='data',
help='path for storing the dataset')
# Dropouts
parser.add_argument('--attn_dropout', type=float, default=0.1,
help='attention dropout')
parser.add_argument('--attn_dropout_a', type=float, default=0.0,
help='attention dropout (for audio)')
parser.add_argument('--attn_dropout_v', type=float, default=0.0,
help='attention dropout (for visual)')
parser.add_argument('--relu_dropout', type=float, default=0.1,
help='relu dropout')
parser.add_argument('--embed_dropout', type=float, default=0.25,
help='embedding dropout')
parser.add_argument('--res_dropout', type=float, default=0.1,
help='residual block dropout')
parser.add_argument('--out_dropout', type=float, default=0.0,
help='output layer dropout')
# Architecture
parser.add_argument('--nlevels', type=int, default=5,
help='number of layers in the network (default: 5)')
parser.add_argument('--num_heads', type=int, default=5,
help='number of heads for the transformer network (default: 5)')
parser.add_argument('--attn_mask', action='store_false',
help='use attention mask for Transformer (default: true)')
# Tuning
parser.add_argument('--batch_size', type=int, default=24, metavar='N',
help='batch size (default: 24)')
parser.add_argument('--clip', type=float, default=0.8,
help='gradient clip value (default: 0.8)')
parser.add_argument('--lr', type=float, default=1e-3,
help='initial learning rate (default: 1e-3)')
parser.add_argument('--optim', type=str, default='Adam',
help='optimizer to use (default: Adam)')
parser.add_argument('--num_epochs', type=int, default=40,
help='number of epochs (default: 40)')
parser.add_argument('--when', type=int, default=20,
help='when to decay learning rate (default: 20)')
parser.add_argument('--batch_chunk', type=int, default=1,
help='number of chunks per batch (default: 1)')
# Logistics
parser.add_argument('--log_interval', type=int, default=30,
help='frequency of result logging (default: 30)')
parser.add_argument('--seed', type=int, default=1111,
help='random seed')
parser.add_argument('--no_cuda', action='store_true',
help='do not use cuda')
parser.add_argument('--name', type=str, default='mult',
help='name of the trial (default: "mult")')
args = parser.parse_args()
torch.manual_seed(args.seed)
dataset = str.lower(args.dataset.strip())
valid_partial_mode = args.lonly + args.vonly + args.aonly
if valid_partial_mode == 0:
args.lonly = args.vonly = args.aonly = True
elif valid_partial_mode != 1:
raise ValueError("You can only choose one of {l/v/a}only.")
use_cuda = False
output_dim_dict = {
'mosi': 1,
'mosei_senti': 1,
'iemocap': 8
}
criterion_dict = {
'iemocap': 'CrossEntropyLoss'
}
torch.set_default_tensor_type('torch.FloatTensor')
if torch.cuda.is_available():
if args.no_cuda:
print("WARNING: You have a CUDA device, so you should probably not run with --no_cuda")
else:
torch.cuda.manual_seed(args.seed)
torch.set_default_tensor_type('torch.cuda.FloatTensor')
use_cuda = True
####################################################################
#
# Load the dataset (aligned or non-aligned)
#
####################################################################
print("Start loading the data....")
train_data = get_data(args, dataset, 'train')
valid_data = get_data(args, dataset, 'valid')
test_data = get_data(args, dataset, 'test')
train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=True)
print('Finish loading the data....')
if not args.aligned:
print("### Note: You are running in unaligned mode.")
####################################################################
#
# Hyperparameters
#
####################################################################
hyp_params = args
hyp_params.orig_d_l, hyp_params.orig_d_a, hyp_params.orig_d_v = train_data.get_dim()
hyp_params.l_len, hyp_params.a_len, hyp_params.v_len = train_data.get_seq_len()
hyp_params.layers = args.nlevels
hyp_params.use_cuda = use_cuda
hyp_params.dataset = dataset
hyp_params.when = args.when
hyp_params.batch_chunk = args.batch_chunk
hyp_params.n_train, hyp_params.n_valid, hyp_params.n_test = len(train_data), len(valid_data), len(test_data)
hyp_params.model = str.upper(args.model.strip())
hyp_params.output_dim = output_dim_dict.get(dataset, 1)
hyp_params.criterion = criterion_dict.get(dataset, 'L1Loss')
if __name__ == '__main__':
test_loss = train.initiate(hyp_params, train_loader, valid_loader, test_loader)