Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some typing problems #1940

Merged
merged 11 commits into from
May 18, 2023
6 changes: 3 additions & 3 deletions nerfstudio/cameras/rays.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,14 @@ def get_weights(self, densities: Float[Tensor, "*batch num_samples 1"]) -> Float
@overload
@staticmethod
def get_weights_and_transmittance_from_alphas(
alphas: Float[Tensor, "*batch num_samples 1"], weights_only: Literal[False] = False
) -> Tuple[Float[Tensor, "*batch num_samples 1"], Float[Tensor, "*batch num_samples 1"]]:
alphas: Float[Tensor, "*batch num_samples 1"], weights_only: Literal[True]
) -> Float[Tensor, "*batch num_samples 1"]:
...

@overload
@staticmethod
def get_weights_and_transmittance_from_alphas(
alphas: Float[Tensor, "*batch num_samples 1"], weights_only: Literal[True]
alphas: Float[Tensor, "*batch num_samples 1"], weights_only: Literal[False] = False
) -> Tuple[Float[Tensor, "*batch num_samples 1"], Float[Tensor, "*batch num_samples 1"]]:
...

Expand Down
18 changes: 12 additions & 6 deletions nerfstudio/engine/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,17 @@
"""
from __future__ import annotations

from dataclasses import InitVar, dataclass
from dataclasses import dataclass
from enum import Enum, auto
from inspect import signature
from typing import Callable, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple

from torch.cuda.amp.grad_scaler import GradScaler

from nerfstudio.engine.optimizers import Optimizers

if TYPE_CHECKING:
from nerfstudio.pipelines.base_pipeline import Pipeline


@dataclass
Expand All @@ -30,12 +37,11 @@ class TrainingCallbackAttributes:
Instead of providing access to the entire Trainer object, we only provide these attributes.
This should be least prone to errors and fairly clean from a user perspective."""

# TODO(ethan): type this without circular imports
optimizers: Optional[InitVar]
optimizers: Optional[Optimizers]
"""optimizers for training"""
grad_scaler: Optional[InitVar]
grad_scaler: Optional[GradScaler]
"""gradient scalers"""
pipeline: Optional[InitVar]
pipeline: Optional["Pipeline"] # Prevent circular import.
"""reference to training pipeline"""


Expand Down
9 changes: 1 addition & 8 deletions nerfstudio/engine/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,10 @@

import numpy as np
from torch.optim import Optimizer, lr_scheduler
from torch.optim.lr_scheduler import LRScheduler

from nerfstudio.configs.base_config import InstantiateConfig

try:
from torch.optim.lr_scheduler import ( # pylint: disable=ungrouped-imports
LRScheduler,
)
except ImportError:
# Backward compatibility for PyTorch 1.x
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler


@dataclass
class SchedulerConfig(InstantiateConfig):
Expand Down
6 changes: 3 additions & 3 deletions nerfstudio/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ def setup(self, test_mode: Literal["test", "val", "inference"] = "val") -> None:

self.callbacks = self.pipeline.get_training_callbacks(
TrainingCallbackAttributes(
optimizers=self.optimizers, # type: ignore
grad_scaler=self.grad_scaler, # type: ignore
pipeline=self.pipeline, # type: ignore
optimizers=self.optimizers,
grad_scaler=self.grad_scaler,
pipeline=self.pipeline,
)
)

