diff --git a/.circleci/test.yml b/.circleci/test.yml index 5c57cd74b9..2d5713cf1a 100644 --- a/.circleci/test.yml +++ b/.circleci/test.yml @@ -66,6 +66,7 @@ jobs: mim install 'mmcv >= 2.0.0' pip install git+https://git@github.com/open-mmlab/mmdetection.git@dev-3.x pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x + pip install git+https://github.com/open-mmlab/mmpretrain.git@dev pip install git+https://github.com/open-mmlab/mmpose.git@dev-1.x pip install -r requirements.txt - run: @@ -126,6 +127,7 @@ jobs: docker exec mmaction pip install git+https://git@github.com/open-mmlab/mmdetection.git@dev-3.x docker exec mmaction pip install git+https://git@github.com/open-mmlab/mmpose.git@dev-1.x docker exec mmaction pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x + docker exec mmaction pip install git+https://github.com/open-mmlab/mmpretrain.git@dev docker exec mmaction pip install -r requirements.txt - run: name: Build and install diff --git a/.github/workflows/merge_stage_test.yml b/.github/workflows/merge_stage_test.yml index 0b83911506..de01615037 100644 --- a/.github/workflows/merge_stage_test.yml +++ b/.github/workflows/merge_stage_test.yml @@ -60,6 +60,8 @@ jobs: run: pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x - name: Install MMCls run: pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x + - name: Install MMPretrain + run: pip install git+https://github.com/open-mmlab/mmpretrain.git@dev - name: Install MMPose run: pip install git+https://github.com/open-mmlab/mmpose.git@dev-1.x - name: Install PytorchVideo @@ -122,6 +124,8 @@ jobs: run: pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x - name: Install MMCls run: pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x + - name: Install MMPretrain + run: pip install git+https://github.com/open-mmlab/mmpretrain.git@dev - name: Install MMPose run: pip install git+https://github.com/open-mmlab/mmpose.git@dev-1.x - name: Install unittest dependencies @@ -186,6 +190,7 @@ jobs: mim install 'mmcv >= 2.0.0' pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x + pip install git+https://github.com/open-mmlab/mmpretrain.git@dev pip install git+https://github.com/open-mmlab/mmpose.git@dev-1.x pip install -r requirements.txt - name: Install PytorchVideo @@ -228,6 +233,7 @@ jobs: mim install 'mmcv >= 2.0.0' pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x + pip install git+https://github.com/open-mmlab/mmpretrain.git@dev pip install git+https://github.com/open-mmlab/mmpose.git@dev-1.x pip install -r requirements.txt - name: Install PytorchVideo diff --git a/.github/workflows/pr_stage_test.yml b/.github/workflows/pr_stage_test.yml index 2513d38596..63b9558e4b 100644 --- a/.github/workflows/pr_stage_test.yml +++ b/.github/workflows/pr_stage_test.yml @@ -51,6 +51,8 @@ jobs: run: pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x - name: Install MMCls run: pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x + - name: Install MMPretrain + run: pip install git+https://github.com/open-mmlab/mmpretrain.git@dev - name: Install MMPose run: pip install git+https://github.com/open-mmlab/mmpose.git@dev-1.x - name: Install unittest dependencies @@ -119,6 +121,7 @@ jobs: mim install 'mmcv >= 2.0.0' pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x + pip install git+https://github.com/open-mmlab/mmpretrain.git@dev pip install git+https://github.com/open-mmlab/mmpose.git@dev-1.x pip install -r requirements.txt - name: Install PytorchVideo @@ -168,6 +171,7 @@ jobs: mim install 'mmcv >= 2.0.0' pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x + pip install git+https://github.com/open-mmlab/mmpretrain.git@dev pip install git+https://github.com/open-mmlab/mmpose.git@dev-1.x pip install -r requirements.txt - name: Install PytorchVideo diff --git a/configs/_base_/models/tsm_mobileone_s4.py b/configs/_base_/models/tsm_mobileone_s4.py new file mode 100644 index 0000000000..df0c8f8c3c --- /dev/null +++ b/configs/_base_/models/tsm_mobileone_s4.py @@ -0,0 +1,31 @@ +# model settings +preprocess_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) + +checkpoint = ('https://download.openmmlab.com/mmclassification/' + 'v0/mobileone/mobileone-s4_8xb32_in1k_20221110-28d888cb.pth') +model = dict( + type='Recognizer2D', + backbone=dict( + type='MobileOneTSM', + arch='s4', + shift_div=8, + num_segments=8, + is_shift=True, + init_cfg=dict( + type='Pretrained', checkpoint=checkpoint, prefix='backbone')), + cls_head=dict( + type='TSMHead', + num_segments=8, + num_classes=400, + in_channels=2048, + spatial_type='avg', + consensus=dict(type='AvgConsensus', dim=1), + dropout_ratio=0.5, + init_std=0.001, + is_shift=True, + average_clips='prob'), + # model training and testing settings + data_preprocessor=dict(type='ActionDataPreprocessor', **preprocess_cfg), + train_cfg=None, + test_cfg=None) diff --git a/configs/_base_/models/tsn_mobileone_s0.py b/configs/_base_/models/tsn_mobileone_s0.py new file mode 100644 index 0000000000..83a070f143 --- /dev/null +++ b/configs/_base_/models/tsn_mobileone_s0.py @@ -0,0 +1,26 @@ +checkpoint = ('https://download.openmmlab.com/mmclassification/' + 'v0/mobileone/mobileone-s0_8xb32_in1k_20221110-0bc94952.pth') +model = dict( + type='Recognizer2D', + backbone=dict( + type='mmpretrain.MobileOne', + arch='s0', + init_cfg=dict( + type='Pretrained', checkpoint=checkpoint, prefix='backbone'), + norm_eval=False), + cls_head=dict( + type='TSNHead', + num_classes=400, + in_channels=1024, + spatial_type='avg', + consensus=dict(type='AvgConsensus', dim=1), + dropout_ratio=0.4, + init_std=0.01, + average_clips='prob'), + data_preprocessor=dict( + type='ActionDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + format_shape='NCHW'), + train_cfg=None, + test_cfg=None) diff --git a/configs/recognition/tsm/README.md b/configs/recognition/tsm/README.md index 3014d0e26b..0a02e14cf6 100644 --- a/configs/recognition/tsm/README.md +++ b/configs/recognition/tsm/README.md @@ -30,6 +30,7 @@ The explosive growth in video streaming gives rise to challenges on performing v | 1x1x8 | 224x224 | 8 | ResNet50 (NonLocalGauss) | ImageNet | 73.66 | 90.99 | 8 clips x 10 crop | 59.06G | 28.00M | [config](/configs/recognition/tsm/tsm_imagenet-pretrained-r50-nl-gaussian_8xb16-1x1x8-50e_kinetics400-rgb.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/recognition/tsm/tsm_imagenet-pretrained-r50-nl-gaussian_8xb16-1x1x8-50e_kinetics400-rgb/tsm_imagenet-pretrained-r50-nl-gaussian_8xb16-1x1x8-50e_kinetics400-rgb_20220831-7e54dacf.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/recognition/tsm/tsm_imagenet-pretrained-r50-nl-gaussian_8xb16-1x1x8-50e_kinetics400-rgb/tsm_imagenet-pretrained-r50-nl-gaussian_8xb16-1x1x8-50e_kinetics400-rgb.log) | | 1x1x8 | 224x224 | 8 | ResNet50 (NonLocalEmbedGauss) | ImageNet | 74.34 | 91.23 | 8 clips x 10 crop | 61.30G | 31.68M | [config](/configs/recognition/tsm/tsm_imagenet-pretrained-r50-nl-embedded-gaussian_8xb16-1x1x8-50e_kinetics400-rgb.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/recognition/tsm/tsm_imagenet-pretrained-r50-nl-embedded-gaussian_8xb16-1x1x8-50e_kinetics400-rgb/tsm_imagenet-pretrained-r50-nl-embedded-gaussian_8xb16-1x1x8-50e_kinetics400-rgb_20220831-35eddb57.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/recognition/tsm/tsm_imagenet-pretrained-r50-nl-embedded-gaussian_8xb16-1x1x8-50e_kinetics400-rgb/tsm_imagenet-pretrained-r50-nl-embedded-gaussian_8xb16-1x1x8-50e_kinetics400-rgb.log) | | 1x1x8 | 224x224 | 8 | MobileNetV2 | ImageNet | 68.71 | 88.32 | 8 clips x 3 crop | 3.269G | 2.736M | [config](/configs/recognition/tsm/tsm_imagenet-pretrained-mobilenetv2_8xb16-1x1x8-100e_kinetics400-rgb.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/recognition/tsm/tsm_imagenet-pretrained-mobilenetv2_8xb16-1x1x8-100e_kinetics400-rgb/tsm_imagenet-pretrained-mobilenetv2_8xb16-1x1x8-100e_kinetics400-rgb_20230414-401127fd.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/recognition/tsm/tsm_imagenet-pretrained-mobilenetv2_8xb16-1x1x8-100e_kinetics400-rgb/tsm_imagenet-pretrained-mobilenetv2_8xb16-1x1x8-100e_kinetics400-rgb.log) | +| 1x1x16 | 224x224 | 8 | MobileOne-S4 | ImageNet | 74.38 | 91.71 | 16 clips x 10 crop | 48.65G | 13.72M | [config](/configs/recognition/tsm/tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/recognition/tsm/tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb/tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb_20230825-a7f8876b.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/recognition/tsm/tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb/tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb.log) | ### Something-something V2 @@ -41,6 +42,7 @@ The explosive growth in video streaming gives rise to challenges on performing v 1. The **gpus** indicates the number of gpus we used to get the checkpoint. If you want to use a different number of gpus or videos per gpu, the best way is to set `--auto-scale-lr` when calling `tools/train.py`, this parameter will auto-scale the learning rate according to the actual batch size and the original batch size. 2. The validation set of Kinetics400 we used consists of 19796 videos. These videos are available at [Kinetics400-Validation](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155136485_link_cuhk_edu_hk/EbXw2WX94J1Hunyt3MWNDJUBz-nHvQYhO9pvKqm6g39PMA?e=a9QldB). The corresponding [data list](https://download.openmmlab.com/mmaction/dataset/k400_val/kinetics_val_list.txt) (each line is of the format 'video_id, num_frames, label_index') and the [label map](https://download.openmmlab.com/mmaction/dataset/k400_val/kinetics_class2ind.txt) are also available. +3. MoibleOne backbone supports reparameterization during inference. You can use the provided [reparameterize tool](/tools/convert/reparameterize_model.py) to convert the checkpoint and switch to the [deploy config file](/configs/recognition/tsm/tsm_imagenet-pretrained-mobileone-s4_deploy_8xb16-1x1x16-50e_kinetics400-rgb.py). For more details on data preparation, you can refer to [Kinetics400](/tools/data/kinetics/README.md). diff --git a/configs/recognition/tsm/metafile.yml b/configs/recognition/tsm/metafile.yml index 409f5a95df..0360c16758 100644 --- a/configs/recognition/tsm/metafile.yml +++ b/configs/recognition/tsm/metafile.yml @@ -167,6 +167,30 @@ Models: Training Log: https://download.openmmlab.com/mmaction/v1.0/recognition/tsm/tsm_imagenet-pretrained-r50-nl-gaussian_8xb16-1x1x8-50e_kinetics400-rgb/tsm_imagenet-pretrained-r50-nl-gaussian_8xb16-1x1x8-50e_kinetics400-rgb.log Weights: https://download.openmmlab.com/mmaction/v1.0/recognition/tsm/tsm_imagenet-pretrained-r50-nl-gaussian_8xb16-1x1x8-50e_kinetics400-rgb/tsm_imagenet-pretrained-r50-nl-gaussian_8xb16-1x1x8-50e_kinetics400-rgb_20220831-7e54dacf.pth + - Name: tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb + Config: configs/recognition/tsm/tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb.py + In Collection: TSM + Metadata: + Architecture: MobileOne-S4 + Batch Size: 16 + Epochs: 100 + FLOPs: 48.65G + Parameters: 13.72M + Pretrained: ImageNet + Resolution: 224x224 + Training Data: Kinetics-400 + Training Resources: 8 GPUs + Modality: RGB + Results: + - Dataset: Kinetics-400 + Task: Action Recognition + Metrics: + Top 1 Accuracy: 74.38 + Top 5 Accuracy: 91.71 + Training Log: https://download.openmmlab.com/mmaction/v1.0/recognition/tsm/tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb/tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb.log + Weights: https://download.openmmlab.com/mmaction/v1.0/recognition/tsm/tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb/tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb_20230825-a7f8876b.pth + + - Name: tsm_imagenet-pretrained-r101_8xb16-1x1x8-50e_sthv2-rgb Config: configs/recognition/tsm/tsm_imagenet-pretrained-r101_8xb16-1x1x8-50e_sthv2-rgb.py In Collection: TSM diff --git a/configs/recognition/tsm/tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb.py b/configs/recognition/tsm/tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb.py new file mode 100644 index 0000000000..e4fac52656 --- /dev/null +++ b/configs/recognition/tsm/tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb.py @@ -0,0 +1,126 @@ +_base_ = [ + '../../_base_/models/tsm_mobileone_s4.py', + '../../_base_/default_runtime.py' +] + +model = dict(cls_head=dict(num_segments=16)) +# dataset settings +dataset_type = 'VideoDataset' +data_root = 'data/kinetics400/videos_train' +data_root_val = 'data/kinetics400/videos_val' +ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt' +ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt' + +file_client_args = dict(io_backend='disk') + +train_pipeline = [ + dict(type='DecordInit', **file_client_args), + dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=16), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict( + type='MultiScaleCrop', + input_size=224, + scales=(1, 0.875, 0.75, 0.66), + random_crop=False, + max_wh_scale_gap=1, + num_fixed_crops=13), + dict(type='Resize', scale=(224, 224), keep_ratio=False), + dict(type='Flip', flip_ratio=0.5), + dict(type='FormatShape', input_format='NCHW'), + dict(type='PackActionInputs') +] +val_pipeline = [ + dict(type='DecordInit', **file_client_args), + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=16, + test_mode=True), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='CenterCrop', crop_size=224), + dict(type='FormatShape', input_format='NCHW'), + dict(type='PackActionInputs') +] +test_pipeline = [ + dict(type='DecordInit', **file_client_args), + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=16, + test_mode=True), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='ThreeCrop', crop_size=256), + dict(type='FormatShape', input_format='NCHW'), + dict(type='PackActionInputs') +] + +train_dataloader = dict( + batch_size=8, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=dict(video=data_root), + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=8, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=dict(video=data_root_val), + pipeline=val_pipeline, + test_mode=True)) +test_dataloader = dict( + batch_size=1, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=dict(video=data_root_val), + pipeline=test_pipeline, + test_mode=True)) + +val_evaluator = dict(type='AccMetric') +test_evaluator = val_evaluator + +default_hooks = dict(checkpoint=dict(interval=3, max_keep_ckpts=3)) + +train_cfg = dict( + type='EpochBasedTrainLoop', max_epochs=50, val_begin=1, val_interval=1) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +param_scheduler = [ + dict(type='LinearLR', start_factor=0.1, by_epoch=True, begin=0, end=5), + dict( + type='MultiStepLR', + begin=0, + end=50, + by_epoch=True, + milestones=[25, 45], + gamma=0.1) +] + +optim_wrapper = dict( + constructor='TSMOptimWrapperConstructor', + paramwise_cfg=dict(fc_lr5=True), + optimizer=dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.00002), + clip_grad=dict(max_norm=20, norm_type=2)) + +# Default setting for scaling LR automatically +# - `enable` means enable scaling LR automatically +# or not by default. +# - `base_batch_size` = (8 GPUs) x (16 samples per GPU). +auto_scale_lr = dict(enable=True, base_batch_size=128) diff --git a/configs/recognition/tsm/tsm_imagenet-pretrained-mobileone-s4_deploy_8xb16-1x1x16-50e_kinetics400-rgb.py b/configs/recognition/tsm/tsm_imagenet-pretrained-mobileone-s4_deploy_8xb16-1x1x16-50e_kinetics400-rgb.py new file mode 100644 index 0000000000..ecd0ed32e0 --- /dev/null +++ b/configs/recognition/tsm/tsm_imagenet-pretrained-mobileone-s4_deploy_8xb16-1x1x16-50e_kinetics400-rgb.py @@ -0,0 +1,5 @@ +_base_ = [ + './tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb.py', # noqa: E501 +] + +model = dict(backbone=dict(deploy=True)) diff --git a/configs/recognition/tsn/README.md b/configs/recognition/tsn/README.md index 8ff8222649..ca21386ce2 100644 --- a/configs/recognition/tsn/README.md +++ b/configs/recognition/tsn/README.md @@ -40,6 +40,7 @@ Deep convolutional networks have achieved great success for visual recognition i It's possible and convenient to use a 3rd-party backbone for TSN under the framework of MMAction2, here we provide some examples for: - [x] Backbones from [MMClassification](https://github.com/open-mmlab/mmclassification/) +- [x] Backbones from [MMPretrain](https://github.com/open-mmlab/mmpretrain) - [x] Backbones from [TorchVision](https://github.com/pytorch/vision/) - [x] Backbones from [TIMM (pytorch-image-models)](https://github.com/rwightman/pytorch-image-models) @@ -49,10 +50,12 @@ It's possible and convenient to use a 3rd-party backbone for TSN under the frame | 1x1x3 | MultiStep | 224x224 | 8 | DenseNet161 | ImageNet | 72.07 | 90.15 | 25 clips x 10 crop | 194.6G | 27.36M | [config](/configs/recognition/tsn/custom_backbones/tsn_imagenet-pretrained-dense161_8xb32-1x1x3-100e_kinetics400-rgb.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/recognition/tsn/tsn_imagenet-pretrained-dense161_8xb32-1x1x3-100e_kinetics400-rgb/tsn_imagenet-pretrained-dense161_8xb32-1x1x3-100e_kinetics400-rgb_20220906-5f4c0daf.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/recognition/tsn/tsn_imagenet-pretrained-dense161_8xb32-1x1x3-100e_kinetics400-rgb/tsn_imagenet-pretrained-dense161_8xb32-1x1x3-100e_kinetics400-rgb.log) | | 1x1x3 | MultiStep | 224x224 | 8 | Swin Transformer | ImageNet | 77.03 | 92.61 | 25 clips x 10 crop | 386.7G | 87.15M | [config](/configs/recognition/tsn/custom_backbones/tsn_imagenet-pretrained-swin-transformer_8xb32-1x1x3-100e_kinetics400-rgb.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/recognition/tsn/tsn_imagenet-pretrained-swin-transformer_8xb32-1x1x3-100e_kinetics400-rgb/tsn_imagenet-pretrained-swin-transformer_8xb32-1x1x3-100e_kinetics400-rgb_20220906-65ed814e.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/recognition/tsn/tsn_imagenet-pretrained-swin-transformer_8xb32-1x1x3-100e_kinetics400-rgb/tsn_imagenet-pretrained-swin-transformer_8xb32-1x1x3-100e_kinetics400-rgb.log) | | 1x1x8 | MultiStep | 224x224 | 8 | Swin Transformer | ImageNet | 79.22 | 94.20 | 25 clips x 10 crop | 386.7G | 87.15M | [config](/configs/recognition/tsn/custom_backbones/tsn_imagenet-pretrained-swin-transformer_32xb8-1x1x8-50e_kinetics400-rgb.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/recognition/tsn/tsn_imagenet-pretrained-swin-transformer_32xb8-1x1x8-50e_kinetics400-rgb/tsn_imagenet-pretrained-swin-transformer_32xb8-1x1x8-50e_kinetics400-rgb_20230530-428f0064.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/recognition/tsn/tsn_imagenet-pretrained-swin-transformer_32xb8-1x1x8-50e_kinetics400-rgb/tsn_imagenet-pretrained-swin-transformer_32xb8-1x1x8-50e_kinetics400-rgb.log) | +| 1x1x8 | MultiStep | 224x224 | 8 | MobileOne-S4 | ImageNet | 73.65 | 91.32 | 25 clips x 10 crop | 76G | 13.72M | [config](/configs/recognition/tsn/custom_backbones/tsn_imagenet-pretrained-mobileone-s4_8xb32-1x1x8-100e_kinetics400-rgb.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/recognition/tsn/tsn_imagenet-pretrained-mobileone-s4_8xb32-1x1x8-100e_kinetics400-rgb/tsn_imagenet-pretrained-mobileone-s4_8xb32-1x1x8-100e_kinetics400-rgb_20230825-2da3c1f7.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/recognition/tsn/tsn_imagenet-pretrained-mobileone-s4_8xb32-1x1x8-100e_kinetics400-rgb/tsn_imagenet-pretrained-mobileone-s4_8xb32-1x1x8-100e_kinetics400-rgb.log) | 1. Note that some backbones in TIMM are not supported due to multiple reasons. Please refer to [PR #880](https://github.com/open-mmlab/mmaction2/pull/880) for details. 2. The **gpus** indicates the number of gpus we used to get the checkpoint. If you want to use a different number of gpus or videos per gpu, the best way is to set `--auto-scale-lr` when calling `tools/train.py`, this parameter will auto-scale the learning rate according to the actual batch size and the original batch size. 3. The validation set of Kinetics400 we used consists of 19796 videos. These videos are available at [Kinetics400-Validation](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155136485_link_cuhk_edu_hk/EbXw2WX94J1Hunyt3MWNDJUBz-nHvQYhO9pvKqm6g39PMA?e=a9QldB). The corresponding [data list](https://download.openmmlab.com/mmaction/dataset/k400_val/kinetics_val_list.txt) (each line is of the format 'video_id, num_frames, label_index') and the [label map](https://download.openmmlab.com/mmaction/dataset/k400_val/kinetics_class2ind.txt) are also available. +4. MoibleOne backbone supports reparameterization during inference. You can use the provided [reparameterize tool](/tools/convert/reparameterize_model.py) to convert the checkpoint and switch to the [deploy config file](/configs/recognition/tsn/custom_backbones/tsn_imagenet-pretrained-mobileone-s4_deploy_8xb32-1x1x8-100e_kinetics400-rgb.py). For more details on data preparation, you can refer to diff --git a/configs/recognition/tsn/custom_backbones/tsn_imagenet-pretrained-mobileone-s4_8xb32-1x1x8-100e_kinetics400-rgb.py b/configs/recognition/tsn/custom_backbones/tsn_imagenet-pretrained-mobileone-s4_8xb32-1x1x8-100e_kinetics400-rgb.py new file mode 100644 index 0000000000..5f07bf40ab --- /dev/null +++ b/configs/recognition/tsn/custom_backbones/tsn_imagenet-pretrained-mobileone-s4_8xb32-1x1x8-100e_kinetics400-rgb.py @@ -0,0 +1,75 @@ +_base_ = ['../tsn_imagenet-pretrained-r50_8xb32-1x1x3-100e_kinetics400-rgb.py'] + +# dataset settings +checkpoint = ('https://download.openmmlab.com/mmclassification/' + 'v0/mobileone/mobileone-s4_8xb32_in1k_20221110-28d888cb.pth') +model = dict( + backbone=dict( + type='mmpretrain.MobileOne', + arch='s4', + out_indices=(3, ), + init_cfg=dict( + type='Pretrained', checkpoint=checkpoint, prefix='backbone'), + _delete_=True), + cls_head=dict(in_channels=2048)) + +dataset_type = 'VideoDataset' +data_root = 'data/kinetics400/videos_train' +data_root_val = 'data/kinetics400/videos_val' +ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt' +ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt' + +file_client_args = dict(io_backend='disk') + +train_pipeline = [ + dict(type='DecordInit', **file_client_args), + dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict( + type='MultiScaleCrop', + input_size=224, + scales=(1, 0.875, 0.75, 0.66), + random_crop=False, + max_wh_scale_gap=1), + dict(type='Resize', scale=(224, 224), keep_ratio=False), + dict(type='Flip', flip_ratio=0.5), + dict(type='FormatShape', input_format='NCHW'), + dict(type='PackActionInputs') +] +val_pipeline = [ + dict(type='DecordInit', **file_client_args), + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=8, + test_mode=True), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='CenterCrop', crop_size=224), + dict(type='FormatShape', input_format='NCHW'), + dict(type='PackActionInputs') +] + +train_dataloader = dict( + batch_size=16, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=dict(video=data_root), + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=16, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=dict(video=data_root_val), + pipeline=val_pipeline, + test_mode=True)) diff --git a/configs/recognition/tsn/custom_backbones/tsn_imagenet-pretrained-mobileone-s4_deploy_8xb32-1x1x8-100e_kinetics400-rgb.py b/configs/recognition/tsn/custom_backbones/tsn_imagenet-pretrained-mobileone-s4_deploy_8xb32-1x1x8-100e_kinetics400-rgb.py new file mode 100644 index 0000000000..38ab106a3f --- /dev/null +++ b/configs/recognition/tsn/custom_backbones/tsn_imagenet-pretrained-mobileone-s4_deploy_8xb32-1x1x8-100e_kinetics400-rgb.py @@ -0,0 +1,5 @@ +_base_ = [ + './tsn_imagenet-pretrained-mobileone-s4_8xb32-1x1x8-100e_kinetics400-rgb.py' # noqa: E501 +] + +model = dict(backbone=dict(deploy=True)) diff --git a/configs/recognition/tsn/metafile.yml b/configs/recognition/tsn/metafile.yml index 378040098c..06822d633c 100644 --- a/configs/recognition/tsn/metafile.yml +++ b/configs/recognition/tsn/metafile.yml @@ -215,6 +215,29 @@ Models: Training Log: https://download.openmmlab.com/mmaction/v1.0/recognition/tsn/tsn_imagenet-pretrained-swin-transformer_32xb8-1x1x8-50e_kinetics400-rgb/tsn_imagenet-pretrained-swin-transformer_32xb8-1x1x8-50e_kinetics400-rgb.log Weights: https://download.openmmlab.com/mmaction/v1.0/recognition/tsn/tsn_imagenet-pretrained-swin-transformer_8xb32-1x1x3-100e_kinetics400-rgb/tsn_imagenet-pretrained-swin-transformer_8xb32-1x1x3-100e_kinetics400-rgb_20220906-65ed814e.pth + - Name: tsn_imagenet-pretrained-mobileone-s4_8xb32-1x1x8-100e_kinetics400-rgb + Config: configs/recognition/tsn/custom_backbones/tsn_imagenet-pretrained-mobileone-s4_8xb32-1x1x8-100e_kinetics400-rgb.py + In Collection: TSN + Metadata: + Architecture: MobileOne-S4 + Batch Size: 32 + Epochs: 100 + FLOPs: 76G + Parameters: 13.72M + Pretrained: ImageNet + Resolution: 224x224 + Training Data: Kinetics-400 + Training Resources: 8 GPUs + Modality: RGB + Results: + - Dataset: Kinetics-400 + Task: Action Recognition + Metrics: + Top 1 Accuracy: 73.65 + Top 5 Accuracy: 91.32 + Training Log: https://download.openmmlab.com/mmaction/v1.0/recognition/tsn/tsn_imagenet-pretrained-mobileone-s4_8xb32-1x1x8-100e_kinetics400-rgb/tsn_imagenet-pretrained-mobileone-s4_8xb32-1x1x8-100e_kinetics400-rgb.log + Weights: https://download.openmmlab.com/mmaction/v1.0/recognition/tsn/tsn_imagenet-pretrained-mobileone-s4_8xb32-1x1x8-100e_kinetics400-rgb/tsn_imagenet-pretrained-mobileone-s4_8xb32-1x1x8-100e_kinetics400-rgb_20230825-2da3c1f7.pth + - Name: tsn_imagenet-pretrained-r50_8xb32-1x1x8-50e_sthv2-rgb Config: configs/recognition/tsn/tsn_imagenet-pretrained-r50_8xb32-1x1x8-50e_sthv2-rgb.py In Collection: TSN diff --git a/mmaction/models/backbones/__init__.py b/mmaction/models/backbones/__init__.py index 2f4eb4a7e3..8a69a057d6 100644 --- a/mmaction/models/backbones/__init__.py +++ b/mmaction/models/backbones/__init__.py @@ -33,3 +33,10 @@ 'TimeSformer', 'UniFormer', 'UniFormerV2', 'VisionTransformer', 'X3D', 'RGBPoseConv3D' ] + +try: + from .mobileone_tsm import MobileOneTSM # noqa: F401 + __all__.append('MobileOneTSM') + +except (ImportError, ModuleNotFoundError): + pass diff --git a/mmaction/models/backbones/mobileone_tsm.py b/mmaction/models/backbones/mobileone_tsm.py new file mode 100644 index 0000000000..96722faf68 --- /dev/null +++ b/mmaction/models/backbones/mobileone_tsm.py @@ -0,0 +1,140 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch.nn as nn +from mmengine.logging import MMLogger +from mmengine.runner.checkpoint import (_load_checkpoint, + _load_checkpoint_with_prefix) +from mmpretrain.models import MobileOne + +from mmaction.registry import MODELS +from .resnet_tsm import TemporalShift + + +@MODELS.register_module() +class MobileOneTSM(MobileOne): + """MobileOne backbone for TSM. + + Args: + arch (str | dict): MobileOne architecture. If use string, choose + from 's0', 's1', 's2', 's3' and 's4'. If use dict, it should + have below keys: + + - num_blocks (Sequence[int]): Number of blocks in each stage. + - width_factor (Sequence[float]): Width factor in each stage. + - num_conv_branches (Sequence[int]): Number of conv branches + in each stage. + - num_se_blocks (Sequence[int]): Number of SE layers in each + stage, all the SE layers are placed in the subsequent order + in each stage. + + Defaults to 's0'. + num_segments (int): Number of frame segments. Defaults to 8. + is_shift (bool): Whether to make temporal shift in reset layers. + Defaults to True. + shift_div (int): Number of div for shift. Defaults to 8. + pretraind2d (bool): Whether to load pretrained 2D model. + Defaults to True. + **kwargs (keyword arguments, optional): Arguments for MobileOne. + """ + + def __init__(self, + arch: str, + num_segments: int = 8, + is_shift: bool = True, + shift_div: int = 8, + pretrained2d: bool = True, + **kwargs): + super().__init__(arch, **kwargs) + self.num_segments = num_segments + self.is_shift = is_shift + self.shift_div = shift_div + self.pretrained2d = pretrained2d + self.init_structure() + + def make_temporal_shift(self): + """Make temporal shift for some layers. + + To make reparameterization work, we can only build the shift layer + before the 'block', instead of the 'blockres' + """ + + def make_block_temporal(stage, num_segments): + """Make temporal shift on some blocks. + + Args: + stage (nn.Module): Model layers to be shifted. + num_segments (int): Number of frame segments. + + Returns: + nn.Module: The shifted blocks. + """ + blocks = list(stage.children()) + for i, b in enumerate(blocks): + blocks[i] = TemporalShift( + b, num_segments=num_segments, shift_div=self.shift_div) + return nn.Sequential(*blocks) + + self.stage0 = make_block_temporal( + nn.Sequential(self.stage0), self.num_segments)[0] + for i in range(1, 5): + temporal_stage = make_block_temporal( + getattr(self, f'stage{i}'), self.num_segments) + setattr(self, f'stage{i}', temporal_stage) + + def init_structure(self): + """Initiate the parameters either from existing checkpoint or from + scratch.""" + if self.is_shift: + self.make_temporal_shift() + + def load_original_weights(self, logger): + assert self.init_cfg.get('type') == 'Pretrained', ( + 'Please specify ' + 'init_cfg to use pretrained 2d checkpoint') + self.pretrained = self.init_cfg.get('checkpoint') + prefix = self.init_cfg.get('prefix') + if prefix is not None: + original_state_dict = _load_checkpoint_with_prefix( + prefix, self.pretrained, map_location='cpu') + else: + original_state_dict = _load_checkpoint( + self.pretrained, map_location='cpu') + if 'state_dict' in original_state_dict: + original_state_dict = original_state_dict['state_dict'] + + wrapped_layers_map = dict() + for name, module in self.named_modules(): + ori_name = name + for wrap_prefix in ['.net']: + if wrap_prefix in ori_name: + ori_name = ori_name.replace(wrap_prefix, '') + wrapped_layers_map[ori_name] = name + + # convert wrapped keys + for param_name in list(original_state_dict.keys()): + layer_name = '.'.join(param_name.split('.')[:-1]) + if layer_name in wrapped_layers_map: + wrapped_name = param_name.replace( + layer_name, wrapped_layers_map[layer_name]) + original_state_dict[wrapped_name] = original_state_dict.pop( + param_name) + + msg = self.load_state_dict(original_state_dict, strict=True) + logger.info(msg) + + def init_weights(self): + """Initiate the parameters either from existing checkpoint or from + scratch.""" + if self.pretrained2d: + logger = MMLogger.get_current_instance() + self.load_original_weights(logger) + else: + super().init_weights() + + def forward(self, x): + """unpack tuple result.""" + x = super().forward(x) + if isinstance(x, tuple): + assert len(x) == 1 + x = x[0] + return x diff --git a/tests/models/backbones/test_mobileone_tsm.py b/tests/models/backbones/test_mobileone_tsm.py new file mode 100644 index 0000000000..b018e9f5a2 --- /dev/null +++ b/tests/models/backbones/test_mobileone_tsm.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import tempfile + +import torch +from mmengine.runner import load_checkpoint, save_checkpoint +from mmengine.runner.checkpoint import _load_checkpoint_with_prefix + +from mmaction.models.backbones.mobileone_tsm import MobileOneTSM +from mmaction.testing import generate_backbone_demo_inputs + + +def test_mobileone_tsm_backbone(): + """Test MobileOne TSM backbone.""" + + from mmpretrain.models.backbones.mobileone import MobileOneBlock + + from mmaction.models.backbones.resnet_tsm import TemporalShift + + model = MobileOneTSM('s0', pretrained2d=False) + model.init_weights() + for cur_module in model.modules(): + if isinstance(cur_module, TemporalShift): + # TemporalShift is a wrapper of MobileOneBlock + assert isinstance(cur_module.net, MobileOneBlock) + assert cur_module.num_segments == model.num_segments + assert cur_module.shift_div == model.shift_div + + inputs = generate_backbone_demo_inputs((8, 3, 64, 64)) + + feat = model(inputs) + assert feat.shape == torch.Size([8, 1024, 2, 2]) + + model = MobileOneTSM('s1', pretrained2d=False) + feat = model(inputs) + assert feat.shape == torch.Size([8, 1280, 2, 2]) + + model = MobileOneTSM('s2', pretrained2d=False) + feat = model(inputs) + assert feat.shape == torch.Size([8, 2048, 2, 2]) + + model = MobileOneTSM('s3', pretrained2d=False) + feat = model(inputs) + assert feat.shape == torch.Size([8, 2048, 2, 2]) + + model = MobileOneTSM('s4', pretrained2d=False) + feat = model(inputs) + assert feat.shape == torch.Size([8, 2048, 2, 2]) + + +def test_mobileone_init_weight(): + checkpoint = ('https://download.openmmlab.com/mmclassification/v0' + '/mobileone/mobileone-s0_8xb32_in1k_20221110-0bc94952.pth') + # ckpt = torch.load(checkpoint)['state_dict'] + model = MobileOneTSM( + arch='s0', + init_cfg=dict( + type='Pretrained', checkpoint=checkpoint, prefix='backbone')) + model.init_weights() + ori_ckpt = _load_checkpoint_with_prefix( + 'backbone', model.init_cfg['checkpoint'], map_location='cpu') + for name, param in model.named_parameters(): + ori_name = name.replace('.net', '') + assert torch.allclose(param, ori_ckpt[ori_name]), \ + f'layer {name} fail to load from pretrained checkpoint' + + +def test_load_deploy_mobileone(): + # Test output before and load from deploy checkpoint + model = MobileOneTSM('s0', pretrained2d=False) + inputs = generate_backbone_demo_inputs((8, 3, 64, 64)) + tmpdir = tempfile.gettempdir() + ckpt_path = os.path.join(tmpdir, 'ckpt.pth') + model.switch_to_deploy() + model.eval() + outputs = model(inputs) + + model_deploy = MobileOneTSM('s0', pretrained2d=False, deploy=True) + save_checkpoint(model.state_dict(), ckpt_path) + load_checkpoint(model_deploy, ckpt_path) + + outputs_load = model_deploy(inputs) + for feat, feat_load in zip(outputs, outputs_load): + assert torch.allclose(feat, feat_load) + os.remove(ckpt_path) diff --git a/tests/models/recognizers/test_recognizer2d.py b/tests/models/recognizers/test_recognizer2d.py index 3a13b0ef37..9c48877204 100644 --- a/tests/models/recognizers/test_recognizer2d.py +++ b/tests/models/recognizers/test_recognizer2d.py @@ -90,6 +90,16 @@ def test_tsn_mmcls_backbone(): train_test_step(config, input_shape) +def test_tsn_mobileone(): + register_all_modules() + config = get_recognizer_cfg( + 'tsn/custom_backbones/tsn_imagenet-pretrained-mobileone-s4_8xb32-1x1x8-100e_kinetics400-rgb.py' # noqa: E501 + ) + config.model['backbone']['init_cfg'] = None + input_shape = (1, 3, 3, 32, 32) + train_test_step(config, input_shape) + + def test_tsn_timm_backbone(): # test tsn from timm register_all_modules() @@ -142,6 +152,7 @@ def test_tsn_tv_backbone(): def test_tsm(): register_all_modules() + # test tsm-mobilenetv2 config = get_recognizer_cfg( 'tsm/tsm_imagenet-pretrained-mobilenetv2_8xb16-1x1x8-100e_kinetics400-rgb.py' # noqa: E501 ) @@ -151,6 +162,7 @@ def test_tsm(): input_shape = (1, 8, 3, 32, 32) train_test_step(config, input_shape) + # test tsm-res50 config = get_recognizer_cfg( 'tsm/tsm_imagenet-pretrained-r50_8xb16-1x1x8-50e_kinetics400-rgb.py') config.model['backbone']['pretrained'] = None @@ -159,6 +171,16 @@ def test_tsm(): input_shape = (1, 8, 3, 32, 32) train_test_step(config, input_shape) + # test tsm-mobileone + config = get_recognizer_cfg( + 'tsm/tsm_imagenet-pretrained-mobileone-s4_8xb16-1x1x16-50e_kinetics400-rgb.py' # noqa: E501 + ) + config.model['backbone']['init_cfg'] = None + config.model['backbone']['pretrained2d'] = None + + input_shape = (1, 16, 3, 32, 32) + train_test_step(config, input_shape) + def test_trn(): register_all_modules() diff --git a/tools/convert/reparameterize_model.py b/tools/convert/reparameterize_model.py new file mode 100644 index 0000000000..6220e092fc --- /dev/null +++ b/tools/convert/reparameterize_model.py @@ -0,0 +1,57 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +from pathlib import Path + +import torch + +from mmaction.apis import init_recognizer +from mmaction.models.recognizers import BaseRecognizer + + +def convert_recoginzer_to_deploy(model, checkpoint, save_path): + print('Converting...') + assert hasattr(model, 'backbone') and \ + hasattr(model.backbone, 'switch_to_deploy'), \ + '`model.backbone` must has method of "switch_to_deploy".' \ + f' But {model.backbone.__class__} does not have.' + + model.backbone.switch_to_deploy() + checkpoint['state_dict'] = model.state_dict() + torch.save(checkpoint, save_path) + + print('Done! Save at path "{}"'.format(save_path)) + + +def main(): + parser = argparse.ArgumentParser( + description='Convert the parameters of the repvgg block ' + 'from training mode to deployment mode.') + parser.add_argument( + 'config_path', + help='The path to the configuration file of the network ' + 'containing the repvgg block.') + parser.add_argument( + 'checkpoint_path', + help='The path to the checkpoint file corresponding to the model.') + parser.add_argument( + 'save_path', + help='The path where the converted checkpoint file is stored.') + args = parser.parse_args() + + save_path = Path(args.save_path) + if save_path.suffix != '.pth' and save_path.suffix != '.tar': + print('The path should contain the name of the pth format file.') + exit() + save_path.parent.mkdir(parents=True, exist_ok=True) + + model = init_recognizer( + args.config_path, checkpoint=args.checkpoint_path, device='cpu') + assert isinstance(model, BaseRecognizer), \ + '`model` must be a `mmpretrain.classifiers.ImageClassifier` instance.' + + checkpoint = torch.load(args.checkpoint_path) + convert_recoginzer_to_deploy(model, checkpoint, args.save_path) + + +if __name__ == '__main__': + main()