Skip to content

Commit

Permalink
torch impl working
Browse files Browse the repository at this point in the history
  • Loading branch information
Ruilong Li committed Jul 16, 2024
1 parent 7562a39 commit 76ca887
Show file tree
Hide file tree
Showing 18 changed files with 905 additions and 464 deletions.
113 changes: 110 additions & 3 deletions examples/simple_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import torch
import torch.nn.functional as F
import viser
from torch import Tensor

from gsplat._helper import load_test_data
from gsplat.rendering import rasterization
from gsplat.rendering import _rasterization, rasterization

parser = argparse.ArgumentParser()
parser.add_argument(
Expand All @@ -38,6 +40,94 @@
torch.manual_seed(42)
device = "cuda"


def getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"):
tanHalfFovY = math.tan((fovY / 2))
tanHalfFovX = math.tan((fovX / 2))

top = tanHalfFovY * znear
bottom = -top
right = tanHalfFovX * znear
left = -right

P = torch.zeros(4, 4, device=device)

z_sign = 1.0

P[0, 0] = 2.0 * znear / (right - left)
P[1, 1] = 2.0 * znear / (top - bottom)
P[0, 2] = (right + left) / (right - left)
P[1, 2] = (top + bottom) / (top - bottom)
P[3, 2] = z_sign
P[2, 2] = z_sign * zfar / (zfar - znear)
P[2, 3] = -(zfar * znear) / (zfar - znear)
return P


def _depths_to_points(depthmap, world_view_transform, full_proj_transform):
c2w = (world_view_transform.T).inverse()
H, W = depthmap.shape[:2]
ndc2pix = (
torch.tensor([[W / 2, 0, 0, (W) / 2], [0, H / 2, 0, (H) / 2], [0, 0, 0, 1]])
.float()
.cuda()
.T
)
projection_matrix = c2w.T @ full_proj_transform
intrins = (projection_matrix @ ndc2pix)[:3, :3].T

grid_x, grid_y = torch.meshgrid(
torch.arange(W, device="cuda").float(),
torch.arange(H, device="cuda").float(),
indexing="xy",
)
points = torch.stack([grid_x, grid_y, torch.ones_like(grid_x)], dim=-1).reshape(
-1, 3
)
rays_d = points @ intrins.inverse().T @ c2w[:3, :3].T
rays_o = c2w[:3, 3]
points = depthmap.reshape(-1, 1) * rays_d + rays_o
return points


def _depth_to_normal(depth, world_view_transform, full_proj_transform):
points = _depths_to_points(
depth, world_view_transform, full_proj_transform
).reshape(*depth.shape[:2], 3)
output = torch.zeros_like(points)
dx = torch.cat([points[2:, 1:-1] - points[:-2, 1:-1]], dim=0)
dy = torch.cat([points[1:-1, 2:] - points[1:-1, :-2]], dim=1)
normal_map = F.normalize(torch.cross(dx, dy, dim=-1), dim=-1)
output[1:-1, 1:-1, :] = normal_map
return output


def depth_to_normal(
depths: Tensor, # [C, H, W, 1]
viewmats: Tensor, # [C, 4, 4]
Ks: Tensor, # [C, 3, 3]
near_plane: float = 0.01,
far_plane: float = 1e10,
) -> Tensor:
height, width = depths.shape[1:3]

normals = []
for cid, depth in enumerate(depths):
FoVx = 2 * math.atan(width / (2 * Ks[cid, 0, 0].item()))
FoVy = 2 * math.atan(height / (2 * Ks[cid, 1, 1].item()))
world_view_transform = viewmats[cid].transpose(0, 1)
projection_matrix = getProjectionMatrix(
znear=near_plane, zfar=far_plane, fovX=FoVx, fovY=FoVy, device=depths.device
).transpose(0, 1)
full_proj_transform = (
world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))
).squeeze(0)
normal = _depth_to_normal(depth, world_view_transform, full_proj_transform)
normals.append(normal)
normals = torch.stack(normals, dim=0)
return normals


if args.ckpt is None:
(
means,
Expand All @@ -55,8 +145,20 @@
N = len(means)
print("Number of Gaussians:", N)

ckpt = torch.load("results/garden/ckpts/ckpt_6999.pt", map_location=device)[
"splats"
]
means = ckpt["means3d"]
quats = F.normalize(ckpt["quats"], p=2, dim=-1)
scales = torch.exp(ckpt["scales"])
opacities = torch.sigmoid(ckpt["opacities"])
sh0 = ckpt["sh0"]
shN = ckpt["shN"]
colors = torch.cat([sh0, shN], dim=-2)
sh_degree = int(math.sqrt(colors.shape[-2]) - 1)

# batched render
render_colors, render_alphas, meta = rasterization(
render_colors, render_alphas, meta = _rasterization(
means, # [N, 3]
quats, # [N, 4]
scales, # [N, 3]
Expand All @@ -66,13 +168,17 @@
Ks, # [C, 3, 3]
width,
height,
render_mode="RGB+D",
render_mode="RGB+ED",
sh_degree=sh_degree,
accurate_depth=False,
)
assert render_colors.shape == (C, height, width, 4)
assert render_alphas.shape == (C, height, width, 1)

render_rgbs = render_colors[..., 0:3]
render_depths = render_colors[..., 3:4]
render_normals = depth_to_normal(render_depths, viewmats, Ks)
render_normals = render_normals * 0.5 + 0.5 # [-1, 1] -> [0, 1]
render_depths = render_depths / render_depths.max()

# dump batch images
Expand All @@ -82,6 +188,7 @@
[
render_rgbs.reshape(C * height, width, 3),
render_depths.reshape(C * height, width, 1).expand(-1, -1, 3),
render_normals.reshape(C * height, width, 3),
render_alphas.reshape(C * height, width, 1).expand(-1, -1, 3),
],
dim=1,
Expand Down
113 changes: 77 additions & 36 deletions gsplat/cuda/_torch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def _persp_proj(
Ks: Tensor, # [C, 3, 3]
width: int,
height: int,
reduce_z: bool = False,
) -> Tuple[Tensor, Tensor]:
"""PyTorch implementation of prespective projection for 3D Gaussians.
Expand All @@ -69,12 +70,13 @@ def _persp_proj(
Ks: Camera intrinsics. [C, 3, 3].
width: Image width.
height: Image height.
reduce_z: Whether to reduce the z-coordinate.
Returns:
A tuple:
- **means2d**: Projected means. [C, N, 2].
- **cov2d**: Projected covariances. [C, N, 2, 2].
- **means2d**: Projected means. [C, N, 2] if `reduce_z=True` or [C, N, 3].
- **cov2d**: Projected covariances. [C, N, 2, 2] if `reduce_z=True` or [C, N, 3, 3].
"""
C, N, _ = means.shape

Expand All @@ -92,14 +94,21 @@ def _persp_proj(
ty = tz * torch.clamp(ty / tz, min=-lim_y, max=lim_y)

O = torch.zeros((C, N), device=means.device, dtype=means.dtype)
I = torch.ones((C, N), device=means.device, dtype=means.dtype)
J = torch.stack(
[fx / tz, O, -fx * tx / tz2, O, fy / tz, -fy * ty / tz2], dim=-1
).reshape(C, N, 2, 3)
[fx / tz, O, -fx * tx / tz2, O, fy / tz, -fy * ty / tz2, O, O, I], dim=-1
).reshape(C, N, 3, 3)

cov2d = torch.einsum("...ij,...jk,...kl->...il", J, covars, J.transpose(-1, -2))
means2d = torch.einsum("cij,cnj->cni", Ks[:, :2, :3], means) # [C, N, 2]
means2d = means2d / tz[..., None] # [C, N, 2]
return means2d, cov2d # [C, N, 2], [C, N, 2, 2]
means2d = torch.einsum("cij,cnj->cni", Ks[:, :3, :3], means) # [C, N, 2]
if reduce_z:
cov2d = cov2d[..., :2, :2]
means2d = means2d[..., :2] / means2d[..., 2:] # [C, N, 2]
else:
means2d = torch.cat(
[means2d[..., :2] / means2d[..., 2:], means2d[..., 2:]], dim=-1
) # [C, N, 3]
return means2d, cov2d


def _world_to_cam(
Expand Down Expand Up @@ -129,7 +138,7 @@ def _world_to_cam(

def _fully_fused_projection(
means: Tensor, # [N, 3]
covars: Tensor, # [N, 3, 3]
covars: Tensor, # [N, 3, 3] or [N, 6]
viewmats: Tensor, # [C, 4, 4]
Ks: Tensor, # [C, 3, 3]
width: int,
Expand All @@ -138,6 +147,7 @@ def _fully_fused_projection(
near_plane: float = 0.01,
far_plane: float = 1e10,
calc_compensations: bool = False,
triu: bool = True,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Optional[Tensor]]:
"""PyTorch implementation of `gsplat.cuda._wrapper.fully_fused_projection()`
Expand All @@ -146,34 +156,43 @@ def _fully_fused_projection(
This is a minimal implementation of fully fused version, which has more
arguments. Not all arguments are supported.
"""
if triu:
covars = torch.stack(
[
covars[..., 0],
covars[..., 1],
covars[..., 2],
covars[..., 1],
covars[..., 3],
covars[..., 4],
covars[..., 2],
covars[..., 4],
covars[..., 5],
],
dim=-1,
).reshape(
-1, 3, 3
) # [N, 3, 3]

means_c, covars_c = _world_to_cam(means, covars, viewmats)
means2d, covars2d = _persp_proj(means_c, covars_c, Ks, width, height)
det_orig = (
covars2d[..., 0, 0] * covars2d[..., 1, 1]
- covars2d[..., 0, 1] * covars2d[..., 1, 0]
means2d, covars2d = _persp_proj(
means_c, covars_c, Ks, width, height, reduce_z=False
)
covars2d = covars2d + torch.eye(2, device=means.device, dtype=means.dtype) * eps2d
det_orig = torch.det(covars2d) # [C, N]

det = (
covars2d[..., 0, 0] * covars2d[..., 1, 1]
- covars2d[..., 0, 1] * covars2d[..., 1, 0]
eps = torch.tensor(
[[eps2d, 0.0, 0.0], [0.0, eps2d, 0.0], [0.0, 0.0, eps2d]], device=means.device
)
covars2d = covars2d + eps

det = torch.det(covars2d) # [C, N]
det = det.clamp(min=1e-10)

if calc_compensations:
compensations = torch.sqrt(torch.clamp(det_orig / det, min=0.0))
else:
compensations = None

conics = torch.stack(
[
covars2d[..., 1, 1] / det,
-(covars2d[..., 0, 1] + covars2d[..., 1, 0]) / 2.0 / det,
covars2d[..., 0, 0] / det,
],
dim=-1,
) # [C, N, 3]

depths = means_c[..., 2] # [C, N]

b = (covars2d[..., 0, 0] + covars2d[..., 1, 1]) / 2 # (...,)
Expand All @@ -194,7 +213,14 @@ def _fully_fused_projection(
radius[~inside] = 0.0

radii = radius.int()
return radii, means2d, depths, conics, compensations
conics = torch.inverse(covars2d) # [C, N, 3, 3]

if triu:
conics = conics.reshape(*conics.shape[:-2], 9)
conics = (
conics[..., [0, 1, 2, 4, 5, 8]] + conics[..., [0, 3, 6, 4, 7, 8]]
) / 2.0
return radii, means2d, conics, compensations


@torch.no_grad()
Expand Down Expand Up @@ -300,8 +326,8 @@ def _isect_offset_encode(


def accumulate(
means2d: Tensor, # [C, N, 2]
conics: Tensor, # [C, N, 3]
means2d: Tensor, # [C, N, 3]
conics: Tensor, # [C, N, 6]
opacities: Tensor, # [C, N]
colors: Tensor, # [C, N, channels]
gaussian_ids: Tensor, # [M]
Expand Down Expand Up @@ -360,10 +386,10 @@ def accumulate(
pixel_ids_x = pixel_ids % image_width
pixel_ids_y = pixel_ids // image_width
pixel_coords = torch.stack([pixel_ids_x, pixel_ids_y], dim=-1) + 0.5 # [M, 2]
deltas = pixel_coords - means2d[camera_ids, gaussian_ids] # [M, 2]
c = conics[camera_ids, gaussian_ids] # [M, 3]
deltas = pixel_coords - means2d[camera_ids, gaussian_ids, :2] # [M, 3]
c = conics[camera_ids, gaussian_ids] # [M, 6]
sigmas = (
0.5 * (c[:, 0] * deltas[:, 0] ** 2 + c[:, 2] * deltas[:, 1] ** 2)
0.5 * (c[:, 0] * deltas[:, 0] ** 2 + c[:, 3] * deltas[:, 1] ** 2)
+ c[:, 1] * deltas[:, 0] * deltas[:, 1]
) # [M]
alphas = torch.clamp_max(
Expand All @@ -373,6 +399,16 @@ def accumulate(
indices = camera_ids * image_height * image_width + pixel_ids
total_pixels = C * image_height * image_width

# calculate depths
mu = means2d[camera_ids, gaussian_ids] # [M, 3]
o = torch.cat(
[pixel_coords, torch.zeros_like(pixel_ids_x)[..., None]], dim=-1
) # [M, 3]
A = c[:, -1] # [M,] conics22
B = torch.einsum("...i,...i->...", c[:, [2, 4, 5]], (mu - o)) # [M]
D = B / A # [M]

# alpha compositing
weights, trans = render_weight_from_alpha(
alphas, ray_indices=indices, n_rays=total_pixels
)
Expand All @@ -385,13 +421,16 @@ def accumulate(
alphas = accumulate_along_rays(
weights, None, ray_indices=indices, n_rays=total_pixels
).reshape(C, image_height, image_width, 1)
depths = accumulate_along_rays(
weights, D[..., None], ray_indices=indices, n_rays=total_pixels
).reshape(C, image_height, image_width, 1)

return renders, alphas
return renders, alphas, depths


def _rasterize_to_pixels(
means2d: Tensor, # [C, N, 2]
conics: Tensor, # [C, N, 3]
means2d: Tensor, # [C, N, 3]
conics: Tensor, # [C, N, 6]
colors: Tensor, # [C, N, channels]
opacities: Tensor, # [C, N]
image_width: int,
Expand Down Expand Up @@ -434,6 +473,7 @@ def _rasterize_to_pixels(
(C, image_height, image_width, colors.shape[-1]), device=device
)
render_alphas = torch.zeros((C, image_height, image_width, 1), device=device)
render_depths = torch.zeros((C, image_height, image_width, 1), device=device)

# Split Gaussians into batches and iteratively accumulate the renderings
block_size = tile_size * tile_size
Expand Down Expand Up @@ -464,7 +504,7 @@ def _rasterize_to_pixels(
break

# Accumulate the renderings within this batch of Gaussians.
renders_step, accs_step = accumulate(
renders_step, accs_step, depths_step = accumulate(
means2d,
conics,
opacities,
Expand All @@ -477,14 +517,15 @@ def _rasterize_to_pixels(
)
render_colors = render_colors + renders_step * transmittances[..., None]
render_alphas = render_alphas + accs_step * transmittances[..., None]
render_depths = render_depths + depths_step * transmittances[..., None]

render_alphas = render_alphas
if backgrounds is not None:
render_colors = render_colors + backgrounds[:, None, None, :] * (
1.0 - render_alphas
)

return render_colors, render_alphas
return render_colors, render_alphas, render_depths


def _eval_sh_bases_fast(basis_dim: int, dirs: Tensor):
Expand Down
Loading

0 comments on commit 76ca887

Please sign in to comment.