Skip to content

Commit beec058

Browse files
committed
add training code
1 parent 14682f9 commit beec058

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+2332
-372
lines changed

README.md

+6-3
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ Please find installation instructions for PyTorch and PySOT in [`INSTALL.md`](IN
4747

4848
### Add PySOT to your PYTHONPATH
4949
```bash
50-
export PYTHONPATH=/path/to/PySOT:$PYTHONPATH
50+
export PYTHONPATH=/path/to/pysot:$PYTHONPATH
5151
```
5252

5353
### Download models
@@ -57,7 +57,7 @@ Download models in [PySOT Model Zoo](MODEL_ZOO.md) and put the model.pth in the
5757
```bash
5858
python tools/demo.py \
5959
--config experiments/siamrpn_r50_l234_dwxcorr/config.yaml \
60-
--snapshot experiments/siamrpn_r50_l234_dwxcorr/model.pth \
60+
--snapshot experiments/siamrpn_r50_l234_dwxcorr/model.pth
6161
# --video demo/bag.avi # (in case you don't have webcam)
6262
```
6363

@@ -75,7 +75,7 @@ python -u ../../tools/test.py \
7575
The testing results will in the current directory(results/dataset/model_name/)
7676

7777
### Eval tracker
78-
assume still in experiments/siamrpn_r50_l234_dwxcorr
78+
assume still in experiments/siamrpn_r50_l234_dwxcorr_8gpu
7979
``` bash
8080
python ../../tools/eval.py \
8181
--tracker_path ./results \ # result path
@@ -84,6 +84,9 @@ python ../../tools/eval.py \
8484
--tracker_prefix 'model' # tracker_name
8585
```
8686

87+
### Training
88+
See [TRAIN.md](TRAIN.md) for detailed instruction.
89+
8790
## References
8891

8992
- [Fast Online Object Tracking and Segmentation: A Unifying Approach](https://arxiv.org/abs/1812.05050).

TRAIN.md

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# PySOT Training Tutorial
2+
3+
This implements training of SiamRPN with backbone architectures, such as ResNet, AlexNet.
4+
### Add PySOT to your PYTHONPATH
5+
```bash
6+
export PYTHONPATH=/path/to/pysot:$PYTHONPATH
7+
```
8+
9+
## Prepare training dataset
10+
Prepare training dataset, detailed preparations are listed in [training_dataset](training_dataset) directory.
11+
* [VID](http://image-net.org/challenges/LSVRC/2017/)
12+
* [YOUTUBEBB](https://research.google.com/youtube-bb/)
13+
* [DET](http://image-net.org/challenges/LSVRC/2017/)
14+
* [COCO](http://cocodataset.org)
15+
16+
## Download pretrained backbones
17+
Download pretrained backbones from [Google Driver](https://drive.google.com/drive/folders/1DuXVWVYIeynAcvt9uxtkuleV6bs6e3T9) and put them in `pretrained_models` directory
18+
19+
## Training
20+
21+
To train a model (SiamRPN++), run `train.py` with the desired configs:
22+
23+
```bash
24+
cd experiments/siamrpn_r50_l234_dwxcorr_8gpu
25+
```
26+
27+
### Multi-processing Distributed Data Parallel Training
28+
29+
#### Single node, multiple GPUs:
30+
```bash
31+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
32+
python -m torch.distributed.launch \
33+
--nproc_per_node=8 \
34+
--master_port=2333 \
35+
../../tools/train.py --cfg config.yaml
36+
```
37+
38+
#### Multiple nodes:
39+
Node 1: (IP: 192.168.1.1, and has a free port: 2333) master node
40+
```bash
41+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
42+
python -m torch.distributed.launch \
43+
--nnodes=2 \
44+
--node_rank=0 \
45+
--nproc_per_node=8 \
46+
--master_addr=192.168.1.1 \ # adjust your ip here
47+
--master_port=2333 \
48+
../../tools/train.py
49+
```
50+
Node 2:
51+
```bash
52+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
53+
python -m torch.distributed.launch \
54+
--nnodes=2 \
55+
--node_rank=1 \
56+
--nproc_per_node=8 \
57+
--master_addr=192.168.1.1 \
58+
--master_port=2333 \
59+
../../tools/train.py
60+
```
61+
62+
## Testing
63+
After training, you can test snapshots on VOT dataset.
64+
For `AlexNet`, you need to test snapshots from 35 to 50 epoch.
65+
For `ResNet`, you need to test snapshots from 10 to 20 epoch.
66+
67+
```bash
68+
START=10
69+
END=20
70+
seq $START 1 $END | \
71+
xargs -I {} echo "snapshot/checkpoint_e{}.pth" | \
72+
xargs -I {} \
73+
python -u ../tools/test.py \
74+
--snapshot {} \
75+
--config config.py \
76+
--dataset VOT2018 2>&1 | tee logs/test_dataset.log
77+
```
78+
79+
## Evaluation
80+
```
81+
python ../../tools/eval.py \
82+
--tracker_path ./results \ # result path
83+
--dataset VOT2018 \ # dataset name
84+
--num 4 \ # number thread to eval
85+
--tracker_prefix 'ch*' # tracker_name
86+
```

experiments/siammask_r50_l3/config.yaml

+21-9
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,35 @@
11
META_ARC: "siamrpn_r50_l234_dwxcorr"
22

33
BACKBONE:
4-
TYPE: "pysot.models.backbone.resnet_atrous.resnet50"
5-
LAYERS: [0, 1, 2, 3]
6-
CHANNELS: [1024]
4+
TYPE: "resnet50"
5+
KWARGS:
6+
used_layers: [0, 1, 2, 3]
77

88
ADJUST:
99
ADJUST: true
10-
TYPE: "pysot.models.neck.neck.AdjustAllLayer"
11-
ADJUST_CHANNEL: [256]
10+
TYPE: "AdjustAllLayer"
11+
KWARGS:
12+
in_channels: [1024]
13+
out_channels: [256]
1214

1315
RPN:
14-
TYPE: 'pysot.models.head.rpn.DepthwiseRPN'
16+
TYPE: 'DepthwiseRPN'
17+
KWARGS:
18+
anchor_num: 5
19+
in_channels: 256
20+
out_channels: 256
1521

1622
MASK:
1723
MASK: True
18-
MASK_TYPE: 'pysot.models.head.mask.MaskCorr'
24+
TYPE: 'MaskCorr'
25+
KWARGS:
26+
in_channels: 256
27+
hidden: 256
28+
out_channels: 3969
29+
30+
REFINE:
1931
REFINE: True
20-
REFINE_TYPE: 'pysot.models.head.mask.Refine'
32+
TYPE: 'Refine'
2133

2234
ANCHOR:
2335
STRIDE: 8
@@ -26,7 +38,7 @@ ANCHOR:
2638
ANCHOR_NUM: 5
2739

2840
TRACK:
29-
TYPE: 'pysot.tracker.siammask_tracker.SiamMaskTracker'
41+
TYPE: 'SiamMaskTracker'
3042
PENALTY_K: 0.10
3143
WINDOW_INFLUENCE: 0.41
3244
LR: 0.32

experiments/siamrpn_alex_dwxcorr/config.yaml

+9-7
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
META_ARC: "siamrpn_alex_dwxcorr"
22

33
BACKBONE:
4-
TYPE: "pysot.models.backbone.alexnet.alexnet"
5-
LAYERS: [-1]
6-
CHANNELS: [256]
7-
WIDTH_MULT: 1.
4+
TYPE: "alexnetlegacy"
5+
KWARGS:
6+
width_mult: 1.0
87

98
ADJUST:
109
ADJUST: False
1110

1211
RPN:
13-
TYPE: 'pysot.models.head.rpn.DepthwiseRPN'
14-
WEIGHTED: False
12+
TYPE: 'DepthwiseRPN'
13+
KWARGS:
14+
anchor_num: 5
15+
in_channels: 256
16+
out_channels: 256
1517

1618
MASK:
1719
MASK: False
@@ -23,7 +25,7 @@ ANCHOR:
2325
ANCHOR_NUM: 5
2426

2527
TRACK:
26-
TYPE: 'pysot.tracker.siamrpn_tracker.SiamRPNTracker'
28+
TYPE: 'SiamRPNTracker'
2729
PENALTY_K: 0.16
2830
WINDOW_INFLUENCE: 0.40
2931
LR: 0.30

experiments/siamrpn_alex_dwxcorr_otb/config.yaml

+9-7
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
META_ARC: "siamrpn_alex_dwxcorr"
22

33
BACKBONE:
4-
TYPE: "pysot.models.backbone.alexnet.alexnet"
5-
LAYERS: [-1]
6-
CHANNELS: [256]
7-
WIDTH_MULT: 1.
4+
TYPE: "alexnetlegacy"
5+
KWARGS:
6+
width_mult: 1.0
87

98
ADJUST:
109
ADJUST: False
1110

1211
RPN:
13-
TYPE: 'pysot.models.head.rpn.DepthwiseRPN'
14-
WEIGHTED: False
12+
TYPE: 'DepthwiseRPN'
13+
KWARGS:
14+
anchor_num: 5
15+
in_channels: 256
16+
out_channels: 256
1517

1618
MASK:
1719
MASK: False
@@ -23,7 +25,7 @@ ANCHOR:
2325
ANCHOR_NUM: 5
2426

2527
TRACK:
26-
TYPE: 'pysot.tracker.siamrpn_tracker.SiamRPNTracker'
28+
TYPE: 'SiamRPNTracker'
2729
PENALTY_K: 0.16
2830
WINDOW_INFLUENCE: 0.40
2931
LR: 0.30

experiments/siamrpn_mobilev2_l234_dwxcorr/config.yaml

+14-9
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,24 @@
11
META_ARC: "siamrpn_mobilev2_l234_dwxcorr"
22

33
BACKBONE:
4-
TYPE: "pysot.models.backbone.mobile_v2.mobilenetv2"
5-
LAYERS: [3, 5, 7]
6-
WIDTH_MULT: 1.4
7-
CHANNELS: [44, 134, 448]
4+
TYPE: "mobilenetv2"
5+
KWARGS:
6+
used_layers: [3, 5, 7]
7+
width_mult: 1.4
88

99
ADJUST:
1010
ADJUST: true
11-
TYPE: "pysot.models.neck.neck.AdjustAllLayer"
12-
ADJUST_CHANNEL: [256, 256, 256]
11+
TYPE: "AdjustAllLayer"
12+
KWARGS:
13+
in_channels: [44, 134, 448]
14+
out_channels: [256, 256, 256]
1315

1416
RPN:
15-
TYPE: 'pysot.models.head.rpn.MultiRPN'
16-
WEIGHTED: False
17+
TYPE: 'MultiRPN'
18+
KWARGS:
19+
anchor_num: 5
20+
in_channels: [256, 256, 256]
21+
weighted: True
1722

1823
MASK:
1924
MASK: False
@@ -25,7 +30,7 @@ ANCHOR:
2530
ANCHOR_NUM: 5
2631

2732
TRACK:
28-
TYPE: 'pysot.tracker.siamrpn_tracker.SiamRPNTracker'
33+
TYPE: 'SiamRPNTracker'
2934
PENALTY_K: 0.04
3035
WINDOW_INFLUENCE: 0.4
3136
LR: 0.5

experiments/siamrpn_r50_l234_dwxcorr/config.yaml

+14-9
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
11
META_ARC: "siamrpn_r50_l234_dwxcorr"
22

33
BACKBONE:
4-
TYPE: "pysot.models.backbone.resnet_atrous.resnet50"
5-
LAYERS: [2, 3, 4]
6-
CHANNELS: [512, 1024, 2048]
4+
TYPE: "resnet50"
5+
KWARGS:
6+
used_layers: [2, 3, 4]
77

88
ADJUST:
99
ADJUST: true
10-
TYPE: "pysot.models.neck.neck.AdjustAllLayer"
11-
ADJUST_CHANNEL: [256, 256, 256]
10+
TYPE: "AdjustAllLayer"
11+
KWARGS:
12+
in_channels: [512, 1024, 2048]
13+
out_channels: [256, 256, 256]
1214

1315
RPN:
14-
TYPE: 'pysot.models.head.rpn.MultiRPN'
15-
WEIGHTED: True
16+
TYPE: 'MultiRPN'
17+
KWARGS:
18+
anchor_num: 5
19+
in_channels: [256, 256, 256]
20+
weighted: true
1621

1722
MASK:
18-
MASK: False
23+
MASK: false
1924

2025
ANCHOR:
2126
STRIDE: 8
@@ -24,7 +29,7 @@ ANCHOR:
2429
ANCHOR_NUM: 5
2530

2631
TRACK:
27-
TYPE: 'pysot.tracker.siamrpn_tracker.SiamRPNTracker'
32+
TYPE: 'SiamRPNTracker'
2833
PENALTY_K: 0.05
2934
WINDOW_INFLUENCE: 0.42
3035
LR: 0.38

experiments/siamrpn_r50_l234_dwxcorr_lt/config.yaml

+14-9
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
1-
META_ARC: "siamrpn_r50_l234_dwxcorr"
1+
META_ARC: "siamrpn_r50_l234_dwxcorr_lt"
22

33
BACKBONE:
4-
TYPE: "pysot.models.backbone.resnet_atrous.resnet50"
5-
LAYERS: [2, 3, 4]
6-
CHANNELS: [512, 1024, 2048]
4+
TYPE: "resnet50"
5+
KWARGS:
6+
used_layers: [2, 3, 4]
77

88
ADJUST:
99
ADJUST: true
10-
TYPE: "pysot.models.neck.neck.AdjustAllLayer"
11-
ADJUST_CHANNEL: [128, 256, 512]
10+
TYPE: "AdjustAllLayer"
11+
KWARGS:
12+
in_channels: [512, 1024, 2048]
13+
out_channels: [128, 256, 512]
1214

1315
RPN:
14-
TYPE: 'pysot.models.head.rpn.MultiRPN'
15-
WEIGHTED: True
16+
TYPE: 'MultiRPN'
17+
KWARGS:
18+
anchor_num: 5
19+
in_channels: [128, 256, 512]
20+
weighted: True
1621

1722
MASK:
1823
MASK: False
@@ -24,7 +29,7 @@ ANCHOR:
2429
ANCHOR_NUM: 5
2530

2631
TRACK:
27-
TYPE: 'pysot.tracker.siamrpnlt_tracker.SiamRPNLTTracker'
32+
TYPE: 'SiamRPNLTTracker'
2833
PENALTY_K: 0.05
2934
WINDOW_INFLUENCE: 0.28
3035
LR: 0.22

0 commit comments

Comments
 (0)