Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
D
diff-gaussian-rasterization
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Alan de Oliveira
diff-gaussian-rasterization
Commits
79cbd71d
Commit
79cbd71d
authored
Jun 20, 2023
by
Bernhard Kerbl
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
No more persistent state
parent
ffee75d5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
113 additions
and
123 deletions
+113
-123
rasterizer.h
cuda_rasterizer/rasterizer.h
+15
-12
rasterizer_impl.cu
cuda_rasterizer/rasterizer_impl.cu
+0
-0
rasterizer_impl.h
cuda_rasterizer/rasterizer_impl.h
+46
-86
rasterizer.py
diff_gaussian_rasterization/rasterizer.py
+10
-4
rasterize_points.cu
rasterize_points.cu
+36
-19
rasterize_points.h
rasterize_points.h
+6
-2
No files found.
cuda_rasterizer/rasterizer.h
View file @
79cbd71d
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#define CUDA_RASTERIZER_H_INCLUDED
#define CUDA_RASTERIZER_H_INCLUDED
#include <vector>
#include <vector>
#include <functional>
namespace
CudaRasterizer
namespace
CudaRasterizer
{
{
...
@@ -9,14 +10,17 @@ namespace CudaRasterizer
...
@@ -9,14 +10,17 @@ namespace CudaRasterizer
{
{
public
:
public
:
virtual
void
markVisible
(
static
void
markVisible
(
int
P
,
int
P
,
float
*
means3D
,
float
*
means3D
,
float
*
viewmatrix
,
float
*
viewmatrix
,
float
*
projmatrix
,
float
*
projmatrix
,
bool
*
present
)
=
0
;
bool
*
present
);
virtual
void
forward
(
static
int
forward
(
std
::
function
<
char
*
(
int
)
>
geometryBuffer
,
std
::
function
<
char
*
(
int
)
>
binningBuffer
,
std
::
function
<
char
*
(
int
)
>
imageBuffer
,
const
int
P
,
int
D
,
int
M
,
const
int
P
,
int
D
,
int
M
,
const
float
*
background
,
const
float
*
background
,
const
int
width
,
int
height
,
const
int
width
,
int
height
,
...
@@ -34,10 +38,10 @@ namespace CudaRasterizer
...
@@ -34,10 +38,10 @@ namespace CudaRasterizer
const
float
tan_fovx
,
float
tan_fovy
,
const
float
tan_fovx
,
float
tan_fovy
,
const
bool
prefiltered
,
const
bool
prefiltered
,
float
*
out_color
,
float
*
out_color
,
int
*
radii
=
nullptr
)
=
0
;
int
*
radii
)
;
virtual
void
backward
(
static
void
backward
(
const
int
P
,
int
D
,
int
M
,
const
int
P
,
int
D
,
int
M
,
int
R
,
const
float
*
background
,
const
float
*
background
,
const
int
width
,
int
height
,
const
int
width
,
int
height
,
const
float
*
means3D
,
const
float
*
means3D
,
...
@@ -47,11 +51,14 @@ namespace CudaRasterizer
...
@@ -47,11 +51,14 @@ namespace CudaRasterizer
const
float
scale_modifier
,
const
float
scale_modifier
,
const
float
*
rotations
,
const
float
*
rotations
,
const
float
*
cov3D_precomp
,
const
float
*
cov3D_precomp
,
const
float
*
viewmatrix
,
const
float
*
viewmatrix
,
const
float
*
projmatrix
,
const
float
*
projmatrix
,
const
float
*
campos
,
const
float
*
campos
,
const
float
tan_fovx
,
float
tan_fovy
,
const
float
tan_fovx
,
float
tan_fovy
,
const
int
*
radii
,
const
int
*
radii
,
char
*
geom_buffer
,
char
*
binning_buffer
,
char
*
image_buffer
,
const
float
*
dL_dpix
,
const
float
*
dL_dpix
,
float
*
dL_dmean2D
,
float
*
dL_dmean2D
,
float
*
dL_dconic
,
float
*
dL_dconic
,
...
@@ -61,11 +68,7 @@ namespace CudaRasterizer
...
@@ -61,11 +68,7 @@ namespace CudaRasterizer
float
*
dL_dcov3D
,
float
*
dL_dcov3D
,
float
*
dL_dsh
,
float
*
dL_dsh
,
float
*
dL_dscale
,
float
*
dL_dscale
,
float
*
dL_drot
)
=
0
;
float
*
dL_drot
);
virtual
~
Rasterizer
()
{};
static
Rasterizer
*
make
(
int
resizeMultipliyer
=
2
);
};
};
};
};
...
...
cuda_rasterizer/rasterizer_impl.cu
View file @
79cbd71d
This diff is collapsed.
Click to expand it.
cuda_rasterizer/rasterizer_impl.h
View file @
79cbd71d
...
@@ -4,101 +4,60 @@
...
@@ -4,101 +4,60 @@
#include <vector>
#include <vector>
#include "rasterizer.h"
#include "rasterizer.h"
#include <cuda_runtime_api.h>
#include <cuda_runtime_api.h>
#include <thrust/device_vector.h>
namespace
CudaRasterizer
namespace
CudaRasterizer
{
{
class
RasterizerImpl
:
public
Rasterizer
template
<
typename
T
>
static
void
obtain
(
char
*&
chunk
,
T
*&
ptr
,
std
::
size_t
count
,
std
::
size_t
alignment
)
{
{
private
:
std
::
size_t
offset
=
(
reinterpret_cast
<
std
::
uintptr_t
>
(
chunk
)
+
alignment
-
1
)
&
~
(
alignment
-
1
);
int
maxP
=
0
;
ptr
=
reinterpret_cast
<
T
*>
(
offset
)
;
int
maxPixels
=
0
;
chunk
=
reinterpret_cast
<
char
*>
(
ptr
+
count
)
;
int
resizeMultiplier
=
2
;
}
// Initial aux structs
struct
GeometryState
size_t
sorting_size
;
{
size_t
list_sorting_size
;
size_t
scan_size
;
size_t
scan_size
;
thrust
::
device_vector
<
float
>
depths
;
float
*
depths
;
thrust
::
device_vector
<
uint32_t
>
tiles_touched
;
char
*
scanning_space
;
thrust
::
device_vector
<
uint32_t
>
point_offsets
;
bool
*
clamped
;
thrust
::
device_vector
<
uint64_t
>
point_list_keys_unsorted
;
int
*
internal_radii
;
thrust
::
device_vector
<
uint64_t
>
point_list_keys
;
float2
*
means2D
;
thrust
::
device_vector
<
uint32_t
>
point_list_unsorted
;
float
*
cov3D
;
thrust
::
device_vector
<
uint32_t
>
point_list
;
float4
*
conic_opacity
;
thrust
::
device_vector
<
char
>
scanning_space
;
float
*
rgb
;
thrust
::
device_vector
<
char
>
list_sorting_space
;
uint32_t
*
point_offsets
;
thrust
::
device_vector
<
bool
>
clamped
;
uint32_t
*
tiles_touched
;
thrust
::
device_vector
<
int
>
internal_radii
;
static
GeometryState
fromChunk
(
char
*&
chunk
,
int
P
);
// Internal state kept across forward / backward
};
thrust
::
device_vector
<
uint2
>
ranges
;
thrust
::
device_vector
<
uint32_t
>
n_contrib
;
thrust
::
device_vector
<
float
>
accum_alpha
;
thrust
::
device_vector
<
float2
>
means2D
;
thrust
::
device_vector
<
float
>
cov3D
;
thrust
::
device_vector
<
float4
>
conic_opacity
;
thrust
::
device_vector
<
float
>
rgb
;
public
:
virtual
void
markVisible
(
int
P
,
float
*
means3D
,
float
*
viewmatrix
,
float
*
projmatrix
,
bool
*
present
)
override
;
virtual
void
forward
(
struct
ImageState
const
int
P
,
int
D
,
int
M
,
{
const
float
*
background
,
uint2
*
ranges
;
const
int
width
,
int
height
,
uint32_t
*
n_contrib
;
const
float
*
means3D
,
float
*
accum_alpha
;
const
float
*
shs
,
const
float
*
colors_precomp
,
const
float
*
opacities
,
const
float
*
scales
,
const
float
scale_modifier
,
const
float
*
rotations
,
const
float
*
cov3D_precomp
,
const
float
*
viewmatrix
,
const
float
*
projmatrix
,
const
float
*
cam_pos
,
const
float
tan_fovx
,
float
tan_fovy
,
const
bool
prefiltered
,
float
*
out_color
,
int
*
radii
)
override
;
virtual
void
backward
(
static
ImageState
fromChunk
(
char
*&
chunk
,
int
N
);
const
int
P
,
int
D
,
int
M
,
};
const
float
*
background
,
const
int
width
,
int
height
,
const
float
*
means3D
,
const
float
*
shs
,
const
float
*
colors_precomp
,
const
float
*
scales
,
const
float
scale_modifier
,
const
float
*
rotations
,
const
float
*
cov3D_precomp
,
const
float
*
viewmatrix
,
const
float
*
projmatrix
,
const
float
*
campos
,
const
float
tan_fovx
,
float
tan_fovy
,
const
int
*
radii
,
const
float
*
dL_dpix
,
float
*
dL_dmean2D
,
float
*
dL_dconic
,
float
*
dL_dopacity
,
float
*
dL_dcolor
,
float
*
dL_dmean3D
,
float
*
dL_dcov3D
,
float
*
dL_dsh
,
float
*
dL_dscale
,
float
*
dL_drot
)
override
;
RasterizerImpl
(
int
resizeMultiplier
);
struct
BinningState
{
size_t
sorting_size
;
uint64_t
*
point_list_keys_unsorted
;
uint64_t
*
point_list_keys
;
uint32_t
*
point_list_unsorted
;
uint32_t
*
point_list
;
char
*
list_sorting_space
;
virtual
~
RasterizerImpl
()
override
;
static
BinningState
fromChunk
(
char
*&
chunk
,
int
P
)
;
};
};
template
<
typename
T
>
int
required
(
int
P
)
{
char
*
size
=
nullptr
;
T
::
fromChunk
(
size
,
P
);
return
((
int
)
size
)
+
128
;
}
};
};
\ No newline at end of file
diff_gaussian_rasterization/rasterizer.py
View file @
79cbd71d
...
@@ -64,19 +64,21 @@ class _RasterizeGaussians(torch.autograd.Function):
...
@@ -64,19 +64,21 @@ class _RasterizeGaussians(torch.autograd.Function):
)
)
# Invoke C++/CUDA rasterizer
# Invoke C++/CUDA rasterizer
color
,
radii
=
_C
.
rasterize_gaussians
(
*
args
)
num_rendered
,
color
,
radii
,
geomBuffer
,
binningBuffer
,
imgBuffer
=
_C
.
rasterize_gaussians
(
*
args
)
# Keep relevant tensors for backward
# Keep relevant tensors for backward
ctx
.
raster_settings
=
raster_settings
ctx
.
raster_settings
=
raster_settings
ctx
.
save_for_backward
(
colors_precomp
,
means3D
,
scales
,
rotations
,
cov3Ds_precomp
,
radii
,
sh
)
ctx
.
num_rendered
=
num_rendered
ctx
.
save_for_backward
(
colors_precomp
,
means3D
,
scales
,
rotations
,
cov3Ds_precomp
,
radii
,
sh
,
geomBuffer
,
binningBuffer
,
imgBuffer
)
return
color
,
radii
return
color
,
radii
@staticmethod
@staticmethod
def
backward
(
ctx
,
grad_out_color
,
_
):
def
backward
(
ctx
,
grad_out_color
,
_
):
# Restore necessary values from context
# Restore necessary values from context
num_rendered
=
ctx
.
num_rendered
raster_settings
=
ctx
.
raster_settings
raster_settings
=
ctx
.
raster_settings
colors_precomp
,
means3D
,
scales
,
rotations
,
cov3Ds_precomp
,
radii
,
sh
=
ctx
.
saved_tensors
colors_precomp
,
means3D
,
scales
,
rotations
,
cov3Ds_precomp
,
radii
,
sh
,
geomBuffer
,
binningBuffer
,
imgBuffer
=
ctx
.
saved_tensors
# Restructure args as C++ method expects them
# Restructure args as C++ method expects them
args
=
(
raster_settings
.
bg
,
args
=
(
raster_settings
.
bg
,
...
@@ -94,7 +96,11 @@ class _RasterizeGaussians(torch.autograd.Function):
...
@@ -94,7 +96,11 @@ class _RasterizeGaussians(torch.autograd.Function):
grad_out_color
,
grad_out_color
,
sh
,
sh
,
raster_settings
.
sh_degree
,
raster_settings
.
sh_degree
,
raster_settings
.
campos
)
raster_settings
.
campos
,
geomBuffer
,
num_rendered
,
binningBuffer
,
imgBuffer
)
# Compute gradients for relevant tensors by invoking backward method
# Compute gradients for relevant tensors by invoking backward method
grad_means2D
,
grad_colors_precomp
,
grad_opacities
,
grad_means3D
,
grad_cov3Ds_precomp
,
grad_sh
,
grad_scales
,
grad_rotations
=
_C
.
rasterize_gaussians_backward
(
*
args
)
grad_means2D
,
grad_colors_precomp
,
grad_opacities
,
grad_means3D
,
grad_cov3Ds_precomp
,
grad_sh
,
grad_scales
,
grad_rotations
=
_C
.
rasterize_gaussians_backward
(
*
args
)
...
...
rasterize_points.cu
View file @
79cbd71d
...
@@ -13,10 +13,17 @@
...
@@ -13,10 +13,17 @@
#include "cuda_rasterizer/rasterizer.h"
#include "cuda_rasterizer/rasterizer.h"
#include <fstream>
#include <fstream>
#include <string>
#include <string>
#include <functional>
static std::unique_ptr<CudaRasterizer::Rasterizer> cudaRenderer = nullptr;
std::function<char*(int N)> resizeFunctional(torch::Tensor& t) {
auto lambda = [&t](int N) {
t.resize_({N});
return reinterpret_cast<char*>(t.contiguous().data_ptr());
};
return lambda;
}
std::tuple<torch::Tensor, torch::Tensor>
std::tuple<
int, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor, torch::Tensor>
RasterizeGaussiansCUDA(
RasterizeGaussiansCUDA(
const torch::Tensor& background,
const torch::Tensor& background,
const torch::Tensor& means3D,
const torch::Tensor& means3D,
...
@@ -37,16 +44,10 @@ RasterizeGaussiansCUDA(
...
@@ -37,16 +44,10 @@ RasterizeGaussiansCUDA(
const torch::Tensor& campos,
const torch::Tensor& campos,
const bool prefiltered)
const bool prefiltered)
{
{
if (means3D.ndimension() != 2 || means3D.size(1) != 3) {
if (means3D.ndimension() != 2 || means3D.size(1) != 3) {
AT_ERROR("means3D must have dimensions (num_points, 3)");
AT_ERROR("means3D must have dimensions (num_points, 3)");
}
}
if (cudaRenderer == nullptr)
{
cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make());
}
const int P = means3D.size(0);
const int P = means3D.size(0);
const int N = 1; // batch size hard-coded
const int N = 1; // batch size hard-coded
const int H = image_height;
const int H = image_height;
...
@@ -57,7 +58,17 @@ RasterizeGaussiansCUDA(
...
@@ -57,7 +58,17 @@ RasterizeGaussiansCUDA(
torch::Tensor out_color = torch::full({N, NUM_CHANNELS, H, W}, 0.0, float_opts);
torch::Tensor out_color = torch::full({N, NUM_CHANNELS, H, W}, 0.0, float_opts);
torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32));
torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32));
torch::Device device(torch::kCUDA);
torch::TensorOptions options(torch::kByte);
torch::Tensor geomBuffer = torch::empty({0}, options.device(device));
torch::Tensor binningBuffer = torch::empty({0}, options.device(device));
torch::Tensor imgBuffer = torch::empty({0}, options.device(device));
std::function<char*(int)> geomFunc = resizeFunctional(geomBuffer);
std::function<char*(int)> binningFunc = resizeFunctional(binningBuffer);
std::function<char*(int)> imgFunc = resizeFunctional(imgBuffer);
int rendered = 0;
if(P != 0)
if(P != 0)
{
{
int M = 0;
int M = 0;
...
@@ -66,7 +77,11 @@ RasterizeGaussiansCUDA(
...
@@ -66,7 +77,11 @@ RasterizeGaussiansCUDA(
M = sh.size(1);
M = sh.size(1);
}
}
cudaRenderer->forward(P, degree, M,
rendered = CudaRasterizer::Rasterizer::forward(
geomFunc,
binningFunc,
imgFunc,
P, degree, M,
background.contiguous().data<float>(),
background.contiguous().data<float>(),
W, H,
W, H,
means3D.contiguous().data<float>(),
means3D.contiguous().data<float>(),
...
@@ -86,7 +101,7 @@ RasterizeGaussiansCUDA(
...
@@ -86,7 +101,7 @@ RasterizeGaussiansCUDA(
out_color.contiguous().data<float>(),
out_color.contiguous().data<float>(),
radii.contiguous().data<int>());
radii.contiguous().data<int>());
}
}
return std::make_tuple(
out_color, radii
);
return std::make_tuple(
rendered, out_color, radii, geomBuffer, binningBuffer, imgBuffer
);
}
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
...
@@ -106,7 +121,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
...
@@ -106,7 +121,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
const torch::Tensor& dL_dout_color,
const torch::Tensor& dL_dout_color,
const torch::Tensor& sh,
const torch::Tensor& sh,
const int degree,
const int degree,
const torch::Tensor& campos)
const torch::Tensor& campos,
const torch::Tensor& geomBuffer,
const int R,
const torch::Tensor& binningBuffer,
const torch::Tensor& imageBuffer)
{
{
const int P = means3D.size(0);
const int P = means3D.size(0);
const int H = dL_dout_color.size(2);
const int H = dL_dout_color.size(2);
...
@@ -130,7 +149,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
...
@@ -130,7 +149,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
if(P != 0)
if(P != 0)
{
{
cudaRenderer->backward(P, degree, M
,
CudaRasterizer::Rasterizer::backward(P, degree, M, R
,
background.contiguous().data<float>(),
background.contiguous().data<float>(),
W, H,
W, H,
means3D.contiguous().data<float>(),
means3D.contiguous().data<float>(),
...
@@ -146,6 +165,9 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
...
@@ -146,6 +165,9 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
tan_fovx,
tan_fovx,
tan_fovy,
tan_fovy,
radii.contiguous().data<int>(),
radii.contiguous().data<int>(),
reinterpret_cast<char*>(geomBuffer.contiguous().data_ptr()),
reinterpret_cast<char*>(binningBuffer.contiguous().data_ptr()),
reinterpret_cast<char*>(imageBuffer.contiguous().data_ptr()),
dL_dout_color.contiguous().data<float>(),
dL_dout_color.contiguous().data<float>(),
dL_dmeans2D.contiguous().data<float>(),
dL_dmeans2D.contiguous().data<float>(),
dL_dconic.contiguous().data<float>(),
dL_dconic.contiguous().data<float>(),
...
@@ -166,18 +188,13 @@ torch::Tensor markVisible(
...
@@ -166,18 +188,13 @@ torch::Tensor markVisible(
torch::Tensor& viewmatrix,
torch::Tensor& viewmatrix,
torch::Tensor& projmatrix)
torch::Tensor& projmatrix)
{
{
if (cudaRenderer == nullptr)
{
cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make());
}
const int P = means3D.size(0);
const int P = means3D.size(0);
torch::Tensor present = torch::full({P}, false, means3D.options().dtype(at::kBool));
torch::Tensor present = torch::full({P}, false, means3D.options().dtype(at::kBool));
if(P != 0)
if(P != 0)
{
{
cudaRenderer->
markVisible(P,
CudaRasterizer::Rasterizer::
markVisible(P,
means3D.contiguous().data<float>(),
means3D.contiguous().data<float>(),
viewmatrix.contiguous().data<float>(),
viewmatrix.contiguous().data<float>(),
projmatrix.contiguous().data<float>(),
projmatrix.contiguous().data<float>(),
...
...
rasterize_points.h
View file @
79cbd71d
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include <tuple>
#include <tuple>
#include <string>
#include <string>
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
std
::
tuple
<
int
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
RasterizeGaussiansCUDA
(
RasterizeGaussiansCUDA
(
const
torch
::
Tensor
&
background
,
const
torch
::
Tensor
&
background
,
const
torch
::
Tensor
&
means3D
,
const
torch
::
Tensor
&
means3D
,
...
@@ -44,7 +44,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
...
@@ -44,7 +44,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
const
torch
::
Tensor
&
dL_dout_color
,
const
torch
::
Tensor
&
dL_dout_color
,
const
torch
::
Tensor
&
sh
,
const
torch
::
Tensor
&
sh
,
const
int
degree
,
const
int
degree
,
const
torch
::
Tensor
&
campos
);
const
torch
::
Tensor
&
campos
,
const
torch
::
Tensor
&
geomBuffer
,
const
int
R
,
const
torch
::
Tensor
&
binningBuffer
,
const
torch
::
Tensor
&
imageBuffer
);
torch
::
Tensor
markVisible
(
torch
::
Tensor
markVisible
(
torch
::
Tensor
&
means3D
,
torch
::
Tensor
&
means3D
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment