Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
D
diff-gaussian-rasterization
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Alan de Oliveira
diff-gaussian-rasterization
Commits
55b0c1b0
Commit
55b0c1b0
authored
Jun 20, 2023
by
Bernhard Kerbl
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
This was a bad idea, undoing it
parent
e41a365f
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
126 additions
and
226 deletions
+126
-226
rasterizer.h
cuda_rasterizer/rasterizer.h
+3
-13
rasterizer_impl.cu
cuda_rasterizer/rasterizer_impl.cu
+72
-86
rasterizer_impl.h
cuda_rasterizer/rasterizer_impl.h
+23
-37
rasterizer.py
diff_gaussian_rasterization/rasterizer.py
+9
-33
ext.cpp
ext.cpp
+1
-7
rasterize_points.cu
rasterize_points.cu
+17
-35
rasterize_points.h
rasterize_points.h
+1
-15
No files found.
cuda_rasterizer/rasterizer.h
View file @
55b0c1b0
...
...
@@ -5,16 +5,10 @@
namespace
CudaRasterizer
{
struct
InternalState
;
class
Rasterizer
{
public
:
virtual
InternalState
*
createInternalState
()
=
0
;
virtual
void
killInternalState
(
InternalState
*
)
=
0
;
virtual
void
markVisible
(
int
P
,
float
*
means3D
,
...
...
@@ -39,13 +33,10 @@ namespace CudaRasterizer
const
float
*
cam_pos
,
const
float
tan_fovx
,
float
tan_fovy
,
const
bool
prefiltered
,
int
*
radii
,
InternalState
*
state
,
float
*
out_color
)
=
0
;
float
*
out_color
,
int
*
radii
=
nullptr
)
=
0
;
virtual
void
backward
(
const
int
*
radii
,
const
InternalState
*
state
,
const
int
P
,
int
D
,
int
M
,
const
float
*
background
,
const
int
width
,
int
height
,
...
...
@@ -60,6 +51,7 @@ namespace CudaRasterizer
const
float
*
projmatrix
,
const
float
*
campos
,
const
float
tan_fovx
,
float
tan_fovy
,
const
int
*
radii
,
const
float
*
dL_dpix
,
float
*
dL_dmean2D
,
float
*
dL_dconic
,
...
...
@@ -74,8 +66,6 @@ namespace CudaRasterizer
virtual
~
Rasterizer
()
{};
static
Rasterizer
*
make
(
int
resizeMultipliyer
=
2
);
static
void
kill
(
Rasterizer
*
rasterizer
);
};
};
...
...
cuda_rasterizer/rasterizer_impl.cu
View file @
55b0c1b0
...
...
@@ -137,11 +137,6 @@ CudaRasterizer::Rasterizer* CudaRasterizer::Rasterizer::make(int resizeMultiplie
return new CudaRasterizer::RasterizerImpl(resizeMultiplier);
}
void CudaRasterizer::Rasterizer::kill(Rasterizer* rasterizer)
{
delete rasterizer;
}
// Mark Gaussians as visible/invisible, based on view frustum testing
void CudaRasterizer::RasterizerImpl::markVisible(
int P,
...
...
@@ -176,46 +171,46 @@ void CudaRasterizer::RasterizerImpl::forward(
const float* cam_pos,
const float tan_fovx, float tan_fovy,
const bool prefiltered,
int* radii,
InternalState* state,
float* out_color)
float* out_color,
int* radii)
{
const float focal_y = height / (2.0f * tan_fovy);
const float focal_x = width / (2.0f * tan_fovx);
// Dynamically resize auxiliary buffers during training
if (P >
state->
maxP)
if (P > maxP)
{
state->maxP = resizeMultiplier * P;
state->cov3D.resize(state->maxP * 6);
state->rgb.resize(state->maxP * 3);
state->tiles_touched.resize(state->maxP);
state->point_offsets.resize(state->maxP);
state->clamped.resize(3 * state->maxP);
state->depths.resize(state->maxP);
state->means2D.resize(state->maxP);
state->conic_opacity.resize(state->maxP);
size_t scan_size;
cub::DeviceScan::InclusiveSum(nullptr,
scan_size,
state->tiles_touched.data().get(),
state->tiles_touched.data().get(),
state->maxP);
state->scanning_space.resize(scan_size);
maxP = resizeMultiplier * P;
cov3D.resize(maxP * 6);
rgb.resize(maxP * 3);
tiles_touched.resize(maxP);
point_offsets.resize(maxP);
clamped.resize(3 * maxP);
depths.resize(maxP);
means2D.resize(maxP);
conic_opacity.resize(maxP);
cub::DeviceScan::InclusiveSum(nullptr, scan_size, tiles_touched.data().get(), tiles_touched.data().get(), maxP);
scanning_space.resize(scan_size);
}
if (radii == nullptr)
{
internal_radii.resize(maxP);
radii = internal_radii.data().get();
}
dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1);
dim3 block(BLOCK_X, BLOCK_Y, 1);
// Dynamically resize image-based auxiliary buffers during training
if (width * height >
state->
maxPixels)
if (width * height > maxPixels)
{
state->
maxPixels = width * height;
state->accum_alpha.resize(state->
maxPixels);
state->n_contrib.resize(state->
maxPixels);
state->
ranges.resize(tile_grid.x * tile_grid.y);
maxPixels = width * height;
accum_alpha.resize(
maxPixels);
n_contrib.resize(
maxPixels);
ranges.resize(tile_grid.x * tile_grid.y);
}
if (NUM_CHANNELS != 3 && colors_precomp == nullptr)
...
...
@@ -232,7 +227,7 @@ void CudaRasterizer::RasterizerImpl::forward(
(glm::vec4*)rotations,
opacities,
shs,
state->
clamped.data().get(),
clamped.data().get(),
cov3D_precomp,
colors_precomp,
viewmatrix, projmatrix,
...
...
@@ -241,56 +236,45 @@ void CudaRasterizer::RasterizerImpl::forward(
tan_fovx, tan_fovy,
focal_x, focal_y,
radii,
state->
means2D.data().get(),
state->
depths.data().get(),
state->
cov3D.data().get(),
state->
rgb.data().get(),
state->
conic_opacity.data().get(),
means2D.data().get(),
depths.data().get(),
cov3D.data().get(),
rgb.data().get(),
conic_opacity.data().get(),
tile_grid,
state->
tiles_touched.data().get(),
tiles_touched.data().get(),
prefiltered
);
// Compute prefix sum over full list of touched tile counts by Gaussians
// E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8]
size_t scanning_space_size = state->scanning_space.size();
cub::DeviceScan::InclusiveSum(
state->scanning_space.data().get(),
scanning_space_size,
state->tiles_touched.data().get(),
state->point_offsets.data().get(),
P);
cub::DeviceScan::InclusiveSum(scanning_space.data().get(), scan_size,
tiles_touched.data().get(), point_offsets.data().get(), P);
// Retrieve total number of Gaussian instances to launch and resize aux buffers
int num_needed;
cudaMemcpy(&num_needed,
state->
point_offsets.data().get() + P - 1, sizeof(int), cudaMemcpyDeviceToHost);
cudaMemcpy(&num_needed, point_offsets.data().get() + P - 1, sizeof(int), cudaMemcpyDeviceToHost);
if (num_needed > point_list_keys_unsorted.size())
{
int resizeNum = resizeMultiplier * num_needed;
point_list_keys_unsorted.resize(resizeNum);
point_list_keys.resize(resizeNum);
point_list_unsorted.resize(resizeNum);
size_t sorting_size;
point_list_keys_unsorted.resize(2 * num_needed);
point_list_keys.resize(2 * num_needed);
point_list_unsorted.resize(2 * num_needed);
point_list.resize(2 * num_needed);
cub::DeviceRadixSort::SortPairs(
nullptr, sorting_size,
point_list_keys_unsorted.data().get(), point_list_keys.data().get(),
point_list_unsorted.data().get(),
state->
point_list.data().get(),
resizeNum
);
point_list_unsorted.data().get(), point_list.data().get(),
2 * num_needed
);
list_sorting_space.resize(sorting_size);
}
if (num_needed > state->point_list.size())
{
state->point_list.resize(resizeMultiplier * num_needed);
}
// For each instance to be rendered, produce adequate [ tile | depth ] key
// and corresponding dublicated Gaussian indices to be sorted
duplicateWithKeys << <(P + 255) / 256, 256 >> > (
P,
state->
means2D.data().get(),
state->depths.data().get(),
state->point_offsets.data().get(),
means2D.data().get(),
depths.data().get(),
point_offsets.data().get(),
point_list_keys_unsorted.data().get(),
point_list_unsorted.data().get(),
radii,
...
...
@@ -300,36 +284,34 @@ void CudaRasterizer::RasterizerImpl::forward(
int bit = getHigherMsb(tile_grid.x * tile_grid.y);
// Sort complete list of (duplicated) Gaussian indices by keys
size_t list_sorting_space_size = list_sorting_space.size();
cub::DeviceRadixSort::SortPairs(
list_sorting_space.data().get(),
list_sorting_space
_size,
sorting
_size,
point_list_keys_unsorted.data().get(), point_list_keys.data().get(),
point_list_unsorted.data().get(),
state->point_list.data().get(),
point_list_unsorted.data().get(), point_list.data().get(),
num_needed, 0, 32 + bit);
cudaMemset(
state->
ranges.data().get(), 0, tile_grid.x * tile_grid.y * sizeof(uint2));
cudaMemset(ranges.data().get(), 0, tile_grid.x * tile_grid.y * sizeof(uint2));
// Identify start and end of per-tile workloads in sorted list
identifyTileRanges << <(num_needed + 255) / 256, 256 >> > (
num_needed,
point_list_keys.data().get(),
state->
ranges.data().get()
ranges.data().get()
);
// Let each tile blend its range of Gaussians independently in parallel
const float* feature_ptr = colors_precomp != nullptr ? colors_precomp :
state->
rgb.data().get();
const float* feature_ptr = colors_precomp != nullptr ? colors_precomp : rgb.data().get();
FORWARD::render(
tile_grid, block,
state->
ranges.data().get(),
state->
point_list.data().get(),
ranges.data().get(),
point_list.data().get(),
width, height,
state->
means2D.data().get(),
means2D.data().get(),
feature_ptr,
state->
conic_opacity.data().get(),
state->
accum_alpha.data().get(),
state->
n_contrib.data().get(),
conic_opacity.data().get(),
accum_alpha.data().get(),
n_contrib.data().get(),
background,
out_color);
}
...
...
@@ -337,8 +319,6 @@ void CudaRasterizer::RasterizerImpl::forward(
// Produce necessary gradients for optimization, corresponding
// to forward render pass
void CudaRasterizer::RasterizerImpl::backward(
const int* radii,
const InternalState* state,
const int P, int D, int M,
const float* background,
const int width, int height,
...
...
@@ -353,6 +333,7 @@ void CudaRasterizer::RasterizerImpl::backward(
const float* projmatrix,
const float* campos,
const float tan_fovx, float tan_fovy,
const int* radii,
const float* dL_dpix,
float* dL_dmean2D,
float* dL_dconic,
...
...
@@ -364,6 +345,11 @@ void CudaRasterizer::RasterizerImpl::backward(
float* dL_dscale,
float* dL_drot)
{
if (radii == nullptr)
{
radii = internal_radii.data().get();
}
const float focal_y = height / (2.0f * tan_fovy);
const float focal_x = width / (2.0f * tan_fovx);
...
...
@@ -373,19 +359,19 @@ void CudaRasterizer::RasterizerImpl::backward(
// Compute loss gradients w.r.t. 2D mean position, conic matrix,
// opacity and RGB of Gaussians from per-pixel loss gradients.
// If we were given precomputed colors and not SHs, use them.
const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp :
state->
rgb.data().get();
const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : rgb.data().get();
BACKWARD::render(
tile_grid,
block,
state->
ranges.data().get(),
state->
point_list.data().get(),
ranges.data().get(),
point_list.data().get(),
width, height,
background,
state->
means2D.data().get(),
state->
conic_opacity.data().get(),
means2D.data().get(),
conic_opacity.data().get(),
color_ptr,
state->
accum_alpha.data().get(),
state->
n_contrib.data().get(),
accum_alpha.data().get(),
n_contrib.data().get(),
dL_dpix,
(float3*)dL_dmean2D,
(float4*)dL_dconic,
...
...
@@ -395,12 +381,12 @@ void CudaRasterizer::RasterizerImpl::backward(
// Take care of the rest of preprocessing. Was the precomputed covariance
// given to us or a scales/rot pair? If precomputed, pass that. If not,
// use the one we computed ourselves.
const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp :
state->
cov3D.data().get();
const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp : cov3D.data().get();
BACKWARD::preprocess(P, D, M,
(float3*)means3D,
radii,
shs,
state->
clamped.data().get(),
clamped.data().get(),
(glm::vec3*)scales,
(glm::vec4*)rotations,
scale_modifier,
...
...
cuda_rasterizer/rasterizer_impl.h
View file @
55b0c1b0
...
...
@@ -8,53 +8,41 @@
namespace
CudaRasterizer
{
//// Internal state kept across forward / backward
struct
InternalState
class
RasterizerImpl
:
public
Rasterizer
{
private
:
int
maxP
=
0
;
int
maxPixels
=
0
;
int
resizeMultiplier
=
2
;
thrust
::
device_vector
<
uint2
>
ranges
;
thrust
::
device_vector
<
float2
>
means2D
;
thrust
::
device_vector
<
float4
>
conic_opacity
;
thrust
::
device_vector
<
float
>
accum_alpha
;
thrust
::
device_vector
<
uint32_t
>
n_contrib
;
thrust
::
device_vector
<
float
>
cov3D
;
thrust
::
device_vector
<
bool
>
clamped
;
thrust
::
device_vector
<
float
>
rgb
;
// Initial aux structs
size_t
sorting_size
;
size_t
list_sorting_size
;
size_t
scan_size
;
thrust
::
device_vector
<
float
>
depths
;
thrust
::
device_vector
<
uint32_t
>
point_offsets
;
thrust
::
device_vector
<
uint32_t
>
tiles_touched
;
thrust
::
device_vector
<
char
>
scanning_space
;
thrust
::
device_vector
<
uint32_t
>
point_list
;
};
// Auxiliary buffer spaces
thrust
::
device_vector
<
uint32_t
>
point_offsets
;
thrust
::
device_vector
<
uint64_t
>
point_list_keys_unsorted
;
thrust
::
device_vector
<
uint64_t
>
point_list_keys
;
thrust
::
device_vector
<
uint32_t
>
point_list_unsorted
;
thrust
::
device_vector
<
uint32_t
>
point_list
;
thrust
::
device_vector
<
char
>
scanning_space
;
thrust
::
device_vector
<
char
>
list_sorting_space
;
thrust
::
device_vector
<
bool
>
clamped
;
thrust
::
device_vector
<
int
>
internal_radii
;
class
RasterizerImpl
:
public
Rasterizer
{
private
:
// Internal state kept across forward / backward
thrust
::
device_vector
<
uint2
>
ranges
;
thrust
::
device_vector
<
uint32_t
>
n_contrib
;
thrust
::
device_vector
<
float
>
accum_alpha
;
int
resizeMultiplier
=
2
;
thrust
::
device_vector
<
float2
>
means2D
;
thrust
::
device_vector
<
float
>
cov3D
;
thrust
::
device_vector
<
float4
>
conic_opacity
;
thrust
::
device_vector
<
float
>
rgb
;
public
:
virtual
InternalState
*
createInternalState
()
override
{
return
new
InternalState
();
}
virtual
void
killInternalState
(
InternalState
*
is
)
override
{
delete
is
;
}
virtual
void
markVisible
(
int
P
,
float
*
means3D
,
...
...
@@ -79,13 +67,10 @@ namespace CudaRasterizer
const
float
*
cam_pos
,
const
float
tan_fovx
,
float
tan_fovy
,
const
bool
prefiltered
,
int
*
radii
,
InternalState
*
state
,
float
*
out_color
)
override
;
float
*
out_color
,
int
*
radii
)
override
;
virtual
void
backward
(
const
int
*
radii
,
const
InternalState
*
fixedState
,
const
int
P
,
int
D
,
int
M
,
const
float
*
background
,
const
int
width
,
int
height
,
...
...
@@ -100,6 +85,7 @@ namespace CudaRasterizer
const
float
*
projmatrix
,
const
float
*
campos
,
const
float
tan_fovx
,
float
tan_fovy
,
const
int
*
radii
,
const
float
*
dL_dpix
,
float
*
dL_dmean2D
,
float
*
dL_dconic
,
...
...
diff_gaussian_rasterization/rasterizer.py
View file @
55b0c1b0
...
...
@@ -4,7 +4,6 @@ import torch
from
.
import
_C
def
rasterize_gaussians
(
instance
,
means3D
,
means2D
,
sh
,
...
...
@@ -14,10 +13,8 @@ def rasterize_gaussians(
rotations
,
cov3Ds_precomp
,
raster_settings
,
rasterizer_state
):
return
_RasterizeGaussians
.
apply
(
instance
,
means3D
,
means2D
,
sh
,
...
...
@@ -27,14 +24,12 @@ def rasterize_gaussians(
rotations
,
cov3Ds_precomp
,
raster_settings
,
rasterizer_state
)
class
_RasterizeGaussians
(
torch
.
autograd
.
Function
):
@staticmethod
def
forward
(
ctx
,
instance
,
means3D
,
means2D
,
sh
,
...
...
@@ -44,12 +39,10 @@ class _RasterizeGaussians(torch.autograd.Function):
rotations
,
cov3Ds_precomp
,
raster_settings
,
rasterizer_state
):
# Restructure arguments the way that the C++ lib expects them
args
=
(
instance
,
raster_settings
.
bg
,
means3D
,
colors_precomp
,
...
...
@@ -68,7 +61,6 @@ class _RasterizeGaussians(torch.autograd.Function):
raster_settings
.
sh_degree
,
raster_settings
.
campos
,
raster_settings
.
prefiltered
,
rasterizer_state
)
# Invoke C++/CUDA rasterizer
...
...
@@ -76,8 +68,6 @@ class _RasterizeGaussians(torch.autograd.Function):
# Keep relevant tensors for backward
ctx
.
raster_settings
=
raster_settings
ctx
.
instance
=
instance
ctx
.
rasterizer_state
=
rasterizer_state
ctx
.
save_for_backward
(
colors_precomp
,
means3D
,
scales
,
rotations
,
cov3Ds_precomp
,
radii
,
sh
)
return
color
,
radii
...
...
@@ -85,15 +75,11 @@ class _RasterizeGaussians(torch.autograd.Function):
def
backward
(
ctx
,
grad_out_color
,
_
):
# Restore necessary values from context
instance
=
ctx
.
instance
rasterizer_state
=
ctx
.
rasterizer_state
raster_settings
=
ctx
.
raster_settings
colors_precomp
,
means3D
,
scales
,
rotations
,
cov3Ds_precomp
,
radii
,
sh
=
ctx
.
saved_tensors
# Restructure args as C++ method expects them
args
=
(
instance
,
rasterizer_state
,
raster_settings
.
bg
,
args
=
(
raster_settings
.
bg
,
means3D
,
radii
,
colors_precomp
,
...
...
@@ -114,7 +100,6 @@ class _RasterizeGaussians(torch.autograd.Function):
grad_means2D
,
grad_colors_precomp
,
grad_opacities
,
grad_means3D
,
grad_cov3Ds_precomp
,
grad_sh
,
grad_scales
,
grad_rotations
=
_C
.
rasterize_gaussians_backward
(
*
args
)
grads
=
(
None
,
grad_means3D
,
grad_means2D
,
grad_sh
,
...
...
@@ -124,7 +109,6 @@ class _RasterizeGaussians(torch.autograd.Function):
grad_rotations
,
grad_cov3Ds_precomp
,
None
,
None
,
)
return
grads
...
...
@@ -143,31 +127,25 @@ class GaussianRasterizationSettings(NamedTuple):
prefiltered
:
bool
class
GaussianRasterizer
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
,
raster_settings
):
super
()
.
__init__
()
self
.
instance
=
_C
.
create_rasterizer
()
def
__del__
(
self
):
_C
.
delete_rasterizer
(
self
.
instance
)
self
.
raster_settings
=
raster_settings
def
createRasterizerState
(
self
):
return
_C
.
create_rasterizer_state
(
self
.
instance
)
def
deleteRasterizerState
(
self
,
state
):
_C
.
delete_rasterizer_state
(
self
.
instance
,
state
)
def
markVisible
(
self
,
raster_settings
,
positions
):
def
markVisible
(
self
,
positions
):
# Mark visible points (based on frustum culling for camera) with a boolean
with
torch
.
no_grad
():
raster_settings
=
self
.
raster_settings
visible
=
_C
.
mark_visible
(
self
.
instance
,
positions
,
raster_settings
.
viewmatrix
,
raster_settings
.
projmatrix
)
return
visible
def
forward
(
self
,
rasterizer_state
,
raster_settings
,
means3D
,
means2D
,
opacities
,
shs
=
None
,
colors_precomp
=
None
,
scales
=
None
,
rotations
=
None
,
cov3D_precomp
=
None
):
def
forward
(
self
,
means3D
,
means2D
,
opacities
,
shs
=
None
,
colors_precomp
=
None
,
scales
=
None
,
rotations
=
None
,
cov3D_precomp
=
None
):
raster_settings
=
self
.
raster_settings
if
(
shs
is
None
and
colors_precomp
is
None
)
or
(
shs
is
not
None
and
colors_precomp
is
not
None
):
raise
Exception
(
'Please provide excatly one of either SHs or precomputed colors!'
)
...
...
@@ -188,7 +166,6 @@ class GaussianRasterizer(nn.Module):
# Invoke C++/CUDA rasterization routine
return
rasterize_gaussians
(
self
.
instance
,
means3D
,
means2D
,
shs
,
...
...
@@ -198,6 +175,5 @@ class GaussianRasterizer(nn.Module):
rotations
,
cov3D_precomp
,
raster_settings
,
rasterizer_state
)
ext.cpp
View file @
55b0c1b0
#include <torch/extension.h>
#include "rasterize_points.h"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"rasterize_gaussians"
,
&
RasterizeGaussiansCUDA
);
m
.
def
(
"rasterize_gaussians_backward"
,
&
RasterizeGaussiansBackwardCUDA
);
m
.
def
(
"mark_visible"
,
&
markVisible
);
m
.
def
(
"create_rasterizer_state"
,
&
createRasterizerState
);
m
.
def
(
"delete_rasterizer_state"
,
&
deleteRasterizerState
);
m
.
def
(
"create_rasterizer"
,
&
createRasterizer
);
m
.
def
(
"delete_rasterizer"
,
&
deleteRasterizer
);
}
\ No newline at end of file
rasterize_points.cu
View file @
55b0c1b0
...
...
@@ -11,33 +11,13 @@
#include <memory>
#include "cuda_rasterizer/config.h"
#include "cuda_rasterizer/rasterizer.h"
#include "rasterize_points.h"
#include <fstream>
#include <string>
void* createRasterizer()
{
return (void*)CudaRasterizer::Rasterizer::make();
}
void deleteRasterizer(void* rasterizer)
{
CudaRasterizer::Rasterizer::kill((CudaRasterizer::Rasterizer*)rasterizer);
}
void* createRasterizerState(void* rasterizer)
{
return (void*)((CudaRasterizer::Rasterizer*)rasterizer)->createInternalState();
}
void deleteRasterizerState(void* rasterizer, void* state)
{
((CudaRasterizer::Rasterizer*)rasterizer)->killInternalState((CudaRasterizer::InternalState*)state);
}
static std::unique_ptr<CudaRasterizer::Rasterizer> cudaRenderer = nullptr;
std::tuple<torch::Tensor, torch::Tensor>
RasterizeGaussiansCUDA(
void* rasterizer,
const torch::Tensor& background,
const torch::Tensor& means3D,
const torch::Tensor& colors,
...
...
@@ -55,14 +35,18 @@ RasterizeGaussiansCUDA(
const torch::Tensor& sh,
const int degree,
const torch::Tensor& campos,
const bool prefiltered,
void* internalState)
const bool prefiltered)
{
if (means3D.ndimension() != 2 || means3D.size(1) != 3) {
AT_ERROR("means3D must have dimensions (num_points, 3)");
}
if (cudaRenderer == nullptr)
{
cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make());
}
const int P = means3D.size(0);
const int N = 1; // batch size hard-coded
const int H = image_height;
...
...
@@ -82,7 +66,7 @@ RasterizeGaussiansCUDA(
M = sh.size(1);
}
((CudaRasterizer::Rasterizer*)rasterizer)
->forward(P, degree, M,
cudaRenderer
->forward(P, degree, M,
background.contiguous().data<float>(),
W, H,
means3D.contiguous().data<float>(),
...
...
@@ -99,17 +83,14 @@ RasterizeGaussiansCUDA(
tan_fovx,
tan_fovy,
prefiltered,
radii.contiguous().data<int>(),
(CudaRasterizer::InternalState*)internalState,
out_color.contiguous().data<float>());
out_color.contiguous().data<float>(),
radii.contiguous().data<int>());
}
return std::make_tuple(out_color, radii);
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeGaussiansBackwardCUDA(
void* rasterizer,
const void* internalState,
const torch::Tensor& background,
const torch::Tensor& means3D,
const torch::Tensor& radii,
...
...
@@ -149,10 +130,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
if(P != 0)
{
((CudaRasterizer::Rasterizer*)rasterizer)->backward(
radii.contiguous().data<int>(),
(CudaRasterizer::InternalState*)internalState,
P, degree, M,
cudaRenderer->backward(P, degree, M,
background.contiguous().data<float>(),
W, H,
means3D.contiguous().data<float>(),
...
...
@@ -167,6 +145,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
campos.contiguous().data<float>(),
tan_fovx,
tan_fovy,
radii.contiguous().data<int>(),
dL_dout_color.contiguous().data<float>(),
dL_dmeans2D.contiguous().data<float>(),
dL_dconic.contiguous().data<float>(),
...
...
@@ -183,11 +162,14 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
}
torch::Tensor markVisible(
void* rasterizer,
torch::Tensor& means3D,
torch::Tensor& viewmatrix,
torch::Tensor& projmatrix)
{
if (cudaRenderer == nullptr)
{
cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make());
}
const int P = means3D.size(0);
...
...
@@ -195,7 +177,7 @@ torch::Tensor markVisible(
if(P != 0)
{
((CudaRasterizer::Rasterizer*)rasterizer)
->markVisible(P,
cudaRenderer
->markVisible(P,
means3D.contiguous().data<float>(),
viewmatrix.contiguous().data<float>(),
projmatrix.contiguous().data<float>(),
...
...
rasterize_points.h
View file @
55b0c1b0
...
...
@@ -8,7 +8,6 @@
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
RasterizeGaussiansCUDA
(
void
*
rasterizer
,
const
torch
::
Tensor
&
background
,
const
torch
::
Tensor
&
means3D
,
const
torch
::
Tensor
&
colors
,
...
...
@@ -26,13 +25,10 @@ RasterizeGaussiansCUDA(
const
torch
::
Tensor
&
sh
,
const
int
degree
,
const
torch
::
Tensor
&
campos
,
const
bool
prefiltered
,
void
*
internalState
);
const
bool
prefiltered
);
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
RasterizeGaussiansBackwardCUDA
(
void
*
rasterizer
,
const
void
*
internalState
,
const
torch
::
Tensor
&
background
,
const
torch
::
Tensor
&
means3D
,
const
torch
::
Tensor
&
radii
,
...
...
@@ -50,16 +46,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
const
int
degree
,
const
torch
::
Tensor
&
campos
);
void
*
createRasterizerState
(
void
*
rasterizer
);
void
deleteRasterizerState
(
void
*
rasterizer
,
void
*
state
);
void
*
createRasterizer
();
void
deleteRasterizer
(
void
*
rasterizer
);
torch
::
Tensor
markVisible
(
void
*
rasterizer
,
torch
::
Tensor
&
means3D
,
torch
::
Tensor
&
viewmatrix
,
torch
::
Tensor
&
projmatrix
);
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment