-
Notifications
You must be signed in to change notification settings - Fork 13
Description
03_gsplat-rendering.py is unable to directly view nerfstudio's splatfacto training result(XX.ckpt). I have made modifications to the code.
You can directly replace the old code, and then you can view the Splatfact training results of nerfstudio
I am unable to view the Splatfact training results of nerfstudio at the moment because the code cannot read the input data.
I have rewritten the reading code based on the Splatfacto training results XX.ckpt
the code:
else:
ckpt = torch.load(args.ckpt, map_location=device)["pipeline"]
print(ckpt.keys())
print()
means = ckpt["_model.means"]#点数,3
scales = torch.exp(ckpt["_model.scales"])#ckpt["scales"]是点数3
quats = F.normalize(ckpt["_model.quats"], p=2, dim=-1)#ckpt["quats"]点数,4 #该函数将quats中的四元数(quaternion)向量归一化为单位向量,保持其方向信息但去除长度差异
sh0 = ckpt["_model.features_dc"]#(...,None)#点数 1行 3列
sh0 = torch.unsqueeze(sh0, axis=1)
shN = ckpt["_model.features_rest"]#点数 15行 3列
colors = torch.cat([sh0, shN], dim=-2)
opacities = torch.sigmoid(ckpt["_model.opacities"])#ckpt["opacities"]是点数
opacities = opacities.squeeze(axis=1) # ckpt["opacities"]是点数
sh_degree = int(math.sqrt(colors.shape[-2]) - 1)#点数 16行 3列
# crop 只保留box内的东西
aabb = torch.tensor((-999.0, -999.0, -999.0, 999.0, 999.0, 999), device=device)
edges = aabb[3:] - aabb[:3]
sel = ((means >= aabb[:3]) & (means <= aabb[3:])).all(dim=-1)
sel = torch.where(sel)[0]
means, quats, scales, colors, opacities = (
means[sel],
quats[sel],
scales[sel],
colors[sel],
opacities[sel],
)
# repeat the scene into a grid (to mimic a large-scale setting)将场景重复排列成网格布局(以模拟大规模场景)。
repeats = args.scene_grid#默认是1
gridx, gridy = torch.meshgrid(
[
torch.arange(-(repeats // 2), repeats // 2 + 1, device=device),
torch.arange(-(repeats // 2), repeats // 2 + 1, device=device),
],
indexing="ij",
)
grid = torch.stack([gridx, gridy, torch.zeros_like(gridx)], dim=-1).reshape(-1, 3)
means = means[None, :, :] + grid[:, None, :] * edges[None, None, :]
means = means.reshape(-1, 3)
quats = quats.repeat(repeats**2, 1)
scales = scales.repeat(repeats**2, 1)
colors = colors.repeat(repeats**2, 1, 1)
opacities = opacities.repeat(repeats**2)
print("Number of Gaussians:", len(means))
Activity