Commit e41a365f by Bernhard Kerbl

working with instances

parent 3851457e
...@@ -4,6 +4,7 @@ import torch ...@@ -4,6 +4,7 @@ import torch
from . import _C from . import _C
def rasterize_gaussians( def rasterize_gaussians(
instance,
means3D, means3D,
means2D, means2D,
sh, sh,
...@@ -16,6 +17,7 @@ def rasterize_gaussians( ...@@ -16,6 +17,7 @@ def rasterize_gaussians(
rasterizer_state rasterizer_state
): ):
return _RasterizeGaussians.apply( return _RasterizeGaussians.apply(
instance,
means3D, means3D,
means2D, means2D,
sh, sh,
...@@ -32,6 +34,7 @@ class _RasterizeGaussians(torch.autograd.Function): ...@@ -32,6 +34,7 @@ class _RasterizeGaussians(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
ctx, ctx,
instance,
means3D, means3D,
means2D, means2D,
sh, sh,
...@@ -46,6 +49,7 @@ class _RasterizeGaussians(torch.autograd.Function): ...@@ -46,6 +49,7 @@ class _RasterizeGaussians(torch.autograd.Function):
# Restructure arguments the way that the C++ lib expects them # Restructure arguments the way that the C++ lib expects them
args = ( args = (
instance,
raster_settings.bg, raster_settings.bg,
means3D, means3D,
colors_precomp, colors_precomp,
...@@ -72,6 +76,7 @@ class _RasterizeGaussians(torch.autograd.Function): ...@@ -72,6 +76,7 @@ class _RasterizeGaussians(torch.autograd.Function):
# Keep relevant tensors for backward # Keep relevant tensors for backward
ctx.raster_settings = raster_settings ctx.raster_settings = raster_settings
ctx.instance = instance
ctx.rasterizer_state = rasterizer_state ctx.rasterizer_state = rasterizer_state
ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh) ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh)
return color, radii return color, radii
...@@ -80,12 +85,14 @@ class _RasterizeGaussians(torch.autograd.Function): ...@@ -80,12 +85,14 @@ class _RasterizeGaussians(torch.autograd.Function):
def backward(ctx, grad_out_color, _): def backward(ctx, grad_out_color, _):
# Restore necessary values from context # Restore necessary values from context
instance = ctx.instance
rasterizer_state = ctx.rasterizer_state rasterizer_state = ctx.rasterizer_state
raster_settings = ctx.raster_settings raster_settings = ctx.raster_settings
colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh = ctx.saved_tensors colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh = ctx.saved_tensors
# Restructure args as C++ method expects them # Restructure args as C++ method expects them
args = (rasterizer_state, args = (instance,
rasterizer_state,
raster_settings.bg, raster_settings.bg,
means3D, means3D,
radii, radii,
...@@ -107,6 +114,7 @@ class _RasterizeGaussians(torch.autograd.Function): ...@@ -107,6 +114,7 @@ class _RasterizeGaussians(torch.autograd.Function):
grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args) grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
grads = ( grads = (
None,
grad_means3D, grad_means3D,
grad_means2D, grad_means2D,
grad_sh, grad_sh,
...@@ -134,33 +142,32 @@ class GaussianRasterizationSettings(NamedTuple): ...@@ -134,33 +142,32 @@ class GaussianRasterizationSettings(NamedTuple):
campos : torch.Tensor campos : torch.Tensor
prefiltered : bool prefiltered : bool
def createRasterizerState():
return _C.create_rasterizer_state()
def deleteRasterizerState(state):
return _C.delete_rasterize_state(state)
class GaussianRasterizer(nn.Module): class GaussianRasterizer(nn.Module):
def __init__(self, raster_settings, rasterizer_state): def __init__(self):
super().__init__() super().__init__()
self.raster_settings = raster_settings self.instance = _C.create_rasterizer()
self.rasterizer_state = rasterizer_state
def __del__(self):
_C.delete_rasterizer(self.instance)
def markVisible(self, positions): def createRasterizerState(self):
return _C.create_rasterizer_state(self.instance)
def deleteRasterizerState(self, state):
_C.delete_rasterizer_state(self.instance, state)
def markVisible(self, raster_settings, positions):
# Mark visible points (based on frustum culling for camera) with a boolean # Mark visible points (based on frustum culling for camera) with a boolean
with torch.no_grad(): with torch.no_grad():
raster_settings = self.raster_settings
visible = _C.mark_visible( visible = _C.mark_visible(
self.instance,
positions, positions,
raster_settings.viewmatrix, raster_settings.viewmatrix,
raster_settings.projmatrix) raster_settings.projmatrix)
return visible return visible
def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None): def forward(self, rasterizer_state, raster_settings, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None):
raster_settings = self.raster_settings
rasterize_state = self.rasterizer_state
if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None): if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None):
raise Exception('Please provide excatly one of either SHs or precomputed colors!') raise Exception('Please provide excatly one of either SHs or precomputed colors!')
...@@ -181,6 +188,7 @@ class GaussianRasterizer(nn.Module): ...@@ -181,6 +188,7 @@ class GaussianRasterizer(nn.Module):
# Invoke C++/CUDA rasterization routine # Invoke C++/CUDA rasterization routine
return rasterize_gaussians( return rasterize_gaussians(
self.instance,
means3D, means3D,
means2D, means2D,
shs, shs,
...@@ -190,6 +198,6 @@ class GaussianRasterizer(nn.Module): ...@@ -190,6 +198,6 @@ class GaussianRasterizer(nn.Module):
rotations, rotations,
cov3D_precomp, cov3D_precomp,
raster_settings, raster_settings,
rasterize_state rasterizer_state
) )
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment