diff --git a/utils/ply_utils.py b/utils/ply_utils.py index c18f8f9..663bc4b 100644 --- a/utils/ply_utils.py +++ b/utils/ply_utils.py @@ -1,4 +1,6 @@ from array import array +import plyfile +import numpy as np import torch @@ -12,24 +14,14 @@ def __init__(self, height, width, min_d=3, max_d=400, batch_size=1, roi=None, dr self.max_d = max_d self.roi = roi self.dropout = dropout - self.data = array('f') + self.data = [] self.projector = Backprojection(batch_size, height, width) def save(self, file): - length = len(self.data) // 6 - header = "ply\n" \ - "format binary_little_endian 1.0\n" \ - f"element vertex {length}\n" \ - f"property float x\n" \ - f"property float y\n" \ - f"property float z\n" \ - f"property float red\n" \ - f"property float green\n" \ - f"property float blue\n" \ - f"end_header\n" - file.write(header.encode(encoding="ascii")) - self.data.tofile(file) + vertices = np.array(list(map(tuple, self.data)), dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]) + vertex_el = plyfile.PlyElement.describe(vertices, 'vertex') + plyfile.PlyData([vertex_el]).write(file) def add_depthmap(self, depth: torch.Tensor, image: torch.Tensor, intrinsics: torch.Tensor, extrinsics: torch.Tensor): @@ -39,7 +31,7 @@ def add_depthmap(self, depth: torch.Tensor, image: torch.Tensor, intrinsics: tor if self.roi is not None: mask[:, :, :self.roi[0], :] = False mask[:, :, self.roi[1]:, :] = False - mask[:, :, :, :self.roi[2]] = False + mask[:, :, :, self.roi[2]] = False mask[:, :, :, self.roi[3]:] = False if self.dropout > 0: mask = mask & (torch.rand_like(depth) > self.dropout) @@ -48,6 +40,6 @@ def add_depthmap(self, depth: torch.Tensor, image: torch.Tensor, intrinsics: tor coords = extrinsics @ coords coords = coords[:, :3, :] data_batch = torch.cat([coords, image.view_as(coords)], dim=1).permute(0, 2, 1) - data_batch = data_batch[mask.view(depth.shape[0], 1, -1).permute(0, 2, 1).expand(-1, -1, 6)] + data_batch = data_batch.view(-1, 6)[mask.view(-1), :] self.data.extend(data_batch.cpu().tolist())