Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move autograd metadata from VariableImpl to TensorImpl #13827

Closed
wants to merge 60 commits into from

Conversation

yf225
Copy link
Contributor

@yf225 yf225 commented Nov 11, 2018

Changes originally in this PR:

  1. Move Variable::Impl data members into TensorImpl as AutogradMeta struct
  2. Change Variable::Impl functions to use data members in AutogradMeta struct
  3. Add shallow_copy_and_detach() function to each subclass of TensorImpl
  4. Do shallow copy when the user calls make_variable(tensor) / make_variable_view(tensor) / variable.set_data(tensor) / variable.detach()

Changes moved from #13645:

  1. Add a flag to Variable to disallow size/stride/storage_ptr changes from in-place operations such as resize_ / resize_as_ / set_ / transpose_, and set this flag to true when people call tensor.data in Python.
  2. Write text in the docs to actively discourage changing the shape or storage of tensor_detached and expecting tensor to also be updated.

This is the 1st+2nd PR mentioned in #13638.

@yf225 yf225 closed this Nov 11, 2018
@yf225 yf225 reopened this Nov 11, 2018
@yf225 yf225 changed the title [WIP] Move autograd metadata from VariableImpl to TensorImpl Move autograd metadata from VariableImpl to TensorImpl Nov 12, 2018
@yf225 yf225 requested review from gchanan and ezyang November 12, 2018 19:28
@yf225 yf225 mentioned this pull request Nov 13, 2018
22 tasks
@yf225 yf225 force-pushed the tensorimpl_autogradmeta branch 2 times, most recently from 82d723d to 4e4a9f7 Compare November 13, 2018 18:21
@yf225 yf225 force-pushed the tensorimpl_autogradmeta branch 6 times, most recently from ac524cd to 6d74b6a Compare November 13, 2018 20:48
torch/csrc/autograd/variable.cpp Show resolved Hide resolved
torch/csrc/autograd/variable.h Outdated Show resolved Hide resolved
torch/csrc/autograd/variable.h Outdated Show resolved Hide resolved
torch/csrc/autograd/variable.cpp Outdated Show resolved Hide resolved
torch/tensor.py Outdated Show resolved Hide resolved
aten/src/ATen/core/TensorImpl.h Outdated Show resolved Hide resolved
aten/src/ATen/core/TensorImpl.h Outdated Show resolved Hide resolved
aten/src/ATen/core/TensorImpl.h Outdated Show resolved Hide resolved
aten/src/ATen/core/TensorImpl.h Outdated Show resolved Hide resolved
aten/src/TH/THTensor.cpp Outdated Show resolved Hide resolved
@@ -1710,6 +1710,16 @@ def test_is_nonzero(self):
with self.assertRaisesRegex(RuntimeError, "bool value of Tensor with no values is ambiguous"):
torch.sparse_coo_tensor(([0, 1],), self.ValueTensor(2, 0), (4, 0)).is_nonzero()

def test_allow_size_or_storage_change(self):
def do_test(t):
a = self.SparseTensor(3, 3)

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -9321,6 +9321,26 @@ def test_reverse_binary_ops_multiple_device(self):
with self.assertRaisesRegex(RuntimeError, "expected both inputs to be on same device"):
torch.tensor(2).to("cuda:1") // torch.tensor(3).to("cuda:0")

def test_allow_size_or_storage_change(self):

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -105,7 +105,7 @@ static int THPVariable_traverse(THPVariable *self, visitproc visit, void *arg)
static int THPVariable_clear(THPVariable *self)
{
Py_CLEAR(self->backward_hooks);
if (self->cdata.defined()) {
if (self->cdata.defined() && self->cdata.get_autograd_meta()) {

This comment was marked as off-topic.

This comment was marked as off-topic.

torch/csrc/autograd/python_variable.cpp Outdated Show resolved Hide resolved
@yf225 yf225 force-pushed the tensorimpl_autogradmeta branch 3 times, most recently from aa46050 to 46e7780 Compare November 15, 2018 17:29
@yf225 yf225 force-pushed the tensorimpl_autogradmeta branch 4 times, most recently from 52b58f7 to ba1abc4 Compare November 20, 2018 05:09
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ezyang
Copy link
Contributor

ezyang commented Dec 18, 2018

Hi Will, do you want a rereview? If so, what's new?

@yf225
Copy link
Contributor Author

yf225 commented Dec 18, 2018

@ezyang the new change is mostly 859aa38, where we pass an argument to make_variable()/make_variable_view() to control the value of allow_tensor_metadata_change. Also wanted to wait for Greg's another pass at the PR. cc @gchanan

@@ -4859,7 +4859,7 @@ class DerivedStateModule(torch.jit.ScriptModule):
def __init__(self):
super(TestScript.DerivedStateModule, self).__init__()
self.param = torch.nn.Parameter(torch.ones(3, 4, dtype=torch.float))
self.register_buffer('derived', torch.neg(self.param).detach())
self.register_buffer('derived', torch.neg(self.param).detach().clone())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's going on here? Is the detached thing being inplace modified later?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch/tensor.py Outdated Show resolved Hide resolved
Copy link
Contributor

@gchanan gchanan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be nice if you could track down more places where we could set allow_tensor_metadata_change (is that even possible?). But that could be a follow-up (again, assuming it's possible).

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@yf225
Copy link
Contributor Author

yf225 commented Dec 22, 2018

Ran into a weird use-after-free error in internal tests. Currently investigating.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Dec 27, 2018
Summary:
Changes originally in this PR:
1. Move Variable::Impl data members into TensorImpl as `AutogradMeta` struct
2. Change Variable::Impl functions to use data members in `AutogradMeta` struct
3. Add `shallow_copy_and_detach()` function to each subclass of TensorImpl
4. Do shallow copy when the user calls `make_variable(tensor)` / `make_variable_view(tensor)` / `variable.set_data(tensor)` / `variable.detach()`

Changes moved from pytorch/pytorch#13645:
1. Add a flag to Variable to disallow size/stride/storage_ptr changes from in-place operations such as `resize_` / `resize_as_` / `set_` / `transpose_`, and set this flag to true when people call `tensor.data` in Python.
2. Write text in the docs to actively discourage changing the shape or storage of `tensor_detached` and expecting `tensor` to also be updated.

This is the 1st+2nd PR mentioned in pytorch/pytorch#13638.
Pull Request resolved: pytorch/pytorch#13827

Differential Revision: D13507173

Pulled By: yf225

fbshipit-source-id: b177b08438d534a8197e34e1ad4a837e2db0ed6a
@ezyang ezyang added the merged label Jun 25, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants