Commit 64f7c223 by Bernhard Kerbl

work with state

parent bccbb2e8
...@@ -41,6 +41,8 @@ class _RasterizeGaussians(torch.autograd.Function): ...@@ -41,6 +41,8 @@ class _RasterizeGaussians(torch.autograd.Function):
raster_settings, raster_settings,
): ):
rasterizer_state = _C.create_rasterizer_state()
# Restructure arguments the way that the C++ lib expects them # Restructure arguments the way that the C++ lib expects them
args = ( args = (
raster_settings.bg, raster_settings.bg,
...@@ -61,6 +63,7 @@ class _RasterizeGaussians(torch.autograd.Function): ...@@ -61,6 +63,7 @@ class _RasterizeGaussians(torch.autograd.Function):
raster_settings.sh_degree, raster_settings.sh_degree,
raster_settings.campos, raster_settings.campos,
raster_settings.prefiltered, raster_settings.prefiltered,
rasterizer_state
) )
# Invoke C++/CUDA rasterizer # Invoke C++/CUDA rasterizer
...@@ -68,6 +71,7 @@ class _RasterizeGaussians(torch.autograd.Function): ...@@ -68,6 +71,7 @@ class _RasterizeGaussians(torch.autograd.Function):
# Keep relevant tensors for backward # Keep relevant tensors for backward
ctx.raster_settings = raster_settings ctx.raster_settings = raster_settings
ctx.rasterizer_state = rasterizer_state
ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh) ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh)
return color, radii return color, radii
...@@ -75,11 +79,13 @@ class _RasterizeGaussians(torch.autograd.Function): ...@@ -75,11 +79,13 @@ class _RasterizeGaussians(torch.autograd.Function):
def backward(ctx, grad_out_color, _): def backward(ctx, grad_out_color, _):
# Restore necessary values from context # Restore necessary values from context
rasterizer_state = ctx.rasterizer_state
raster_settings = ctx.raster_settings raster_settings = ctx.raster_settings
colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh = ctx.saved_tensors colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh = ctx.saved_tensors
# Restructure args as C++ method expects them # Restructure args as C++ method expects them
args = (raster_settings.bg, args = (rasterizer_state,
raster_settings.bg,
means3D, means3D,
radii, radii,
colors_precomp, colors_precomp,
...@@ -111,6 +117,8 @@ class _RasterizeGaussians(torch.autograd.Function): ...@@ -111,6 +117,8 @@ class _RasterizeGaussians(torch.autograd.Function):
None, None,
) )
_C.delete_rasterizer_state(rasterizer_state)
return grads return grads
class GaussianRasterizationSettings(NamedTuple): class GaussianRasterizationSettings(NamedTuple):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment