50% Faster!! Modular!! Pure Python or CUDA Extension
Welcome to this repository, which houses a refactored codebase designed to enhance the flexibility and performance of Gaussian splatting.
Gaussian splatting is a powerful technique used in various computer graphics and vision applications. It involves mapping a set of points onto a grid using Gaussian distributions, allowing for efficient and accurate representation of spatial data. However, the original implementation (https://github.com/graphdeco-inria/gaussian-splatting) of Gaussian splatting in PyTorch posed certain challenges:
- Both the forward and backward passes were encapsulated within a single PyTorch extension function, making it difficult to access intermediate variables without delving into the C code.
- Even if one attempted to modify the C code, they would have to derive the gradient formula manually and subsequently implement those formulas in the backward pass.
-
Modular Design: The refactored Gaussian splatting implementation separates the forward and backward into multiple PyTorch extension functions, promoting greater modularity and ease of access intermediate data. Besides, by leveraging PyTorch Autograd, the refactored codebase eliminates the need for deriving the gradient formula manually.
-
Flexible: PytorchGS offers two sets of modular APIs. One is implemented in CUDA, while the other is purely Python-based. You can easily modify the calculation logic in the Python API without delving into C code, allowing for rapid prototyping of your ideas. Moreover, we’ve carefully designed the GPU memory layout to ensure the Python API maintains reasonable training speeds. For performance enthusiasts, the CUDA-implemented API is also available for modification.
-
Better Performance and Fewer Resources: LiteGS not only surpasses the original 3DGS code in terms of flexibility and readability but also achieves a significant speedup of 50%. Our optimizations have been focused on enhancing training performance while reducing GPU memory usage.
-
no-new: We’ve preserved the original 3DGS algorithm, with only minor training logic adjustments due to the BVH creation.
To get started, you’ll need to install the required submodules:
-
Install simple-knn
cd gaussian_splatting/submodules/simple-knn
python setup.py build_ext --inplace -j8
python setup.py install
-
Install gaussian_raster
cd gaussian_splatting/submodules/gaussian_raster
mkdir ./build
cd ./build
cmake -DCMAKE_PREFIX_PATH=@1/share/cmake ../
replace @1 with the installation path of your PyTorch, which is like "$PYTHONHOME$/Lib/site-packages/torch"cmake --build . --config Release
Begin training with the following command:
./train.py --random_background --sh_degree 0 -s dataset/garden -i images_4 -m output/garden
The metrics is highly relay on the colmap result. we using the processed data same as the RetinaGS( https://github.com/MooreThreads/RetinaGS ).
The traininig result of LiteGS with mipnerf360 datasets in NVIDIA A100 is shown as below. The training and eval command is python ./full_eval.py --mipnerf360 SOURCE_PATH
metric|scene | Bicycle | flowers | garden | stump | treehill | room | counter | kitchen | bonsai |
---|---|---|---|---|---|---|---|---|---|
SSIM(Test) | 0.764 | 0.612 | 0.858 | 0.787 | 0.637 | 0.919 | 0.906 | 0.920 | 0.940 |
PSNR(Test) | 25.454 | 22.038 | 27.622 | 27.174 | 22.768 | 32.187 | 29.307 | 32.100 | 33.225 |
LPIPS(Test) | 0.218 | 0.335 | 0.112 | 0.218 | 0.318 | 0.186 | 0.179 | 0.113 | 0.166 |
The metrics of original 3DGS is shown as below for comparison. The metrics demonstrate that our refactored code does not compromise accuracy.
metric|scene | Bicycle | flowers | garden | stump | treehill | room | counter | kitchen | bonsai |
---|---|---|---|---|---|---|---|---|---|
SSIM(Test) | 0.770 | 0.623 | 0.866 | 0.771 | 0.641 | 0.931 | 0.919 | 0.933 | 0.950 |
PSNR(Test) | 25.33 | 21.85 | 27.58 | 26.74 | 22.46 | 31.93 | 29.54 | 31.52 | 33.10 |
LPIPS(Test) | 0.200 | 0.320 | 0.107 | 0.223 | 0.317 | 0.220 | 0.204 | 0.129 | 0.205 |
The cost of training on NVIDIA A100 is shown as below, illustrating the approximately 50% speedup achieved by LiteGS.
repo|scene | Bicycle | flowers | garden | stump | treehill | room | counter | kitchen | bonsai |
---|---|---|---|---|---|---|---|---|---|
Original(d9fad7b) | 32:09 | 24:37 | 31:34 | 25:58 | 24:08 | 21:59 | 21:45 | 25:41 | 19:10 |
LiteGS | 19:46 | 15:41 | 19:34 | 19:19 | 17:36 | 11:26 | 11:50 | 12:44 | 12:17 |
In contrast to the original 3DGS, which encapsulates almost the entire rendering process into a single torch extension function, LiteGS divides the rendering process into several functions. Users can access intermediate variables and insert compute logic with Python scripts instead of modifying C code. LiteGS breaks down the rendering into the following steps:
-
Cluster Culling
LiteGS devide the gaussian points into servel chunks and each chunk include 1024 points. The first step of rendering is the frustum culling.
-
Cluster Compact
Similar to mesh rendering, LiteGS compacts visible primitives after frustum culling. Specifically, it compacts each property of Gaussian points into sequential memory.
-
3DGS Projection
This step projects Gaussian points into screen space, with no changes from the original version.
-
Create Visibility Table
The forth step is creating the visibility table which maps the tile and its visible primitives.
-
Rasterization
The final step is rasterization. Each tile rasterize its visible primitives in parallel.
Due to LiteGS’s clustering of point clouds, density control requires a new implementation. LiteGS clones and splits points each epoch, clustering them into new chunks without rebuilding the BVH. The opacity of points is reset and invalid points are removed every 10 epochs, followed by BVH rebuilding. The opacity reset frequency in LiteGS is faster than in original 3DGS, so we opt for decaying opacity instead of resetting it close to zero.
Inside gaussian_splatting/wrapper.py
, there are two sets of GS function. Functions with the “v1” suffix are implemented in Python, while those with the “v2” suffix are implemented as CUDA extensions. The CUDA extensions offer significant performance improvements but lack the flexibility of Python. The choice between the two implementations depends on the specific use case:
-
PythonScript: Offers greater flexibility. It is suitable for rapid prototyping and development where the training speed is not critical.
-
CudaExtension: Provides the highest performance and is recommended for production environments where training speed is of utmost importance.
To switch between the two implementations, simply modify the gaussian_splatting/wrapper.py
file
Here is an example that demonstrates the flexibility of LiteGS. In this instance, our goal is to create a more precise bounding box for a 2D Gaussian when generating visibility tables. In the original 3DGS, the bounding box for the 2D Gaussian is set to three times the length of the major axis. However, this coefficient can be reduced if opacity is taken into account.
Let’s attempt to implement this change in the original 3DGS. The tasks to be completed are as follows:
- Modify the C++ functions, including both declarations and definitions.
- Modify the CUDA global function.
- Compile
In contrast, with LiteGS, only the Python scripts need to be altered.
original:
axis_length=(3.0*eigen_val.abs()).sqrt().ceil()
modified:
coefficient=2*((255*opacity).log())
axis_length=(coefficient*eigen_val.abs()).sqrt().ceil()
If you find this project useful in your research, please consider cite:
@misc{LiteGS,
title={LiteGS},
author={LiteGS Contributors},
howpublished = {\url{https://sh-code.mthreads.com/ai/Pytorch-GaussianSplatting}},
year={2024}
}
In this section, we detail the key optimizations implemented at each stage mentioned earlier.
At a high level, We restructured the memory layout. The dimension order of tensors are altered from [points,property] to [property,points]. After this changing, the tensor computation implemented by pure python will load and store sequencially.
A simple BVH is build for density control. In contrast to the original implementation, where new points are appended to the end of tensors, leading to decreasing spatial locality, our approach maintains it.
The original implementation culls primitives directly, which can be time-consuming in large scenes. We optimized this by performing frustum culling on cluster AABBs (Axis-Aligned Bounding Boxes) instead.
Beside, we compact the parameters of gaussian points which improve the gpu utilization. The original implementation skips calculations for culled Gaussian points but the gpu schedule instruction in wrap(32 threads).
An important note is that compaction should be applied to Gaussian point clusters. Applying it to primitives can slow down the backward pass due to out-of-order writes.
The projection step has nothing changed with original implements. We highly recommend you to using the cuda implemented function instead of the prue python function. The essential different between these two functions is whether we write the intermediate data into the gpu global memory. Projection step contain servel consecutive 4x4 matmul. Writing the inermediate data into gpu global memory will significantly enlarge data transfer overhead.
Same as the "binning" step in compute raster, the purpose of this step is mapping the gaussian points and tiles. The key of this step is create a more accurate bounding box for each primitive. An accurate visibility table will reduce computation. Original 3DGS use the long axis of 2D gaussian to create a bounding square. It is rough. Taking the transparency and short axis into consideration, LiteGS create a accurate bounding box of 2D gaussian.
We have to trade off the accuracy for speed. The bounding box composed of long-axis and short-axis need 2 pass to create the visibility table. One pass for the allocate size and One pass for filling the table. 2D AABB is easy to get the allocate size so it only need one pass. LiteGS choose to using 2D AABB.
There is a large efficiency problem in original implement. In the backward of rasterization, the gradient of properties is calculated in each pixel. These gradient need a summation. The original implement directly use AtomicAdd.
A better implement is the reduction algorithm and many repo has already used this method. The code is followed.
for(int i=element_num;i > 16;i/=2)
{
if(threadidx.x < i)
{
property0[threadidx.x]+=property0[threadidx.x+i];
property1[threadidx.x]+=property1[threadidx.x+i];
property2[threadidx.x]+=property2[threadidx.x+i];
...
__syncthreads();
}
}
float grad0=property0[threadidx.x];
float grad1=property1[threadidx.x];
...
for(int i=16;i > 0;i/=2)
{
grad0 += __shfl_down_sync(0xffffffff, grad0, i);
grad1 += __shfl_down_sync(0xffffffff, grad1, i);
...
}
It is not a best implement. The warp stall is serious. From data-oriented perspective, we need a multibatch reduction. The code is followed.
for (int property_id = threadidx / 32; property_id < property_num ; property_id+= wraps_num_in_block)
{
float gradient_sum = 0;
for (int offset = threadid_in_warp; offset < tilesize * tilesize; reduction_i+=32)
{
gradient_sum += gradient_buffer[property_id * tilesize * tilesize + offset];
}
for (int offset = 16; offset > 0; offset /= 2)
{
gradient_sum += __shfl_down_sync(0xffffffff, gradient_sum, offset);
}
if (threadid_in_warp == 0)
{
shared_gradient_sum[property_id] = gradient_sum;
}
}