Commit e41a365f by Bernhard Kerbl

working with instances

parent 3851457e
......@@ -4,6 +4,7 @@ import torch
from . import _C
def rasterize_gaussians(
instance,
means3D,
means2D,
sh,
......@@ -16,6 +17,7 @@ def rasterize_gaussians(
rasterizer_state
):
return _RasterizeGaussians.apply(
instance,
means3D,
means2D,
sh,
......@@ -32,6 +34,7 @@ class _RasterizeGaussians(torch.autograd.Function):
@staticmethod
def forward(
ctx,
instance,
means3D,
means2D,
sh,
......@@ -46,6 +49,7 @@ class _RasterizeGaussians(torch.autograd.Function):
# Restructure arguments the way that the C++ lib expects them
args = (
instance,
raster_settings.bg,
means3D,
colors_precomp,
......@@ -72,6 +76,7 @@ class _RasterizeGaussians(torch.autograd.Function):
# Keep relevant tensors for backward
ctx.raster_settings = raster_settings
ctx.instance = instance
ctx.rasterizer_state = rasterizer_state
ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh)
return color, radii
......@@ -80,12 +85,14 @@ class _RasterizeGaussians(torch.autograd.Function):
def backward(ctx, grad_out_color, _):
# Restore necessary values from context
instance = ctx.instance
rasterizer_state = ctx.rasterizer_state
raster_settings = ctx.raster_settings
colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh = ctx.saved_tensors
# Restructure args as C++ method expects them
args = (rasterizer_state,
args = (instance,
rasterizer_state,
raster_settings.bg,
means3D,
radii,
......@@ -107,6 +114,7 @@ class _RasterizeGaussians(torch.autograd.Function):
grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
grads = (
None,
grad_means3D,
grad_means2D,
grad_sh,
......@@ -134,33 +142,32 @@ class GaussianRasterizationSettings(NamedTuple):
campos : torch.Tensor
prefiltered : bool
def createRasterizerState():
return _C.create_rasterizer_state()
def deleteRasterizerState(state):
return _C.delete_rasterize_state(state)
class GaussianRasterizer(nn.Module):
def __init__(self, raster_settings, rasterizer_state):
def __init__(self):
super().__init__()
self.raster_settings = raster_settings
self.rasterizer_state = rasterizer_state
self.instance = _C.create_rasterizer()
def __del__(self):
_C.delete_rasterizer(self.instance)
def markVisible(self, positions):
def createRasterizerState(self):
return _C.create_rasterizer_state(self.instance)
def deleteRasterizerState(self, state):
_C.delete_rasterizer_state(self.instance, state)
def markVisible(self, raster_settings, positions):
# Mark visible points (based on frustum culling for camera) with a boolean
with torch.no_grad():
raster_settings = self.raster_settings
visible = _C.mark_visible(
self.instance,
positions,
raster_settings.viewmatrix,
raster_settings.projmatrix)
return visible
def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None):
raster_settings = self.raster_settings
rasterize_state = self.rasterizer_state
def forward(self, rasterizer_state, raster_settings, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None):
if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None):
raise Exception('Please provide excatly one of either SHs or precomputed colors!')
......@@ -181,6 +188,7 @@ class GaussianRasterizer(nn.Module):
# Invoke C++/CUDA rasterization routine
return rasterize_gaussians(
self.instance,
means3D,
means2D,
shs,
......@@ -190,6 +198,6 @@ class GaussianRasterizer(nn.Module):
rotations,
cov3D_precomp,
raster_settings,
rasterize_state
rasterizer_state
)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment