Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
D
diff-gaussian-rasterization
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Alan de Oliveira
diff-gaussian-rasterization
Commits
bccbb2e8
Commit
bccbb2e8
authored
Jun 19, 2023
by
Bernhard Kerbl
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
InternalState explicit now
parent
ffee75d5
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
171 additions
and
110 deletions
+171
-110
rasterizer.h
cuda_rasterizer/rasterizer.h
+12
-4
rasterizer_impl.cu
cuda_rasterizer/rasterizer_impl.cu
+81
-72
rasterizer_impl.h
cuda_rasterizer/rasterizer_impl.h
+40
-26
ext.cpp
ext.cpp
+5
-1
rasterize_points.cu
rasterize_points.cu
+25
-5
rasterize_points.h
rasterize_points.h
+8
-2
No files found.
cuda_rasterizer/rasterizer.h
View file @
bccbb2e8
...
@@ -5,10 +5,16 @@
...
@@ -5,10 +5,16 @@
namespace
CudaRasterizer
namespace
CudaRasterizer
{
{
struct
InternalState
;
class
Rasterizer
class
Rasterizer
{
{
public
:
public
:
virtual
InternalState
*
createInternalState
()
=
0
;
virtual
void
killInternalState
(
InternalState
*
)
=
0
;
virtual
void
markVisible
(
virtual
void
markVisible
(
int
P
,
int
P
,
float
*
means3D
,
float
*
means3D
,
...
@@ -33,10 +39,13 @@ namespace CudaRasterizer
...
@@ -33,10 +39,13 @@ namespace CudaRasterizer
const
float
*
cam_pos
,
const
float
*
cam_pos
,
const
float
tan_fovx
,
float
tan_fovy
,
const
float
tan_fovx
,
float
tan_fovy
,
const
bool
prefiltered
,
const
bool
prefiltered
,
float
*
out_color
,
int
*
radii
,
int
*
radii
=
nullptr
)
=
0
;
InternalState
*
state
,
float
*
out_color
)
=
0
;
virtual
void
backward
(
virtual
void
backward
(
const
int
*
radii
,
const
InternalState
*
state
,
const
int
P
,
int
D
,
int
M
,
const
int
P
,
int
D
,
int
M
,
const
float
*
background
,
const
float
*
background
,
const
int
width
,
int
height
,
const
int
width
,
int
height
,
...
@@ -47,11 +56,10 @@ namespace CudaRasterizer
...
@@ -47,11 +56,10 @@ namespace CudaRasterizer
const
float
scale_modifier
,
const
float
scale_modifier
,
const
float
*
rotations
,
const
float
*
rotations
,
const
float
*
cov3D_precomp
,
const
float
*
cov3D_precomp
,
const
float
*
viewmatrix
,
const
float
*
viewmatrix
,
const
float
*
projmatrix
,
const
float
*
projmatrix
,
const
float
*
campos
,
const
float
*
campos
,
const
float
tan_fovx
,
float
tan_fovy
,
const
float
tan_fovx
,
float
tan_fovy
,
const
int
*
radii
,
const
float
*
dL_dpix
,
const
float
*
dL_dpix
,
float
*
dL_dmean2D
,
float
*
dL_dmean2D
,
float
*
dL_dconic
,
float
*
dL_dconic
,
...
...
cuda_rasterizer/rasterizer_impl.cu
View file @
bccbb2e8
...
@@ -171,46 +171,46 @@ void CudaRasterizer::RasterizerImpl::forward(
...
@@ -171,46 +171,46 @@ void CudaRasterizer::RasterizerImpl::forward(
const float* cam_pos,
const float* cam_pos,
const float tan_fovx, float tan_fovy,
const float tan_fovx, float tan_fovy,
const bool prefiltered,
const bool prefiltered,
float* out_color,
int* radii,
int* radii)
InternalState* state,
float* out_color)
{
{
const float focal_y = height / (2.0f * tan_fovy);
const float focal_y = height / (2.0f * tan_fovy);
const float focal_x = width / (2.0f * tan_fovx);
const float focal_x = width / (2.0f * tan_fovx);
// Dynamically resize auxiliary buffers during training
// Dynamically resize auxiliary buffers during training
if (P > maxP)
if (P > state->maxP)
{
maxP = resizeMultiplier * P;
cov3D.resize(maxP * 6);
rgb.resize(maxP * 3);
tiles_touched.resize(maxP);
point_offsets.resize(maxP);
clamped.resize(3 * maxP);
depths.resize(maxP);
means2D.resize(maxP);
conic_opacity.resize(maxP);
cub::DeviceScan::InclusiveSum(nullptr, scan_size, tiles_touched.data().get(), tiles_touched.data().get(), maxP);
scanning_space.resize(scan_size);
}
if (radii == nullptr)
{
{
internal_radii.resize(maxP);
state->maxP = resizeMultiplier * P;
radii = internal_radii.data().get();
state->cov3D.resize(state->maxP * 6);
state->rgb.resize(state->maxP * 3);
state->tiles_touched.resize(state->maxP);
state->point_offsets.resize(state->maxP);
state->clamped.resize(3 * state->maxP);
state->depths.resize(state->maxP);
state->means2D.resize(state->maxP);
state->conic_opacity.resize(state->maxP);
size_t scan_size;
cub::DeviceScan::InclusiveSum(nullptr,
scan_size,
state->tiles_touched.data().get(),
state->tiles_touched.data().get(),
state->maxP);
state->scanning_space.resize(scan_size);
}
}
dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1);
dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1);
dim3 block(BLOCK_X, BLOCK_Y, 1);
dim3 block(BLOCK_X, BLOCK_Y, 1);
// Dynamically resize image-based auxiliary buffers during training
// Dynamically resize image-based auxiliary buffers during training
if (width * height > maxPixels)
if (width * height >
state->
maxPixels)
{
{
maxPixels = width * height;
state->
maxPixels = width * height;
accum_alpha.resize(
maxPixels);
state->accum_alpha.resize(state->
maxPixels);
n_contrib.resize(
maxPixels);
state->n_contrib.resize(state->
maxPixels);
ranges.resize(tile_grid.x * tile_grid.y);
state->
ranges.resize(tile_grid.x * tile_grid.y);
}
}
if (NUM_CHANNELS != 3 && colors_precomp == nullptr)
if (NUM_CHANNELS != 3 && colors_precomp == nullptr)
...
@@ -227,7 +227,7 @@ void CudaRasterizer::RasterizerImpl::forward(
...
@@ -227,7 +227,7 @@ void CudaRasterizer::RasterizerImpl::forward(
(glm::vec4*)rotations,
(glm::vec4*)rotations,
opacities,
opacities,
shs,
shs,
clamped.data().get(),
state->
clamped.data().get(),
cov3D_precomp,
cov3D_precomp,
colors_precomp,
colors_precomp,
viewmatrix, projmatrix,
viewmatrix, projmatrix,
...
@@ -236,45 +236,56 @@ void CudaRasterizer::RasterizerImpl::forward(
...
@@ -236,45 +236,56 @@ void CudaRasterizer::RasterizerImpl::forward(
tan_fovx, tan_fovy,
tan_fovx, tan_fovy,
focal_x, focal_y,
focal_x, focal_y,
radii,
radii,
means2D.data().get(),
state->
means2D.data().get(),
depths.data().get(),
state->
depths.data().get(),
cov3D.data().get(),
state->
cov3D.data().get(),
rgb.data().get(),
state->
rgb.data().get(),
conic_opacity.data().get(),
state->
conic_opacity.data().get(),
tile_grid,
tile_grid,
tiles_touched.data().get(),
state->
tiles_touched.data().get(),
prefiltered
prefiltered
);
);
// Compute prefix sum over full list of touched tile counts by Gaussians
// Compute prefix sum over full list of touched tile counts by Gaussians
// E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8]
// E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8]
cub::DeviceScan::InclusiveSum(scanning_space.data().get(), scan_size,
size_t scanning_space_size = state->scanning_space.size();
tiles_touched.data().get(), point_offsets.data().get(), P);
cub::DeviceScan::InclusiveSum(
state->scanning_space.data().get(),
scanning_space_size,
state->tiles_touched.data().get(),
state->point_offsets.data().get(),
P);
// Retrieve total number of Gaussian instances to launch and resize aux buffers
// Retrieve total number of Gaussian instances to launch and resize aux buffers
int num_needed;
int num_needed;
cudaMemcpy(&num_needed, point_offsets.data().get() + P - 1, sizeof(int), cudaMemcpyDeviceToHost);
cudaMemcpy(&num_needed,
state->
point_offsets.data().get() + P - 1, sizeof(int), cudaMemcpyDeviceToHost);
if (num_needed > point_list_keys_unsorted.size())
if (num_needed > point_list_keys_unsorted.size())
{
{
point_list_keys_unsorted.resize(2 * num_needed);
int resizeNum = resizeMultiplier * num_needed;
point_list_keys.resize(2 * num_needed);
point_list_keys_unsorted.resize(resizeNum);
point_list_unsorted.resize(2 * num_needed);
point_list_keys.resize(resizeNum);
point_list.resize(2 * num_needed);
point_list_unsorted.resize(resizeNum);
size_t sorting_size;
cub::DeviceRadixSort::SortPairs(
cub::DeviceRadixSort::SortPairs(
nullptr, sorting_size,
nullptr, sorting_size,
point_list_keys_unsorted.data().get(), point_list_keys.data().get(),
point_list_keys_unsorted.data().get(), point_list_keys.data().get(),
point_list_unsorted.data().get(), point_list.data().get(),
point_list_unsorted.data().get(),
state->
point_list.data().get(),
2 * num_needed
);
resizeNum
);
list_sorting_space.resize(sorting_size);
list_sorting_space.resize(sorting_size);
}
}
if (num_needed > state->point_list.size())
{
state->point_list.resize(resizeMultiplier * num_needed);
}
// For each instance to be rendered, produce adequate [ tile | depth ] key
// For each instance to be rendered, produce adequate [ tile | depth ] key
// and corresponding dublicated Gaussian indices to be sorted
// and corresponding dublicated Gaussian indices to be sorted
duplicateWithKeys << <(P + 255) / 256, 256 >> > (
duplicateWithKeys << <(P + 255) / 256, 256 >> > (
P,
P,
means2D.data().get(),
state->
means2D.data().get(),
depths.data().get(),
state->depths.data().get(),
point_offsets.data().get(),
state->point_offsets.data().get(),
point_list_keys_unsorted.data().get(),
point_list_keys_unsorted.data().get(),
point_list_unsorted.data().get(),
point_list_unsorted.data().get(),
radii,
radii,
...
@@ -284,34 +295,36 @@ void CudaRasterizer::RasterizerImpl::forward(
...
@@ -284,34 +295,36 @@ void CudaRasterizer::RasterizerImpl::forward(
int bit = getHigherMsb(tile_grid.x * tile_grid.y);
int bit = getHigherMsb(tile_grid.x * tile_grid.y);
// Sort complete list of (duplicated) Gaussian indices by keys
// Sort complete list of (duplicated) Gaussian indices by keys
size_t list_sorting_space_size = list_sorting_space.size();
cub::DeviceRadixSort::SortPairs(
cub::DeviceRadixSort::SortPairs(
list_sorting_space.data().get(),
list_sorting_space.data().get(),
sorting
_size,
list_sorting_space
_size,
point_list_keys_unsorted.data().get(), point_list_keys.data().get(),
point_list_keys_unsorted.data().get(), point_list_keys.data().get(),
point_list_unsorted.data().get(), point_list.data().get(),
point_list_unsorted.data().get(),
state->point_list.data().get(),
num_needed, 0, 32 + bit);
num_needed, 0, 32 + bit);
cudaMemset(ranges.data().get(), 0, tile_grid.x * tile_grid.y * sizeof(uint2));
cudaMemset(
state->
ranges.data().get(), 0, tile_grid.x * tile_grid.y * sizeof(uint2));
// Identify start and end of per-tile workloads in sorted list
// Identify start and end of per-tile workloads in sorted list
identifyTileRanges << <(num_needed + 255) / 256, 256 >> > (
identifyTileRanges << <(num_needed + 255) / 256, 256 >> > (
num_needed,
num_needed,
point_list_keys.data().get(),
point_list_keys.data().get(),
ranges.data().get()
state->
ranges.data().get()
);
);
// Let each tile blend its range of Gaussians independently in parallel
// Let each tile blend its range of Gaussians independently in parallel
const float* feature_ptr = colors_precomp != nullptr ? colors_precomp : rgb.data().get();
const float* feature_ptr = colors_precomp != nullptr ? colors_precomp :
state->
rgb.data().get();
FORWARD::render(
FORWARD::render(
tile_grid, block,
tile_grid, block,
ranges.data().get(),
state->
ranges.data().get(),
point_list.data().get(),
state->
point_list.data().get(),
width, height,
width, height,
means2D.data().get(),
state->
means2D.data().get(),
feature_ptr,
feature_ptr,
conic_opacity.data().get(),
state->
conic_opacity.data().get(),
accum_alpha.data().get(),
state->
accum_alpha.data().get(),
n_contrib.data().get(),
state->
n_contrib.data().get(),
background,
background,
out_color);
out_color);
}
}
...
@@ -319,6 +332,8 @@ void CudaRasterizer::RasterizerImpl::forward(
...
@@ -319,6 +332,8 @@ void CudaRasterizer::RasterizerImpl::forward(
// Produce necessary gradients for optimization, corresponding
// Produce necessary gradients for optimization, corresponding
// to forward render pass
// to forward render pass
void CudaRasterizer::RasterizerImpl::backward(
void CudaRasterizer::RasterizerImpl::backward(
const int* radii,
const InternalState* state,
const int P, int D, int M,
const int P, int D, int M,
const float* background,
const float* background,
const int width, int height,
const int width, int height,
...
@@ -333,7 +348,6 @@ void CudaRasterizer::RasterizerImpl::backward(
...
@@ -333,7 +348,6 @@ void CudaRasterizer::RasterizerImpl::backward(
const float* projmatrix,
const float* projmatrix,
const float* campos,
const float* campos,
const float tan_fovx, float tan_fovy,
const float tan_fovx, float tan_fovy,
const int* radii,
const float* dL_dpix,
const float* dL_dpix,
float* dL_dmean2D,
float* dL_dmean2D,
float* dL_dconic,
float* dL_dconic,
...
@@ -345,11 +359,6 @@ void CudaRasterizer::RasterizerImpl::backward(
...
@@ -345,11 +359,6 @@ void CudaRasterizer::RasterizerImpl::backward(
float* dL_dscale,
float* dL_dscale,
float* dL_drot)
float* dL_drot)
{
{
if (radii == nullptr)
{
radii = internal_radii.data().get();
}
const float focal_y = height / (2.0f * tan_fovy);
const float focal_y = height / (2.0f * tan_fovy);
const float focal_x = width / (2.0f * tan_fovx);
const float focal_x = width / (2.0f * tan_fovx);
...
@@ -359,19 +368,19 @@ void CudaRasterizer::RasterizerImpl::backward(
...
@@ -359,19 +368,19 @@ void CudaRasterizer::RasterizerImpl::backward(
// Compute loss gradients w.r.t. 2D mean position, conic matrix,
// Compute loss gradients w.r.t. 2D mean position, conic matrix,
// opacity and RGB of Gaussians from per-pixel loss gradients.
// opacity and RGB of Gaussians from per-pixel loss gradients.
// If we were given precomputed colors and not SHs, use them.
// If we were given precomputed colors and not SHs, use them.
const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : rgb.data().get();
const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp :
state->
rgb.data().get();
BACKWARD::render(
BACKWARD::render(
tile_grid,
tile_grid,
block,
block,
ranges.data().get(),
state->
ranges.data().get(),
point_list.data().get(),
state->
point_list.data().get(),
width, height,
width, height,
background,
background,
means2D.data().get(),
state->
means2D.data().get(),
conic_opacity.data().get(),
state->
conic_opacity.data().get(),
color_ptr,
color_ptr,
accum_alpha.data().get(),
state->
accum_alpha.data().get(),
n_contrib.data().get(),
state->
n_contrib.data().get(),
dL_dpix,
dL_dpix,
(float3*)dL_dmean2D,
(float3*)dL_dmean2D,
(float4*)dL_dconic,
(float4*)dL_dconic,
...
@@ -381,12 +390,12 @@ void CudaRasterizer::RasterizerImpl::backward(
...
@@ -381,12 +390,12 @@ void CudaRasterizer::RasterizerImpl::backward(
// Take care of the rest of preprocessing. Was the precomputed covariance
// Take care of the rest of preprocessing. Was the precomputed covariance
// given to us or a scales/rot pair? If precomputed, pass that. If not,
// given to us or a scales/rot pair? If precomputed, pass that. If not,
// use the one we computed ourselves.
// use the one we computed ourselves.
const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp : cov3D.data().get();
const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp :
state->
cov3D.data().get();
BACKWARD::preprocess(P, D, M,
BACKWARD::preprocess(P, D, M,
(float3*)means3D,
(float3*)means3D,
radii,
radii,
shs,
shs,
clamped.data().get(),
state->
clamped.data().get(),
(glm::vec3*)scales,
(glm::vec3*)scales,
(glm::vec4*)rotations,
(glm::vec4*)rotations,
scale_modifier,
scale_modifier,
...
...
cuda_rasterizer/rasterizer_impl.h
View file @
bccbb2e8
...
@@ -8,41 +8,53 @@
...
@@ -8,41 +8,53 @@
namespace
CudaRasterizer
namespace
CudaRasterizer
{
{
class
RasterizerImpl
:
public
Rasterizer
//// Internal state kept across forward / backward
struct
InternalState
{
{
private
:
int
maxP
=
0
;
int
maxP
=
0
;
int
maxPixels
=
0
;
int
maxPixels
=
0
;
int
resizeMultiplier
=
2
;
// Initial aux structs
thrust
::
device_vector
<
uint2
>
ranges
;
size_t
sorting_size
;
thrust
::
device_vector
<
float2
>
means2D
;
size_t
list_sorting_size
;
thrust
::
device_vector
<
float4
>
conic_opacity
;
size_t
scan_size
;
thrust
::
device_vector
<
float
>
accum_alpha
;
thrust
::
device_vector
<
uint32_t
>
n_contrib
;
thrust
::
device_vector
<
float
>
cov3D
;
thrust
::
device_vector
<
bool
>
clamped
;
thrust
::
device_vector
<
float
>
rgb
;
thrust
::
device_vector
<
float
>
depths
;
thrust
::
device_vector
<
float
>
depths
;
thrust
::
device_vector
<
uint32_t
>
tiles_touched
;
thrust
::
device_vector
<
uint32_t
>
point_offsets
;
thrust
::
device_vector
<
uint32_t
>
point_offsets
;
thrust
::
device_vector
<
uint64_t
>
point_list_keys_unsorted
;
thrust
::
device_vector
<
uint32_t
>
tiles_touched
;
thrust
::
device_vector
<
uint64_t
>
point_list_keys
;
thrust
::
device_vector
<
uint32_t
>
point_list_unsorted
;
thrust
::
device_vector
<
uint32_t
>
point_list
;
thrust
::
device_vector
<
char
>
scanning_space
;
thrust
::
device_vector
<
char
>
scanning_space
;
thrust
::
device_vector
<
char
>
list_sorting_space
;
thrust
::
device_vector
<
uint32_t
>
point_list
;
thrust
::
device_vector
<
bool
>
clamped
;
};
thrust
::
device_vector
<
int
>
internal_radii
;
// Internal state kept across forward / backward
// Auxiliary buffer spaces
thrust
::
device_vector
<
uint2
>
ranges
;
thrust
::
device_vector
<
uint64_t
>
point_list_keys_unsorted
;
thrust
::
device_vector
<
uint32_t
>
n_contrib
;
thrust
::
device_vector
<
uint64_t
>
point_list_keys
;
thrust
::
device_vector
<
float
>
accum_alpha
;
thrust
::
device_vector
<
uint32_t
>
point_list_unsorted
;
thrust
::
device_vector
<
char
>
list_sorting_space
;
thrust
::
device_vector
<
float2
>
means2D
;
class
RasterizerImpl
:
public
Rasterizer
thrust
::
device_vector
<
float
>
cov3D
;
{
thrust
::
device_vector
<
float4
>
conic_opacity
;
private
:
thrust
::
device_vector
<
float
>
rgb
;
int
resizeMultiplier
=
2
;
public
:
public
:
virtual
InternalState
*
createInternalState
()
override
{
return
new
InternalState
();
}
virtual
void
killInternalState
(
InternalState
*
is
)
override
{
delete
is
;
}
virtual
void
markVisible
(
virtual
void
markVisible
(
int
P
,
int
P
,
float
*
means3D
,
float
*
means3D
,
...
@@ -67,10 +79,13 @@ namespace CudaRasterizer
...
@@ -67,10 +79,13 @@ namespace CudaRasterizer
const
float
*
cam_pos
,
const
float
*
cam_pos
,
const
float
tan_fovx
,
float
tan_fovy
,
const
float
tan_fovx
,
float
tan_fovy
,
const
bool
prefiltered
,
const
bool
prefiltered
,
float
*
out_color
,
int
*
radii
,
int
*
radii
)
override
;
InternalState
*
state
,
float
*
out_color
)
override
;
virtual
void
backward
(
virtual
void
backward
(
const
int
*
radii
,
const
InternalState
*
fixedState
,
const
int
P
,
int
D
,
int
M
,
const
int
P
,
int
D
,
int
M
,
const
float
*
background
,
const
float
*
background
,
const
int
width
,
int
height
,
const
int
width
,
int
height
,
...
@@ -85,7 +100,6 @@ namespace CudaRasterizer
...
@@ -85,7 +100,6 @@ namespace CudaRasterizer
const
float
*
projmatrix
,
const
float
*
projmatrix
,
const
float
*
campos
,
const
float
*
campos
,
const
float
tan_fovx
,
float
tan_fovy
,
const
float
tan_fovx
,
float
tan_fovy
,
const
int
*
radii
,
const
float
*
dL_dpix
,
const
float
*
dL_dpix
,
float
*
dL_dmean2D
,
float
*
dL_dmean2D
,
float
*
dL_dconic
,
float
*
dL_dconic
,
...
...
ext.cpp
View file @
bccbb2e8
#include <torch/extension.h>
#include <torch/extension.h>
#include "rasterize_points.h"
#include "rasterize_points.h"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"rasterize_gaussians"
,
&
RasterizeGaussiansCUDA
);
m
.
def
(
"rasterize_gaussians"
,
&
RasterizeGaussiansCUDA
);
m
.
def
(
"rasterize_gaussians_backward"
,
&
RasterizeGaussiansBackwardCUDA
);
m
.
def
(
"rasterize_gaussians_backward"
,
&
RasterizeGaussiansBackwardCUDA
);
m
.
def
(
"mark_visible"
,
&
markVisible
);
m
.
def
(
"mark_visible"
,
&
markVisible
);
m
.
def
(
"create_rasterizer_state"
,
&
createRasterizerState
);
m
.
def
(
"delete_rasterizer_state"
,
&
deleteRasterizerState
);
}
}
\ No newline at end of file
rasterize_points.cu
View file @
bccbb2e8
...
@@ -11,11 +11,26 @@
...
@@ -11,11 +11,26 @@
#include <memory>
#include <memory>
#include "cuda_rasterizer/config.h"
#include "cuda_rasterizer/config.h"
#include "cuda_rasterizer/rasterizer.h"
#include "cuda_rasterizer/rasterizer.h"
#include "rasterize_points.h"
#include <fstream>
#include <fstream>
#include <string>
#include <string>
static std::unique_ptr<CudaRasterizer::Rasterizer> cudaRenderer = nullptr;
static std::unique_ptr<CudaRasterizer::Rasterizer> cudaRenderer = nullptr;
void* createRasterizerState()
{
if (cudaRenderer == nullptr)
{
cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make());
}
return (void*)cudaRenderer->createInternalState();
}
void deleteRasterizerState(void* state)
{
cudaRenderer->killInternalState((CudaRasterizer::InternalState*)state);
}
std::tuple<torch::Tensor, torch::Tensor>
std::tuple<torch::Tensor, torch::Tensor>
RasterizeGaussiansCUDA(
RasterizeGaussiansCUDA(
const torch::Tensor& background,
const torch::Tensor& background,
...
@@ -35,7 +50,8 @@ RasterizeGaussiansCUDA(
...
@@ -35,7 +50,8 @@ RasterizeGaussiansCUDA(
const torch::Tensor& sh,
const torch::Tensor& sh,
const int degree,
const int degree,
const torch::Tensor& campos,
const torch::Tensor& campos,
const bool prefiltered)
const bool prefiltered,
void* internalState)
{
{
if (means3D.ndimension() != 2 || means3D.size(1) != 3) {
if (means3D.ndimension() != 2 || means3D.size(1) != 3) {
...
@@ -83,14 +99,16 @@ RasterizeGaussiansCUDA(
...
@@ -83,14 +99,16 @@ RasterizeGaussiansCUDA(
tan_fovx,
tan_fovx,
tan_fovy,
tan_fovy,
prefiltered,
prefiltered,
out_color.contiguous().data<float>(),
radii.contiguous().data<int>(),
radii.contiguous().data<int>());
(CudaRasterizer::InternalState*)internalState,
out_color.contiguous().data<float>());
}
}
return std::make_tuple(out_color, radii);
return std::make_tuple(out_color, radii);
}
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeGaussiansBackwardCUDA(
RasterizeGaussiansBackwardCUDA(
const void* internalState,
const torch::Tensor& background,
const torch::Tensor& background,
const torch::Tensor& means3D,
const torch::Tensor& means3D,
const torch::Tensor& radii,
const torch::Tensor& radii,
...
@@ -130,7 +148,10 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
...
@@ -130,7 +148,10 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
if(P != 0)
if(P != 0)
{
{
cudaRenderer->backward(P, degree, M,
cudaRenderer->backward(
radii.contiguous().data<int>(),
(CudaRasterizer::InternalState*)internalState,
P, degree, M,
background.contiguous().data<float>(),
background.contiguous().data<float>(),
W, H,
W, H,
means3D.contiguous().data<float>(),
means3D.contiguous().data<float>(),
...
@@ -145,7 +166,6 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
...
@@ -145,7 +166,6 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
campos.contiguous().data<float>(),
campos.contiguous().data<float>(),
tan_fovx,
tan_fovx,
tan_fovy,
tan_fovy,
radii.contiguous().data<int>(),
dL_dout_color.contiguous().data<float>(),
dL_dout_color.contiguous().data<float>(),
dL_dmeans2D.contiguous().data<float>(),
dL_dmeans2D.contiguous().data<float>(),
dL_dconic.contiguous().data<float>(),
dL_dconic.contiguous().data<float>(),
...
...
rasterize_points.h
View file @
bccbb2e8
...
@@ -25,10 +25,12 @@ RasterizeGaussiansCUDA(
...
@@ -25,10 +25,12 @@ RasterizeGaussiansCUDA(
const
torch
::
Tensor
&
sh
,
const
torch
::
Tensor
&
sh
,
const
int
degree
,
const
int
degree
,
const
torch
::
Tensor
&
campos
,
const
torch
::
Tensor
&
campos
,
const
bool
prefiltered
);
const
bool
prefiltered
,
void
*
internalState
);
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
RasterizeGaussiansBackwardCUDA
(
RasterizeGaussiansBackwardCUDA
(
const
void
*
internalState
,
const
torch
::
Tensor
&
background
,
const
torch
::
Tensor
&
background
,
const
torch
::
Tensor
&
means3D
,
const
torch
::
Tensor
&
means3D
,
const
torch
::
Tensor
&
radii
,
const
torch
::
Tensor
&
radii
,
...
@@ -39,13 +41,17 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
...
@@ -39,13 +41,17 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
const
torch
::
Tensor
&
cov3D_precomp
,
const
torch
::
Tensor
&
cov3D_precomp
,
const
torch
::
Tensor
&
viewmatrix
,
const
torch
::
Tensor
&
viewmatrix
,
const
torch
::
Tensor
&
projmatrix
,
const
torch
::
Tensor
&
projmatrix
,
const
float
tan_fovx
,
const
float
tan_fovx
,
const
float
tan_fovy
,
const
float
tan_fovy
,
const
torch
::
Tensor
&
dL_dout_color
,
const
torch
::
Tensor
&
dL_dout_color
,
const
torch
::
Tensor
&
sh
,
const
torch
::
Tensor
&
sh
,
const
int
degree
,
const
int
degree
,
const
torch
::
Tensor
&
campos
);
const
torch
::
Tensor
&
campos
);
void
*
createRasterizerState
();
void
deleteRasterizerState
(
void
*
state
);
torch
::
Tensor
markVisible
(
torch
::
Tensor
markVisible
(
torch
::
Tensor
&
means3D
,
torch
::
Tensor
&
means3D
,
torch
::
Tensor
&
viewmatrix
,
torch
::
Tensor
&
viewmatrix
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment