Skip to content

Commit

Permalink
[Feature] Support MobileOne TSN/TSM (#2656)
Browse files Browse the repository at this point in the history
  • Loading branch information
cir7 authored Sep 6, 2023
1 parent ed1270c commit 2ddf4b5
Show file tree
Hide file tree
Showing 18 changed files with 643 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .circleci/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ jobs:
mim install 'mmcv >= 2.0.0'
pip install git+https://[email protected]/open-mmlab/[email protected]
pip install git+https://github.com/open-mmlab/[email protected]
pip install git+https://github.com/open-mmlab/mmpretrain.git@dev
pip install git+https://github.com/open-mmlab/[email protected]
pip install -r requirements.txt
- run:
Expand Down Expand Up @@ -126,6 +127,7 @@ jobs:
docker exec mmaction pip install git+https://[email protected]/open-mmlab/[email protected]
docker exec mmaction pip install git+https://[email protected]/open-mmlab/[email protected]
docker exec mmaction pip install git+https://github.com/open-mmlab/[email protected]
docker exec mmaction pip install git+https://github.com/open-mmlab/mmpretrain.git@dev
docker exec mmaction pip install -r requirements.txt
- run:
name: Build and install
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/merge_stage_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ jobs:
run: pip install git+https://github.com/open-mmlab/[email protected]
- name: Install MMCls
run: pip install git+https://github.com/open-mmlab/[email protected]
- name: Install MMPretrain
run: pip install git+https://github.com/open-mmlab/mmpretrain.git@dev
- name: Install MMPose
run: pip install git+https://github.com/open-mmlab/[email protected]
- name: Install PytorchVideo
Expand Down Expand Up @@ -122,6 +124,8 @@ jobs:
run: pip install git+https://github.com/open-mmlab/[email protected]
- name: Install MMCls
run: pip install git+https://github.com/open-mmlab/[email protected]
- name: Install MMPretrain
run: pip install git+https://github.com/open-mmlab/mmpretrain.git@dev
- name: Install MMPose
run: pip install git+https://github.com/open-mmlab/[email protected]
- name: Install unittest dependencies
Expand Down Expand Up @@ -186,6 +190,7 @@ jobs:
mim install 'mmcv >= 2.0.0'
pip install git+https://github.com/open-mmlab/[email protected]
pip install git+https://github.com/open-mmlab/[email protected]
pip install git+https://github.com/open-mmlab/mmpretrain.git@dev
pip install git+https://github.com/open-mmlab/[email protected]
pip install -r requirements.txt
- name: Install PytorchVideo
Expand Down Expand Up @@ -228,6 +233,7 @@ jobs:
mim install 'mmcv >= 2.0.0'
pip install git+https://github.com/open-mmlab/[email protected]
pip install git+https://github.com/open-mmlab/[email protected]
pip install git+https://github.com/open-mmlab/mmpretrain.git@dev
pip install git+https://github.com/open-mmlab/[email protected]
pip install -r requirements.txt
- name: Install PytorchVideo
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/pr_stage_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ jobs:
run: pip install git+https://github.com/open-mmlab/[email protected]
- name: Install MMCls
run: pip install git+https://github.com/open-mmlab/[email protected]
- name: Install MMPretrain
run: pip install git+https://github.com/open-mmlab/mmpretrain.git@dev
- name: Install MMPose
run: pip install git+https://github.com/open-mmlab/[email protected]
- name: Install unittest dependencies
Expand Down Expand Up @@ -119,6 +121,7 @@ jobs:
mim install 'mmcv >= 2.0.0'
pip install git+https://github.com/open-mmlab/[email protected]
pip install git+https://github.com/open-mmlab/[email protected]
pip install git+https://github.com/open-mmlab/mmpretrain.git@dev
pip install git+https://github.com/open-mmlab/[email protected]
pip install -r requirements.txt
- name: Install PytorchVideo
Expand Down Expand Up @@ -168,6 +171,7 @@ jobs:
mim install 'mmcv >= 2.0.0'
pip install git+https://github.com/open-mmlab/[email protected]
pip install git+https://github.com/open-mmlab/[email protected]
pip install git+https://github.com/open-mmlab/mmpretrain.git@dev
pip install git+https://github.com/open-mmlab/[email protected]
pip install -r requirements.txt
- name: Install PytorchVideo
Expand Down
31 changes: 31 additions & 0 deletions configs/_base_/models/tsm_mobileone_s4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# model settings
preprocess_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])

checkpoint = ('https://download.openmmlab.com/mmclassification/'
'v0/mobileone/mobileone-s4_8xb32_in1k_20221110-28d888cb.pth')
model = dict(
type='Recognizer2D',
backbone=dict(
type='MobileOneTSM',
arch='s4',
shift_div=8,
num_segments=8,
is_shift=True,
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint, prefix='backbone')),
cls_head=dict(
type='TSMHead',
num_segments=8,
num_classes=400,
in_channels=2048,
spatial_type='avg',
consensus=dict(type='AvgConsensus', dim=1),
dropout_ratio=0.5,
init_std=0.001,
is_shift=True,
average_clips='prob'),
# model training and testing settings
data_preprocessor=dict(type='ActionDataPreprocessor', **preprocess_cfg),
train_cfg=None,
test_cfg=None)
26 changes: 26 additions & 0 deletions configs/_base_/models/tsn_mobileone_s0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
checkpoint = ('https://download.openmmlab.com/mmclassification/'
'v0/mobileone/mobileone-s0_8xb32_in1k_20221110-0bc94952.pth')
model = dict(
type='Recognizer2D',
backbone=dict(
type='mmpretrain.MobileOne',
arch='s0',
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint, prefix='backbone'),
norm_eval=False),
cls_head=dict(
type='TSNHead',
num_classes=400,
in_channels=1024,
spatial_type='avg',
consensus=dict(type='AvgConsensus', dim=1),
dropout_ratio=0.4,
init_std=0.01,
average_clips='prob'),
data_preprocessor=dict(
type='ActionDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
format_shape='NCHW'),
train_cfg=None,
test_cfg=None)
2 changes: 2 additions & 0 deletions configs/recognition/tsm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ The explosive growth in video streaming gives rise to challenges on performing v
| 1x1x8 | 224x224 | 8 | ResNet50 (NonLocalGauss) | ImageNet | 73.66 | 90.99 | 8 clips x 10 crop | 59.06G | 28.00M | [config](/configs/recognition/tsm/tsm_imagenet-pretrained-r50-nl-gaussian_8xb16-1x1x8-50e_kinetics400-rgb.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/recognition/tsm/tsm_imagenet-pretrained-r50-nl-gaussian_8xb16-1x1x8-50e_kinetics400-rgb/tsm_imagenet-pretrained-r50-nl-gaussian_8xb16-1x1x8-50e_kinetics400-rgb_20220831-7e54dacf.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/recognition/tsm/tsm_imagenet-pretrained-r50-nl-gaussian_8xb16-1x1x8-50e_kinetics400-rgb/tsm_imagenet-pretrained-r50-nl-gaussian_8xb16-1x1x8-50e_kinetics400-rgb.log) |
| 1x1x8 | 224x224 | 8 | ResNet50 (NonLocalEmbedGauss) | ImageNet | 74.34 | 91.23 | 8 clips x 10 crop | 61.30G | 31.68M | [config](/configs/recognition/tsm/tsm_imagenet-pretrained-r50-nl-embedded-gaussian_8xb16-1x1x8-50e_kinetics400-rgb.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/recognition/tsm/tsm_imagenet-pretrained-r50-nl-embedded-gaussian_8xb16-1x1x8-50e_kinetics400-rgb/tsm_imagenet-pretrained-r50-nl-embedded-gaussian_8xb16-1x1x8-50e_kinetics400-rgb_20220831-35eddb57.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/recognition/tsm/tsm_imagenet-pretrained-r50-nl-embedded-gaussian_8xb16-1x1x8-50e_kinetics400-rgb/tsm_imagenet-pretrained-r50-nl-embedded-gaussian_8xb16-1x1x8-50e_kinetics400-rgb.log) |
| 1x1x8 | 224x224 | 8 | MobileNetV2 | ImageNet | 68.71 | 88.32 | 8 clips x 3 crop | 3.269G | 2.736M | [config](/configs/recognition/tsm/tsm_imagenet-pretrained-mobilenetv2_8xb16-1x1x8-100e_kinetics400-rgb.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/recognition/tsm/tsm_imagenet-pretrained-mobilenetv2_8xb16-1x1x8-100e_kinetics400-rgb/tsm_imagenet-pretrained-mobilenetv2_8xb16-1x1x8-100e_kinetics400-rgb_20230414-401127fd.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/recognition/tsm/tsm_imagenet-pretrained-mobilenetv2_8xb16-1x1x8-100e_kinetics400-rgb/tsm_imagenet-pretrained-mobilenetv2_8xb16-1x1x8-100e_kinetics400-rgb.log) |
| 1x1x16 | 224x224 | 8 | MobileOne-S4 | ImageNet | 74.38 | 91.71 | 16 clips x 10 crop | 48.65G | 13.72M | [config](/configs/recognition/tsm/tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/recognition/tsm/tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb/tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb_20230825-a7f8876b.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/recognition/tsm/tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb/tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb.log) |

### Something-something V2

Expand All @@ -41,6 +42,7 @@ The explosive growth in video streaming gives rise to challenges on performing v

1. The **gpus** indicates the number of gpus we used to get the checkpoint. If you want to use a different number of gpus or videos per gpu, the best way is to set `--auto-scale-lr` when calling `tools/train.py`, this parameter will auto-scale the learning rate according to the actual batch size and the original batch size.
2. The validation set of Kinetics400 we used consists of 19796 videos. These videos are available at [Kinetics400-Validation](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155136485_link_cuhk_edu_hk/EbXw2WX94J1Hunyt3MWNDJUBz-nHvQYhO9pvKqm6g39PMA?e=a9QldB). The corresponding [data list](https://download.openmmlab.com/mmaction/dataset/k400_val/kinetics_val_list.txt) (each line is of the format 'video_id, num_frames, label_index') and the [label map](https://download.openmmlab.com/mmaction/dataset/k400_val/kinetics_class2ind.txt) are also available.
3. MoibleOne backbone supports reparameterization during inference. You can use the provided [reparameterize tool](/tools/convert/reparameterize_model.py) to convert the checkpoint and switch to the [deploy config file](/configs/recognition/tsm/tsm_imagenet-pretrained-mobileone-s4_deploy_8xb16-1x1x16-50e_kinetics400-rgb.py).

For more details on data preparation, you can refer to [Kinetics400](/tools/data/kinetics/README.md).

Expand Down
24 changes: 24 additions & 0 deletions configs/recognition/tsm/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,30 @@ Models:
Training Log: https://download.openmmlab.com/mmaction/v1.0/recognition/tsm/tsm_imagenet-pretrained-r50-nl-gaussian_8xb16-1x1x8-50e_kinetics400-rgb/tsm_imagenet-pretrained-r50-nl-gaussian_8xb16-1x1x8-50e_kinetics400-rgb.log
Weights: https://download.openmmlab.com/mmaction/v1.0/recognition/tsm/tsm_imagenet-pretrained-r50-nl-gaussian_8xb16-1x1x8-50e_kinetics400-rgb/tsm_imagenet-pretrained-r50-nl-gaussian_8xb16-1x1x8-50e_kinetics400-rgb_20220831-7e54dacf.pth

- Name: tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb
Config: configs/recognition/tsm/tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb.py
In Collection: TSM
Metadata:
Architecture: MobileOne-S4
Batch Size: 16
Epochs: 100
FLOPs: 48.65G
Parameters: 13.72M
Pretrained: ImageNet
Resolution: 224x224
Training Data: Kinetics-400
Training Resources: 8 GPUs
Modality: RGB
Results:
- Dataset: Kinetics-400
Task: Action Recognition
Metrics:
Top 1 Accuracy: 74.38
Top 5 Accuracy: 91.71
Training Log: https://download.openmmlab.com/mmaction/v1.0/recognition/tsm/tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb/tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb.log
Weights: https://download.openmmlab.com/mmaction/v1.0/recognition/tsm/tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb/tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb_20230825-a7f8876b.pth


- Name: tsm_imagenet-pretrained-r101_8xb16-1x1x8-50e_sthv2-rgb
Config: configs/recognition/tsm/tsm_imagenet-pretrained-r101_8xb16-1x1x8-50e_sthv2-rgb.py
In Collection: TSM
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
_base_ = [
'../../_base_/models/tsm_mobileone_s4.py',
'../../_base_/default_runtime.py'
]

model = dict(cls_head=dict(num_segments=16))
# dataset settings
dataset_type = 'VideoDataset'
data_root = 'data/kinetics400/videos_train'
data_root_val = 'data/kinetics400/videos_val'
ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt'
ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt'

file_client_args = dict(io_backend='disk')

train_pipeline = [
dict(type='DecordInit', **file_client_args),
dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=16),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(
type='MultiScaleCrop',
input_size=224,
scales=(1, 0.875, 0.75, 0.66),
random_crop=False,
max_wh_scale_gap=1,
num_fixed_crops=13),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='FormatShape', input_format='NCHW'),
dict(type='PackActionInputs')
]
val_pipeline = [
dict(type='DecordInit', **file_client_args),
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=16,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='FormatShape', input_format='NCHW'),
dict(type='PackActionInputs')
]
test_pipeline = [
dict(type='DecordInit', **file_client_args),
dict(
type='SampleFrames',
clip_len=1,
frame_interval=1,
num_clips=16,
test_mode=True),
dict(type='DecordDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='ThreeCrop', crop_size=256),
dict(type='FormatShape', input_format='NCHW'),
dict(type='PackActionInputs')
]

train_dataloader = dict(
batch_size=8,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=dict(video=data_root),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=8,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=dict(video=data_root_val),
pipeline=val_pipeline,
test_mode=True))
test_dataloader = dict(
batch_size=1,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=dict(video=data_root_val),
pipeline=test_pipeline,
test_mode=True))

val_evaluator = dict(type='AccMetric')
test_evaluator = val_evaluator

default_hooks = dict(checkpoint=dict(interval=3, max_keep_ckpts=3))

train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=50, val_begin=1, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

param_scheduler = [
dict(type='LinearLR', start_factor=0.1, by_epoch=True, begin=0, end=5),
dict(
type='MultiStepLR',
begin=0,
end=50,
by_epoch=True,
milestones=[25, 45],
gamma=0.1)
]

optim_wrapper = dict(
constructor='TSMOptimWrapperConstructor',
paramwise_cfg=dict(fc_lr5=True),
optimizer=dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.00002),
clip_grad=dict(max_norm=20, norm_type=2))

# Default setting for scaling LR automatically
# - `enable` means enable scaling LR automatically
# or not by default.
# - `base_batch_size` = (8 GPUs) x (16 samples per GPU).
auto_scale_lr = dict(enable=True, base_batch_size=128)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = [
'./tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb.py', # noqa: E501
]

model = dict(backbone=dict(deploy=True))
Loading

0 comments on commit 2ddf4b5

Please sign in to comment.