Commit 55b0c1b0 by Bernhard Kerbl

This was a bad idea, undoing it

parent e41a365f
...@@ -5,16 +5,10 @@ ...@@ -5,16 +5,10 @@
namespace CudaRasterizer namespace CudaRasterizer
{ {
struct InternalState;
class Rasterizer class Rasterizer
{ {
public: public:
virtual InternalState* createInternalState() = 0;
virtual void killInternalState(InternalState*) = 0;
virtual void markVisible( virtual void markVisible(
int P, int P,
float* means3D, float* means3D,
...@@ -39,13 +33,10 @@ namespace CudaRasterizer ...@@ -39,13 +33,10 @@ namespace CudaRasterizer
const float* cam_pos, const float* cam_pos,
const float tan_fovx, float tan_fovy, const float tan_fovx, float tan_fovy,
const bool prefiltered, const bool prefiltered,
int* radii, float* out_color,
InternalState* state, int* radii = nullptr) = 0;
float* out_color) = 0;
virtual void backward( virtual void backward(
const int* radii,
const InternalState* state,
const int P, int D, int M, const int P, int D, int M,
const float* background, const float* background,
const int width, int height, const int width, int height,
...@@ -56,10 +47,11 @@ namespace CudaRasterizer ...@@ -56,10 +47,11 @@ namespace CudaRasterizer
const float scale_modifier, const float scale_modifier,
const float* rotations, const float* rotations,
const float* cov3D_precomp, const float* cov3D_precomp,
const float* viewmatrix, const float* viewmatrix,
const float* projmatrix, const float* projmatrix,
const float* campos, const float* campos,
const float tan_fovx, float tan_fovy, const float tan_fovx, float tan_fovy,
const int* radii,
const float* dL_dpix, const float* dL_dpix,
float* dL_dmean2D, float* dL_dmean2D,
float* dL_dconic, float* dL_dconic,
...@@ -74,8 +66,6 @@ namespace CudaRasterizer ...@@ -74,8 +66,6 @@ namespace CudaRasterizer
virtual ~Rasterizer() {}; virtual ~Rasterizer() {};
static Rasterizer* make(int resizeMultipliyer = 2); static Rasterizer* make(int resizeMultipliyer = 2);
static void kill(Rasterizer* rasterizer);
}; };
}; };
......
...@@ -137,11 +137,6 @@ CudaRasterizer::Rasterizer* CudaRasterizer::Rasterizer::make(int resizeMultiplie ...@@ -137,11 +137,6 @@ CudaRasterizer::Rasterizer* CudaRasterizer::Rasterizer::make(int resizeMultiplie
return new CudaRasterizer::RasterizerImpl(resizeMultiplier); return new CudaRasterizer::RasterizerImpl(resizeMultiplier);
} }
void CudaRasterizer::Rasterizer::kill(Rasterizer* rasterizer)
{
delete rasterizer;
}
// Mark Gaussians as visible/invisible, based on view frustum testing // Mark Gaussians as visible/invisible, based on view frustum testing
void CudaRasterizer::RasterizerImpl::markVisible( void CudaRasterizer::RasterizerImpl::markVisible(
int P, int P,
...@@ -176,46 +171,46 @@ void CudaRasterizer::RasterizerImpl::forward( ...@@ -176,46 +171,46 @@ void CudaRasterizer::RasterizerImpl::forward(
const float* cam_pos, const float* cam_pos,
const float tan_fovx, float tan_fovy, const float tan_fovx, float tan_fovy,
const bool prefiltered, const bool prefiltered,
int* radii, float* out_color,
InternalState* state, int* radii)
float* out_color)
{ {
const float focal_y = height / (2.0f * tan_fovy); const float focal_y = height / (2.0f * tan_fovy);
const float focal_x = width / (2.0f * tan_fovx); const float focal_x = width / (2.0f * tan_fovx);
// Dynamically resize auxiliary buffers during training // Dynamically resize auxiliary buffers during training
if (P > state->maxP) if (P > maxP)
{ {
state->maxP = resizeMultiplier * P; maxP = resizeMultiplier * P;
state->cov3D.resize(state->maxP * 6); cov3D.resize(maxP * 6);
state->rgb.resize(state->maxP * 3); rgb.resize(maxP * 3);
state->tiles_touched.resize(state->maxP); tiles_touched.resize(maxP);
state->point_offsets.resize(state->maxP); point_offsets.resize(maxP);
state->clamped.resize(3 * state->maxP); clamped.resize(3 * maxP);
state->depths.resize(state->maxP); depths.resize(maxP);
state->means2D.resize(state->maxP); means2D.resize(maxP);
state->conic_opacity.resize(state->maxP); conic_opacity.resize(maxP);
size_t scan_size; cub::DeviceScan::InclusiveSum(nullptr, scan_size, tiles_touched.data().get(), tiles_touched.data().get(), maxP);
cub::DeviceScan::InclusiveSum(nullptr, scanning_space.resize(scan_size);
scan_size, }
state->tiles_touched.data().get(),
state->tiles_touched.data().get(), if (radii == nullptr)
state->maxP); {
state->scanning_space.resize(scan_size); internal_radii.resize(maxP);
radii = internal_radii.data().get();
} }
dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1); dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1);
dim3 block(BLOCK_X, BLOCK_Y, 1); dim3 block(BLOCK_X, BLOCK_Y, 1);
// Dynamically resize image-based auxiliary buffers during training // Dynamically resize image-based auxiliary buffers during training
if (width * height > state->maxPixels) if (width * height > maxPixels)
{ {
state->maxPixels = width * height; maxPixels = width * height;
state->accum_alpha.resize(state->maxPixels); accum_alpha.resize(maxPixels);
state->n_contrib.resize(state->maxPixels); n_contrib.resize(maxPixels);
state->ranges.resize(tile_grid.x * tile_grid.y); ranges.resize(tile_grid.x * tile_grid.y);
} }
if (NUM_CHANNELS != 3 && colors_precomp == nullptr) if (NUM_CHANNELS != 3 && colors_precomp == nullptr)
...@@ -232,7 +227,7 @@ void CudaRasterizer::RasterizerImpl::forward( ...@@ -232,7 +227,7 @@ void CudaRasterizer::RasterizerImpl::forward(
(glm::vec4*)rotations, (glm::vec4*)rotations,
opacities, opacities,
shs, shs,
state->clamped.data().get(), clamped.data().get(),
cov3D_precomp, cov3D_precomp,
colors_precomp, colors_precomp,
viewmatrix, projmatrix, viewmatrix, projmatrix,
...@@ -241,56 +236,45 @@ void CudaRasterizer::RasterizerImpl::forward( ...@@ -241,56 +236,45 @@ void CudaRasterizer::RasterizerImpl::forward(
tan_fovx, tan_fovy, tan_fovx, tan_fovy,
focal_x, focal_y, focal_x, focal_y,
radii, radii,
state->means2D.data().get(), means2D.data().get(),
state->depths.data().get(), depths.data().get(),
state->cov3D.data().get(), cov3D.data().get(),
state->rgb.data().get(), rgb.data().get(),
state->conic_opacity.data().get(), conic_opacity.data().get(),
tile_grid, tile_grid,
state->tiles_touched.data().get(), tiles_touched.data().get(),
prefiltered prefiltered
); );
// Compute prefix sum over full list of touched tile counts by Gaussians // Compute prefix sum over full list of touched tile counts by Gaussians
// E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8] // E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8]
size_t scanning_space_size = state->scanning_space.size(); cub::DeviceScan::InclusiveSum(scanning_space.data().get(), scan_size,
cub::DeviceScan::InclusiveSum( tiles_touched.data().get(), point_offsets.data().get(), P);
state->scanning_space.data().get(),
scanning_space_size,
state->tiles_touched.data().get(),
state->point_offsets.data().get(),
P);
// Retrieve total number of Gaussian instances to launch and resize aux buffers // Retrieve total number of Gaussian instances to launch and resize aux buffers
int num_needed; int num_needed;
cudaMemcpy(&num_needed, state->point_offsets.data().get() + P - 1, sizeof(int), cudaMemcpyDeviceToHost); cudaMemcpy(&num_needed, point_offsets.data().get() + P - 1, sizeof(int), cudaMemcpyDeviceToHost);
if (num_needed > point_list_keys_unsorted.size()) if (num_needed > point_list_keys_unsorted.size())
{ {
int resizeNum = resizeMultiplier * num_needed; point_list_keys_unsorted.resize(2 * num_needed);
point_list_keys_unsorted.resize(resizeNum); point_list_keys.resize(2 * num_needed);
point_list_keys.resize(resizeNum); point_list_unsorted.resize(2 * num_needed);
point_list_unsorted.resize(resizeNum); point_list.resize(2 * num_needed);
size_t sorting_size;
cub::DeviceRadixSort::SortPairs( cub::DeviceRadixSort::SortPairs(
nullptr, sorting_size, nullptr, sorting_size,
point_list_keys_unsorted.data().get(), point_list_keys.data().get(), point_list_keys_unsorted.data().get(), point_list_keys.data().get(),
point_list_unsorted.data().get(), state->point_list.data().get(), point_list_unsorted.data().get(), point_list.data().get(),
resizeNum); 2 * num_needed);
list_sorting_space.resize(sorting_size); list_sorting_space.resize(sorting_size);
} }
if (num_needed > state->point_list.size())
{
state->point_list.resize(resizeMultiplier * num_needed);
}
// For each instance to be rendered, produce adequate [ tile | depth ] key // For each instance to be rendered, produce adequate [ tile | depth ] key
// and corresponding dublicated Gaussian indices to be sorted // and corresponding dublicated Gaussian indices to be sorted
duplicateWithKeys << <(P + 255) / 256, 256 >> > ( duplicateWithKeys << <(P + 255) / 256, 256 >> > (
P, P,
state->means2D.data().get(), means2D.data().get(),
state->depths.data().get(), depths.data().get(),
state->point_offsets.data().get(), point_offsets.data().get(),
point_list_keys_unsorted.data().get(), point_list_keys_unsorted.data().get(),
point_list_unsorted.data().get(), point_list_unsorted.data().get(),
radii, radii,
...@@ -300,36 +284,34 @@ void CudaRasterizer::RasterizerImpl::forward( ...@@ -300,36 +284,34 @@ void CudaRasterizer::RasterizerImpl::forward(
int bit = getHigherMsb(tile_grid.x * tile_grid.y); int bit = getHigherMsb(tile_grid.x * tile_grid.y);
// Sort complete list of (duplicated) Gaussian indices by keys // Sort complete list of (duplicated) Gaussian indices by keys
size_t list_sorting_space_size = list_sorting_space.size();
cub::DeviceRadixSort::SortPairs( cub::DeviceRadixSort::SortPairs(
list_sorting_space.data().get(), list_sorting_space.data().get(),
list_sorting_space_size, sorting_size,
point_list_keys_unsorted.data().get(), point_list_keys.data().get(), point_list_keys_unsorted.data().get(), point_list_keys.data().get(),
point_list_unsorted.data().get(), point_list_unsorted.data().get(), point_list.data().get(),
state->point_list.data().get(),
num_needed, 0, 32 + bit); num_needed, 0, 32 + bit);
cudaMemset(state->ranges.data().get(), 0, tile_grid.x * tile_grid.y * sizeof(uint2)); cudaMemset(ranges.data().get(), 0, tile_grid.x * tile_grid.y * sizeof(uint2));
// Identify start and end of per-tile workloads in sorted list // Identify start and end of per-tile workloads in sorted list
identifyTileRanges << <(num_needed + 255) / 256, 256 >> > ( identifyTileRanges << <(num_needed + 255) / 256, 256 >> > (
num_needed, num_needed,
point_list_keys.data().get(), point_list_keys.data().get(),
state->ranges.data().get() ranges.data().get()
); );
// Let each tile blend its range of Gaussians independently in parallel // Let each tile blend its range of Gaussians independently in parallel
const float* feature_ptr = colors_precomp != nullptr ? colors_precomp : state->rgb.data().get(); const float* feature_ptr = colors_precomp != nullptr ? colors_precomp : rgb.data().get();
FORWARD::render( FORWARD::render(
tile_grid, block, tile_grid, block,
state->ranges.data().get(), ranges.data().get(),
state->point_list.data().get(), point_list.data().get(),
width, height, width, height,
state->means2D.data().get(), means2D.data().get(),
feature_ptr, feature_ptr,
state->conic_opacity.data().get(), conic_opacity.data().get(),
state->accum_alpha.data().get(), accum_alpha.data().get(),
state->n_contrib.data().get(), n_contrib.data().get(),
background, background,
out_color); out_color);
} }
...@@ -337,8 +319,6 @@ void CudaRasterizer::RasterizerImpl::forward( ...@@ -337,8 +319,6 @@ void CudaRasterizer::RasterizerImpl::forward(
// Produce necessary gradients for optimization, corresponding // Produce necessary gradients for optimization, corresponding
// to forward render pass // to forward render pass
void CudaRasterizer::RasterizerImpl::backward( void CudaRasterizer::RasterizerImpl::backward(
const int* radii,
const InternalState* state,
const int P, int D, int M, const int P, int D, int M,
const float* background, const float* background,
const int width, int height, const int width, int height,
...@@ -353,6 +333,7 @@ void CudaRasterizer::RasterizerImpl::backward( ...@@ -353,6 +333,7 @@ void CudaRasterizer::RasterizerImpl::backward(
const float* projmatrix, const float* projmatrix,
const float* campos, const float* campos,
const float tan_fovx, float tan_fovy, const float tan_fovx, float tan_fovy,
const int* radii,
const float* dL_dpix, const float* dL_dpix,
float* dL_dmean2D, float* dL_dmean2D,
float* dL_dconic, float* dL_dconic,
...@@ -364,6 +345,11 @@ void CudaRasterizer::RasterizerImpl::backward( ...@@ -364,6 +345,11 @@ void CudaRasterizer::RasterizerImpl::backward(
float* dL_dscale, float* dL_dscale,
float* dL_drot) float* dL_drot)
{ {
if (radii == nullptr)
{
radii = internal_radii.data().get();
}
const float focal_y = height / (2.0f * tan_fovy); const float focal_y = height / (2.0f * tan_fovy);
const float focal_x = width / (2.0f * tan_fovx); const float focal_x = width / (2.0f * tan_fovx);
...@@ -373,19 +359,19 @@ void CudaRasterizer::RasterizerImpl::backward( ...@@ -373,19 +359,19 @@ void CudaRasterizer::RasterizerImpl::backward(
// Compute loss gradients w.r.t. 2D mean position, conic matrix, // Compute loss gradients w.r.t. 2D mean position, conic matrix,
// opacity and RGB of Gaussians from per-pixel loss gradients. // opacity and RGB of Gaussians from per-pixel loss gradients.
// If we were given precomputed colors and not SHs, use them. // If we were given precomputed colors and not SHs, use them.
const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : state->rgb.data().get(); const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : rgb.data().get();
BACKWARD::render( BACKWARD::render(
tile_grid, tile_grid,
block, block,
state->ranges.data().get(), ranges.data().get(),
state->point_list.data().get(), point_list.data().get(),
width, height, width, height,
background, background,
state->means2D.data().get(), means2D.data().get(),
state->conic_opacity.data().get(), conic_opacity.data().get(),
color_ptr, color_ptr,
state->accum_alpha.data().get(), accum_alpha.data().get(),
state->n_contrib.data().get(), n_contrib.data().get(),
dL_dpix, dL_dpix,
(float3*)dL_dmean2D, (float3*)dL_dmean2D,
(float4*)dL_dconic, (float4*)dL_dconic,
...@@ -395,12 +381,12 @@ void CudaRasterizer::RasterizerImpl::backward( ...@@ -395,12 +381,12 @@ void CudaRasterizer::RasterizerImpl::backward(
// Take care of the rest of preprocessing. Was the precomputed covariance // Take care of the rest of preprocessing. Was the precomputed covariance
// given to us or a scales/rot pair? If precomputed, pass that. If not, // given to us or a scales/rot pair? If precomputed, pass that. If not,
// use the one we computed ourselves. // use the one we computed ourselves.
const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp : state->cov3D.data().get(); const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp : cov3D.data().get();
BACKWARD::preprocess(P, D, M, BACKWARD::preprocess(P, D, M,
(float3*)means3D, (float3*)means3D,
radii, radii,
shs, shs,
state->clamped.data().get(), clamped.data().get(),
(glm::vec3*)scales, (glm::vec3*)scales,
(glm::vec4*)rotations, (glm::vec4*)rotations,
scale_modifier, scale_modifier,
......
...@@ -8,53 +8,41 @@ ...@@ -8,53 +8,41 @@
namespace CudaRasterizer namespace CudaRasterizer
{ {
//// Internal state kept across forward / backward class RasterizerImpl : public Rasterizer
struct InternalState
{ {
private:
int maxP = 0; int maxP = 0;
int maxPixels = 0; int maxPixels = 0;
int resizeMultiplier = 2;
thrust::device_vector<uint2> ranges; // Initial aux structs
thrust::device_vector<float2> means2D; size_t sorting_size;
thrust::device_vector<float4> conic_opacity; size_t list_sorting_size;
thrust::device_vector<float> accum_alpha; size_t scan_size;
thrust::device_vector<uint32_t> n_contrib;
thrust::device_vector<float> cov3D;
thrust::device_vector<bool> clamped;
thrust::device_vector<float> rgb;
thrust::device_vector<float> depths; thrust::device_vector<float> depths;
thrust::device_vector<uint32_t> point_offsets;
thrust::device_vector<uint32_t> tiles_touched; thrust::device_vector<uint32_t> tiles_touched;
thrust::device_vector<uint32_t> point_offsets;
thrust::device_vector<char> scanning_space; thrust::device_vector<uint64_t> point_list_keys_unsorted;
thrust::device_vector<uint64_t> point_list_keys;
thrust::device_vector<uint32_t> point_list_unsorted;
thrust::device_vector<uint32_t> point_list; thrust::device_vector<uint32_t> point_list;
}; thrust::device_vector<char> scanning_space;
thrust::device_vector<char> list_sorting_space;
thrust::device_vector<bool> clamped;
thrust::device_vector<int> internal_radii;
// Auxiliary buffer spaces // Internal state kept across forward / backward
thrust::device_vector<uint64_t> point_list_keys_unsorted; thrust::device_vector<uint2> ranges;
thrust::device_vector<uint64_t> point_list_keys; thrust::device_vector<uint32_t> n_contrib;
thrust::device_vector<uint32_t> point_list_unsorted; thrust::device_vector<float> accum_alpha;
thrust::device_vector<char> list_sorting_space;
class RasterizerImpl : public Rasterizer thrust::device_vector<float2> means2D;
{ thrust::device_vector<float> cov3D;
private: thrust::device_vector<float4> conic_opacity;
thrust::device_vector<float> rgb;
int resizeMultiplier = 2;
public: public:
virtual InternalState* createInternalState() override
{
return new InternalState();
}
virtual void killInternalState(InternalState* is) override
{
delete is;
}
virtual void markVisible( virtual void markVisible(
int P, int P,
float* means3D, float* means3D,
...@@ -79,13 +67,10 @@ namespace CudaRasterizer ...@@ -79,13 +67,10 @@ namespace CudaRasterizer
const float* cam_pos, const float* cam_pos,
const float tan_fovx, float tan_fovy, const float tan_fovx, float tan_fovy,
const bool prefiltered, const bool prefiltered,
int* radii, float* out_color,
InternalState* state, int* radii) override;
float* out_color) override;
virtual void backward( virtual void backward(
const int* radii,
const InternalState* fixedState,
const int P, int D, int M, const int P, int D, int M,
const float* background, const float* background,
const int width, int height, const int width, int height,
...@@ -100,6 +85,7 @@ namespace CudaRasterizer ...@@ -100,6 +85,7 @@ namespace CudaRasterizer
const float* projmatrix, const float* projmatrix,
const float* campos, const float* campos,
const float tan_fovx, float tan_fovy, const float tan_fovx, float tan_fovy,
const int* radii,
const float* dL_dpix, const float* dL_dpix,
float* dL_dmean2D, float* dL_dmean2D,
float* dL_dconic, float* dL_dconic,
......
...@@ -4,7 +4,6 @@ import torch ...@@ -4,7 +4,6 @@ import torch
from . import _C from . import _C
def rasterize_gaussians( def rasterize_gaussians(
instance,
means3D, means3D,
means2D, means2D,
sh, sh,
...@@ -14,10 +13,8 @@ def rasterize_gaussians( ...@@ -14,10 +13,8 @@ def rasterize_gaussians(
rotations, rotations,
cov3Ds_precomp, cov3Ds_precomp,
raster_settings, raster_settings,
rasterizer_state
): ):
return _RasterizeGaussians.apply( return _RasterizeGaussians.apply(
instance,
means3D, means3D,
means2D, means2D,
sh, sh,
...@@ -27,14 +24,12 @@ def rasterize_gaussians( ...@@ -27,14 +24,12 @@ def rasterize_gaussians(
rotations, rotations,
cov3Ds_precomp, cov3Ds_precomp,
raster_settings, raster_settings,
rasterizer_state
) )
class _RasterizeGaussians(torch.autograd.Function): class _RasterizeGaussians(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
ctx, ctx,
instance,
means3D, means3D,
means2D, means2D,
sh, sh,
...@@ -44,12 +39,10 @@ class _RasterizeGaussians(torch.autograd.Function): ...@@ -44,12 +39,10 @@ class _RasterizeGaussians(torch.autograd.Function):
rotations, rotations,
cov3Ds_precomp, cov3Ds_precomp,
raster_settings, raster_settings,
rasterizer_state
): ):
# Restructure arguments the way that the C++ lib expects them # Restructure arguments the way that the C++ lib expects them
args = ( args = (
instance,
raster_settings.bg, raster_settings.bg,
means3D, means3D,
colors_precomp, colors_precomp,
...@@ -68,7 +61,6 @@ class _RasterizeGaussians(torch.autograd.Function): ...@@ -68,7 +61,6 @@ class _RasterizeGaussians(torch.autograd.Function):
raster_settings.sh_degree, raster_settings.sh_degree,
raster_settings.campos, raster_settings.campos,
raster_settings.prefiltered, raster_settings.prefiltered,
rasterizer_state
) )
# Invoke C++/CUDA rasterizer # Invoke C++/CUDA rasterizer
...@@ -76,8 +68,6 @@ class _RasterizeGaussians(torch.autograd.Function): ...@@ -76,8 +68,6 @@ class _RasterizeGaussians(torch.autograd.Function):
# Keep relevant tensors for backward # Keep relevant tensors for backward
ctx.raster_settings = raster_settings ctx.raster_settings = raster_settings
ctx.instance = instance
ctx.rasterizer_state = rasterizer_state
ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh) ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh)
return color, radii return color, radii
...@@ -85,15 +75,11 @@ class _RasterizeGaussians(torch.autograd.Function): ...@@ -85,15 +75,11 @@ class _RasterizeGaussians(torch.autograd.Function):
def backward(ctx, grad_out_color, _): def backward(ctx, grad_out_color, _):
# Restore necessary values from context # Restore necessary values from context
instance = ctx.instance
rasterizer_state = ctx.rasterizer_state
raster_settings = ctx.raster_settings raster_settings = ctx.raster_settings
colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh = ctx.saved_tensors colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh = ctx.saved_tensors
# Restructure args as C++ method expects them # Restructure args as C++ method expects them
args = (instance, args = (raster_settings.bg,
rasterizer_state,
raster_settings.bg,
means3D, means3D,
radii, radii,
colors_precomp, colors_precomp,
...@@ -114,7 +100,6 @@ class _RasterizeGaussians(torch.autograd.Function): ...@@ -114,7 +100,6 @@ class _RasterizeGaussians(torch.autograd.Function):
grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args) grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
grads = ( grads = (
None,
grad_means3D, grad_means3D,
grad_means2D, grad_means2D,
grad_sh, grad_sh,
...@@ -124,7 +109,6 @@ class _RasterizeGaussians(torch.autograd.Function): ...@@ -124,7 +109,6 @@ class _RasterizeGaussians(torch.autograd.Function):
grad_rotations, grad_rotations,
grad_cov3Ds_precomp, grad_cov3Ds_precomp,
None, None,
None,
) )
return grads return grads
...@@ -143,31 +127,25 @@ class GaussianRasterizationSettings(NamedTuple): ...@@ -143,31 +127,25 @@ class GaussianRasterizationSettings(NamedTuple):
prefiltered : bool prefiltered : bool
class GaussianRasterizer(nn.Module): class GaussianRasterizer(nn.Module):
def __init__(self): def __init__(self, raster_settings):
super().__init__() super().__init__()
self.instance = _C.create_rasterizer() self.raster_settings = raster_settings
def __del__(self):
_C.delete_rasterizer(self.instance)
def createRasterizerState(self):
return _C.create_rasterizer_state(self.instance)
def deleteRasterizerState(self, state): def markVisible(self, positions):
_C.delete_rasterizer_state(self.instance, state)
def markVisible(self, raster_settings, positions):
# Mark visible points (based on frustum culling for camera) with a boolean # Mark visible points (based on frustum culling for camera) with a boolean
with torch.no_grad(): with torch.no_grad():
raster_settings = self.raster_settings
visible = _C.mark_visible( visible = _C.mark_visible(
self.instance,
positions, positions,
raster_settings.viewmatrix, raster_settings.viewmatrix,
raster_settings.projmatrix) raster_settings.projmatrix)
return visible return visible
def forward(self, rasterizer_state, raster_settings, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None): def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None):
raster_settings = self.raster_settings
if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None): if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None):
raise Exception('Please provide excatly one of either SHs or precomputed colors!') raise Exception('Please provide excatly one of either SHs or precomputed colors!')
...@@ -188,7 +166,6 @@ class GaussianRasterizer(nn.Module): ...@@ -188,7 +166,6 @@ class GaussianRasterizer(nn.Module):
# Invoke C++/CUDA rasterization routine # Invoke C++/CUDA rasterization routine
return rasterize_gaussians( return rasterize_gaussians(
self.instance,
means3D, means3D,
means2D, means2D,
shs, shs,
...@@ -198,6 +175,5 @@ class GaussianRasterizer(nn.Module): ...@@ -198,6 +175,5 @@ class GaussianRasterizer(nn.Module):
rotations, rotations,
cov3D_precomp, cov3D_precomp,
raster_settings, raster_settings,
rasterizer_state
) )
#include <torch/extension.h> #include <torch/extension.h>
#include "rasterize_points.h" #include "rasterize_points.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
{
m.def("rasterize_gaussians", &RasterizeGaussiansCUDA); m.def("rasterize_gaussians", &RasterizeGaussiansCUDA);
m.def("rasterize_gaussians_backward", &RasterizeGaussiansBackwardCUDA); m.def("rasterize_gaussians_backward", &RasterizeGaussiansBackwardCUDA);
m.def("mark_visible", &markVisible); m.def("mark_visible", &markVisible);
m.def("create_rasterizer_state", &createRasterizerState);
m.def("delete_rasterizer_state", &deleteRasterizerState);
m.def("create_rasterizer", &createRasterizer);
m.def("delete_rasterizer", &deleteRasterizer);
} }
\ No newline at end of file
...@@ -11,33 +11,13 @@ ...@@ -11,33 +11,13 @@
#include <memory> #include <memory>
#include "cuda_rasterizer/config.h" #include "cuda_rasterizer/config.h"
#include "cuda_rasterizer/rasterizer.h" #include "cuda_rasterizer/rasterizer.h"
#include "rasterize_points.h"
#include <fstream> #include <fstream>
#include <string> #include <string>
void* createRasterizer() static std::unique_ptr<CudaRasterizer::Rasterizer> cudaRenderer = nullptr;
{
return (void*)CudaRasterizer::Rasterizer::make();
}
void deleteRasterizer(void* rasterizer)
{
CudaRasterizer::Rasterizer::kill((CudaRasterizer::Rasterizer*)rasterizer);
}
void* createRasterizerState(void* rasterizer)
{
return (void*)((CudaRasterizer::Rasterizer*)rasterizer)->createInternalState();
}
void deleteRasterizerState(void* rasterizer, void* state)
{
((CudaRasterizer::Rasterizer*)rasterizer)->killInternalState((CudaRasterizer::InternalState*)state);
}
std::tuple<torch::Tensor, torch::Tensor> std::tuple<torch::Tensor, torch::Tensor>
RasterizeGaussiansCUDA( RasterizeGaussiansCUDA(
void* rasterizer,
const torch::Tensor& background, const torch::Tensor& background,
const torch::Tensor& means3D, const torch::Tensor& means3D,
const torch::Tensor& colors, const torch::Tensor& colors,
...@@ -55,14 +35,18 @@ RasterizeGaussiansCUDA( ...@@ -55,14 +35,18 @@ RasterizeGaussiansCUDA(
const torch::Tensor& sh, const torch::Tensor& sh,
const int degree, const int degree,
const torch::Tensor& campos, const torch::Tensor& campos,
const bool prefiltered, const bool prefiltered)
void* internalState)
{ {
if (means3D.ndimension() != 2 || means3D.size(1) != 3) { if (means3D.ndimension() != 2 || means3D.size(1) != 3) {
AT_ERROR("means3D must have dimensions (num_points, 3)"); AT_ERROR("means3D must have dimensions (num_points, 3)");
} }
if (cudaRenderer == nullptr)
{
cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make());
}
const int P = means3D.size(0); const int P = means3D.size(0);
const int N = 1; // batch size hard-coded const int N = 1; // batch size hard-coded
const int H = image_height; const int H = image_height;
...@@ -82,7 +66,7 @@ RasterizeGaussiansCUDA( ...@@ -82,7 +66,7 @@ RasterizeGaussiansCUDA(
M = sh.size(1); M = sh.size(1);
} }
((CudaRasterizer::Rasterizer*)rasterizer)->forward(P, degree, M, cudaRenderer->forward(P, degree, M,
background.contiguous().data<float>(), background.contiguous().data<float>(),
W, H, W, H,
means3D.contiguous().data<float>(), means3D.contiguous().data<float>(),
...@@ -99,17 +83,14 @@ RasterizeGaussiansCUDA( ...@@ -99,17 +83,14 @@ RasterizeGaussiansCUDA(
tan_fovx, tan_fovx,
tan_fovy, tan_fovy,
prefiltered, prefiltered,
radii.contiguous().data<int>(), out_color.contiguous().data<float>(),
(CudaRasterizer::InternalState*)internalState, radii.contiguous().data<int>());
out_color.contiguous().data<float>());
} }
return std::make_tuple(out_color, radii); return std::make_tuple(out_color, radii);
} }
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeGaussiansBackwardCUDA( RasterizeGaussiansBackwardCUDA(
void* rasterizer,
const void* internalState,
const torch::Tensor& background, const torch::Tensor& background,
const torch::Tensor& means3D, const torch::Tensor& means3D,
const torch::Tensor& radii, const torch::Tensor& radii,
...@@ -149,10 +130,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te ...@@ -149,10 +130,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
if(P != 0) if(P != 0)
{ {
((CudaRasterizer::Rasterizer*)rasterizer)->backward( cudaRenderer->backward(P, degree, M,
radii.contiguous().data<int>(),
(CudaRasterizer::InternalState*)internalState,
P, degree, M,
background.contiguous().data<float>(), background.contiguous().data<float>(),
W, H, W, H,
means3D.contiguous().data<float>(), means3D.contiguous().data<float>(),
...@@ -167,6 +145,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te ...@@ -167,6 +145,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
campos.contiguous().data<float>(), campos.contiguous().data<float>(),
tan_fovx, tan_fovx,
tan_fovy, tan_fovy,
radii.contiguous().data<int>(),
dL_dout_color.contiguous().data<float>(), dL_dout_color.contiguous().data<float>(),
dL_dmeans2D.contiguous().data<float>(), dL_dmeans2D.contiguous().data<float>(),
dL_dconic.contiguous().data<float>(), dL_dconic.contiguous().data<float>(),
...@@ -183,11 +162,14 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te ...@@ -183,11 +162,14 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
} }
torch::Tensor markVisible( torch::Tensor markVisible(
void* rasterizer,
torch::Tensor& means3D, torch::Tensor& means3D,
torch::Tensor& viewmatrix, torch::Tensor& viewmatrix,
torch::Tensor& projmatrix) torch::Tensor& projmatrix)
{ {
if (cudaRenderer == nullptr)
{
cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make());
}
const int P = means3D.size(0); const int P = means3D.size(0);
...@@ -195,7 +177,7 @@ torch::Tensor markVisible( ...@@ -195,7 +177,7 @@ torch::Tensor markVisible(
if(P != 0) if(P != 0)
{ {
((CudaRasterizer::Rasterizer*)rasterizer)->markVisible(P, cudaRenderer->markVisible(P,
means3D.contiguous().data<float>(), means3D.contiguous().data<float>(),
viewmatrix.contiguous().data<float>(), viewmatrix.contiguous().data<float>(),
projmatrix.contiguous().data<float>(), projmatrix.contiguous().data<float>(),
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
std::tuple<torch::Tensor, torch::Tensor> std::tuple<torch::Tensor, torch::Tensor>
RasterizeGaussiansCUDA( RasterizeGaussiansCUDA(
void* rasterizer,
const torch::Tensor& background, const torch::Tensor& background,
const torch::Tensor& means3D, const torch::Tensor& means3D,
const torch::Tensor& colors, const torch::Tensor& colors,
...@@ -26,13 +25,10 @@ RasterizeGaussiansCUDA( ...@@ -26,13 +25,10 @@ RasterizeGaussiansCUDA(
const torch::Tensor& sh, const torch::Tensor& sh,
const int degree, const int degree,
const torch::Tensor& campos, const torch::Tensor& campos,
const bool prefiltered, const bool prefiltered);
void* internalState);
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeGaussiansBackwardCUDA( RasterizeGaussiansBackwardCUDA(
void* rasterizer,
const void* internalState,
const torch::Tensor& background, const torch::Tensor& background,
const torch::Tensor& means3D, const torch::Tensor& means3D,
const torch::Tensor& radii, const torch::Tensor& radii,
...@@ -43,23 +39,14 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te ...@@ -43,23 +39,14 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
const torch::Tensor& cov3D_precomp, const torch::Tensor& cov3D_precomp,
const torch::Tensor& viewmatrix, const torch::Tensor& viewmatrix,
const torch::Tensor& projmatrix, const torch::Tensor& projmatrix,
const float tan_fovx, const float tan_fovx,
const float tan_fovy, const float tan_fovy,
const torch::Tensor& dL_dout_color, const torch::Tensor& dL_dout_color,
const torch::Tensor& sh, const torch::Tensor& sh,
const int degree, const int degree,
const torch::Tensor& campos); const torch::Tensor& campos);
void* createRasterizerState(void* rasterizer);
void deleteRasterizerState(void* rasterizer, void* state);
void* createRasterizer();
void deleteRasterizer(void* rasterizer);
torch::Tensor markVisible( torch::Tensor markVisible(
void* rasterizer,
torch::Tensor& means3D, torch::Tensor& means3D,
torch::Tensor& viewmatrix, torch::Tensor& viewmatrix,
torch::Tensor& projmatrix); torch::Tensor& projmatrix);
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment