Commit 79cbd71d by Bernhard Kerbl

No more persistent state

parent ffee75d5
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define CUDA_RASTERIZER_H_INCLUDED #define CUDA_RASTERIZER_H_INCLUDED
#include <vector> #include <vector>
#include <functional>
namespace CudaRasterizer namespace CudaRasterizer
{ {
...@@ -9,14 +10,17 @@ namespace CudaRasterizer ...@@ -9,14 +10,17 @@ namespace CudaRasterizer
{ {
public: public:
virtual void markVisible( static void markVisible(
int P, int P,
float* means3D, float* means3D,
float* viewmatrix, float* viewmatrix,
float* projmatrix, float* projmatrix,
bool* present) = 0; bool* present);
virtual void forward( static int forward(
std::function<char* (int)> geometryBuffer,
std::function<char* (int)> binningBuffer,
std::function<char* (int)> imageBuffer,
const int P, int D, int M, const int P, int D, int M,
const float* background, const float* background,
const int width, int height, const int width, int height,
...@@ -34,10 +38,10 @@ namespace CudaRasterizer ...@@ -34,10 +38,10 @@ namespace CudaRasterizer
const float tan_fovx, float tan_fovy, const float tan_fovx, float tan_fovy,
const bool prefiltered, const bool prefiltered,
float* out_color, float* out_color,
int* radii = nullptr) = 0; int* radii);
virtual void backward( static void backward(
const int P, int D, int M, const int P, int D, int M, int R,
const float* background, const float* background,
const int width, int height, const int width, int height,
const float* means3D, const float* means3D,
...@@ -52,6 +56,9 @@ namespace CudaRasterizer ...@@ -52,6 +56,9 @@ namespace CudaRasterizer
const float* campos, const float* campos,
const float tan_fovx, float tan_fovy, const float tan_fovx, float tan_fovy,
const int* radii, const int* radii,
char* geom_buffer,
char* binning_buffer,
char* image_buffer,
const float* dL_dpix, const float* dL_dpix,
float* dL_dmean2D, float* dL_dmean2D,
float* dL_dconic, float* dL_dconic,
...@@ -61,11 +68,7 @@ namespace CudaRasterizer ...@@ -61,11 +68,7 @@ namespace CudaRasterizer
float* dL_dcov3D, float* dL_dcov3D,
float* dL_dsh, float* dL_dsh,
float* dL_dscale, float* dL_dscale,
float* dL_drot) = 0; float* dL_drot);
virtual ~Rasterizer() {};
static Rasterizer* make(int resizeMultipliyer = 2);
}; };
}; };
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
#include "device_launch_parameters.h" #include "device_launch_parameters.h"
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <cub/device/device_radix_sort.cuh> #include <cub/device/device_radix_sort.cuh>
#include <thrust/sequence.h>
#define GLM_FORCE_CUDA #define GLM_FORCE_CUDA
#include <glm/glm.hpp> #include <glm/glm.hpp>
...@@ -127,18 +126,8 @@ __global__ void identifyTileRanges(int L, uint64_t* point_list_keys, uint2* rang ...@@ -127,18 +126,8 @@ __global__ void identifyTileRanges(int L, uint64_t* point_list_keys, uint2* rang
} }
} }
CudaRasterizer::RasterizerImpl::RasterizerImpl(int resizeMultiplier)
: resizeMultiplier(resizeMultiplier)
{}
// Instantiate rasterizer
CudaRasterizer::Rasterizer* CudaRasterizer::Rasterizer::make(int resizeMultiplier)
{
return new CudaRasterizer::RasterizerImpl(resizeMultiplier);
}
// Mark Gaussians as visible/invisible, based on view frustum testing // Mark Gaussians as visible/invisible, based on view frustum testing
void CudaRasterizer::RasterizerImpl::markVisible( void CudaRasterizer::Rasterizer::markVisible(
int P, int P,
float* means3D, float* means3D,
float* viewmatrix, float* viewmatrix,
...@@ -152,9 +141,53 @@ void CudaRasterizer::RasterizerImpl::markVisible( ...@@ -152,9 +141,53 @@ void CudaRasterizer::RasterizerImpl::markVisible(
present); present);
} }
CudaRasterizer::GeometryState CudaRasterizer::GeometryState::fromChunk(char*& chunk, int P)
{
GeometryState geom;
obtain(chunk, geom.depths, P, 128);
obtain(chunk, geom.clamped, P, 128);
obtain(chunk, geom.internal_radii, P, 128);
obtain(chunk, geom.means2D, P, 128);
obtain(chunk, geom.cov3D, P * 6, 128);
obtain(chunk, geom.conic_opacity, P, 128);
obtain(chunk, geom.rgb, P * 3, 128);
obtain(chunk, geom.tiles_touched, P, 128);
cub::DeviceScan::InclusiveSum(nullptr, geom.scan_size, geom.tiles_touched, geom.tiles_touched, P);
obtain(chunk, geom.scanning_space, geom.scan_size, 128);
obtain(chunk, geom.point_offsets, P, 128);
return geom;
}
CudaRasterizer::ImageState CudaRasterizer::ImageState::fromChunk(char*& chunk, int N)
{
ImageState img;
obtain(chunk, img.accum_alpha, N, 128);
obtain(chunk, img.n_contrib, N, 128);
obtain(chunk, img.ranges, N, 128);
return img;
}
CudaRasterizer::BinningState CudaRasterizer::BinningState::fromChunk(char*& chunk, int P)
{
BinningState binning;
obtain(chunk, binning.point_list, P, 128);
obtain(chunk, binning.point_list_unsorted, P, 128);
obtain(chunk, binning.point_list_keys, P, 128);
obtain(chunk, binning.point_list_keys_unsorted, P, 128);
cub::DeviceRadixSort::SortPairs(
nullptr, binning.sorting_size,
binning.point_list_keys_unsorted, binning.point_list_keys,
binning.point_list_unsorted, binning.point_list, P);
obtain(chunk, binning.list_sorting_space, binning.sorting_size, 128);
return binning;
}
// Forward rendering procedure for differentiable rasterization // Forward rendering procedure for differentiable rasterization
// of Gaussians. // of Gaussians.
void CudaRasterizer::RasterizerImpl::forward( int CudaRasterizer::Rasterizer::forward(
std::function<char*(int)> geometryBuffer,
std::function<char* (int)> binningBuffer,
std::function<char* (int)> imageBuffer,
const int P, int D, int M, const int P, int D, int M,
const float* background, const float* background,
const int width, int height, const int width, int height,
...@@ -177,41 +210,22 @@ void CudaRasterizer::RasterizerImpl::forward( ...@@ -177,41 +210,22 @@ void CudaRasterizer::RasterizerImpl::forward(
const float focal_y = height / (2.0f * tan_fovy); const float focal_y = height / (2.0f * tan_fovy);
const float focal_x = width / (2.0f * tan_fovx); const float focal_x = width / (2.0f * tan_fovx);
// Dynamically resize auxiliary buffers during training int chunk_size = required<GeometryState>(P);
if (P > maxP) char* chunkptr = geometryBuffer(chunk_size);
{ GeometryState geomState = GeometryState::fromChunk(chunkptr, P);
maxP = resizeMultiplier * P;
cov3D.resize(maxP * 6);
rgb.resize(maxP * 3);
tiles_touched.resize(maxP);
point_offsets.resize(maxP);
clamped.resize(3 * maxP);
depths.resize(maxP);
means2D.resize(maxP);
conic_opacity.resize(maxP);
cub::DeviceScan::InclusiveSum(nullptr, scan_size, tiles_touched.data().get(), tiles_touched.data().get(), maxP);
scanning_space.resize(scan_size);
}
if (radii == nullptr) if (radii == nullptr)
{ {
internal_radii.resize(maxP); radii = geomState.internal_radii;
radii = internal_radii.data().get();
} }
dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1); dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1);
dim3 block(BLOCK_X, BLOCK_Y, 1); dim3 block(BLOCK_X, BLOCK_Y, 1);
// Dynamically resize image-based auxiliary buffers during training // Dynamically resize image-based auxiliary buffers during training
if (width * height > maxPixels) int img_chunk_size = required<ImageState>(width * height);
{ char* img_chunkptr = imageBuffer(img_chunk_size);
maxPixels = width * height; ImageState imgState = ImageState::fromChunk(img_chunkptr, width * height);
accum_alpha.resize(maxPixels);
n_contrib.resize(maxPixels);
ranges.resize(tile_grid.x * tile_grid.y);
}
if (NUM_CHANNELS != 3 && colors_precomp == nullptr) if (NUM_CHANNELS != 3 && colors_precomp == nullptr)
{ {
...@@ -227,7 +241,7 @@ void CudaRasterizer::RasterizerImpl::forward( ...@@ -227,7 +241,7 @@ void CudaRasterizer::RasterizerImpl::forward(
(glm::vec4*)rotations, (glm::vec4*)rotations,
opacities, opacities,
shs, shs,
clamped.data().get(), geomState.clamped,
cov3D_precomp, cov3D_precomp,
colors_precomp, colors_precomp,
viewmatrix, projmatrix, viewmatrix, projmatrix,
...@@ -236,47 +250,38 @@ void CudaRasterizer::RasterizerImpl::forward( ...@@ -236,47 +250,38 @@ void CudaRasterizer::RasterizerImpl::forward(
tan_fovx, tan_fovy, tan_fovx, tan_fovy,
focal_x, focal_y, focal_x, focal_y,
radii, radii,
means2D.data().get(), geomState.means2D,
depths.data().get(), geomState.depths,
cov3D.data().get(), geomState.cov3D,
rgb.data().get(), geomState.rgb,
conic_opacity.data().get(), geomState.conic_opacity,
tile_grid, tile_grid,
tiles_touched.data().get(), geomState.tiles_touched,
prefiltered prefiltered
); );
// Compute prefix sum over full list of touched tile counts by Gaussians // Compute prefix sum over full list of touched tile counts by Gaussians
// E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8] // E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8]
cub::DeviceScan::InclusiveSum(scanning_space.data().get(), scan_size, cub::DeviceScan::InclusiveSum(geomState.scanning_space, geomState.scan_size,
tiles_touched.data().get(), point_offsets.data().get(), P); geomState.tiles_touched, geomState.point_offsets, P);
// Retrieve total number of Gaussian instances to launch and resize aux buffers // Retrieve total number of Gaussian instances to launch and resize aux buffers
int num_needed; int num_rendered;
cudaMemcpy(&num_needed, point_offsets.data().get() + P - 1, sizeof(int), cudaMemcpyDeviceToHost); cudaMemcpy(&num_rendered, geomState.point_offsets + P - 1, sizeof(int), cudaMemcpyDeviceToHost);
if (num_needed > point_list_keys_unsorted.size())
{ int binning_chunk_size = required<BinningState>(num_rendered);
point_list_keys_unsorted.resize(2 * num_needed); char* binning_chunkptr = binningBuffer(binning_chunk_size);
point_list_keys.resize(2 * num_needed); BinningState binningState = BinningState::fromChunk(binning_chunkptr, num_rendered);
point_list_unsorted.resize(2 * num_needed);
point_list.resize(2 * num_needed);
cub::DeviceRadixSort::SortPairs(
nullptr, sorting_size,
point_list_keys_unsorted.data().get(), point_list_keys.data().get(),
point_list_unsorted.data().get(), point_list.data().get(),
2 * num_needed);
list_sorting_space.resize(sorting_size);
}
// For each instance to be rendered, produce adequate [ tile | depth ] key // For each instance to be rendered, produce adequate [ tile | depth ] key
// and corresponding dublicated Gaussian indices to be sorted // and corresponding dublicated Gaussian indices to be sorted
duplicateWithKeys << <(P + 255) / 256, 256 >> > ( duplicateWithKeys << <(P + 255) / 256, 256 >> > (
P, P,
means2D.data().get(), geomState.means2D,
depths.data().get(), geomState.depths,
point_offsets.data().get(), geomState.point_offsets,
point_list_keys_unsorted.data().get(), binningState.point_list_keys_unsorted,
point_list_unsorted.data().get(), binningState.point_list_unsorted,
radii, radii,
tile_grid tile_grid
); );
...@@ -285,41 +290,43 @@ void CudaRasterizer::RasterizerImpl::forward( ...@@ -285,41 +290,43 @@ void CudaRasterizer::RasterizerImpl::forward(
// Sort complete list of (duplicated) Gaussian indices by keys // Sort complete list of (duplicated) Gaussian indices by keys
cub::DeviceRadixSort::SortPairs( cub::DeviceRadixSort::SortPairs(
list_sorting_space.data().get(), binningState.list_sorting_space,
sorting_size, binningState.sorting_size,
point_list_keys_unsorted.data().get(), point_list_keys.data().get(), binningState.point_list_keys_unsorted, binningState.point_list_keys,
point_list_unsorted.data().get(), point_list.data().get(), binningState.point_list_unsorted, binningState.point_list,
num_needed, 0, 32 + bit); num_rendered, 0, 32 + bit);
cudaMemset(ranges.data().get(), 0, tile_grid.x * tile_grid.y * sizeof(uint2)); cudaMemset(imgState.ranges, 0, tile_grid.x * tile_grid.y * sizeof(uint2));
// Identify start and end of per-tile workloads in sorted list // Identify start and end of per-tile workloads in sorted list
identifyTileRanges << <(num_needed + 255) / 256, 256 >> > ( identifyTileRanges << <(num_rendered + 255) / 256, 256 >> > (
num_needed, num_rendered,
point_list_keys.data().get(), binningState.point_list_keys,
ranges.data().get() imgState.ranges
); );
// Let each tile blend its range of Gaussians independently in parallel // Let each tile blend its range of Gaussians independently in parallel
const float* feature_ptr = colors_precomp != nullptr ? colors_precomp : rgb.data().get(); const float* feature_ptr = colors_precomp != nullptr ? colors_precomp : geomState.rgb;
FORWARD::render( FORWARD::render(
tile_grid, block, tile_grid, block,
ranges.data().get(), imgState.ranges,
point_list.data().get(), binningState.point_list,
width, height, width, height,
means2D.data().get(), geomState.means2D,
feature_ptr, feature_ptr,
conic_opacity.data().get(), geomState.conic_opacity,
accum_alpha.data().get(), imgState.accum_alpha,
n_contrib.data().get(), imgState.n_contrib,
background, background,
out_color); out_color);
return num_rendered;
} }
// Produce necessary gradients for optimization, corresponding // Produce necessary gradients for optimization, corresponding
// to forward render pass // to forward render pass
void CudaRasterizer::RasterizerImpl::backward( void CudaRasterizer::Rasterizer::backward(
const int P, int D, int M, const int P, int D, int M, int R,
const float* background, const float* background,
const int width, int height, const int width, int height,
const float* means3D, const float* means3D,
...@@ -334,6 +341,9 @@ void CudaRasterizer::RasterizerImpl::backward( ...@@ -334,6 +341,9 @@ void CudaRasterizer::RasterizerImpl::backward(
const float* campos, const float* campos,
const float tan_fovx, float tan_fovy, const float tan_fovx, float tan_fovy,
const int* radii, const int* radii,
char* geom_buffer,
char* binning_buffer,
char* img_buffer,
const float* dL_dpix, const float* dL_dpix,
float* dL_dmean2D, float* dL_dmean2D,
float* dL_dconic, float* dL_dconic,
...@@ -345,9 +355,13 @@ void CudaRasterizer::RasterizerImpl::backward( ...@@ -345,9 +355,13 @@ void CudaRasterizer::RasterizerImpl::backward(
float* dL_dscale, float* dL_dscale,
float* dL_drot) float* dL_drot)
{ {
GeometryState geomState = GeometryState::fromChunk(geom_buffer, P);
BinningState binningState = BinningState::fromChunk(binning_buffer, R);
ImageState imgState = ImageState::fromChunk(img_buffer, width * height);
if (radii == nullptr) if (radii == nullptr)
{ {
radii = internal_radii.data().get(); radii = geomState.internal_radii;
} }
const float focal_y = height / (2.0f * tan_fovy); const float focal_y = height / (2.0f * tan_fovy);
...@@ -359,19 +373,19 @@ void CudaRasterizer::RasterizerImpl::backward( ...@@ -359,19 +373,19 @@ void CudaRasterizer::RasterizerImpl::backward(
// Compute loss gradients w.r.t. 2D mean position, conic matrix, // Compute loss gradients w.r.t. 2D mean position, conic matrix,
// opacity and RGB of Gaussians from per-pixel loss gradients. // opacity and RGB of Gaussians from per-pixel loss gradients.
// If we were given precomputed colors and not SHs, use them. // If we were given precomputed colors and not SHs, use them.
const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : rgb.data().get(); const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : geomState.rgb;
BACKWARD::render( BACKWARD::render(
tile_grid, tile_grid,
block, block,
ranges.data().get(), imgState.ranges,
point_list.data().get(), binningState.point_list,
width, height, width, height,
background, background,
means2D.data().get(), geomState.means2D,
conic_opacity.data().get(), geomState.conic_opacity,
color_ptr, color_ptr,
accum_alpha.data().get(), imgState.accum_alpha,
n_contrib.data().get(), imgState.n_contrib,
dL_dpix, dL_dpix,
(float3*)dL_dmean2D, (float3*)dL_dmean2D,
(float4*)dL_dconic, (float4*)dL_dconic,
...@@ -381,12 +395,12 @@ void CudaRasterizer::RasterizerImpl::backward( ...@@ -381,12 +395,12 @@ void CudaRasterizer::RasterizerImpl::backward(
// Take care of the rest of preprocessing. Was the precomputed covariance // Take care of the rest of preprocessing. Was the precomputed covariance
// given to us or a scales/rot pair? If precomputed, pass that. If not, // given to us or a scales/rot pair? If precomputed, pass that. If not,
// use the one we computed ourselves. // use the one we computed ourselves.
const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp : cov3D.data().get(); const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp : geomState.cov3D;
BACKWARD::preprocess(P, D, M, BACKWARD::preprocess(P, D, M,
(float3*)means3D, (float3*)means3D,
radii, radii,
shs, shs,
clamped.data().get(), geomState.clamped,
(glm::vec3*)scales, (glm::vec3*)scales,
(glm::vec4*)rotations, (glm::vec4*)rotations,
scale_modifier, scale_modifier,
...@@ -404,7 +418,3 @@ void CudaRasterizer::RasterizerImpl::backward( ...@@ -404,7 +418,3 @@ void CudaRasterizer::RasterizerImpl::backward(
(glm::vec3*)dL_dscale, (glm::vec3*)dL_dscale,
(glm::vec4*)dL_drot); (glm::vec4*)dL_drot);
} }
\ No newline at end of file
CudaRasterizer::RasterizerImpl::~RasterizerImpl()
{
}
\ No newline at end of file
...@@ -4,101 +4,60 @@ ...@@ -4,101 +4,60 @@
#include <vector> #include <vector>
#include "rasterizer.h" #include "rasterizer.h"
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <thrust/device_vector.h>
namespace CudaRasterizer namespace CudaRasterizer
{ {
class RasterizerImpl : public Rasterizer template <typename T>
static void obtain(char*& chunk, T*& ptr, std::size_t count, std::size_t alignment)
{ {
private: std::size_t offset = (reinterpret_cast<std::uintptr_t>(chunk) + alignment - 1) & ~(alignment - 1);
int maxP = 0; ptr = reinterpret_cast<T*>(offset);
int maxPixels = 0; chunk = reinterpret_cast<char*>(ptr + count);
int resizeMultiplier = 2; }
// Initial aux structs struct GeometryState
size_t sorting_size; {
size_t list_sorting_size;
size_t scan_size; size_t scan_size;
thrust::device_vector<float> depths; float* depths;
thrust::device_vector<uint32_t> tiles_touched; char* scanning_space;
thrust::device_vector<uint32_t> point_offsets; bool* clamped;
thrust::device_vector<uint64_t> point_list_keys_unsorted; int* internal_radii;
thrust::device_vector<uint64_t> point_list_keys; float2* means2D;
thrust::device_vector<uint32_t> point_list_unsorted; float* cov3D;
thrust::device_vector<uint32_t> point_list; float4* conic_opacity;
thrust::device_vector<char> scanning_space; float* rgb;
thrust::device_vector<char> list_sorting_space; uint32_t* point_offsets;
thrust::device_vector<bool> clamped; uint32_t* tiles_touched;
thrust::device_vector<int> internal_radii;
static GeometryState fromChunk(char*& chunk, int P);
// Internal state kept across forward / backward };
thrust::device_vector<uint2> ranges;
thrust::device_vector<uint32_t> n_contrib;
thrust::device_vector<float> accum_alpha;
thrust::device_vector<float2> means2D;
thrust::device_vector<float> cov3D;
thrust::device_vector<float4> conic_opacity;
thrust::device_vector<float> rgb;
public:
virtual void markVisible(
int P,
float* means3D,
float* viewmatrix,
float* projmatrix,
bool* present) override;
virtual void forward( struct ImageState
const int P, int D, int M, {
const float* background, uint2* ranges;
const int width, int height, uint32_t* n_contrib;
const float* means3D, float* accum_alpha;
const float* shs,
const float* colors_precomp,
const float* opacities,
const float* scales,
const float scale_modifier,
const float* rotations,
const float* cov3D_precomp,
const float* viewmatrix,
const float* projmatrix,
const float* cam_pos,
const float tan_fovx, float tan_fovy,
const bool prefiltered,
float* out_color,
int* radii) override;
virtual void backward( static ImageState fromChunk(char*& chunk, int N);
const int P, int D, int M, };
const float* background,
const int width, int height,
const float* means3D,
const float* shs,
const float* colors_precomp,
const float* scales,
const float scale_modifier,
const float* rotations,
const float* cov3D_precomp,
const float* viewmatrix,
const float* projmatrix,
const float* campos,
const float tan_fovx, float tan_fovy,
const int* radii,
const float* dL_dpix,
float* dL_dmean2D,
float* dL_dconic,
float* dL_dopacity,
float* dL_dcolor,
float* dL_dmean3D,
float* dL_dcov3D,
float* dL_dsh,
float* dL_dscale,
float* dL_drot) override;
RasterizerImpl(int resizeMultiplier); struct BinningState
{
size_t sorting_size;
uint64_t* point_list_keys_unsorted;
uint64_t* point_list_keys;
uint32_t* point_list_unsorted;
uint32_t* point_list;
char* list_sorting_space;
virtual ~RasterizerImpl() override; static BinningState fromChunk(char*& chunk, int P);
}; };
template<typename T>
int required(int P)
{
char* size = nullptr;
T::fromChunk(size, P);
return ((int)size) + 128;
}
}; };
\ No newline at end of file
...@@ -64,19 +64,21 @@ class _RasterizeGaussians(torch.autograd.Function): ...@@ -64,19 +64,21 @@ class _RasterizeGaussians(torch.autograd.Function):
) )
# Invoke C++/CUDA rasterizer # Invoke C++/CUDA rasterizer
color, radii = _C.rasterize_gaussians(*args) num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
# Keep relevant tensors for backward # Keep relevant tensors for backward
ctx.raster_settings = raster_settings ctx.raster_settings = raster_settings
ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh) ctx.num_rendered = num_rendered
ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer)
return color, radii return color, radii
@staticmethod @staticmethod
def backward(ctx, grad_out_color, _): def backward(ctx, grad_out_color, _):
# Restore necessary values from context # Restore necessary values from context
num_rendered = ctx.num_rendered
raster_settings = ctx.raster_settings raster_settings = ctx.raster_settings
colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh = ctx.saved_tensors colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer = ctx.saved_tensors
# Restructure args as C++ method expects them # Restructure args as C++ method expects them
args = (raster_settings.bg, args = (raster_settings.bg,
...@@ -94,7 +96,11 @@ class _RasterizeGaussians(torch.autograd.Function): ...@@ -94,7 +96,11 @@ class _RasterizeGaussians(torch.autograd.Function):
grad_out_color, grad_out_color,
sh, sh,
raster_settings.sh_degree, raster_settings.sh_degree,
raster_settings.campos) raster_settings.campos,
geomBuffer,
num_rendered,
binningBuffer,
imgBuffer)
# Compute gradients for relevant tensors by invoking backward method # Compute gradients for relevant tensors by invoking backward method
grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args) grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
......
...@@ -13,10 +13,17 @@ ...@@ -13,10 +13,17 @@
#include "cuda_rasterizer/rasterizer.h" #include "cuda_rasterizer/rasterizer.h"
#include <fstream> #include <fstream>
#include <string> #include <string>
#include <functional>
std::function<char*(int N)> resizeFunctional(torch::Tensor& t) {
auto lambda = [&t](int N) {
t.resize_({N});
return reinterpret_cast<char*>(t.contiguous().data_ptr());
};
return lambda;
}
static std::unique_ptr<CudaRasterizer::Rasterizer> cudaRenderer = nullptr; std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
std::tuple<torch::Tensor, torch::Tensor>
RasterizeGaussiansCUDA( RasterizeGaussiansCUDA(
const torch::Tensor& background, const torch::Tensor& background,
const torch::Tensor& means3D, const torch::Tensor& means3D,
...@@ -37,16 +44,10 @@ RasterizeGaussiansCUDA( ...@@ -37,16 +44,10 @@ RasterizeGaussiansCUDA(
const torch::Tensor& campos, const torch::Tensor& campos,
const bool prefiltered) const bool prefiltered)
{ {
if (means3D.ndimension() != 2 || means3D.size(1) != 3) { if (means3D.ndimension() != 2 || means3D.size(1) != 3) {
AT_ERROR("means3D must have dimensions (num_points, 3)"); AT_ERROR("means3D must have dimensions (num_points, 3)");
} }
if (cudaRenderer == nullptr)
{
cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make());
}
const int P = means3D.size(0); const int P = means3D.size(0);
const int N = 1; // batch size hard-coded const int N = 1; // batch size hard-coded
const int H = image_height; const int H = image_height;
...@@ -58,6 +59,16 @@ RasterizeGaussiansCUDA( ...@@ -58,6 +59,16 @@ RasterizeGaussiansCUDA(
torch::Tensor out_color = torch::full({N, NUM_CHANNELS, H, W}, 0.0, float_opts); torch::Tensor out_color = torch::full({N, NUM_CHANNELS, H, W}, 0.0, float_opts);
torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32)); torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32));
torch::Device device(torch::kCUDA);
torch::TensorOptions options(torch::kByte);
torch::Tensor geomBuffer = torch::empty({0}, options.device(device));
torch::Tensor binningBuffer = torch::empty({0}, options.device(device));
torch::Tensor imgBuffer = torch::empty({0}, options.device(device));
std::function<char*(int)> geomFunc = resizeFunctional(geomBuffer);
std::function<char*(int)> binningFunc = resizeFunctional(binningBuffer);
std::function<char*(int)> imgFunc = resizeFunctional(imgBuffer);
int rendered = 0;
if(P != 0) if(P != 0)
{ {
int M = 0; int M = 0;
...@@ -66,7 +77,11 @@ RasterizeGaussiansCUDA( ...@@ -66,7 +77,11 @@ RasterizeGaussiansCUDA(
M = sh.size(1); M = sh.size(1);
} }
cudaRenderer->forward(P, degree, M, rendered = CudaRasterizer::Rasterizer::forward(
geomFunc,
binningFunc,
imgFunc,
P, degree, M,
background.contiguous().data<float>(), background.contiguous().data<float>(),
W, H, W, H,
means3D.contiguous().data<float>(), means3D.contiguous().data<float>(),
...@@ -86,7 +101,7 @@ RasterizeGaussiansCUDA( ...@@ -86,7 +101,7 @@ RasterizeGaussiansCUDA(
out_color.contiguous().data<float>(), out_color.contiguous().data<float>(),
radii.contiguous().data<int>()); radii.contiguous().data<int>());
} }
return std::make_tuple(out_color, radii); return std::make_tuple(rendered, out_color, radii, geomBuffer, binningBuffer, imgBuffer);
} }
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
...@@ -106,7 +121,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te ...@@ -106,7 +121,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
const torch::Tensor& dL_dout_color, const torch::Tensor& dL_dout_color,
const torch::Tensor& sh, const torch::Tensor& sh,
const int degree, const int degree,
const torch::Tensor& campos) const torch::Tensor& campos,
const torch::Tensor& geomBuffer,
const int R,
const torch::Tensor& binningBuffer,
const torch::Tensor& imageBuffer)
{ {
const int P = means3D.size(0); const int P = means3D.size(0);
const int H = dL_dout_color.size(2); const int H = dL_dout_color.size(2);
...@@ -130,7 +149,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te ...@@ -130,7 +149,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
if(P != 0) if(P != 0)
{ {
cudaRenderer->backward(P, degree, M, CudaRasterizer::Rasterizer::backward(P, degree, M, R,
background.contiguous().data<float>(), background.contiguous().data<float>(),
W, H, W, H,
means3D.contiguous().data<float>(), means3D.contiguous().data<float>(),
...@@ -146,6 +165,9 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te ...@@ -146,6 +165,9 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
tan_fovx, tan_fovx,
tan_fovy, tan_fovy,
radii.contiguous().data<int>(), radii.contiguous().data<int>(),
reinterpret_cast<char*>(geomBuffer.contiguous().data_ptr()),
reinterpret_cast<char*>(binningBuffer.contiguous().data_ptr()),
reinterpret_cast<char*>(imageBuffer.contiguous().data_ptr()),
dL_dout_color.contiguous().data<float>(), dL_dout_color.contiguous().data<float>(),
dL_dmeans2D.contiguous().data<float>(), dL_dmeans2D.contiguous().data<float>(),
dL_dconic.contiguous().data<float>(), dL_dconic.contiguous().data<float>(),
...@@ -166,18 +188,13 @@ torch::Tensor markVisible( ...@@ -166,18 +188,13 @@ torch::Tensor markVisible(
torch::Tensor& viewmatrix, torch::Tensor& viewmatrix,
torch::Tensor& projmatrix) torch::Tensor& projmatrix)
{ {
if (cudaRenderer == nullptr)
{
cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make());
}
const int P = means3D.size(0); const int P = means3D.size(0);
torch::Tensor present = torch::full({P}, false, means3D.options().dtype(at::kBool)); torch::Tensor present = torch::full({P}, false, means3D.options().dtype(at::kBool));
if(P != 0) if(P != 0)
{ {
cudaRenderer->markVisible(P, CudaRasterizer::Rasterizer::markVisible(P,
means3D.contiguous().data<float>(), means3D.contiguous().data<float>(),
viewmatrix.contiguous().data<float>(), viewmatrix.contiguous().data<float>(),
projmatrix.contiguous().data<float>(), projmatrix.contiguous().data<float>(),
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <tuple> #include <tuple>
#include <string> #include <string>
std::tuple<torch::Tensor, torch::Tensor> std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeGaussiansCUDA( RasterizeGaussiansCUDA(
const torch::Tensor& background, const torch::Tensor& background,
const torch::Tensor& means3D, const torch::Tensor& means3D,
...@@ -44,7 +44,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te ...@@ -44,7 +44,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
const torch::Tensor& dL_dout_color, const torch::Tensor& dL_dout_color,
const torch::Tensor& sh, const torch::Tensor& sh,
const int degree, const int degree,
const torch::Tensor& campos); const torch::Tensor& campos,
const torch::Tensor& geomBuffer,
const int R,
const torch::Tensor& binningBuffer,
const torch::Tensor& imageBuffer);
torch::Tensor markVisible( torch::Tensor markVisible(
torch::Tensor& means3D, torch::Tensor& means3D,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment