Skip to content

Commit

Permalink
rasterize_to_pixels_fwd_kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
Ruilong Li committed Aug 7, 2024
1 parent 76ca887 commit 7c78f4b
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 12 deletions.
6 changes: 4 additions & 2 deletions gsplat/cuda/csrc/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ torch::Tensor isect_offset_encode_tensor(const torch::Tensor &isect_ids, // [n_i
const uint32_t C, const uint32_t tile_width,
const uint32_t tile_height);

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> rasterize_to_pixels_fwd_tensor(
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
rasterize_to_pixels_fwd_tensor(
// Gaussian parameters
const torch::Tensor &means2d, // [C, N, 2]
const torch::Tensor &conics, // [C, N, 3]
Expand All @@ -122,7 +123,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> rasterize_to_pixels_fwd_
const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size,
// intersections
const torch::Tensor &tile_offsets, // [C, tile_height, tile_width]
const torch::Tensor &flatten_ids // [n_isects]
const torch::Tensor &flatten_ids, // [n_isects]
const bool calc_depth // whether to calculate depth
);

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
Expand Down
44 changes: 34 additions & 10 deletions gsplat/cuda/csrc/rasterize_to_pixels_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ __global__ void rasterize_to_pixels_fwd_kernel(
const int32_t *__restrict__ flatten_ids, // [n_isects]
S *__restrict__ render_colors, // [C, image_height, image_width, COLOR_DIM]
S *__restrict__ render_alphas, // [C, image_height, image_width, 1]
S *__restrict__ render_depths, // [C, image_height, image_width, 1] optional
int32_t *__restrict__ last_ids // [C, image_height, image_width]
) {
// each thread draws one pixel, but also timeshares caching gaussians in a
Expand All @@ -39,6 +40,9 @@ __global__ void rasterize_to_pixels_fwd_kernel(
tile_offsets += camera_id * tile_height * tile_width;
render_colors += camera_id * image_height * image_width * COLOR_DIM;
render_alphas += camera_id * image_height * image_width;
if (render_depths != nullptr) {
render_depths += camera_id * image_height * image_width;
}
last_ids += camera_id * image_height * image_width;
if (backgrounds != nullptr) {
backgrounds += camera_id * COLOR_DIM;
Expand Down Expand Up @@ -84,6 +88,7 @@ __global__ void rasterize_to_pixels_fwd_kernel(
uint32_t tr = block.thread_rank();

S pix_out[COLOR_DIM] = {0.f};
S depth_out = 0.f;
for (uint32_t b = 0; b < num_batches; ++b) {
// resync all threads before beginning next batch
// end early if entire tile is done
Expand Down Expand Up @@ -140,10 +145,18 @@ __global__ void rasterize_to_pixels_fwd_kernel(
int32_t g = id_batch[t];
const S vis = alpha * T;
const S *c_ptr = colors + g * COLOR_DIM;
// accumulate color
PRAGMA_UNROLL
for (uint32_t k = 0; k < COLOR_DIM; ++k) {
pix_out[k] += c_ptr[k] * vis;
}
// accumulate depth
if (render_depths != nullptr) {
S depth =
mean2d.z +
(conic02 * (mean2d.x - px) + conic12 * (mean2d.y - py)) / conic22;
depth_out += depth * vis;
}

This comment has been minimized.

Copy link
@mateosss

mateosss Aug 20, 2024

Does this depth formulation from a gaussian have any reference in the literature?

This comment has been minimized.

Copy link
@mateosss

This comment has been minimized.

Copy link
@liruilong940607

liruilong940607 Aug 20, 2024

Collaborator

Its motivated from https://arxiv.org/abs/2406.01467, but slightly different where their formulation is to calculate t-depth, while here we are calculating z-depth @mateosss

cur_idx = batch_start + t;

T = next_T;
Expand All @@ -162,13 +175,17 @@ __global__ void rasterize_to_pixels_fwd_kernel(
render_colors[pix_id * COLOR_DIM + k] =
backgrounds == nullptr ? pix_out[k] : (pix_out[k] + T * backgrounds[k]);
}
if (render_depths != nullptr) {
render_depths[pix_id] = depth_out;
}
// index in bin of last gaussian in this pixel
last_ids[pix_id] = static_cast<int32_t>(cur_idx);
}
}

template <uint32_t CDIM>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> call_kernel_with_dim(
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
call_kernel_with_dim(
// Gaussian parameters
const torch::Tensor &means2d, // [C, N, 3] or [nnz, 3]
const torch::Tensor &conics, // [C, N, 6] or [nnz, 6]
Expand All @@ -179,8 +196,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> call_kernel_with_dim(
const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size,
// intersections
const torch::Tensor &tile_offsets, // [C, tile_height, tile_width]
const torch::Tensor &flatten_ids // [n_isects]
) {
const torch::Tensor &flatten_ids, // [n_isects]
bool calc_depth) {
DEVICE_GUARD(means2d);
CHECK_INPUT(means2d);
CHECK_INPUT(conics);
Expand Down Expand Up @@ -209,6 +226,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> call_kernel_with_dim(
means2d.options().dtype(torch::kFloat32));
torch::Tensor alphas = torch::empty({C, image_height, image_width, 1},
means2d.options().dtype(torch::kFloat32));
torch::Tensor depths;
if (calc_depth) {
depths = torch::empty({C, image_height, image_width, 1},
means2d.options().dtype(torch::kFloat32));
}
torch::Tensor last_ids = torch::empty({C, image_height, image_width},
means2d.options().dtype(torch::kInt32));

Expand Down Expand Up @@ -236,12 +258,14 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> call_kernel_with_dim(
image_width, image_height, tile_size, tile_width, tile_height,
tile_offsets.data_ptr<int32_t>(), flatten_ids.data_ptr<int32_t>(),
renders.data_ptr<float>(), alphas.data_ptr<float>(),
calc_depth ? depths.data_ptr<float>() : nullptr,
last_ids.data_ptr<int32_t>());

return std::make_tuple(renders, alphas, last_ids);
return std::make_tuple(renders, alphas, depths, last_ids);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> rasterize_to_pixels_fwd_tensor(
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
rasterize_to_pixels_fwd_tensor(
// Gaussian parameters
const torch::Tensor &means2d, // [C, N, 3] or [nnz, 3]
const torch::Tensor &conics, // [C, N, 6] or [nnz, 6]
Expand All @@ -252,16 +276,16 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> rasterize_to_pixels_fwd_
const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size,
// intersections
const torch::Tensor &tile_offsets, // [C, tile_height, tile_width]
const torch::Tensor &flatten_ids // [n_isects]
) {
const torch::Tensor &flatten_ids, // [n_isects]
bool calc_depth) {
CHECK_INPUT(colors);
uint32_t channels = colors.size(-1);

#define __GS__CALL_(N) \
case N: \
return call_kernel_with_dim<N>(means2d, conics, colors, opacities, \
backgrounds, image_width, image_height, \
tile_size, tile_offsets, flatten_ids);
return call_kernel_with_dim<N>( \
means2d, conics, colors, opacities, backgrounds, image_width, \
image_height, tile_size, tile_offsets, flatten_ids, calc_depth);

// TODO: an optimization can be done by passing the actual number of channels into
// the kernel functions and avoid necessary global memory writes. This requires
Expand Down

0 comments on commit 7c78f4b

Please sign in to comment.