Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
D
diff-gaussian-rasterization
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Alan de Oliveira
diff-gaussian-rasterization
Commits
79cbd71d
Commit
79cbd71d
authored
Jun 20, 2023
by
Bernhard Kerbl
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
No more persistent state
parent
ffee75d5
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
232 additions
and
233 deletions
+232
-233
rasterizer.h
cuda_rasterizer/rasterizer.h
+15
-12
rasterizer_impl.cu
cuda_rasterizer/rasterizer_impl.cu
+119
-110
rasterizer_impl.h
cuda_rasterizer/rasterizer_impl.h
+46
-86
rasterizer.py
diff_gaussian_rasterization/rasterizer.py
+10
-4
rasterize_points.cu
rasterize_points.cu
+36
-19
rasterize_points.h
rasterize_points.h
+6
-2
No files found.
cuda_rasterizer/rasterizer.h
View file @
79cbd71d
...
...
@@ -2,6 +2,7 @@
#define CUDA_RASTERIZER_H_INCLUDED
#include <vector>
#include <functional>
namespace
CudaRasterizer
{
...
...
@@ -9,14 +10,17 @@ namespace CudaRasterizer
{
public
:
virtual
void
markVisible
(
static
void
markVisible
(
int
P
,
float
*
means3D
,
float
*
viewmatrix
,
float
*
projmatrix
,
bool
*
present
)
=
0
;
bool
*
present
);
virtual
void
forward
(
static
int
forward
(
std
::
function
<
char
*
(
int
)
>
geometryBuffer
,
std
::
function
<
char
*
(
int
)
>
binningBuffer
,
std
::
function
<
char
*
(
int
)
>
imageBuffer
,
const
int
P
,
int
D
,
int
M
,
const
float
*
background
,
const
int
width
,
int
height
,
...
...
@@ -34,10 +38,10 @@ namespace CudaRasterizer
const
float
tan_fovx
,
float
tan_fovy
,
const
bool
prefiltered
,
float
*
out_color
,
int
*
radii
=
nullptr
)
=
0
;
int
*
radii
)
;
virtual
void
backward
(
const
int
P
,
int
D
,
int
M
,
static
void
backward
(
const
int
P
,
int
D
,
int
M
,
int
R
,
const
float
*
background
,
const
int
width
,
int
height
,
const
float
*
means3D
,
...
...
@@ -47,11 +51,14 @@ namespace CudaRasterizer
const
float
scale_modifier
,
const
float
*
rotations
,
const
float
*
cov3D_precomp
,
const
float
*
viewmatrix
,
const
float
*
viewmatrix
,
const
float
*
projmatrix
,
const
float
*
campos
,
const
float
tan_fovx
,
float
tan_fovy
,
const
int
*
radii
,
char
*
geom_buffer
,
char
*
binning_buffer
,
char
*
image_buffer
,
const
float
*
dL_dpix
,
float
*
dL_dmean2D
,
float
*
dL_dconic
,
...
...
@@ -61,11 +68,7 @@ namespace CudaRasterizer
float
*
dL_dcov3D
,
float
*
dL_dsh
,
float
*
dL_dscale
,
float
*
dL_drot
)
=
0
;
virtual
~
Rasterizer
()
{};
static
Rasterizer
*
make
(
int
resizeMultipliyer
=
2
);
float
*
dL_drot
);
};
};
...
...
cuda_rasterizer/rasterizer_impl.cu
View file @
79cbd71d
...
...
@@ -8,7 +8,6 @@
#include "device_launch_parameters.h"
#include <cub/cub.cuh>
#include <cub/device/device_radix_sort.cuh>
#include <thrust/sequence.h>
#define GLM_FORCE_CUDA
#include <glm/glm.hpp>
...
...
@@ -127,23 +126,13 @@ __global__ void identifyTileRanges(int L, uint64_t* point_list_keys, uint2* rang
}
}
CudaRasterizer::RasterizerImpl::RasterizerImpl(int resizeMultiplier)
: resizeMultiplier(resizeMultiplier)
{}
// Instantiate rasterizer
CudaRasterizer::Rasterizer* CudaRasterizer::Rasterizer::make(int resizeMultiplier)
{
return new CudaRasterizer::RasterizerImpl(resizeMultiplier);
}
// Mark Gaussians as visible/invisible, based on view frustum testing
void CudaRasterizer::Rasterizer
Impl
::markVisible(
int P,
float* means3D,
float* viewmatrix,
float* projmatrix,
bool* present)
void CudaRasterizer::Rasterizer::markVisible(
int P,
float* means3D,
float* viewmatrix,
float* projmatrix,
bool* present)
{
checkFrustum << <(P + 255) / 256, 256 >> > (
P,
...
...
@@ -152,9 +141,53 @@ void CudaRasterizer::RasterizerImpl::markVisible(
present);
}
CudaRasterizer::GeometryState CudaRasterizer::GeometryState::fromChunk(char*& chunk, int P)
{
GeometryState geom;
obtain(chunk, geom.depths, P, 128);
obtain(chunk, geom.clamped, P, 128);
obtain(chunk, geom.internal_radii, P, 128);
obtain(chunk, geom.means2D, P, 128);
obtain(chunk, geom.cov3D, P * 6, 128);
obtain(chunk, geom.conic_opacity, P, 128);
obtain(chunk, geom.rgb, P * 3, 128);
obtain(chunk, geom.tiles_touched, P, 128);
cub::DeviceScan::InclusiveSum(nullptr, geom.scan_size, geom.tiles_touched, geom.tiles_touched, P);
obtain(chunk, geom.scanning_space, geom.scan_size, 128);
obtain(chunk, geom.point_offsets, P, 128);
return geom;
}
CudaRasterizer::ImageState CudaRasterizer::ImageState::fromChunk(char*& chunk, int N)
{
ImageState img;
obtain(chunk, img.accum_alpha, N, 128);
obtain(chunk, img.n_contrib, N, 128);
obtain(chunk, img.ranges, N, 128);
return img;
}
CudaRasterizer::BinningState CudaRasterizer::BinningState::fromChunk(char*& chunk, int P)
{
BinningState binning;
obtain(chunk, binning.point_list, P, 128);
obtain(chunk, binning.point_list_unsorted, P, 128);
obtain(chunk, binning.point_list_keys, P, 128);
obtain(chunk, binning.point_list_keys_unsorted, P, 128);
cub::DeviceRadixSort::SortPairs(
nullptr, binning.sorting_size,
binning.point_list_keys_unsorted, binning.point_list_keys,
binning.point_list_unsorted, binning.point_list, P);
obtain(chunk, binning.list_sorting_space, binning.sorting_size, 128);
return binning;
}
// Forward rendering procedure for differentiable rasterization
// of Gaussians.
void CudaRasterizer::RasterizerImpl::forward(
int CudaRasterizer::Rasterizer::forward(
std::function<char*(int)> geometryBuffer,
std::function<char* (int)> binningBuffer,
std::function<char* (int)> imageBuffer,
const int P, int D, int M,
const float* background,
const int width, int height,
...
...
@@ -177,41 +210,22 @@ void CudaRasterizer::RasterizerImpl::forward(
const float focal_y = height / (2.0f * tan_fovy);
const float focal_x = width / (2.0f * tan_fovx);
// Dynamically resize auxiliary buffers during training
if (P > maxP)
{
maxP = resizeMultiplier * P;
cov3D.resize(maxP * 6);
rgb.resize(maxP * 3);
tiles_touched.resize(maxP);
point_offsets.resize(maxP);
clamped.resize(3 * maxP);
depths.resize(maxP);
means2D.resize(maxP);
conic_opacity.resize(maxP);
cub::DeviceScan::InclusiveSum(nullptr, scan_size, tiles_touched.data().get(), tiles_touched.data().get(), maxP);
scanning_space.resize(scan_size);
}
int chunk_size = required<GeometryState>(P);
char* chunkptr = geometryBuffer(chunk_size);
GeometryState geomState = GeometryState::fromChunk(chunkptr, P);
if (radii == nullptr)
{
internal_radii.resize(maxP);
radii = internal_radii.data().get();
radii = geomState.internal_radii;
}
dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1);
dim3 block(BLOCK_X, BLOCK_Y, 1);
// Dynamically resize image-based auxiliary buffers during training
if (width * height > maxPixels)
{
maxPixels = width * height;
accum_alpha.resize(maxPixels);
n_contrib.resize(maxPixels);
ranges.resize(tile_grid.x * tile_grid.y);
}
int img_chunk_size = required<ImageState>(width * height);
char* img_chunkptr = imageBuffer(img_chunk_size);
ImageState imgState = ImageState::fromChunk(img_chunkptr, width * height);
if (NUM_CHANNELS != 3 && colors_precomp == nullptr)
{
...
...
@@ -227,7 +241,7 @@ void CudaRasterizer::RasterizerImpl::forward(
(glm::vec4*)rotations,
opacities,
shs,
clamped.data().get()
,
geomState.clamped
,
cov3D_precomp,
colors_precomp,
viewmatrix, projmatrix,
...
...
@@ -236,47 +250,38 @@ void CudaRasterizer::RasterizerImpl::forward(
tan_fovx, tan_fovy,
focal_x, focal_y,
radii,
means2D.data().get()
,
depths.data().get()
,
cov3D.data().get()
,
rgb.data().get()
,
conic_opacity.data().get()
,
geomState.means2D
,
geomState.depths
,
geomState.cov3D
,
geomState.rgb
,
geomState.conic_opacity
,
tile_grid,
tiles_touched.data().get()
,
geomState.tiles_touched
,
prefiltered
);
);
// Compute prefix sum over full list of touched tile counts by Gaussians
// E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8]
cub::DeviceScan::InclusiveSum(
scanning_space.data().get(),
scan_size,
tiles_touched.data().get(), point_offsets.data().get()
, P);
cub::DeviceScan::InclusiveSum(
geomState.scanning_space, geomState.
scan_size,
geomState.tiles_touched, geomState.point_offsets
, P);
// Retrieve total number of Gaussian instances to launch and resize aux buffers
int num_needed;
cudaMemcpy(&num_needed, point_offsets.data().get() + P - 1, sizeof(int), cudaMemcpyDeviceToHost);
if (num_needed > point_list_keys_unsorted.size())
{
point_list_keys_unsorted.resize(2 * num_needed);
point_list_keys.resize(2 * num_needed);
point_list_unsorted.resize(2 * num_needed);
point_list.resize(2 * num_needed);
cub::DeviceRadixSort::SortPairs(
nullptr, sorting_size,
point_list_keys_unsorted.data().get(), point_list_keys.data().get(),
point_list_unsorted.data().get(), point_list.data().get(),
2 * num_needed);
list_sorting_space.resize(sorting_size);
}
int num_rendered;
cudaMemcpy(&num_rendered, geomState.point_offsets + P - 1, sizeof(int), cudaMemcpyDeviceToHost);
int binning_chunk_size = required<BinningState>(num_rendered);
char* binning_chunkptr = binningBuffer(binning_chunk_size);
BinningState binningState = BinningState::fromChunk(binning_chunkptr, num_rendered);
// For each instance to be rendered, produce adequate [ tile | depth ] key
// and corresponding dublicated Gaussian indices to be sorted
duplicateWithKeys << <(P + 255) / 256, 256 >> > (
P,
means2D.data().get()
,
depths.data().get(),
point_offsets.data().get(),
point_list_keys_unsorted.data().get(),
point_list_unsorted.data().get(),
P,
geomState.means2D
,
geomState.depths,
geomState.point_offsets,
binningState.point_list_keys_unsorted,
binningState.point_list_unsorted,
radii,
tile_grid
);
...
...
@@ -285,41 +290,43 @@ void CudaRasterizer::RasterizerImpl::forward(
// Sort complete list of (duplicated) Gaussian indices by keys
cub::DeviceRadixSort::SortPairs(
list_sorting_space.data().get()
,
sorting_size,
point_list_keys_unsorted.data().get(), point_list_keys.data().get()
,
point_list_unsorted.data().get(), point_list.data().get()
,
num_
need
ed, 0, 32 + bit);
binningState.list_sorting_space
,
binningState.
sorting_size,
binningState.point_list_keys_unsorted, binningState.point_list_keys
,
binningState.point_list_unsorted, binningState.point_list
,
num_
render
ed, 0, 32 + bit);
cudaMemset(
ranges.data().get()
, 0, tile_grid.x * tile_grid.y * sizeof(uint2));
cudaMemset(
imgState.ranges
, 0, tile_grid.x * tile_grid.y * sizeof(uint2));
// Identify start and end of per-tile workloads in sorted list
identifyTileRanges << <(num_
needed + 255) / 256, 256 >> > (
num_
needed,
point_list_keys.data().get(),
ranges.data().get()
identifyTileRanges << <(num_
rendered + 255) / 256, 256 >> > (
num_
rendered,
binningState.point_list_keys,
imgState.ranges
);
// Let each tile blend its range of Gaussians independently in parallel
const float* feature_ptr = colors_precomp != nullptr ? colors_precomp :
rgb.data().get()
;
const float* feature_ptr = colors_precomp != nullptr ? colors_precomp :
geomState.rgb
;
FORWARD::render(
tile_grid, block,
ranges.data().get()
,
point_list.data().get()
,
imgState.ranges
,
binningState.point_list
,
width, height,
means2D.data().get()
,
geomState.means2D
,
feature_ptr,
conic_opacity.data().get()
,
accum_alpha.data().get()
,
n_contrib.data().get()
,
geomState.conic_opacity
,
imgState.accum_alpha
,
imgState.n_contrib
,
background,
out_color);
return num_rendered;
}
// Produce necessary gradients for optimization, corresponding
// to forward render pass
void CudaRasterizer::Rasterizer
Impl
::backward(
const int P, int D, int M,
void CudaRasterizer::Rasterizer::backward(
const int P, int D, int M,
int R,
const float* background,
const int width, int height,
const float* means3D,
...
...
@@ -334,6 +341,9 @@ void CudaRasterizer::RasterizerImpl::backward(
const float* campos,
const float tan_fovx, float tan_fovy,
const int* radii,
char* geom_buffer,
char* binning_buffer,
char* img_buffer,
const float* dL_dpix,
float* dL_dmean2D,
float* dL_dconic,
...
...
@@ -345,9 +355,13 @@ void CudaRasterizer::RasterizerImpl::backward(
float* dL_dscale,
float* dL_drot)
{
GeometryState geomState = GeometryState::fromChunk(geom_buffer, P);
BinningState binningState = BinningState::fromChunk(binning_buffer, R);
ImageState imgState = ImageState::fromChunk(img_buffer, width * height);
if (radii == nullptr)
{
radii =
internal_radii.data().get()
;
radii =
geomState.internal_radii
;
}
const float focal_y = height / (2.0f * tan_fovy);
...
...
@@ -359,19 +373,19 @@ void CudaRasterizer::RasterizerImpl::backward(
// Compute loss gradients w.r.t. 2D mean position, conic matrix,
// opacity and RGB of Gaussians from per-pixel loss gradients.
// If we were given precomputed colors and not SHs, use them.
const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp :
rgb.data().get()
;
const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp :
geomState.rgb
;
BACKWARD::render(
tile_grid,
block,
ranges.data().get()
,
point_list.data().get()
,
imgState.ranges
,
binningState.point_list
,
width, height,
background,
means2D.data().get()
,
conic_opacity.data().get()
,
geomState.means2D
,
geomState.conic_opacity
,
color_ptr,
accum_alpha.data().get()
,
n_contrib.data().get()
,
imgState.accum_alpha
,
imgState.n_contrib
,
dL_dpix,
(float3*)dL_dmean2D,
(float4*)dL_dconic,
...
...
@@ -381,12 +395,12 @@ void CudaRasterizer::RasterizerImpl::backward(
// Take care of the rest of preprocessing. Was the precomputed covariance
// given to us or a scales/rot pair? If precomputed, pass that. If not,
// use the one we computed ourselves.
const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp :
cov3D.data().get()
;
const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp :
geomState.cov3D
;
BACKWARD::preprocess(P, D, M,
(float3*)means3D,
(float3*)means3D,
radii,
shs,
clamped.data().get()
,
geomState.clamped
,
(glm::vec3*)scales,
(glm::vec4*)rotations,
scale_modifier,
...
...
@@ -403,8 +417,4 @@ void CudaRasterizer::RasterizerImpl::backward(
dL_dsh,
(glm::vec3*)dL_dscale,
(glm::vec4*)dL_drot);
}
CudaRasterizer::RasterizerImpl::~RasterizerImpl()
{
}
\ No newline at end of file
cuda_rasterizer/rasterizer_impl.h
View file @
79cbd71d
...
...
@@ -4,101 +4,60 @@
#include <vector>
#include "rasterizer.h"
#include <cuda_runtime_api.h>
#include <thrust/device_vector.h>
namespace
CudaRasterizer
{
class
RasterizerImpl
:
public
Rasterizer
template
<
typename
T
>
static
void
obtain
(
char
*&
chunk
,
T
*&
ptr
,
std
::
size_t
count
,
std
::
size_t
alignment
)
{
private
:
int
maxP
=
0
;
int
maxPixels
=
0
;
int
resizeMultiplier
=
2
;
std
::
size_t
offset
=
(
reinterpret_cast
<
std
::
uintptr_t
>
(
chunk
)
+
alignment
-
1
)
&
~
(
alignment
-
1
);
ptr
=
reinterpret_cast
<
T
*>
(
offset
)
;
chunk
=
reinterpret_cast
<
char
*>
(
ptr
+
count
)
;
}
// Initial aux structs
size_t
sorting_size
;
size_t
list_sorting_size
;
struct
GeometryState
{
size_t
scan_size
;
thrust
::
device_vector
<
float
>
depths
;
thrust
::
device_vector
<
uint32_t
>
tiles_touched
;
thrust
::
device_vector
<
uint32_t
>
point_offsets
;
thrust
::
device_vector
<
uint64_t
>
point_list_keys_unsorted
;
thrust
::
device_vector
<
uint64_t
>
point_list_keys
;
thrust
::
device_vector
<
uint32_t
>
point_list_unsorted
;
thrust
::
device_vector
<
uint32_t
>
point_list
;
thrust
::
device_vector
<
char
>
scanning_space
;
thrust
::
device_vector
<
char
>
list_sorting_space
;
thrust
::
device_vector
<
bool
>
clamped
;
thrust
::
device_vector
<
int
>
internal_radii
;
// Internal state kept across forward / backward
thrust
::
device_vector
<
uint2
>
ranges
;
thrust
::
device_vector
<
uint32_t
>
n_contrib
;
thrust
::
device_vector
<
float
>
accum_alpha
;
thrust
::
device_vector
<
float2
>
means2D
;
thrust
::
device_vector
<
float
>
cov3D
;
thrust
::
device_vector
<
float4
>
conic_opacity
;
thrust
::
device_vector
<
float
>
rgb
;
public
:
virtual
void
markVisible
(
int
P
,
float
*
means3D
,
float
*
viewmatrix
,
float
*
projmatrix
,
bool
*
present
)
override
;
float
*
depths
;
char
*
scanning_space
;
bool
*
clamped
;
int
*
internal_radii
;
float2
*
means2D
;
float
*
cov3D
;
float4
*
conic_opacity
;
float
*
rgb
;
uint32_t
*
point_offsets
;
uint32_t
*
tiles_touched
;
static
GeometryState
fromChunk
(
char
*&
chunk
,
int
P
);
};
virtual
void
forward
(
const
int
P
,
int
D
,
int
M
,
const
float
*
background
,
const
int
width
,
int
height
,
const
float
*
means3D
,
const
float
*
shs
,
const
float
*
colors_precomp
,
const
float
*
opacities
,
const
float
*
scales
,
const
float
scale_modifier
,
const
float
*
rotations
,
const
float
*
cov3D_precomp
,
const
float
*
viewmatrix
,
const
float
*
projmatrix
,
const
float
*
cam_pos
,
const
float
tan_fovx
,
float
tan_fovy
,
const
bool
prefiltered
,
float
*
out_color
,
int
*
radii
)
override
;
struct
ImageState
{
uint2
*
ranges
;
uint32_t
*
n_contrib
;
float
*
accum_alpha
;
virtual
void
backward
(
const
int
P
,
int
D
,
int
M
,
const
float
*
background
,
const
int
width
,
int
height
,
const
float
*
means3D
,
const
float
*
shs
,
const
float
*
colors_precomp
,
const
float
*
scales
,
const
float
scale_modifier
,
const
float
*
rotations
,
const
float
*
cov3D_precomp
,
const
float
*
viewmatrix
,
const
float
*
projmatrix
,
const
float
*
campos
,
const
float
tan_fovx
,
float
tan_fovy
,
const
int
*
radii
,
const
float
*
dL_dpix
,
float
*
dL_dmean2D
,
float
*
dL_dconic
,
float
*
dL_dopacity
,
float
*
dL_dcolor
,
float
*
dL_dmean3D
,
float
*
dL_dcov3D
,
float
*
dL_dsh
,
float
*
dL_dscale
,
float
*
dL_drot
)
override
;
static
ImageState
fromChunk
(
char
*&
chunk
,
int
N
);
};
RasterizerImpl
(
int
resizeMultiplier
);
struct
BinningState
{
size_t
sorting_size
;
uint64_t
*
point_list_keys_unsorted
;
uint64_t
*
point_list_keys
;
uint32_t
*
point_list_unsorted
;
uint32_t
*
point_list
;
char
*
list_sorting_space
;
virtual
~
RasterizerImpl
()
override
;
static
BinningState
fromChunk
(
char
*&
chunk
,
int
P
)
;
};
template
<
typename
T
>
int
required
(
int
P
)
{
char
*
size
=
nullptr
;
T
::
fromChunk
(
size
,
P
);
return
((
int
)
size
)
+
128
;
}
};
\ No newline at end of file
diff_gaussian_rasterization/rasterizer.py
View file @
79cbd71d
...
...
@@ -64,19 +64,21 @@ class _RasterizeGaussians(torch.autograd.Function):
)
# Invoke C++/CUDA rasterizer
color
,
radii
=
_C
.
rasterize_gaussians
(
*
args
)
num_rendered
,
color
,
radii
,
geomBuffer
,
binningBuffer
,
imgBuffer
=
_C
.
rasterize_gaussians
(
*
args
)
# Keep relevant tensors for backward
ctx
.
raster_settings
=
raster_settings
ctx
.
save_for_backward
(
colors_precomp
,
means3D
,
scales
,
rotations
,
cov3Ds_precomp
,
radii
,
sh
)
ctx
.
num_rendered
=
num_rendered
ctx
.
save_for_backward
(
colors_precomp
,
means3D
,
scales
,
rotations
,
cov3Ds_precomp
,
radii
,
sh
,
geomBuffer
,
binningBuffer
,
imgBuffer
)
return
color
,
radii
@staticmethod
def
backward
(
ctx
,
grad_out_color
,
_
):
# Restore necessary values from context
num_rendered
=
ctx
.
num_rendered
raster_settings
=
ctx
.
raster_settings
colors_precomp
,
means3D
,
scales
,
rotations
,
cov3Ds_precomp
,
radii
,
sh
=
ctx
.
saved_tensors
colors_precomp
,
means3D
,
scales
,
rotations
,
cov3Ds_precomp
,
radii
,
sh
,
geomBuffer
,
binningBuffer
,
imgBuffer
=
ctx
.
saved_tensors
# Restructure args as C++ method expects them
args
=
(
raster_settings
.
bg
,
...
...
@@ -94,7 +96,11 @@ class _RasterizeGaussians(torch.autograd.Function):
grad_out_color
,
sh
,
raster_settings
.
sh_degree
,
raster_settings
.
campos
)
raster_settings
.
campos
,
geomBuffer
,
num_rendered
,
binningBuffer
,
imgBuffer
)
# Compute gradients for relevant tensors by invoking backward method
grad_means2D
,
grad_colors_precomp
,
grad_opacities
,
grad_means3D
,
grad_cov3Ds_precomp
,
grad_sh
,
grad_scales
,
grad_rotations
=
_C
.
rasterize_gaussians_backward
(
*
args
)
...
...
rasterize_points.cu
View file @
79cbd71d
...
...
@@ -13,10 +13,17 @@
#include "cuda_rasterizer/rasterizer.h"
#include <fstream>
#include <string>
#include <functional>
static std::unique_ptr<CudaRasterizer::Rasterizer> cudaRenderer = nullptr;
std::function<char*(int N)> resizeFunctional(torch::Tensor& t) {
auto lambda = [&t](int N) {
t.resize_({N});
return reinterpret_cast<char*>(t.contiguous().data_ptr());
};
return lambda;
}
std::tuple<torch::Tensor, torch::Tensor>
std::tuple<
int, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor, torch::Tensor>
RasterizeGaussiansCUDA(
const torch::Tensor& background,
const torch::Tensor& means3D,
...
...
@@ -37,16 +44,10 @@ RasterizeGaussiansCUDA(
const torch::Tensor& campos,
const bool prefiltered)
{
if (means3D.ndimension() != 2 || means3D.size(1) != 3) {
AT_ERROR("means3D must have dimensions (num_points, 3)");
}
if (cudaRenderer == nullptr)
{
cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make());
}
const int P = means3D.size(0);
const int N = 1; // batch size hard-coded
const int H = image_height;
...
...
@@ -57,7 +58,17 @@ RasterizeGaussiansCUDA(
torch::Tensor out_color = torch::full({N, NUM_CHANNELS, H, W}, 0.0, float_opts);
torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32));
torch::Device device(torch::kCUDA);
torch::TensorOptions options(torch::kByte);
torch::Tensor geomBuffer = torch::empty({0}, options.device(device));
torch::Tensor binningBuffer = torch::empty({0}, options.device(device));
torch::Tensor imgBuffer = torch::empty({0}, options.device(device));
std::function<char*(int)> geomFunc = resizeFunctional(geomBuffer);
std::function<char*(int)> binningFunc = resizeFunctional(binningBuffer);
std::function<char*(int)> imgFunc = resizeFunctional(imgBuffer);
int rendered = 0;
if(P != 0)
{
int M = 0;
...
...
@@ -66,7 +77,11 @@ RasterizeGaussiansCUDA(
M = sh.size(1);
}
cudaRenderer->forward(P, degree, M,
rendered = CudaRasterizer::Rasterizer::forward(
geomFunc,
binningFunc,
imgFunc,
P, degree, M,
background.contiguous().data<float>(),
W, H,
means3D.contiguous().data<float>(),
...
...
@@ -86,7 +101,7 @@ RasterizeGaussiansCUDA(
out_color.contiguous().data<float>(),
radii.contiguous().data<int>());
}
return std::make_tuple(
out_color, radii
);
return std::make_tuple(
rendered, out_color, radii, geomBuffer, binningBuffer, imgBuffer
);
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
...
...
@@ -106,7 +121,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
const torch::Tensor& dL_dout_color,
const torch::Tensor& sh,
const int degree,
const torch::Tensor& campos)
const torch::Tensor& campos,
const torch::Tensor& geomBuffer,
const int R,
const torch::Tensor& binningBuffer,
const torch::Tensor& imageBuffer)
{
const int P = means3D.size(0);
const int H = dL_dout_color.size(2);
...
...
@@ -130,7 +149,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
if(P != 0)
{
cudaRenderer->backward(P, degree, M
,
CudaRasterizer::Rasterizer::backward(P, degree, M, R
,
background.contiguous().data<float>(),
W, H,
means3D.contiguous().data<float>(),
...
...
@@ -146,6 +165,9 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
tan_fovx,
tan_fovy,
radii.contiguous().data<int>(),
reinterpret_cast<char*>(geomBuffer.contiguous().data_ptr()),
reinterpret_cast<char*>(binningBuffer.contiguous().data_ptr()),
reinterpret_cast<char*>(imageBuffer.contiguous().data_ptr()),
dL_dout_color.contiguous().data<float>(),
dL_dmeans2D.contiguous().data<float>(),
dL_dconic.contiguous().data<float>(),
...
...
@@ -166,18 +188,13 @@ torch::Tensor markVisible(
torch::Tensor& viewmatrix,
torch::Tensor& projmatrix)
{
if (cudaRenderer == nullptr)
{
cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make());
}
const int P = means3D.size(0);
torch::Tensor present = torch::full({P}, false, means3D.options().dtype(at::kBool));
if(P != 0)
{
cudaRenderer->
markVisible(P,
CudaRasterizer::Rasterizer::
markVisible(P,
means3D.contiguous().data<float>(),
viewmatrix.contiguous().data<float>(),
projmatrix.contiguous().data<float>(),
...
...
rasterize_points.h
View file @
79cbd71d
...
...
@@ -6,7 +6,7 @@
#include <tuple>
#include <string>
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
std
::
tuple
<
int
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
RasterizeGaussiansCUDA
(
const
torch
::
Tensor
&
background
,
const
torch
::
Tensor
&
means3D
,
...
...
@@ -44,7 +44,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
const
torch
::
Tensor
&
dL_dout_color
,
const
torch
::
Tensor
&
sh
,
const
int
degree
,
const
torch
::
Tensor
&
campos
);
const
torch
::
Tensor
&
campos
,
const
torch
::
Tensor
&
geomBuffer
,
const
int
R
,
const
torch
::
Tensor
&
binningBuffer
,
const
torch
::
Tensor
&
imageBuffer
);
torch
::
Tensor
markVisible
(
torch
::
Tensor
&
means3D
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment