Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
D
diff-gaussian-rasterization
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Alan de Oliveira
diff-gaussian-rasterization
Commits
bccbb2e8
Commit
bccbb2e8
authored
Jun 19, 2023
by
Bernhard Kerbl
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
InternalState explicit now
parent
ffee75d5
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
166 additions
and
105 deletions
+166
-105
rasterizer.h
cuda_rasterizer/rasterizer.h
+11
-3
rasterizer_impl.cu
cuda_rasterizer/rasterizer_impl.cu
+81
-72
rasterizer_impl.h
cuda_rasterizer/rasterizer_impl.h
+37
-23
ext.cpp
ext.cpp
+5
-1
rasterize_points.cu
rasterize_points.cu
+25
-5
rasterize_points.h
rasterize_points.h
+7
-1
No files found.
cuda_rasterizer/rasterizer.h
View file @
bccbb2e8
...
...
@@ -5,10 +5,16 @@
namespace
CudaRasterizer
{
struct
InternalState
;
class
Rasterizer
{
public
:
virtual
InternalState
*
createInternalState
()
=
0
;
virtual
void
killInternalState
(
InternalState
*
)
=
0
;
virtual
void
markVisible
(
int
P
,
float
*
means3D
,
...
...
@@ -33,10 +39,13 @@ namespace CudaRasterizer
const
float
*
cam_pos
,
const
float
tan_fovx
,
float
tan_fovy
,
const
bool
prefiltered
,
float
*
out_color
,
int
*
radii
=
nullptr
)
=
0
;
int
*
radii
,
InternalState
*
state
,
float
*
out_color
)
=
0
;
virtual
void
backward
(
const
int
*
radii
,
const
InternalState
*
state
,
const
int
P
,
int
D
,
int
M
,
const
float
*
background
,
const
int
width
,
int
height
,
...
...
@@ -51,7 +60,6 @@ namespace CudaRasterizer
const
float
*
projmatrix
,
const
float
*
campos
,
const
float
tan_fovx
,
float
tan_fovy
,
const
int
*
radii
,
const
float
*
dL_dpix
,
float
*
dL_dmean2D
,
float
*
dL_dconic
,
...
...
cuda_rasterizer/rasterizer_impl.cu
View file @
bccbb2e8
...
...
@@ -171,46 +171,46 @@ void CudaRasterizer::RasterizerImpl::forward(
const float* cam_pos,
const float tan_fovx, float tan_fovy,
const bool prefiltered,
float* out_color,
int* radii)
int* radii,
InternalState* state,
float* out_color)
{
const float focal_y = height / (2.0f * tan_fovy);
const float focal_x = width / (2.0f * tan_fovx);
// Dynamically resize auxiliary buffers during training
if (P > maxP)
{
maxP = resizeMultiplier * P;
cov3D.resize(maxP * 6);
rgb.resize(maxP * 3);
tiles_touched.resize(maxP);
point_offsets.resize(maxP);
clamped.resize(3 * maxP);
depths.resize(maxP);
means2D.resize(maxP);
conic_opacity.resize(maxP);
cub::DeviceScan::InclusiveSum(nullptr, scan_size, tiles_touched.data().get(), tiles_touched.data().get(), maxP);
scanning_space.resize(scan_size);
}
if (radii == nullptr)
if (P > state->maxP)
{
internal_radii.resize(maxP);
radii = internal_radii.data().get();
state->maxP = resizeMultiplier * P;
state->cov3D.resize(state->maxP * 6);
state->rgb.resize(state->maxP * 3);
state->tiles_touched.resize(state->maxP);
state->point_offsets.resize(state->maxP);
state->clamped.resize(3 * state->maxP);
state->depths.resize(state->maxP);
state->means2D.resize(state->maxP);
state->conic_opacity.resize(state->maxP);
size_t scan_size;
cub::DeviceScan::InclusiveSum(nullptr,
scan_size,
state->tiles_touched.data().get(),
state->tiles_touched.data().get(),
state->maxP);
state->scanning_space.resize(scan_size);
}
dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1);
dim3 block(BLOCK_X, BLOCK_Y, 1);
// Dynamically resize image-based auxiliary buffers during training
if (width * height > maxPixels)
if (width * height >
state->
maxPixels)
{
maxPixels = width * height;
accum_alpha.resize(
maxPixels);
n_contrib.resize(
maxPixels);
ranges.resize(tile_grid.x * tile_grid.y);
state->
maxPixels = width * height;
state->accum_alpha.resize(state->
maxPixels);
state->n_contrib.resize(state->
maxPixels);
state->
ranges.resize(tile_grid.x * tile_grid.y);
}
if (NUM_CHANNELS != 3 && colors_precomp == nullptr)
...
...
@@ -227,7 +227,7 @@ void CudaRasterizer::RasterizerImpl::forward(
(glm::vec4*)rotations,
opacities,
shs,
clamped.data().get(),
state->
clamped.data().get(),
cov3D_precomp,
colors_precomp,
viewmatrix, projmatrix,
...
...
@@ -236,45 +236,56 @@ void CudaRasterizer::RasterizerImpl::forward(
tan_fovx, tan_fovy,
focal_x, focal_y,
radii,
means2D.data().get(),
depths.data().get(),
cov3D.data().get(),
rgb.data().get(),
conic_opacity.data().get(),
state->
means2D.data().get(),
state->
depths.data().get(),
state->
cov3D.data().get(),
state->
rgb.data().get(),
state->
conic_opacity.data().get(),
tile_grid,
tiles_touched.data().get(),
state->
tiles_touched.data().get(),
prefiltered
);
// Compute prefix sum over full list of touched tile counts by Gaussians
// E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8]
cub::DeviceScan::InclusiveSum(scanning_space.data().get(), scan_size,
tiles_touched.data().get(), point_offsets.data().get(), P);
size_t scanning_space_size = state->scanning_space.size();
cub::DeviceScan::InclusiveSum(
state->scanning_space.data().get(),
scanning_space_size,
state->tiles_touched.data().get(),
state->point_offsets.data().get(),
P);
// Retrieve total number of Gaussian instances to launch and resize aux buffers
int num_needed;
cudaMemcpy(&num_needed, point_offsets.data().get() + P - 1, sizeof(int), cudaMemcpyDeviceToHost);
cudaMemcpy(&num_needed,
state->
point_offsets.data().get() + P - 1, sizeof(int), cudaMemcpyDeviceToHost);
if (num_needed > point_list_keys_unsorted.size())
{
point_list_keys_unsorted.resize(2 * num_needed);
point_list_keys.resize(2 * num_needed);
point_list_unsorted.resize(2 * num_needed);
point_list.resize(2 * num_needed);
int resizeNum = resizeMultiplier * num_needed;
point_list_keys_unsorted.resize(resizeNum);
point_list_keys.resize(resizeNum);
point_list_unsorted.resize(resizeNum);
size_t sorting_size;
cub::DeviceRadixSort::SortPairs(
nullptr, sorting_size,
point_list_keys_unsorted.data().get(), point_list_keys.data().get(),
point_list_unsorted.data().get(), point_list.data().get(),
2 * num_needed
);
point_list_unsorted.data().get(),
state->
point_list.data().get(),
resizeNum
);
list_sorting_space.resize(sorting_size);
}
if (num_needed > state->point_list.size())
{
state->point_list.resize(resizeMultiplier * num_needed);
}
// For each instance to be rendered, produce adequate [ tile | depth ] key
// and corresponding dublicated Gaussian indices to be sorted
duplicateWithKeys << <(P + 255) / 256, 256 >> > (
P,
means2D.data().get(),
depths.data().get(),
point_offsets.data().get(),
state->
means2D.data().get(),
state->depths.data().get(),
state->point_offsets.data().get(),
point_list_keys_unsorted.data().get(),
point_list_unsorted.data().get(),
radii,
...
...
@@ -284,34 +295,36 @@ void CudaRasterizer::RasterizerImpl::forward(
int bit = getHigherMsb(tile_grid.x * tile_grid.y);
// Sort complete list of (duplicated) Gaussian indices by keys
size_t list_sorting_space_size = list_sorting_space.size();
cub::DeviceRadixSort::SortPairs(
list_sorting_space.data().get(),
sorting
_size,
list_sorting_space
_size,
point_list_keys_unsorted.data().get(), point_list_keys.data().get(),
point_list_unsorted.data().get(), point_list.data().get(),
point_list_unsorted.data().get(),
state->point_list.data().get(),
num_needed, 0, 32 + bit);
cudaMemset(ranges.data().get(), 0, tile_grid.x * tile_grid.y * sizeof(uint2));
cudaMemset(
state->
ranges.data().get(), 0, tile_grid.x * tile_grid.y * sizeof(uint2));
// Identify start and end of per-tile workloads in sorted list
identifyTileRanges << <(num_needed + 255) / 256, 256 >> > (
num_needed,
point_list_keys.data().get(),
ranges.data().get()
state->
ranges.data().get()
);
// Let each tile blend its range of Gaussians independently in parallel
const float* feature_ptr = colors_precomp != nullptr ? colors_precomp : rgb.data().get();
const float* feature_ptr = colors_precomp != nullptr ? colors_precomp :
state->
rgb.data().get();
FORWARD::render(
tile_grid, block,
ranges.data().get(),
point_list.data().get(),
state->
ranges.data().get(),
state->
point_list.data().get(),
width, height,
means2D.data().get(),
state->
means2D.data().get(),
feature_ptr,
conic_opacity.data().get(),
accum_alpha.data().get(),
n_contrib.data().get(),
state->
conic_opacity.data().get(),
state->
accum_alpha.data().get(),
state->
n_contrib.data().get(),
background,
out_color);
}
...
...
@@ -319,6 +332,8 @@ void CudaRasterizer::RasterizerImpl::forward(
// Produce necessary gradients for optimization, corresponding
// to forward render pass
void CudaRasterizer::RasterizerImpl::backward(
const int* radii,
const InternalState* state,
const int P, int D, int M,
const float* background,
const int width, int height,
...
...
@@ -333,7 +348,6 @@ void CudaRasterizer::RasterizerImpl::backward(
const float* projmatrix,
const float* campos,
const float tan_fovx, float tan_fovy,
const int* radii,
const float* dL_dpix,
float* dL_dmean2D,
float* dL_dconic,
...
...
@@ -345,11 +359,6 @@ void CudaRasterizer::RasterizerImpl::backward(
float* dL_dscale,
float* dL_drot)
{
if (radii == nullptr)
{
radii = internal_radii.data().get();
}
const float focal_y = height / (2.0f * tan_fovy);
const float focal_x = width / (2.0f * tan_fovx);
...
...
@@ -359,19 +368,19 @@ void CudaRasterizer::RasterizerImpl::backward(
// Compute loss gradients w.r.t. 2D mean position, conic matrix,
// opacity and RGB of Gaussians from per-pixel loss gradients.
// If we were given precomputed colors and not SHs, use them.
const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : rgb.data().get();
const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp :
state->
rgb.data().get();
BACKWARD::render(
tile_grid,
block,
ranges.data().get(),
point_list.data().get(),
state->
ranges.data().get(),
state->
point_list.data().get(),
width, height,
background,
means2D.data().get(),
conic_opacity.data().get(),
state->
means2D.data().get(),
state->
conic_opacity.data().get(),
color_ptr,
accum_alpha.data().get(),
n_contrib.data().get(),
state->
accum_alpha.data().get(),
state->
n_contrib.data().get(),
dL_dpix,
(float3*)dL_dmean2D,
(float4*)dL_dconic,
...
...
@@ -381,12 +390,12 @@ void CudaRasterizer::RasterizerImpl::backward(
// Take care of the rest of preprocessing. Was the precomputed covariance
// given to us or a scales/rot pair? If precomputed, pass that. If not,
// use the one we computed ourselves.
const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp : cov3D.data().get();
const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp :
state->
cov3D.data().get();
BACKWARD::preprocess(P, D, M,
(float3*)means3D,
radii,
shs,
clamped.data().get(),
state->
clamped.data().get(),
(glm::vec3*)scales,
(glm::vec4*)rotations,
scale_modifier,
...
...
cuda_rasterizer/rasterizer_impl.h
View file @
bccbb2e8
...
...
@@ -8,41 +8,53 @@
namespace
CudaRasterizer
{
class
RasterizerImpl
:
public
Rasterizer
//// Internal state kept across forward / backward
struct
InternalState
{
private
:
int
maxP
=
0
;
int
maxPixels
=
0
;
int
resizeMultiplier
=
2
;
// Initial aux structs
size_t
sorting_size
;
size_t
list_sorting_size
;
size_t
scan_size
;
thrust
::
device_vector
<
uint2
>
ranges
;
thrust
::
device_vector
<
float2
>
means2D
;
thrust
::
device_vector
<
float4
>
conic_opacity
;
thrust
::
device_vector
<
float
>
accum_alpha
;
thrust
::
device_vector
<
uint32_t
>
n_contrib
;
thrust
::
device_vector
<
float
>
cov3D
;
thrust
::
device_vector
<
bool
>
clamped
;
thrust
::
device_vector
<
float
>
rgb
;
thrust
::
device_vector
<
float
>
depths
;
thrust
::
device_vector
<
uint32_t
>
tiles_touched
;
thrust
::
device_vector
<
uint32_t
>
point_offsets
;
thrust
::
device_vector
<
uint32_t
>
tiles_touched
;
thrust
::
device_vector
<
char
>
scanning_space
;
thrust
::
device_vector
<
uint32_t
>
point_list
;
};
// Auxiliary buffer spaces
thrust
::
device_vector
<
uint64_t
>
point_list_keys_unsorted
;
thrust
::
device_vector
<
uint64_t
>
point_list_keys
;
thrust
::
device_vector
<
uint32_t
>
point_list_unsorted
;
thrust
::
device_vector
<
uint32_t
>
point_list
;
thrust
::
device_vector
<
char
>
scanning_space
;
thrust
::
device_vector
<
char
>
list_sorting_space
;
thrust
::
device_vector
<
bool
>
clamped
;
thrust
::
device_vector
<
int
>
internal_radii
;
// Internal state kept across forward / backward
thrust
::
device_vector
<
uint2
>
ranges
;
thrust
::
device_vector
<
uint32_t
>
n_contrib
;
thrust
::
device_vector
<
float
>
accum_alpha
;
class
RasterizerImpl
:
public
Rasterizer
{
private
:
thrust
::
device_vector
<
float2
>
means2D
;
thrust
::
device_vector
<
float
>
cov3D
;
thrust
::
device_vector
<
float4
>
conic_opacity
;
thrust
::
device_vector
<
float
>
rgb
;
int
resizeMultiplier
=
2
;
public
:
virtual
InternalState
*
createInternalState
()
override
{
return
new
InternalState
();
}
virtual
void
killInternalState
(
InternalState
*
is
)
override
{
delete
is
;
}
virtual
void
markVisible
(
int
P
,
float
*
means3D
,
...
...
@@ -67,10 +79,13 @@ namespace CudaRasterizer
const
float
*
cam_pos
,
const
float
tan_fovx
,
float
tan_fovy
,
const
bool
prefiltered
,
float
*
out_color
,
int
*
radii
)
override
;
int
*
radii
,
InternalState
*
state
,
float
*
out_color
)
override
;
virtual
void
backward
(
const
int
*
radii
,
const
InternalState
*
fixedState
,
const
int
P
,
int
D
,
int
M
,
const
float
*
background
,
const
int
width
,
int
height
,
...
...
@@ -85,7 +100,6 @@ namespace CudaRasterizer
const
float
*
projmatrix
,
const
float
*
campos
,
const
float
tan_fovx
,
float
tan_fovy
,
const
int
*
radii
,
const
float
*
dL_dpix
,
float
*
dL_dmean2D
,
float
*
dL_dconic
,
...
...
ext.cpp
View file @
bccbb2e8
#include <torch/extension.h>
#include "rasterize_points.h"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"rasterize_gaussians"
,
&
RasterizeGaussiansCUDA
);
m
.
def
(
"rasterize_gaussians_backward"
,
&
RasterizeGaussiansBackwardCUDA
);
m
.
def
(
"mark_visible"
,
&
markVisible
);
m
.
def
(
"create_rasterizer_state"
,
&
createRasterizerState
);
m
.
def
(
"delete_rasterizer_state"
,
&
deleteRasterizerState
);
}
\ No newline at end of file
rasterize_points.cu
View file @
bccbb2e8
...
...
@@ -11,11 +11,26 @@
#include <memory>
#include "cuda_rasterizer/config.h"
#include "cuda_rasterizer/rasterizer.h"
#include "rasterize_points.h"
#include <fstream>
#include <string>
static std::unique_ptr<CudaRasterizer::Rasterizer> cudaRenderer = nullptr;
void* createRasterizerState()
{
if (cudaRenderer == nullptr)
{
cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make());
}
return (void*)cudaRenderer->createInternalState();
}
void deleteRasterizerState(void* state)
{
cudaRenderer->killInternalState((CudaRasterizer::InternalState*)state);
}
std::tuple<torch::Tensor, torch::Tensor>
RasterizeGaussiansCUDA(
const torch::Tensor& background,
...
...
@@ -35,7 +50,8 @@ RasterizeGaussiansCUDA(
const torch::Tensor& sh,
const int degree,
const torch::Tensor& campos,
const bool prefiltered)
const bool prefiltered,
void* internalState)
{
if (means3D.ndimension() != 2 || means3D.size(1) != 3) {
...
...
@@ -83,14 +99,16 @@ RasterizeGaussiansCUDA(
tan_fovx,
tan_fovy,
prefiltered,
out_color.contiguous().data<float>(),
radii.contiguous().data<int>());
radii.contiguous().data<int>(),
(CudaRasterizer::InternalState*)internalState,
out_color.contiguous().data<float>());
}
return std::make_tuple(out_color, radii);
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeGaussiansBackwardCUDA(
const void* internalState,
const torch::Tensor& background,
const torch::Tensor& means3D,
const torch::Tensor& radii,
...
...
@@ -130,7 +148,10 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
if(P != 0)
{
cudaRenderer->backward(P, degree, M,
cudaRenderer->backward(
radii.contiguous().data<int>(),
(CudaRasterizer::InternalState*)internalState,
P, degree, M,
background.contiguous().data<float>(),
W, H,
means3D.contiguous().data<float>(),
...
...
@@ -145,7 +166,6 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
campos.contiguous().data<float>(),
tan_fovx,
tan_fovy,
radii.contiguous().data<int>(),
dL_dout_color.contiguous().data<float>(),
dL_dmeans2D.contiguous().data<float>(),
dL_dconic.contiguous().data<float>(),
...
...
rasterize_points.h
View file @
bccbb2e8
...
...
@@ -25,10 +25,12 @@ RasterizeGaussiansCUDA(
const
torch
::
Tensor
&
sh
,
const
int
degree
,
const
torch
::
Tensor
&
campos
,
const
bool
prefiltered
);
const
bool
prefiltered
,
void
*
internalState
);
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
RasterizeGaussiansBackwardCUDA
(
const
void
*
internalState
,
const
torch
::
Tensor
&
background
,
const
torch
::
Tensor
&
means3D
,
const
torch
::
Tensor
&
radii
,
...
...
@@ -46,6 +48,10 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
const
int
degree
,
const
torch
::
Tensor
&
campos
);
void
*
createRasterizerState
();
void
deleteRasterizerState
(
void
*
state
);
torch
::
Tensor
markVisible
(
torch
::
Tensor
&
means3D
,
torch
::
Tensor
&
viewmatrix
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment