@@ -227,9 +227,6 @@ def _test_serialization(self, weights_only):
227
227
def test_serialization (self ):
228
228
self ._test_serialization (False )
229
229
230
- def test_serialization_safe (self ):
231
- self ._test_serialization (True )
232
-
233
230
def test_serialization_filelike (self ):
234
231
# Test serialization (load and save) with a filelike object
235
232
b = self ._test_serialization_data ()
@@ -366,9 +363,6 @@ def _test_serialization(conversion):
366
363
def test_serialization_sparse (self ):
367
364
self ._test_serialization (False )
368
365
369
- def test_serialization_sparse_safe (self ):
370
- self ._test_serialization (True )
371
-
372
366
def test_serialization_sparse_invalid (self ):
373
367
x = torch .zeros (3 , 3 )
374
368
x [1 ][1 ] = 1
@@ -514,9 +508,6 @@ def __reduce__(self):
514
508
def test_serialization_backwards_compat (self ):
515
509
self ._test_serialization_backwards_compat (False )
516
510
517
- def test_serialization_backwards_compat_safe (self ):
518
- self ._test_serialization_backwards_compat (True )
519
-
520
511
def test_serialization_save_warnings (self ):
521
512
with warnings .catch_warnings (record = True ) as warns :
522
513
with tempfile .NamedTemporaryFile () as checkpoint :
@@ -561,7 +552,8 @@ def load_bytes():
561
552
def check_map_locations (map_locations , dtype , intended_device ):
562
553
for fileobject_lambda in fileobject_lambdas :
563
554
for map_location in map_locations :
564
- tensor = torch .load (fileobject_lambda (), map_location = map_location )
555
+ # weigts_only=False as the downloaded file path uses the old serialization format
556
+ tensor = torch .load (fileobject_lambda (), map_location = map_location , weights_only = False )
565
557
566
558
self .assertEqual (tensor .device , intended_device )
567
559
self .assertEqual (tensor .dtype , dtype )
@@ -604,7 +596,8 @@ def test_load_nonexistent_device(self):
604
596
605
597
error_msg = r'Attempting to deserialize object on a CUDA device'
606
598
with self .assertRaisesRegex (RuntimeError , error_msg ):
607
- _ = torch .load (buf )
599
+ # weights_only=False as serialized is in legacy format
600
+ _ = torch .load (buf , weights_only = False )
608
601
609
602
@unittest .skipIf ((3 , 8 , 0 ) <= sys .version_info < (3 , 8 , 2 ), "See https://bugs.python.org/issue39681" )
610
603
def test_serialization_filelike_api_requirements (self ):
@@ -724,7 +717,8 @@ def test_serialization_storage_slice(self):
724
717
b'\x00 \x00 \x00 \x00 ' )
725
718
726
719
buf = io .BytesIO (serialized )
727
- (s1 , s2 ) = torch .load (buf )
720
+ # serialized was saved with PyTorch 0.3.1
721
+ (s1 , s2 ) = torch .load (buf , weights_only = False )
728
722
self .assertEqual (s1 [0 ], 0 )
729
723
self .assertEqual (s2 [0 ], 0 )
730
724
self .assertEqual (s1 .data_ptr () + 4 , s2 .data_ptr ())
@@ -841,6 +835,24 @@ def wrapper(*args, **kwargs):
841
835
def __exit__ (self , * args , ** kwargs ):
842
836
torch .save = self .torch_save
843
837
838
+
839
+ # used to set weights_only=False in _use_new_zipfile_serialization=False tests
840
+ class load_method :
841
+ def __init__ (self , weights_only ):
842
+ self .weights_only = weights_only
843
+ self .torch_load = torch .load
844
+
845
+ def __enter__ (self , * args , ** kwargs ):
846
+ def wrapper (* args , ** kwargs ):
847
+ kwargs ['weights_only' ] = self .weights_only
848
+ return self .torch_load (* args , ** kwargs )
849
+
850
+ torch .load = wrapper
851
+
852
+ def __exit__ (self , * args , ** kwargs ):
853
+ torch .load = self .torch_load
854
+
855
+
844
856
Point = namedtuple ('Point' , ['x' , 'y' ])
845
857
846
858
class ClassThatUsesBuildInstruction :
@@ -877,14 +889,25 @@ def test(f_new, f_old):
877
889
878
890
torch .save (x , f_old , _use_new_zipfile_serialization = False )
879
891
f_old .seek (0 )
880
- x_old_load = torch .load (f_old , weights_only = weights_only )
892
+ x_old_load = torch .load (f_old , weights_only = False )
881
893
self .assertEqual (x_old_load , x_new_load )
882
894
883
895
with AlwaysWarnTypedStorageRemoval (True ), warnings .catch_warnings (record = True ) as w :
884
896
with tempfile .NamedTemporaryFile () as f_new , tempfile .NamedTemporaryFile () as f_old :
885
897
test (f_new , f_old )
886
898
self .assertTrue (len (w ) == 0 , msg = f"Expected no warnings but got { [str (x ) for x in w ]} " )
887
899
900
+ def test_old_serialization_fails_with_weights_only (self ):
901
+ a = torch .randn (5 , 5 )
902
+ with BytesIOContext () as f :
903
+ torch .save (a , f , _use_new_zipfile_serialization = False )
904
+ f .seek (0 )
905
+ with self .assertRaisesRegex (
906
+ RuntimeError ,
907
+ "Cannot use ``weights_only=True`` with files saved in the .tar format used before version 1.6."
908
+ ):
909
+ torch .load (f , weights_only = True )
910
+
888
911
889
912
class TestOldSerialization (TestCase , SerializationMixin ):
890
913
# unique_key is necessary because on Python 2.7, if a warning passed to
@@ -960,8 +983,7 @@ def test_serialization_offset(self):
960
983
self .assertEqual (i , i_loaded )
961
984
self .assertEqual (j , j_loaded )
962
985
963
- @parametrize ('weights_only' , (True , False ))
964
- def test_serialization_offset_filelike (self , weights_only ):
986
+ def test_serialization_offset_filelike (self ):
965
987
a = torch .randn (5 , 5 )
966
988
b = torch .randn (1024 , 1024 , 512 , dtype = torch .float32 )
967
989
i , j = 41 , 43
@@ -973,16 +995,16 @@ def test_serialization_offset_filelike(self, weights_only):
973
995
self .assertTrue (f .tell () > 2 * 1024 * 1024 * 1024 )
974
996
f .seek (0 )
975
997
i_loaded = pickle .load (f )
976
- a_loaded = torch .load (f , weights_only = weights_only )
998
+ a_loaded = torch .load (f )
977
999
j_loaded = pickle .load (f )
978
- b_loaded = torch .load (f , weights_only = weights_only )
1000
+ b_loaded = torch .load (f )
979
1001
self .assertTrue (torch .equal (a , a_loaded ))
980
1002
self .assertTrue (torch .equal (b , b_loaded ))
981
1003
self .assertEqual (i , i_loaded )
982
1004
self .assertEqual (j , j_loaded )
983
1005
984
1006
def run (self , * args , ** kwargs ):
985
- with serialization_method (use_zip = False ):
1007
+ with serialization_method (use_zip = False ), load_method ( weights_only = False ) :
986
1008
return super ().run (* args , ** kwargs )
987
1009
988
1010
0 commit comments