-
Notifications
You must be signed in to change notification settings - Fork 0
/
stage1.py
118 lines (103 loc) · 3.57 KB
/
stage1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import warnings
warnings.filterwarnings("ignore")
from mmcv import Config
import numpy as np
import torch
import torch.nn as nn
from glob import glob
import os
import sys
os.environ["PYOPENGL_PLATFORM"] = "egl"
import trimesh
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from pytorch3d.transforms import matrix_to_axis_angle, matrix_to_euler_angles
from tqdm import tqdm, trange
import argparse
import pytorch_lightning as pl
from src.models import ResnetFC
from src.utils import select_sampling_method_online, extract_mesh_from_mlp
from pathlib import Path
pl.seed_everything(1997)
torch.cuda.manual_seed_all(1997)
np.random.seed(1997)
torch.cuda.manual_seed(1997)
def overfit(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResnetFC(d_in=3, d_out=1, n_blocks=args.n_blocks, d_hidden=args.d_hidden)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
filename = Path(args.pcd_path)
out_root = Path(args.output_dir) / "preds"
out_root.mkdir(exist_ok=True, parents=True)
print(f"Overfitting on {filename}")
data = select_sampling_method_online(
filename, num_sample_inout=args.nsample, is_train=True
)
save_name = out_root / f"{data['name']}_overfitted_model_fine.pth"
for i in trange(args.epochs):
pts = data["samples"][None].cuda().transpose(2, 1)
gt = data["labels"][None].cuda().transpose(1, 2)
optimizer.zero_grad()
pred = model(pts.cuda()).sigmoid().squeeze(-1)
loss = torch.nn.functional.mse_loss(pred, gt.squeeze(-1))
loss.backward()
optimizer.step()
if i % 10 == 0:
print(f"Mesh: {data['name']} | Iter: {i} | Loss: {loss.item()}")
if i % 500 == 0:
# sample new points every 500 iterations for robust fitting
try:
with torch.no_grad():
model.eval()
mesh, sdf = extract_mesh_from_mlp(model, N=128)
model.train()
mesh.export(
out_root / f"{data['name']}_pred_128_iter_{i+1}_fine.ply"
)
except Exception as e:
print(f"Error in exporting mesh: {e}")
data = select_sampling_method_online(filename, 10000)
torch.save(
model.state_dict(),
str(save_name),
)
print(f"{filename} overfitted and saved to {out_root}")
if __name__ == "__main__":
parser = argparse.ArgumentParser("Stage 1")
parser.add_argument("--epochs", "-e", default=6000, type=int)
parser.add_argument(
"--lr",
type=float,
default=1e-3,
metavar="LR",
help="learning rate (absolute lr)",
)
parser.add_argument(
"--nsample",
type=int,
default=10000,
)
parser.add_argument(
"--output_dir",
default="./output/",
help="path where to save, empty for no saving",
)
parser.add_argument(
"--device", default="cuda", help="device to use for training / testing"
)
parser.add_argument("--seed", default=1997, type=int)
parser.add_argument(
"--d_hidden", default=128, type=int, help="number of hidden dims of MLP"
)
parser.add_argument(
"--n_blocks", default=5, type=int, help="number of resnet blocks"
)
parser.add_argument(
"--pcd_path",
default="/path/to/mesh.ply",
type=str,
help="path of 3D shape as .obj/.ply",
)
args = parser.parse_args()
overfit(args)