Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
D
diff-gaussian-rasterization
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Alan de Oliveira
diff-gaussian-rasterization
Commits
79cbd71d
Commit
79cbd71d
authored
Jun 20, 2023
by
Bernhard Kerbl
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
No more persistent state
parent
ffee75d5
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
224 additions
and
224 deletions
+224
-224
rasterizer.h
cuda_rasterizer/rasterizer.h
+14
-11
rasterizer_impl.cu
cuda_rasterizer/rasterizer_impl.cu
+112
-102
rasterizer_impl.h
cuda_rasterizer/rasterizer_impl.h
+46
-86
rasterizer.py
diff_gaussian_rasterization/rasterizer.py
+10
-4
rasterize_points.cu
rasterize_points.cu
+36
-19
rasterize_points.h
rasterize_points.h
+6
-2
No files found.
cuda_rasterizer/rasterizer.h
View file @
79cbd71d
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#define CUDA_RASTERIZER_H_INCLUDED
#define CUDA_RASTERIZER_H_INCLUDED
#include <vector>
#include <vector>
#include <functional>
namespace
CudaRasterizer
namespace
CudaRasterizer
{
{
...
@@ -9,14 +10,17 @@ namespace CudaRasterizer
...
@@ -9,14 +10,17 @@ namespace CudaRasterizer
{
{
public
:
public
:
virtual
void
markVisible
(
static
void
markVisible
(
int
P
,
int
P
,
float
*
means3D
,
float
*
means3D
,
float
*
viewmatrix
,
float
*
viewmatrix
,
float
*
projmatrix
,
float
*
projmatrix
,
bool
*
present
)
=
0
;
bool
*
present
);
virtual
void
forward
(
static
int
forward
(
std
::
function
<
char
*
(
int
)
>
geometryBuffer
,
std
::
function
<
char
*
(
int
)
>
binningBuffer
,
std
::
function
<
char
*
(
int
)
>
imageBuffer
,
const
int
P
,
int
D
,
int
M
,
const
int
P
,
int
D
,
int
M
,
const
float
*
background
,
const
float
*
background
,
const
int
width
,
int
height
,
const
int
width
,
int
height
,
...
@@ -34,10 +38,10 @@ namespace CudaRasterizer
...
@@ -34,10 +38,10 @@ namespace CudaRasterizer
const
float
tan_fovx
,
float
tan_fovy
,
const
float
tan_fovx
,
float
tan_fovy
,
const
bool
prefiltered
,
const
bool
prefiltered
,
float
*
out_color
,
float
*
out_color
,
int
*
radii
=
nullptr
)
=
0
;
int
*
radii
)
;
virtual
void
backward
(
static
void
backward
(
const
int
P
,
int
D
,
int
M
,
const
int
P
,
int
D
,
int
M
,
int
R
,
const
float
*
background
,
const
float
*
background
,
const
int
width
,
int
height
,
const
int
width
,
int
height
,
const
float
*
means3D
,
const
float
*
means3D
,
...
@@ -52,6 +56,9 @@ namespace CudaRasterizer
...
@@ -52,6 +56,9 @@ namespace CudaRasterizer
const
float
*
campos
,
const
float
*
campos
,
const
float
tan_fovx
,
float
tan_fovy
,
const
float
tan_fovx
,
float
tan_fovy
,
const
int
*
radii
,
const
int
*
radii
,
char
*
geom_buffer
,
char
*
binning_buffer
,
char
*
image_buffer
,
const
float
*
dL_dpix
,
const
float
*
dL_dpix
,
float
*
dL_dmean2D
,
float
*
dL_dmean2D
,
float
*
dL_dconic
,
float
*
dL_dconic
,
...
@@ -61,11 +68,7 @@ namespace CudaRasterizer
...
@@ -61,11 +68,7 @@ namespace CudaRasterizer
float
*
dL_dcov3D
,
float
*
dL_dcov3D
,
float
*
dL_dsh
,
float
*
dL_dsh
,
float
*
dL_dscale
,
float
*
dL_dscale
,
float
*
dL_drot
)
=
0
;
float
*
dL_drot
);
virtual
~
Rasterizer
()
{};
static
Rasterizer
*
make
(
int
resizeMultipliyer
=
2
);
};
};
};
};
...
...
cuda_rasterizer/rasterizer_impl.cu
View file @
79cbd71d
...
@@ -8,7 +8,6 @@
...
@@ -8,7 +8,6 @@
#include "device_launch_parameters.h"
#include "device_launch_parameters.h"
#include <cub/cub.cuh>
#include <cub/cub.cuh>
#include <cub/device/device_radix_sort.cuh>
#include <cub/device/device_radix_sort.cuh>
#include <thrust/sequence.h>
#define GLM_FORCE_CUDA
#define GLM_FORCE_CUDA
#include <glm/glm.hpp>
#include <glm/glm.hpp>
...
@@ -127,18 +126,8 @@ __global__ void identifyTileRanges(int L, uint64_t* point_list_keys, uint2* rang
...
@@ -127,18 +126,8 @@ __global__ void identifyTileRanges(int L, uint64_t* point_list_keys, uint2* rang
}
}
}
}
CudaRasterizer::RasterizerImpl::RasterizerImpl(int resizeMultiplier)
: resizeMultiplier(resizeMultiplier)
{}
// Instantiate rasterizer
CudaRasterizer::Rasterizer* CudaRasterizer::Rasterizer::make(int resizeMultiplier)
{
return new CudaRasterizer::RasterizerImpl(resizeMultiplier);
}
// Mark Gaussians as visible/invisible, based on view frustum testing
// Mark Gaussians as visible/invisible, based on view frustum testing
void CudaRasterizer::Rasterizer
Impl
::markVisible(
void CudaRasterizer::Rasterizer::markVisible(
int P,
int P,
float* means3D,
float* means3D,
float* viewmatrix,
float* viewmatrix,
...
@@ -152,9 +141,53 @@ void CudaRasterizer::RasterizerImpl::markVisible(
...
@@ -152,9 +141,53 @@ void CudaRasterizer::RasterizerImpl::markVisible(
present);
present);
}
}
CudaRasterizer::GeometryState CudaRasterizer::GeometryState::fromChunk(char*& chunk, int P)
{
GeometryState geom;
obtain(chunk, geom.depths, P, 128);
obtain(chunk, geom.clamped, P, 128);
obtain(chunk, geom.internal_radii, P, 128);
obtain(chunk, geom.means2D, P, 128);
obtain(chunk, geom.cov3D, P * 6, 128);
obtain(chunk, geom.conic_opacity, P, 128);
obtain(chunk, geom.rgb, P * 3, 128);
obtain(chunk, geom.tiles_touched, P, 128);
cub::DeviceScan::InclusiveSum(nullptr, geom.scan_size, geom.tiles_touched, geom.tiles_touched, P);
obtain(chunk, geom.scanning_space, geom.scan_size, 128);
obtain(chunk, geom.point_offsets, P, 128);
return geom;
}
CudaRasterizer::ImageState CudaRasterizer::ImageState::fromChunk(char*& chunk, int N)
{
ImageState img;
obtain(chunk, img.accum_alpha, N, 128);
obtain(chunk, img.n_contrib, N, 128);
obtain(chunk, img.ranges, N, 128);
return img;
}
CudaRasterizer::BinningState CudaRasterizer::BinningState::fromChunk(char*& chunk, int P)
{
BinningState binning;
obtain(chunk, binning.point_list, P, 128);
obtain(chunk, binning.point_list_unsorted, P, 128);
obtain(chunk, binning.point_list_keys, P, 128);
obtain(chunk, binning.point_list_keys_unsorted, P, 128);
cub::DeviceRadixSort::SortPairs(
nullptr, binning.sorting_size,
binning.point_list_keys_unsorted, binning.point_list_keys,
binning.point_list_unsorted, binning.point_list, P);
obtain(chunk, binning.list_sorting_space, binning.sorting_size, 128);
return binning;
}
// Forward rendering procedure for differentiable rasterization
// Forward rendering procedure for differentiable rasterization
// of Gaussians.
// of Gaussians.
void CudaRasterizer::RasterizerImpl::forward(
int CudaRasterizer::Rasterizer::forward(
std::function<char*(int)> geometryBuffer,
std::function<char* (int)> binningBuffer,
std::function<char* (int)> imageBuffer,
const int P, int D, int M,
const int P, int D, int M,
const float* background,
const float* background,
const int width, int height,
const int width, int height,
...
@@ -177,41 +210,22 @@ void CudaRasterizer::RasterizerImpl::forward(
...
@@ -177,41 +210,22 @@ void CudaRasterizer::RasterizerImpl::forward(
const float focal_y = height / (2.0f * tan_fovy);
const float focal_y = height / (2.0f * tan_fovy);
const float focal_x = width / (2.0f * tan_fovx);
const float focal_x = width / (2.0f * tan_fovx);
// Dynamically resize auxiliary buffers during training
int chunk_size = required<GeometryState>(P);
if (P > maxP)
char* chunkptr = geometryBuffer(chunk_size);
{
GeometryState geomState = GeometryState::fromChunk(chunkptr, P);
maxP = resizeMultiplier * P;
cov3D.resize(maxP * 6);
rgb.resize(maxP * 3);
tiles_touched.resize(maxP);
point_offsets.resize(maxP);
clamped.resize(3 * maxP);
depths.resize(maxP);
means2D.resize(maxP);
conic_opacity.resize(maxP);
cub::DeviceScan::InclusiveSum(nullptr, scan_size, tiles_touched.data().get(), tiles_touched.data().get(), maxP);
scanning_space.resize(scan_size);
}
if (radii == nullptr)
if (radii == nullptr)
{
{
internal_radii.resize(maxP);
radii = geomState.internal_radii;
radii = internal_radii.data().get();
}
}
dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1);
dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1);
dim3 block(BLOCK_X, BLOCK_Y, 1);
dim3 block(BLOCK_X, BLOCK_Y, 1);
// Dynamically resize image-based auxiliary buffers during training
// Dynamically resize image-based auxiliary buffers during training
if (width * height > maxPixels)
int img_chunk_size = required<ImageState>(width * height);
{
char* img_chunkptr = imageBuffer(img_chunk_size);
maxPixels = width * height;
ImageState imgState = ImageState::fromChunk(img_chunkptr, width * height);
accum_alpha.resize(maxPixels);
n_contrib.resize(maxPixels);
ranges.resize(tile_grid.x * tile_grid.y);
}
if (NUM_CHANNELS != 3 && colors_precomp == nullptr)
if (NUM_CHANNELS != 3 && colors_precomp == nullptr)
{
{
...
@@ -227,7 +241,7 @@ void CudaRasterizer::RasterizerImpl::forward(
...
@@ -227,7 +241,7 @@ void CudaRasterizer::RasterizerImpl::forward(
(glm::vec4*)rotations,
(glm::vec4*)rotations,
opacities,
opacities,
shs,
shs,
clamped.data().get()
,
geomState.clamped
,
cov3D_precomp,
cov3D_precomp,
colors_precomp,
colors_precomp,
viewmatrix, projmatrix,
viewmatrix, projmatrix,
...
@@ -236,47 +250,38 @@ void CudaRasterizer::RasterizerImpl::forward(
...
@@ -236,47 +250,38 @@ void CudaRasterizer::RasterizerImpl::forward(
tan_fovx, tan_fovy,
tan_fovx, tan_fovy,
focal_x, focal_y,
focal_x, focal_y,
radii,
radii,
means2D.data().get()
,
geomState.means2D
,
depths.data().get()
,
geomState.depths
,
cov3D.data().get()
,
geomState.cov3D
,
rgb.data().get()
,
geomState.rgb
,
conic_opacity.data().get()
,
geomState.conic_opacity
,
tile_grid,
tile_grid,
tiles_touched.data().get()
,
geomState.tiles_touched
,
prefiltered
prefiltered
);
);
// Compute prefix sum over full list of touched tile counts by Gaussians
// Compute prefix sum over full list of touched tile counts by Gaussians
// E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8]
// E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8]
cub::DeviceScan::InclusiveSum(
scanning_space.data().get(),
scan_size,
cub::DeviceScan::InclusiveSum(
geomState.scanning_space, geomState.
scan_size,
tiles_touched.data().get(), point_offsets.data().get()
, P);
geomState.tiles_touched, geomState.point_offsets
, P);
// Retrieve total number of Gaussian instances to launch and resize aux buffers
// Retrieve total number of Gaussian instances to launch and resize aux buffers
int num_needed;
int num_rendered;
cudaMemcpy(&num_needed, point_offsets.data().get() + P - 1, sizeof(int), cudaMemcpyDeviceToHost);
cudaMemcpy(&num_rendered, geomState.point_offsets + P - 1, sizeof(int), cudaMemcpyDeviceToHost);
if (num_needed > point_list_keys_unsorted.size())
{
int binning_chunk_size = required<BinningState>(num_rendered);
point_list_keys_unsorted.resize(2 * num_needed);
char* binning_chunkptr = binningBuffer(binning_chunk_size);
point_list_keys.resize(2 * num_needed);
BinningState binningState = BinningState::fromChunk(binning_chunkptr, num_rendered);
point_list_unsorted.resize(2 * num_needed);
point_list.resize(2 * num_needed);
cub::DeviceRadixSort::SortPairs(
nullptr, sorting_size,
point_list_keys_unsorted.data().get(), point_list_keys.data().get(),
point_list_unsorted.data().get(), point_list.data().get(),
2 * num_needed);
list_sorting_space.resize(sorting_size);
}
// For each instance to be rendered, produce adequate [ tile | depth ] key
// For each instance to be rendered, produce adequate [ tile | depth ] key
// and corresponding dublicated Gaussian indices to be sorted
// and corresponding dublicated Gaussian indices to be sorted
duplicateWithKeys << <(P + 255) / 256, 256 >> > (
duplicateWithKeys << <(P + 255) / 256, 256 >> > (
P,
P,
means2D.data().get()
,
geomState.means2D
,
depths.data().get(),
geomState.depths,
point_offsets.data().get(),
geomState.point_offsets,
point_list_keys_unsorted.data().get(),
binningState.point_list_keys_unsorted,
point_list_unsorted.data().get(),
binningState.point_list_unsorted,
radii,
radii,
tile_grid
tile_grid
);
);
...
@@ -285,41 +290,43 @@ void CudaRasterizer::RasterizerImpl::forward(
...
@@ -285,41 +290,43 @@ void CudaRasterizer::RasterizerImpl::forward(
// Sort complete list of (duplicated) Gaussian indices by keys
// Sort complete list of (duplicated) Gaussian indices by keys
cub::DeviceRadixSort::SortPairs(
cub::DeviceRadixSort::SortPairs(
list_sorting_space.data().get()
,
binningState.list_sorting_space
,
sorting_size,
binningState.
sorting_size,
point_list_keys_unsorted.data().get(), point_list_keys.data().get()
,
binningState.point_list_keys_unsorted, binningState.point_list_keys
,
point_list_unsorted.data().get(), point_list.data().get()
,
binningState.point_list_unsorted, binningState.point_list
,
num_
need
ed, 0, 32 + bit);
num_
render
ed, 0, 32 + bit);
cudaMemset(
ranges.data().get()
, 0, tile_grid.x * tile_grid.y * sizeof(uint2));
cudaMemset(
imgState.ranges
, 0, tile_grid.x * tile_grid.y * sizeof(uint2));
// Identify start and end of per-tile workloads in sorted list
// Identify start and end of per-tile workloads in sorted list
identifyTileRanges << <(num_
needed + 255) / 256, 256 >> > (
identifyTileRanges << <(num_
rendered + 255) / 256, 256 >> > (
num_
needed,
num_
rendered,
point_list_keys.data().get(),
binningState.point_list_keys,
ranges.data().get()
imgState.ranges
);
);
// Let each tile blend its range of Gaussians independently in parallel
// Let each tile blend its range of Gaussians independently in parallel
const float* feature_ptr = colors_precomp != nullptr ? colors_precomp :
rgb.data().get()
;
const float* feature_ptr = colors_precomp != nullptr ? colors_precomp :
geomState.rgb
;
FORWARD::render(
FORWARD::render(
tile_grid, block,
tile_grid, block,
ranges.data().get()
,
imgState.ranges
,
point_list.data().get()
,
binningState.point_list
,
width, height,
width, height,
means2D.data().get()
,
geomState.means2D
,
feature_ptr,
feature_ptr,
conic_opacity.data().get()
,
geomState.conic_opacity
,
accum_alpha.data().get()
,
imgState.accum_alpha
,
n_contrib.data().get()
,
imgState.n_contrib
,
background,
background,
out_color);
out_color);
return num_rendered;
}
}
// Produce necessary gradients for optimization, corresponding
// Produce necessary gradients for optimization, corresponding
// to forward render pass
// to forward render pass
void CudaRasterizer::Rasterizer
Impl
::backward(
void CudaRasterizer::Rasterizer::backward(
const int P, int D, int M,
const int P, int D, int M,
int R,
const float* background,
const float* background,
const int width, int height,
const int width, int height,
const float* means3D,
const float* means3D,
...
@@ -334,6 +341,9 @@ void CudaRasterizer::RasterizerImpl::backward(
...
@@ -334,6 +341,9 @@ void CudaRasterizer::RasterizerImpl::backward(
const float* campos,
const float* campos,
const float tan_fovx, float tan_fovy,
const float tan_fovx, float tan_fovy,
const int* radii,
const int* radii,
char* geom_buffer,
char* binning_buffer,
char* img_buffer,
const float* dL_dpix,
const float* dL_dpix,
float* dL_dmean2D,
float* dL_dmean2D,
float* dL_dconic,
float* dL_dconic,
...
@@ -345,9 +355,13 @@ void CudaRasterizer::RasterizerImpl::backward(
...
@@ -345,9 +355,13 @@ void CudaRasterizer::RasterizerImpl::backward(
float* dL_dscale,
float* dL_dscale,
float* dL_drot)
float* dL_drot)
{
{
GeometryState geomState = GeometryState::fromChunk(geom_buffer, P);
BinningState binningState = BinningState::fromChunk(binning_buffer, R);
ImageState imgState = ImageState::fromChunk(img_buffer, width * height);
if (radii == nullptr)
if (radii == nullptr)
{
{
radii =
internal_radii.data().get()
;
radii =
geomState.internal_radii
;
}
}
const float focal_y = height / (2.0f * tan_fovy);
const float focal_y = height / (2.0f * tan_fovy);
...
@@ -359,19 +373,19 @@ void CudaRasterizer::RasterizerImpl::backward(
...
@@ -359,19 +373,19 @@ void CudaRasterizer::RasterizerImpl::backward(
// Compute loss gradients w.r.t. 2D mean position, conic matrix,
// Compute loss gradients w.r.t. 2D mean position, conic matrix,
// opacity and RGB of Gaussians from per-pixel loss gradients.
// opacity and RGB of Gaussians from per-pixel loss gradients.
// If we were given precomputed colors and not SHs, use them.
// If we were given precomputed colors and not SHs, use them.
const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp :
rgb.data().get()
;
const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp :
geomState.rgb
;
BACKWARD::render(
BACKWARD::render(
tile_grid,
tile_grid,
block,
block,
ranges.data().get()
,
imgState.ranges
,
point_list.data().get()
,
binningState.point_list
,
width, height,
width, height,
background,
background,
means2D.data().get()
,
geomState.means2D
,
conic_opacity.data().get()
,
geomState.conic_opacity
,
color_ptr,
color_ptr,
accum_alpha.data().get()
,
imgState.accum_alpha
,
n_contrib.data().get()
,
imgState.n_contrib
,
dL_dpix,
dL_dpix,
(float3*)dL_dmean2D,
(float3*)dL_dmean2D,
(float4*)dL_dconic,
(float4*)dL_dconic,
...
@@ -381,12 +395,12 @@ void CudaRasterizer::RasterizerImpl::backward(
...
@@ -381,12 +395,12 @@ void CudaRasterizer::RasterizerImpl::backward(
// Take care of the rest of preprocessing. Was the precomputed covariance
// Take care of the rest of preprocessing. Was the precomputed covariance
// given to us or a scales/rot pair? If precomputed, pass that. If not,
// given to us or a scales/rot pair? If precomputed, pass that. If not,
// use the one we computed ourselves.
// use the one we computed ourselves.
const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp :
cov3D.data().get()
;
const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp :
geomState.cov3D
;
BACKWARD::preprocess(P, D, M,
BACKWARD::preprocess(P, D, M,
(float3*)means3D,
(float3*)means3D,
radii,
radii,
shs,
shs,
clamped.data().get()
,
geomState.clamped
,
(glm::vec3*)scales,
(glm::vec3*)scales,
(glm::vec4*)rotations,
(glm::vec4*)rotations,
scale_modifier,
scale_modifier,
...
@@ -404,7 +418,3 @@ void CudaRasterizer::RasterizerImpl::backward(
...
@@ -404,7 +418,3 @@ void CudaRasterizer::RasterizerImpl::backward(
(glm::vec3*)dL_dscale,
(glm::vec3*)dL_dscale,
(glm::vec4*)dL_drot);
(glm::vec4*)dL_drot);
}
}
\ No newline at end of file
CudaRasterizer::RasterizerImpl::~RasterizerImpl()
{
}
\ No newline at end of file
cuda_rasterizer/rasterizer_impl.h
View file @
79cbd71d
...
@@ -4,101 +4,60 @@
...
@@ -4,101 +4,60 @@
#include <vector>
#include <vector>
#include "rasterizer.h"
#include "rasterizer.h"
#include <cuda_runtime_api.h>
#include <cuda_runtime_api.h>
#include <thrust/device_vector.h>
namespace
CudaRasterizer
namespace
CudaRasterizer
{
{
class
RasterizerImpl
:
public
Rasterizer
template
<
typename
T
>
static
void
obtain
(
char
*&
chunk
,
T
*&
ptr
,
std
::
size_t
count
,
std
::
size_t
alignment
)
{
{
private
:
std
::
size_t
offset
=
(
reinterpret_cast
<
std
::
uintptr_t
>
(
chunk
)
+
alignment
-
1
)
&
~
(
alignment
-
1
);
int
maxP
=
0
;
ptr
=
reinterpret_cast
<
T
*>
(
offset
)
;
int
maxPixels
=
0
;
chunk
=
reinterpret_cast
<
char
*>
(
ptr
+
count
)
;
int
resizeMultiplier
=
2
;
}
// Initial aux structs
struct
GeometryState
size_t
sorting_size
;
{
size_t
list_sorting_size
;
size_t
scan_size
;
size_t
scan_size
;
thrust
::
device_vector
<
float
>
depths
;
float
*
depths
;
thrust
::
device_vector
<
uint32_t
>
tiles_touched
;
char
*
scanning_space
;
thrust
::
device_vector
<
uint32_t
>
point_offsets
;
bool
*
clamped
;
thrust
::
device_vector
<
uint64_t
>
point_list_keys_unsorted
;
int
*
internal_radii
;
thrust
::
device_vector
<
uint64_t
>
point_list_keys
;
float2
*
means2D
;
thrust
::
device_vector
<
uint32_t
>
point_list_unsorted
;
float
*
cov3D
;
thrust
::
device_vector
<
uint32_t
>
point_list
;
float4
*
conic_opacity
;
thrust
::
device_vector
<
char
>
scanning_space
;
float
*
rgb
;
thrust
::
device_vector
<
char
>
list_sorting_space
;
uint32_t
*
point_offsets
;
thrust
::
device_vector
<
bool
>
clamped
;
uint32_t
*
tiles_touched
;
thrust
::
device_vector
<
int
>
internal_radii
;
static
GeometryState
fromChunk
(
char
*&
chunk
,
int
P
);
// Internal state kept across forward / backward
};
thrust
::
device_vector
<
uint2
>
ranges
;
thrust
::
device_vector
<
uint32_t
>
n_contrib
;
thrust
::
device_vector
<
float
>
accum_alpha
;
thrust
::
device_vector
<
float2
>
means2D
;
thrust
::
device_vector
<
float
>
cov3D
;
thrust
::
device_vector
<
float4
>
conic_opacity
;
thrust
::
device_vector
<
float
>
rgb
;
public
:
virtual
void
markVisible
(
int
P
,
float
*
means3D
,
float
*
viewmatrix
,
float
*
projmatrix
,
bool
*
present
)
override
;
virtual
void
forward
(
struct
ImageState
const
int
P
,
int
D
,
int
M
,
{
const
float
*
background
,
uint2
*
ranges
;
const
int
width
,
int
height
,
uint32_t
*
n_contrib
;
const
float
*
means3D
,
float
*
accum_alpha
;
const
float
*
shs
,
const
float
*
colors_precomp
,
const
float
*
opacities
,
const
float
*
scales
,
const
float
scale_modifier
,
const
float
*
rotations
,
const
float
*
cov3D_precomp
,
const
float
*
viewmatrix
,
const
float
*
projmatrix
,
const
float
*
cam_pos
,
const
float
tan_fovx
,
float
tan_fovy
,
const
bool
prefiltered
,
float
*
out_color
,
int
*
radii
)
override
;
virtual
void
backward
(
static
ImageState
fromChunk
(
char
*&
chunk
,
int
N
);
const
int
P
,
int
D
,
int
M
,
};
const
float
*
background
,
const
int
width
,
int
height
,
const
float
*
means3D
,
const
float
*
shs
,
const
float
*
colors_precomp
,
const
float
*
scales
,
const
float
scale_modifier
,
const
float
*
rotations
,
const
float
*
cov3D_precomp
,
const
float
*
viewmatrix
,
const
float
*
projmatrix
,
const
float
*
campos
,
const
float
tan_fovx
,
float
tan_fovy
,
const
int
*
radii
,
const
float
*
dL_dpix
,
float
*
dL_dmean2D
,
float
*
dL_dconic
,
float
*
dL_dopacity
,
float
*
dL_dcolor
,
float
*
dL_dmean3D
,
float
*
dL_dcov3D
,
float
*
dL_dsh
,
float
*
dL_dscale
,
float
*
dL_drot
)
override
;
RasterizerImpl
(
int
resizeMultiplier
);
struct
BinningState
{
size_t
sorting_size
;
uint64_t
*
point_list_keys_unsorted
;
uint64_t
*
point_list_keys
;
uint32_t
*
point_list_unsorted
;
uint32_t
*
point_list
;
char
*
list_sorting_space
;
virtual
~
RasterizerImpl
()
override
;
static
BinningState
fromChunk
(
char
*&
chunk
,
int
P
)
;
};
};
template
<
typename
T
>
int
required
(
int
P
)
{
char
*
size
=
nullptr
;
T
::
fromChunk
(
size
,
P
);
return
((
int
)
size
)
+
128
;
}
};
};
\ No newline at end of file
diff_gaussian_rasterization/rasterizer.py
View file @
79cbd71d
...
@@ -64,19 +64,21 @@ class _RasterizeGaussians(torch.autograd.Function):
...
@@ -64,19 +64,21 @@ class _RasterizeGaussians(torch.autograd.Function):
)
)
# Invoke C++/CUDA rasterizer
# Invoke C++/CUDA rasterizer
color
,
radii
=
_C
.
rasterize_gaussians
(
*
args
)
num_rendered
,
color
,
radii
,
geomBuffer
,
binningBuffer
,
imgBuffer
=
_C
.
rasterize_gaussians
(
*
args
)
# Keep relevant tensors for backward
# Keep relevant tensors for backward
ctx
.
raster_settings
=
raster_settings
ctx
.
raster_settings
=
raster_settings
ctx
.
save_for_backward
(
colors_precomp
,
means3D
,
scales
,
rotations
,
cov3Ds_precomp
,
radii
,
sh
)
ctx
.
num_rendered
=
num_rendered
ctx
.
save_for_backward
(
colors_precomp
,
means3D
,
scales
,
rotations
,
cov3Ds_precomp
,
radii
,
sh
,
geomBuffer
,
binningBuffer
,
imgBuffer
)
return
color
,
radii
return
color
,
radii
@staticmethod
@staticmethod
def
backward
(
ctx
,
grad_out_color
,
_
):
def
backward
(
ctx
,
grad_out_color
,
_
):
# Restore necessary values from context
# Restore necessary values from context
num_rendered
=
ctx
.
num_rendered
raster_settings
=
ctx
.
raster_settings
raster_settings
=
ctx
.
raster_settings
colors_precomp
,
means3D
,
scales
,
rotations
,
cov3Ds_precomp
,
radii
,
sh
=
ctx
.
saved_tensors
colors_precomp
,
means3D
,
scales
,
rotations
,
cov3Ds_precomp
,
radii
,
sh
,
geomBuffer
,
binningBuffer
,
imgBuffer
=
ctx
.
saved_tensors
# Restructure args as C++ method expects them
# Restructure args as C++ method expects them
args
=
(
raster_settings
.
bg
,
args
=
(
raster_settings
.
bg
,
...
@@ -94,7 +96,11 @@ class _RasterizeGaussians(torch.autograd.Function):
...
@@ -94,7 +96,11 @@ class _RasterizeGaussians(torch.autograd.Function):
grad_out_color
,
grad_out_color
,
sh
,
sh
,
raster_settings
.
sh_degree
,
raster_settings
.
sh_degree
,
raster_settings
.
campos
)
raster_settings
.
campos
,
geomBuffer
,
num_rendered
,
binningBuffer
,
imgBuffer
)
# Compute gradients for relevant tensors by invoking backward method
# Compute gradients for relevant tensors by invoking backward method
grad_means2D
,
grad_colors_precomp
,
grad_opacities
,
grad_means3D
,
grad_cov3Ds_precomp
,
grad_sh
,
grad_scales
,
grad_rotations
=
_C
.
rasterize_gaussians_backward
(
*
args
)
grad_means2D
,
grad_colors_precomp
,
grad_opacities
,
grad_means3D
,
grad_cov3Ds_precomp
,
grad_sh
,
grad_scales
,
grad_rotations
=
_C
.
rasterize_gaussians_backward
(
*
args
)
...
...
rasterize_points.cu
View file @
79cbd71d
...
@@ -13,10 +13,17 @@
...
@@ -13,10 +13,17 @@
#include "cuda_rasterizer/rasterizer.h"
#include "cuda_rasterizer/rasterizer.h"
#include <fstream>
#include <fstream>
#include <string>
#include <string>
#include <functional>
std::function<char*(int N)> resizeFunctional(torch::Tensor& t) {
auto lambda = [&t](int N) {
t.resize_({N});
return reinterpret_cast<char*>(t.contiguous().data_ptr());
};
return lambda;
}
static std::unique_ptr<CudaRasterizer::Rasterizer> cudaRenderer = nullptr;
std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
std::tuple<torch::Tensor, torch::Tensor>
RasterizeGaussiansCUDA(
RasterizeGaussiansCUDA(
const torch::Tensor& background,
const torch::Tensor& background,
const torch::Tensor& means3D,
const torch::Tensor& means3D,
...
@@ -37,16 +44,10 @@ RasterizeGaussiansCUDA(
...
@@ -37,16 +44,10 @@ RasterizeGaussiansCUDA(
const torch::Tensor& campos,
const torch::Tensor& campos,
const bool prefiltered)
const bool prefiltered)
{
{
if (means3D.ndimension() != 2 || means3D.size(1) != 3) {
if (means3D.ndimension() != 2 || means3D.size(1) != 3) {
AT_ERROR("means3D must have dimensions (num_points, 3)");
AT_ERROR("means3D must have dimensions (num_points, 3)");
}
}
if (cudaRenderer == nullptr)
{
cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make());
}
const int P = means3D.size(0);
const int P = means3D.size(0);
const int N = 1; // batch size hard-coded
const int N = 1; // batch size hard-coded
const int H = image_height;
const int H = image_height;
...
@@ -58,6 +59,16 @@ RasterizeGaussiansCUDA(
...
@@ -58,6 +59,16 @@ RasterizeGaussiansCUDA(
torch::Tensor out_color = torch::full({N, NUM_CHANNELS, H, W}, 0.0, float_opts);
torch::Tensor out_color = torch::full({N, NUM_CHANNELS, H, W}, 0.0, float_opts);
torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32));
torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32));
torch::Device device(torch::kCUDA);
torch::TensorOptions options(torch::kByte);
torch::Tensor geomBuffer = torch::empty({0}, options.device(device));
torch::Tensor binningBuffer = torch::empty({0}, options.device(device));
torch::Tensor imgBuffer = torch::empty({0}, options.device(device));
std::function<char*(int)> geomFunc = resizeFunctional(geomBuffer);
std::function<char*(int)> binningFunc = resizeFunctional(binningBuffer);
std::function<char*(int)> imgFunc = resizeFunctional(imgBuffer);
int rendered = 0;
if(P != 0)
if(P != 0)
{
{
int M = 0;
int M = 0;
...
@@ -66,7 +77,11 @@ RasterizeGaussiansCUDA(
...
@@ -66,7 +77,11 @@ RasterizeGaussiansCUDA(
M = sh.size(1);
M = sh.size(1);
}
}
cudaRenderer->forward(P, degree, M,
rendered = CudaRasterizer::Rasterizer::forward(
geomFunc,
binningFunc,
imgFunc,
P, degree, M,
background.contiguous().data<float>(),
background.contiguous().data<float>(),
W, H,
W, H,
means3D.contiguous().data<float>(),
means3D.contiguous().data<float>(),
...
@@ -86,7 +101,7 @@ RasterizeGaussiansCUDA(
...
@@ -86,7 +101,7 @@ RasterizeGaussiansCUDA(
out_color.contiguous().data<float>(),
out_color.contiguous().data<float>(),
radii.contiguous().data<int>());
radii.contiguous().data<int>());
}
}
return std::make_tuple(
out_color, radii
);
return std::make_tuple(
rendered, out_color, radii, geomBuffer, binningBuffer, imgBuffer
);
}
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
...
@@ -106,7 +121,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
...
@@ -106,7 +121,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
const torch::Tensor& dL_dout_color,
const torch::Tensor& dL_dout_color,
const torch::Tensor& sh,
const torch::Tensor& sh,
const int degree,
const int degree,
const torch::Tensor& campos)
const torch::Tensor& campos,
const torch::Tensor& geomBuffer,
const int R,
const torch::Tensor& binningBuffer,
const torch::Tensor& imageBuffer)
{
{
const int P = means3D.size(0);
const int P = means3D.size(0);
const int H = dL_dout_color.size(2);
const int H = dL_dout_color.size(2);
...
@@ -130,7 +149,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
...
@@ -130,7 +149,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
if(P != 0)
if(P != 0)
{
{
cudaRenderer->backward(P, degree, M
,
CudaRasterizer::Rasterizer::backward(P, degree, M, R
,
background.contiguous().data<float>(),
background.contiguous().data<float>(),
W, H,
W, H,
means3D.contiguous().data<float>(),
means3D.contiguous().data<float>(),
...
@@ -146,6 +165,9 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
...
@@ -146,6 +165,9 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
tan_fovx,
tan_fovx,
tan_fovy,
tan_fovy,
radii.contiguous().data<int>(),
radii.contiguous().data<int>(),
reinterpret_cast<char*>(geomBuffer.contiguous().data_ptr()),
reinterpret_cast<char*>(binningBuffer.contiguous().data_ptr()),
reinterpret_cast<char*>(imageBuffer.contiguous().data_ptr()),
dL_dout_color.contiguous().data<float>(),
dL_dout_color.contiguous().data<float>(),
dL_dmeans2D.contiguous().data<float>(),
dL_dmeans2D.contiguous().data<float>(),
dL_dconic.contiguous().data<float>(),
dL_dconic.contiguous().data<float>(),
...
@@ -166,18 +188,13 @@ torch::Tensor markVisible(
...
@@ -166,18 +188,13 @@ torch::Tensor markVisible(
torch::Tensor& viewmatrix,
torch::Tensor& viewmatrix,
torch::Tensor& projmatrix)
torch::Tensor& projmatrix)
{
{
if (cudaRenderer == nullptr)
{
cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make());
}
const int P = means3D.size(0);
const int P = means3D.size(0);
torch::Tensor present = torch::full({P}, false, means3D.options().dtype(at::kBool));
torch::Tensor present = torch::full({P}, false, means3D.options().dtype(at::kBool));
if(P != 0)
if(P != 0)
{
{
cudaRenderer->
markVisible(P,
CudaRasterizer::Rasterizer::
markVisible(P,
means3D.contiguous().data<float>(),
means3D.contiguous().data<float>(),
viewmatrix.contiguous().data<float>(),
viewmatrix.contiguous().data<float>(),
projmatrix.contiguous().data<float>(),
projmatrix.contiguous().data<float>(),
...
...
rasterize_points.h
View file @
79cbd71d
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include <tuple>
#include <tuple>
#include <string>
#include <string>
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
std
::
tuple
<
int
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
RasterizeGaussiansCUDA
(
RasterizeGaussiansCUDA
(
const
torch
::
Tensor
&
background
,
const
torch
::
Tensor
&
background
,
const
torch
::
Tensor
&
means3D
,
const
torch
::
Tensor
&
means3D
,
...
@@ -44,7 +44,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
...
@@ -44,7 +44,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
const
torch
::
Tensor
&
dL_dout_color
,
const
torch
::
Tensor
&
dL_dout_color
,
const
torch
::
Tensor
&
sh
,
const
torch
::
Tensor
&
sh
,
const
int
degree
,
const
int
degree
,
const
torch
::
Tensor
&
campos
);
const
torch
::
Tensor
&
campos
,
const
torch
::
Tensor
&
geomBuffer
,
const
int
R
,
const
torch
::
Tensor
&
binningBuffer
,
const
torch
::
Tensor
&
imageBuffer
);
torch
::
Tensor
markVisible
(
torch
::
Tensor
markVisible
(
torch
::
Tensor
&
means3D
,
torch
::
Tensor
&
means3D
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment