Commit 3851457e by Bernhard Kerbl

changes

parent 8a219fd5
......@@ -74,6 +74,8 @@ namespace CudaRasterizer
virtual ~Rasterizer() {};
static Rasterizer* make(int resizeMultipliyer = 2);
static void kill(Rasterizer* rasterizer);
};
};
......
......@@ -137,6 +137,11 @@ CudaRasterizer::Rasterizer* CudaRasterizer::Rasterizer::make(int resizeMultiplie
return new CudaRasterizer::RasterizerImpl(resizeMultiplier);
}
void CudaRasterizer::Rasterizer::kill(Rasterizer* rasterizer)
{
delete rasterizer;
}
// Mark Gaussians as visible/invisible, based on view frustum testing
void CudaRasterizer::RasterizerImpl::markVisible(
int P,
......
......@@ -8,4 +8,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("mark_visible", &markVisible);
m.def("create_rasterizer_state", &createRasterizerState);
m.def("delete_rasterizer_state", &deleteRasterizerState);
m.def("create_rasterizer", &createRasterizer);
m.def("delete_rasterizer", &deleteRasterizer);
}
\ No newline at end of file
......@@ -15,24 +15,29 @@
#include <fstream>
#include <string>
static std::unique_ptr<CudaRasterizer::Rasterizer> cudaRenderer = nullptr;
void* createRasterizer()
{
return (void*)CudaRasterizer::Rasterizer::make();
}
void* createRasterizerState()
void deleteRasterizer(void* rasterizer)
{
if (cudaRenderer == nullptr)
{
cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make());
}
return (void*)cudaRenderer->createInternalState();
CudaRasterizer::Rasterizer::kill((CudaRasterizer::Rasterizer*)rasterizer);
}
void* createRasterizerState(void* rasterizer)
{
return (void*)((CudaRasterizer::Rasterizer*)rasterizer)->createInternalState();
}
void deleteRasterizerState(void* state)
void deleteRasterizerState(void* rasterizer, void* state)
{
cudaRenderer->killInternalState((CudaRasterizer::InternalState*)state);
((CudaRasterizer::Rasterizer*)rasterizer)->killInternalState((CudaRasterizer::InternalState*)state);
}
std::tuple<torch::Tensor, torch::Tensor>
RasterizeGaussiansCUDA(
void* rasterizer,
const torch::Tensor& background,
const torch::Tensor& means3D,
const torch::Tensor& colors,
......@@ -58,11 +63,6 @@ RasterizeGaussiansCUDA(
AT_ERROR("means3D must have dimensions (num_points, 3)");
}
if (cudaRenderer == nullptr)
{
cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make());
}
const int P = means3D.size(0);
const int N = 1; // batch size hard-coded
const int H = image_height;
......@@ -82,7 +82,7 @@ RasterizeGaussiansCUDA(
M = sh.size(1);
}
cudaRenderer->forward(P, degree, M,
((CudaRasterizer::Rasterizer*)rasterizer)->forward(P, degree, M,
background.contiguous().data<float>(),
W, H,
means3D.contiguous().data<float>(),
......@@ -108,6 +108,7 @@ RasterizeGaussiansCUDA(
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeGaussiansBackwardCUDA(
void* rasterizer,
const void* internalState,
const torch::Tensor& background,
const torch::Tensor& means3D,
......@@ -148,7 +149,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
if(P != 0)
{
cudaRenderer->backward(
((CudaRasterizer::Rasterizer*)rasterizer)->backward(
radii.contiguous().data<int>(),
(CudaRasterizer::InternalState*)internalState,
P, degree, M,
......@@ -182,14 +183,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
}
torch::Tensor markVisible(
void* rasterizer,
torch::Tensor& means3D,
torch::Tensor& viewmatrix,
torch::Tensor& projmatrix)
{
if (cudaRenderer == nullptr)
{
cudaRenderer = std::unique_ptr<CudaRasterizer::Rasterizer>(CudaRasterizer::Rasterizer::make());
}
const int P = means3D.size(0);
......@@ -197,7 +195,7 @@ torch::Tensor markVisible(
if(P != 0)
{
cudaRenderer->markVisible(P,
((CudaRasterizer::Rasterizer*)rasterizer)->markVisible(P,
means3D.contiguous().data<float>(),
viewmatrix.contiguous().data<float>(),
projmatrix.contiguous().data<float>(),
......
......@@ -8,6 +8,7 @@
std::tuple<torch::Tensor, torch::Tensor>
RasterizeGaussiansCUDA(
void* rasterizer,
const torch::Tensor& background,
const torch::Tensor& means3D,
const torch::Tensor& colors,
......@@ -30,6 +31,7 @@ RasterizeGaussiansCUDA(
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeGaussiansBackwardCUDA(
void* rasterizer,
const void* internalState,
const torch::Tensor& background,
const torch::Tensor& means3D,
......@@ -48,11 +50,16 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
const int degree,
const torch::Tensor& campos);
void* createRasterizerState();
void* createRasterizerState(void* rasterizer);
void deleteRasterizerState(void* state);
void deleteRasterizerState(void* rasterizer, void* state);
void* createRasterizer();
void deleteRasterizer(void* rasterizer);
torch::Tensor markVisible(
void* rasterizer,
torch::Tensor& means3D,
torch::Tensor& viewmatrix,
torch::Tensor& projmatrix);
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment