Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement gradient scaling #1901

Merged
merged 9 commits into from
May 12, 2023
Merged
39 changes: 38 additions & 1 deletion nerfstudio/model_components/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Collection of Losses.
"""
from enum import Enum
from typing import Literal
from typing import Dict, Literal

import torch
from torch import nn
Expand Down Expand Up @@ -502,3 +502,40 @@ def tv_loss(grids: TensorType["grids", "feature_dim", "row", "column"]) -> Tenso
h_tv = torch.pow((grids[:, :, 1:, :] - grids[:, :, :-1, :]), 2).sum()
w_tv = torch.pow((grids[:, :, :, 1:] - grids[:, :, :, :-1]), 2).sum()
return 2 * (h_tv / h_tv_count + w_tv / w_tv_count) / number_of_grids


class _GradientScaler(torch.autograd.Function): # typing: ignore, pylint: disable=abstract-method
"""
Scale gradients by a constant factor.
"""

@staticmethod
def forward(ctx, value, scaling): # pylint: disable=arguments-differ
ctx.save_for_backward(scaling)
return value, scaling

@staticmethod
def backward(ctx, output_grad, grad_scaling): # pylint: disable=arguments-differ
(scaling,) = ctx.saved_tensors
return output_grad * scaling, grad_scaling


def scale_gradients_by_distance_squared(
field_outputs: Dict[str, torch.Tensor], ray_samples: RaySamples
) -> Dict[str, torch.Tensor]:
"""
Scale gradients by the ray distance to the pixel
as suggested in `Radiance Field Gradient Scaling for Unbiased Near-Camera Training` paper

Note: The scaling is applied on the interval of [0, 1] along the ray!

Example:
GradientLoss should be called right after obtaining the densities and colors from the field. ::
>>> field_outputs = scale_gradient_by_distance_squared(field_outputs, ray_samples)
"""
out = {}
ray_dist = (ray_samples.frustums.starts + ray_samples.frustums.ends) / 2
scaling = torch.square(ray_dist).clamp(0, 1)
for key, value in field_outputs.items():
out[key], _ = _GradientScaler.apply(value, scaling)
return out
9 changes: 8 additions & 1 deletion nerfstudio/models/nerfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from torchmetrics.functional import structural_similarity_index_measure
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

from nerfstudio.cameras.rays import RayBundle
from nerfstudio.cameras.rays import RayBundle, RaySamples
from nerfstudio.engine.callbacks import (
TrainingCallback,
TrainingCallbackAttributes,
Expand All @@ -44,6 +44,7 @@
interlevel_loss,
orientation_loss,
pred_normal_loss,
scale_gradients_by_distance_squared,
)
from nerfstudio.model_components.ray_samplers import (
ProposalNetworkSampler,
Expand Down Expand Up @@ -127,6 +128,8 @@ class NerfactoModelConfig(ModelConfig):
"""Whether to predict normals or not."""
disable_scene_contraction: bool = False
"""Whether to disable scene contraction or not."""
use_gradient_scaling: bool = False
"""Use gradient scaler where the gradients are lower for points closer to the camera."""


class NerfactoModel(Model):
Expand Down Expand Up @@ -262,8 +265,12 @@ def set_anneal(step):
return callbacks

def get_outputs(self, ray_bundle: RayBundle):
ray_samples: RaySamples
ray_samples, weights_list, ray_samples_list = self.proposal_sampler(ray_bundle, density_fns=self.density_fns)
field_outputs = self.field(ray_samples, compute_normals=self.config.predict_normals)
if self.config.use_gradient_scaling:
field_outputs = scale_gradients_by_distance_squared(field_outputs, ray_samples)

weights = ray_samples.get_weights(field_outputs[FieldHeadNames.DENSITY])
weights_list.append(weights)
ray_samples_list.append(ray_samples)
Expand Down