Commit 3851457e by Bernhard Kerbl

changes

parent 8a219fd5
...@@ -74,6 +74,8 @@ namespace CudaRasterizer ...@@ -74,6 +74,8 @@ namespace CudaRasterizer
virtual ~Rasterizer() {}; virtual ~Rasterizer() {};
static Rasterizer* make(int resizeMultipliyer = 2); static Rasterizer* make(int resizeMultipliyer = 2);
static void kill(Rasterizer* rasterizer);
}; };
}; };
......
...@@ -137,6 +137,11 @@ CudaRasterizer::Rasterizer* CudaRasterizer::Rasterizer::make(int resizeMultiplie ...@@ -137,6 +137,11 @@ CudaRasterizer::Rasterizer* CudaRasterizer::Rasterizer::make(int resizeMultiplie
return new CudaRasterizer::RasterizerImpl(resizeMultiplier); return new CudaRasterizer::RasterizerImpl(resizeMultiplier);
} }
void CudaRasterizer::Rasterizer::kill(Rasterizer* rasterizer)
{
delete rasterizer;
}
// Mark Gaussians as visible/invisible, based on view frustum testing // Mark Gaussians as visible/invisible, based on view frustum testing
void CudaRasterizer::RasterizerImpl::markVisible( void CudaRasterizer::RasterizerImpl::markVisible(
int P, int P,
......
...@@ -8,4 +8,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) ...@@ -8,4 +8,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("mark_visible", &markVisible); m.def("mark_visible", &markVisible);
m.def("create_rasterizer_state", &createRasterizerState); m.def("create_rasterizer_state", &createRasterizerState);
m.def("delete_rasterizer_state", &deleteRasterizerState); m.def("delete_rasterizer_state", &deleteRasterizerState);
m.def("create_rasterizer", &createRasterizer);
m.def("delete_rasterizer", &deleteRasterizer);
} }
\ No newline at end of file
...@@ -15,24 +15,29 @@ ...@@ -15,24 +15,29 @@
#include <fstream> #include <fstream>
#include <string> #include <string>
static std::unique_ptr<CudaRasterizer::Rasterizer> cudaRenderer = nullptr; void* createRasterizer()
{
return (void*)CudaRasterizer::Rasterizer::make();
}
void* createRasterizerState() void deleteRasterizer(void* rasterizer)
{ {
if (cudaRenderer == nullptr) CudaRasterizer::Rasterizer::kill((CudaRasterizer::Rasterizer*)rasterizer);
{ }
cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make());
} void* createRasterizerState(void* rasterizer)
return (void*)cudaRenderer->createInternalState(); {
return (void*)((CudaRasterizer::Rasterizer*)rasterizer)->createInternalState();
} }
void deleteRasterizerState(void* state) void deleteRasterizerState(void* rasterizer, void* state)
{ {
cudaRenderer->killInternalState((CudaRasterizer::InternalState*)state); ((CudaRasterizer::Rasterizer*)rasterizer)->killInternalState((CudaRasterizer::InternalState*)state);
} }
std::tuple<torch::Tensor, torch::Tensor> std::tuple<torch::Tensor, torch::Tensor>
RasterizeGaussiansCUDA( RasterizeGaussiansCUDA(
void* rasterizer,
const torch::Tensor& background, const torch::Tensor& background,
const torch::Tensor& means3D, const torch::Tensor& means3D,
const torch::Tensor& colors, const torch::Tensor& colors,
...@@ -58,11 +63,6 @@ RasterizeGaussiansCUDA( ...@@ -58,11 +63,6 @@ RasterizeGaussiansCUDA(
AT_ERROR("means3D must have dimensions (num_points, 3)"); AT_ERROR("means3D must have dimensions (num_points, 3)");
} }
if (cudaRenderer == nullptr)
{
cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make());
}
const int P = means3D.size(0); const int P = means3D.size(0);
const int N = 1; // batch size hard-coded const int N = 1; // batch size hard-coded
const int H = image_height; const int H = image_height;
...@@ -82,7 +82,7 @@ RasterizeGaussiansCUDA( ...@@ -82,7 +82,7 @@ RasterizeGaussiansCUDA(
M = sh.size(1); M = sh.size(1);
} }
cudaRenderer->forward(P, degree, M, ((CudaRasterizer::Rasterizer*)rasterizer)->forward(P, degree, M,
background.contiguous().data<float>(), background.contiguous().data<float>(),
W, H, W, H,
means3D.contiguous().data<float>(), means3D.contiguous().data<float>(),
...@@ -108,6 +108,7 @@ RasterizeGaussiansCUDA( ...@@ -108,6 +108,7 @@ RasterizeGaussiansCUDA(
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeGaussiansBackwardCUDA( RasterizeGaussiansBackwardCUDA(
void* rasterizer,
const void* internalState, const void* internalState,
const torch::Tensor& background, const torch::Tensor& background,
const torch::Tensor& means3D, const torch::Tensor& means3D,
...@@ -148,7 +149,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te ...@@ -148,7 +149,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
if(P != 0) if(P != 0)
{ {
cudaRenderer->backward( ((CudaRasterizer::Rasterizer*)rasterizer)->backward(
radii.contiguous().data<int>(), radii.contiguous().data<int>(),
(CudaRasterizer::InternalState*)internalState, (CudaRasterizer::InternalState*)internalState,
P, degree, M, P, degree, M,
...@@ -182,14 +183,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te ...@@ -182,14 +183,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
} }
torch::Tensor markVisible( torch::Tensor markVisible(
void* rasterizer,
torch::Tensor& means3D, torch::Tensor& means3D,
torch::Tensor& viewmatrix, torch::Tensor& viewmatrix,
torch::Tensor& projmatrix) torch::Tensor& projmatrix)
{ {
if (cudaRenderer == nullptr)
{
cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make());
}
const int P = means3D.size(0); const int P = means3D.size(0);
...@@ -197,7 +195,7 @@ torch::Tensor markVisible( ...@@ -197,7 +195,7 @@ torch::Tensor markVisible(
if(P != 0) if(P != 0)
{ {
cudaRenderer->markVisible(P, ((CudaRasterizer::Rasterizer*)rasterizer)->markVisible(P,
means3D.contiguous().data<float>(), means3D.contiguous().data<float>(),
viewmatrix.contiguous().data<float>(), viewmatrix.contiguous().data<float>(),
projmatrix.contiguous().data<float>(), projmatrix.contiguous().data<float>(),
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
std::tuple<torch::Tensor, torch::Tensor> std::tuple<torch::Tensor, torch::Tensor>
RasterizeGaussiansCUDA( RasterizeGaussiansCUDA(
void* rasterizer,
const torch::Tensor& background, const torch::Tensor& background,
const torch::Tensor& means3D, const torch::Tensor& means3D,
const torch::Tensor& colors, const torch::Tensor& colors,
...@@ -30,6 +31,7 @@ RasterizeGaussiansCUDA( ...@@ -30,6 +31,7 @@ RasterizeGaussiansCUDA(
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeGaussiansBackwardCUDA( RasterizeGaussiansBackwardCUDA(
void* rasterizer,
const void* internalState, const void* internalState,
const torch::Tensor& background, const torch::Tensor& background,
const torch::Tensor& means3D, const torch::Tensor& means3D,
...@@ -48,11 +50,16 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te ...@@ -48,11 +50,16 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
const int degree, const int degree,
const torch::Tensor& campos); const torch::Tensor& campos);
void* createRasterizerState(); void* createRasterizerState(void* rasterizer);
void deleteRasterizerState(void* state); void deleteRasterizerState(void* rasterizer, void* state);
void* createRasterizer();
void deleteRasterizer(void* rasterizer);
torch::Tensor markVisible( torch::Tensor markVisible(
void* rasterizer,
torch::Tensor& means3D, torch::Tensor& means3D,
torch::Tensor& viewmatrix, torch::Tensor& viewmatrix,
torch::Tensor& projmatrix); torch::Tensor& projmatrix);
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment