Skip to content

Commit

Permalink
[Enhance] Support the Training of ActionClip (#2620)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dai-Wenxun authored Oct 12, 2023
1 parent baf385e commit 17b88a3
Show file tree
Hide file tree
Showing 4 changed files with 437 additions and 26 deletions.
34 changes: 31 additions & 3 deletions projects/actionclip/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,45 @@ Create a symbolic link from `$MMACTION2/data` to `./data` in the current directo
ln -s ../../data ./data
```

### Training commands

**To train with single GPU:**

```bash
mim train mmaction configs/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py
```

**To train with multiple GPUs:**

```bash
mim train mmaction configs/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py --launcher pytorch --gpus 8
```

**To train with multiple GPUs by slurm:**

```bash
mim train mmaction configs/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py --launcher slurm \
--gpus 8 --gpus-per-node 8 --partition $PARTITION
```

### Testing commands

**To test with single GPU:**

```bash
mim test mmaction configs/actionclip_vit-base-p32-res224-clip-pre_1x1x8_k400-rgb.py --checkpoint $CHECKPOINT
mim test mmaction configs/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py --checkpoint $CHECKPOINT
```

**To test with multiple GPUs:**

```bash
mim test mmaction configs/actionclip_vit-base-p32-res224-clip-pre_1x1x8_k400-rgb.py --checkpoint $CHECKPOINT --launcher pytorch --gpus 8
mim test mmaction configs/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py --checkpoint $CHECKPOINT --launcher pytorch --gpus 8
```

**To test with multiple GPUs by slurm:**

```bash
mim test mmaction configs/actionclip_vit-base-p32-res224-clip-pre_1x1x8_k400-rgb.py --checkpoint $CHECKPOINT --launcher slurm \
mim test mmaction configs/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py --checkpoint $CHECKPOINT --launcher slurm \
--gpus 8 --gpus-per-node 8 --partition $PARTITION
```

Expand All @@ -80,6 +101,13 @@ mim test mmaction configs/actionclip_vit-base-p32-res224-clip-pre_1x1x8_k400-rgb

\[1\] The models are ported from the repo [ActionCLIP](https://github.com/sallymmx/ActionCLIP) and tested on our data. Currently, we only support the testing of ActionCLIP models. Due to the variation in testing data, our reported test accuracy differs from that of the original repository (on average, it is lower by one point). Please refer to this [issue](https://github.com/sallymmx/ActionCLIP/issues/14) for more details.

### Kinetics400 (Trained on Our K400 dataset)

| frame sampling strategy | gpus | backbone | top1 acc | top5 acc | testing protocol | config | ckpt | log |
| :---------------------: | :--: | :------: | :------: | :------: | :---------------: | :-------------------------------------------: | :------------------------------------------: | :-----------------------------------------: |
| 1x1x8 | 8 | ViT-B/32 | 77.5 | 93.2 | 8 clips x 1 crop | [config](./configs/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/projects/actionclip/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb_20230801-8535b794.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/projects/actionclip/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.log) |
| 1x1x8 | 8 | ViT-B/16 | 81.3 | 95.2 | 8 clips x 1 crop | [config](./configs/actionclip_vit-base-p16-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/projects/actionclip/actionclip_vit-base-p16-res224-clip-pre_g8xb16_1x1x8_k400-rgb/actionclip_vit-base-p16-res224-clip-pre_g8xb16_1x1x8_k400-rgb_20230801-b307a0cd.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/projects/actionclip/actionclip_vit-base-p16-res224-clip-pre_g8xb16_1x1x8_k400-rgb/actionclip_vit-base-p16-res224-clip-pre_g8xb16_1x1x8_k400-rgb.log) |

## Zero-Shot Prediction

We offer two methods for zero-shot prediction as follows. The `test.mp4` can be downloaded from [here](https://github-production-user-asset-6210df.s3.amazonaws.com/58767402/237333525-89ebee9a-573e-4e27-9047-0ad6422fa82f.mp4).
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
custom_imports = dict(imports='models')

num_segs = 8

model = dict(
type='ActionClip',
clip_arch='ViT-B/16',
num_adapter_segs=num_segs,
num_adapter_layers=6,
to_float32=True,
labels_or_label_file='configs/label_map_k400.txt',
data_preprocessor=dict(
type='ActionDataPreprocessor',
mean=[122.771, 116.746, 104.093],
std=[68.500, 66.632, 70.323],
format_shape='NCHW'))

dataset_type = 'VideoDataset'
data_root = 'data/kinetics400/videos_train'
data_root_val = 'data/kinetics400/videos_val'
ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt'
ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt'
ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt'

file_client_args = dict(io_backend='disk')
file_client_args = dict(
io_backend='petrel',
path_mapping=dict(
{'data/kinetics400/': 's3://openmmlab/datasets/action/Kinetics400/'}))

train_pipeline = [
dict(type='DecordInit', **file_client_args),
dict(
type='SampleFrames', clip_len=1, frame_interval=1, num_clips=num_segs),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='RandomResizedCrop'),
dict(
type='MultiScaleCrop',
input_size=224,
scales=(1, .875, .75, .66),
random_crop=False,
num_fixed_crops=13,
max_wh_scale_gap=1),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='FormatShape', input_format='NCHW'),
dict(type='PackActionInputs')
]

val_pipeline = [
dict(type='DecordInit', **file_client_args),
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=num_segs,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='FormatShape', input_format='NCHW'),
dict(type='PackActionInputs')
]

test_pipeline = val_pipeline

train_dataloader = dict(
batch_size=16,
num_workers=16,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=dict(video=data_root),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=16,
num_workers=16,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=dict(video=data_root_val),
pipeline=val_pipeline,
test_mode=True))
test_dataloader = dict(
batch_size=1,
num_workers=16,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=dict(video=data_root_val),
pipeline=test_pipeline,
test_mode=True))

val_evaluator = dict(type='AccMetric')
test_evaluator = val_evaluator

train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=50, val_begin=1, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

optim_wrapper = dict(
optimizer=dict(
type='AdamW', lr=5e-6, betas=(0.9, 0.98), eps=1e-08, weight_decay=0.2),
paramwise_cfg=dict(custom_keys=dict(adapter=dict(lr_mult=10))))

param_scheduler = [
dict(
type='LinearLR',
start_factor=0.01,
by_epoch=True,
begin=0,
end=5,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=45,
eta_min=0,
by_epoch=True,
begin=5,
end=50,
convert_to_iter_based=True)
]

# Default setting for scaling LR automatically
# - `enable` means enable scaling LR automatically
# or not by default.
# - `base_batch_size` = (8 GPUs) x (16 samples per GPU).
auto_scale_lr = dict(enable=False, base_batch_size=128)

default_scope = 'mmaction'

default_hooks = dict(
runtime_info=dict(type='RuntimeInfoHook'),
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=100, ignore_last=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(
type='CheckpointHook', interval=1, save_best='auto', max_keep_ckpts=5),
sampler_seed=dict(type='DistSamplerSeedHook'),
sync_buffers=dict(type='SyncBuffersHook'))

env_cfg = dict(
cudnn_benchmark=False,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'))

log_processor = dict(type='LogProcessor', window_size=20, by_epoch=True)

vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(type='ActionVisualizer', vis_backends=vis_backends)

log_level = 'INFO'
load_from = None
resume = False
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
custom_imports = dict(imports='models')

num_segs = 8

model = dict(
type='ActionClip',
clip_arch='ViT-B/32',
num_adapter_segs=num_segs,
num_adapter_layers=6,
to_float32=True,
labels_or_label_file='configs/label_map_k400.txt',
data_preprocessor=dict(
type='ActionDataPreprocessor',
mean=[122.771, 116.746, 104.093],
std=[68.500, 66.632, 70.323],
format_shape='NCHW'))

dataset_type = 'VideoDataset'
data_root = 'data/kinetics400/videos_train'
data_root_val = 'data/kinetics400/videos_val'
ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt'
ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt'
ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt'

file_client_args = dict(io_backend='disk')
file_client_args = dict(
io_backend='petrel',
path_mapping=dict(
{'data/kinetics400/': 's3://openmmlab/datasets/action/Kinetics400/'}))

train_pipeline = [
dict(type='DecordInit', **file_client_args),
dict(
type='SampleFrames', clip_len=1, frame_interval=1, num_clips=num_segs),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='RandomResizedCrop'),
dict(
type='MultiScaleCrop',
input_size=224,
scales=(1, .875, .75, .66),
random_crop=False,
num_fixed_crops=13,
max_wh_scale_gap=1),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='FormatShape', input_format='NCHW'),
dict(type='PackActionInputs')
]

val_pipeline = [
dict(type='DecordInit', **file_client_args),
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=num_segs,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='FormatShape', input_format='NCHW'),
dict(type='PackActionInputs')
]

test_pipeline = val_pipeline

train_dataloader = dict(
batch_size=16,
num_workers=16,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=dict(video=data_root),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=16,
num_workers=16,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=dict(video=data_root_val),
pipeline=val_pipeline,
test_mode=True))
test_dataloader = dict(
batch_size=1,
num_workers=16,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=dict(video=data_root_val),
pipeline=test_pipeline,
test_mode=True))

val_evaluator = dict(type='AccMetric')
test_evaluator = val_evaluator

train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=50, val_begin=1, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

optim_wrapper = dict(
optimizer=dict(
type='AdamW', lr=5e-6, betas=(0.9, 0.98), eps=1e-08, weight_decay=0.2),
paramwise_cfg=dict(custom_keys=dict(adapter=dict(lr_mult=10))))

param_scheduler = [
dict(
type='LinearLR',
start_factor=0.01,
by_epoch=True,
begin=0,
end=5,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=45,
eta_min=0,
by_epoch=True,
begin=5,
end=50,
convert_to_iter_based=True)
]

# Default setting for scaling LR automatically
# - `enable` means enable scaling LR automatically
# or not by default.
# - `base_batch_size` = (8 GPUs) x (16 samples per GPU).
auto_scale_lr = dict(enable=False, base_batch_size=128)

default_scope = 'mmaction'

default_hooks = dict(
runtime_info=dict(type='RuntimeInfoHook'),
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=100, ignore_last=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(
type='CheckpointHook', interval=1, save_best='auto', max_keep_ckpts=5),
sampler_seed=dict(type='DistSamplerSeedHook'),
sync_buffers=dict(type='SyncBuffersHook'))

env_cfg = dict(
cudnn_benchmark=False,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'))

log_processor = dict(type='LogProcessor', window_size=20, by_epoch=True)

vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(type='ActionVisualizer', vis_backends=vis_backends)

log_level = 'INFO'
load_from = None
resume = False
Loading

0 comments on commit 17b88a3

Please sign in to comment.