Skip to content

Commit 17b88a3

Browse files
authored
[Enhance] Support the Training of ActionClip (#2620)
1 parent baf385e commit 17b88a3

File tree

4 files changed

+437
-26
lines changed

4 files changed

+437
-26
lines changed

projects/actionclip/README.md

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,24 +46,45 @@ Create a symbolic link from `$MMACTION2/data` to `./data` in the current directo
4646
ln -s ../../data ./data
4747
```
4848

49+
### Training commands
50+
51+
**To train with single GPU:**
52+
53+
```bash
54+
mim train mmaction configs/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py
55+
```
56+
57+
**To train with multiple GPUs:**
58+
59+
```bash
60+
mim train mmaction configs/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py --launcher pytorch --gpus 8
61+
```
62+
63+
**To train with multiple GPUs by slurm:**
64+
65+
```bash
66+
mim train mmaction configs/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py --launcher slurm \
67+
--gpus 8 --gpus-per-node 8 --partition $PARTITION
68+
```
69+
4970
### Testing commands
5071

5172
**To test with single GPU:**
5273

5374
```bash
54-
mim test mmaction configs/actionclip_vit-base-p32-res224-clip-pre_1x1x8_k400-rgb.py --checkpoint $CHECKPOINT
75+
mim test mmaction configs/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py --checkpoint $CHECKPOINT
5576
```
5677

5778
**To test with multiple GPUs:**
5879

5980
```bash
60-
mim test mmaction configs/actionclip_vit-base-p32-res224-clip-pre_1x1x8_k400-rgb.py --checkpoint $CHECKPOINT --launcher pytorch --gpus 8
81+
mim test mmaction configs/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py --checkpoint $CHECKPOINT --launcher pytorch --gpus 8
6182
```
6283

6384
**To test with multiple GPUs by slurm:**
6485

6586
```bash
66-
mim test mmaction configs/actionclip_vit-base-p32-res224-clip-pre_1x1x8_k400-rgb.py --checkpoint $CHECKPOINT --launcher slurm \
87+
mim test mmaction configs/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py --checkpoint $CHECKPOINT --launcher slurm \
6788
--gpus 8 --gpus-per-node 8 --partition $PARTITION
6889
```
6990

@@ -80,6 +101,13 @@ mim test mmaction configs/actionclip_vit-base-p32-res224-clip-pre_1x1x8_k400-rgb
80101

81102
\[1\] The models are ported from the repo [ActionCLIP](https://github.com/sallymmx/ActionCLIP) and tested on our data. Currently, we only support the testing of ActionCLIP models. Due to the variation in testing data, our reported test accuracy differs from that of the original repository (on average, it is lower by one point). Please refer to this [issue](https://github.com/sallymmx/ActionCLIP/issues/14) for more details.
82103

104+
### Kinetics400 (Trained on Our K400 dataset)
105+
106+
| frame sampling strategy | gpus | backbone | top1 acc | top5 acc | testing protocol | config | ckpt | log |
107+
| :---------------------: | :--: | :------: | :------: | :------: | :---------------: | :-------------------------------------------: | :------------------------------------------: | :-----------------------------------------: |
108+
| 1x1x8 | 8 | ViT-B/32 | 77.5 | 93.2 | 8 clips x 1 crop | [config](./configs/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/projects/actionclip/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb_20230801-8535b794.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/projects/actionclip/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.log) |
109+
| 1x1x8 | 8 | ViT-B/16 | 81.3 | 95.2 | 8 clips x 1 crop | [config](./configs/actionclip_vit-base-p16-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/projects/actionclip/actionclip_vit-base-p16-res224-clip-pre_g8xb16_1x1x8_k400-rgb/actionclip_vit-base-p16-res224-clip-pre_g8xb16_1x1x8_k400-rgb_20230801-b307a0cd.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/projects/actionclip/actionclip_vit-base-p16-res224-clip-pre_g8xb16_1x1x8_k400-rgb/actionclip_vit-base-p16-res224-clip-pre_g8xb16_1x1x8_k400-rgb.log) |
110+
83111
## Zero-Shot Prediction
84112

85113
We offer two methods for zero-shot prediction as follows. The `test.mp4` can be downloaded from [here](https://github-production-user-asset-6210df.s3.amazonaws.com/58767402/237333525-89ebee9a-573e-4e27-9047-0ad6422fa82f.mp4).
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
custom_imports = dict(imports='models')
2+
3+
num_segs = 8
4+
5+
model = dict(
6+
type='ActionClip',
7+
clip_arch='ViT-B/16',
8+
num_adapter_segs=num_segs,
9+
num_adapter_layers=6,
10+
to_float32=True,
11+
labels_or_label_file='configs/label_map_k400.txt',
12+
data_preprocessor=dict(
13+
type='ActionDataPreprocessor',
14+
mean=[122.771, 116.746, 104.093],
15+
std=[68.500, 66.632, 70.323],
16+
format_shape='NCHW'))
17+
18+
dataset_type = 'VideoDataset'
19+
data_root = 'data/kinetics400/videos_train'
20+
data_root_val = 'data/kinetics400/videos_val'
21+
ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt'
22+
ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt'
23+
ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt'
24+
25+
file_client_args = dict(io_backend='disk')
26+
file_client_args = dict(
27+
io_backend='petrel',
28+
path_mapping=dict(
29+
{'data/kinetics400/': 's3://openmmlab/datasets/action/Kinetics400/'}))
30+
31+
train_pipeline = [
32+
dict(type='DecordInit', **file_client_args),
33+
dict(
34+
type='SampleFrames', clip_len=1, frame_interval=1, num_clips=num_segs),
35+
dict(type='DecordDecode'),
36+
dict(type='Resize', scale=(-1, 256)),
37+
dict(type='RandomResizedCrop'),
38+
dict(
39+
type='MultiScaleCrop',
40+
input_size=224,
41+
scales=(1, .875, .75, .66),
42+
random_crop=False,
43+
num_fixed_crops=13,
44+
max_wh_scale_gap=1),
45+
dict(type='Resize', scale=(224, 224), keep_ratio=False),
46+
dict(type='Flip', flip_ratio=0.5),
47+
dict(type='FormatShape', input_format='NCHW'),
48+
dict(type='PackActionInputs')
49+
]
50+
51+
val_pipeline = [
52+
dict(type='DecordInit', **file_client_args),
53+
dict(
54+
type='SampleFrames',
55+
clip_len=1,
56+
frame_interval=1,
57+
num_clips=num_segs,
58+
test_mode=True),
59+
dict(type='DecordDecode'),
60+
dict(type='Resize', scale=(-1, 256)),
61+
dict(type='CenterCrop', crop_size=224),
62+
dict(type='FormatShape', input_format='NCHW'),
63+
dict(type='PackActionInputs')
64+
]
65+
66+
test_pipeline = val_pipeline
67+
68+
train_dataloader = dict(
69+
batch_size=16,
70+
num_workers=16,
71+
persistent_workers=True,
72+
sampler=dict(type='DefaultSampler', shuffle=True),
73+
dataset=dict(
74+
type=dataset_type,
75+
ann_file=ann_file_train,
76+
data_prefix=dict(video=data_root),
77+
pipeline=train_pipeline))
78+
val_dataloader = dict(
79+
batch_size=16,
80+
num_workers=16,
81+
persistent_workers=True,
82+
sampler=dict(type='DefaultSampler', shuffle=False),
83+
dataset=dict(
84+
type=dataset_type,
85+
ann_file=ann_file_val,
86+
data_prefix=dict(video=data_root_val),
87+
pipeline=val_pipeline,
88+
test_mode=True))
89+
test_dataloader = dict(
90+
batch_size=1,
91+
num_workers=16,
92+
persistent_workers=True,
93+
sampler=dict(type='DefaultSampler', shuffle=False),
94+
dataset=dict(
95+
type=dataset_type,
96+
ann_file=ann_file_test,
97+
data_prefix=dict(video=data_root_val),
98+
pipeline=test_pipeline,
99+
test_mode=True))
100+
101+
val_evaluator = dict(type='AccMetric')
102+
test_evaluator = val_evaluator
103+
104+
train_cfg = dict(
105+
type='EpochBasedTrainLoop', max_epochs=50, val_begin=1, val_interval=1)
106+
val_cfg = dict(type='ValLoop')
107+
test_cfg = dict(type='TestLoop')
108+
109+
optim_wrapper = dict(
110+
optimizer=dict(
111+
type='AdamW', lr=5e-6, betas=(0.9, 0.98), eps=1e-08, weight_decay=0.2),
112+
paramwise_cfg=dict(custom_keys=dict(adapter=dict(lr_mult=10))))
113+
114+
param_scheduler = [
115+
dict(
116+
type='LinearLR',
117+
start_factor=0.01,
118+
by_epoch=True,
119+
begin=0,
120+
end=5,
121+
convert_to_iter_based=True),
122+
dict(
123+
type='CosineAnnealingLR',
124+
T_max=45,
125+
eta_min=0,
126+
by_epoch=True,
127+
begin=5,
128+
end=50,
129+
convert_to_iter_based=True)
130+
]
131+
132+
# Default setting for scaling LR automatically
133+
# - `enable` means enable scaling LR automatically
134+
# or not by default.
135+
# - `base_batch_size` = (8 GPUs) x (16 samples per GPU).
136+
auto_scale_lr = dict(enable=False, base_batch_size=128)
137+
138+
default_scope = 'mmaction'
139+
140+
default_hooks = dict(
141+
runtime_info=dict(type='RuntimeInfoHook'),
142+
timer=dict(type='IterTimerHook'),
143+
logger=dict(type='LoggerHook', interval=100, ignore_last=False),
144+
param_scheduler=dict(type='ParamSchedulerHook'),
145+
checkpoint=dict(
146+
type='CheckpointHook', interval=1, save_best='auto', max_keep_ckpts=5),
147+
sampler_seed=dict(type='DistSamplerSeedHook'),
148+
sync_buffers=dict(type='SyncBuffersHook'))
149+
150+
env_cfg = dict(
151+
cudnn_benchmark=False,
152+
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
153+
dist_cfg=dict(backend='nccl'))
154+
155+
log_processor = dict(type='LogProcessor', window_size=20, by_epoch=True)
156+
157+
vis_backends = [dict(type='LocalVisBackend')]
158+
visualizer = dict(type='ActionVisualizer', vis_backends=vis_backends)
159+
160+
log_level = 'INFO'
161+
load_from = None
162+
resume = False
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
custom_imports = dict(imports='models')
2+
3+
num_segs = 8
4+
5+
model = dict(
6+
type='ActionClip',
7+
clip_arch='ViT-B/32',
8+
num_adapter_segs=num_segs,
9+
num_adapter_layers=6,
10+
to_float32=True,
11+
labels_or_label_file='configs/label_map_k400.txt',
12+
data_preprocessor=dict(
13+
type='ActionDataPreprocessor',
14+
mean=[122.771, 116.746, 104.093],
15+
std=[68.500, 66.632, 70.323],
16+
format_shape='NCHW'))
17+
18+
dataset_type = 'VideoDataset'
19+
data_root = 'data/kinetics400/videos_train'
20+
data_root_val = 'data/kinetics400/videos_val'
21+
ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt'
22+
ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt'
23+
ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt'
24+
25+
file_client_args = dict(io_backend='disk')
26+
file_client_args = dict(
27+
io_backend='petrel',
28+
path_mapping=dict(
29+
{'data/kinetics400/': 's3://openmmlab/datasets/action/Kinetics400/'}))
30+
31+
train_pipeline = [
32+
dict(type='DecordInit', **file_client_args),
33+
dict(
34+
type='SampleFrames', clip_len=1, frame_interval=1, num_clips=num_segs),
35+
dict(type='DecordDecode'),
36+
dict(type='Resize', scale=(-1, 256)),
37+
dict(type='RandomResizedCrop'),
38+
dict(
39+
type='MultiScaleCrop',
40+
input_size=224,
41+
scales=(1, .875, .75, .66),
42+
random_crop=False,
43+
num_fixed_crops=13,
44+
max_wh_scale_gap=1),
45+
dict(type='Resize', scale=(224, 224), keep_ratio=False),
46+
dict(type='Flip', flip_ratio=0.5),
47+
dict(type='FormatShape', input_format='NCHW'),
48+
dict(type='PackActionInputs')
49+
]
50+
51+
val_pipeline = [
52+
dict(type='DecordInit', **file_client_args),
53+
dict(
54+
type='SampleFrames',
55+
clip_len=1,
56+
frame_interval=1,
57+
num_clips=num_segs,
58+
test_mode=True),
59+
dict(type='DecordDecode'),
60+
dict(type='Resize', scale=(-1, 256)),
61+
dict(type='CenterCrop', crop_size=224),
62+
dict(type='FormatShape', input_format='NCHW'),
63+
dict(type='PackActionInputs')
64+
]
65+
66+
test_pipeline = val_pipeline
67+
68+
train_dataloader = dict(
69+
batch_size=16,
70+
num_workers=16,
71+
persistent_workers=True,
72+
sampler=dict(type='DefaultSampler', shuffle=True),
73+
dataset=dict(
74+
type=dataset_type,
75+
ann_file=ann_file_train,
76+
data_prefix=dict(video=data_root),
77+
pipeline=train_pipeline))
78+
val_dataloader = dict(
79+
batch_size=16,
80+
num_workers=16,
81+
persistent_workers=True,
82+
sampler=dict(type='DefaultSampler', shuffle=False),
83+
dataset=dict(
84+
type=dataset_type,
85+
ann_file=ann_file_val,
86+
data_prefix=dict(video=data_root_val),
87+
pipeline=val_pipeline,
88+
test_mode=True))
89+
test_dataloader = dict(
90+
batch_size=1,
91+
num_workers=16,
92+
persistent_workers=True,
93+
sampler=dict(type='DefaultSampler', shuffle=False),
94+
dataset=dict(
95+
type=dataset_type,
96+
ann_file=ann_file_test,
97+
data_prefix=dict(video=data_root_val),
98+
pipeline=test_pipeline,
99+
test_mode=True))
100+
101+
val_evaluator = dict(type='AccMetric')
102+
test_evaluator = val_evaluator
103+
104+
train_cfg = dict(
105+
type='EpochBasedTrainLoop', max_epochs=50, val_begin=1, val_interval=1)
106+
val_cfg = dict(type='ValLoop')
107+
test_cfg = dict(type='TestLoop')
108+
109+
optim_wrapper = dict(
110+
optimizer=dict(
111+
type='AdamW', lr=5e-6, betas=(0.9, 0.98), eps=1e-08, weight_decay=0.2),
112+
paramwise_cfg=dict(custom_keys=dict(adapter=dict(lr_mult=10))))
113+
114+
param_scheduler = [
115+
dict(
116+
type='LinearLR',
117+
start_factor=0.01,
118+
by_epoch=True,
119+
begin=0,
120+
end=5,
121+
convert_to_iter_based=True),
122+
dict(
123+
type='CosineAnnealingLR',
124+
T_max=45,
125+
eta_min=0,
126+
by_epoch=True,
127+
begin=5,
128+
end=50,
129+
convert_to_iter_based=True)
130+
]
131+
132+
# Default setting for scaling LR automatically
133+
# - `enable` means enable scaling LR automatically
134+
# or not by default.
135+
# - `base_batch_size` = (8 GPUs) x (16 samples per GPU).
136+
auto_scale_lr = dict(enable=False, base_batch_size=128)
137+
138+
default_scope = 'mmaction'
139+
140+
default_hooks = dict(
141+
runtime_info=dict(type='RuntimeInfoHook'),
142+
timer=dict(type='IterTimerHook'),
143+
logger=dict(type='LoggerHook', interval=100, ignore_last=False),
144+
param_scheduler=dict(type='ParamSchedulerHook'),
145+
checkpoint=dict(
146+
type='CheckpointHook', interval=1, save_best='auto', max_keep_ckpts=5),
147+
sampler_seed=dict(type='DistSamplerSeedHook'),
148+
sync_buffers=dict(type='SyncBuffersHook'))
149+
150+
env_cfg = dict(
151+
cudnn_benchmark=False,
152+
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
153+
dist_cfg=dict(backend='nccl'))
154+
155+
log_processor = dict(type='LogProcessor', window_size=20, by_epoch=True)
156+
157+
vis_backends = [dict(type='LocalVisBackend')]
158+
visualizer = dict(type='ActionVisualizer', vis_backends=vis_backends)
159+
160+
log_level = 'INFO'
161+
load_from = None
162+
resume = False

0 commit comments

Comments
 (0)