Skip to content

Commit 289effd

Browse files
committed
add depthsplat depth models
1 parent 95ffabe commit 289effd

14 files changed

+3364
-14
lines changed

.gitignore

+165
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
share/python-wheels/
24+
*.egg-info/
25+
.installed.cfg
26+
*.egg
27+
MANIFEST
28+
29+
# PyInstaller
30+
# Usually these files are written by a python script from a template
31+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
32+
*.manifest
33+
*.spec
34+
35+
# Installer logs
36+
pip-log.txt
37+
pip-delete-this-directory.txt
38+
39+
# Unit test / coverage reports
40+
htmlcov/
41+
.tox/
42+
.nox/
43+
.coverage
44+
.coverage.*
45+
.cache
46+
nosetests.xml
47+
coverage.xml
48+
*.cover
49+
*.py,cover
50+
.hypothesis/
51+
.pytest_cache/
52+
cover/
53+
54+
# Translations
55+
*.mo
56+
*.pot
57+
58+
# Django stuff:
59+
*.log
60+
local_settings.py
61+
db.sqlite3
62+
db.sqlite3-journal
63+
64+
# Flask stuff:
65+
instance/
66+
.webassets-cache
67+
68+
# Scrapy stuff:
69+
.scrapy
70+
71+
# Sphinx documentation
72+
docs/_build/
73+
74+
# PyBuilder
75+
.pybuilder/
76+
target/
77+
78+
# Jupyter Notebook
79+
.ipynb_checkpoints
80+
81+
# IPython
82+
profile_default/
83+
ipython_config.py
84+
85+
# pyenv
86+
# For a library or package, you might want to ignore these files since the code is
87+
# intended to run in multiple environments; otherwise, check them in:
88+
# .python-version
89+
90+
# pipenv
91+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
93+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
94+
# install all needed dependencies.
95+
#Pipfile.lock
96+
97+
# poetry
98+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99+
# This is especially recommended for binary packages to ensure reproducibility, and is more
100+
# commonly ignored for libraries.
101+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102+
#poetry.lock
103+
104+
# pdm
105+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106+
#pdm.lock
107+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108+
# in version control.
109+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110+
.pdm.toml
111+
.pdm-python
112+
.pdm-build/
113+
114+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115+
__pypackages__/
116+
117+
# Celery stuff
118+
celerybeat-schedule
119+
celerybeat.pid
120+
121+
# SageMath parsed files
122+
*.sage.py
123+
124+
# Environments
125+
.env
126+
.venv
127+
env/
128+
venv/
129+
ENV/
130+
env.bak/
131+
venv.bak/
132+
133+
# Spyder project settings
134+
.spyderproject
135+
.spyproject
136+
137+
# Rope project settings
138+
.ropeproject
139+
140+
# mkdocs documentation
141+
/site
142+
143+
# mypy
144+
.mypy_cache/
145+
.dmypy.json
146+
dmypy.json
147+
148+
# Pyre type checker
149+
.pyre/
150+
151+
# pytype static type analyzer
152+
.pytype/
153+
154+
# Cython debug symbols
155+
cython_debug/
156+
157+
# PyCharm
158+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160+
# and can be added to the global gitignore or merged into this file. For a more nuclear
161+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
162+
#.idea/
163+
164+
output/
165+
pretrained/

README.md

