Skip to content

Commit 7bd0362

Browse files
authored
Add Human3.6M pretrained models
* Added Human3.6M pretrained models * Update README.md
1 parent b171a0f commit 7bd0362

File tree

8 files changed

+123
-66
lines changed

8 files changed

+123
-66
lines changed

README.md

+19-25
Original file line numberDiff line numberDiff line change
@@ -22,34 +22,27 @@ pip install -r requirements.txt
2222

2323
#### Human3.6M
2424
1. Download and preprocess the dataset by following the instructions in [mvn/datasets/human36m_preprocessing/README.md](https://github.com/karfly/learnable-triangulation-pytorch/blob/master/mvn/datasets/human36m_preprocessing/README.md).
25-
2. Place the preprocessed dataset to `data/human36m`. If you don't want to store the dataset in the directory with code, just create a soft symbolic link: `ln -s {PATH_TO_HUMAN36M_DATASET} ./data/human36m`.
26-
3. Download pretrained backbone's weights from [here](https://drive.google.com/open?id=1TGHBfa9LsFPVS5CH6Qkcy5Jr2QsJdPEa) and place them here: `data/pretrained/human36m/pose_resnet_4.5_pixels_human36m.pth` (ResNet-152 trained on COCO dataset and finetuned jointly on MPII and Human3.6M).
27-
4. If you want to train Volumetric model, you need rough estimations of the 3D skeleton both for train and val splits. You have two options:
28-
- Rough 3D skeletons can be estimated by Algebraic model and placed to `data/precalculated_results/human36m/results_train.pkl` and `data/precalculated_results/human36m/results_val.pkl` respectively.
29-
- Other option is to use the ground truth (GT) estimate of the 3D skeleton by setting `use_gt_pelvis: true` in a config file. Here you don't need any precalculated results, but such training mode overestimates the resulting accuracy, because pelvis is always perfectly defined.
25+
2. Place the preprocessed dataset to `./data/human36m`. If you don't want to store the dataset in the directory with code, just create a soft symbolic link: `ln -s {PATH_TO_HUMAN36M_DATASET} ./data/human36m`.
26+
3. Download pretrained backbone's weights from [here](https://drive.google.com/open?id=1TGHBfa9LsFPVS5CH6Qkcy5Jr2QsJdPEa) and place them here: `./data/pretrained/human36m/pose_resnet_4.5_pixels_human36m.pth` (ResNet-152 trained on COCO dataset and finetuned jointly on MPII and Human3.6M).
27+
4. If you want to train Volumetric model, you need rough estimations of the 3D skeleton both for train and val splits. In the paper we estimate 3D skeletons via Algebraic model. You can use [pretrained](#model-zoo) Algebraic model to produce predictions or just take [precalculated 3D skeletons](#model-zoo).
3028

31-
#### CMU Panoptic
32-
*Will be added soon*
33-
34-
## Train
35-
Every experiment is defined by `.config` files. Configs with experiments from the paper can be found in `experiments` directory (results can be found below):
29+
## Model zoo
30+
In this section we collect pretrained models and configs. All **pretrained weights** and **precalculated 3D skeletons** can be downloaded from [Google Drive](https://drive.google.com/open?id=1TGHBfa9LsFPVS5CH6Qkcy5Jr2QsJdPEa) and placed to `./data` dir, so that eval configs can work out-of-the-box (without additional setting of paths).
3631

3732
**Human3.6M:**
38-
1. Algebraic w/o confidences — [experiments/human36m/train/human36m_alg_no_conf.yaml](https://github.com/karfly/learnable-triangulation-pytorch/blob/master/experiments/human36m/train/human36m_alg_no_conf.yaml)
39-
2. Algebraic w/ confidences — [experiments/human36m/train/human36m_alg.yaml](https://github.com/karfly/learnable-triangulation-pytorch/blob/master/experiments/human36m/train/human36m_alg.yaml)
40-
3. Volumetric (softmax aggregation) — [experiments/human36m/train/human36m_vol_softmax.yaml](https://github.com/karfly/learnable-triangulation-pytorch/blob/master/experiments/human36m/train/human36m_vol_softmax.yaml)
41-
4. Volumetric (softmax aggregation, GT pelvis) — [experiments/human36m/train/human36m_vol_softmax_gtpelvis.yaml](https://github.com/karfly/learnable-triangulation-pytorch/blob/master/experiments/human36m/train/human36m_vol_softmax_gtpelvis.yaml)
42-
43-
**CMU Panoptic**
44-
45-
*Will be added soon*
4633

34+
| Model | Train config | Eval config | Weights | Precalculated results | MPJPE (relative to pelvis), mm |
35+
|----------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------:|-------------------------------:|
36+
| Algebraic | [train/human36m_alg.yaml](https://github.com/karfly/learnable-triangulation-pytorch/blob/master/train/human36m_alg.yaml) | [eval/human36m_alg.yaml](https://github.com/karfly/learnable-triangulation-pytorch/blob/master/eval/human36m_alg.yaml) | [link](https://drive.google.com/file/d/1HAqMwH94kCfTs9jUHiuCB7vt94rMvxWe/view?usp=sharing) | [link](https://drive.google.com/drive/folders/1LCzMQswdn4UM9fbRYOZb3FmMZ7pZFyIP?usp=sharing) | 22.4 |
37+
| Volumetric (softmax) | [train/human36m_vol_softmax.yaml](https://github.com/karfly/learnable-triangulation-pytorch/blob/master/train/human36m_vol_softmax.yaml) | [eval/human36m_vol_softmax.yaml](https://github.com/karfly/learnable-triangulation-pytorch/blob/master/eval/human36m_vol_softmax.yaml) | [link](https://drive.google.com/file/d/1r6Ut3oMKPxhyxRh3PZ05taaXwekhJWqj/view?usp=sharing) || **20.5** |
38+
## Train
39+
Every experiment is defined by `.config` files. Configs with experiments from the paper can be found in the `./experiments` directory (see [model zoo](#model-zoo)).
4740

4841
#### Single-GPU
49-
To train a Volumetric model with softmax aggregation and GT-estimated pelvises using **1 GPU**, run:
42+
To train a Volumetric model with softmax aggregation using **1 GPU**, run:
5043
```bash
5144
python3 train.py \
52-
--config experiments/human36m/train/human36m_vol_softmax_gtpelvis.yaml \
45+
--config train/human36m_vol_softmax.yaml \
5346
--logdir ./logs
5447
```
5548

@@ -58,11 +51,11 @@ The training will start with the config file specified by `--config`, and logs (
5851
#### Multi-GPU (*in testing*)
5952
Multi-GPU training is implemented with PyTorch's [DistributedDataParallel](https://pytorch.org/docs/stable/nn.html#distributeddataparallel). It can be used both for single-machine and multi-machine (cluster) training. To run the processes use the PyTorch [launch utility](https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py).
6053

61-
To train a Volumetric model with softmax aggregation and GT-estimated pelvises using **2 GPUs on single machine**, run:
54+
To train a Volumetric model with softmax aggregation using **2 GPUs on single machine**, run:
6255
```bash
6356
python3 -m torch.distributed.launch --nproc_per_node=2 --master_port=2345 \
6457
train.py \
65-
--config experiments/human36m/train/human36m_vol_softmax_gtpelvis.yaml \
58+
--config train/human36m_vol_softmax.yaml \
6659
--logdir ./logs
6760
```
6861

@@ -86,7 +79,7 @@ Run:
8679
```bash
8780
python3 train.py \
8881
--eval --eval_dataset val \
89-
--config experiments/human36m/eval/human36m_vol_softmax.yaml \
82+
--config eval/human36m_vol_softmax.yaml \
9083
--logdir ./logs
9184
```
9285
Argument `--eval_dataset` can be `val` or `train`. Results can be seen in `logs` directory or in the tensorboard.
@@ -111,8 +104,8 @@ MPJPE relative to pelvis:
111104
| Kadkhodamohammadi & Padoy [\[5\]](#references) | 49.1 |
112105
| [Qiu et al.](https://github.com/microsoft/multiview-human-pose-estimation-pytorch) [\[9\]](#references) | 26.2 |
113106
| RANSAC (our implementation) | 27.4 |
114-
| **Ours, algebraic** | 22.6 |
115-
| **Ours, volumetric** | **20.8** |
107+
| **Ours, algebraic** | 22.4 |
108+
| **Ours, volumetric** | **20.5** |
116109

117110
<br>
118111
MPJPE absolute (scenes with invalid ground-truth annotations are excluded):
@@ -190,6 +183,7 @@ Volumetric triangulation additionally improves accuracy, drastically reducing th
190183
- [Ivan Bulygin](https://github.com/blufzzz)
191184

192185
# News
186+
**18 Oct 2019:** Pretrained models (algebraic and volumetric) for Human3.6M are released.
193187
**8 Oct 2019:** Code is released!
194188

195189
# References
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
title: "human36m_alg"
2+
kind: "human36m"
3+
vis_freq: 1000
4+
vis_n_elements: 10
5+
6+
image_shape: [384, 384]
7+
8+
opt:
9+
criterion: "MSESmooth"
10+
mse_smooth_threshold: 400
11+
12+
n_objects_per_epoch: 15000
13+
n_epochs: 9999
14+
15+
batch_size: 8
16+
val_batch_size: 100
17+
18+
lr: 0.00001
19+
20+
scale_keypoints_3d: 0.1
21+
22+
model:
23+
name: "alg"
24+
25+
init_weights: true
26+
checkpoint: "./data/pretrained/human36m/human36m_alg_10-04-2019/checkpoints/0060/weights.pth"
27+
28+
29+
use_confidences: true
30+
heatmap_multiplier: 100.0
31+
heatmap_softmax: true
32+
33+
backbone:
34+
name: "resnet152"
35+
style: "simple"
36+
37+
init_weights: true
38+
checkpoint: "./data/pretrained/human36m/pose_resnet_4.5_pixels_human36m.pth"
39+
40+
num_joints: 17
41+
num_layers: 152
42+
43+
dataset:
44+
kind: "human36m"
45+
46+
train:
47+
h36m_root: "./data/human36m/processed"
48+
labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy"
49+
with_damaged_actions: true
50+
undistort_images: true
51+
52+
scale_bbox: 1.0
53+
54+
shuffle: true
55+
randomize_n_views: false
56+
min_n_views: null
57+
max_n_views: null
58+
num_workers: 8
59+
60+
val:
61+
h36m_root: "./data/human36m/processed"
62+
labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy"
63+
with_damaged_actions: true
64+
undistort_images: true
65+
66+
scale_bbox: 1.0
67+
68+
shuffle: false
69+
randomize_n_views: false
70+
min_n_views: null
71+
max_n_views: null
72+
num_workers: 8
73+
74+
retain_every_n_frames_in_test: 1
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,33 @@
1-
title: "debug"
1+
title: "human36m_ransac"
22
kind: "human36m"
33
vis_freq: 1000
44
vis_n_elements: 10
55

66
image_shape: [384, 384]
77

88
opt:
9-
criterion: "MAE"
9+
criterion: "MSESmooth"
10+
mse_smooth_threshold: 400
1011

11-
use_volumetric_ce_loss: true
12-
volumetric_ce_loss_weight: 0.01
13-
14-
n_objects_per_epoch: 50
12+
n_objects_per_epoch: 15000
1513
n_epochs: 9999
1614

17-
batch_size: 5
18-
val_batch_size: 10
15+
batch_size: 8
16+
val_batch_size: 100
1917

20-
lr: 0.0001
21-
process_features_lr: 0.001
22-
volume_net_lr: 0.001
18+
lr: 0.00001
2319

2420
scale_keypoints_3d: 0.1
2521

2622
model:
27-
name: "vol"
28-
kind: "mpii"
29-
volume_aggregation_method: "softmax"
23+
name: "ransac"
3024

3125
init_weights: false
3226
checkpoint: ""
3327

34-
use_gt_pelvis: false
35-
36-
cuboid_side: 2500.0
37-
38-
volume_size: 64
39-
volume_multiplier: 1.0
40-
volume_softmax: true
41-
42-
heatmap_softmax: true
28+
direct_optimization: true
4329
heatmap_multiplier: 100.0
30+
heatmap_softmax: true
4431

4532
backbone:
4633
name: "resnet152"
@@ -58,8 +45,6 @@ dataset:
5845
train:
5946
h36m_root: "./data/human36m/processed"
6047
labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy"
61-
pred_results_path: "./data/precalculated_results/human36m/results_train.pkl"
62-
6348
with_damaged_actions: true
6449
undistort_images: true
6550

@@ -74,8 +59,6 @@ dataset:
7459
val:
7560
h36m_root: "./data/human36m/processed"
7661
labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy"
77-
pred_results_path: "./data/precalculated_results/human36m/results_val.pkl"
78-
7962
with_damaged_actions: true
8063
undistort_images: true
8164

@@ -87,4 +70,4 @@ dataset:
8770
max_n_views: null
8871
num_workers: 8
8972

90-
retain_every_n_frames_in_test: 30
73+
retain_every_n_frames_in_test: 1

experiments/human36m/train/human36m_vol_softmax_gtpelvis.yaml experiments/human36m/eval/human36m_vol_softmax.yaml

+8-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
title: "human36m_vol_softmax_gtpelvis"
1+
title: "human36m_vol_softmax"
22
kind: "human36m"
33
vis_freq: 1000
44
vis_n_elements: 10
@@ -15,7 +15,7 @@ opt:
1515
n_epochs: 9999
1616

1717
batch_size: 5
18-
val_batch_size: 10
18+
val_batch_size: 20
1919

2020
lr: 0.0001
2121
process_features_lr: 0.001
@@ -28,10 +28,10 @@ model:
2828
kind: "mpii"
2929
volume_aggregation_method: "softmax"
3030

31-
init_weights: false
32-
checkpoint: ""
31+
init_weights: true
32+
checkpoint: "./data/pretrained/human36m/human36m_vol_softmax_10-08-2019/checkpoints/0040/weights.pth"
3333

34-
use_gt_pelvis: true
34+
use_gt_pelvis: false
3535

3636
cuboid_side: 2500.0
3737

@@ -58,6 +58,7 @@ dataset:
5858
train:
5959
h36m_root: "./data/human36m/processed"
6060
labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy"
61+
pred_results_path: "./data/pretrained/human36m/human36m_alg_10-04-2019/checkpoints/0060/results/train.pkl"
6162

6263
with_damaged_actions: true
6364
undistort_images: true
@@ -73,6 +74,7 @@ dataset:
7374
val:
7475
h36m_root: "./data/human36m/processed"
7576
labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy"
77+
pred_results_path: "./data/pretrained/human36m/human36m_alg_10-04-2019/checkpoints/0060/results/val.pkl"
7678

7779
with_damaged_actions: true
7880
undistort_images: true
@@ -85,4 +87,4 @@ dataset:
8587
max_n_views: null
8688
num_workers: 8
8789

88-
retain_every_n_frames_in_test: 30
90+
retain_every_n_frames_in_test: 1

experiments/human36m/train/human36m_alg.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ opt:
99
criterion: "MSESmooth"
1010
mse_smooth_threshold: 400
1111

12-
n_objects_per_epoch: 10000
12+
n_objects_per_epoch: 15000
1313
n_epochs: 9999
1414

1515
batch_size: 8

mvn/datasets/human36m.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,8 @@ def __getitem__(self, idx):
180180
# save sample's index
181181
sample['indexes'] = idx
182182

183-
try:
183+
if self.keypoints_3d_pred is not None:
184184
sample['pred_keypoints_3d'] = self.keypoints_3d_pred[idx]
185-
except AttributeError:
186-
pass
187185

188186
sample.default_factory = None
189187
return sample
@@ -270,4 +268,4 @@ def evaluate(self, keypoints_3d_predicted, split_by_subject=False, transfer_cmu_
270268
'per_pose_error_relative': self.evaluate_using_per_pose_error(per_pose_error_relative, split_by_subject)
271269
}
272270

273-
return result['per_pose_error']['Average']['Average'], result
271+
return result['per_pose_error_relative']['Average']['Average'], result

mvn/models/pose_resnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,6 @@ def get_pose_net(config, device='cuda:0'):
372372
print("Parameters [{}] were not inited".format(not_inited_params))
373373

374374
model.load_state_dict(new_pretrained_state_dict, strict=False)
375-
print("Successfully loaded pretrained weights")
375+
print("Successfully loaded pretrained weights for backbone")
376376

377377
return model

train.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from collections import defaultdict
88
from itertools import islice
99
import pickle
10+
import copy
1011

1112
import numpy as np
1213
import cv2
@@ -406,7 +407,12 @@ def main(args):
406407

407408
if config.model.init_weights:
408409
state_dict = torch.load(config.model.checkpoint)
409-
model.load_state_dict(state_dict, strict=False)
410+
for key in list(state_dict.keys()):
411+
new_key = key.replace("module.", "")
412+
state_dict[new_key] = state_dict.pop(key)
413+
414+
model.load_state_dict(state_dict, strict=True)
415+
print("Successfully loaded pretrained weights for whole model")
410416

411417
# criterion
412418
criterion_class = {

0 commit comments

Comments
 (0)