Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support bg color #94

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions taichi_3d_gaussian_splatting/GaussianPointCloudRasterisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def gaussian_point_rasterisation(
# output
pixel_offset_of_last_effective_point: ti.types.ndarray(ti.i32, ndim=2),
pixel_valid_point_count: ti.types.ndarray(ti.i32, ndim=2), # output
background_color: ti.types.ndarray(ti.f32, ndim=1), # (3)
rgb_only: ti.template(), # input
):
ti.loop_config(block_dim=(TILE_WIDTH * TILE_HEIGHT))
Expand Down Expand Up @@ -469,6 +470,10 @@ def gaussian_point_rasterisation(
valid_point_count += 1
T_i = next_T_i
# end of point group loop

background_color_vector = ti.math.vec3([background_color[0], background_color[1], background_color[2]])
background_color_vector = ti.math.clamp(background_color_vector, 0, 1)
accumulated_color += background_color_vector * T_i

# end of point group id loop

Expand Down Expand Up @@ -517,7 +522,7 @@ def gaussian_point_rasterisation_backward(
point_uv_conic_and_rescale: ti.types.ndarray(ti.f32, ndim=2), # (M, 3)
point_alpha_after_activation: ti.types.ndarray(ti.f32, ndim=1), # (M)
point_color: ti.types.ndarray(ti.f32, ndim=2), # (M, 3)

background_color: ti.types.ndarray(ti.f32, ndim=1), # (3)
need_extra_info: ti.template(),
magnitude_grad_viewspace: ti.types.ndarray(ti.f32, ndim=1), # (N)
# (H, W, 2)
Expand Down Expand Up @@ -558,6 +563,9 @@ def gaussian_point_rasterisation_backward(
last_effective_point = pixel_offset_of_last_effective_point[pixel_v, pixel_u]
accumulated_alpha: ti.f32 = pixel_accumulated_alpha[pixel_v, pixel_u]
T_i = 1.0 - accumulated_alpha # T_i = \prod_{j=1}^{i-1} (1 - a_j)
T_final = T_i
background_color_vector = ti.math.vec3([
background_color[0], background_color[1], background_color[2]])
# \frac{dC}{da_i} = c_i T(i) - \frac{1}{1 - a_i} \sum_{j=i+1}^{n} c_j a_j T(j)
# let w_i = \sum_{j=i+1}^{n} c_j a_j T(j)
# we have w_n = 0, w_{i-1} = w_i + c_i a_i T(i)
Expand Down Expand Up @@ -652,6 +660,8 @@ def gaussian_point_rasterisation_backward(
# \frac{dC}{da_i} = c_i T(i) - \frac{1}{1 - a_i} w_i
alpha_grad_from_rgb = (color * T_i - w_i / (1. - alpha)) \
* pixel_rgb_grad
alpha_grad_from_rgb -= pixel_rgb_grad * background_color_vector * \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you provide the formula to calculate the gradient properly? I am a bit confused by the name convention.

T_final / (1 - alpha)
# w_{i-1} = w_i + c_i a_i T(i)
w_i += color * alpha * T_i
alpha_grad: ti.f32 = alpha_grad_from_rgb.sum()
Expand Down Expand Up @@ -800,7 +810,9 @@ class GaussianPointCloudRasterisationInput:
# Kx4, x to the right, y down, z forward, K is the number of objects
q_pointcloud_camera: torch.Tensor
# Kx3, x to the right, y down, z forward, K is the number of objects

t_pointcloud_camera: torch.Tensor
background_color: Optional[torch.Tensor] = None # 3
color_max_sh_band: int = 2

@dataclass
Expand Down Expand Up @@ -837,6 +849,7 @@ def forward(ctx,
t_pointcloud_camera,
camera_info,
color_max_sh_band,
background_color,
):
point_in_camera_mask = torch.zeros(
size=(pointcloud.shape[0],), dtype=torch.int8, device=pointcloud.device)
Expand Down Expand Up @@ -994,7 +1007,8 @@ def forward(ctx,
rasterized_depth=rasterized_depth,
pixel_accumulated_alpha=pixel_accumulated_alpha,
pixel_offset_of_last_effective_point=pixel_offset_of_last_effective_point,
pixel_valid_point_count=pixel_valid_point_count)
pixel_valid_point_count=pixel_valid_point_count,
background_color=background_color)
ctx.save_for_backward(
pointcloud,
pointcloud_features,
Expand All @@ -1016,6 +1030,7 @@ def forward(ctx,
point_uv_conic_and_rescale,
point_alpha_after_activation,
point_color,
background_color
)
ctx.camera_info = camera_info
ctx.color_max_sh_band = color_max_sh_band
Expand Down Expand Up @@ -1044,7 +1059,8 @@ def backward(ctx, grad_rasterized_image, grad_rasterized_depth, grad_pixel_valid
point_in_camera, \
point_uv_conic, \
point_alpha_after_activation, \
point_color = ctx.saved_tensors
point_color, \
background_color = ctx.saved_tensors
camera_info = ctx.camera_info
color_max_sh_band = ctx.color_max_sh_band
grad_rasterized_image = grad_rasterized_image.contiguous()
Expand Down Expand Up @@ -1093,6 +1109,7 @@ def backward(ctx, grad_rasterized_image, grad_rasterized_depth, grad_pixel_valid
point_uv_conic_and_rescale=point_uv_conic.contiguous(),
point_alpha_after_activation=point_alpha_after_activation.contiguous(),
point_color=point_color.contiguous(),
background_color=background_color.contiguous(),
need_extra_info=True,
magnitude_grad_viewspace=magnitude_grad_viewspace.contiguous(),
magnitude_grad_viewspace_on_image=magnitude_grad_viewspace_on_image.contiguous(),
Expand Down Expand Up @@ -1160,7 +1177,7 @@ def backward(ctx, grad_rasterized_image, grad_rasterized_depth, grad_pixel_valid
None, \
grad_q_pointcloud_camera, \
grad_t_pointcloud_camera, \
None, None
None, None, None

self._module_function = _module_function

Expand Down Expand Up @@ -1189,6 +1206,10 @@ def forward(self, input_data: GaussianPointCloudRasterisationInput):
q_pointcloud_camera = input_data.q_pointcloud_camera
t_pointcloud_camera = input_data.t_pointcloud_camera
color_max_sh_band = input_data.color_max_sh_band
background_color = input_data.background_color
if background_color is None:
background_color = torch.zeros((3, ), dtype=torch.float32,
device=pointcloud.device)
camera_info = input_data.camera_info
assert camera_info.camera_width % TILE_WIDTH == 0
assert camera_info.camera_height % TILE_HEIGHT == 0
Expand All @@ -1201,4 +1222,5 @@ def forward(self, input_data: GaussianPointCloudRasterisationInput):
t_pointcloud_camera,
camera_info,
color_max_sh_band,
background_color,
)
Loading