+27-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ This project is developed based on our previous works:
4848
- [AANet: Adaptive Aggregation Network for Efficient Stereo Matching, CVPR 2020](https://github.com/haofeixu/aanet)
4949

5050

51+
## Updates
52+
53+
- 2025-01-04: Check out [DepthSplat](https://haofeixu.github.io/depthsplat/) for a modern multi-view depth model, which leverages monocular depth ([Depth Anything V2](https://github.com/DepthAnything/Depth-Anything-V2)) to significantly improve the robustness of UniMatch.
54+
55+
- 2025-01-04: The UniMatch depth model served as the foundational backbone of [MVSplat (ECCV 2024, Oral)](https://donydchen.github.io/mvsplat/) for sparse-view feed-forward 3DGS reconstruction.
5156

5257
## Installation
5358

@@ -67,11 +72,22 @@ bash pip_install.sh
6772
```
6873

6974

75+
To use the [depth models from DepthSplat](https://github.com/cvg/depthsplat/blob/main/MODEL_ZOO.md), you need to create a new conda environment with higher version dependencies:
76+
77+
```
78+
conda create -y -n depthsplat-depth python=3.10
79+
conda activate depthsplat-depth
80+
pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124
81+
pip install tensorboard==2.9.1 einops opencv-python>=4.8.1.78 matplotlib
82+
```
83+
7084

7185
## Model Zoo
7286

7387
A large number of pretrained models with different speed-accuracy trade-offs for flow, stereo and depth are available at [MODEL_ZOO.md](MODEL_ZOO.md).
7488

89+
Check out [DepthSplat's Model Zoo](https://github.com/cvg/depthsplat/blob/main/MODEL_ZOO.md) for better depth models.
90+
7591
We assume the downloaded weights are located under the `pretrained` directory.
7692

7793
Otherwise, you may need to change the corresponding paths in the scripts.
@@ -82,7 +98,7 @@ Otherwise, you may need to change the corresponding paths in the scripts.
8298

8399
Given an image pair or a video sequence, our code supports generating prediction results of optical flow, disparity and depth.
84100

85-
Please refer to [scripts/gmflow_demo.sh](scripts/gmflow_demo.sh), [scripts/gmstereo_demo.sh](scripts/gmstereo_demo.sh) and [scripts/gmdepth_demo.sh](scripts/gmdepth_demo.sh) for example usages.
101+
Please refer to [scripts/gmflow_demo.sh](scripts/gmflow_demo.sh), [scripts/gmstereo_demo.sh](scripts/gmstereo_demo.sh), [scripts/gmdepth_demo.sh](scripts/gmdepth_demo.sh) and [scripts/depthsplat_depth_demo.sh](scripts/depthsplat_depth_demo.sh) for example usages.
86102

87103

88104

@@ -142,6 +158,16 @@ This work is a substantial extension of our previous conference paper [GMFlow (C
142158
}
143159
```
144160

161+
Please consider citing [DepthSplat](https://arxiv.org/abs/2410.13862) if DepthSplat's depth model is used in your research.
162+
163+
```
164+
@article{xu2024depthsplat,
165+
title = {DepthSplat: Connecting Gaussian Splatting and Depth},
166+
author = {Xu, Haofei and Peng, Songyou and Wang, Fangjinhua and Blum, Hermann and Barath, Daniel and Geiger, Andreas and Pollefeys, Marc},
167+
journal = {arXiv preprint arXiv:2410.13862},
168+
year = {2024}
169+
}
170+
```
145171

146172

147173
## Acknowledgements

main_depth.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch.utils.tensorboard import SummaryWriter
77

88
from unimatch.unimatch import UniMatch
9+
from unimatch.unimatch_depthsplat import UniMatchDepthSplat
910
from dataloader.depth.datasets import DemonDataset, ScannetDataset
1011
from dataloader.depth import augmentation
1112
from loss.depth_loss import depth_loss_func, depth_grad_loss_func
@@ -85,6 +86,10 @@ def get_args_parser():
8586
parser.add_argument('--num_reg_refine', default=1, type=int,
8687
help='number of additional local regression refinement')
8788

89+
# depthsplat depth model
90+
parser.add_argument('--depthsplat_depth', action='store_true')
91+
parser.add_argument('--vit_type', default='vits', type=str, choices=['vits', 'vitb', 'vitl'])
92+
8893
# loss
8994
parser.add_argument('--depth_loss_weight', default=20, type=float)
9095
parser.add_argument('--depth_grad_loss_weight', default=20, type=float)
@@ -143,14 +148,20 @@ def main(args):
143148
setup_for_distributed(args.local_rank == 0)
144149

145150
# model
146-
model = UniMatch(feature_channels=args.feature_channels,
147-
num_scales=args.num_scales,
148-
upsample_factor=args.upsample_factor,
149-
num_head=args.num_head,
150-
ffn_dim_expansion=args.ffn_dim_expansion,
151-
num_transformer_layers=args.num_transformer_layers,
152-
reg_refine=args.reg_refine,
153-
task=args.task).to(device)
151+
if args.depthsplat_depth:
152+
model = UniMatchDepthSplat(num_scales=args.num_scales,
153+
upsample_factor=args.upsample_factor,
154+
vit_type=args.vit_type,
155+
).to(device)
156+
else:
157+
model = UniMatch(feature_channels=args.feature_channels,
158+
num_scales=args.num_scales,
159+
upsample_factor=args.upsample_factor,
160+
num_head=args.num_head,
161+
ffn_dim_expansion=args.ffn_dim_expansion,
162+
num_transformer_layers=args.num_transformer_layers,
163+
reg_refine=args.reg_refine,
164+
task=args.task).to(device)
154165

155166
if print_info:
156167
print(model)

scripts/depthsplat_depth_demo.sh

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#!/usr/bin/env bash
2+
3+
4+
# depthsplat-depth-small
5+
CUDA_VISIBLE_DEVICES=0 python main_depth.py \
6+
--inference_dir demo/depth-scannet \
7+
--output_path output/depthsplat-depth-small \
8+
--resume pretrained/depthsplat-depth-small-3d79dd5e.pth \
9+
--depthsplat_depth
10+
11+
# predict depth for both images
12+
# --pred_bidir_depth
13+
14+
15+
16+
# depthsplat-depth-base
17+
CUDA_VISIBLE_DEVICES=0 python main_depth.py \
18+
--inference_dir demo/depth-scannet \
19+
--output_path output/depthsplat-depth-base \
20+
--resume pretrained/depthsplat-depth-base-f57113bd.pth \
21+
--depthsplat_depth \
22+
--vit_type vitb \
23+
--num_scales 2 \
24+
--upsample_factor 4
25+
26+
27+
28+
# depthsplat-depth-large
29+
CUDA_VISIBLE_DEVICES=0 python main_depth.py \
30+
--inference_dir demo/depth-scannet \
31+
--output_path output/depthsplat-depth-large \
32+
--resume pretrained/depthsplat-depth-large-50d3d7cf.pth \
33+
--depthsplat_depth \
34+
--vit_type vitl \
35+
--num_scales 2 \
36+
--upsample_factor 4
37+

unimatch/backbone.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,12 @@ class CNNEncoder(nn.Module):
4040
def __init__(self, output_dim=128,
4141
norm_layer=nn.InstanceNorm2d,
4242
num_output_scales=1,
43+
return_all_scales=False,
4344
**kwargs,
4445
):
4546
super(CNNEncoder, self).__init__()
4647
self.num_branch = num_output_scales
48+
self.return_all_scales = return_all_scales
4749

4850
feature_dims = [64, 96, 128]
4951

@@ -56,14 +58,17 @@ def __init__(self, output_dim=128,
5658
self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4
5759

5860
# highest resolution 1/4 or 1/8
59-
stride = 2 if num_output_scales == 1 else 1
61+
if return_all_scales: # depthsplat
62+
stride = 2
63+
else:
64+
stride = 2 if num_output_scales == 1 else 1
6065
self.layer3 = self._make_layer(feature_dims[2], stride=stride,
6166
norm_layer=norm_layer,
6267
) # 1/4 or 1/8
6368

6469
self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0)
6570

66-
if self.num_branch > 1:
71+
if self.num_branch > 1 and not return_all_scales:
6772
if self.num_branch == 4:
6873
strides = (1, 2, 4, 8)
6974
elif self.num_branch == 3:
@@ -99,16 +104,27 @@ def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d):
99104
return nn.Sequential(*layers)
100105

101106
def forward(self, x):
107+
output_all_scales = []
102108
x = self.conv1(x)
103109
x = self.norm1(x)
104110
x = self.relu1(x)
105111

106112
x = self.layer1(x) # 1/2
113+
if self.return_all_scales:
114+
output_all_scales.append(x)
115+
107116
x = self.layer2(x) # 1/4
117+
if self.return_all_scales:
118+
output_all_scales.append(x)
119+
108120
x = self.layer3(x) # 1/8 or 1/4
109121

110122
x = self.conv2(x)
111123

124+
if self.return_all_scales:
125+
output_all_scales.append(x)
126+
return output_all_scales
127+
112128
if self.num_branch > 1:
113129
out = self.trident_conv([x] * self.num_branch) # high to low res
114130
else:

0 commit comments

Comments
 (0)