Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 50b888d

Browse filesBrowse files
Allow NEWOBJ instruction for items added via torch.serialization.add_safe_globals
ghstack-source-id: 34a8fc3 Pull Request resolved: #129251
1 parent cc99c01 commit 50b888d
Copy full SHA for 50b888d

File tree

Expand file treeCollapse file tree

4 files changed

+85
-5
lines changed
Filter options
Expand file treeCollapse file tree

4 files changed

+85
-5
lines changed

‎test/distributed/_tensor/test_dtensor.py

Copy file name to clipboardExpand all lines: test/distributed/_tensor/test_dtensor.py
+10Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,16 @@ def test_dtensor_save_load(self):
536536
buffer.seek(0)
537537
reloaded_st = torch.load(buffer)
538538
self.assertEqual(sharded_tensor, reloaded_st)
539+
# Test weights_only load
540+
try:
541+
torch.serialization.add_safe_globals(
542+
[DTensor, DeviceMesh, Shard, DTensorSpec, TensorMeta]
543+
)
544+
buffer.seek(0)
545+
reloaded_st = torch.load(buffer, weights_only=True)
546+
self.assertEqual(sharded_tensor, reloaded_st)
547+
finally:
548+
torch.serialization.clear_safe_globals()
539549

540550

541551
class DTensorMeshTest(DTensorTestBase):

‎test/test_serialization.py

Copy file name to clipboardExpand all lines: test/test_serialization.py
+61-1Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import shutil
1616
import pathlib
1717
import platform
18-
from collections import OrderedDict
18+
from collections import namedtuple, OrderedDict
1919
from copy import deepcopy
2020
from itertools import product
2121

@@ -804,6 +804,17 @@ def wrapper(*args, **kwargs):
804804
def __exit__(self, *args, **kwargs):
805805
torch.save = self.torch_save
806806

807+
Point = namedtuple('Point', ['x', 'y'])
808+
809+
class ClassThatUsesBuildInstruction:
810+
def __init__(self, num):
811+
self.num = num
812+
813+
def __reduce_ex__(self, proto):
814+
# Third item, state here will cause pickle to push a BUILD instruction
815+
return ClassThatUsesBuildInstruction, (self.num,), {'foo': 'bar'}
816+
817+
807818
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
808819
class TestBothSerialization(TestCase):
809820
@parametrize("weights_only", (True, False))
@@ -1049,6 +1060,55 @@ def __reduce__(self):
10491060
finally:
10501061
torch.serialization.clear_safe_globals()
10511062

1063+
def test_weights_only_safe_globals_newobj(self):
1064+
# This will use NEWOBJ
1065+
p = Point(x=1, y=2)
1066+
with BytesIOContext() as f:
1067+
torch.save(p, f)
1068+
f.seek(0)
1069+
with self.assertRaisesRegex(pickle.UnpicklingError,
1070+
"GLOBAL __main__.Point was not an allowed global by default"):
1071+
torch.load(f, weights_only=True)
1072+
f.seek(0)
1073+
try:
1074+
torch.serialization.add_safe_globals([Point])
1075+
loaded_p = torch.load(f, weights_only=True)
1076+
self.assertEqual(loaded_p, p)
1077+
finally:
1078+
torch.serialization.clear_safe_globals()
1079+
1080+
def test_weights_only_safe_globals_build(self):
1081+
counter = 0
1082+
1083+
def fake_set_state(obj, *args):
1084+
nonlocal counter
1085+
counter += 1
1086+
1087+
c = ClassThatUsesBuildInstruction(2)
1088+
with BytesIOContext() as f:
1089+
torch.save(c, f)
1090+
f.seek(0)
1091+
with self.assertRaisesRegex(pickle.UnpicklingError,
1092+
"GLOBAL __main__.ClassThatUsesBuildInstruction was not an allowed global by default"):
1093+
torch.load(f, weights_only=True)
1094+
try:
1095+
torch.serialization.add_safe_globals([ClassThatUsesBuildInstruction])
1096+
# Test dict update path
1097+
f.seek(0)
1098+
loaded_c = torch.load(f, weights_only=True)
1099+
self.assertEqual(loaded_c.num, 2)
1100+
self.assertEqual(loaded_c.foo, 'bar')
1101+
# Test setstate path
1102+
ClassThatUsesBuildInstruction.__setstate__ = fake_set_state
1103+
f.seek(0)
1104+
loaded_c = torch.load(f, weights_only=True)
1105+
self.assertEqual(loaded_c.num, 2)
1106+
self.assertEqual(counter, 1)
1107+
self.assertFalse(hasattr(loaded_c, 'foo'))
1108+
finally:
1109+
torch.serialization.clear_safe_globals()
1110+
ClassThatUsesBuildInstruction.__setstate__ = None
1111+
10521112
@parametrize('weights_only', (False, True))
10531113
def test_serialization_math_bits(self, weights_only):
10541114
t = torch.randn(1, dtype=torch.cfloat)

‎torch/_weights_only_unpickler.py

Copy file name to clipboardExpand all lines: torch/_weights_only_unpickler.py
+11-3Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,12 @@ def load(self):
231231
elif key[0] == NEWOBJ[0]:
232232
args = self.stack.pop()
233233
cls = self.stack.pop()
234-
if cls is not torch.nn.Parameter:
234+
if cls is torch.nn.Parameter:
235+
self.append(torch.nn.Parameter(*args))
236+
elif cls in _get_user_allowed_globals().values():
237+
self.append(cls.__new__(cls, *args))
238+
else:
235239
raise RuntimeError(f"Trying to instantiate unsupported class {cls}")
236-
self.append(torch.nn.Parameter(*args))
237240
elif key[0] == REDUCE[0]:
238241
args = self.stack.pop()
239242
func = self.stack[-1]
@@ -255,9 +258,14 @@ def load(self):
255258
inst.__setstate__(state)
256259
elif type(inst) is OrderedDict:
257260
inst.__dict__.update(state)
261+
elif type(inst) in _get_user_allowed_globals().values():
262+
if hasattr(inst, "__setstate__"):
263+
inst.__setstate__(state)
264+
else:
265+
inst.__dict__.update(state)
258266
else:
259267
raise RuntimeError(
260-
f"Can only build Tensor, parameter or dict objects, but got {type(inst)}"
268+
f"Can only build Tensor, parameter or OrderedDict objects, but got {type(inst)}"
261269
)
262270
# Stack manipulation
263271
elif key[0] == APPEND[0]:

‎torch/serialization.py

Copy file name to clipboardExpand all lines: torch/serialization.py
+3-1Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,9 @@ def get_safe_globals() -> List[Any]:
203203

204204
def add_safe_globals(safe_globals: List[Any]) -> None:
205205
"""
206-
Marks the given globals as safe for ``weights_only`` load.
206+
Marks the given globals as safe for ``weights_only`` load. For example, functions
207+
added to this list can be called during unpickling, classes could be instantiated
208+
and have state set.
207209
208210
Args:
209211
safe_globals (List[Any]): list of globals to mark as safe

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.