Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit cd15d7b

Browse filesBrowse files
Prevent _legacy_load with weights_only=True (#144993)
Prevent _legacy_load with weights_only=True (#144914) Pull Request resolved: #144914 Approved by: https://github.com/malfet, https://github.com/albanD (cherry picked from commit 7c3aa1d) Co-authored-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
1 parent a2639bc commit cd15d7b
Copy full SHA for cd15d7b

File tree

3 files changed

+48
-29
lines changed
Filter options

3 files changed

+48
-29
lines changed

‎test/quantization/bc/test_backward_compatibility.py

Copy file name to clipboardExpand all lines: test/quantization/bc/test_backward_compatibility.py
+4-2Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,14 @@ def _test_op(
112112
torch.jit.save(torch.jit.trace(qmodule, input_tensor), traced_module_file)
113113
torch.save(qmodule(input_tensor), expected_file)
114114

115-
input_tensor = torch.load(input_file)
115+
# weights_only=False as file was saved in .tar format
116+
input_tensor = torch.load(input_file, weights_only=False)
116117
# weights_only = False as sometimes get ScriptObject here
117118
qmodule.load_state_dict(torch.load(state_dict_file, weights_only=False))
118119
qmodule_scripted = torch.jit.load(scripted_module_file)
119120
qmodule_traced = torch.jit.load(traced_module_file)
120-
expected = torch.load(expected_file)
121+
# weights_only=False as file was saved in .tar format
122+
expected = torch.load(expected_file, weights_only=False)
121123
self.assertEqual(qmodule(input_tensor), expected, atol=prec)
122124
self.assertEqual(qmodule_scripted(input_tensor), expected, atol=prec)
123125
self.assertEqual(qmodule_traced(input_tensor), expected, atol=prec)

‎test/test_serialization.py

Copy file name to clipboardExpand all lines: test/test_serialization.py
+40-18Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,6 @@ def _test_serialization(self, weights_only):
224224
def test_serialization(self):
225225
self._test_serialization(False)
226226

227-
def test_serialization_safe(self):
228-
self._test_serialization(True)
229-
230227
def test_serialization_filelike(self):
231228
# Test serialization (load and save) with a filelike object
232229
b = self._test_serialization_data()
@@ -362,9 +359,6 @@ def _test_serialization(conversion):
362359
def test_serialization_sparse(self):
363360
self._test_serialization(False)
364361

365-
def test_serialization_sparse_safe(self):
366-
self._test_serialization(True)
367-
368362
def test_serialization_sparse_invalid(self):
369363
x = torch.zeros(3, 3)
370364
x[1][1] = 1
@@ -510,9 +504,6 @@ def __reduce__(self):
510504
def test_serialization_backwards_compat(self):
511505
self._test_serialization_backwards_compat(False)
512506

513-
def test_serialization_backwards_compat_safe(self):
514-
self._test_serialization_backwards_compat(True)
515-
516507
def test_serialization_save_warnings(self):
517508
with warnings.catch_warnings(record=True) as warns:
518509
with tempfile.NamedTemporaryFile() as checkpoint:
@@ -557,7 +548,8 @@ def load_bytes():
557548
def check_map_locations(map_locations, dtype, intended_device):
558549
for fileobject_lambda in fileobject_lambdas:
559550
for map_location in map_locations:
560-
tensor = torch.load(fileobject_lambda(), map_location=map_location)
551+
# weigts_only=False as the downloaded file path uses the old serialization format
552+
tensor = torch.load(fileobject_lambda(), map_location=map_location, weights_only=False)
561553

562554
self.assertEqual(tensor.device, intended_device)
563555
self.assertEqual(tensor.dtype, dtype)
@@ -600,7 +592,8 @@ def test_load_nonexistent_device(self):
600592

601593
error_msg = r'Attempting to deserialize object on a CUDA device'
602594
with self.assertRaisesRegex(RuntimeError, error_msg):
603-
_ = torch.load(buf)
595+
# weights_only=False as serialized is in legacy format
596+
_ = torch.load(buf, weights_only=False)
604597

605598
@unittest.skipIf((3, 8, 0) <= sys.version_info < (3, 8, 2), "See https://bugs.python.org/issue39681")
606599
def test_serialization_filelike_api_requirements(self):
@@ -720,7 +713,8 @@ def test_serialization_storage_slice(self):
720713
b'\x00\x00\x00\x00')
721714

722715
buf = io.BytesIO(serialized)
723-
(s1, s2) = torch.load(buf)
716+
# serialized was saved with PyTorch 0.3.1
717+
(s1, s2) = torch.load(buf, weights_only=False)
724718
self.assertEqual(s1[0], 0)
725719
self.assertEqual(s2[0], 0)
726720
self.assertEqual(s1.data_ptr() + 4, s2.data_ptr())
@@ -837,6 +831,24 @@ def wrapper(*args, **kwargs):
837831
def __exit__(self, *args, **kwargs):
838832
torch.save = self.torch_save
839833

834+
835+
# used to set weights_only=False in _use_new_zipfile_serialization=False tests
836+
class load_method:
837+
def __init__(self, weights_only):
838+
self.weights_only = weights_only
839+
self.torch_load = torch.load
840+
841+
def __enter__(self, *args, **kwargs):
842+
def wrapper(*args, **kwargs):
843+
kwargs['weights_only'] = self.weights_only
844+
return self.torch_load(*args, **kwargs)
845+
846+
torch.load = wrapper
847+
848+
def __exit__(self, *args, **kwargs):
849+
torch.load = self.torch_load
850+
851+
840852
Point = namedtuple('Point', ['x', 'y'])
841853

842854
class ClassThatUsesBuildInstruction:
@@ -873,14 +885,25 @@ def test(f_new, f_old):
873885

874886
torch.save(x, f_old, _use_new_zipfile_serialization=False)
875887
f_old.seek(0)
876-
x_old_load = torch.load(f_old, weights_only=weights_only)
888+
x_old_load = torch.load(f_old, weights_only=False)
877889
self.assertEqual(x_old_load, x_new_load)
878890

879891
with AlwaysWarnTypedStorageRemoval(True), warnings.catch_warnings(record=True) as w:
880892
with tempfile.NamedTemporaryFile() as f_new, tempfile.NamedTemporaryFile() as f_old:
881893
test(f_new, f_old)
882894
self.assertTrue(len(w) == 0, msg=f"Expected no warnings but got {[str(x) for x in w]}")
883895

896+
def test_old_serialization_fails_with_weights_only(self):
897+
a = torch.randn(5, 5)
898+
with BytesIOContext() as f:
899+
torch.save(a, f, _use_new_zipfile_serialization=False)
900+
f.seek(0)
901+
with self.assertRaisesRegex(
902+
RuntimeError,
903+
"Cannot use ``weights_only=True`` with files saved in the .tar format used before version 1.6."
904+
):
905+
torch.load(f, weights_only=True)
906+
884907

885908
class TestOldSerialization(TestCase, SerializationMixin):
886909
# unique_key is necessary because on Python 2.7, if a warning passed to
@@ -956,8 +979,7 @@ def test_serialization_offset(self):
956979
self.assertEqual(i, i_loaded)
957980
self.assertEqual(j, j_loaded)
958981

959-
@parametrize('weights_only', (True, False))
960-
def test_serialization_offset_filelike(self, weights_only):
982+
def test_serialization_offset_filelike(self):
961983
a = torch.randn(5, 5)
962984
b = torch.randn(1024, 1024, 512, dtype=torch.float32)
963985
i, j = 41, 43
@@ -969,16 +991,16 @@ def test_serialization_offset_filelike(self, weights_only):
969991
self.assertTrue(f.tell() > 2 * 1024 * 1024 * 1024)
970992
f.seek(0)
971993
i_loaded = pickle.load(f)
972-
a_loaded = torch.load(f, weights_only=weights_only)
994+
a_loaded = torch.load(f)
973995
j_loaded = pickle.load(f)
974-
b_loaded = torch.load(f, weights_only=weights_only)
996+
b_loaded = torch.load(f)
975997
self.assertTrue(torch.equal(a, a_loaded))
976998
self.assertTrue(torch.equal(b, b_loaded))
977999
self.assertEqual(i, i_loaded)
9781000
self.assertEqual(j, j_loaded)
9791001

9801002
def run(self, *args, **kwargs):
981-
with serialization_method(use_zip=False):
1003+
with serialization_method(use_zip=False), load_method(weights_only=False):
9821004
return super().run(*args, **kwargs)
9831005

9841006

‎torch/serialization.py

Copy file name to clipboardExpand all lines: torch/serialization.py
+4-9Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1482,15 +1482,10 @@ def _get_wo_message(message: str) -> str:
14821482
"please torch.save your checkpoint with this option in order to use mmap."
14831483
)
14841484
if weights_only:
1485-
try:
1486-
return _legacy_load(
1487-
opened_file,
1488-
map_location,
1489-
_weights_only_unpickler,
1490-
**pickle_load_args,
1491-
)
1492-
except pickle.UnpicklingError as e:
1493-
raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
1485+
raise RuntimeError(
1486+
"Cannot use ``weights_only=True`` with files saved in the "
1487+
".tar format used before version 1.6. " + UNSAFE_MESSAGE
1488+
)
14941489
return _legacy_load(
14951490
opened_file, map_location, pickle_module, **pickle_load_args
14961491
)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.