From 17b88a327f7089c64e217dfd10d2d6ab1652de9d Mon Sep 17 00:00:00 2001 From: wxDai Date: Thu, 12 Oct 2023 11:53:31 +0800 Subject: [PATCH] [Enhance] Support the Training of ActionClip (#2620) --- projects/actionclip/README.md | 34 +++- ...6-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py | 162 ++++++++++++++++++ ...2-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py | 162 ++++++++++++++++++ projects/actionclip/models/actionclip.py | 105 +++++++++--- 4 files changed, 437 insertions(+), 26 deletions(-) create mode 100644 projects/actionclip/configs/actionclip_vit-base-p16-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py create mode 100644 projects/actionclip/configs/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py diff --git a/projects/actionclip/README.md b/projects/actionclip/README.md index df694fd538..ffe14a4cae 100644 --- a/projects/actionclip/README.md +++ b/projects/actionclip/README.md @@ -46,24 +46,45 @@ Create a symbolic link from `$MMACTION2/data` to `./data` in the current directo ln -s ../../data ./data ``` +### Training commands + +**To train with single GPU:** + +```bash +mim train mmaction configs/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py +``` + +**To train with multiple GPUs:** + +```bash +mim train mmaction configs/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py --launcher pytorch --gpus 8 +``` + +**To train with multiple GPUs by slurm:** + +```bash +mim train mmaction configs/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py --launcher slurm \ + --gpus 8 --gpus-per-node 8 --partition $PARTITION +``` + ### Testing commands **To test with single GPU:** ```bash -mim test mmaction configs/actionclip_vit-base-p32-res224-clip-pre_1x1x8_k400-rgb.py --checkpoint $CHECKPOINT +mim test mmaction configs/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py --checkpoint $CHECKPOINT ``` **To test with multiple GPUs:** ```bash -mim test mmaction configs/actionclip_vit-base-p32-res224-clip-pre_1x1x8_k400-rgb.py --checkpoint $CHECKPOINT --launcher pytorch --gpus 8 +mim test mmaction configs/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py --checkpoint $CHECKPOINT --launcher pytorch --gpus 8 ``` **To test with multiple GPUs by slurm:** ```bash -mim test mmaction configs/actionclip_vit-base-p32-res224-clip-pre_1x1x8_k400-rgb.py --checkpoint $CHECKPOINT --launcher slurm \ +mim test mmaction configs/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py --checkpoint $CHECKPOINT --launcher slurm \ --gpus 8 --gpus-per-node 8 --partition $PARTITION ``` @@ -80,6 +101,13 @@ mim test mmaction configs/actionclip_vit-base-p32-res224-clip-pre_1x1x8_k400-rgb \[1\] The models are ported from the repo [ActionCLIP](https://github.com/sallymmx/ActionCLIP) and tested on our data. Currently, we only support the testing of ActionCLIP models. Due to the variation in testing data, our reported test accuracy differs from that of the original repository (on average, it is lower by one point). Please refer to this [issue](https://github.com/sallymmx/ActionCLIP/issues/14) for more details. +### Kinetics400 (Trained on Our K400 dataset) + +| frame sampling strategy | gpus | backbone | top1 acc | top5 acc | testing protocol | config | ckpt | log | +| :---------------------: | :--: | :------: | :------: | :------: | :---------------: | :-------------------------------------------: | :------------------------------------------: | :-----------------------------------------: | +| 1x1x8 | 8 | ViT-B/32 | 77.5 | 93.2 | 8 clips x 1 crop | [config](./configs/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/projects/actionclip/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb_20230801-8535b794.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/projects/actionclip/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.log) | +| 1x1x8 | 8 | ViT-B/16 | 81.3 | 95.2 | 8 clips x 1 crop | [config](./configs/actionclip_vit-base-p16-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/projects/actionclip/actionclip_vit-base-p16-res224-clip-pre_g8xb16_1x1x8_k400-rgb/actionclip_vit-base-p16-res224-clip-pre_g8xb16_1x1x8_k400-rgb_20230801-b307a0cd.pth) | [log](https://download.openmmlab.com/mmaction/v1.0/projects/actionclip/actionclip_vit-base-p16-res224-clip-pre_g8xb16_1x1x8_k400-rgb/actionclip_vit-base-p16-res224-clip-pre_g8xb16_1x1x8_k400-rgb.log) | + ## Zero-Shot Prediction We offer two methods for zero-shot prediction as follows. The `test.mp4` can be downloaded from [here](https://github-production-user-asset-6210df.s3.amazonaws.com/58767402/237333525-89ebee9a-573e-4e27-9047-0ad6422fa82f.mp4). diff --git a/projects/actionclip/configs/actionclip_vit-base-p16-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py b/projects/actionclip/configs/actionclip_vit-base-p16-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py new file mode 100644 index 0000000000..732fd6fac0 --- /dev/null +++ b/projects/actionclip/configs/actionclip_vit-base-p16-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py @@ -0,0 +1,162 @@ +custom_imports = dict(imports='models') + +num_segs = 8 + +model = dict( + type='ActionClip', + clip_arch='ViT-B/16', + num_adapter_segs=num_segs, + num_adapter_layers=6, + to_float32=True, + labels_or_label_file='configs/label_map_k400.txt', + data_preprocessor=dict( + type='ActionDataPreprocessor', + mean=[122.771, 116.746, 104.093], + std=[68.500, 66.632, 70.323], + format_shape='NCHW')) + +dataset_type = 'VideoDataset' +data_root = 'data/kinetics400/videos_train' +data_root_val = 'data/kinetics400/videos_val' +ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt' +ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt' +ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt' + +file_client_args = dict(io_backend='disk') +file_client_args = dict( + io_backend='petrel', + path_mapping=dict( + {'data/kinetics400/': 's3://openmmlab/datasets/action/Kinetics400/'})) + +train_pipeline = [ + dict(type='DecordInit', **file_client_args), + dict( + type='SampleFrames', clip_len=1, frame_interval=1, num_clips=num_segs), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='RandomResizedCrop'), + dict( + type='MultiScaleCrop', + input_size=224, + scales=(1, .875, .75, .66), + random_crop=False, + num_fixed_crops=13, + max_wh_scale_gap=1), + dict(type='Resize', scale=(224, 224), keep_ratio=False), + dict(type='Flip', flip_ratio=0.5), + dict(type='FormatShape', input_format='NCHW'), + dict(type='PackActionInputs') +] + +val_pipeline = [ + dict(type='DecordInit', **file_client_args), + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=num_segs, + test_mode=True), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='CenterCrop', crop_size=224), + dict(type='FormatShape', input_format='NCHW'), + dict(type='PackActionInputs') +] + +test_pipeline = val_pipeline + +train_dataloader = dict( + batch_size=16, + num_workers=16, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=dict(video=data_root), + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=16, + num_workers=16, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=dict(video=data_root_val), + pipeline=val_pipeline, + test_mode=True)) +test_dataloader = dict( + batch_size=1, + num_workers=16, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=dict(video=data_root_val), + pipeline=test_pipeline, + test_mode=True)) + +val_evaluator = dict(type='AccMetric') +test_evaluator = val_evaluator + +train_cfg = dict( + type='EpochBasedTrainLoop', max_epochs=50, val_begin=1, val_interval=1) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +optim_wrapper = dict( + optimizer=dict( + type='AdamW', lr=5e-6, betas=(0.9, 0.98), eps=1e-08, weight_decay=0.2), + paramwise_cfg=dict(custom_keys=dict(adapter=dict(lr_mult=10)))) + +param_scheduler = [ + dict( + type='LinearLR', + start_factor=0.01, + by_epoch=True, + begin=0, + end=5, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=45, + eta_min=0, + by_epoch=True, + begin=5, + end=50, + convert_to_iter_based=True) +] + +# Default setting for scaling LR automatically +# - `enable` means enable scaling LR automatically +# or not by default. +# - `base_batch_size` = (8 GPUs) x (16 samples per GPU). +auto_scale_lr = dict(enable=False, base_batch_size=128) + +default_scope = 'mmaction' + +default_hooks = dict( + runtime_info=dict(type='RuntimeInfoHook'), + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=100, ignore_last=False), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict( + type='CheckpointHook', interval=1, save_best='auto', max_keep_ckpts=5), + sampler_seed=dict(type='DistSamplerSeedHook'), + sync_buffers=dict(type='SyncBuffersHook')) + +env_cfg = dict( + cudnn_benchmark=False, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl')) + +log_processor = dict(type='LogProcessor', window_size=20, by_epoch=True) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict(type='ActionVisualizer', vis_backends=vis_backends) + +log_level = 'INFO' +load_from = None +resume = False diff --git a/projects/actionclip/configs/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py b/projects/actionclip/configs/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py new file mode 100644 index 0000000000..0991730c71 --- /dev/null +++ b/projects/actionclip/configs/actionclip_vit-base-p32-res224-clip-pre_g8xb16_1x1x8_k400-rgb.py @@ -0,0 +1,162 @@ +custom_imports = dict(imports='models') + +num_segs = 8 + +model = dict( + type='ActionClip', + clip_arch='ViT-B/32', + num_adapter_segs=num_segs, + num_adapter_layers=6, + to_float32=True, + labels_or_label_file='configs/label_map_k400.txt', + data_preprocessor=dict( + type='ActionDataPreprocessor', + mean=[122.771, 116.746, 104.093], + std=[68.500, 66.632, 70.323], + format_shape='NCHW')) + +dataset_type = 'VideoDataset' +data_root = 'data/kinetics400/videos_train' +data_root_val = 'data/kinetics400/videos_val' +ann_file_train = 'data/kinetics400/kinetics400_train_list_videos.txt' +ann_file_val = 'data/kinetics400/kinetics400_val_list_videos.txt' +ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt' + +file_client_args = dict(io_backend='disk') +file_client_args = dict( + io_backend='petrel', + path_mapping=dict( + {'data/kinetics400/': 's3://openmmlab/datasets/action/Kinetics400/'})) + +train_pipeline = [ + dict(type='DecordInit', **file_client_args), + dict( + type='SampleFrames', clip_len=1, frame_interval=1, num_clips=num_segs), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='RandomResizedCrop'), + dict( + type='MultiScaleCrop', + input_size=224, + scales=(1, .875, .75, .66), + random_crop=False, + num_fixed_crops=13, + max_wh_scale_gap=1), + dict(type='Resize', scale=(224, 224), keep_ratio=False), + dict(type='Flip', flip_ratio=0.5), + dict(type='FormatShape', input_format='NCHW'), + dict(type='PackActionInputs') +] + +val_pipeline = [ + dict(type='DecordInit', **file_client_args), + dict( + type='SampleFrames', + clip_len=1, + frame_interval=1, + num_clips=num_segs, + test_mode=True), + dict(type='DecordDecode'), + dict(type='Resize', scale=(-1, 256)), + dict(type='CenterCrop', crop_size=224), + dict(type='FormatShape', input_format='NCHW'), + dict(type='PackActionInputs') +] + +test_pipeline = val_pipeline + +train_dataloader = dict( + batch_size=16, + num_workers=16, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=dict(video=data_root), + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=16, + num_workers=16, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=dict(video=data_root_val), + pipeline=val_pipeline, + test_mode=True)) +test_dataloader = dict( + batch_size=1, + num_workers=16, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=dict(video=data_root_val), + pipeline=test_pipeline, + test_mode=True)) + +val_evaluator = dict(type='AccMetric') +test_evaluator = val_evaluator + +train_cfg = dict( + type='EpochBasedTrainLoop', max_epochs=50, val_begin=1, val_interval=1) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +optim_wrapper = dict( + optimizer=dict( + type='AdamW', lr=5e-6, betas=(0.9, 0.98), eps=1e-08, weight_decay=0.2), + paramwise_cfg=dict(custom_keys=dict(adapter=dict(lr_mult=10)))) + +param_scheduler = [ + dict( + type='LinearLR', + start_factor=0.01, + by_epoch=True, + begin=0, + end=5, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=45, + eta_min=0, + by_epoch=True, + begin=5, + end=50, + convert_to_iter_based=True) +] + +# Default setting for scaling LR automatically +# - `enable` means enable scaling LR automatically +# or not by default. +# - `base_batch_size` = (8 GPUs) x (16 samples per GPU). +auto_scale_lr = dict(enable=False, base_batch_size=128) + +default_scope = 'mmaction' + +default_hooks = dict( + runtime_info=dict(type='RuntimeInfoHook'), + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=100, ignore_last=False), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict( + type='CheckpointHook', interval=1, save_best='auto', max_keep_ckpts=5), + sampler_seed=dict(type='DistSamplerSeedHook'), + sync_buffers=dict(type='SyncBuffersHook')) + +env_cfg = dict( + cudnn_benchmark=False, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl')) + +log_processor = dict(type='LogProcessor', window_size=20, by_epoch=True) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict(type='ActionVisualizer', vis_backends=vis_backends) + +log_level = 'INFO' +load_from = None +resume = False diff --git a/projects/actionclip/models/actionclip.py b/projects/actionclip/models/actionclip.py index 923b78c68f..6b125b40b2 100644 --- a/projects/actionclip/models/actionclip.py +++ b/projects/actionclip/models/actionclip.py @@ -1,9 +1,11 @@ -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import clip import mmengine +import numpy as np import torch import torch.nn.functional as F +from mmengine.dist import all_gather, get_rank from mmengine.model import BaseModel from mmengine.structures import LabelData @@ -11,7 +13,23 @@ from .adapter import TransformerAdapter -def text_prompt(labels_or_label_file, template=None): +class GatherLayer(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, input: torch.Tensor) -> Tuple[List]: + ctx.save_for_backward(input) + output = all_gather(input) + return tuple(output) + + @staticmethod + def backward(ctx: Any, *grads: torch.Tensor) -> torch.Tensor: + input, = ctx.saved_tensors + grad_out = torch.zeros_like(input) + grad_out[:] = grads[get_rank()] + return grad_out + + +def text_prompt(labels_or_label_file, templates_or_template_file=None): if isinstance(labels_or_label_file, str): labels = mmengine.list_from_file(labels_or_label_file) elif isinstance(labels_or_label_file, list): @@ -20,8 +38,8 @@ def text_prompt(labels_or_label_file, template=None): raise ValueError(f'`labels_or_label_file` must be `list` or `str`, ' f'but got {type(labels_or_label_file)}') - if template is None: - template = [ + if templates_or_template_file is None: + templates = [ 'a photo of action {}', 'a picture of action {}', 'Human action of {}', '{}, an action', '{} this is an action', '{}, a video of action', 'Playing action of {}', '{}', @@ -30,15 +48,15 @@ def text_prompt(labels_or_label_file, template=None): 'Video classification of {}', 'A video of {}', 'The man is {}', 'The woman is {}' ] - elif isinstance(template, str): - template = [template] - elif not mmengine.is_seq_of(template, str): + elif isinstance(templates_or_template_file, str): + templates = mmengine.list_from_file(templates_or_template_file) + elif not mmengine.is_seq_of(templates_or_template_file, str): raise ValueError(f'`template` must be list of `str`, `str` or `None`, ' - f'but got {type(template)}') + f'but got {type(templates_or_template_file)}') - num_prompt = len(template) + num_prompt = len(templates) prompt = torch.cat( - [clip.tokenize(t.format(c)) for t in template for c in labels]) + [clip.tokenize(t.format(c)) for t in templates for c in labels]) return prompt, num_prompt @@ -49,18 +67,25 @@ def __init__(self, clip_arch: str, num_adapter_segs: int, num_adapter_layers: int = 6, + to_float32: bool = False, labels_or_label_file: Optional[Union[List[str], str]] = None, - template: Optional[Union[List[str], str]] = None, - data_preprocessor: Optional[Dict] = None): + templates_or_template_file: Optional[Union[List[str], + str]] = None, + data_preprocessor: Optional[Dict] = None, + loss: Dict = dict(type='CrossEntropyLoss', loss_weight=0.5)): super(ActionClip, self).__init__(data_preprocessor=data_preprocessor) - self.clip = clip.load(clip_arch)[0] + self.clip = clip.load(clip_arch, device='cpu')[0] + if to_float32: + self.clip.float() + self.adapter = TransformerAdapter(self.clip, num_adapter_segs, num_adapter_layers) + self.loss = MODELS.build(loss) + if labels_or_label_file is not None: - self.prompt, self.num_prompt = text_prompt(labels_or_label_file, - template) - self.text_features = None + self.prompt, self.num_prompt = text_prompt( + labels_or_label_file, templates_or_template_file) def encode_video(self, video): b, n, c, h, w = video.shape @@ -95,14 +120,13 @@ def forward(self, bsz = len(data_samples) num_views = video_features.shape[0] // bsz - if self.text_features is None: - text_features = self.encode_text(self.prompt.to(inputs.device)) - self.text_features = text_features / text_features.norm( - dim=-1, keepdim=True) + text_features = self.encode_text(self.prompt.to(inputs.device)) + text_features = text_features / text_features.norm( + dim=-1, keepdim=True) # (bsz*num_views, num_prompt, num_classes) -> # (bsz, num_views*num_prompt, num_classes) - similarity = (100.0 * video_features @ self.text_features.T). \ + similarity = (100.0 * video_features @ text_features.T). \ view(bsz, num_views * self.num_prompt, -1) cls_scores = F.softmax(similarity, dim=2).mean(dim=1) @@ -112,6 +136,41 @@ def forward(self, return data_samples + elif mode == 'loss': + video_features = self.encode_video(inputs) + video_features = video_features / video_features.norm( + dim=-1, keepdim=True) + + text_id = np.random.randint( + self.num_prompt, size=len(data_samples)) + real_labels = [x.gt_labels.item.item() for x in data_samples] + selected_prompt = self.prompt.view( + self.num_prompt, -1, + self.prompt.shape[-1])[text_id, real_labels].to(inputs.device) + + text_features = self.encode_text(selected_prompt) + text_features = text_features / text_features.norm( + dim=-1, keepdim=True) + + video_features = torch.cat( + GatherLayer.apply(video_features), dim=0) + text_features = torch.cat(GatherLayer.apply(text_features), dim=0) + + logit_scale = self.clip.logit_scale.exp() + logits_per_video = logit_scale * video_features @ text_features.t() + logits_per_text = logits_per_video.t() + labels = torch.arange(logits_per_video.shape[0]).to( + logit_scale.device) + + sim_loss_v2t = self.loss(logits_per_video, labels) + sim_loss_t2v = self.loss(logits_per_text, labels) + + losses = dict() + losses['sim_loss_v2t'] = sim_loss_v2t + losses['sim_loss_t2v'] = sim_loss_t2v + return losses + else: - raise RuntimeError(f'Invalid mode "{mode}". ' - 'Only supports `predict` and `tensor` mode. ') + raise RuntimeError( + f'Invalid mode "{mode}". ' + 'Only supports `predict`, `loss` and `tensor` mode. ')