Expand Down
2 changes: 1 addition & 1 deletion nerfstudio/exporter/marching_cubes.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def evaluate_multiresolution_sdf(
if pts_to_eval.shape[0] > 0:
pts_sdf_eval = evaluate(pts_to_eval.contiguous())
assert pts_sdf is not None
pts_sdf[mask] = pts_sdf_eval # pylint: disable=unsupported-assignment-operation
pts_sdf[mask] = pts_sdf_eval

if pid < 3:
# Update mask
Expand Down
18 changes: 15 additions & 3 deletions nerfstudio/field_components/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
Special activation functions.
"""

from typing import TYPE_CHECKING

import torch
from jaxtyping import Float
from torch import Tensor
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd

Expand All @@ -37,6 +41,14 @@ def backward(ctx, g):
return g * torch.exp(x.clamp(-15, 15))


trunc_exp = _TruncExp.apply
"""Same as torch.exp, but with the backward pass clipped to prevent vanishing/exploding
gradients."""
if TYPE_CHECKING:

def trunc_exp(_: Float[Tensor, "*bs"], /) -> Float[Tensor, "*bs"]:
"""Same as torch.exp, but with the backward pass clipped to prevent vanishing/exploding
gradients."""
raise NotImplementedError()

else:
trunc_exp = _TruncExp.apply
"""Same as torch.exp, but with the backward pass clipped to prevent vanishing/exploding
gradients."""
3 changes: 3 additions & 0 deletions nerfstudio/field_components/encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ def forward(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*
"""Sample features from this encoder. Expects ``in_tensor`` to be in range [-1, 1]"""
original_shape = in_tensor.shape

assert any(self.coo_combs)
output = 1.0 if self.reduce == "product" else 0.0 # identity for corresponding op
for ci, coo_comb in enumerate(self.coo_combs):
grid = self.plane_coefs[ci].unsqueeze(0) # [1, feature_dim, reso1, reso2]
Expand All @@ -634,6 +635,8 @@ def forward(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*
else:
output = output + interp

# Typing: output gets converted to a tensor after the first iteration of the loop
assert isinstance(output, Tensor)
return output.reshape(*original_shape[:-1], self.num_components)


Expand Down
8 changes: 7 additions & 1 deletion nerfstudio/field_components/temporal_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@ class TemporalGridEncoder(nn.Module):
align_corners: same as other interpolation operators
"""

sampling_index: Tensor
index_a_mask: Tensor
index_b_mask: Tensor
index_list: Tensor

def __init__(
self,
temporal_dim: int = 64,
Expand Down Expand Up @@ -345,13 +350,14 @@ def forward(
self.gridtype_id,
self.align_corners,
)
assert isinstance(outputs, Tensor)
return outputs

def get_temporal_tv_loss(self) -> Float[Tensor, ""]:
"""Apply TV loss on the temporal channels.
Sample a random channel combination (i.e., row for the combination table),
and then compute loss on it.
"""
row_idx = torch.randint(0, len(self.index_list), [1]).item()
row_idx = torch.randint(0, len(self.index_list), [1])
feat_idx = self.index_list[row_idx]
return (self.embeddings[:, feat_idx[0]] - self.embeddings[:, feat_idx[1]]).abs().mean()
5 changes: 4 additions & 1 deletion nerfstudio/fields/base_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,15 @@ def __init__(self) -> None:
self._sample_locations = None
self._density_before_activation = None

def density_fn(self, positions: Shaped[Tensor, "*bs 3"]) -> Shaped[Tensor, "*bs 1"]:
def density_fn(
self, positions: Shaped[Tensor, "*bs 3"], times: Optional[Shaped[Tensor, "*bs 1"]] = None
) -> Shaped[Tensor, "*bs 1"]:
"""Returns only the density. Used primarily with the density grid.

Args:
positions: the origin of the samples/frustums
"""
del times
# Need to figure out a better way to describe positions with a ray.
ray_samples = RaySamples(
frustums=Frustums(
Expand Down
2 changes: 2 additions & 0 deletions nerfstudio/fields/density_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class HashMLPDensityField(Field):
use_linear: whether to skip the MLP and use a single linear layer instead
"""

aabb: Tensor

def __init__(
self,
aabb: Tensor,
Expand Down
6 changes: 4 additions & 2 deletions nerfstudio/fields/nerfacto_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""


from typing import Dict, Optional, Tuple
from typing import Dict, Optional, Tuple, Type

import numpy as np
import torch
Expand Down Expand Up @@ -80,6 +80,8 @@ class TCNNNerfactoField(Field):
spatial_distortion: spatial distortion to apply to the scene
"""

aabb: Tensor

def __init__(
self,
aabb: Tensor,
Expand Down Expand Up @@ -420,4 +422,4 @@ def get_outputs(
return outputs


field_implementation_to_class: Dict[str, Field] = {"tcnn": TCNNNerfactoField, "torch": TorchNerfactoField}
field_implementation_to_class: Dict[str, Type[Field]] = {"tcnn": TCNNNerfactoField, "torch": TorchNerfactoField}
5 changes: 4 additions & 1 deletion nerfstudio/fields/nerfplayer_nerfacto_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,16 @@ def __init__(
},
)

def density_fn(self, positions: Float[Tensor, "*bs 3"], times: Float[Tensor, "bs 1"]) -> Float[Tensor, "*bs 1"]:
def density_fn(
self, positions: Float[Tensor, "*bs 3"], times: Optional[Float[Tensor, "bs 1"]] = None
) -> Float[Tensor, "*bs 1"]:
"""Returns only the density. Used primarily with the density grid.

Args:
positions: the origin of the samples/frustums
times: the time of rays
"""
assert times is not None, "TemporalHashMLPDensityField requires times to be specified"
if len(positions.shape) == 3 and len(times.shape) == 2:
# position is [ray, sample, 3]; times is [ray, 1]
times = times[:, None] # RaySamples can handle the shape
Expand Down
6 changes: 4 additions & 2 deletions nerfstudio/fields/nerfplayer_ngp_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(
level_dim=features_per_level,
base_resolution=base_resolution,
log2_hashmap_size=log2_hashmap_size,
desired_resolution=1024 * (self.aabb.max() - self.aabb.min()),
desired_resolution=int(1024 * (self.aabb.max() - self.aabb.min())),
)
self.mlp_base_decode = tcnn.Network(
n_input_dims=num_levels * features_per_level,
Expand Down Expand Up @@ -201,7 +201,9 @@ def get_outputs(
rgb = self.mlp_head(h).view(*ray_samples.frustums.directions.shape[:-1], -1).to(directions)
return {FieldHeadNames.RGB: rgb}

def density_fn(self, positions: Float[Tensor, "*bs 3"], times: Float[Tensor, "*bs 1"]) -> Float[Tensor, "*bs 1"]:
def density_fn(
self, positions: Float[Tensor, "*bs 3"], times: Optional[Float[Tensor, "*bs 1"]] = None
) -> Float[Tensor, "*bs 1"]:
"""Returns only the density. Used primarily with the density grid.
Overwrite this function since density is time dependent now.

Expand Down
6 changes: 6 additions & 0 deletions nerfstudio/fields/sdf_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class LearnedVariance(nn.Module):
init_val: initial value in NeuS variance network
"""

variance: Tensor

def __init__(self, init_val):
super().__init__()
self.register_parameter("variance", nn.Parameter(init_val * torch.ones(1), requires_grad=True))
Expand Down Expand Up @@ -205,6 +207,9 @@ def __init__(

self._cos_anneal_ratio = 1.0

if self.use_grid_feature:
assert self.spatial_distortion is not None, "spatial distortion must be provided when using grid feature"

def initialize_geo_layers(self) -> None:
"""
Initialize layers for geometric network (sdf)
Expand Down Expand Up @@ -255,6 +260,7 @@ def set_cos_anneal_ratio(self, anneal: float) -> None:
def forward_geonetwork(self, inputs: Float[Tensor, "*batch 3"]) -> Float[Tensor, "*batch geo_features+1"]:
"""forward the geonetwork"""
if self.use_grid_feature:
assert self.spatial_distortion is not None, "spatial distortion must be provided when using grid feature"
positions = self.spatial_distortion(inputs)
# map range [-2, 2] to [0, 1]
positions = (positions + 2.0) / 4.0
Expand Down
4 changes: 2 additions & 2 deletions nerfstudio/generative/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Union
from typing import List, Optional, Union, cast

import appdirs
import mediapy
Expand Down Expand Up @@ -223,7 +223,7 @@ def sds_loss(

if grad_scaler is not None:
latents = grad_scaler.scale(latents)
loss = _SDSGradient.apply(latents, grad)
loss = cast(Tensor, _SDSGradient.apply(latents, grad))

return loss

Expand Down
24 changes: 14 additions & 10 deletions nerfstudio/model_components/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
Collection of Losses.
"""
from enum import Enum
from typing import Dict, Literal, Optional, TypeVar, cast
from typing import Dict, Literal, Optional, Tuple, cast

import torch
from jaxtyping import Bool, Float
from torch import Tensor, nn

from nerfstudio.cameras.rays import RaySamples
from nerfstudio.field_components.field_heads import FieldHeadNames
from nerfstudio.utils.math import masked_reduction, normalized_depth_scale_and_shift

L1Loss = nn.L1Loss
Expand Down Expand Up @@ -103,20 +104,24 @@ def ray_samples_to_sdist(ray_samples):
return sdist


def interlevel_loss(weights_list, ray_samples_list):
def interlevel_loss(weights_list, ray_samples_list) -> torch.Tensor:
"""Calculates the proposal loss in the MipNeRF-360 paper.

https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/model.py#L515
https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/train_utils.py#L133
"""
c = ray_samples_to_sdist(ray_samples_list[-1]).detach()
w = weights_list[-1][..., 0].detach()
assert len(ray_samples_list) > 0

loss_interlevel = 0.0
for ray_samples, weights in zip(ray_samples_list[:-1], weights_list[:-1]):
sdist = ray_samples_to_sdist(ray_samples)
cp = sdist # (num_rays, num_samples + 1)
wp = weights[..., 0] # (num_rays, num_samples)
loss_interlevel += torch.mean(lossfun_outer(c, w, cp, wp))

assert isinstance(loss_interlevel, Tensor)
return loss_interlevel


Expand Down Expand Up @@ -398,7 +403,8 @@ def forward(
Returns:
gradient loss based on reduction function
"""
total = 0
assert self.__scales >= 1
total = 0.0

for scale in range(self.__scales):
step = pow(2, scale)
Expand All @@ -410,7 +416,8 @@ def forward(
)
total += grad_loss

return cast(Tensor, total)
assert isinstance(total, Tensor)
return total

def gradient_loss(
self,
Expand Down Expand Up @@ -534,12 +541,9 @@ def backward(ctx, output_grad, grad_scaling):
return output_grad * scaling, grad_scaling


K = TypeVar("K")


def scale_gradients_by_distance_squared(
field_outputs: Dict[K, torch.Tensor], ray_samples: RaySamples
) -> Dict[K, torch.Tensor]:
field_outputs: Dict[FieldHeadNames, torch.Tensor], ray_samples: RaySamples
) -> Dict[FieldHeadNames, torch.Tensor]:
"""
Scale gradients by the ray distance to the pixel
as suggested in `Radiance Field Gradient Scaling for Unbiased Near-Camera Training` paper
Expand All @@ -554,5 +558,5 @@ def scale_gradients_by_distance_squared(
ray_dist = (ray_samples.frustums.starts + ray_samples.frustums.ends) / 2
scaling = torch.square(ray_dist).clamp(0, 1)
for key, value in field_outputs.items():
out[key], _ = _GradientScaler.apply(value, scaling)
out[key], _ = cast(Tuple[Tensor, Tensor], _GradientScaler.apply(value, scaling))
return out
2 changes: 2 additions & 0 deletions nerfstudio/model_components/ray_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class RayGenerator(nn.Module):
pose_optimizer: pose optimization module, for optimizing noisy camera intrinsics/extrinsics.
"""

image_coords: Tensor

def __init__(self, cameras: Cameras, pose_optimizer: CameraOptimizer) -> None:
super().__init__()
self.cameras = cameras
Expand Down
Loading