Commit 8a219fd5 by Bernhard Kerbl

changes:

parent 64f7c223
...@@ -13,6 +13,7 @@ def rasterize_gaussians( ...@@ -13,6 +13,7 @@ def rasterize_gaussians(
rotations, rotations,
cov3Ds_precomp, cov3Ds_precomp,
raster_settings, raster_settings,
rasterizer_state
): ):
return _RasterizeGaussians.apply( return _RasterizeGaussians.apply(
means3D, means3D,
...@@ -24,6 +25,7 @@ def rasterize_gaussians( ...@@ -24,6 +25,7 @@ def rasterize_gaussians(
rotations, rotations,
cov3Ds_precomp, cov3Ds_precomp,
raster_settings, raster_settings,
rasterizer_state
) )
class _RasterizeGaussians(torch.autograd.Function): class _RasterizeGaussians(torch.autograd.Function):
...@@ -39,10 +41,9 @@ class _RasterizeGaussians(torch.autograd.Function): ...@@ -39,10 +41,9 @@ class _RasterizeGaussians(torch.autograd.Function):
rotations, rotations,
cov3Ds_precomp, cov3Ds_precomp,
raster_settings, raster_settings,
rasterizer_state
): ):
rasterizer_state = _C.create_rasterizer_state()
# Restructure arguments the way that the C++ lib expects them # Restructure arguments the way that the C++ lib expects them
args = ( args = (
raster_settings.bg, raster_settings.bg,
...@@ -115,10 +116,9 @@ class _RasterizeGaussians(torch.autograd.Function): ...@@ -115,10 +116,9 @@ class _RasterizeGaussians(torch.autograd.Function):
grad_rotations, grad_rotations,
grad_cov3Ds_precomp, grad_cov3Ds_precomp,
None, None,
None,
) )
_C.delete_rasterizer_state(rasterizer_state)
return grads return grads
class GaussianRasterizationSettings(NamedTuple): class GaussianRasterizationSettings(NamedTuple):
...@@ -134,10 +134,17 @@ class GaussianRasterizationSettings(NamedTuple): ...@@ -134,10 +134,17 @@ class GaussianRasterizationSettings(NamedTuple):
campos : torch.Tensor campos : torch.Tensor
prefiltered : bool prefiltered : bool
def createRasterizerState():
return _C.create_rasterizer_state()
def deleteRasterizerState(state):
return _C.delete_rasterize_state(state)
class GaussianRasterizer(nn.Module): class GaussianRasterizer(nn.Module):
def __init__(self, raster_settings): def __init__(self, raster_settings, rasterizer_state):
super().__init__() super().__init__()
self.raster_settings = raster_settings self.raster_settings = raster_settings
self.rasterizer_state = rasterizer_state
def markVisible(self, positions): def markVisible(self, positions):
# Mark visible points (based on frustum culling for camera) with a boolean # Mark visible points (based on frustum culling for camera) with a boolean
...@@ -151,8 +158,8 @@ class GaussianRasterizer(nn.Module): ...@@ -151,8 +158,8 @@ class GaussianRasterizer(nn.Module):
return visible return visible
def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None): def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None):
raster_settings = self.raster_settings raster_settings = self.raster_settings
rasterize_state = self.rasterizer_state
if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None): if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None):
raise Exception('Please provide excatly one of either SHs or precomputed colors!') raise Exception('Please provide excatly one of either SHs or precomputed colors!')
...@@ -183,5 +190,6 @@ class GaussianRasterizer(nn.Module): ...@@ -183,5 +190,6 @@ class GaussianRasterizer(nn.Module):
rotations, rotations,
cov3D_precomp, cov3D_precomp,
raster_settings, raster_settings,
rasterize_state
) )
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment