Skip to content

Commit

Permalink
[Enhance] Support 2D&3D Optical Flow Training (#2631)
Browse files Browse the repository at this point in the history
  • Loading branch information
makecent authored Aug 15, 2023
1 parent 33088ee commit 96d20d3
Show file tree
Hide file tree
Showing 6 changed files with 314 additions and 45 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
_base_ = '../../_base_/default_runtime.py'

model = dict(
type='Recognizer3D',
backbone=dict(
type='ResNet3dSlowOnly',
depth=50,
pretrained=None,
lateral=False,
in_channels=2,
conv1_kernel=(1, 7, 7),
conv1_stride_t=1,
pool1_stride_t=1,
inflate=(0, 0, 1, 1),
norm_eval=False),
cls_head=dict(
type='I3DHead',
in_channels=2048,
num_classes=400,
spatial_type='avg',
dropout_ratio=0.5,
average_clips='prob'),
data_preprocessor=dict(
type='ActionDataPreprocessor',
mean=[128, 128],
std=[128, 128],
format_shape='NCTHW'))

# dataset settings
dataset_type = 'RawframeDataset'
data_root = 'data/kinetics400/rawframes_train'
data_root_val = 'data/kinetics400/rawframes_val'
ann_file_train = 'data/kinetics400/kinetics400_train_list_flow.txt'
ann_file_val = 'data/kinetics400/kinetics400_val_list_flow.txt'
ann_file_test = 'data/kinetics400/kinetics400_val_list_flow.txt'
file_client_args = dict(io_backend='disk')
train_pipeline = [
dict(type='SampleFrames', clip_len=16, frame_interval=4, num_clips=1),
dict(type='RawFrameDecode', **file_client_args),
dict(type='Resize', scale=(-1, 256)),
dict(type='RandomResizedCrop'),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='PackActionInputs')
]

val_pipeline = [
dict(
type='SampleFrames',
clip_len=16,
frame_interval=4,
num_clips=2,
test_mode=True),
dict(type='RawFrameDecode', **file_client_args),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='PackActionInputs')
]

test_pipeline = [
dict(
type='SampleFrames',
clip_len=16,
frame_interval=4,
num_clips=10,
test_mode=True),
dict(type='RawFrameDecode', **file_client_args),
dict(type='Resize', scale=(-1, 256)),
dict(type='ThreeCrop', crop_size=256),
dict(type='FormatShape', input_format='NCTHW'),
dict(type='PackActionInputs')
]

train_dataloader = dict(
batch_size=16,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
ann_file=ann_file_train,
filename_tmpl='{}_{:05d}.jpg',
modality='Flow',
data_prefix=dict(img=data_root),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=16,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
ann_file=ann_file_val,
filename_tmpl='{}_{:05d}.jpg',
modality='Flow',
data_prefix=dict(img=data_root_val),
pipeline=val_pipeline,
test_mode=True))
test_dataloader = dict(
batch_size=1,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
ann_file=ann_file_val,
filename_tmpl='{}_{:05d}.jpg',
modality='Flow',
data_prefix=dict(img=data_root_val),
pipeline=test_pipeline,
test_mode=True))

val_evaluator = dict(type='AccMetric')
test_evaluator = val_evaluator

train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=256, val_begin=1, val_interval=8)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

# learning policy
param_scheduler = [
dict(type='LinearLR', start_factor=0.1, by_epoch=True, begin=0, end=34),
dict(
type='CosineAnnealingLR',
T_max=222,
eta_min=0,
by_epoch=True,
begin=34,
end=256)
]

optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.2, momentum=0.9, weight_decay=1e-4),
clip_grad=dict(max_norm=40, norm_type=2))

# runtime settings
default_hooks = dict(checkpoint=dict(interval=8, max_keep_ckpts=3))

# Default setting for scaling LR automatically
# - `enable` means enable scaling LR automatically
# or not by default.
# - `base_batch_size` = (8 GPUs) x (16 samples per GPU).
auto_scale_lr = dict(enable=False, base_batch_size=128)
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
_base_ = '../../_base_/default_runtime.py'

clip_len = 5

model = dict(
type='Recognizer2D',
backbone=dict(
type='ResNet',
pretrained='https://download.pytorch.org/models/resnet50-11ad3fa6.pth',
depth=50,
in_channels=2 * clip_len, # ``in_channels`` should be 2 * clip_len
norm_eval=False),
cls_head=dict(
type='TSNHead',
num_classes=400,
in_channels=2048,
spatial_type='avg',
consensus=dict(type='AvgConsensus', dim=1),
dropout_ratio=0.4,
init_std=0.01,
average_clips='prob'),
data_preprocessor=dict(
type='ActionDataPreprocessor',
mean=[128, 128] * clip_len, # ``in_channels`` should be 2 * clip_len
std=[128, 128] * clip_len, # ``in_channels`` should be 2 * clip_len
format_shape='NCHW'))

# dataset settings
dataset_type = 'RawframeDataset'
data_root = 'data/kinetics400/rawframes_train'
data_root_val = 'data/kinetics400/rawframes_val'
ann_file_train = 'data/kinetics400/kinetics400_train_list_flow.txt'
ann_file_val = 'data/kinetics400/kinetics400_val_list_flow.txt'
ann_file_test = 'data/kinetics400/kinetics400_val_list_flow.txt'
file_client_args = dict(io_backend='disk')
train_pipeline = [
dict(
type='SampleFrames', clip_len=clip_len, frame_interval=1, num_clips=3),
dict(type='RawFrameDecode', **file_client_args),
dict(type='Resize', scale=(-1, 256)),
dict(type='RandomResizedCrop'),
dict(type='Resize', scale=(224, 224), keep_ratio=False),
dict(type='Flip', flip_ratio=0.5),
dict(type='FormatShape', input_format='NCHW'),
dict(type='PackActionInputs')
]
val_pipeline = [
dict(
type='SampleFrames',
clip_len=clip_len,
frame_interval=1,
num_clips=3,
test_mode=True),
dict(type='RawFrameDecode', **file_client_args),
dict(type='Resize', scale=(-1, 256)),
dict(type='CenterCrop', crop_size=224),
dict(type='FormatShape', input_format='NCHW'),
dict(type='PackActionInputs')
]
test_pipeline = [
dict(
type='SampleFrames',
clip_len=clip_len,
frame_interval=1,
num_clips=25,
test_mode=True),
dict(type='RawFrameDecode'),
dict(type='Resize', scale=(-1, 256)),
dict(type='TenCrop', crop_size=224),
dict(type='FormatShape', input_format='NCHW'),
dict(type='PackActionInputs')
]

train_dataloader = dict(
batch_size=32,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
ann_file=ann_file_train,
filename_tmpl='{}_{:05d}.jpg',
modality='Flow',
data_prefix=dict(img=data_root),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=32,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
ann_file=ann_file_val,
filename_tmpl='{}_{:05d}.jpg',
modality='Flow',
data_prefix=dict(img=data_root_val),
pipeline=val_pipeline,
test_mode=True))
test_dataloader = dict(
batch_size=1,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
ann_file=ann_file_val,
filename_tmpl='{}_{:05d}.jpg',
modality='Flow',
data_prefix=dict(img=data_root_val),
pipeline=test_pipeline,
test_mode=True))

val_evaluator = dict(type='AccMetric')
test_evaluator = val_evaluator

train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=110, val_begin=1, val_interval=5)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.005, momentum=0.9, weight_decay=0.0001),
clip_grad=dict(max_norm=40, norm_type=2))

param_scheduler = [
dict(
type='MultiStepLR',
begin=0,
end=110,
by_epoch=True,
milestones=[70, 100],
gamma=0.1)
]

default_hooks = dict(checkpoint=dict(interval=5, max_keep_ckpts=3))

