Skip to content

Commit c01b4bf

Browse files
initial version of example (NVIDIAGameWorks#621)
added Or's suggestions Signed-off-by: Charles Loop <[email protected]> Updated tutorial with simplified api Signed-off-by: operel <[email protected]> Add ctor for features Signed-off-by: operel <[email protected]> remove features from make_dense Signed-off-by: operel <[email protected]> simplify create dense octree Signed-off-by: operel <[email protected]> Signed-off-by: operel <[email protected]> Co-authored-by: Charles Loop <[email protected]>
1 parent 6323e80 commit c01b4bf

File tree

5 files changed

+212
-1
lines changed

5 files changed

+212
-1
lines changed

ci/gitlab_jenkins_templates/ubuntu_test_CI.jenkins

+8
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,14 @@ spec:
158158
build_passed = false
159159
echo e.toString()
160160
}
161+
try {
162+
stage("SPC Convolution 3D Recipe") {
163+
sh 'cd /kaolin/examples/recipes/spc/ && python spc_conv3d_example.py'
164+
}
165+
} catch(e) {
166+
build_passed = false
167+
echo e.toString()
168+
}
161169
if (build_passed) {
162170
currentBuild.result = "SUCCESS"
163171
updateGitlabCommitStatus(name: "test-${configName}-${arch}", state: "success")

ci/gitlab_jenkins_templates/windows_test_CI.jenkins

+8
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,14 @@ spec:
156156
'''
157157
}
158158
}
159+
stage("SPC Convolution 3D Recipe") {
160+
catchError(stageResult: "failure") {
161+
powershell '''
162+
cd c:\\kaolin\\examples\\recipes\\spc
163+
python spc_conv3d_example.py
164+
'''
165+
}
166+
}
159167
stage("Run pytest - io") {
160168
catchError(stageResult: "failure") {
161169
timeout(time: 5, unit: "MINUTES") {
+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# ==============================================================================================================
2+
# The following code demonstrates the usage of kaolin's "Structured Point Cloud (SPC)" 3d convolution
3+
# functionality. Note that this sample does NOT demonstrate how to use Kaolin's Pytorch 3d convolution layers.
4+
# Rather, 3d convolutions are used to 'filter' color data useful for level-of-detail management during
5+
# rendering. This can be thought of as the 3d analog of generating a 2d mipmap.
6+
#
7+
# Note this is a low level interface: practitioners are encouraged to visit the references below.
8+
# ==============================================================================================================
9+
# See also:
10+
#
11+
# - Code: kaolin.ops.spc.SPC
12+
# https://kaolin.readthedocs.io/en/latest/modules/kaolin.rep.html?highlight=SPC#kaolin.rep.Spc
13+
#
14+
# - Tutorial: Understanding Structured Point Clouds (SPCs)
15+
# https://github.com/NVIDIAGameWorks/kaolin/blob/master/examples/tutorial/understanding_spcs_tutorial.ipynb
16+
#
17+
# - Documentation: Structured Point Clouds
18+
# https://kaolin.readthedocs.io/en/latest/modules/kaolin.ops.spc.html?highlight=spc#kaolin-ops-spc
19+
# ==============================================================================================================
20+
21+
import torch
22+
import kaolin
23+
24+
# The following function applies a series of SPC convolutions to encode the entire hierarchy into a single tensor.
25+
# Each step applies a convolution on the "highest" level of the SPC with some averaging kernel.
26+
# Therefore, each step locally averages the "colored point hierarchy", where each "colored point"
27+
# corresponds to a point in the SPC point hierarchy.
28+
# For a description of inputs 'octree', 'point_hierachy', 'level', 'pyramids', and 'exsum', as well a
29+
# detailed description of the mathematics of SPC convolutions, see:
30+
# https://kaolin.readthedocs.io/en/latest/modules/kaolin.ops.spc.html?highlight=SPC#kaolin.ops.spc.Conv3d
31+
# The input 'color' is Pytorch tensor containing color features corresponding to some 'level' of the hierarchy.
32+
def encode(colors, octree, point_hierachy, pyramids, exsum, level):
33+
34+
# SPC convolutions are characterized by a set of 'kernel vectors' and corresponding 'weights'.
35+
36+
# kernel_vectors is the "kernel support" -
37+
# a listing of 3D coordinates where the weights of the convolution are non-null,
38+
# in this case a it's a simple dense 2x2x2 grid.
39+
kernel_vectors = torch.tensor([[0,0,0],[0,0,1],[0,1,0],[0,1,1],
40+
[1,0,0],[1,0,1],[1,1,0],[1,1,1]],
41+
dtype=torch.short, device='cuda')
42+
43+
# The weights specify how the input colors 'under' the kernel are mapped to an output color,
44+
# in this case a simple average.
45+
weights = torch.diag(torch.tensor([0.125, 0.125, 0.125, 0.125],
46+
dtype=torch.float32, device='cuda')) # Tensor of (4, 4)
47+
weights = weights.repeat(8,1,1).contiguous() # Tensor of (8, 4, 4)
48+
49+
# Storage for the output color hierarchy is allocated. This includes points at the bottom of the hierarchy,
50+
# as well as intermediate SPC levels (which may store different features)
51+
color_hierarchy = torch.empty((pyramids[0,1,level+1],4), dtype=torch.float32, device='cuda')
52+
# Copy the input colors into the highest level of color_hierarchy. pyramids is used here to select all leaf
53+
# points at the bottom of the hierarchy and set them to some pre-sampled random color. Points at intermediate
54+
# levels are left empty.
55+
color_hierarchy[pyramids[0,1,level]:pyramids[0,1,level+1]] = colors[:]
56+
57+
# Performs the 3d convolutions in a bottom up fashion to 'filter' colors from the previous level
58+
for l in range(level,0,-1):
59+
60+
# Apply the 3d convolution. Note that jump=1 means the inputs and outputs differ by 1 level
61+
# This is analogous to to a stride=2 in grid based convolutions
62+
colors, ll = kaolin.ops.spc.conv3d(octree,
63+
point_hierachy,
64+
l,
65+
pyramids,
66+
exsum,
67+
colors,
68+
weights,
69+
kernel_vectors,
70+
jump=1)
71+
# Copy the output colors into the color hierarchy
72+
color_hierarchy[pyramids[0,1,ll]:pyramids[0,1,l]] = colors[:]
73+
print(f"At level {l}, output feature shape is:\n{colors.shape}")
74+
75+
# Normalize the colors.
76+
color_hierarchy /= color_hierarchy[:,3:]
77+
# Normalization is needed here due to the sparse nature of SPCs. When a point under a kernel is not
78+
# present in the point hierarchy, the corresponding data is treated as zeros. Normalization is equivalent
79+
# to having the filter weights sum to one. This may not always be desirable, e.g. alpha blending.
80+
81+
return color_hierarchy
82+
83+
84+
# Highest level of SPC
85+
level = 3
86+
87+
# Construct a fully occupied Structured Point Cloud with N levels of detail
88+
# See https://kaolin.readthedocs.io/en/latest/modules/kaolin.rep.html?highlight=SPC#kaolin.rep.Spc
89+
spc = kaolin.rep.Spc.make_dense(level, device='cuda')
90+
91+
# In kaolin, operations are batched by default, the spc object above contains a single item batch, hence [0]
92+
num_points_last_lod = spc.num_points(level)[0]
93+
94+
# Create tensor of random colors for all points in the highest level of detail
95+
colors = torch.rand((num_points_last_lod, 4), dtype=torch.float32, device='cuda')
96+
# Set 4th color channel to one for subsequent color normalization
97+
colors[:,3] = 1
98+
99+
print(f'Input SPC features: {colors.shape}')
100+
101+
# Encode color hierarchy by invoking a series of convolutions, until we end up with a single tensor.
102+
color_hierarchy = encode(colors=colors,
103+
octree=spc.octrees,
104+
point_hierachy=spc.point_hierarchies,
105+
pyramids=spc.pyramids,
106+
exsum=spc.exsum,
107+
level=level)
108+
109+
# Print root node color
110+
print(f'Final encoded value (average of averages):')
111+
print(color_hierarchy[0])
112+
# This will be the average of averages, over the entire spc hierarchy. Since the initial random colors
113+
# came from a uniform distribution, this should approach [0.5, 0.5, 0.5, 1.0] as 'level' increases

kaolin/ops/spc/points.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@
2222
'coords_to_trilinear',
2323
'coords_to_trilinear_coeffs',
2424
'unbatched_points_to_octree',
25-
'quantize_points'
25+
'quantize_points',
26+
'create_dense_spc'
2627
]
2728

2829
import warnings
30+
import numpy as np
2931
import torch
3032

3133
from kaolin import _C
@@ -293,3 +295,18 @@ def coords_to_trilinear_coeffs(coords, points, level):
293295
coords_ = (2**level) * (coords * 0.5 + 0.5)
294296

295297
return _C.ops.spc.coords_to_trilinear_cuda(coords_.contiguous(), points.contiguous()).reshape(*shape)
298+
299+
300+
def create_dense_spc(level, device):
301+
"""Creates a dense SPC model
302+
303+
Args:
304+
level (int): The level at which the octree will be initialized to.
305+
device (torch.device): Torch device to keep the spc octree
306+
307+
Returns:
308+
(torch.ByteTensor): the octree tensor
309+
"""
310+
lengths = torch.tensor([sum(8 ** l for l in range(level))], dtype=torch.int32)
311+
octree = torch.full((lengths,), 255, device=device, dtype=torch.uint8)
312+
return octree, lengths

kaolin/rep/spc.py

+65
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,56 @@ def __init__(self, octrees, lengths, max_level=None, pyramids=None,
138138
self._point_hierarchies = point_hierarchies
139139
self.features = features
140140

141+
@classmethod
142+
def make_dense(cls, level, device='cuda'):
143+
"""Creates a dense, fully occupied Spc object.
144+
The Spc will have ``level`` levels of detail.
145+
146+
Args:
147+
level (int):
148+
Number of levels to use for the dense Spc.
149+
device (torch.device):
150+
Torch device to keep the spc octree
151+
152+
Return:
153+
(kaolin.rep.Spc): a new fully occupied ``Spc``.
154+
"""
155+
from ..ops.spc import create_dense_spc
156+
octree, lengths = create_dense_spc(level, device) # Create a single entry batch
157+
return Spc(octrees=octree, lengths=lengths)
158+
159+
@classmethod
160+
def from_features(cls, feature_grids, masks=None):
161+
"""Creates a sparse Spc object from the feature grid.
162+
163+
Args:
164+
feature_grids (torch.Tensor):
165+
The sparse 3D feature grids, of shape
166+
:math:`(\text{batch_size}, \text{feature_dim}, X, Y, Z)`
167+
masks (optional, torch.BoolTensor):
168+
The topology mask, showing where are the features,
169+
of shape :math:`(\text{batch_size}, X, Y, Z)`.
170+
Default: A feature is determined when not full of zeros.
171+
172+
Returns:
173+
(torch.ByteTensor, torch.IntTensor, torch.Tensor):
174+
a tuple containing:
175+
176+
- The octree, of size :math:`(\text{num_nodes})`
177+
178+
- The lengths of each octree, of size :math:`(\text{batch_size})`
179+
180+
- The coalescent features, of same dtype than ``feature_grids``,
181+
of shape :math:`(\text{num_features}, \text{feature_dim})`.
182+
Return:
183+
(kaolin.rep.Spc): a ``Spc``, with length of :math:`(\text{batch_size})`,
184+
an octree of size octree, of size :math:`(\text{num_nodes})`, and the features field
185+
of the same dtype as ``feature_grids`` and of shape :math:`(\text{num_features}, \text{feature_dim})`.
186+
"""
187+
from ..ops.spc import feature_grids_to_spc
188+
octrees, lengths, coalescent_features = feature_grids_to_spc(feature_grids, masks=masks)
189+
return Spc(octrees=octrees, lengths=lengths, features=coalescent_features)
190+
141191
# TODO(cfujitsang): could be interesting to separate into multiple functions
142192
def _apply_scan_octrees(self):
143193
# to break circular dependency
@@ -237,3 +287,18 @@ def to_dict(self, keys=None):
237287
return {k: getattr(self, k) for k in self.KEYS}
238288
else:
239289
return {k: getattr(self, k) for k in keys}
290+
291+
def num_points(self, lod: int):
292+
"""
293+
Returns how many points the SPC holds at a given level of detail.
294+
295+
Args:
296+
lod (int):
297+
Index of a level of detail.
298+
Level 0 is considered the root and always holds a single point,
299+
level 1 holds up to :math:`(\text{num_points}=8)` points,
300+
level 2 holds up to :math:`(\text{num_points}=8^{2})`, and so forth.
301+
Return:
302+
(torch.Tensor): The number of points each SPC entry holds for the given level of detail.
303+
"""
304+
return self.pyramids[:, 0, lod]

0 commit comments

Comments
 (0)