Skip to content

Commit

Permalink
pytorch 1.11 support: don't use conv2d_gradfix on v1.11, port grid_sa…
Browse files Browse the repository at this point in the history
…mple_gradfix to the new API

thanks @timothybrooks for the fix!

for #145
  • Loading branch information
jannehellsten committed Apr 22, 2022
1 parent 69c7ef0 commit 407db86
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
5 changes: 5 additions & 0 deletions torch_utils/ops/conv2d_gradfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import contextlib
import torch
from pkg_resources import parse_version

# pylint: disable=redefined-builtin
# pylint: disable=arguments-differ
Expand All @@ -20,6 +21,7 @@

enabled = False # Enable the custom op by setting this to true.
weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11

@contextlib.contextmanager
def no_weight_gradients(disable=True):
Expand Down Expand Up @@ -48,6 +50,9 @@ def _should_use_custom_op(input):
assert isinstance(input, torch.Tensor)
if (not enabled) or (not torch.backends.cudnn.enabled):
return False
if _use_pytorch_1_11_api:
# The work-around code doesn't work on PyTorch 1.11.0 onwards
return False
if input.device.type != 'cuda':
return False
return True
Expand Down
8 changes: 7 additions & 1 deletion torch_utils/ops/grid_sample_gradfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""

import torch
from pkg_resources import parse_version

# pylint: disable=redefined-builtin
# pylint: disable=arguments-differ
Expand All @@ -20,6 +21,7 @@
#----------------------------------------------------------------------------

enabled = False # Enable the custom op by setting this to true.
_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11

#----------------------------------------------------------------------------

Expand Down Expand Up @@ -56,7 +58,11 @@ class _GridSample2dBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, grad_output, input, grid):
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
if _use_pytorch_1_11_api:
output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask)
else:
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
ctx.save_for_backward(grid)
return grad_input, grad_grid

Expand Down

0 comments on commit 407db86

Please sign in to comment.