# Default setting for scaling LR automatically
# - `enable` means enable scaling LR automatically
# or not by default.
# - `base_batch_size` = (8 GPUs) x (32 samples per GPU).
auto_scale_lr = dict(enable=False, base_batch_size=256)
40 changes: 11 additions & 29 deletions mmaction/datasets/transforms/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,16 +204,20 @@ class FormatShape(BaseTransform):
"""Format final imgs shape to the given input_format.
Required keys:
- imgs (optional)
- heatmap_imgs (optional)
- modality (optional)
- num_clips
- clip_len
Modified Keys:
- imgs (optional)
- input_shape (optional)
- imgs
Added Keys:
- input_shape
- heatmap_input_shape (optional)
Args:
Expand All @@ -227,7 +231,7 @@ def __init__(self, input_format: str, collapse: bool = False) -> None:
self.input_format = input_format
self.collapse = collapse
if self.input_format not in [
'NCTHW', 'NCHW', 'NCHW_Flow', 'NCTHW_Heatmap', 'NPTCHW'
'NCTHW', 'NCHW', 'NCTHW_Heatmap', 'NPTCHW'
]:
raise ValueError(
f'The input format {self.input_format} is invalid.')
Expand Down Expand Up @@ -300,36 +304,14 @@ def transform(self, results: Dict) -> Dict:
elif self.input_format == 'NCHW':
imgs = results['imgs']
imgs = np.transpose(imgs, (0, 3, 1, 2))
if 'modality' in results and results['modality'] == 'Flow':
clip_len = results['clip_len']
imgs = imgs.reshape((-1, clip_len * imgs.shape[1]) +
imgs.shape[2:])
# M x C x H x W
results['imgs'] = imgs
results['input_shape'] = imgs.shape

elif self.input_format == 'NCHW_Flow':
num_imgs = len(results['imgs'])
assert num_imgs % 2 == 0
n = num_imgs // 2
h, w = results['imgs'][0].shape
x_flow = np.empty((n, h, w), dtype=np.float32)
y_flow = np.empty((n, h, w), dtype=np.float32)
for i in range(n):
x_flow[i] = results['imgs'][2 * i]
y_flow[i] = results['imgs'][2 * i + 1]
imgs = np.stack([x_flow, y_flow], axis=-1)

num_clips = results['num_clips']
clip_len = results['clip_len']
imgs = imgs.reshape((-1, num_clips, clip_len) + imgs.shape[1:])
# N_crops x N_clips x T x H x W x C
imgs = np.transpose(imgs, (0, 1, 2, 5, 3, 4))
# N_crops x N_clips x T x C x H x W
imgs = imgs.reshape((-1, imgs.shape[2] * imgs.shape[3]) +
imgs.shape[4:])
# M' x C' x H x W
# M' = N_crops x N_clips
# C' = T x C
results['imgs'] = imgs
results['input_shape'] = imgs.shape

elif self.input_format == 'NPTCHW':
num_proposals = results['num_proposals']
num_clips = results['num_clips']
Expand Down
8 changes: 2 additions & 6 deletions mmaction/datasets/transforms/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,11 +1418,7 @@ def transform(self, results: dict) -> dict:
for i, frame_idx in enumerate(results['frame_inds']):
# Avoid loading duplicated frames
if frame_idx in cache:
if modality == 'RGB':
imgs.append(cp.deepcopy(imgs[cache[frame_idx]]))
else:
imgs.append(cp.deepcopy(imgs[2 * cache[frame_idx]]))
imgs.append(cp.deepcopy(imgs[2 * cache[frame_idx] + 1]))
imgs.append(cp.deepcopy(imgs[cache[frame_idx]]))
continue
else:
cache[frame_idx] = i
Expand All @@ -1443,7 +1439,7 @@ def transform(self, results: dict) -> dict:
x_frame = mmcv.imfrombytes(x_img_bytes, flag='grayscale')
y_img_bytes = self.file_client.get(y_filepath)
y_frame = mmcv.imfrombytes(y_img_bytes, flag='grayscale')
imgs.extend([x_frame, y_frame])
imgs.append(np.stack([x_frame, y_frame], axis=-1))
else:
raise NotImplementedError

Expand Down
Loading

0 comments on commit 96d20d3

Please sign in to comment.