Skip to content

Commit

Permalink
Merge branch 'main' into sarlinpe/cleanup-train
Browse files Browse the repository at this point in the history
  • Loading branch information
sarlinpe authored Dec 10, 2023
2 parents 4c52c03 + 4a82835 commit c53503c
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 14 deletions.
8 changes: 4 additions & 4 deletions gluefactory/configs/superpoint+lsd+gluestick-homography.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
data:
name: homographies
homography:
difficulty: 0.5
max_angle: 30
difficulty: 0.7
max_angle: 45
patch_shape: [640, 480]
photometric:
p: 0.75
train_size: 900000
val_size: 1000
batch_size: 80 # 20 per 10GB of GPU mem (12 for triplet)
batch_size: 160 # 20 per 10GB of GPU mem (12 for triplet)
num_workers: 15
model:
name: gluefactory.models.two_view_pipeline
Expand Down Expand Up @@ -70,4 +70,4 @@ train:
n_steps: 4
submodules: []
# clip_grad: 10 # Use only with mixed precision
# load_experiment:
# load_experiment:
13 changes: 9 additions & 4 deletions gluefactory/configs/superpoint+lsd+gluestick-megadepth.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
data:
name: gluefactory.datasets.megadepth
train_num_per_scene: 300
val_pairs: valid_pairs.txt
views: 2
min_overlap: 0.1
max_overlap: 0.7
num_overlap_bins: 3
preprocessing:
resize: 640
square_pad: True
batch_size: 60
batch_size: 160
num_workers: 15
model:
name: gluefactory.models.two_view_pipeline
Expand Down Expand Up @@ -53,9 +58,9 @@ model:
train:
seed: 0
epochs: 200
log_every_iter: 10
eval_every_iter: 100
save_every_iter: 500
log_every_iter: 400
eval_every_iter: 700
save_every_iter: 1400
lr: 1e-4
lr_schedule:
type: exp # exp or multi_step
Expand Down
2 changes: 1 addition & 1 deletion gluefactory/datasets/eth3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def download_eth3d(self):
if tmp_dir.exists():
shutil.rmtree(tmp_dir)
tmp_dir.mkdir(exist_ok=True, parents=True)
url_base = "https://cvg-data.inf.ethz.ch/ETH3D_undistorted/"
url_base = "https://cvg-data.inf.ethz.ch/SOLD2/SOLD2_ETH3D_undistorted/"
zip_name = "ETH3D_undistorted.zip"
zip_path = tmp_dir / zip_name
torch.hub.download_url_to_file(url_base + zip_name, zip_path)
Expand Down
2 changes: 1 addition & 1 deletion gluefactory/models/lines/deeplsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def download_model(self, path):

if not path.parent.is_dir():
path.parent.mkdir(parents=True, exist_ok=True)
link = "https://www.polybox.ethz.ch/index.php/s/XVb30sUyuJttFys/download"
link = "https://cvg-data.inf.ethz.ch/DeepLSD/deeplsd_md.tar"
cmd = ["wget", link, "-O", path]
print("Downloading DeepLSD model...")
subprocess.run(cmd, check=True)
Expand Down
4 changes: 1 addition & 3 deletions gluefactory/models/matchers/gluestick.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _init(self, conf):
state_dict = {
k.replace("module.", ""): v for k, v in state_dict.items()
}
self.load_state_dict(state_dict)
self.load_state_dict(state_dict, strict=False)

def _forward(self, data):
device = data["keypoints0"].device
Expand Down Expand Up @@ -200,8 +200,6 @@ def _forward(self, data):
kpts0 = normalize_keypoints(kpts0, image_size0)
kpts1 = normalize_keypoints(kpts1, image_size1)

assert torch.all(kpts0 >= -1) and torch.all(kpts0 <= 1)
assert torch.all(kpts1 >= -1) and torch.all(kpts1 <= 1)
desc0 = desc0 + self.kenc(kpts0, data["keypoint_scores0"])
desc1 = desc1 + self.kenc(kpts1, data["keypoint_scores1"])

Expand Down
2 changes: 1 addition & 1 deletion gluefactory/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def sigint_handler(signal, frame):

results = None # fix bug with it saving

lr_scheduler = get_lr_scheduler(optimizer=optimizer, conf=conf.train.lr_scheduler)
lr_scheduler = get_lr_scheduler(optimizer=optimizer, conf=conf.train.lr_schedule)
if args.restore:
optimizer.load_state_dict(init_cp["optimizer"])
if "lr_scheduler" in init_cp:
Expand Down

0 comments on commit c53503c

Please sign in to comment.