Commit bccbb2e8 by Bernhard Kerbl

InternalState explicit now

parent ffee75d5
......@@ -5,10 +5,16 @@
namespace CudaRasterizer
{
struct InternalState;
class Rasterizer
{
public:
virtual InternalState* createInternalState() = 0;
virtual void killInternalState(InternalState*) = 0;
virtual void markVisible(
int P,
float* means3D,
......@@ -33,10 +39,13 @@ namespace CudaRasterizer
const float* cam_pos,
const float tan_fovx, float tan_fovy,
const bool prefiltered,
float* out_color,
int* radii = nullptr) = 0;
int* radii,
InternalState* state,
float* out_color) = 0;
virtual void backward(
const int* radii,
const InternalState* state,
const int P, int D, int M,
const float* background,
const int width, int height,
......@@ -47,11 +56,10 @@ namespace CudaRasterizer
const float scale_modifier,
const float* rotations,
const float* cov3D_precomp,
const float* viewmatrix,
const float* viewmatrix,
const float* projmatrix,
const float* campos,
const float tan_fovx, float tan_fovy,
const int* radii,
const float* dL_dpix,
float* dL_dmean2D,
float* dL_dconic,
......
......@@ -171,46 +171,46 @@ void CudaRasterizer::RasterizerImpl::forward(
const float* cam_pos,
const float tan_fovx, float tan_fovy,
const bool prefiltered,
float* out_color,
int* radii)
int* radii,
InternalState* state,
float* out_color)
{
const float focal_y = height / (2.0f * tan_fovy);
const float focal_x = width / (2.0f * tan_fovx);
// Dynamically resize auxiliary buffers during training
if (P > maxP)
{
maxP = resizeMultiplier * P;
cov3D.resize(maxP * 6);
rgb.resize(maxP * 3);
tiles_touched.resize(maxP);
point_offsets.resize(maxP);
clamped.resize(3 * maxP);
depths.resize(maxP);
means2D.resize(maxP);
conic_opacity.resize(maxP);
cub::DeviceScan::InclusiveSum(nullptr, scan_size, tiles_touched.data().get(), tiles_touched.data().get(), maxP);
scanning_space.resize(scan_size);
}
if (radii == nullptr)
if (P > state->maxP)
{
internal_radii.resize(maxP);
radii = internal_radii.data().get();
state->maxP = resizeMultiplier * P;
state->cov3D.resize(state->maxP * 6);
state->rgb.resize(state->maxP * 3);
state->tiles_touched.resize(state->maxP);
state->point_offsets.resize(state->maxP);
state->clamped.resize(3 * state->maxP);
state->depths.resize(state->maxP);
state->means2D.resize(state->maxP);
state->conic_opacity.resize(state->maxP);
size_t scan_size;
cub::DeviceScan::InclusiveSum(nullptr,
scan_size,
state->tiles_touched.data().get(),
state->tiles_touched.data().get(),
state->maxP);
state->scanning_space.resize(scan_size);
}
dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1);
dim3 block(BLOCK_X, BLOCK_Y, 1);
// Dynamically resize image-based auxiliary buffers during training
if (width * height > maxPixels)
if (width * height > state->maxPixels)
{
maxPixels = width * height;
accum_alpha.resize(maxPixels);
n_contrib.resize(maxPixels);
ranges.resize(tile_grid.x * tile_grid.y);
state->maxPixels = width * height;
state->accum_alpha.resize(state->maxPixels);
state->n_contrib.resize(state->maxPixels);
state->ranges.resize(tile_grid.x * tile_grid.y);
}
if (NUM_CHANNELS != 3 && colors_precomp == nullptr)
......@@ -227,7 +227,7 @@ void CudaRasterizer::RasterizerImpl::forward(
(glm::vec4*)rotations,
opacities,
shs,
clamped.data().get(),
state->clamped.data().get(),
cov3D_precomp,
colors_precomp,
viewmatrix, projmatrix,
......@@ -236,45 +236,56 @@ void CudaRasterizer::RasterizerImpl::forward(
tan_fovx, tan_fovy,
focal_x, focal_y,
radii,
means2D.data().get(),
depths.data().get(),
cov3D.data().get(),
rgb.data().get(),
conic_opacity.data().get(),
state->means2D.data().get(),
state->depths.data().get(),
state->cov3D.data().get(),
state->rgb.data().get(),
state->conic_opacity.data().get(),
tile_grid,
tiles_touched.data().get(),
state->tiles_touched.data().get(),
prefiltered
);
// Compute prefix sum over full list of touched tile counts by Gaussians
// E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8]
cub::DeviceScan::InclusiveSum(scanning_space.data().get(), scan_size,
tiles_touched.data().get(), point_offsets.data().get(), P);
size_t scanning_space_size = state->scanning_space.size();
cub::DeviceScan::InclusiveSum(
state->scanning_space.data().get(),
scanning_space_size,
state->tiles_touched.data().get(),
state->point_offsets.data().get(),
P);
// Retrieve total number of Gaussian instances to launch and resize aux buffers
int num_needed;
cudaMemcpy(&num_needed, point_offsets.data().get() + P - 1, sizeof(int), cudaMemcpyDeviceToHost);
cudaMemcpy(&num_needed, state->point_offsets.data().get() + P - 1, sizeof(int), cudaMemcpyDeviceToHost);
if (num_needed > point_list_keys_unsorted.size())
{
point_list_keys_unsorted.resize(2 * num_needed);
point_list_keys.resize(2 * num_needed);
point_list_unsorted.resize(2 * num_needed);
point_list.resize(2 * num_needed);
int resizeNum = resizeMultiplier * num_needed;
point_list_keys_unsorted.resize(resizeNum);
point_list_keys.resize(resizeNum);
point_list_unsorted.resize(resizeNum);
size_t sorting_size;
cub::DeviceRadixSort::SortPairs(
nullptr, sorting_size,
point_list_keys_unsorted.data().get(), point_list_keys.data().get(),
point_list_unsorted.data().get(), point_list.data().get(),
2 * num_needed);
point_list_unsorted.data().get(), state->point_list.data().get(),
resizeNum);
list_sorting_space.resize(sorting_size);
}
if (num_needed > state->point_list.size())
{
state->point_list.resize(resizeMultiplier * num_needed);
}
// For each instance to be rendered, produce adequate [ tile | depth ] key
// and corresponding dublicated Gaussian indices to be sorted
duplicateWithKeys << <(P + 255) / 256, 256 >> > (
P,
means2D.data().get(),
depths.data().get(),
point_offsets.data().get(),
state->means2D.data().get(),
state->depths.data().get(),
state->point_offsets.data().get(),
point_list_keys_unsorted.data().get(),
point_list_unsorted.data().get(),
radii,
......@@ -284,34 +295,36 @@ void CudaRasterizer::RasterizerImpl::forward(
int bit = getHigherMsb(tile_grid.x * tile_grid.y);
// Sort complete list of (duplicated) Gaussian indices by keys
size_t list_sorting_space_size = list_sorting_space.size();
cub::DeviceRadixSort::SortPairs(
list_sorting_space.data().get(),
sorting_size,
list_sorting_space_size,
point_list_keys_unsorted.data().get(), point_list_keys.data().get(),
point_list_unsorted.data().get(), point_list.data().get(),
point_list_unsorted.data().get(),
state->point_list.data().get(),
num_needed, 0, 32 + bit);
cudaMemset(ranges.data().get(), 0, tile_grid.x * tile_grid.y * sizeof(uint2));
cudaMemset(state->ranges.data().get(), 0, tile_grid.x * tile_grid.y * sizeof(uint2));
// Identify start and end of per-tile workloads in sorted list
identifyTileRanges << <(num_needed + 255) / 256, 256 >> > (
num_needed,
point_list_keys.data().get(),
ranges.data().get()
state->ranges.data().get()
);
// Let each tile blend its range of Gaussians independently in parallel
const float* feature_ptr = colors_precomp != nullptr ? colors_precomp : rgb.data().get();
const float* feature_ptr = colors_precomp != nullptr ? colors_precomp : state->rgb.data().get();
FORWARD::render(
tile_grid, block,
ranges.data().get(),
point_list.data().get(),
state->ranges.data().get(),
state->point_list.data().get(),
width, height,
means2D.data().get(),
state->means2D.data().get(),
feature_ptr,
conic_opacity.data().get(),
accum_alpha.data().get(),
n_contrib.data().get(),
state->conic_opacity.data().get(),
state->accum_alpha.data().get(),
state->n_contrib.data().get(),
background,
out_color);
}
......@@ -319,6 +332,8 @@ void CudaRasterizer::RasterizerImpl::forward(
// Produce necessary gradients for optimization, corresponding
// to forward render pass
void CudaRasterizer::RasterizerImpl::backward(
const int* radii,
const InternalState* state,
const int P, int D, int M,
const float* background,
const int width, int height,
......@@ -333,7 +348,6 @@ void CudaRasterizer::RasterizerImpl::backward(
const float* projmatrix,
const float* campos,
const float tan_fovx, float tan_fovy,
const int* radii,
const float* dL_dpix,
float* dL_dmean2D,
float* dL_dconic,
......@@ -345,11 +359,6 @@ void CudaRasterizer::RasterizerImpl::backward(
float* dL_dscale,
float* dL_drot)
{
if (radii == nullptr)
{
radii = internal_radii.data().get();
}
const float focal_y = height / (2.0f * tan_fovy);
const float focal_x = width / (2.0f * tan_fovx);
......@@ -359,19 +368,19 @@ void CudaRasterizer::RasterizerImpl::backward(
// Compute loss gradients w.r.t. 2D mean position, conic matrix,
// opacity and RGB of Gaussians from per-pixel loss gradients.
// If we were given precomputed colors and not SHs, use them.
const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : rgb.data().get();
const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : state->rgb.data().get();
BACKWARD::render(
tile_grid,
block,
ranges.data().get(),
point_list.data().get(),
state->ranges.data().get(),
state->point_list.data().get(),
width, height,
background,
means2D.data().get(),
conic_opacity.data().get(),
state->means2D.data().get(),
state->conic_opacity.data().get(),
color_ptr,
accum_alpha.data().get(),
n_contrib.data().get(),
state->accum_alpha.data().get(),
state->n_contrib.data().get(),
dL_dpix,
(float3*)dL_dmean2D,
(float4*)dL_dconic,
......@@ -381,12 +390,12 @@ void CudaRasterizer::RasterizerImpl::backward(
// Take care of the rest of preprocessing. Was the precomputed covariance
// given to us or a scales/rot pair? If precomputed, pass that. If not,
// use the one we computed ourselves.
const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp : cov3D.data().get();
const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp : state->cov3D.data().get();
BACKWARD::preprocess(P, D, M,
(float3*)means3D,
radii,
shs,
clamped.data().get(),
state->clamped.data().get(),
(glm::vec3*)scales,
(glm::vec4*)rotations,
scale_modifier,
......
......@@ -8,41 +8,53 @@
namespace CudaRasterizer
{
class RasterizerImpl : public Rasterizer
//// Internal state kept across forward / backward
struct InternalState
{
private:
int maxP = 0;
int maxPixels = 0;
int resizeMultiplier = 2;
// Initial aux structs
size_t sorting_size;
size_t list_sorting_size;
size_t scan_size;
thrust::device_vector<uint2> ranges;
thrust::device_vector<float2> means2D;
thrust::device_vector<float4> conic_opacity;
thrust::device_vector<float> accum_alpha;
thrust::device_vector<uint32_t> n_contrib;
thrust::device_vector<float> cov3D;
thrust::device_vector<bool> clamped;
thrust::device_vector<float> rgb;
thrust::device_vector<float> depths;
thrust::device_vector<uint32_t> tiles_touched;
thrust::device_vector<uint32_t> point_offsets;
thrust::device_vector<uint64_t> point_list_keys_unsorted;
thrust::device_vector<uint64_t> point_list_keys;
thrust::device_vector<uint32_t> point_list_unsorted;
thrust::device_vector<uint32_t> point_list;
thrust::device_vector<uint32_t> tiles_touched;
thrust::device_vector<char> scanning_space;
thrust::device_vector<char> list_sorting_space;
thrust::device_vector<bool> clamped;
thrust::device_vector<int> internal_radii;
thrust::device_vector<uint32_t> point_list;
};
// Internal state kept across forward / backward
thrust::device_vector<uint2> ranges;
thrust::device_vector<uint32_t> n_contrib;
thrust::device_vector<float> accum_alpha;
// Auxiliary buffer spaces
thrust::device_vector<uint64_t> point_list_keys_unsorted;
thrust::device_vector<uint64_t> point_list_keys;
thrust::device_vector<uint32_t> point_list_unsorted;
thrust::device_vector<char> list_sorting_space;
thrust::device_vector<float2> means2D;
thrust::device_vector<float> cov3D;
thrust::device_vector<float4> conic_opacity;
thrust::device_vector<float> rgb;
class RasterizerImpl : public Rasterizer
{
private:
int resizeMultiplier = 2;
public:
virtual InternalState* createInternalState() override
{
return new InternalState();
}
virtual void killInternalState(InternalState* is) override
{
delete is;
}
virtual void markVisible(
int P,
float* means3D,
......@@ -67,10 +79,13 @@ namespace CudaRasterizer
const float* cam_pos,
const float tan_fovx, float tan_fovy,
const bool prefiltered,
float* out_color,
int* radii) override;
int* radii,
InternalState* state,
float* out_color) override;
virtual void backward(
const int* radii,
const InternalState* fixedState,
const int P, int D, int M,
const float* background,
const int width, int height,
......@@ -85,7 +100,6 @@ namespace CudaRasterizer
const float* projmatrix,
const float* campos,
const float tan_fovx, float tan_fovy,
const int* radii,
const float* dL_dpix,
float* dL_dmean2D,
float* dL_dconic,
......
#include <torch/extension.h>
#include "rasterize_points.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("rasterize_gaussians", &RasterizeGaussiansCUDA);
m.def("rasterize_gaussians_backward", &RasterizeGaussiansBackwardCUDA);
m.def("mark_visible", &markVisible);
m.def("create_rasterizer_state", &createRasterizerState);
m.def("delete_rasterizer_state", &deleteRasterizerState);
}
\ No newline at end of file
......@@ -11,11 +11,26 @@
#include <memory>
#include "cuda_rasterizer/config.h"
#include "cuda_rasterizer/rasterizer.h"
#include "rasterize_points.h"
#include <fstream>
#include <string>
static std::unique_ptr<CudaRasterizer::Rasterizer> cudaRenderer = nullptr;
void* createRasterizerState()
{
if (cudaRenderer == nullptr)
{
cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make());
}
return (void*)cudaRenderer->createInternalState();
}
void deleteRasterizerState(void* state)
{
cudaRenderer->killInternalState((CudaRasterizer::InternalState*)state);
}
std::tuple<torch::Tensor, torch::Tensor>
RasterizeGaussiansCUDA(
const torch::Tensor& background,
......@@ -35,7 +50,8 @@ RasterizeGaussiansCUDA(
const torch::Tensor& sh,
const int degree,
const torch::Tensor& campos,
const bool prefiltered)
const bool prefiltered,
void* internalState)
{
if (means3D.ndimension() != 2 || means3D.size(1) != 3) {
......@@ -83,14 +99,16 @@ RasterizeGaussiansCUDA(
tan_fovx,
tan_fovy,
prefiltered,
out_color.contiguous().data<float>(),
radii.contiguous().data<int>());
radii.contiguous().data<int>(),
(CudaRasterizer::InternalState*)internalState,
out_color.contiguous().data<float>());
}
return std::make_tuple(out_color, radii);
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeGaussiansBackwardCUDA(
const void* internalState,
const torch::Tensor& background,
const torch::Tensor& means3D,
const torch::Tensor& radii,
......@@ -130,7 +148,10 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
if(P != 0)
{
cudaRenderer->backward(P, degree, M,
cudaRenderer->backward(
radii.contiguous().data<int>(),
(CudaRasterizer::InternalState*)internalState,
P, degree, M,
background.contiguous().data<float>(),
W, H,
means3D.contiguous().data<float>(),
......@@ -145,7 +166,6 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
campos.contiguous().data<float>(),
tan_fovx,
tan_fovy,
radii.contiguous().data<int>(),
dL_dout_color.contiguous().data<float>(),
dL_dmeans2D.contiguous().data<float>(),
dL_dconic.contiguous().data<float>(),
......
......@@ -25,10 +25,12 @@ RasterizeGaussiansCUDA(
const torch::Tensor& sh,
const int degree,
const torch::Tensor& campos,
const bool prefiltered);
const bool prefiltered,
void* internalState);
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeGaussiansBackwardCUDA(
const void* internalState,
const torch::Tensor& background,
const torch::Tensor& means3D,
const torch::Tensor& radii,
......@@ -39,13 +41,17 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
const torch::Tensor& cov3D_precomp,
const torch::Tensor& viewmatrix,
const torch::Tensor& projmatrix,
const float tan_fovx,
const float tan_fovx,
const float tan_fovy,
const torch::Tensor& dL_dout_color,
const torch::Tensor& sh,
const int degree,
const torch::Tensor& campos);
void* createRasterizerState();
void deleteRasterizerState(void* state);
torch::Tensor markVisible(
torch::Tensor& means3D,
torch::Tensor& viewmatrix,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment