@@ -314,16 +314,17 @@ def load_frame(self, frame_size):
314314# Tools used for pickling.
315315
316316def _getattribute (obj , name ):
317+ top = obj
317318 for subpath in name .split ('.' ):
318319 if subpath == '<locals>' :
319320 raise AttributeError ("Can't get local attribute {!r} on {!r}"
320- .format (name , obj ))
321+ .format (name , top ))
321322 try :
322323 parent = obj
323324 obj = getattr (obj , subpath )
324325 except AttributeError :
325326 raise AttributeError ("Can't get attribute {!r} on {!r}"
326- .format (name , obj )) from None
327+ .format (name , top )) from None
327328 return obj , parent
328329
329330def whichmodule (obj , name ):
@@ -396,6 +397,8 @@ def decode_long(data):
396397 return int .from_bytes (data , byteorder = 'little' , signed = True )
397398
398399
400+ _NoValue = object ()
401+
399402# Pickling machinery
400403
401404class _Pickler :
@@ -530,10 +533,11 @@ def save(self, obj, save_persistent_id=True):
530533 self .framer .commit_frame ()
531534
532535 # Check for persistent id (defined by a subclass)
533- pid = self .persistent_id (obj )
534- if pid is not None and save_persistent_id :
535- self .save_pers (pid )
536- return
536+ if save_persistent_id :
537+ pid = self .persistent_id (obj )
538+ if pid is not None :
539+ self .save_pers (pid )
540+ return
537541
538542 # Check the memo
539543 x = self .memo .get (id (obj ))
@@ -542,8 +546,8 @@ def save(self, obj, save_persistent_id=True):
542546 return
543547
544548 rv = NotImplemented
545- reduce = getattr (self , "reducer_override" , None )
546- if reduce is not None :
549+ reduce = getattr (self , "reducer_override" , _NoValue )
550+ if reduce is not _NoValue :
547551 rv = reduce (obj )
548552
549553 if rv is NotImplemented :
@@ -556,8 +560,8 @@ def save(self, obj, save_persistent_id=True):
556560
557561 # Check private dispatch table if any, or else
558562 # copyreg.dispatch_table
559- reduce = getattr (self , 'dispatch_table' , dispatch_table ).get (t )
560- if reduce is not None :
563+ reduce = getattr (self , 'dispatch_table' , dispatch_table ).get (t , _NoValue )
564+ if reduce is not _NoValue :
561565 rv = reduce (obj )
562566 else :
563567 # Check for a class with a custom metaclass; treat as regular
@@ -567,12 +571,12 @@ def save(self, obj, save_persistent_id=True):
567571 return
568572
569573 # Check for a __reduce_ex__ method, fall back to __reduce__
570- reduce = getattr (obj , "__reduce_ex__" , None )
571- if reduce is not None :
574+ reduce = getattr (obj , "__reduce_ex__" , _NoValue )
575+ if reduce is not _NoValue :
572576 rv = reduce (self .proto )
573577 else :
574- reduce = getattr (obj , "__reduce__" , None )
575- if reduce is not None :
578+ reduce = getattr (obj , "__reduce__" , _NoValue )
579+ if reduce is not _NoValue :
576580 rv = reduce ()
577581 else :
578582 raise PicklingError ("Can't pickle %r object: %r" %
@@ -780,14 +784,10 @@ def save_float(self, obj):
780784 self .write (FLOAT + repr (obj ).encode ("ascii" ) + b'\n ' )
781785 dispatch [float ] = save_float
782786
783- def save_bytes (self , obj ):
784- if self .proto < 3 :
785- if not obj : # bytes object is empty
786- self .save_reduce (bytes , (), obj = obj )
787- else :
788- self .save_reduce (codecs .encode ,
789- (str (obj , 'latin1' ), 'latin1' ), obj = obj )
790- return
787+ def _save_bytes_no_memo (self , obj ):
788+ # helper for writing bytes objects for protocol >= 3
789+ # without memoizing them
790+ assert self .proto >= 3
791791 n = len (obj )
792792 if n <= 0xff :
793793 self .write (SHORT_BINBYTES + pack ("<B" , n ) + obj )
@@ -797,28 +797,44 @@ def save_bytes(self, obj):
797797 self ._write_large_bytes (BINBYTES + pack ("<I" , n ), obj )
798798 else :
799799 self .write (BINBYTES + pack ("<I" , n ) + obj )
800+
801+ def save_bytes (self , obj ):
802+ if self .proto < 3 :
803+ if not obj : # bytes object is empty
804+ self .save_reduce (bytes , (), obj = obj )
805+ else :
806+ self .save_reduce (codecs .encode ,
807+ (str (obj , 'latin1' ), 'latin1' ), obj = obj )
808+ return
809+ self ._save_bytes_no_memo (obj )
800810 self .memoize (obj )
801811 dispatch [bytes ] = save_bytes
802812
813+ def _save_bytearray_no_memo (self , obj ):
814+ # helper for writing bytearray objects for protocol >= 5
815+ # without memoizing them
816+ assert self .proto >= 5
817+ n = len (obj )
818+ if n >= self .framer ._FRAME_SIZE_TARGET :
819+ self ._write_large_bytes (BYTEARRAY8 + pack ("<Q" , n ), obj )
820+ else :
821+ self .write (BYTEARRAY8 + pack ("<Q" , n ) + obj )
822+
803823 def save_bytearray (self , obj ):
804824 if self .proto < 5 :
805825 if not obj : # bytearray is empty
806826 self .save_reduce (bytearray , (), obj = obj )
807827 else :
808828 self .save_reduce (bytearray , (bytes (obj ),), obj = obj )
809829 return
810- n = len (obj )
811- if n >= self .framer ._FRAME_SIZE_TARGET :
812- self ._write_large_bytes (BYTEARRAY8 + pack ("<Q" , n ), obj )
813- else :
814- self .write (BYTEARRAY8 + pack ("<Q" , n ) + obj )
830+ self ._save_bytearray_no_memo (obj )
815831 self .memoize (obj )
816832 dispatch [bytearray ] = save_bytearray
817833
818834 if _HAVE_PICKLE_BUFFER :
819835 def save_picklebuffer (self , obj ):
820836 if self .proto < 5 :
821- raise PicklingError ("PickleBuffer can only pickled with "
837+ raise PicklingError ("PickleBuffer can only be pickled with "
822838 "protocol >= 5" )
823839 with obj .raw () as m :
824840 if not m .contiguous :
@@ -830,10 +846,18 @@ def save_picklebuffer(self, obj):
830846 if in_band :
831847 # Write data in-band
832848 # XXX The C implementation avoids a copy here
849+ buf = m .tobytes ()
850+ in_memo = id (buf ) in self .memo
833851 if m .readonly :
834- self .save_bytes (m .tobytes ())
852+ if in_memo :
853+ self ._save_bytes_no_memo (buf )
854+ else :
855+ self .save_bytes (buf )
835856 else :
836- self .save_bytearray (m .tobytes ())
857+ if in_memo :
858+ self ._save_bytearray_no_memo (buf )
859+ else :
860+ self .save_bytearray (buf )
837861 else :
838862 # Write data out-of-band
839863 self .write (NEXT_BUFFER )
@@ -1070,11 +1094,16 @@ def save_global(self, obj, name=None):
10701094 (obj , module_name , name ))
10711095
10721096 if self .proto >= 2 :
1073- code = _extension_registry .get ((module_name , name ))
1074- if code :
1075- assert code > 0
1097+ code = _extension_registry .get ((module_name , name ), _NoValue )
1098+ if code is not _NoValue :
10761099 if code <= 0xff :
1077- write (EXT1 + pack ("<B" , code ))
1100+ data = pack ("<B" , code )
1101+ if data == b'\0 ' :
1102+ # Should never happen in normal circumstances,
1103+ # since the type and the value of the code are
1104+ # checked in copyreg.add_extension().
1105+ raise RuntimeError ("extension code 0 is out of range" )
1106+ write (EXT1 + data )
10781107 elif code <= 0xffff :
10791108 write (EXT2 + pack ("<H" , code ))
10801109 else :
@@ -1088,11 +1117,35 @@ def save_global(self, obj, name=None):
10881117 self .save (module_name )
10891118 self .save (name )
10901119 write (STACK_GLOBAL )
1091- elif parent is not module :
1092- self .save_reduce (getattr , (parent , lastname ))
1093- elif self .proto >= 3 :
1094- write (GLOBAL + bytes (module_name , "utf-8" ) + b'\n ' +
1095- bytes (name , "utf-8" ) + b'\n ' )
1120+ elif '.' in name :
1121+ # In protocol < 4, objects with multi-part __qualname__
1122+ # are represented as
1123+ # getattr(getattr(..., attrname1), attrname2).
1124+ dotted_path = name .split ('.' )
1125+ name = dotted_path .pop (0 )
1126+ save = self .save
1127+ for attrname in dotted_path :
1128+ save (getattr )
1129+ if self .proto < 2 :
1130+ write (MARK )
1131+ self ._save_toplevel_by_name (module_name , name )
1132+ for attrname in dotted_path :
1133+ save (attrname )
1134+ if self .proto < 2 :
1135+ write (TUPLE )
1136+ else :
1137+ write (TUPLE2 )
1138+ write (REDUCE )
1139+ else :
1140+ self ._save_toplevel_by_name (module_name , name )
1141+
1142+ self .memoize (obj )
1143+
1144+ def _save_toplevel_by_name (self , module_name , name ):
1145+ if self .proto >= 3 :
1146+ # Non-ASCII identifiers are supported only with protocols >= 3.
1147+ self .write (GLOBAL + bytes (module_name , "utf-8" ) + b'\n ' +
1148+ bytes (name , "utf-8" ) + b'\n ' )
10961149 else :
10971150 if self .fix_imports :
10981151 r_name_mapping = _compat_pickle .REVERSE_NAME_MAPPING
@@ -1102,14 +1155,12 @@ def save_global(self, obj, name=None):
11021155 elif module_name in r_import_mapping :
11031156 module_name = r_import_mapping [module_name ]
11041157 try :
1105- write (GLOBAL + bytes (module_name , "ascii" ) + b'\n ' +
1106- bytes (name , "ascii" ) + b'\n ' )
1158+ self . write (GLOBAL + bytes (module_name , "ascii" ) + b'\n ' +
1159+ bytes (name , "ascii" ) + b'\n ' )
11071160 except UnicodeEncodeError :
11081161 raise PicklingError (
11091162 "can't pickle global identifier '%s.%s' using "
1110- "pickle protocol %i" % (module , name , self .proto )) from None
1111-
1112- self .memoize (obj )
1163+ "pickle protocol %i" % (module_name , name , self .proto )) from None
11131164
11141165 def save_type (self , obj ):
11151166 if obj is type (None ):
@@ -1546,9 +1597,8 @@ def load_ext4(self):
15461597 dispatch [EXT4 [0 ]] = load_ext4
15471598
15481599 def get_extension (self , code ):
1549- nil = []
1550- obj = _extension_cache .get (code , nil )
1551- if obj is not nil :
1600+ obj = _extension_cache .get (code , _NoValue )
1601+ if obj is not _NoValue :
15521602 self .append (obj )
15531603 return
15541604 key = _inverted_registry .get (code )
@@ -1705,8 +1755,8 @@ def load_build(self):
17051755 stack = self .stack
17061756 state = stack .pop ()
17071757 inst = stack [- 1 ]
1708- setstate = getattr (inst , "__setstate__" , None )
1709- if setstate is not None :
1758+ setstate = getattr (inst , "__setstate__" , _NoValue )
1759+ if setstate is not _NoValue :
17101760 setstate (state )
17111761 return
17121762 slotstate = None
0 commit comments