Commit 73917be7 by bkerbl

Debug functionality

parent f6f13c68
......@@ -163,4 +163,13 @@ __forceinline__ __device__ bool in_frustum(int idx,
return true;
}
#define CHECK_CUDA(A, debug) \
A; if(debug) { \
auto ret = cudaDeviceSynchronize(); \
if (ret != cudaSuccess) { \
std::cerr << "\n[CUDA ERROR] in " << __FILE__ << "\nLine " << __LINE__ << ": " << cudaGetErrorString(ret); \
throw std::runtime_error(cudaGetErrorString(ret)); \
} \
}
#endif
\ No newline at end of file
......@@ -49,7 +49,8 @@ namespace CudaRasterizer
const float tan_fovx, float tan_fovy,
const bool prefiltered,
float* out_color,
int* radii = nullptr);
int* radii = nullptr,
bool debug = false);
static void backward(
const int P, int D, int M, int R,
......@@ -79,7 +80,8 @@ namespace CudaRasterizer
float* dL_dcov3D,
float* dL_dsh,
float* dL_dscale,
float* dL_drot);
float* dL_drot,
bool debug);
};
};
......
......@@ -216,7 +216,8 @@ int CudaRasterizer::Rasterizer::forward(
const float tan_fovx, float tan_fovy,
const bool prefiltered,
float* out_color,
int* radii)
int* radii,
bool debug)
{
const float focal_y = height / (2.0f * tan_fovy);
const float focal_x = width / (2.0f * tan_fovx);
......@@ -244,7 +245,7 @@ int CudaRasterizer::Rasterizer::forward(
}
// Run preprocessing per-Gaussian (transformation, bounding, conversion of SHs to RGB)
FORWARD::preprocess(
CHECK_CUDA(FORWARD::preprocess(
P, D, M,
means3D,
(glm::vec3*)scales,
......@@ -269,16 +270,15 @@ int CudaRasterizer::Rasterizer::forward(
tile_grid,
geomState.tiles_touched,
prefiltered
);
), debug)
// Compute prefix sum over full list of touched tile counts by Gaussians
// E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8]
cub::DeviceScan::InclusiveSum(geomState.scanning_space, geomState.scan_size,
geomState.tiles_touched, geomState.point_offsets, P);
CHECK_CUDA(cub::DeviceScan::InclusiveSum(geomState.scanning_space, geomState.scan_size, geomState.tiles_touched, geomState.point_offsets, P), debug)
// Retrieve total number of Gaussian instances to launch and resize aux buffers
int num_rendered;
cudaMemcpy(&num_rendered, geomState.point_offsets + P - 1, sizeof(int), cudaMemcpyDeviceToHost);
CHECK_CUDA(cudaMemcpy(&num_rendered, geomState.point_offsets + P - 1, sizeof(int), cudaMemcpyDeviceToHost), debug);
size_t binning_chunk_size = required<BinningState>(num_rendered);
char* binning_chunkptr = binningBuffer(binning_chunk_size);
......@@ -294,32 +294,32 @@ int CudaRasterizer::Rasterizer::forward(
binningState.point_list_keys_unsorted,
binningState.point_list_unsorted,
radii,
tile_grid
);
tile_grid)
CHECK_CUDA(, debug)
int bit = getHigherMsb(tile_grid.x * tile_grid.y);
// Sort complete list of (duplicated) Gaussian indices by keys
cub::DeviceRadixSort::SortPairs(
CHECK_CUDA(cub::DeviceRadixSort::SortPairs(
binningState.list_sorting_space,
binningState.sorting_size,
binningState.point_list_keys_unsorted, binningState.point_list_keys,
binningState.point_list_unsorted, binningState.point_list,
num_rendered, 0, 32 + bit);
num_rendered, 0, 32 + bit), debug)
cudaMemset(imgState.ranges, 0, tile_grid.x * tile_grid.y * sizeof(uint2));
CHECK_CUDA(cudaMemset(imgState.ranges, 0, tile_grid.x * tile_grid.y * sizeof(uint2)), debug);
// Identify start and end of per-tile workloads in sorted list
if (num_rendered > 0)
identifyTileRanges << <(num_rendered + 255) / 256, 256 >> > (
num_rendered,
binningState.point_list_keys,
imgState.ranges
);
imgState.ranges);
CHECK_CUDA(, debug)
// Let each tile blend its range of Gaussians independently in parallel
const float* feature_ptr = colors_precomp != nullptr ? colors_precomp : geomState.rgb;
FORWARD::render(
CHECK_CUDA(FORWARD::render(
tile_grid, block,
imgState.ranges,
binningState.point_list,
......@@ -330,7 +330,7 @@ int CudaRasterizer::Rasterizer::forward(
imgState.accum_alpha,
imgState.n_contrib,
background,
out_color);
out_color), debug)
return num_rendered;
}
......@@ -365,7 +365,8 @@ void CudaRasterizer::Rasterizer::backward(
float* dL_dcov3D,
float* dL_dsh,
float* dL_dscale,
float* dL_drot)
float* dL_drot,
bool debug)
{
GeometryState geomState = GeometryState::fromChunk(geom_buffer, P);
BinningState binningState = BinningState::fromChunk(binning_buffer, R);
......@@ -386,7 +387,7 @@ void CudaRasterizer::Rasterizer::backward(
// opacity and RGB of Gaussians from per-pixel loss gradients.
// If we were given precomputed colors and not SHs, use them.
const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : geomState.rgb;
BACKWARD::render(
CHECK_CUDA(BACKWARD::render(
tile_grid,
block,
imgState.ranges,
......@@ -402,13 +403,13 @@ void CudaRasterizer::Rasterizer::backward(
(float3*)dL_dmean2D,
(float4*)dL_dconic,
dL_dopacity,
dL_dcolor);
dL_dcolor), debug)
// Take care of the rest of preprocessing. Was the precomputed covariance
// given to us or a scales/rot pair? If precomputed, pass that. If not,
// use the one we computed ourselves.
const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp : geomState.cov3D;
BACKWARD::preprocess(P, D, M,
CHECK_CUDA(BACKWARD::preprocess(P, D, M,
(float3*)means3D,
radii,
shs,
......@@ -429,5 +430,5 @@ void CudaRasterizer::Rasterizer::backward(
dL_dcov3D,
dL_dsh,
(glm::vec3*)dL_dscale,
(glm::vec4*)dL_drot);
(glm::vec4*)dL_drot), debug)
}
\ No newline at end of file
......@@ -14,6 +14,10 @@ import torch.nn as nn
import torch
from . import _C
def cpu_deep_copy_tuple(input_tuple):
copied_tensors = [item.cpu().clone() if isinstance(item, torch.Tensor) else item for item in input_tuple]
return tuple(copied_tensors)
def rasterize_gaussians(
means3D,
means2D,
......@@ -72,10 +76,20 @@ class _RasterizeGaussians(torch.autograd.Function):
raster_settings.sh_degree,
raster_settings.campos,
raster_settings.prefiltered,
raster_settings.debug
)
# Invoke C++/CUDA rasterizer
num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
if raster_settings.debug:
cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
try:
num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
except Exception as ex:
torch.save(cpu_args, "snapshot_fw.dump")
print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.")
raise ex
else:
num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
# Keep relevant tensors for backward
ctx.raster_settings = raster_settings
......@@ -111,10 +125,20 @@ class _RasterizeGaussians(torch.autograd.Function):
geomBuffer,
num_rendered,
binningBuffer,
imgBuffer)
imgBuffer,
raster_settings.debug)
# Compute gradients for relevant tensors by invoking backward method
grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
if raster_settings.debug:
cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
try:
grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
except Exception as ex:
print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n")
torch.save(cpu_args, "snapshot_bw.dump")
raise ex
else:
grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
grads = (
grad_means3D,
......@@ -142,6 +166,7 @@ class GaussianRasterizationSettings(NamedTuple):
sh_degree : int
campos : torch.Tensor
prefiltered : bool
debug : bool
class GaussianRasterizer(nn.Module):
def __init__(self, raster_settings):
......
......@@ -51,7 +51,8 @@ RasterizeGaussiansCUDA(
const torch::Tensor& sh,
const int degree,
const torch::Tensor& campos,
const bool prefiltered)
const bool prefiltered,
const bool debug)
{
if (means3D.ndimension() != 2 || means3D.size(1) != 3) {
AT_ERROR("means3D must have dimensions (num_points, 3)");
......@@ -107,7 +108,8 @@ RasterizeGaussiansCUDA(
tan_fovy,
prefiltered,
out_color.contiguous().data<float>(),
radii.contiguous().data<int>());
radii.contiguous().data<int>(),
debug);
}
return std::make_tuple(rendered, out_color, radii, geomBuffer, binningBuffer, imgBuffer);
}
......@@ -133,7 +135,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
const torch::Tensor& geomBuffer,
const int R,
const torch::Tensor& binningBuffer,
const torch::Tensor& imageBuffer)
const torch::Tensor& imageBuffer,
const bool debug)
{
const int P = means3D.size(0);
const int H = dL_dout_color.size(1);
......@@ -185,7 +188,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
dL_dcov3D.contiguous().data<float>(),
dL_dsh.contiguous().data<float>(),
dL_dscales.contiguous().data<float>(),
dL_drotations.contiguous().data<float>());
dL_drotations.contiguous().data<float>(),
debug);
}
return std::make_tuple(dL_dmeans2D, dL_dcolors, dL_dopacity, dL_dmeans3D, dL_dcov3D, dL_dsh, dL_dscales, dL_drotations);
......
......@@ -34,7 +34,8 @@ RasterizeGaussiansCUDA(
const torch::Tensor& sh,
const int degree,
const torch::Tensor& campos,
const bool prefiltered);
const bool prefiltered,
const bool debug);
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeGaussiansBackwardCUDA(
......@@ -57,7 +58,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
const torch::Tensor& geomBuffer,
const int R,
const torch::Tensor& binningBuffer,
const torch::Tensor& imageBuffer);
const torch::Tensor& imageBuffer,
const bool debug);
torch::Tensor markVisible(
torch::Tensor& means3D,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment