Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 7c3aa1d

Browse filesBrowse files
mikaylagawareckipytorchmergebot
authored andcommitted
Prevent _legacy_load with weights_only=True (#144914)
Pull Request resolved: #144914 Approved by: https://github.com/malfet, https://github.com/albanD
1 parent cf28d61 commit 7c3aa1d
Copy full SHA for 7c3aa1d

File tree

Expand file treeCollapse file tree

3 files changed

+48
-29
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+48
-29
lines changed

‎test/quantization/bc/test_backward_compatibility.py

Copy file name to clipboardExpand all lines: test/quantization/bc/test_backward_compatibility.py
+4-2Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,14 @@ def _test_op(
110110
torch.jit.save(torch.jit.trace(qmodule, input_tensor), traced_module_file)
111111
torch.save(qmodule(input_tensor), expected_file)
112112

113-
input_tensor = torch.load(input_file)
113+
# weights_only=False as file was saved in .tar format
114+
input_tensor = torch.load(input_file, weights_only=False)
114115
# weights_only = False as sometimes get ScriptObject here
115116
qmodule.load_state_dict(torch.load(state_dict_file, weights_only=False))
116117
qmodule_scripted = torch.jit.load(scripted_module_file)
117118
qmodule_traced = torch.jit.load(traced_module_file)
118-
expected = torch.load(expected_file)
119+
# weights_only=False as file was saved in .tar format
120+
expected = torch.load(expected_file, weights_only=False)
119121
self.assertEqual(qmodule(input_tensor), expected, atol=prec)
120122
self.assertEqual(qmodule_scripted(input_tensor), expected, atol=prec)
121123
self.assertEqual(qmodule_traced(input_tensor), expected, atol=prec)

‎test/test_serialization.py

Copy file name to clipboardExpand all lines: test/test_serialization.py
+40-18Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -227,9 +227,6 @@ def _test_serialization(self, weights_only):
227227
def test_serialization(self):
228228
self._test_serialization(False)
229229

230-
def test_serialization_safe(self):
231-
self._test_serialization(True)
232-
233230
def test_serialization_filelike(self):
234231
# Test serialization (load and save) with a filelike object
235232
b = self._test_serialization_data()
@@ -366,9 +363,6 @@ def _test_serialization(conversion):
366363
def test_serialization_sparse(self):
367364
self._test_serialization(False)
368365

369-
def test_serialization_sparse_safe(self):
370-
self._test_serialization(True)
371-
372366
def test_serialization_sparse_invalid(self):
373367
x = torch.zeros(3, 3)
374368
x[1][1] = 1
@@ -514,9 +508,6 @@ def __reduce__(self):
514508
def test_serialization_backwards_compat(self):
515509
self._test_serialization_backwards_compat(False)
516510

517-
def test_serialization_backwards_compat_safe(self):
518-
self._test_serialization_backwards_compat(True)
519-
520511
def test_serialization_save_warnings(self):
521512
with warnings.catch_warnings(record=True) as warns:
522513
with tempfile.NamedTemporaryFile() as checkpoint:
@@ -561,7 +552,8 @@ def load_bytes():
561552
def check_map_locations(map_locations, dtype, intended_device):
562553
for fileobject_lambda in fileobject_lambdas:
563554
for map_location in map_locations:
564-
tensor = torch.load(fileobject_lambda(), map_location=map_location)
555+
# weigts_only=False as the downloaded file path uses the old serialization format
556+
tensor = torch.load(fileobject_lambda(), map_location=map_location, weights_only=False)
565557

566558
self.assertEqual(tensor.device, intended_device)
567559
self.assertEqual(tensor.dtype, dtype)
@@ -604,7 +596,8 @@ def test_load_nonexistent_device(self):
604596

605597
error_msg = r'Attempting to deserialize object on a CUDA device'
606598
with self.assertRaisesRegex(RuntimeError, error_msg):
607-
_ = torch.load(buf)
599+
# weights_only=False as serialized is in legacy format
600+
_ = torch.load(buf, weights_only=False)
608601

609602
@unittest.skipIf((3, 8, 0) <= sys.version_info < (3, 8, 2), "See https://bugs.python.org/issue39681")
610603
def test_serialization_filelike_api_requirements(self):
@@ -724,7 +717,8 @@ def test_serialization_storage_slice(self):
724717
b'\x00\x00\x00\x00')
725718

726719
buf = io.BytesIO(serialized)
727-
(s1, s2) = torch.load(buf)
720+
# serialized was saved with PyTorch 0.3.1
721+
(s1, s2) = torch.load(buf, weights_only=False)
728722
self.assertEqual(s1[0], 0)
729723
self.assertEqual(s2[0], 0)
730724
self.assertEqual(s1.data_ptr() + 4, s2.data_ptr())
@@ -841,6 +835,24 @@ def wrapper(*args, **kwargs):
841835
def __exit__(self, *args, **kwargs):
842836
torch.save = self.torch_save
843837

838+
839+
# used to set weights_only=False in _use_new_zipfile_serialization=False tests
840+
class load_method:
841+
def __init__(self, weights_only):
842+
self.weights_only = weights_only
843+
self.torch_load = torch.load
844+
845+
def __enter__(self, *args, **kwargs):
846+
def wrapper(*args, **kwargs):
847+
kwargs['weights_only'] = self.weights_only
848+
return self.torch_load(*args, **kwargs)
849+
850+
torch.load = wrapper
851+
852+
def __exit__(self, *args, **kwargs):
853+
torch.load = self.torch_load
854+
855+
844856
Point = namedtuple('Point', ['x', 'y'])
845857

846858
class ClassThatUsesBuildInstruction:
@@ -877,14 +889,25 @@ def test(f_new, f_old):
877889

878890
torch.save(x, f_old, _use_new_zipfile_serialization=False)
879891
f_old.seek(0)
880-
x_old_load = torch.load(f_old, weights_only=weights_only)
892+
x_old_load = torch.load(f_old, weights_only=False)
881893
self.assertEqual(x_old_load, x_new_load)
882894

883895
with AlwaysWarnTypedStorageRemoval(True), warnings.catch_warnings(record=True) as w:
884896
with tempfile.NamedTemporaryFile() as f_new, tempfile.NamedTemporaryFile() as f_old:
885897
test(f_new, f_old)
886898
self.assertTrue(len(w) == 0, msg=f"Expected no warnings but got {[str(x) for x in w]}")
887899

900+
def test_old_serialization_fails_with_weights_only(self):
901+
a = torch.randn(5, 5)
902+
with BytesIOContext() as f:
903+
torch.save(a, f, _use_new_zipfile_serialization=False)
904+
f.seek(0)
905+
with self.assertRaisesRegex(
906+
RuntimeError,
907+
"Cannot use ``weights_only=True`` with files saved in the .tar format used before version 1.6."
908+
):
909+
torch.load(f, weights_only=True)
910+
888911

889912
class TestOldSerialization(TestCase, SerializationMixin):
890913
# unique_key is necessary because on Python 2.7, if a warning passed to
@@ -960,8 +983,7 @@ def test_serialization_offset(self):
960983
self.assertEqual(i, i_loaded)
961984
self.assertEqual(j, j_loaded)
962985

963-
@parametrize('weights_only', (True, False))
964-
def test_serialization_offset_filelike(self, weights_only):
986+
def test_serialization_offset_filelike(self):
965987
a = torch.randn(5, 5)
966988
b = torch.randn(1024, 1024, 512, dtype=torch.float32)
967989
i, j = 41, 43
@@ -973,16 +995,16 @@ def test_serialization_offset_filelike(self, weights_only):
973995
self.assertTrue(f.tell() > 2 * 1024 * 1024 * 1024)
974996
f.seek(0)
975997
i_loaded = pickle.load(f)
976-
a_loaded = torch.load(f, weights_only=weights_only)
998+
a_loaded = torch.load(f)
977999
j_loaded = pickle.load(f)
978-
b_loaded = torch.load(f, weights_only=weights_only)
1000+
b_loaded = torch.load(f)
9791001
self.assertTrue(torch.equal(a, a_loaded))
9801002
self.assertTrue(torch.equal(b, b_loaded))
9811003
self.assertEqual(i, i_loaded)
9821004
self.assertEqual(j, j_loaded)
9831005

9841006
def run(self, *args, **kwargs):
985-
with serialization_method(use_zip=False):
1007+
with serialization_method(use_zip=False), load_method(weights_only=False):
9861008
return super().run(*args, **kwargs)
9871009

9881010

‎torch/serialization.py

Copy file name to clipboardExpand all lines: torch/serialization.py
+4-9Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1501,15 +1501,10 @@ def _get_wo_message(message: str) -> str:
15011501
"please torch.save your checkpoint with this option in order to use mmap."
15021502
)
15031503
if weights_only:
1504-
try:
1505-
return _legacy_load(
1506-
opened_file,
1507-
map_location,
1508-
_weights_only_unpickler,
1509-
**pickle_load_args,
1510-
)
1511-
except pickle.UnpicklingError as e:
1512-
raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
1504+
raise RuntimeError(
1505+
"Cannot use ``weights_only=True`` with files saved in the "
1506+
".tar format used before version 1.6. " + UNSAFE_MESSAGE
1507+
)
15131508
return _legacy_load(
15141509
opened_file, map_location, pickle_module, **pickle_load_args
15151510
)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.