Skip to content

Commit

Permalink
Error when torch.load-ing a JIT model (pytorch#15578)
Browse files Browse the repository at this point in the history
Summary:
Throw a warning when calling `torch.load` on a zip file

Fixes pytorch#15570
Pull Request resolved: pytorch#15578

Differential Revision: D13555954

Pulled By: driazati

fbshipit-source-id: a37ecdb3dd0c23eff809f86e2f8b74cd48ff7277
  • Loading branch information
David Riazati authored and facebook-github-bot committed Dec 28, 2018
1 parent fb22f76 commit 692898f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
16 changes: 16 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2070,6 +2070,22 @@ def fn(x):
warns = [str(w.message) for w in warns]
self.assertEqual(len(warns), 0)

@unittest.skipIf(sys.platform == "win32", "TODO: need to fix this test case for Windows")
def test_torch_load_error(self):
class J(torch.jit.ScriptModule):
def __init__(self):
super(J, self).__init__()

@torch.jit.script_method
def forward(self, input):
return input + 100

j = J()
with tempfile.NamedTemporaryFile() as f:
j.save(f.name)
with self.assertRaisesRegex(RuntimeError, "is a zip"):
torch.load(f.name)


class TestBatched(TestCase):
# generate random examples and create an batchtensor with them
Expand Down
4 changes: 4 additions & 0 deletions torch/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys
import torch
import tarfile
import zipfile
import tempfile
import warnings
from contextlib import closing, contextmanager
Expand Down Expand Up @@ -546,6 +547,9 @@ def persistent_load(saved_id):
try:
return legacy_load(f)
except tarfile.TarError:
if zipfile.is_zipfile(f):
# .zip is used for torch.jit.save and will throw an un-pickling error here
raise RuntimeError("{} is a zip archive (did you mean to use torch.jit.load()?)".format(f.name))
# if not a tarfile, reset file offset and proceed
f.seek(0)

Expand Down

0 comments on commit 692898f

Please sign in to comment.