Commit 64f7c223 by Bernhard Kerbl

work with state

parent bccbb2e8
......@@ -41,6 +41,8 @@ class _RasterizeGaussians(torch.autograd.Function):
raster_settings,
):
rasterizer_state = _C.create_rasterizer_state()
# Restructure arguments the way that the C++ lib expects them
args = (
raster_settings.bg,
......@@ -61,6 +63,7 @@ class _RasterizeGaussians(torch.autograd.Function):
raster_settings.sh_degree,
raster_settings.campos,
raster_settings.prefiltered,
rasterizer_state
)
# Invoke C++/CUDA rasterizer
......@@ -68,6 +71,7 @@ class _RasterizeGaussians(torch.autograd.Function):
# Keep relevant tensors for backward
ctx.raster_settings = raster_settings
ctx.rasterizer_state = rasterizer_state
ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh)
return color, radii
......@@ -75,11 +79,13 @@ class _RasterizeGaussians(torch.autograd.Function):
def backward(ctx, grad_out_color, _):
# Restore necessary values from context
rasterizer_state = ctx.rasterizer_state
raster_settings = ctx.raster_settings
colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh = ctx.saved_tensors
# Restructure args as C++ method expects them
args = (raster_settings.bg,
args = (rasterizer_state,
raster_settings.bg,
means3D,
radii,
colors_precomp,
......@@ -111,6 +117,8 @@ class _RasterizeGaussians(torch.autograd.Function):
None,
)
_C.delete_rasterizer_state(rasterizer_state)
return grads
class GaussianRasterizationSettings(NamedTuple):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment