diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 5c1b1b6637..12308c1dbc 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -315,7 +315,7 @@ jobs: with: python-version: ${{ env.PYTHON_VERSION }} - name: install ruff - run: python -m pip install ruff + run: python -m pip install ruff==0.0.291 # astral-sh/ruff#7778 - name: run python lint run: ruff extra_tests wasm examples --exclude='./.*',./Lib,./vm/Lib,./benches/ --select=E9,F63,F7,F82 --show-source - name: install prettier diff --git a/Lib/enum.py b/Lib/enum.py index 31afdd3a24..625e9ea56a 100644 --- a/Lib/enum.py +++ b/Lib/enum.py @@ -1,14 +1,40 @@ import sys +import builtins as bltns from types import MappingProxyType, DynamicClassAttribute +from operator import or_ as _or_ +from functools import reduce __all__ = [ - 'EnumMeta', - 'Enum', 'IntEnum', 'Flag', 'IntFlag', - 'auto', 'unique', + 'EnumType', 'EnumMeta', + 'Enum', 'IntEnum', 'StrEnum', 'Flag', 'IntFlag', 'ReprEnum', + 'auto', 'unique', 'property', 'verify', 'member', 'nonmember', + 'FlagBoundary', 'STRICT', 'CONFORM', 'EJECT', 'KEEP', + 'global_flag_repr', 'global_enum_repr', 'global_str', 'global_enum', + 'EnumCheck', 'CONTINUOUS', 'NAMED_FLAGS', 'UNIQUE', + 'pickle_by_global_name', 'pickle_by_enum_name', ] +# Dummy value for Enum and Flag as there are explicit checks for them +# before they have been created. +# This is also why there are checks in EnumType like `if Enum is not None` +Enum = Flag = EJECT = _stdlib_enums = ReprEnum = None + +class nonmember(object): + """ + Protects item from becoming an Enum member during class creation. + """ + def __init__(self, value): + self.value = value + +class member(object): + """ + Forces item to become an Enum member during class creation. + """ + def __init__(self, value): + self.value = value + def _is_descriptor(obj): """ Returns True if obj is a descriptor, False otherwise. @@ -41,33 +67,297 @@ def _is_sunder(name): name[-2:-1] != '_' ) -def _make_class_unpicklable(cls): +def _is_internal_class(cls_name, obj): + # do not use `re` as `re` imports `enum` + if not isinstance(obj, type): + return False + qualname = getattr(obj, '__qualname__', '') + s_pattern = cls_name + '.' + getattr(obj, '__name__', '') + e_pattern = '.' + s_pattern + return qualname == s_pattern or qualname.endswith(e_pattern) + +def _is_private(cls_name, name): + # do not use `re` as `re` imports `enum` + pattern = '_%s__' % (cls_name, ) + pat_len = len(pattern) + if ( + len(name) > pat_len + and name.startswith(pattern) + and name[pat_len:pat_len+1] != ['_'] + and (name[-1] != '_' or name[-2] != '_') + ): + return True + else: + return False + +def _is_single_bit(num): """ - Make the given class un-picklable. + True if only one bit set in num (should be an int) + """ + if num == 0: + return False + num &= num - 1 + return num == 0 + +def _make_class_unpicklable(obj): + """ + Make the given obj un-picklable. + + obj should be either a dictionary, or an Enum """ def _break_on_call_reduce(self, proto): raise TypeError('%r cannot be pickled' % self) - cls.__reduce_ex__ = _break_on_call_reduce - cls.__module__ = '' + if isinstance(obj, dict): + obj['__reduce_ex__'] = _break_on_call_reduce + obj['__module__'] = '' + else: + setattr(obj, '__reduce_ex__', _break_on_call_reduce) + setattr(obj, '__module__', '') + +def _iter_bits_lsb(num): + # num must be a positive integer + original = num + if isinstance(num, Enum): + num = num.value + if num < 0: + raise ValueError('%r is not a positive integer' % original) + while num: + b = num & (~num + 1) + yield b + num ^= b + +def show_flag_values(value): + return list(_iter_bits_lsb(value)) + +def bin(num, max_bits=None): + """ + Like built-in bin(), except negative values are represented in + twos-compliment, and the leading bit always indicates sign + (0=positive, 1=negative). + + >>> bin(10) + '0b0 1010' + >>> bin(~10) # ~10 is -11 + '0b1 0101' + """ + + ceiling = 2 ** (num).bit_length() + if num >= 0: + s = bltns.bin(num + ceiling).replace('1', '0', 1) + else: + s = bltns.bin(~num ^ (ceiling - 1) + ceiling) + sign = s[:3] + digits = s[3:] + if max_bits is not None: + if len(digits) < max_bits: + digits = (sign[-1] * max_bits + digits)[-max_bits:] + return "%s %s" % (sign, digits) + +def _dedent(text): + """ + Like textwrap.dedent. Rewritten because we cannot import textwrap. + """ + lines = text.split('\n') + blanks = 0 + for i, ch in enumerate(lines[0]): + if ch != ' ': + break + for j, l in enumerate(lines): + lines[j] = l[i:] + return '\n'.join(lines) + +class _auto_null: + def __repr__(self): + return '_auto_null' +_auto_null = _auto_null() -_auto_null = object() class auto: """ Instances are replaced with an appropriate value in Enum class suites. """ - value = _auto_null + def __init__(self, value=_auto_null): + self.value = value + + def __repr__(self): + return "auto(%r)" % self.value + +class property(DynamicClassAttribute): + """ + This is a descriptor, used to define attributes that act differently + when accessed through an enum member and through an enum class. + Instance access is the same as property(), but access to an attribute + through the enum class will instead look in the class' _member_map_ for + a corresponding enum member. + """ + + def __get__(self, instance, ownerclass=None): + if instance is None: + try: + return ownerclass._member_map_[self.name] + except KeyError: + raise AttributeError( + '%r has no attribute %r' % (ownerclass, self.name) + ) + else: + if self.fget is None: + # look for a member by this name. + try: + return ownerclass._member_map_[self.name] + except KeyError: + raise AttributeError( + '%r has no attribute %r' % (ownerclass, self.name) + ) from None + else: + return self.fget(instance) + + def __set__(self, instance, value): + if self.fset is None: + raise AttributeError( + " cannot set attribute %r" % (self.clsname, self.name) + ) + else: + return self.fset(instance, value) + + def __delete__(self, instance): + if self.fdel is None: + raise AttributeError( + " cannot delete attribute %r" % (self.clsname, self.name) + ) + else: + return self.fdel(instance) + + def __set_name__(self, ownerclass, name): + self.name = name + self.clsname = ownerclass.__name__ + + +class _proto_member: + """ + intermediate step for enum members between class execution and final creation + """ + + def __init__(self, value): + self.value = value + + def __set_name__(self, enum_class, member_name): + """ + convert each quasi-member into an instance of the new enum class + """ + # first step: remove ourself from enum_class + delattr(enum_class, member_name) + # second step: create member based on enum_class + value = self.value + if not isinstance(value, tuple): + args = (value, ) + else: + args = value + if enum_class._member_type_ is tuple: # special case for tuple enums + args = (args, ) # wrap it one more time + if not enum_class._use_args_: + enum_member = enum_class._new_member_(enum_class) + else: + enum_member = enum_class._new_member_(enum_class, *args) + if not hasattr(enum_member, '_value_'): + if enum_class._member_type_ is object: + enum_member._value_ = value + else: + try: + enum_member._value_ = enum_class._member_type_(*args) + except Exception as exc: + new_exc = TypeError( + '_value_ not set in __new__, unable to create it' + ) + new_exc.__cause__ = exc + raise new_exc + value = enum_member._value_ + enum_member._name_ = member_name + enum_member.__objclass__ = enum_class + enum_member.__init__(*args) + enum_member._sort_order_ = len(enum_class._member_names_) + + if Flag is not None and issubclass(enum_class, Flag): + enum_class._flag_mask_ |= value + if _is_single_bit(value): + enum_class._singles_mask_ |= value + enum_class._all_bits_ = 2 ** ((enum_class._flag_mask_).bit_length()) - 1 + + # If another member with the same value was already defined, the + # new member becomes an alias to the existing one. + try: + try: + # try to do a fast lookup to avoid the quadratic loop + enum_member = enum_class._value2member_map_[value] + except TypeError: + for name, canonical_member in enum_class._member_map_.items(): + if canonical_member._value_ == value: + enum_member = canonical_member + break + else: + raise KeyError + except KeyError: + # this could still be an alias if the value is multi-bit and the + # class is a flag class + if ( + Flag is None + or not issubclass(enum_class, Flag) + ): + # no other instances found, record this member in _member_names_ + enum_class._member_names_.append(member_name) + elif ( + Flag is not None + and issubclass(enum_class, Flag) + and _is_single_bit(value) + ): + # no other instances found, record this member in _member_names_ + enum_class._member_names_.append(member_name) + # if necessary, get redirect in place and then add it to _member_map_ + found_descriptor = None + for base in enum_class.__mro__[1:]: + descriptor = base.__dict__.get(member_name) + if descriptor is not None: + if isinstance(descriptor, (property, DynamicClassAttribute)): + found_descriptor = descriptor + break + elif ( + hasattr(descriptor, 'fget') and + hasattr(descriptor, 'fset') and + hasattr(descriptor, 'fdel') + ): + found_descriptor = descriptor + continue + if found_descriptor: + redirect = property() + redirect.member = enum_member + redirect.__set_name__(enum_class, member_name) + # earlier descriptor found; copy fget, fset, fdel to this one. + redirect.fget = found_descriptor.fget + redirect.fset = found_descriptor.fset + redirect.fdel = found_descriptor.fdel + setattr(enum_class, member_name, redirect) + else: + setattr(enum_class, member_name, enum_member) + # now add to _member_map_ (even aliases) + enum_class._member_map_[member_name] = enum_member + try: + # This may fail if value is not hashable. We can't add the value + # to the map, and by-value lookups for this value will be + # linear. + enum_class._value2member_map_.setdefault(value, enum_member) + except TypeError: + # keep track of the value in a list so containment checks are quick + enum_class._unhashable_values_.append(value) class _EnumDict(dict): """ Track enum member order and ensure member names are not reused. - EnumMeta will use the names found in self._member_names as the + EnumType will use the names found in self._member_names as the enumeration member names. """ def __init__(self): super().__init__() - self._member_names = [] + self._member_names = {} # use a dict to keep insertion order self._last_values = [] self._ignore = [] self._auto_called = False @@ -81,17 +371,33 @@ def __setitem__(self, key, value): Single underscore (sunder) names are reserved. """ - if _is_sunder(key): + if _is_internal_class(self._cls_name, value): + import warnings + warnings.warn( + "In 3.13 classes created inside an enum will not become a member. " + "Use the `member` decorator to keep the current behavior.", + DeprecationWarning, + stacklevel=2, + ) + if _is_private(self._cls_name, key): + # also do nothing, name will be a normal attribute + pass + elif _is_sunder(key): if key not in ( - '_order_', '_create_pseudo_member_', - '_generate_next_value_', '_missing_', '_ignore_', + '_order_', + '_generate_next_value_', '_numeric_repr_', '_missing_', '_ignore_', + '_iter_member_', '_iter_member_by_value_', '_iter_member_by_def_', ): - raise ValueError('_names_ are reserved for future Enum use') + raise ValueError( + '_sunder_ names, such as %r, are reserved for future Enum use' + % (key, ) + ) if key == '_generate_next_value_': # check if members already defined as auto() if self._auto_called: raise TypeError("_generate_next_value_ must be defined before members") - setattr(self, '_generate_next_value', value) + _gnv = value.__func__ if isinstance(value, staticmethod) else value + setattr(self, '_generate_next_value', _gnv) elif key == '_ignore_': if isinstance(value, str): value = value.replace(',',' ').split() @@ -109,43 +415,77 @@ def __setitem__(self, key, value): key = '_order_' elif key in self._member_names: # descriptor overwriting an enum? - raise TypeError('Attempted to reuse key: %r' % key) + raise TypeError('%r already defined as %r' % (key, self[key])) elif key in self._ignore: pass - elif not _is_descriptor(value): + elif isinstance(value, nonmember): + # unwrap value here; it won't be processed by the below `else` + value = value.value + elif _is_descriptor(value): + pass + # TODO: uncomment next three lines in 3.13 + # elif _is_internal_class(self._cls_name, value): + # # do nothing, name will be a normal attribute + # pass + else: if key in self: # enum overwriting a descriptor? - raise TypeError('%r already defined as: %r' % (key, self[key])) - if isinstance(value, auto): - if value.value == _auto_null: - value.value = self._generate_next_value( - key, - 1, - len(self._member_names), - self._last_values[:], - ) - self._auto_called = True + raise TypeError('%r already defined as %r' % (key, self[key])) + elif isinstance(value, member): + # unwrap value here -- it will become a member value = value.value - self._member_names.append(key) - self._last_values.append(value) + non_auto_store = True + single = False + if isinstance(value, auto): + single = True + value = (value, ) + if type(value) is tuple and any(isinstance(v, auto) for v in value): + # insist on an actual tuple, no subclasses, in keeping with only supporting + # top-level auto() usage (not contained in any other data structure) + auto_valued = [] + for v in value: + if isinstance(v, auto): + non_auto_store = False + if v.value == _auto_null: + v.value = self._generate_next_value( + key, 1, len(self._member_names), self._last_values[:], + ) + self._auto_called = True + v = v.value + self._last_values.append(v) + auto_valued.append(v) + if single: + value = auto_valued[0] + else: + value = tuple(auto_valued) + self._member_names[key] = None + if non_auto_store: + self._last_values.append(value) super().__setitem__(key, value) + def update(self, members, **more_members): + try: + for name in members.keys(): + self[name] = members[name] + except AttributeError: + for name, value in members: + self[name] = value + for name, value in more_members.items(): + self[name] = value -# Dummy value for Enum as EnumMeta explicitly checks for it, but of course -# until EnumMeta finishes running the first time the Enum class doesn't exist. -# This is also why there are checks in EnumMeta like `if Enum is not None` -Enum = None -class EnumMeta(type): +class EnumType(type): """ Metaclass for Enum """ + @classmethod - def __prepare__(metacls, cls, bases): + def __prepare__(metacls, cls, bases, **kwds): # check that previous enum members do not exist - metacls._check_for_existing_members(cls, bases) + metacls._check_for_existing_members_(cls, bases) # create the namespace dict enum_dict = _EnumDict() + enum_dict._cls_name = cls # inherit previous flags and _generate_next_value_ function member_type, first_enum = metacls._get_mixins_(cls, bases) if first_enum is not None: @@ -154,138 +494,125 @@ def __prepare__(metacls, cls, bases): ) return enum_dict - def __new__(metacls, cls, bases, classdict): + def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **kwds): # an Enum class is final once enumeration items have been defined; it # cannot be mixed with other types (int, float, etc.) if it has an # inherited __new__ unless a new __new__ is defined (or the resulting # class will fail). # + if _simple: + return super().__new__(metacls, cls, bases, classdict, **kwds) + # # remove any keys listed in _ignore_ classdict.setdefault('_ignore_', []).append('_ignore_') ignore = classdict['_ignore_'] for key in ignore: classdict.pop(key, None) + # + # grab member names + member_names = classdict._member_names + # + # check for illegal enum names (any others?) + invalid_names = set(member_names) & {'mro', ''} + if invalid_names: + raise ValueError('invalid enum member name(s) %s' % ( + ','.join(repr(n) for n in invalid_names) + )) + # + # adjust the sunders + _order_ = classdict.pop('_order_', None) + # convert to normal dict + classdict = dict(classdict.items()) + # + # data type of member and the controlling Enum class member_type, first_enum = metacls._get_mixins_(cls, bases) __new__, save_new, use_args = metacls._find_new_( classdict, member_type, first_enum, ) - - # save enum items into separate mapping so they don't get baked into - # the new class - enum_members = {k: classdict[k] for k in classdict._member_names} - for name in classdict._member_names: - del classdict[name] - - # adjust the sunders - _order_ = classdict.pop('_order_', None) - - # check for illegal enum names (any others?) - invalid_names = set(enum_members) & {'mro', ''} - if invalid_names: - raise ValueError('Invalid enum member name: {0}'.format( - ','.join(invalid_names))) - - # create a default docstring if one has not been provided - if '__doc__' not in classdict: - classdict['__doc__'] = 'An enumeration.' - - # create our new Enum type - enum_class = super().__new__(metacls, cls, bases, classdict) - enum_class._member_names_ = [] # names in definition order - enum_class._member_map_ = {} # name->value map - enum_class._member_type_ = member_type - - # save DynamicClassAttribute attributes from super classes so we know - # if we can take the shortcut of storing members in the class dict - dynamic_attributes = { - k for c in enum_class.mro() - for k, v in c.__dict__.items() - if isinstance(v, DynamicClassAttribute) - } - - # Reverse value->name map for hashable values. - enum_class._value2member_map_ = {} - - # If a custom type is mixed into the Enum, and it does not know how - # to pickle itself, pickle.dumps will succeed but pickle.loads will - # fail. Rather than have the error show up later and possibly far - # from the source, sabotage the pickle protocol for this class so - # that pickle.dumps also fails. + classdict['_new_member_'] = __new__ + classdict['_use_args_'] = use_args + # + # convert future enum members into temporary _proto_members + for name in member_names: + value = classdict[name] + classdict[name] = _proto_member(value) + # + # house-keeping structures + classdict['_member_names_'] = [] + classdict['_member_map_'] = {} + classdict['_value2member_map_'] = {} + classdict['_unhashable_values_'] = [] + classdict['_member_type_'] = member_type + # now set the __repr__ for the value + classdict['_value_repr_'] = metacls._find_data_repr_(cls, bases) + # + # Flag structures (will be removed if final class is not a Flag + classdict['_boundary_'] = ( + boundary + or getattr(first_enum, '_boundary_', None) + ) + classdict['_flag_mask_'] = 0 + classdict['_singles_mask_'] = 0 + classdict['_all_bits_'] = 0 + classdict['_inverted_'] = None + try: + exc = None + enum_class = super().__new__(metacls, cls, bases, classdict, **kwds) + except RuntimeError as e: + # any exceptions raised by member.__new__ will get converted to a + # RuntimeError, so get that original exception back and raise it instead + exc = e.__cause__ or e + if exc is not None: + raise exc + # + # update classdict with any changes made by __init_subclass__ + classdict.update(enum_class.__dict__) # - # However, if the new class implements its own __reduce_ex__, do not - # sabotage -- it's on them to make sure it works correctly. We use - # __reduce_ex__ instead of any of the others as it is preferred by - # pickle over __reduce__, and it handles all pickle protocols. - if '__reduce_ex__' not in classdict: - if member_type is not object: - methods = ('__getnewargs_ex__', '__getnewargs__', - '__reduce_ex__', '__reduce__') - if not any(m in member_type.__dict__ for m in methods): - _make_class_unpicklable(enum_class) - - # instantiate them, checking for duplicates as we go - # we instantiate first instead of checking for duplicates first in case - # a custom __new__ is doing something funky with the values -- such as - # auto-numbering ;) - for member_name in classdict._member_names: - value = enum_members[member_name] - if not isinstance(value, tuple): - args = (value, ) - else: - args = value - if member_type is tuple: # special case for tuple enums - args = (args, ) # wrap it one more time - if not use_args: - enum_member = __new__(enum_class) - if not hasattr(enum_member, '_value_'): - enum_member._value_ = value - else: - enum_member = __new__(enum_class, *args) - if not hasattr(enum_member, '_value_'): - if member_type is object: - enum_member._value_ = value - else: - enum_member._value_ = member_type(*args) - value = enum_member._value_ - enum_member._name_ = member_name - enum_member.__objclass__ = enum_class - enum_member.__init__(*args) - # If another member with the same value was already defined, the - # new member becomes an alias to the existing one. - for name, canonical_member in enum_class._member_map_.items(): - if canonical_member._value_ == enum_member._value_: - enum_member = canonical_member - break - else: - # Aliases don't appear in member names (only in __members__). - enum_class._member_names_.append(member_name) - # performance boost for any member that would not shadow - # a DynamicClassAttribute - if member_name not in dynamic_attributes: - setattr(enum_class, member_name, enum_member) - # now add to _member_map_ - enum_class._member_map_[member_name] = enum_member - try: - # This may fail if value is not hashable. We can't add the value - # to the map, and by-value lookups for this value will be - # linear. - enum_class._value2member_map_[value] = enum_member - except TypeError: - pass - # double check that repr and friends are not the mixin's or various # things break (such as pickle) # however, if the method is defined in the Enum itself, don't replace # it + # + # Also, special handling for ReprEnum + if ReprEnum is not None and ReprEnum in bases: + if member_type is object: + raise TypeError( + 'ReprEnum subclasses must be mixed with a data type (i.e.' + ' int, str, float, etc.)' + ) + if '__format__' not in classdict: + enum_class.__format__ = member_type.__format__ + classdict['__format__'] = enum_class.__format__ + if '__str__' not in classdict: + method = member_type.__str__ + if method is object.__str__: + # if member_type does not define __str__, object.__str__ will use + # its __repr__ instead, so we'll also use its __repr__ + method = member_type.__repr__ + enum_class.__str__ = method + classdict['__str__'] = enum_class.__str__ for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'): - if name in classdict: - continue - class_method = getattr(enum_class, name) - obj_method = getattr(member_type, name, None) - enum_method = getattr(first_enum, name, None) - if obj_method is not None and obj_method is class_method: - setattr(enum_class, name, enum_method) - + if name not in classdict: + # check for mixin overrides before replacing + enum_method = getattr(first_enum, name) + found_method = getattr(enum_class, name) + object_method = getattr(object, name) + data_type_method = getattr(member_type, name) + if found_method in (data_type_method, object_method): + setattr(enum_class, name, enum_method) + # + # for Flag, add __or__, __and__, __xor__, and __invert__ + if Flag is not None and issubclass(enum_class, Flag): + for name in ( + '__or__', '__and__', '__xor__', + '__ror__', '__rand__', '__rxor__', + '__invert__' + ): + if name not in classdict: + enum_method = getattr(Flag, name) + setattr(enum_class, name, enum_method) + classdict[name] = enum_method + # # replace any other __new__ with our own (as long as Enum is not None, # anyway) -- again, this is to support pickle if Enum is not None: @@ -294,23 +621,69 @@ def __new__(metacls, cls, bases, classdict): if save_new: enum_class.__new_member__ = __new__ enum_class.__new__ = Enum.__new__ - + # # py3 support for definition order (helps keep py2/py3 code in sync) + # + # _order_ checking is spread out into three/four steps + # - if enum_class is a Flag: + # - remove any non-single-bit flags from _order_ + # - remove any aliases from _order_ + # - check that _order_ and _member_names_ match + # + # step 1: ensure we have a list if _order_ is not None: if isinstance(_order_, str): _order_ = _order_.replace(',', ' ').split() + # + # remove Flag structures if final class is not a Flag + if ( + Flag is None and cls != 'Flag' + or Flag is not None and not issubclass(enum_class, Flag) + ): + delattr(enum_class, '_boundary_') + delattr(enum_class, '_flag_mask_') + delattr(enum_class, '_singles_mask_') + delattr(enum_class, '_all_bits_') + delattr(enum_class, '_inverted_') + elif Flag is not None and issubclass(enum_class, Flag): + # set correct __iter__ + member_list = [m._value_ for m in enum_class] + if member_list != sorted(member_list): + enum_class._iter_member_ = enum_class._iter_member_by_def_ + if _order_: + # _order_ step 2: remove any items from _order_ that are not single-bit + _order_ = [ + o + for o in _order_ + if o not in enum_class._member_map_ or _is_single_bit(enum_class[o]._value_) + ] + # + if _order_: + # _order_ step 3: remove aliases from _order_ + _order_ = [ + o + for o in _order_ + if ( + o not in enum_class._member_map_ + or + (o in enum_class._member_map_ and o in enum_class._member_names_) + )] + # _order_ step 4: verify that _order_ and _member_names_ match if _order_ != enum_class._member_names_: - raise TypeError('member order does not match _order_') + raise TypeError( + 'member order does not match _order_:\n %r\n %r' + % (enum_class._member_names_, _order_) + ) return enum_class - def __bool__(self): + def __bool__(cls): """ classes/types should always be True. """ return True - def __call__(cls, value, names=None, *, module=None, qualname=None, type=None, start=1): + def __call__(cls, value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None): """ Either returns an existing member, or creates a new enum class. @@ -345,10 +718,25 @@ def __call__(cls, value, names=None, *, module=None, qualname=None, type=None, s qualname=qualname, type=type, start=start, + boundary=boundary, ) def __contains__(cls, member): + """ + Return True if member is a member of this enum + raises TypeError if member is not an enum member + + note: in 3.12 TypeError will no longer be raised, and True will also be + returned if member is the value of a member in this enum + """ if not isinstance(member, Enum): + import warnings + warnings.warn( + "in 3.12 __contains__ will no longer raise TypeError, but will return True or\n" + "False depending on whether the value is a member or the value of a member", + DeprecationWarning, + stacklevel=2, + ) raise TypeError( "unsupported operand type(s) for 'in': '%s' and '%s'" % ( type(member).__qualname__, cls.__class__.__qualname__)) @@ -358,14 +746,26 @@ def __delattr__(cls, attr): # nicer error message when someone tries to delete an attribute # (see issue19025). if attr in cls._member_map_: - raise AttributeError("%s: cannot delete Enum member." % cls.__name__) + raise AttributeError("%r cannot delete member %r." % (cls.__name__, attr)) super().__delattr__(attr) - def __dir__(self): - return ( - ['__class__', '__doc__', '__members__', '__module__'] - + self._member_names_ + def __dir__(cls): + interesting = set([ + '__class__', '__contains__', '__doc__', '__getitem__', + '__iter__', '__len__', '__members__', '__module__', + '__name__', '__qualname__', + ] + + cls._member_names_ ) + if cls._new_member_ is not object.__new__: + interesting.add('__new__') + if cls.__init_subclass__ is not object.__init_subclass__: + interesting.add('__init_subclass__') + if cls._member_type_ is object: + return sorted(interesting) + else: + # return whatever mixed-in data type has + return sorted(set(dir(cls._member_type_)) | interesting) def __getattr__(cls, name): """ @@ -384,18 +784,24 @@ def __getattr__(cls, name): raise AttributeError(name) from None def __getitem__(cls, name): + """ + Return the member matching `name`. + """ return cls._member_map_[name] def __iter__(cls): """ - Returns members in definition order. + Return members in definition order. """ return (cls._member_map_[name] for name in cls._member_names_) def __len__(cls): + """ + Return the number of members (no aliases) + """ return len(cls._member_names_) - @property + @bltns.property def __members__(cls): """ Returns a mapping of member name->value. @@ -406,11 +812,14 @@ def __members__(cls): return MappingProxyType(cls._member_map_) def __repr__(cls): - return "" % cls.__name__ + if Flag is not None and issubclass(cls, Flag): + return "" % cls.__name__ + else: + return "" % cls.__name__ def __reversed__(cls): """ - Returns members in reverse definition order. + Return members in reverse definition order. """ return (cls._member_map_[name] for name in reversed(cls._member_names_)) @@ -424,10 +833,10 @@ def __setattr__(cls, name, value): """ member_map = cls.__dict__.get('_member_map_', {}) if name in member_map: - raise AttributeError('Cannot reassign members.') + raise AttributeError('cannot reassign member %r' % (name, )) super().__setattr__(name, value) - def _create_(cls, class_name, names, *, module=None, qualname=None, type=None, start=1): + def _create_(cls, class_name, names, *, module=None, qualname=None, type=None, start=1, boundary=None): """ Convenience method to create a new Enum class. @@ -441,7 +850,7 @@ def _create_(cls, class_name, names, *, module=None, qualname=None, type=None, s """ metacls = cls.__class__ bases = (cls, ) if type is None else (type, cls) - _, first_enum = cls._get_mixins_(cls, bases) + _, first_enum = cls._get_mixins_(class_name, bases) classdict = metacls.__prepare__(class_name, bases) # special processing needed for names? @@ -462,25 +871,24 @@ def _create_(cls, class_name, names, *, module=None, qualname=None, type=None, s else: member_name, member_value = item classdict[member_name] = member_value - enum_class = metacls.__new__(metacls, class_name, bases, classdict) # TODO: replace the frame hack if a blessed way to know the calling # module is ever developed if module is None: try: module = sys._getframe(2).f_globals['__name__'] - except (AttributeError, ValueError, KeyError) as exc: + except (AttributeError, ValueError, KeyError): pass if module is None: - _make_class_unpicklable(enum_class) + _make_class_unpicklable(classdict) else: - enum_class.__module__ = module + classdict['__module__'] = module if qualname is not None: - enum_class.__qualname__ = qualname + classdict['__qualname__'] = qualname - return enum_class + return metacls.__new__(metacls, class_name, bases, classdict, boundary=boundary) - def _convert_(cls, name, module, filter, source=None): + def _convert_(cls, name, module, filter, source=None, *, boundary=None, as_global=False): """ Create a new Enum subclass that replaces a collection of global constants """ @@ -489,9 +897,9 @@ def _convert_(cls, name, module, filter, source=None): # module; # also, replace the __reduce_ex__ method so unpickling works in # previous Python versions - module_globals = vars(sys.modules[module]) + module_globals = sys.modules[module].__dict__ if source: - source = vars(source) + source = source.__dict__ else: source = module_globals # _value2member_map_ is populated in the same order every time @@ -507,30 +915,29 @@ def _convert_(cls, name, module, filter, source=None): except TypeError: # unless some values aren't comparable, in which case sort by name members.sort(key=lambda t: t[0]) - cls = cls(name, members, module=module) - cls.__reduce_ex__ = _reduce_ex_by_name - module_globals.update(cls.__members__) + body = {t[0]: t[1] for t in members} + body['__module__'] = module + tmp_cls = type(name, (object, ), body) + cls = _simple_enum(etype=cls, boundary=boundary or KEEP)(tmp_cls) + if as_global: + global_enum(cls) + else: + sys.modules[cls.__module__].__dict__.update(cls.__members__) module_globals[name] = cls return cls - def _convert(cls, *args, **kwargs): - import warnings - warnings.warn("_convert is deprecated and will be removed in 3.9, use " - "_convert_ instead.", DeprecationWarning, stacklevel=2) - return cls._convert_(*args, **kwargs) - - @staticmethod - def _check_for_existing_members(class_name, bases): + @classmethod + def _check_for_existing_members_(mcls, class_name, bases): for chain in bases: for base in chain.__mro__: - if issubclass(base, Enum) and base._member_names_: + if isinstance(base, EnumType) and base._member_names_: raise TypeError( - "%s: cannot extend enumeration %r" - % (class_name, base.__name__) + " cannot extend %r" + % (class_name, base) ) - @staticmethod - def _get_mixins_(class_name, bases): + @classmethod + def _get_mixins_(mcls, class_name, bases): """ Returns the type for creating enum members, and the first inherited enum class. @@ -540,44 +947,62 @@ def _get_mixins_(class_name, bases): if not bases: return object, Enum - def _find_data_type(bases): - data_types = [] - for chain in bases: - candidate = None - for base in chain.__mro__: - if base is object: - continue - elif issubclass(base, Enum): - if base._member_type_ is not object: - data_types.append(base._member_type_) - break - elif '__new__' in base.__dict__: - if issubclass(base, Enum): - continue - data_types.append(candidate or base) - break - else: - candidate = base - if len(data_types) > 1: - raise TypeError('%r: too many data types: %r' % (class_name, data_types)) - elif data_types: - return data_types[0] - else: - return None + mcls._check_for_existing_members_(class_name, bases) # ensure final parent class is an Enum derivative, find any concrete # data type, and check that Enum has no members first_enum = bases[-1] - if not issubclass(first_enum, Enum): + if not isinstance(first_enum, EnumType): raise TypeError("new enumerations should be created as " "`EnumName([mixin_type, ...] [data_type,] enum_type)`") - member_type = _find_data_type(bases) or object - if first_enum._member_names_: - raise TypeError("Cannot extend enumerations") + member_type = mcls._find_data_type_(class_name, bases) or object return member_type, first_enum - @staticmethod - def _find_new_(classdict, member_type, first_enum): + @classmethod + def _find_data_repr_(mcls, class_name, bases): + for chain in bases: + for base in chain.__mro__: + if base is object: + continue + elif isinstance(base, EnumType): + # if we hit an Enum, use it's _value_repr_ + return base._value_repr_ + elif '__repr__' in base.__dict__: + # this is our data repr + return base.__dict__['__repr__'] + return None + + @classmethod + def _find_data_type_(mcls, class_name, bases): + # a datatype has a __new__ method + data_types = set() + base_chain = set() + for chain in bases: + candidate = None + for base in chain.__mro__: + base_chain.add(base) + if base is object: + continue + elif isinstance(base, EnumType): + if base._member_type_ is not object: + data_types.add(base._member_type_) + break + elif '__new__' in base.__dict__ or '__dataclass_fields__' in base.__dict__: + if isinstance(base, EnumType): + continue + data_types.add(candidate or base) + break + else: + candidate = candidate or base + if len(data_types) > 1: + raise TypeError('too many data types for %r: %r' % (class_name, data_types)) + elif data_types: + return data_types.pop() + else: + return None + + @classmethod + def _find_new_(mcls, classdict, member_type, first_enum): """ Returns the __new__ to be used for creating the enum members. @@ -591,7 +1016,7 @@ def _find_new_(classdict, member_type, first_enum): __new__ = classdict.get('__new__', None) # should __new__ be saved as __new_member__ later? - save_new = __new__ is not None + save_new = first_enum is not None and __new__ is not None if __new__ is None: # check all possibles for __new_member__ before falling back to @@ -615,19 +1040,54 @@ def _find_new_(classdict, member_type, first_enum): # if a non-object.__new__ is used then whatever value/tuple was # assigned to the enum member name will be passed to __new__ and to the # new enum member's __init__ - if __new__ is object.__new__: + if first_enum is None or __new__ in (Enum.__new__, object.__new__): use_args = False else: use_args = True return __new__, save_new, use_args +EnumMeta = EnumType -class Enum(metaclass=EnumMeta): +class Enum(metaclass=EnumType): """ - Generic enumeration. + Create a collection of name/value pairs. + + Example enumeration: + + >>> class Color(Enum): + ... RED = 1 + ... BLUE = 2 + ... GREEN = 3 + + Access them by: + + - attribute access:: + + >>> Color.RED + + + - value lookup: + + >>> Color(1) + + + - name lookup: + + >>> Color['RED'] + + + Enumerations can be iterated over, and know how many members they have: + + >>> len(Color) + 3 - Derive from this class to define new enumerations. + >>> list(Color) + [, , ] + + Methods can be added to enumerations, and members can have their own + attributes -- see the documentation for details. """ + def __new__(cls, value): # all enum instances are actually created during class construction # without calling this method; this method is called by the metaclass' @@ -654,19 +1114,33 @@ def __new__(cls, value): except Exception as e: exc = e result = None - if isinstance(result, cls): - return result - else: - ve_exc = ValueError("%r is not a valid %s" % (value, cls.__name__)) - if result is None and exc is None: - raise ve_exc - elif exc is None: - exc = TypeError( - 'error in %s._missing_: returned %r instead of None or a valid member' - % (cls.__name__, result) - ) - exc.__context__ = ve_exc - raise exc + try: + if isinstance(result, cls): + return result + elif ( + Flag is not None and issubclass(cls, Flag) + and cls._boundary_ is EJECT and isinstance(result, int) + ): + return result + else: + ve_exc = ValueError("%r is not a valid %s" % (value, cls.__qualname__)) + if result is None and exc is None: + raise ve_exc + elif exc is None: + exc = TypeError( + 'error in %s._missing_: returned %r instead of None or a valid member' + % (cls.__name__, result) + ) + if not isinstance(exc, ValueError): + exc.__context__ = ve_exc + raise exc + finally: + # ensure all variables that could hold an exception are destroyed + exc = None + ve_exc = None + + def __init__(self, *args, **kwds): + pass def _generate_next_value_(name, start, count, last_values): """ @@ -675,14 +1149,32 @@ def _generate_next_value_(name, start, count, last_values): name: the name of the member start: the initial start value or None count: the number of existing members - last_value: the last value assigned or None + last_values: the list of values assigned """ - for last_value in reversed(last_values): - try: - return last_value + 1 - except TypeError: - pass - else: + if not last_values: + return start + try: + last = last_values[-1] + last_values.sort() + if last == last_values[-1]: + # no difference between old and new methods + return last + 1 + else: + # trigger old method (with warning) + raise TypeError + except TypeError: + import warnings + warnings.warn( + "In 3.13 the default `auto()`/`_generate_next_value_` will require all values to be sortable and support adding +1\n" + "and the value returned will be the largest value in the enum incremented by 1", + DeprecationWarning, + stacklevel=3, + ) + for v in reversed(last_values): + try: + return v + 1 + except TypeError: + pass return start @classmethod @@ -690,42 +1182,44 @@ def _missing_(cls, value): return None def __repr__(self): - return "<%s.%s: %r>" % ( - self.__class__.__name__, self._name_, self._value_) + v_repr = self.__class__._value_repr_ or repr + return "<%s.%s: %s>" % (self.__class__.__name__, self._name_, v_repr(self._value_)) def __str__(self): - return "%s.%s" % (self.__class__.__name__, self._name_) + return "%s.%s" % (self.__class__.__name__, self._name_, ) def __dir__(self): """ Returns all members and all public methods """ - added_behavior = [ - m - for cls in self.__class__.mro() - for m in cls.__dict__ - if m[0] != '_' and m not in self._member_map_ - ] + [m for m in self.__dict__ if m[0] != '_'] - return (['__class__', '__doc__', '__module__'] + added_behavior) + if self.__class__._member_type_ is object: + interesting = set(['__class__', '__doc__', '__eq__', '__hash__', '__module__', 'name', 'value']) + else: + interesting = set(object.__dir__(self)) + for name in getattr(self, '__dict__', []): + if name[0] != '_': + interesting.add(name) + for cls in self.__class__.mro(): + for name, obj in cls.__dict__.items(): + if name[0] == '_': + continue + if isinstance(obj, property): + # that's an enum.property + if obj.fget is not None or name not in self._member_map_: + interesting.add(name) + else: + # in case it was added by `dir(self)` + interesting.discard(name) + else: + interesting.add(name) + names = sorted( + set(['__class__', '__doc__', '__eq__', '__hash__', '__module__']) + | interesting + ) + return names def __format__(self, format_spec): - """ - Returns format using actual value type unless __str__ has been overridden. - """ - # mixed-in Enums should use the mixed-in type's __format__, otherwise - # we can get strange results with the Enum name showing up instead of - # the value - - # pure Enum branch, or branch with __str__ explicitly overridden - str_overridden = type(self).__str__ not in (Enum.__str__, Flag.__str__) - if self._member_type_ is object or str_overridden: - cls = str - val = str(self) - # mix-in branch - else: - cls = self._member_type_ - val = self._value_ - return cls.__format__(val, format_spec) + return str.__format__(str(self), format_spec) def __hash__(self): return hash(self._name_) @@ -733,36 +1227,107 @@ def __hash__(self): def __reduce_ex__(self, proto): return self.__class__, (self._value_, ) - # DynamicClassAttribute is used to provide access to the `name` and - # `value` properties of enum members while keeping some measure of + def __deepcopy__(self,memo): + return self + + def __copy__(self): + return self + + # enum.property is used to provide access to the `name` and + # `value` attributes of enum members while keeping some measure of # protection from modification, while still allowing for an enumeration # to have members named `name` and `value`. This works because enumeration - # members are not set directly on the enum class -- __getattr__ is - # used to look them up. + # members are not set directly on the enum class; they are kept in a + # separate structure, _member_map_, which is where enum.property looks for + # them - @DynamicClassAttribute + @property def name(self): """The name of the Enum member.""" return self._name_ - @DynamicClassAttribute + @property def value(self): """The value of the Enum member.""" return self._value_ -class IntEnum(int, Enum): - """Enum where members are also (and must be) ints""" +class ReprEnum(Enum): + """ + Only changes the repr(), leaving str() and format() to the mixed-in type. + """ -def _reduce_ex_by_name(self, proto): +class IntEnum(int, ReprEnum): + """ + Enum where members are also (and must be) ints + """ + + +class StrEnum(str, ReprEnum): + """ + Enum where members are also (and must be) strings + """ + + def __new__(cls, *values): + "values must already be of type `str`" + if len(values) > 3: + raise TypeError('too many arguments for str(): %r' % (values, )) + if len(values) == 1: + # it must be a string + if not isinstance(values[0], str): + raise TypeError('%r is not a string' % (values[0], )) + if len(values) >= 2: + # check that encoding argument is a string + if not isinstance(values[1], str): + raise TypeError('encoding must be a string, not %r' % (values[1], )) + if len(values) == 3: + # check that errors argument is a string + if not isinstance(values[2], str): + raise TypeError('errors must be a string, not %r' % (values[2])) + value = str(*values) + member = str.__new__(cls, value) + member._value_ = value + return member + + def _generate_next_value_(name, start, count, last_values): + """ + Return the lower-cased version of the member name. + """ + return name.lower() + + +def pickle_by_global_name(self, proto): + # should not be used with Flag-type enums return self.name +_reduce_ex_by_global_name = pickle_by_global_name + +def pickle_by_enum_name(self, proto): + # should not be used with Flag-type enums + return getattr, (self.__class__, self._name_) -class Flag(Enum): +class FlagBoundary(StrEnum): + """ + control how out of range values are handled + "strict" -> error is raised [default for Flag] + "conform" -> extra bits are discarded + "eject" -> lose flag status + "keep" -> keep flag status and all bits [default for IntFlag] + """ + STRICT = auto() + CONFORM = auto() + EJECT = auto() + KEEP = auto() +STRICT, CONFORM, EJECT, KEEP = FlagBoundary + + +class Flag(Enum, boundary=STRICT): """ Support for flags """ + _numeric_repr_ = repr + def _generate_next_value_(name, start, count, last_values): """ Generate the next value when not given. @@ -770,49 +1335,128 @@ def _generate_next_value_(name, start, count, last_values): name: the name of the member start: the initial start value or None count: the number of existing members - last_value: the last value assigned or None + last_values: the last value assigned or None """ if not count: return start if start is not None else 1 - for last_value in reversed(last_values): - try: - high_bit = _high_bit(last_value) - break - except Exception: - raise TypeError('Invalid Flag value: %r' % last_value) from None + last_value = max(last_values) + try: + high_bit = _high_bit(last_value) + except Exception: + raise TypeError('invalid flag value %r' % last_value) from None return 2 ** (high_bit+1) @classmethod - def _missing_(cls, value): + def _iter_member_by_value_(cls, value): """ - Returns member (possibly creating it) if one can be found for value. + Extract all members from the value in definition (i.e. increasing value) order. """ - original_value = value - if value < 0: - value = ~value - possible_member = cls._create_pseudo_member_(value) - if original_value < 0: - possible_member = ~possible_member - return possible_member + for val in _iter_bits_lsb(value & cls._flag_mask_): + yield cls._value2member_map_.get(val) + + _iter_member_ = _iter_member_by_value_ @classmethod - def _create_pseudo_member_(cls, value): + def _iter_member_by_def_(cls, value): """ - Create a composite member iff value contains only members. + Extract all members from the value in definition order. """ - pseudo_member = cls._value2member_map_.get(value, None) - if pseudo_member is None: - # verify all bits are accounted for - _, extra_flags = _decompose(cls, value) - if extra_flags: - raise ValueError("%r is not a valid %s" % (value, cls.__name__)) + yield from sorted( + cls._iter_member_by_value_(value), + key=lambda m: m._sort_order_, + ) + + @classmethod + def _missing_(cls, value): + """ + Create a composite member containing all canonical members present in `value`. + + If non-member values are present, result depends on `_boundary_` setting. + """ + if not isinstance(value, int): + raise ValueError( + "%r is not a valid %s" % (value, cls.__qualname__) + ) + # check boundaries + # - value must be in range (e.g. -16 <-> +15, i.e. ~15 <-> 15) + # - value must not include any skipped flags (e.g. if bit 2 is not + # defined, then 0d10 is invalid) + flag_mask = cls._flag_mask_ + singles_mask = cls._singles_mask_ + all_bits = cls._all_bits_ + neg_value = None + if ( + not ~all_bits <= value <= all_bits + or value & (all_bits ^ flag_mask) + ): + if cls._boundary_ is STRICT: + max_bits = max(value.bit_length(), flag_mask.bit_length()) + raise ValueError( + "%r invalid value %r\n given %s\n allowed %s" % ( + cls, value, bin(value, max_bits), bin(flag_mask, max_bits), + )) + elif cls._boundary_ is CONFORM: + value = value & flag_mask + elif cls._boundary_ is EJECT: + return value + elif cls._boundary_ is KEEP: + if value < 0: + value = ( + max(all_bits+1, 2**(value.bit_length())) + + value + ) + else: + raise ValueError( + '%r unknown flag boundary %r' % (cls, cls._boundary_, ) + ) + if value < 0: + neg_value = value + value = all_bits + 1 + value + # get members and unknown + unknown = value & ~flag_mask + aliases = value & ~singles_mask + member_value = value & singles_mask + if unknown and cls._boundary_ is not KEEP: + raise ValueError( + '%s(%r) --> unknown values %r [%s]' + % (cls.__name__, value, unknown, bin(unknown)) + ) + # normal Flag? + if cls._member_type_ is object: # construct a singleton enum pseudo-member pseudo_member = object.__new__(cls) - pseudo_member._name_ = None + else: + pseudo_member = cls._member_type_.__new__(cls, value) + if not hasattr(pseudo_member, '_value_'): pseudo_member._value_ = value - # use setdefault in case another thread already created a composite - # with this value - pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member) + if member_value or aliases: + members = [] + combined_value = 0 + for m in cls._iter_member_(member_value): + members.append(m) + combined_value |= m._value_ + if aliases: + value = member_value | aliases + for n, pm in cls._member_map_.items(): + if pm not in members and pm._value_ and pm._value_ & value == pm._value_: + members.append(pm) + combined_value |= pm._value_ + unknown = value ^ combined_value + pseudo_member._name_ = '|'.join([m._name_ for m in members]) + if not combined_value: + pseudo_member._name_ = None + elif unknown and cls._boundary_ is STRICT: + raise ValueError('%r: no members with value %r' % (cls, unknown)) + elif unknown: + pseudo_member._name_ += '|%s' % cls._numeric_repr_(unknown) + else: + pseudo_member._name_ = None + # use setdefault in case another thread already created a composite + # with this value + # note: zero is a special case -- always add it + pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member) + if neg_value is not None: + cls._value2member_map_[neg_value] = pseudo_member return pseudo_member def __contains__(self, other): @@ -821,133 +1465,85 @@ def __contains__(self, other): """ if not isinstance(other, self.__class__): raise TypeError( - "unsupported operand type(s) for 'in': '%s' and '%s'" % ( + "unsupported operand type(s) for 'in': %r and %r" % ( type(other).__qualname__, self.__class__.__qualname__)) return other._value_ & self._value_ == other._value_ + def __iter__(self): + """ + Returns flags in definition order. + """ + yield from self._iter_member_(self._value_) + + def __len__(self): + return self._value_.bit_count() + def __repr__(self): - cls = self.__class__ - if self._name_ is not None: - return '<%s.%s: %r>' % (cls.__name__, self._name_, self._value_) - members, uncovered = _decompose(cls, self._value_) - return '<%s.%s: %r>' % ( - cls.__name__, - '|'.join([str(m._name_ or m._value_) for m in members]), - self._value_, - ) + cls_name = self.__class__.__name__ + v_repr = self.__class__._value_repr_ or repr + if self._name_ is None: + return "<%s: %s>" % (cls_name, v_repr(self._value_)) + else: + return "<%s.%s: %s>" % (cls_name, self._name_, v_repr(self._value_)) def __str__(self): - cls = self.__class__ - if self._name_ is not None: - return '%s.%s' % (cls.__name__, self._name_) - members, uncovered = _decompose(cls, self._value_) - if len(members) == 1 and members[0]._name_ is None: - return '%s.%r' % (cls.__name__, members[0]._value_) + cls_name = self.__class__.__name__ + if self._name_ is None: + return '%s(%r)' % (cls_name, self._value_) else: - return '%s.%s' % ( - cls.__name__, - '|'.join([str(m._name_ or m._value_) for m in members]), - ) + return "%s.%s" % (cls_name, self._name_) def __bool__(self): return bool(self._value_) def __or__(self, other): - if not isinstance(other, self.__class__): + if isinstance(other, self.__class__): + other = other._value_ + elif self._member_type_ is not object and isinstance(other, self._member_type_): + other = other + else: return NotImplemented - return self.__class__(self._value_ | other._value_) + value = self._value_ + return self.__class__(value | other) def __and__(self, other): - if not isinstance(other, self.__class__): + if isinstance(other, self.__class__): + other = other._value_ + elif self._member_type_ is not object and isinstance(other, self._member_type_): + other = other + else: return NotImplemented - return self.__class__(self._value_ & other._value_) + value = self._value_ + return self.__class__(value & other) def __xor__(self, other): - if not isinstance(other, self.__class__): + if isinstance(other, self.__class__): + other = other._value_ + elif self._member_type_ is not object and isinstance(other, self._member_type_): + other = other + else: return NotImplemented - return self.__class__(self._value_ ^ other._value_) + value = self._value_ + return self.__class__(value ^ other) def __invert__(self): - members, uncovered = _decompose(self.__class__, self._value_) - inverted = self.__class__(0) - for m in self.__class__: - if m not in members and not (m._value_ & self._value_): - inverted = inverted | m - return self.__class__(inverted) + if self._inverted_ is None: + if self._boundary_ in (EJECT, KEEP): + self._inverted_ = self.__class__(~self._value_) + else: + self._inverted_ = self.__class__(self._singles_mask_ & ~self._value_) + return self._inverted_ + + __rand__ = __and__ + __ror__ = __or__ + __rxor__ = __xor__ -class IntFlag(int, Flag): +class IntFlag(int, ReprEnum, Flag, boundary=KEEP): """ Support for integer-based Flags """ - @classmethod - def _missing_(cls, value): - """ - Returns member (possibly creating it) if one can be found for value. - """ - if not isinstance(value, int): - raise ValueError("%r is not a valid %s" % (value, cls.__name__)) - new_member = cls._create_pseudo_member_(value) - return new_member - - @classmethod - def _create_pseudo_member_(cls, value): - """ - Create a composite member iff value contains only members. - """ - pseudo_member = cls._value2member_map_.get(value, None) - if pseudo_member is None: - need_to_create = [value] - # get unaccounted for bits - _, extra_flags = _decompose(cls, value) - # timer = 10 - while extra_flags: - # timer -= 1 - bit = _high_bit(extra_flags) - flag_value = 2 ** bit - if (flag_value not in cls._value2member_map_ and - flag_value not in need_to_create - ): - need_to_create.append(flag_value) - if extra_flags == -flag_value: - extra_flags = 0 - else: - extra_flags ^= flag_value - for value in reversed(need_to_create): - # construct singleton pseudo-members - pseudo_member = int.__new__(cls, value) - pseudo_member._name_ = None - pseudo_member._value_ = value - # use setdefault in case another thread already created a composite - # with this value - pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member) - return pseudo_member - - def __or__(self, other): - if not isinstance(other, (self.__class__, int)): - return NotImplemented - result = self.__class__(self._value_ | self.__class__(other)._value_) - return result - - def __and__(self, other): - if not isinstance(other, (self.__class__, int)): - return NotImplemented - return self.__class__(self._value_ & self.__class__(other)._value_) - - def __xor__(self, other): - if not isinstance(other, (self.__class__, int)): - return NotImplemented - return self.__class__(self._value_ ^ self.__class__(other)._value_) - - __ror__ = __or__ - __rand__ = __and__ - __rxor__ = __xor__ - - def __invert__(self): - result = self.__class__(~self._value_) - return result - def _high_bit(value): """ @@ -970,44 +1566,478 @@ def unique(enumeration): (enumeration, alias_details)) return enumeration -def _decompose(flag, value): - """ - Extract all members from the value. - """ - # _decompose is only called if the value is not named - not_covered = value - negative = value < 0 - # issue29167: wrap accesses to _value2member_map_ in a list to avoid race - # conditions between iterating over it and having more pseudo- - # members added to it - if negative: - # only check for named flags - flags_to_check = [ - (m, v) - for v, m in list(flag._value2member_map_.items()) - if m.name is not None - ] - else: - # check for named flags and powers-of-two flags - flags_to_check = [ - (m, v) - for v, m in list(flag._value2member_map_.items()) - if m.name is not None or _power_of_two(v) - ] - members = [] - for member, member_value in flags_to_check: - if member_value and member_value & value == member_value: - members.append(member) - not_covered &= ~member_value - if not members and value in flag._value2member_map_: - members.append(flag._value2member_map_[value]) - members.sort(key=lambda m: m._value_, reverse=True) - if len(members) > 1 and members[0].value == value: - # we have the breakdown, don't need the value member itself - members.pop(0) - return members, not_covered - def _power_of_two(value): if value < 1: return False return value == 2 ** _high_bit(value) + +def global_enum_repr(self): + """ + use module.enum_name instead of class.enum_name + + the module is the last module in case of a multi-module name + """ + module = self.__class__.__module__.split('.')[-1] + return '%s.%s' % (module, self._name_) + +def global_flag_repr(self): + """ + use module.flag_name instead of class.flag_name + + the module is the last module in case of a multi-module name + """ + module = self.__class__.__module__.split('.')[-1] + cls_name = self.__class__.__name__ + if self._name_ is None: + return "%s.%s(%r)" % (module, cls_name, self._value_) + if _is_single_bit(self): + return '%s.%s' % (module, self._name_) + if self._boundary_ is not FlagBoundary.KEEP: + return '|'.join(['%s.%s' % (module, name) for name in self.name.split('|')]) + else: + name = [] + for n in self._name_.split('|'): + if n[0].isdigit(): + name.append(n) + else: + name.append('%s.%s' % (module, n)) + return '|'.join(name) + +def global_str(self): + """ + use enum_name instead of class.enum_name + """ + if self._name_ is None: + cls_name = self.__class__.__name__ + return "%s(%r)" % (cls_name, self._value_) + else: + return self._name_ + +def global_enum(cls, update_str=False): + """ + decorator that makes the repr() of an enum member reference its module + instead of its class; also exports all members to the enum's module's + global namespace + """ + if issubclass(cls, Flag): + cls.__repr__ = global_flag_repr + else: + cls.__repr__ = global_enum_repr + if not issubclass(cls, ReprEnum) or update_str: + cls.__str__ = global_str + sys.modules[cls.__module__].__dict__.update(cls.__members__) + return cls + +def _simple_enum(etype=Enum, *, boundary=None, use_args=None): + """ + Class decorator that converts a normal class into an :class:`Enum`. No + safety checks are done, and some advanced behavior (such as + :func:`__init_subclass__`) is not available. Enum creation can be faster + using :func:`simple_enum`. + + >>> from enum import Enum, _simple_enum + >>> @_simple_enum(Enum) + ... class Color: + ... RED = auto() + ... GREEN = auto() + ... BLUE = auto() + >>> Color + + """ + def convert_class(cls): + nonlocal use_args + cls_name = cls.__name__ + if use_args is None: + use_args = etype._use_args_ + __new__ = cls.__dict__.get('__new__') + if __new__ is not None: + new_member = __new__.__func__ + else: + new_member = etype._member_type_.__new__ + attrs = {} + body = {} + if __new__ is not None: + body['__new_member__'] = new_member + body['_new_member_'] = new_member + body['_use_args_'] = use_args + body['_generate_next_value_'] = gnv = etype._generate_next_value_ + body['_member_names_'] = member_names = [] + body['_member_map_'] = member_map = {} + body['_value2member_map_'] = value2member_map = {} + body['_unhashable_values_'] = [] + body['_member_type_'] = member_type = etype._member_type_ + body['_value_repr_'] = etype._value_repr_ + if issubclass(etype, Flag): + body['_boundary_'] = boundary or etype._boundary_ + body['_flag_mask_'] = None + body['_all_bits_'] = None + body['_singles_mask_'] = None + body['_inverted_'] = None + body['__or__'] = Flag.__or__ + body['__xor__'] = Flag.__xor__ + body['__and__'] = Flag.__and__ + body['__ror__'] = Flag.__ror__ + body['__rxor__'] = Flag.__rxor__ + body['__rand__'] = Flag.__rand__ + body['__invert__'] = Flag.__invert__ + for name, obj in cls.__dict__.items(): + if name in ('__dict__', '__weakref__'): + continue + if _is_dunder(name) or _is_private(cls_name, name) or _is_sunder(name) or _is_descriptor(obj): + body[name] = obj + else: + attrs[name] = obj + if cls.__dict__.get('__doc__') is None: + body['__doc__'] = 'An enumeration.' + # + # double check that repr and friends are not the mixin's or various + # things break (such as pickle) + # however, if the method is defined in the Enum itself, don't replace + # it + enum_class = type(cls_name, (etype, ), body, boundary=boundary, _simple=True) + for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'): + if name not in body: + # check for mixin overrides before replacing + enum_method = getattr(etype, name) + found_method = getattr(enum_class, name) + object_method = getattr(object, name) + data_type_method = getattr(member_type, name) + if found_method in (data_type_method, object_method): + setattr(enum_class, name, enum_method) + gnv_last_values = [] + if issubclass(enum_class, Flag): + # Flag / IntFlag + single_bits = multi_bits = 0 + for name, value in attrs.items(): + if isinstance(value, auto) and auto.value is _auto_null: + value = gnv(name, 1, len(member_names), gnv_last_values) + if value in value2member_map: + # an alias to an existing member + redirect = property() + redirect.__set_name__(enum_class, name) + setattr(enum_class, name, redirect) + member_map[name] = value2member_map[value] + else: + # create the member + if use_args: + if not isinstance(value, tuple): + value = (value, ) + member = new_member(enum_class, *value) + value = value[0] + else: + member = new_member(enum_class) + if __new__ is None: + member._value_ = value + member._name_ = name + member.__objclass__ = enum_class + member.__init__(value) + redirect = property() + redirect.__set_name__(enum_class, name) + setattr(enum_class, name, redirect) + member_map[name] = member + member._sort_order_ = len(member_names) + value2member_map[value] = member + if _is_single_bit(value): + # not a multi-bit alias, record in _member_names_ and _flag_mask_ + member_names.append(name) + single_bits |= value + else: + multi_bits |= value + gnv_last_values.append(value) + enum_class._flag_mask_ = single_bits | multi_bits + enum_class._singles_mask_ = single_bits + enum_class._all_bits_ = 2 ** ((single_bits|multi_bits).bit_length()) - 1 + # set correct __iter__ + member_list = [m._value_ for m in enum_class] + if member_list != sorted(member_list): + enum_class._iter_member_ = enum_class._iter_member_by_def_ + else: + # Enum / IntEnum / StrEnum + for name, value in attrs.items(): + if isinstance(value, auto): + if value.value is _auto_null: + value.value = gnv(name, 1, len(member_names), gnv_last_values) + value = value.value + if value in value2member_map: + # an alias to an existing member + redirect = property() + redirect.__set_name__(enum_class, name) + setattr(enum_class, name, redirect) + member_map[name] = value2member_map[value] + else: + # create the member + if use_args: + if not isinstance(value, tuple): + value = (value, ) + member = new_member(enum_class, *value) + value = value[0] + else: + member = new_member(enum_class) + if __new__ is None: + member._value_ = value + member._name_ = name + member.__objclass__ = enum_class + member.__init__(value) + member._sort_order_ = len(member_names) + redirect = property() + redirect.__set_name__(enum_class, name) + setattr(enum_class, name, redirect) + member_map[name] = member + value2member_map[value] = member + member_names.append(name) + gnv_last_values.append(value) + if '__new__' in body: + enum_class.__new_member__ = enum_class.__new__ + enum_class.__new__ = Enum.__new__ + return enum_class + return convert_class + +@_simple_enum(StrEnum) +class EnumCheck: + """ + various conditions to check an enumeration for + """ + CONTINUOUS = "no skipped integer values" + NAMED_FLAGS = "multi-flag aliases may not contain unnamed flags" + UNIQUE = "one name per value" +CONTINUOUS, NAMED_FLAGS, UNIQUE = EnumCheck + + +class verify: + """ + Check an enumeration for various constraints. (see EnumCheck) + """ + def __init__(self, *checks): + self.checks = checks + def __call__(self, enumeration): + checks = self.checks + cls_name = enumeration.__name__ + if Flag is not None and issubclass(enumeration, Flag): + enum_type = 'flag' + elif issubclass(enumeration, Enum): + enum_type = 'enum' + else: + raise TypeError("the 'verify' decorator only works with Enum and Flag") + for check in checks: + if check is UNIQUE: + # check for duplicate names + duplicates = [] + for name, member in enumeration.__members__.items(): + if name != member.name: + duplicates.append((name, member.name)) + if duplicates: + alias_details = ', '.join( + ["%s -> %s" % (alias, name) for (alias, name) in duplicates]) + raise ValueError('aliases found in %r: %s' % + (enumeration, alias_details)) + elif check is CONTINUOUS: + values = set(e.value for e in enumeration) + if len(values) < 2: + continue + low, high = min(values), max(values) + missing = [] + if enum_type == 'flag': + # check for powers of two + for i in range(_high_bit(low)+1, _high_bit(high)): + if 2**i not in values: + missing.append(2**i) + elif enum_type == 'enum': + # check for powers of one + for i in range(low+1, high): + if i not in values: + missing.append(i) + else: + raise Exception('verify: unknown type %r' % enum_type) + if missing: + raise ValueError(('invalid %s %r: missing values %s' % ( + enum_type, cls_name, ', '.join((str(m) for m in missing))) + )[:256]) + # limit max length to protect against DOS attacks + elif check is NAMED_FLAGS: + # examine each alias and check for unnamed flags + member_names = enumeration._member_names_ + member_values = [m.value for m in enumeration] + missing_names = [] + missing_value = 0 + for name, alias in enumeration._member_map_.items(): + if name in member_names: + # not an alias + continue + if alias.value < 0: + # negative numbers are not checked + continue + values = list(_iter_bits_lsb(alias.value)) + missed = [v for v in values if v not in member_values] + if missed: + missing_names.append(name) + missing_value |= reduce(_or_, missed) + if missing_names: + if len(missing_names) == 1: + alias = 'alias %s is missing' % missing_names[0] + else: + alias = 'aliases %s and %s are missing' % ( + ', '.join(missing_names[:-1]), missing_names[-1] + ) + if _is_single_bit(missing_value): + value = 'value 0x%x' % missing_value + else: + value = 'combined values of 0x%x' % missing_value + raise ValueError( + 'invalid Flag %r: %s %s [use enum.show_flag_values(value) for details]' + % (cls_name, alias, value) + ) + return enumeration + +def _test_simple_enum(checked_enum, simple_enum): + """ + A function that can be used to test an enum created with :func:`_simple_enum` + against the version created by subclassing :class:`Enum`:: + + >>> from enum import Enum, _simple_enum, _test_simple_enum + >>> @_simple_enum(Enum) + ... class Color: + ... RED = auto() + ... GREEN = auto() + ... BLUE = auto() + >>> class CheckedColor(Enum): + ... RED = auto() + ... GREEN = auto() + ... BLUE = auto() + >>> # TODO: RUSTPYTHON + >>> # _test_simple_enum(CheckedColor, Color) + + If differences are found, a :exc:`TypeError` is raised. + """ + failed = [] + if checked_enum.__dict__ != simple_enum.__dict__: + checked_dict = checked_enum.__dict__ + checked_keys = list(checked_dict.keys()) + simple_dict = simple_enum.__dict__ + simple_keys = list(simple_dict.keys()) + member_names = set( + list(checked_enum._member_map_.keys()) + + list(simple_enum._member_map_.keys()) + ) + for key in set(checked_keys + simple_keys): + if key in ('__module__', '_member_map_', '_value2member_map_', '__doc__'): + # keys known to be different, or very long + continue + elif key in member_names: + # members are checked below + continue + elif key not in simple_keys: + failed.append("missing key: %r" % (key, )) + elif key not in checked_keys: + failed.append("extra key: %r" % (key, )) + else: + checked_value = checked_dict[key] + simple_value = simple_dict[key] + if callable(checked_value) or isinstance(checked_value, bltns.property): + continue + if key == '__doc__': + # remove all spaces/tabs + compressed_checked_value = checked_value.replace(' ','').replace('\t','') + compressed_simple_value = simple_value.replace(' ','').replace('\t','') + if compressed_checked_value != compressed_simple_value: + failed.append("%r:\n %s\n %s" % ( + key, + "checked -> %r" % (checked_value, ), + "simple -> %r" % (simple_value, ), + )) + elif checked_value != simple_value: + failed.append("%r:\n %s\n %s" % ( + key, + "checked -> %r" % (checked_value, ), + "simple -> %r" % (simple_value, ), + )) + failed.sort() + for name in member_names: + failed_member = [] + if name not in simple_keys: + failed.append('missing member from simple enum: %r' % name) + elif name not in checked_keys: + failed.append('extra member in simple enum: %r' % name) + else: + checked_member_dict = checked_enum[name].__dict__ + checked_member_keys = list(checked_member_dict.keys()) + simple_member_dict = simple_enum[name].__dict__ + simple_member_keys = list(simple_member_dict.keys()) + for key in set(checked_member_keys + simple_member_keys): + if key in ('__module__', '__objclass__', '_inverted_'): + # keys known to be different or absent + continue + elif key not in simple_member_keys: + failed_member.append("missing key %r not in the simple enum member %r" % (key, name)) + elif key not in checked_member_keys: + failed_member.append("extra key %r in simple enum member %r" % (key, name)) + else: + checked_value = checked_member_dict[key] + simple_value = simple_member_dict[key] + if checked_value != simple_value: + failed_member.append("%r:\n %s\n %s" % ( + key, + "checked member -> %r" % (checked_value, ), + "simple member -> %r" % (simple_value, ), + )) + if failed_member: + failed.append('%r member mismatch:\n %s' % ( + name, '\n '.join(failed_member), + )) + for method in ( + '__str__', '__repr__', '__reduce_ex__', '__format__', + '__getnewargs_ex__', '__getnewargs__', '__reduce_ex__', '__reduce__' + ): + if method in simple_keys and method in checked_keys: + # cannot compare functions, and it exists in both, so we're good + continue + elif method not in simple_keys and method not in checked_keys: + # method is inherited -- check it out + checked_method = getattr(checked_enum, method, None) + simple_method = getattr(simple_enum, method, None) + if hasattr(checked_method, '__func__'): + checked_method = checked_method.__func__ + simple_method = simple_method.__func__ + if checked_method != simple_method: + failed.append("%r: %-30s %s" % ( + method, + "checked -> %r" % (checked_method, ), + "simple -> %r" % (simple_method, ), + )) + else: + # if the method existed in only one of the enums, it will have been caught + # in the first checks above + pass + if failed: + raise TypeError('enum mismatch:\n %s' % '\n '.join(failed)) + +def _old_convert_(etype, name, module, filter, source=None, *, boundary=None): + """ + Create a new Enum subclass that replaces a collection of global constants + """ + # convert all constants from source (or module) that pass filter() to + # a new Enum called name, and export the enum and its members back to + # module; + # also, replace the __reduce_ex__ method so unpickling works in + # previous Python versions + module_globals = sys.modules[module].__dict__ + if source: + source = source.__dict__ + else: + source = module_globals + # _value2member_map_ is populated in the same order every time + # for a consistent reverse mapping of number to name when there + # are multiple names for the same number. + members = [ + (name, value) + for name, value in source.items() + if filter(name)] + try: + # sort by value + members.sort(key=lambda t: (t[1], t[0])) + except TypeError: + # unless some values aren't comparable, in which case sort by name + members.sort(key=lambda t: t[0]) + cls = etype(name, members, module=module, boundary=boundary or KEEP) + return cls + +_stdlib_enums = IntEnum, StrEnum, IntFlag diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py index 1cccd27dee..be242e93f7 100644 --- a/Lib/test/test_enum.py +++ b/Lib/test/test_enum.py @@ -1,16 +1,46 @@ +import copy import enum +import doctest import inspect +import os import pydoc import sys import unittest import threading +import typing +import builtins as bltns from collections import OrderedDict -from enum import Enum, IntEnum, EnumMeta, Flag, IntFlag, unique, auto +from datetime import date +from enum import Enum, EnumMeta, IntEnum, StrEnum, EnumType, Flag, IntFlag, unique, auto +from enum import STRICT, CONFORM, EJECT, KEEP, _simple_enum, _test_simple_enum +from enum import verify, UNIQUE, CONTINUOUS, NAMED_FLAGS, ReprEnum +from enum import member, nonmember, _iter_bits_lsb from io import StringIO from pickle import dumps, loads, PicklingError, HIGHEST_PROTOCOL -from test.support import ALWAYS_EQ, check__all__, threading_helper +from test import support +from test.support import ALWAYS_EQ +from test.support import threading_helper +from textwrap import dedent from datetime import timedelta +python_version = sys.version_info[:2] + +def load_tests(loader, tests, ignore): + tests.addTests(doctest.DocTestSuite(enum)) + if os.path.exists('Doc/library/enum.rst'): + tests.addTests(doctest.DocFileSuite( + '../../Doc/library/enum.rst', + optionflags=doctest.ELLIPSIS|doctest.NORMALIZE_WHITESPACE, + )) + if os.path.exists('Doc/howto/enum.rst'): + tests.addTests(doctest.DocFileSuite( + '../../Doc/howto/enum.rst', + optionflags=doctest.ELLIPSIS|doctest.NORMALIZE_WHITESPACE, + )) + return tests + +MODULE = __name__ +SHORT_MODULE = MODULE.split('.')[-1] # for pickle tests try: @@ -41,19 +71,35 @@ class FloatStooges(float, Enum): class FlagStooges(Flag): LARRY = 1 CURLY = 2 - MOE = 3 + MOE = 4 + BIG = 389 except Exception as exc: FlagStooges = exc +class FlagStoogesWithZero(Flag): + NOFLAG = 0 + LARRY = 1 + CURLY = 2 + MOE = 4 + BIG = 389 + +class IntFlagStooges(IntFlag): + LARRY = 1 + CURLY = 2 + MOE = 4 + BIG = 389 + +class IntFlagStoogesWithZero(IntFlag): + NOFLAG = 0 + LARRY = 1 + CURLY = 2 + MOE = 4 + BIG = 389 + # for pickle test and subclass tests -try: - class StrEnum(str, Enum): - 'accepts only string values' - class Name(StrEnum): - BDFL = 'Guido van Rossum' - FLUFL = 'Barry Warsaw' -except Exception as exc: - Name = exc +class Name(StrEnum): + BDFL = 'Guido van Rossum' + FLUFL = 'Barry Warsaw' try: Question = Enum('Question', 'who what when where why', module=__name__) @@ -93,6 +139,12 @@ def test_pickle_exception(assertion, exception, obj): class TestHelpers(unittest.TestCase): # _is_descriptor, _is_sunder, _is_dunder + sunder_names = '_bad_', '_good_', '_what_ho_' + dunder_names = '__mal__', '__bien__', '__que_que__' + private_names = '_MyEnum__private', '_MyEnum__still_private' + private_and_sunder_names = '_MyEnum__private_', '_MyEnum__also_private_' + random_names = 'okay', '_semi_private', '_weird__', '_MyEnum__' + def test_is_descriptor(self): class foo: pass @@ -102,21 +154,40 @@ class foo: setattr(obj, attr, 1) self.assertTrue(enum._is_descriptor(obj)) - def test_is_sunder(self): + def test_sunder(self): + for name in self.sunder_names + self.private_and_sunder_names: + self.assertTrue(enum._is_sunder(name), '%r is a not sunder name?' % name) + for name in self.dunder_names + self.private_names + self.random_names: + self.assertFalse(enum._is_sunder(name), '%r is a sunder name?' % name) for s in ('_a_', '_aa_'): self.assertTrue(enum._is_sunder(s)) - for s in ('a', 'a_', '_a', '__a', 'a__', '__a__', '_a__', '__a_', '_', '__', '___', '____', '_____',): self.assertFalse(enum._is_sunder(s)) - def test_is_dunder(self): + def test_dunder(self): + for name in self.dunder_names: + self.assertTrue(enum._is_dunder(name), '%r is a not dunder name?' % name) + for name in self.sunder_names + self.private_names + self.private_and_sunder_names + self.random_names: + self.assertFalse(enum._is_dunder(name), '%r is a dunder name?' % name) for s in ('__a__', '__aa__'): self.assertTrue(enum._is_dunder(s)) for s in ('a', 'a_', '_a', '__a', 'a__', '_a_', '_a__', '__a_', '_', '__', '___', '____', '_____',): self.assertFalse(enum._is_dunder(s)) + + def test_is_private(self): + for name in self.private_names + self.private_and_sunder_names: + self.assertTrue(enum._is_private('MyEnum', name), '%r is a not private name?') + for name in self.sunder_names + self.dunder_names + self.random_names: + self.assertFalse(enum._is_private('MyEnum', name), '%r is a private name?') + + def test_iter_bits_lsb(self): + self.assertEqual(list(_iter_bits_lsb(7)), [1, 2, 4]) + self.assertRaisesRegex(ValueError, '-8 is not a positive integer', list, _iter_bits_lsb(-8)) + + # for subclassing tests class classproperty: @@ -132,199 +203,794 @@ def __init__(self, fget=None, fset=None, fdel=None, doc=None): def __get__(self, instance, ownerclass): return self.fget(ownerclass) +# for global repr tests + +@enum.global_enum +class HeadlightsK(IntFlag, boundary=enum.KEEP): + OFF_K = 0 + LOW_BEAM_K = auto() + HIGH_BEAM_K = auto() + FOG_K = auto() + + +@enum.global_enum +class HeadlightsC(IntFlag, boundary=enum.CONFORM): + OFF_C = 0 + LOW_BEAM_C = auto() + HIGH_BEAM_C = auto() + FOG_C = auto() + + +@enum.global_enum +class NoName(Flag): + ONE = 1 + TWO = 2 + # tests -class TestEnum(unittest.TestCase): +class _EnumTests: + """ + Test for behavior that is the same across the different types of enumerations. + """ - def setUp(self): - class Season(Enum): - SPRING = 1 - SUMMER = 2 - AUTUMN = 3 - WINTER = 4 - self.Season = Season + values = None - class Konstants(float, Enum): - E = 2.7182818 - PI = 3.1415926 - TAU = 2 * PI - self.Konstants = Konstants + def setUp(self): + class BaseEnum(self.enum_type): + @enum.property + def first(self): + return '%s is first!' % self.name + class MainEnum(BaseEnum): + first = auto() + second = auto() + third = auto() + if issubclass(self.enum_type, Flag): + dupe = 3 + else: + dupe = third + self.MainEnum = MainEnum + # + class NewStrEnum(self.enum_type): + def __str__(self): + return self.name.upper() + first = auto() + self.NewStrEnum = NewStrEnum + # + class NewFormatEnum(self.enum_type): + def __format__(self, spec): + return self.name.upper() + first = auto() + self.NewFormatEnum = NewFormatEnum + # + class NewStrFormatEnum(self.enum_type): + def __str__(self): + return self.name.title() + def __format__(self, spec): + return ''.join(reversed(self.name)) + first = auto() + self.NewStrFormatEnum = NewStrFormatEnum + # + class NewBaseEnum(self.enum_type): + def __str__(self): + return self.name.title() + def __format__(self, spec): + return ''.join(reversed(self.name)) + class NewSubEnum(NewBaseEnum): + first = auto() + self.NewSubEnum = NewSubEnum + # + self.is_flag = False + self.names = ['first', 'second', 'third'] + if issubclass(MainEnum, StrEnum): + self.values = self.names + elif MainEnum._member_type_ is str: + self.values = ['1', '2', '3'] + elif issubclass(self.enum_type, Flag): + self.values = [1, 2, 4] + self.is_flag = True + self.dupe2 = MainEnum(5) + else: + self.values = self.values or [1, 2, 3] + # + if not getattr(self, 'source_values', False): + self.source_values = self.values - class Grades(IntEnum): - A = 5 - B = 4 - C = 3 - D = 2 - F = 0 - self.Grades = Grades + def assertFormatIsValue(self, spec, member): + self.assertEqual(spec.format(member), spec.format(member.value)) - class Directional(str, Enum): - EAST = 'east' - WEST = 'west' - NORTH = 'north' - SOUTH = 'south' - self.Directional = Directional + def assertFormatIsStr(self, spec, member): + self.assertEqual(spec.format(member), spec.format(str(member))) - from datetime import date - class Holiday(date, Enum): - NEW_YEAR = 2013, 1, 1 - IDES_OF_MARCH = 2013, 3, 15 - self.Holiday = Holiday + def test_attribute_deletion(self): + class Season(self.enum_type): + SPRING = auto() + SUMMER = auto() + AUTUMN = auto() + # + def spam(cls): + pass + # + self.assertTrue(hasattr(Season, 'spam')) + del Season.spam + self.assertFalse(hasattr(Season, 'spam')) + # + with self.assertRaises(AttributeError): + del Season.SPRING + with self.assertRaises(AttributeError): + del Season.DRY + with self.assertRaises(AttributeError): + del Season.SPRING.name - def test_dir_on_class(self): - Season = self.Season + def test_basics(self): + TE = self.MainEnum + if self.is_flag: + self.assertEqual(repr(TE), "") + self.assertEqual(str(TE), "") + self.assertEqual(format(TE), "") + self.assertTrue(TE(5) is self.dupe2) + else: + self.assertEqual(repr(TE), "") + self.assertEqual(str(TE), "") + self.assertEqual(format(TE), "") + self.assertEqual(list(TE), [TE.first, TE.second, TE.third]) + self.assertEqual( + [m.name for m in TE], + self.names, + ) self.assertEqual( - set(dir(Season)), - set(['__class__', '__doc__', '__members__', '__module__', - 'SPRING', 'SUMMER', 'AUTUMN', 'WINTER']), + [m.value for m in TE], + self.values, + ) + self.assertEqual( + [m.first for m in TE], + ['first is first!', 'second is first!', 'third is first!'] + ) + for member, name in zip(TE, self.names, strict=True): + self.assertIs(TE[name], member) + for member, value in zip(TE, self.values, strict=True): + self.assertIs(TE(value), member) + if issubclass(TE, StrEnum): + self.assertTrue(TE.dupe is TE('third') is TE['dupe']) + elif TE._member_type_ is str: + self.assertTrue(TE.dupe is TE('3') is TE['dupe']) + elif issubclass(TE, Flag): + self.assertTrue(TE.dupe is TE(3) is TE['dupe']) + else: + self.assertTrue(TE.dupe is TE(self.values[2]) is TE['dupe']) + + def test_bool_is_true(self): + class Empty(self.enum_type): + pass + self.assertTrue(Empty) + # + self.assertTrue(self.MainEnum) + for member in self.MainEnum: + self.assertTrue(member) + + def test_changing_member_fails(self): + MainEnum = self.MainEnum + with self.assertRaises(AttributeError): + self.MainEnum.second = 'really first' + + @unittest.skipIf( + python_version >= (3, 12), + '__contains__ now returns True/False for all inputs', ) + def test_contains_er(self): + MainEnum = self.MainEnum + self.assertIn(MainEnum.third, MainEnum) + with self.assertRaises(TypeError): + with self.assertWarns(DeprecationWarning): + self.source_values[1] in MainEnum + with self.assertRaises(TypeError): + with self.assertWarns(DeprecationWarning): + 'first' in MainEnum + val = MainEnum.dupe + self.assertIn(val, MainEnum) + # + class OtherEnum(Enum): + one = auto() + two = auto() + self.assertNotIn(OtherEnum.two, MainEnum) - def test_dir_on_item(self): - Season = self.Season - self.assertEqual( - set(dir(Season.WINTER)), - set(['__class__', '__doc__', '__module__', 'name', 'value']), + @unittest.skipIf( + python_version < (3, 12), + '__contains__ works only with enum memmbers before 3.12', ) + def test_contains_tf(self): + MainEnum = self.MainEnum + self.assertIn(MainEnum.first, MainEnum) + self.assertTrue(self.source_values[0] in MainEnum) + self.assertFalse('first' in MainEnum) + val = MainEnum.dupe + self.assertIn(val, MainEnum) + # + class OtherEnum(Enum): + one = auto() + two = auto() + self.assertNotIn(OtherEnum.two, MainEnum) + + def test_dir_on_class(self): + TE = self.MainEnum + self.assertEqual(set(dir(TE)), set(enum_dir(TE))) + + def test_dir_on_item(self): + TE = self.MainEnum + self.assertEqual(set(dir(TE.first)), set(member_dir(TE.first))) def test_dir_with_added_behavior(self): - class Test(Enum): - this = 'that' - these = 'those' + class Test(self.enum_type): + this = auto() + these = auto() def wowser(self): return ("Wowser! I'm %s!" % self.name) - self.assertEqual( - set(dir(Test)), - set(['__class__', '__doc__', '__members__', '__module__', 'this', 'these']), - ) - self.assertEqual( - set(dir(Test.this)), - set(['__class__', '__doc__', '__module__', 'name', 'value', 'wowser']), - ) + self.assertTrue('wowser' not in dir(Test)) + self.assertTrue('wowser' in dir(Test.this)) def test_dir_on_sub_with_behavior_on_super(self): # see issue22506 - class SuperEnum(Enum): + class SuperEnum(self.enum_type): def invisible(self): return "did you see me?" class SubEnum(SuperEnum): - sample = 5 - self.assertEqual( - set(dir(SubEnum.sample)), - set(['__class__', '__doc__', '__module__', 'name', 'value', 'invisible']), - ) + sample = auto() + self.assertTrue('invisible' not in dir(SubEnum)) + self.assertTrue('invisible' in dir(SubEnum.sample)) def test_dir_on_sub_with_behavior_including_instance_dict_on_super(self): # see issue40084 - class SuperEnum(IntEnum): - def __new__(cls, value, description=""): - obj = int.__new__(cls, value) - obj._value_ = value - obj.description = description + class SuperEnum(self.enum_type): + def __new__(cls, *value, **kwds): + new = self.enum_type._member_type_.__new__ + if self.enum_type._member_type_ is object: + obj = new(cls) + else: + if isinstance(value[0], tuple): + create_value ,= value[0] + else: + create_value = value + obj = new(cls, *create_value) + obj._value_ = value[0] if len(value) == 1 else value + obj.description = 'test description' return obj class SubEnum(SuperEnum): - sample = 5 - self.assertTrue({'description'} <= set(dir(SubEnum.sample))) + sample = self.source_values[1] + self.assertTrue('description' not in dir(SubEnum)) + self.assertTrue('description' in dir(SubEnum.sample), dir(SubEnum.sample)) def test_enum_in_enum_out(self): - Season = self.Season - self.assertIs(Season(Season.WINTER), Season.WINTER) - - def test_enum_value(self): - Season = self.Season - self.assertEqual(Season.SPRING.value, 1) - - def test_intenum_value(self): - self.assertEqual(IntStooges.CURLY.value, 2) - - def test_enum(self): - Season = self.Season - lst = list(Season) - self.assertEqual(len(lst), len(Season)) - self.assertEqual(len(Season), 4, Season) - self.assertEqual( - [Season.SPRING, Season.SUMMER, Season.AUTUMN, Season.WINTER], lst) - - for i, season in enumerate('SPRING SUMMER AUTUMN WINTER'.split(), 1): - e = Season(i) - self.assertEqual(e, getattr(Season, season)) - self.assertEqual(e.value, i) - self.assertNotEqual(e, i) - self.assertEqual(e.name, season) - self.assertIn(e, Season) - self.assertIs(type(e), Season) - self.assertIsInstance(e, Season) - self.assertEqual(str(e), 'Season.' + season) - self.assertEqual( - repr(e), - ''.format(season, i), - ) - - def test_value_name(self): - Season = self.Season - self.assertEqual(Season.SPRING.name, 'SPRING') - self.assertEqual(Season.SPRING.value, 1) - with self.assertRaises(AttributeError): - Season.SPRING.name = 'invierno' - with self.assertRaises(AttributeError): - Season.SPRING.value = 2 - - def test_changing_member(self): - Season = self.Season - with self.assertRaises(AttributeError): - Season.WINTER = 'really cold' - - def test_attribute_deletion(self): - class Season(Enum): - SPRING = 1 - SUMMER = 2 - AUTUMN = 3 - WINTER = 4 + Main = self.MainEnum + self.assertIs(Main(Main.first), Main.first) - def spam(cls): - pass - - self.assertTrue(hasattr(Season, 'spam')) - del Season.spam - self.assertFalse(hasattr(Season, 'spam')) - - with self.assertRaises(AttributeError): - del Season.SPRING - with self.assertRaises(AttributeError): - del Season.DRY - with self.assertRaises(AttributeError): - del Season.SPRING.name - - def test_bool_of_class(self): - class Empty(Enum): - pass - self.assertTrue(bool(Empty)) - - def test_bool_of_member(self): - class Count(Enum): - zero = 0 - one = 1 - two = 2 - for member in Count: - self.assertTrue(bool(member)) + def test_hash(self): + MainEnum = self.MainEnum + mapping = {} + mapping[MainEnum.first] = '1225' + mapping[MainEnum.second] = '0315' + mapping[MainEnum.third] = '0704' + self.assertEqual(mapping[MainEnum.second], '0315') def test_invalid_names(self): with self.assertRaises(ValueError): - class Wrong(Enum): + class Wrong(self.enum_type): mro = 9 with self.assertRaises(ValueError): - class Wrong(Enum): + class Wrong(self.enum_type): _create_= 11 with self.assertRaises(ValueError): - class Wrong(Enum): + class Wrong(self.enum_type): _get_mixins_ = 9 with self.assertRaises(ValueError): - class Wrong(Enum): + class Wrong(self.enum_type): _find_new_ = 1 with self.assertRaises(ValueError): - class Wrong(Enum): + class Wrong(self.enum_type): _any_name_ = 9 + def test_object_str_override(self): + "check that setting __str__ to object's is not reset to Enum's" + class Generic(self.enum_type): + item = self.source_values[2] + def __repr__(self): + return "%s.test" % (self._name_, ) + __str__ = object.__str__ + self.assertEqual(str(Generic.item), 'item.test') + + def test_overridden_str(self): + # TODO: RUSTPYTHON, format(NS.first) does not use __str__ + if isinstance(self, TestIntFlag) or isinstance(self, TestIntEnum) or isinstance(self, TestMinimalFloat): + self.skipTest("format(NS.first) does not use __str__") + NS = self.NewStrEnum + self.assertEqual(str(NS.first), NS.first.name.upper()) + self.assertEqual(format(NS.first), NS.first.name.upper()) + + def test_overridden_str_format(self): + NSF = self.NewStrFormatEnum + self.assertEqual(str(NSF.first), NSF.first.name.title()) + self.assertEqual(format(NSF.first), ''.join(reversed(NSF.first.name))) + + def test_overridden_str_format_inherited(self): + NSE = self.NewSubEnum + self.assertEqual(str(NSE.first), NSE.first.name.title()) + self.assertEqual(format(NSE.first), ''.join(reversed(NSE.first.name))) + + def test_programmatic_function_string(self): + MinorEnum = self.enum_type('MinorEnum', 'june july august') + lst = list(MinorEnum) + self.assertEqual(len(lst), len(MinorEnum)) + self.assertEqual(len(MinorEnum), 3, MinorEnum) + self.assertEqual( + [MinorEnum.june, MinorEnum.july, MinorEnum.august], + lst, + ) + values = self.values + if self.enum_type is StrEnum: + values = ['june','july','august'] + for month, av in zip('june july august'.split(), values): + e = MinorEnum[month] + self.assertEqual(e.value, av, list(MinorEnum)) + self.assertEqual(e.name, month) + if MinorEnum._member_type_ is not object and issubclass(MinorEnum, MinorEnum._member_type_): + self.assertEqual(e, av) + else: + self.assertNotEqual(e, av) + self.assertIn(e, MinorEnum) + self.assertIs(type(e), MinorEnum) + self.assertIs(e, MinorEnum(av)) + + def test_programmatic_function_string_list(self): + MinorEnum = self.enum_type('MinorEnum', ['june', 'july', 'august']) + lst = list(MinorEnum) + self.assertEqual(len(lst), len(MinorEnum)) + self.assertEqual(len(MinorEnum), 3, MinorEnum) + self.assertEqual( + [MinorEnum.june, MinorEnum.july, MinorEnum.august], + lst, + ) + values = self.values + if self.enum_type is StrEnum: + values = ['june','july','august'] + for month, av in zip('june july august'.split(), values): + e = MinorEnum[month] + self.assertEqual(e.value, av) + self.assertEqual(e.name, month) + if MinorEnum._member_type_ is not object and issubclass(MinorEnum, MinorEnum._member_type_): + self.assertEqual(e, av) + else: + self.assertNotEqual(e, av) + self.assertIn(e, MinorEnum) + self.assertIs(type(e), MinorEnum) + self.assertIs(e, MinorEnum(av)) + + def test_programmatic_function_iterable(self): + MinorEnum = self.enum_type( + 'MinorEnum', + (('june', self.source_values[0]), ('july', self.source_values[1]), ('august', self.source_values[2])) + ) + lst = list(MinorEnum) + self.assertEqual(len(lst), len(MinorEnum)) + self.assertEqual(len(MinorEnum), 3, MinorEnum) + self.assertEqual( + [MinorEnum.june, MinorEnum.july, MinorEnum.august], + lst, + ) + for month, av in zip('june july august'.split(), self.values): + e = MinorEnum[month] + self.assertEqual(e.value, av) + self.assertEqual(e.name, month) + if MinorEnum._member_type_ is not object and issubclass(MinorEnum, MinorEnum._member_type_): + self.assertEqual(e, av) + else: + self.assertNotEqual(e, av) + self.assertIn(e, MinorEnum) + self.assertIs(type(e), MinorEnum) + self.assertIs(e, MinorEnum(av)) + + def test_programmatic_function_from_dict(self): + MinorEnum = self.enum_type( + 'MinorEnum', + OrderedDict((('june', self.source_values[0]), ('july', self.source_values[1]), ('august', self.source_values[2]))) + ) + lst = list(MinorEnum) + self.assertEqual(len(lst), len(MinorEnum)) + self.assertEqual(len(MinorEnum), 3, MinorEnum) + self.assertEqual( + [MinorEnum.june, MinorEnum.july, MinorEnum.august], + lst, + ) + for month, av in zip('june july august'.split(), self.values): + e = MinorEnum[month] + if MinorEnum._member_type_ is not object and issubclass(MinorEnum, MinorEnum._member_type_): + self.assertEqual(e, av) + else: + self.assertNotEqual(e, av) + self.assertIn(e, MinorEnum) + self.assertIs(type(e), MinorEnum) + self.assertIs(e, MinorEnum(av)) + + def test_repr(self): + TE = self.MainEnum + if self.is_flag: + self.assertEqual(repr(TE(0)), "") + self.assertEqual(repr(TE.dupe), "") + self.assertEqual(repr(self.dupe2), "") + elif issubclass(TE, StrEnum): + self.assertEqual(repr(TE.dupe), "") + else: + self.assertEqual(repr(TE.dupe), "" % (self.values[2], ), TE._value_repr_) + for name, value, member in zip(self.names, self.values, TE, strict=True): + self.assertEqual(repr(member), "" % (member.name, member.value)) + + def test_repr_override(self): + class Generic(self.enum_type): + first = auto() + second = auto() + third = auto() + def __repr__(self): + return "don't you just love shades of %s?" % self.name + self.assertEqual( + repr(Generic.third), + "don't you just love shades of third?", + ) + + def test_inherited_repr(self): + class MyEnum(self.enum_type): + def __repr__(self): + return "My name is %s." % self.name + class MySubEnum(MyEnum): + this = auto() + that = auto() + theother = auto() + self.assertEqual(repr(MySubEnum.that), "My name is that.") + + def test_multiple_superclasses_repr(self): + class _EnumSuperClass(metaclass=EnumMeta): + pass + class E(_EnumSuperClass, Enum): + A = 1 + self.assertEqual(repr(E.A), "") + + def test_reversed_iteration_order(self): + self.assertEqual( + list(reversed(self.MainEnum)), + [self.MainEnum.third, self.MainEnum.second, self.MainEnum.first], + ) + +class _PlainOutputTests: + + def test_str(self): + TE = self.MainEnum + if self.is_flag: + self.assertEqual(str(TE(0)), "MainEnum(0)") + self.assertEqual(str(TE.dupe), "MainEnum.dupe") + self.assertEqual(str(self.dupe2), "MainEnum.first|third") + else: + self.assertEqual(str(TE.dupe), "MainEnum.third") + for name, value, member in zip(self.names, self.values, TE, strict=True): + self.assertEqual(str(member), "MainEnum.%s" % (member.name, )) + + def test_format(self): + TE = self.MainEnum + if self.is_flag: + self.assertEqual(format(TE.dupe), "MainEnum.dupe") + self.assertEqual(format(self.dupe2), "MainEnum.first|third") + else: + self.assertEqual(format(TE.dupe), "MainEnum.third") + for name, value, member in zip(self.names, self.values, TE, strict=True): + self.assertEqual(format(member), "MainEnum.%s" % (member.name, )) + + def test_overridden_format(self): + NF = self.NewFormatEnum + self.assertEqual(str(NF.first), "NewFormatEnum.first", '%s %r' % (NF.__str__, NF.first)) + self.assertEqual(format(NF.first), "FIRST") + + def test_format_specs(self): + TE = self.MainEnum + self.assertFormatIsStr('{}', TE.second) + self.assertFormatIsStr('{:}', TE.second) + self.assertFormatIsStr('{:20}', TE.second) + self.assertFormatIsStr('{:^20}', TE.second) + self.assertFormatIsStr('{:>20}', TE.second) + self.assertFormatIsStr('{:<20}', TE.second) + self.assertFormatIsStr('{:5.2}', TE.second) + + +class _MixedOutputTests: + + def test_str(self): + TE = self.MainEnum + if self.is_flag: + self.assertEqual(str(TE.dupe), "MainEnum.dupe") + self.assertEqual(str(self.dupe2), "MainEnum.first|third") + else: + self.assertEqual(str(TE.dupe), "MainEnum.third") + for name, value, member in zip(self.names, self.values, TE, strict=True): + self.assertEqual(str(member), "MainEnum.%s" % (member.name, )) + + def test_format(self): + TE = self.MainEnum + if self.is_flag: + self.assertEqual(format(TE.dupe), "MainEnum.dupe") + self.assertEqual(format(self.dupe2), "MainEnum.first|third") + else: + self.assertEqual(format(TE.dupe), "MainEnum.third") + for name, value, member in zip(self.names, self.values, TE, strict=True): + self.assertEqual(format(member), "MainEnum.%s" % (member.name, )) + + def test_overridden_format(self): + NF = self.NewFormatEnum + self.assertEqual(str(NF.first), "NewFormatEnum.first") + self.assertEqual(format(NF.first), "FIRST") + + def test_format_specs(self): + TE = self.MainEnum + self.assertFormatIsStr('{}', TE.first) + self.assertFormatIsStr('{:}', TE.first) + self.assertFormatIsStr('{:20}', TE.first) + self.assertFormatIsStr('{:^20}', TE.first) + self.assertFormatIsStr('{:>20}', TE.first) + self.assertFormatIsStr('{:<20}', TE.first) + self.assertFormatIsStr('{:5.2}', TE.first) + + +class _MinimalOutputTests: + + def test_str(self): + TE = self.MainEnum + if self.is_flag: + self.assertEqual(str(TE.dupe), "3") + self.assertEqual(str(self.dupe2), "5") + else: + self.assertEqual(str(TE.dupe), str(self.values[2])) + for name, value, member in zip(self.names, self.values, TE, strict=True): + self.assertEqual(str(member), str(value)) + + def test_format(self): + TE = self.MainEnum + if self.is_flag: + self.assertEqual(format(TE.dupe), "3") + self.assertEqual(format(self.dupe2), "5") + else: + self.assertEqual(format(TE.dupe), format(self.values[2])) + for name, value, member in zip(self.names, self.values, TE, strict=True): + self.assertEqual(format(member), format(value)) + + def test_overridden_format(self): + NF = self.NewFormatEnum + self.assertEqual(str(NF.first), str(self.values[0])) + self.assertEqual(format(NF.first), "FIRST") + + def test_format_specs(self): + TE = self.MainEnum + self.assertFormatIsValue('{}', TE.third) + self.assertFormatIsValue('{:}', TE.third) + self.assertFormatIsValue('{:20}', TE.third) + self.assertFormatIsValue('{:^20}', TE.third) + self.assertFormatIsValue('{:>20}', TE.third) + self.assertFormatIsValue('{:<20}', TE.third) + if TE._member_type_ is float: + self.assertFormatIsValue('{:n}', TE.third) + self.assertFormatIsValue('{:5.2}', TE.third) + self.assertFormatIsValue('{:f}', TE.third) + + def test_copy(self): + TE = self.MainEnum + copied = copy.copy(TE) + self.assertEqual(copied, TE) + self.assertIs(copied, TE) + deep = copy.deepcopy(TE) + self.assertEqual(deep, TE) + self.assertIs(deep, TE) + + def test_copy_member(self): + TE = self.MainEnum + copied = copy.copy(TE.first) + self.assertIs(copied, TE.first) + deep = copy.deepcopy(TE.first) + self.assertIs(deep, TE.first) + +class _FlagTests: + + def test_default_missing_with_wrong_type_value(self): + with self.assertRaisesRegex( + ValueError, + "'RED' is not a valid ", + ) as ctx: + self.MainEnum('RED') + self.assertIs(ctx.exception.__context__, None) + + def test_closed_invert_expectations(self): + class ClosedAB(self.enum_type): + A = 1 + B = 2 + MASK = 3 + A, B = ClosedAB + AB_MASK = ClosedAB.MASK + # + self.assertIs(~A, B) + self.assertIs(~B, A) + self.assertIs(~(A|B), ClosedAB(0)) + self.assertIs(~AB_MASK, ClosedAB(0)) + self.assertIs(~ClosedAB(0), (A|B)) + # + class ClosedXYZ(self.enum_type): + X = 4 + Y = 2 + Z = 1 + MASK = 7 + X, Y, Z = ClosedXYZ + XYZ_MASK = ClosedXYZ.MASK + # + self.assertIs(~X, Y|Z) + self.assertIs(~Y, X|Z) + self.assertIs(~Z, X|Y) + self.assertIs(~(X|Y), Z) + self.assertIs(~(X|Z), Y) + self.assertIs(~(Y|Z), X) + self.assertIs(~(X|Y|Z), ClosedXYZ(0)) + self.assertIs(~XYZ_MASK, ClosedXYZ(0)) + self.assertIs(~ClosedXYZ(0), (X|Y|Z)) + + def test_open_invert_expectations(self): + class OpenAB(self.enum_type): + A = 1 + B = 2 + MASK = 255 + A, B = OpenAB + AB_MASK = OpenAB.MASK + # + if OpenAB._boundary_ in (EJECT, KEEP): + self.assertIs(~A, OpenAB(254)) + self.assertIs(~B, OpenAB(253)) + self.assertIs(~(A|B), OpenAB(252)) + self.assertIs(~AB_MASK, OpenAB(0)) + self.assertIs(~OpenAB(0), AB_MASK) + else: + self.assertIs(~A, B) + self.assertIs(~B, A) + self.assertIs(~(A|B), OpenAB(0)) + self.assertIs(~AB_MASK, OpenAB(0)) + self.assertIs(~OpenAB(0), (A|B)) + # + class OpenXYZ(self.enum_type): + X = 4 + Y = 2 + Z = 1 + MASK = 31 + X, Y, Z = OpenXYZ + XYZ_MASK = OpenXYZ.MASK + # + if OpenXYZ._boundary_ in (EJECT, KEEP): + self.assertIs(~X, OpenXYZ(27)) + self.assertIs(~Y, OpenXYZ(29)) + self.assertIs(~Z, OpenXYZ(30)) + self.assertIs(~(X|Y), OpenXYZ(25)) + self.assertIs(~(X|Z), OpenXYZ(26)) + self.assertIs(~(Y|Z), OpenXYZ(28)) + self.assertIs(~(X|Y|Z), OpenXYZ(24)) + self.assertIs(~XYZ_MASK, OpenXYZ(0)) + self.assertTrue(~OpenXYZ(0), XYZ_MASK) + else: + self.assertIs(~X, Y|Z) + self.assertIs(~Y, X|Z) + self.assertIs(~Z, X|Y) + self.assertIs(~(X|Y), Z) + self.assertIs(~(X|Z), Y) + self.assertIs(~(Y|Z), X) + self.assertIs(~(X|Y|Z), OpenXYZ(0)) + self.assertIs(~XYZ_MASK, OpenXYZ(0)) + self.assertTrue(~OpenXYZ(0), (X|Y|Z)) + + +class TestPlainEnum(_EnumTests, _PlainOutputTests, unittest.TestCase): + enum_type = Enum + + +class TestPlainFlag(_EnumTests, _PlainOutputTests, _FlagTests, unittest.TestCase): + enum_type = Flag + + +class TestIntEnum(_EnumTests, _MinimalOutputTests, unittest.TestCase): + enum_type = IntEnum + + +class TestStrEnum(_EnumTests, _MinimalOutputTests, unittest.TestCase): + enum_type = StrEnum + + +class TestIntFlag(_EnumTests, _MinimalOutputTests, _FlagTests, unittest.TestCase): + enum_type = IntFlag + + +class TestMixedInt(_EnumTests, _MixedOutputTests, unittest.TestCase): + class enum_type(int, Enum): pass + + +class TestMixedStr(_EnumTests, _MixedOutputTests, unittest.TestCase): + class enum_type(str, Enum): pass + + +class TestMixedIntFlag(_EnumTests, _MixedOutputTests, _FlagTests, unittest.TestCase): + class enum_type(int, Flag): pass + + +class TestMixedDate(_EnumTests, _MixedOutputTests, unittest.TestCase): + + values = [date(2021, 12, 25), date(2020, 3, 15), date(2019, 11, 27)] + source_values = [(2021, 12, 25), (2020, 3, 15), (2019, 11, 27)] + + class enum_type(date, Enum): + def _generate_next_value_(name, start, count, last_values): + values = [(2021, 12, 25), (2020, 3, 15), (2019, 11, 27)] + return values[count] + + +class TestMinimalDate(_EnumTests, _MinimalOutputTests, unittest.TestCase): + + values = [date(2023, 12, 1), date(2016, 2, 29), date(2009, 1, 1)] + source_values = [(2023, 12, 1), (2016, 2, 29), (2009, 1, 1)] + + class enum_type(date, ReprEnum): + def _generate_next_value_(name, start, count, last_values): + values = [(2023, 12, 1), (2016, 2, 29), (2009, 1, 1)] + return values[count] + + +class TestMixedFloat(_EnumTests, _MixedOutputTests, unittest.TestCase): + + values = [1.1, 2.2, 3.3] + + class enum_type(float, Enum): + def _generate_next_value_(name, start, count, last_values): + values = [1.1, 2.2, 3.3] + return values[count] + + +class TestMinimalFloat(_EnumTests, _MinimalOutputTests, unittest.TestCase): + + values = [4.4, 5.5, 6.6] + + class enum_type(float, ReprEnum): + def _generate_next_value_(name, start, count, last_values): + values = [4.4, 5.5, 6.6] + return values[count] + + +class TestSpecial(unittest.TestCase): + """ + various operations that are not attributable to every possible enum + """ + + def setUp(self): + class Season(Enum): + SPRING = 1 + SUMMER = 2 + AUTUMN = 3 + WINTER = 4 + self.Season = Season + # + class Grades(IntEnum): + A = 5 + B = 4 + C = 3 + D = 2 + F = 0 + self.Grades = Grades + # + class Directional(str, Enum): + EAST = 'east' + WEST = 'west' + NORTH = 'north' + SOUTH = 'south' + self.Directional = Directional + # + from datetime import date + class Holiday(date, Enum): + NEW_YEAR = 2013, 1, 1 + IDES_OF_MARCH = 2013, 3, 15 + self.Holiday = Holiday + def test_bool(self): # plain Enum members are always True class Logic(Enum): @@ -347,71 +1013,57 @@ class IntLogic(int, Enum): self.assertTrue(IntLogic.true) self.assertFalse(IntLogic.false) - def test_contains(self): - Season = self.Season - self.assertIn(Season.AUTUMN, Season) - with self.assertRaises(TypeError): - 3 in Season - with self.assertRaises(TypeError): - 'AUTUMN' in Season - - val = Season(3) - self.assertIn(val, Season) - - class OtherEnum(Enum): - one = 1; two = 2 - self.assertNotIn(OtherEnum.two, Season) - def test_comparisons(self): Season = self.Season with self.assertRaises(TypeError): Season.SPRING < Season.WINTER with self.assertRaises(TypeError): Season.SPRING > 4 - + # self.assertNotEqual(Season.SPRING, 1) - + # class Part(Enum): SPRING = 1 CLIP = 2 BARREL = 3 - + # self.assertNotEqual(Season.SPRING, Part.SPRING) with self.assertRaises(TypeError): Season.SPRING < Part.CLIP - def test_enum_duplicates(self): - class Season(Enum): - SPRING = 1 - SUMMER = 2 - AUTUMN = FALL = 3 - WINTER = 4 - ANOTHER_SPRING = 1 - lst = list(Season) - self.assertEqual( - lst, - [Season.SPRING, Season.SUMMER, - Season.AUTUMN, Season.WINTER, - ]) - self.assertIs(Season.FALL, Season.AUTUMN) - self.assertEqual(Season.FALL.value, 3) - self.assertEqual(Season.AUTUMN.value, 3) - self.assertIs(Season(3), Season.AUTUMN) - self.assertIs(Season(1), Season.SPRING) - self.assertEqual(Season.FALL.name, 'AUTUMN') - self.assertEqual( - [k for k,v in Season.__members__.items() if v.name != k], - ['FALL', 'ANOTHER_SPRING'], - ) + @unittest.skip('to-do list') + def test_dir_with_custom_dunders(self): + class PlainEnum(Enum): + pass + cls_dir = dir(PlainEnum) + self.assertNotIn('__repr__', cls_dir) + self.assertNotIn('__str__', cls_dir) + self.assertNotIn('__format__', cls_dir) + self.assertNotIn('__init__', cls_dir) + # + class MyEnum(Enum): + def __repr__(self): + return object.__repr__(self) + def __str__(self): + return object.__repr__(self) + def __format__(self): + return object.__repr__(self) + def __init__(self): + pass + cls_dir = dir(MyEnum) + self.assertIn('__repr__', cls_dir) + self.assertIn('__str__', cls_dir) + self.assertIn('__format__', cls_dir) + self.assertIn('__init__', cls_dir) - def test_duplicate_name(self): + def test_duplicate_name_error(self): with self.assertRaises(TypeError): class Color(Enum): red = 1 green = 2 blue = 3 red = 4 - + # with self.assertRaises(TypeError): class Color(Enum): red = 1 @@ -419,186 +1071,345 @@ class Color(Enum): blue = 3 def red(self): return 'red' - + # with self.assertRaises(TypeError): class Color(Enum): - @property + @enum.property def red(self): return 'redder' red = 1 green = 2 blue = 3 + def test_enum_function_with_qualname(self): + if isinstance(Theory, Exception): + raise Theory + self.assertEqual(Theory.__qualname__, 'spanish_inquisition') + + def test_enum_of_types(self): + """Support using Enum to refer to types deliberately.""" + class MyTypes(Enum): + i = int + f = float + s = str + self.assertEqual(MyTypes.i.value, int) + self.assertEqual(MyTypes.f.value, float) + self.assertEqual(MyTypes.s.value, str) + class Foo: + pass + class Bar: + pass + class MyTypes2(Enum): + a = Foo + b = Bar + self.assertEqual(MyTypes2.a.value, Foo) + self.assertEqual(MyTypes2.b.value, Bar) + class SpamEnumNotInner: + pass + class SpamEnum(Enum): + spam = SpamEnumNotInner + self.assertEqual(SpamEnum.spam.value, SpamEnumNotInner) + + def test_enum_of_generic_aliases(self): + class E(Enum): + a = typing.List[int] + b = list[int] + self.assertEqual(E.a.value, typing.List[int]) + self.assertEqual(E.b.value, list[int]) + self.assertEqual(repr(E.a), '') + self.assertEqual(repr(E.b), '') + + @unittest.skipIf( + python_version >= (3, 13), + 'inner classes are not members', + ) + def test_nested_classes_in_enum_are_members(self): + """ + Check for warnings pre-3.13 + """ + with self.assertWarnsRegex(DeprecationWarning, 'will not become a member'): + class Outer(Enum): + a = 1 + b = 2 + class Inner(Enum): + foo = 10 + bar = 11 + self.assertTrue(isinstance(Outer.Inner, Outer)) + self.assertEqual(Outer.a.value, 1) + self.assertEqual(Outer.Inner.value.foo.value, 10) + self.assertEqual( + list(Outer.Inner.value), + [Outer.Inner.value.foo, Outer.Inner.value.bar], + ) + self.assertEqual( + list(Outer), + [Outer.a, Outer.b, Outer.Inner], + ) + + @unittest.skipIf( + python_version < (3, 13), + 'inner classes are still members', + ) + def test_nested_classes_in_enum_are_not_members(self): + """Support locally-defined nested classes.""" + class Outer(Enum): + a = 1 + b = 2 + class Inner(Enum): + foo = 10 + bar = 11 + self.assertTrue(isinstance(Outer.Inner, type)) + self.assertEqual(Outer.a.value, 1) + self.assertEqual(Outer.Inner.foo.value, 10) + self.assertEqual( + list(Outer.Inner), + [Outer.Inner.foo, Outer.Inner.bar], + ) + self.assertEqual( + list(Outer), + [Outer.a, Outer.b], + ) + + def test_nested_classes_in_enum_with_nonmember(self): + class Outer(Enum): + a = 1 + b = 2 + @nonmember + class Inner(Enum): + foo = 10 + bar = 11 + self.assertTrue(isinstance(Outer.Inner, type)) + self.assertEqual(Outer.a.value, 1) + self.assertEqual(Outer.Inner.foo.value, 10) + self.assertEqual( + list(Outer.Inner), + [Outer.Inner.foo, Outer.Inner.bar], + ) + self.assertEqual( + list(Outer), + [Outer.a, Outer.b], + ) + + def test_enum_of_types_with_nonmember(self): + """Support using Enum to refer to types deliberately.""" + class MyTypes(Enum): + i = int + f = nonmember(float) + s = str + self.assertEqual(MyTypes.i.value, int) + self.assertTrue(MyTypes.f is float) + self.assertEqual(MyTypes.s.value, str) + class Foo: + pass + class Bar: + pass + class MyTypes2(Enum): + a = Foo + b = nonmember(Bar) + self.assertEqual(MyTypes2.a.value, Foo) + self.assertTrue(MyTypes2.b is Bar) + class SpamEnumIsInner: + pass + class SpamEnum(Enum): + spam = nonmember(SpamEnumIsInner) + self.assertTrue(SpamEnum.spam is SpamEnumIsInner) + + def test_nested_classes_in_enum_with_member(self): + """Support locally-defined nested classes.""" + class Outer(Enum): + a = 1 + b = 2 + @member + class Inner(Enum): + foo = 10 + bar = 11 + self.assertTrue(isinstance(Outer.Inner, Outer)) + self.assertEqual(Outer.a.value, 1) + self.assertEqual(Outer.Inner.value.foo.value, 10) + self.assertEqual( + list(Outer.Inner.value), + [Outer.Inner.value.foo, Outer.Inner.value.bar], + ) + self.assertEqual( + list(Outer), + [Outer.a, Outer.b, Outer.Inner], + ) + def test_enum_with_value_name(self): class Huh(Enum): name = 1 value = 2 - self.assertEqual( - list(Huh), - [Huh.name, Huh.value], - ) + self.assertEqual(list(Huh), [Huh.name, Huh.value]) self.assertIs(type(Huh.name), Huh) self.assertEqual(Huh.name.name, 'name') self.assertEqual(Huh.name.value, 1) - def test_format_enum(self): - Season = self.Season - self.assertEqual('{}'.format(Season.SPRING), - '{}'.format(str(Season.SPRING))) - self.assertEqual( '{:}'.format(Season.SPRING), - '{:}'.format(str(Season.SPRING))) - self.assertEqual('{:20}'.format(Season.SPRING), - '{:20}'.format(str(Season.SPRING))) - self.assertEqual('{:^20}'.format(Season.SPRING), - '{:^20}'.format(str(Season.SPRING))) - self.assertEqual('{:>20}'.format(Season.SPRING), - '{:>20}'.format(str(Season.SPRING))) - self.assertEqual('{:<20}'.format(Season.SPRING), - '{:<20}'.format(str(Season.SPRING))) - - def test_str_override_enum(self): - class EnumWithStrOverrides(Enum): - one = auto() - two = auto() + def test_inherited_data_type(self): + class HexInt(int): + __qualname__ = 'HexInt' + def __repr__(self): + return hex(self) + class MyEnum(HexInt, enum.Enum): + __qualname__ = 'MyEnum' + A = 1 + B = 2 + C = 3 + self.assertEqual(repr(MyEnum.A), '') + globals()['HexInt'] = HexInt + globals()['MyEnum'] = MyEnum + test_pickle_dump_load(self.assertIs, MyEnum.A) + test_pickle_dump_load(self.assertIs, MyEnum) + # + class SillyInt(HexInt): + __qualname__ = 'SillyInt' + pass + class MyOtherEnum(SillyInt, enum.Enum): + __qualname__ = 'MyOtherEnum' + D = 4 + E = 5 + F = 6 + self.assertIs(MyOtherEnum._member_type_, SillyInt) + globals()['SillyInt'] = SillyInt + globals()['MyOtherEnum'] = MyOtherEnum + test_pickle_dump_load(self.assertIs, MyOtherEnum.E) + test_pickle_dump_load(self.assertIs, MyOtherEnum) + # + # This did not work in 3.10, but does now with pickling by name + class UnBrokenInt(int): + __qualname__ = 'UnBrokenInt' + def __new__(cls, value): + return int.__new__(cls, value) + class MyUnBrokenEnum(UnBrokenInt, Enum): + __qualname__ = 'MyUnBrokenEnum' + G = 7 + H = 8 + I = 9 + self.assertIs(MyUnBrokenEnum._member_type_, UnBrokenInt) + self.assertIs(MyUnBrokenEnum(7), MyUnBrokenEnum.G) + globals()['UnBrokenInt'] = UnBrokenInt + globals()['MyUnBrokenEnum'] = MyUnBrokenEnum + test_pickle_dump_load(self.assertIs, MyUnBrokenEnum.I) + test_pickle_dump_load(self.assertIs, MyUnBrokenEnum) - def __str__(self): - return 'Str!' - self.assertEqual(str(EnumWithStrOverrides.one), 'Str!') - self.assertEqual('{}'.format(EnumWithStrOverrides.one), 'Str!') - - def test_format_override_enum(self): - class EnumWithFormatOverride(Enum): - one = 1.0 - two = 2.0 - def __format__(self, spec): - return 'Format!!' - self.assertEqual(str(EnumWithFormatOverride.one), 'EnumWithFormatOverride.one') - self.assertEqual('{}'.format(EnumWithFormatOverride.one), 'Format!!') + def test_floatenum_fromhex(self): + h = float.hex(FloatStooges.MOE.value) + self.assertIs(FloatStooges.fromhex(h), FloatStooges.MOE) + h = float.hex(FloatStooges.MOE.value + 0.01) + with self.assertRaises(ValueError): + FloatStooges.fromhex(h) - def test_str_and_format_override_enum(self): - class EnumWithStrFormatOverrides(Enum): - one = auto() - two = auto() - def __str__(self): - return 'Str!' - def __format__(self, spec): - return 'Format!' - self.assertEqual(str(EnumWithStrFormatOverrides.one), 'Str!') - self.assertEqual('{}'.format(EnumWithStrFormatOverrides.one), 'Format!') - - def test_str_override_mixin(self): - class MixinEnumWithStrOverride(float, Enum): - one = 1.0 - two = 2.0 - def __str__(self): - return 'Overridden!' - self.assertEqual(str(MixinEnumWithStrOverride.one), 'Overridden!') - self.assertEqual('{}'.format(MixinEnumWithStrOverride.one), 'Overridden!') - - def test_str_and_format_override_mixin(self): - class MixinWithStrFormatOverrides(float, Enum): - one = 1.0 - two = 2.0 - def __str__(self): - return 'Str!' - def __format__(self, spec): - return 'Format!' - self.assertEqual(str(MixinWithStrFormatOverrides.one), 'Str!') - self.assertEqual('{}'.format(MixinWithStrFormatOverrides.one), 'Format!') - - def test_format_override_mixin(self): - class TestFloat(float, Enum): - one = 1.0 - two = 2.0 - def __format__(self, spec): - return 'TestFloat success!' - self.assertEqual(str(TestFloat.one), 'TestFloat.one') - self.assertEqual('{}'.format(TestFloat.one), 'TestFloat success!') + def test_programmatic_function_type(self): + MinorEnum = Enum('MinorEnum', 'june july august', type=int) + lst = list(MinorEnum) + self.assertEqual(len(lst), len(MinorEnum)) + self.assertEqual(len(MinorEnum), 3, MinorEnum) + self.assertEqual( + [MinorEnum.june, MinorEnum.july, MinorEnum.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 1): + e = MinorEnum(i) + self.assertEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, MinorEnum) + self.assertIs(type(e), MinorEnum) - def assertFormatIsValue(self, spec, member): - self.assertEqual(spec.format(member), spec.format(member.value)) + def test_programmatic_function_string_with_start(self): + MinorEnum = Enum('MinorEnum', 'june july august', start=10) + lst = list(MinorEnum) + self.assertEqual(len(lst), len(MinorEnum)) + self.assertEqual(len(MinorEnum), 3, MinorEnum) + self.assertEqual( + [MinorEnum.june, MinorEnum.july, MinorEnum.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 10): + e = MinorEnum(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, MinorEnum) + self.assertIs(type(e), MinorEnum) + + def test_programmatic_function_type_with_start(self): + MinorEnum = Enum('MinorEnum', 'june july august', type=int, start=30) + lst = list(MinorEnum) + self.assertEqual(len(lst), len(MinorEnum)) + self.assertEqual(len(MinorEnum), 3, MinorEnum) + self.assertEqual( + [MinorEnum.june, MinorEnum.july, MinorEnum.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 30): + e = MinorEnum(i) + self.assertEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, MinorEnum) + self.assertIs(type(e), MinorEnum) - def test_format_enum_date(self): - Holiday = self.Holiday - self.assertFormatIsValue('{}', Holiday.IDES_OF_MARCH) - self.assertFormatIsValue('{:}', Holiday.IDES_OF_MARCH) - self.assertFormatIsValue('{:20}', Holiday.IDES_OF_MARCH) - self.assertFormatIsValue('{:^20}', Holiday.IDES_OF_MARCH) - self.assertFormatIsValue('{:>20}', Holiday.IDES_OF_MARCH) - self.assertFormatIsValue('{:<20}', Holiday.IDES_OF_MARCH) - self.assertFormatIsValue('{:%Y %m}', Holiday.IDES_OF_MARCH) - self.assertFormatIsValue('{:%Y %m %M:00}', Holiday.IDES_OF_MARCH) - - def test_format_enum_float(self): - Konstants = self.Konstants - self.assertFormatIsValue('{}', Konstants.TAU) - self.assertFormatIsValue('{:}', Konstants.TAU) - self.assertFormatIsValue('{:20}', Konstants.TAU) - self.assertFormatIsValue('{:^20}', Konstants.TAU) - self.assertFormatIsValue('{:>20}', Konstants.TAU) - self.assertFormatIsValue('{:<20}', Konstants.TAU) - self.assertFormatIsValue('{:n}', Konstants.TAU) - self.assertFormatIsValue('{:5.2}', Konstants.TAU) - self.assertFormatIsValue('{:f}', Konstants.TAU) - - def test_format_enum_int(self): - Grades = self.Grades - self.assertFormatIsValue('{}', Grades.C) - self.assertFormatIsValue('{:}', Grades.C) - self.assertFormatIsValue('{:20}', Grades.C) - self.assertFormatIsValue('{:^20}', Grades.C) - self.assertFormatIsValue('{:>20}', Grades.C) - self.assertFormatIsValue('{:<20}', Grades.C) - self.assertFormatIsValue('{:+}', Grades.C) - self.assertFormatIsValue('{:08X}', Grades.C) - self.assertFormatIsValue('{:b}', Grades.C) - - def test_format_enum_str(self): - Directional = self.Directional - self.assertFormatIsValue('{}', Directional.WEST) - self.assertFormatIsValue('{:}', Directional.WEST) - self.assertFormatIsValue('{:20}', Directional.WEST) - self.assertFormatIsValue('{:^20}', Directional.WEST) - self.assertFormatIsValue('{:>20}', Directional.WEST) - self.assertFormatIsValue('{:<20}', Directional.WEST) + def test_programmatic_function_string_list_with_start(self): + MinorEnum = Enum('MinorEnum', ['june', 'july', 'august'], start=20) + lst = list(MinorEnum) + self.assertEqual(len(lst), len(MinorEnum)) + self.assertEqual(len(MinorEnum), 3, MinorEnum) + self.assertEqual( + [MinorEnum.june, MinorEnum.july, MinorEnum.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 20): + e = MinorEnum(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, MinorEnum) + self.assertIs(type(e), MinorEnum) + + def test_programmatic_function_type_from_subclass(self): + MinorEnum = IntEnum('MinorEnum', 'june july august') + lst = list(MinorEnum) + self.assertEqual(len(lst), len(MinorEnum)) + self.assertEqual(len(MinorEnum), 3, MinorEnum) + self.assertEqual( + [MinorEnum.june, MinorEnum.july, MinorEnum.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 1): + e = MinorEnum(i) + self.assertEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, MinorEnum) + self.assertIs(type(e), MinorEnum) - def test_object_str_override(self): - class Colors(Enum): - RED, GREEN, BLUE = 1, 2, 3 - def __repr__(self): - return "test.%s" % (self._name_, ) - __str__ = object.__str__ - self.assertEqual(str(Colors.RED), 'test.RED') + def test_programmatic_function_type_from_subclass_with_start(self): + MinorEnum = IntEnum('MinorEnum', 'june july august', start=40) + lst = list(MinorEnum) + self.assertEqual(len(lst), len(MinorEnum)) + self.assertEqual(len(MinorEnum), 3, MinorEnum) + self.assertEqual( + [MinorEnum.june, MinorEnum.july, MinorEnum.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 40): + e = MinorEnum(i) + self.assertEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, MinorEnum) + self.assertIs(type(e), MinorEnum) - def test_enum_str_override(self): - class MyStrEnum(Enum): - def __str__(self): - return 'MyStr' - class MyMethodEnum(Enum): - def hello(self): - return 'Hello! My name is %s' % self.name - class Test1Enum(MyMethodEnum, int, MyStrEnum): - One = 1 - Two = 2 - self.assertTrue(Test1Enum._member_type_ is int) - self.assertEqual(str(Test1Enum.One), 'MyStr') - self.assertEqual(format(Test1Enum.One, ''), 'MyStr') - # - class Test2Enum(MyStrEnum, MyMethodEnum): - One = 1 - Two = 2 - self.assertEqual(str(Test2Enum.One), 'MyStr') - self.assertEqual(format(Test1Enum.One, ''), 'MyStr') + # TODO: RUSTPYTHON, AssertionError: is not + @unittest.expectedFailure + def test_intenum_from_bytes(self): + self.assertIs(IntStooges.from_bytes(b'\x00\x03', 'big'), IntStooges.MOE) + with self.assertRaises(ValueError): + IntStooges.from_bytes(b'\x00\x05', 'big') - def test_inherited_data_type(self): - class HexInt(int): - def __repr__(self): - return hex(self) - class MyEnum(HexInt, enum.Enum): - A = 1 - B = 2 - C = 3 - self.assertEqual(repr(MyEnum.A), '') + def test_reserved_sunder_error(self): + with self.assertRaisesRegex( + ValueError, + '_sunder_ names, such as ._bad_., are reserved', + ): + class Bad(Enum): + _bad_ = 1 def test_too_many_data_types(self): with self.assertRaisesRegex(TypeError, 'too many data types'): @@ -615,116 +1426,6 @@ def repr(self): class Huh(MyStr, MyInt, Enum): One = 1 - def test_hash(self): - Season = self.Season - dates = {} - dates[Season.WINTER] = '1225' - dates[Season.SPRING] = '0315' - dates[Season.SUMMER] = '0704' - dates[Season.AUTUMN] = '1031' - self.assertEqual(dates[Season.AUTUMN], '1031') - - def test_intenum_from_scratch(self): - class phy(int, Enum): - pi = 3 - tau = 2 * pi - self.assertTrue(phy.pi < phy.tau) - - def test_intenum_inherited(self): - class IntEnum(int, Enum): - pass - class phy(IntEnum): - pi = 3 - tau = 2 * pi - self.assertTrue(phy.pi < phy.tau) - - def test_floatenum_from_scratch(self): - class phy(float, Enum): - pi = 3.1415926 - tau = 2 * pi - self.assertTrue(phy.pi < phy.tau) - - def test_floatenum_inherited(self): - class FloatEnum(float, Enum): - pass - class phy(FloatEnum): - pi = 3.1415926 - tau = 2 * pi - self.assertTrue(phy.pi < phy.tau) - - def test_strenum_from_scratch(self): - class phy(str, Enum): - pi = 'Pi' - tau = 'Tau' - self.assertTrue(phy.pi < phy.tau) - - def test_strenum_inherited(self): - class StrEnum(str, Enum): - pass - class phy(StrEnum): - pi = 'Pi' - tau = 'Tau' - self.assertTrue(phy.pi < phy.tau) - - - def test_intenum(self): - class WeekDay(IntEnum): - SUNDAY = 1 - MONDAY = 2 - TUESDAY = 3 - WEDNESDAY = 4 - THURSDAY = 5 - FRIDAY = 6 - SATURDAY = 7 - - self.assertEqual(['a', 'b', 'c'][WeekDay.MONDAY], 'c') - self.assertEqual([i for i in range(WeekDay.TUESDAY)], [0, 1, 2]) - - lst = list(WeekDay) - self.assertEqual(len(lst), len(WeekDay)) - self.assertEqual(len(WeekDay), 7) - target = 'SUNDAY MONDAY TUESDAY WEDNESDAY THURSDAY FRIDAY SATURDAY' - target = target.split() - for i, weekday in enumerate(target, 1): - e = WeekDay(i) - self.assertEqual(e, i) - self.assertEqual(int(e), i) - self.assertEqual(e.name, weekday) - self.assertIn(e, WeekDay) - self.assertEqual(lst.index(e)+1, i) - self.assertTrue(0 < e < 8) - self.assertIs(type(e), WeekDay) - self.assertIsInstance(e, int) - self.assertIsInstance(e, Enum) - - def test_intenum_duplicates(self): - class WeekDay(IntEnum): - SUNDAY = 1 - MONDAY = 2 - TUESDAY = TEUSDAY = 3 - WEDNESDAY = 4 - THURSDAY = 5 - FRIDAY = 6 - SATURDAY = 7 - self.assertIs(WeekDay.TEUSDAY, WeekDay.TUESDAY) - self.assertEqual(WeekDay(3).name, 'TUESDAY') - self.assertEqual([k for k,v in WeekDay.__members__.items() - if v.name != k], ['TEUSDAY', ]) - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_intenum_from_bytes(self): - self.assertIs(IntStooges.from_bytes(b'\x00\x03', 'big'), IntStooges.MOE) - with self.assertRaises(ValueError): - IntStooges.from_bytes(b'\x00\x05', 'big') - - def test_floatenum_fromhex(self): - h = float.hex(FloatStooges.MOE.value) - self.assertIs(FloatStooges.fromhex(h), FloatStooges.MOE) - h = float.hex(FloatStooges.MOE.value + 0.01) - with self.assertRaises(ValueError): - FloatStooges.fromhex(h) - def test_pickle_enum(self): if isinstance(Stooges, Exception): raise Stooges @@ -755,12 +1456,7 @@ def test_pickle_enum_function_with_module(self): test_pickle_dump_load(self.assertIs, Question.who) test_pickle_dump_load(self.assertIs, Question) - def test_enum_function_with_qualname(self): - if isinstance(Theory, Exception): - raise Theory - self.assertEqual(Theory.__qualname__, 'spanish_inquisition') - - def test_class_nested_enum_and_pickle_protocol_four(self): + def test_pickle_nested_class(self): # would normally just have this directly in the class namespace class NestedEnum(Enum): twigs = 'common' @@ -774,11 +1470,11 @@ def test_pickle_by_name(self): class ReplaceGlobalInt(IntEnum): ONE = 1 TWO = 2 - ReplaceGlobalInt.__reduce_ex__ = enum._reduce_ex_by_name + ReplaceGlobalInt.__reduce_ex__ = enum._reduce_ex_by_global_name for proto in range(HIGHEST_PROTOCOL): self.assertEqual(ReplaceGlobalInt.TWO.__reduce_ex__(proto), 'TWO') - def test_exploding_pickle(self): + def test_pickle_explodes(self): BadPickle = Enum( 'BadPickle', 'dill sweet bread-n-butter', module=__name__) globals()['BadPickle'] = BadPickle @@ -819,185 +1515,6 @@ class Season(Enum): [Season.SUMMER, Season.WINTER, Season.AUTUMN, Season.SPRING], ) - def test_reversed_iteration_order(self): - self.assertEqual( - list(reversed(self.Season)), - [self.Season.WINTER, self.Season.AUTUMN, self.Season.SUMMER, - self.Season.SPRING] - ) - - def test_programmatic_function_string(self): - SummerMonth = Enum('SummerMonth', 'june july august') - lst = list(SummerMonth) - self.assertEqual(len(lst), len(SummerMonth)) - self.assertEqual(len(SummerMonth), 3, SummerMonth) - self.assertEqual( - [SummerMonth.june, SummerMonth.july, SummerMonth.august], - lst, - ) - for i, month in enumerate('june july august'.split(), 1): - e = SummerMonth(i) - self.assertEqual(int(e.value), i) - self.assertNotEqual(e, i) - self.assertEqual(e.name, month) - self.assertIn(e, SummerMonth) - self.assertIs(type(e), SummerMonth) - - def test_programmatic_function_string_with_start(self): - SummerMonth = Enum('SummerMonth', 'june july august', start=10) - lst = list(SummerMonth) - self.assertEqual(len(lst), len(SummerMonth)) - self.assertEqual(len(SummerMonth), 3, SummerMonth) - self.assertEqual( - [SummerMonth.june, SummerMonth.july, SummerMonth.august], - lst, - ) - for i, month in enumerate('june july august'.split(), 10): - e = SummerMonth(i) - self.assertEqual(int(e.value), i) - self.assertNotEqual(e, i) - self.assertEqual(e.name, month) - self.assertIn(e, SummerMonth) - self.assertIs(type(e), SummerMonth) - - def test_programmatic_function_string_list(self): - SummerMonth = Enum('SummerMonth', ['june', 'july', 'august']) - lst = list(SummerMonth) - self.assertEqual(len(lst), len(SummerMonth)) - self.assertEqual(len(SummerMonth), 3, SummerMonth) - self.assertEqual( - [SummerMonth.june, SummerMonth.july, SummerMonth.august], - lst, - ) - for i, month in enumerate('june july august'.split(), 1): - e = SummerMonth(i) - self.assertEqual(int(e.value), i) - self.assertNotEqual(e, i) - self.assertEqual(e.name, month) - self.assertIn(e, SummerMonth) - self.assertIs(type(e), SummerMonth) - - def test_programmatic_function_string_list_with_start(self): - SummerMonth = Enum('SummerMonth', ['june', 'july', 'august'], start=20) - lst = list(SummerMonth) - self.assertEqual(len(lst), len(SummerMonth)) - self.assertEqual(len(SummerMonth), 3, SummerMonth) - self.assertEqual( - [SummerMonth.june, SummerMonth.july, SummerMonth.august], - lst, - ) - for i, month in enumerate('june july august'.split(), 20): - e = SummerMonth(i) - self.assertEqual(int(e.value), i) - self.assertNotEqual(e, i) - self.assertEqual(e.name, month) - self.assertIn(e, SummerMonth) - self.assertIs(type(e), SummerMonth) - - def test_programmatic_function_iterable(self): - SummerMonth = Enum( - 'SummerMonth', - (('june', 1), ('july', 2), ('august', 3)) - ) - lst = list(SummerMonth) - self.assertEqual(len(lst), len(SummerMonth)) - self.assertEqual(len(SummerMonth), 3, SummerMonth) - self.assertEqual( - [SummerMonth.june, SummerMonth.july, SummerMonth.august], - lst, - ) - for i, month in enumerate('june july august'.split(), 1): - e = SummerMonth(i) - self.assertEqual(int(e.value), i) - self.assertNotEqual(e, i) - self.assertEqual(e.name, month) - self.assertIn(e, SummerMonth) - self.assertIs(type(e), SummerMonth) - - def test_programmatic_function_from_dict(self): - SummerMonth = Enum( - 'SummerMonth', - OrderedDict((('june', 1), ('july', 2), ('august', 3))) - ) - lst = list(SummerMonth) - self.assertEqual(len(lst), len(SummerMonth)) - self.assertEqual(len(SummerMonth), 3, SummerMonth) - self.assertEqual( - [SummerMonth.june, SummerMonth.july, SummerMonth.august], - lst, - ) - for i, month in enumerate('june july august'.split(), 1): - e = SummerMonth(i) - self.assertEqual(int(e.value), i) - self.assertNotEqual(e, i) - self.assertEqual(e.name, month) - self.assertIn(e, SummerMonth) - self.assertIs(type(e), SummerMonth) - - def test_programmatic_function_type(self): - SummerMonth = Enum('SummerMonth', 'june july august', type=int) - lst = list(SummerMonth) - self.assertEqual(len(lst), len(SummerMonth)) - self.assertEqual(len(SummerMonth), 3, SummerMonth) - self.assertEqual( - [SummerMonth.june, SummerMonth.july, SummerMonth.august], - lst, - ) - for i, month in enumerate('june july august'.split(), 1): - e = SummerMonth(i) - self.assertEqual(e, i) - self.assertEqual(e.name, month) - self.assertIn(e, SummerMonth) - self.assertIs(type(e), SummerMonth) - - def test_programmatic_function_type_with_start(self): - SummerMonth = Enum('SummerMonth', 'june july august', type=int, start=30) - lst = list(SummerMonth) - self.assertEqual(len(lst), len(SummerMonth)) - self.assertEqual(len(SummerMonth), 3, SummerMonth) - self.assertEqual( - [SummerMonth.june, SummerMonth.july, SummerMonth.august], - lst, - ) - for i, month in enumerate('june july august'.split(), 30): - e = SummerMonth(i) - self.assertEqual(e, i) - self.assertEqual(e.name, month) - self.assertIn(e, SummerMonth) - self.assertIs(type(e), SummerMonth) - - def test_programmatic_function_type_from_subclass(self): - SummerMonth = IntEnum('SummerMonth', 'june july august') - lst = list(SummerMonth) - self.assertEqual(len(lst), len(SummerMonth)) - self.assertEqual(len(SummerMonth), 3, SummerMonth) - self.assertEqual( - [SummerMonth.june, SummerMonth.july, SummerMonth.august], - lst, - ) - for i, month in enumerate('june july august'.split(), 1): - e = SummerMonth(i) - self.assertEqual(e, i) - self.assertEqual(e.name, month) - self.assertIn(e, SummerMonth) - self.assertIs(type(e), SummerMonth) - - def test_programmatic_function_type_from_subclass_with_start(self): - SummerMonth = IntEnum('SummerMonth', 'june july august', start=40) - lst = list(SummerMonth) - self.assertEqual(len(lst), len(SummerMonth)) - self.assertEqual(len(SummerMonth), 3, SummerMonth) - self.assertEqual( - [SummerMonth.june, SummerMonth.july, SummerMonth.august], - lst, - ) - for i, month in enumerate('june july august'.split(), 40): - e = SummerMonth(i) - self.assertEqual(e, i) - self.assertEqual(e.name, month) - self.assertIn(e, SummerMonth) - self.assertIs(type(e), SummerMonth) - def test_subclassing(self): if isinstance(Name, Exception): raise Name @@ -1011,14 +1528,19 @@ class Color(Enum): red = 1 green = 2 blue = 3 + # with self.assertRaises(TypeError): class MoreColor(Color): cyan = 4 magenta = 5 yellow = 6 - with self.assertRaisesRegex(TypeError, "EvenMoreColor: cannot extend enumeration 'Color'"): + # + with self.assertRaisesRegex(TypeError, " cannot extend "): class EvenMoreColor(Color, IntEnum): chartruese = 7 + # + with self.assertRaisesRegex(TypeError, " cannot extend "): + Color('Foo', ('pink', 'black')) def test_exclude_methods(self): class whatever(Enum): @@ -1121,32 +1643,13 @@ class Color(Enum): with self.assertRaises(KeyError): Color['chartreuse'] - def test_new_repr(self): - class Color(Enum): - red = 1 - green = 2 - blue = 3 - def __repr__(self): - return "don't you just love shades of %s?" % self.name - self.assertEqual( - repr(Color.blue), - "don't you just love shades of blue?", - ) - - def test_inherited_repr(self): - class MyEnum(Enum): - def __repr__(self): - return "My name is %s." % self.name - class MyIntEnum(int, MyEnum): - this = 1 - that = 2 - theother = 3 - self.assertEqual(repr(MyIntEnum.that), "My name is that.") + # tests that need to be evalualted for moving def test_multiple_mixin_mro(self): class auto_enum(type(Enum)): def __new__(metacls, cls, bases, classdict): temp = type(classdict)() + temp._cls_name = cls names = set(classdict._member_names) i = 0 for k in classdict._member_names: @@ -1193,14 +1696,16 @@ def __new__(cls, *args): return self def __getnewargs__(self): return self._args - @property + @bltns.property def __name__(self): return self._intname def __repr__(self): # repr() is updated to include the name and type info - return "{}({!r}, {})".format(type(self).__name__, - self.__name__, - int.__repr__(self)) + return "{}({!r}, {})".format( + type(self).__name__, + self.__name__, + int.__repr__(self), + ) def __str__(self): # str() is unchanged, even if it relies on the repr() fallback base = int @@ -1215,7 +1720,8 @@ def __add__(self, other): if isinstance(self, NamedInt) and isinstance(other, NamedInt): return NamedInt( '({0} + {1})'.format(self.__name__, other.__name__), - temp ) + temp, + ) else: return temp @@ -1236,7 +1742,7 @@ class NEI(NamedInt, Enum): test_pickle_dump_load(self.assertIs, NEI.y) test_pickle_dump_load(self.assertIs, NEI) - # TODO: RUSTPYTHON + # TODO: RUSTPYTHON, fails on pickle @unittest.expectedFailure def test_subclasses_with_getnewargs_ex(self): class NamedInt(int): @@ -1252,14 +1758,16 @@ def __new__(cls, *args): return self def __getnewargs_ex__(self): return self._args, {} - @property + @bltns.property def __name__(self): return self._intname def __repr__(self): # repr() is updated to include the name and type info - return "{}({!r}, {})".format(type(self).__name__, - self.__name__, - int.__repr__(self)) + return "{}({!r}, {})".format( + type(self).__name__, + self.__name__, + int.__repr__(self), + ) def __str__(self): # str() is unchanged, even if it relies on the repr() fallback base = int @@ -1274,7 +1782,8 @@ def __add__(self, other): if isinstance(self, NamedInt) and isinstance(other, NamedInt): return NamedInt( '({0} + {1})'.format(self.__name__, other.__name__), - temp ) + temp, + ) else: return temp @@ -1309,14 +1818,16 @@ def __new__(cls, *args): return self def __reduce__(self): return self.__class__, self._args - @property + @bltns.property def __name__(self): return self._intname def __repr__(self): # repr() is updated to include the name and type info - return "{}({!r}, {})".format(type(self).__name__, - self.__name__, - int.__repr__(self)) + return "{}({!r}, {})".format( + type(self).__name__, + self.__name__, + int.__repr__(self), + ) def __str__(self): # str() is unchanged, even if it relies on the repr() fallback base = int @@ -1331,7 +1842,8 @@ def __add__(self, other): if isinstance(self, NamedInt) and isinstance(other, NamedInt): return NamedInt( '({0} + {1})'.format(self.__name__, other.__name__), - temp ) + temp, + ) else: return temp @@ -1366,14 +1878,16 @@ def __new__(cls, *args): return self def __reduce_ex__(self, proto): return self.__class__, self._args - @property + @bltns.property def __name__(self): return self._intname def __repr__(self): # repr() is updated to include the name and type info - return "{}({!r}, {})".format(type(self).__name__, - self.__name__, - int.__repr__(self)) + return "{}({!r}, {})".format( + type(self).__name__, + self.__name__, + int.__repr__(self), + ) def __str__(self): # str() is unchanged, even if it relies on the repr() fallback base = int @@ -1388,7 +1902,8 @@ def __add__(self, other): if isinstance(self, NamedInt) and isinstance(other, NamedInt): return NamedInt( '({0} + {1})'.format(self.__name__, other.__name__), - temp ) + temp, + ) else: return temp @@ -1397,7 +1912,6 @@ class NEI(NamedInt, Enum): x = ('the-x', 1) y = ('the-y', 2) - self.assertIs(NEI.__new__, Enum.__new__) self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)") globals()['NamedInt'] = NamedInt @@ -1421,14 +1935,16 @@ def __new__(cls, *args): self._intname = name self._args = _args return self - @property + @bltns.property def __name__(self): return self._intname def __repr__(self): # repr() is updated to include the name and type info - return "{}({!r}, {})".format(type(self).__name__, - self.__name__, - int.__repr__(self)) + return "{}({!r}, {})".format( + type(self).__name__, + self.__name__, + int.__repr__(self), + ) def __str__(self): # str() is unchanged, even if it relies on the repr() fallback base = int @@ -1451,7 +1967,6 @@ class NEI(NamedInt, Enum): __qualname__ = 'NEI' x = ('the-x', 1) y = ('the-y', 2) - self.assertIs(NEI.__new__, Enum.__new__) self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)") globals()['NamedInt'] = NamedInt @@ -1459,10 +1974,14 @@ class NEI(NamedInt, Enum): NI5 = NamedInt('test', 5) self.assertEqual(NI5, 5) self.assertEqual(NEI.y.value, 2) - test_pickle_exception(self.assertRaises, TypeError, NEI.x) - test_pickle_exception(self.assertRaises, PicklingError, NEI) + with self.assertRaisesRegex(TypeError, "name and value must be specified"): + test_pickle_dump_load(self.assertIs, NEI.y) + # fix pickle support and try again + NEI.__reduce_ex__ = enum.pickle_by_enum_name + test_pickle_dump_load(self.assertIs, NEI.y) + test_pickle_dump_load(self.assertIs, NEI) - def test_subclasses_without_direct_pickle_support_using_name(self): + def test_subclasses_with_direct_pickle_support(self): class NamedInt(int): __qualname__ = 'NamedInt' def __new__(cls, *args): @@ -1474,14 +1993,16 @@ def __new__(cls, *args): self._intname = name self._args = _args return self - @property + @bltns.property def __name__(self): return self._intname def __repr__(self): # repr() is updated to include the name and type info - return "{}({!r}, {})".format(type(self).__name__, - self.__name__, - int.__repr__(self)) + return "{}({!r}, {})".format( + type(self).__name__, + self.__name__, + int.__repr__(self), + ) def __str__(self): # str() is unchanged, even if it relies on the repr() fallback base = int @@ -1496,7 +2017,8 @@ def __add__(self, other): if isinstance(self, NamedInt) and isinstance(other, NamedInt): return NamedInt( '({0} + {1})'.format(self.__name__, other.__name__), - temp ) + temp, + ) else: return temp @@ -1580,13 +2102,10 @@ class Color(AutoNumber): self.assertEqual(list(map(int, Color)), [1, 2, 3]) def test_equality(self): - class AlwaysEqual: - def __eq__(self, other): - return True class OrdinaryEnum(Enum): a = 1 - self.assertEqual(AlwaysEqual(), OrdinaryEnum.a) - self.assertEqual(OrdinaryEnum.a, AlwaysEqual()) + self.assertEqual(ALWAYS_EQ, OrdinaryEnum.a) + self.assertEqual(OrdinaryEnum.a, ALWAYS_EQ) def test_ordered_mixin(self): class OrderedEnum(Enum): @@ -1663,6 +2182,15 @@ def test(self): class Test(Base): test = 1 self.assertEqual(Test.test.test, 'dynamic') + self.assertEqual(Test.test.value, 1) + class Base2(Enum): + @enum.property + def flash(self): + return 'flashy dynamic' + class Test(Base2): + flash = 1 + self.assertEqual(Test.flash.flash, 'flashy dynamic') + self.assertEqual(Test.flash.value, 1) def test_no_duplicates(self): class UniqueEnum(Enum): @@ -1699,7 +2227,7 @@ class Planet(Enum): def __init__(self, mass, radius): self.mass = mass # in kilograms self.radius = radius # in meters - @property + @enum.property def surface_gravity(self): # universal gravitational constant (m3 kg-1 s-2) G = 6.67300E-11 @@ -1735,124 +2263,41 @@ def __new__(cls, value, period): self.assertTrue(Period.month_1 is Period.day_30) self.assertTrue(Period.week_4 is Period.day_28) - def test_nonhash_value(self): - class AutoNumberInAList(Enum): - def __new__(cls): - value = [len(cls.__members__) + 1] - obj = object.__new__(cls) - obj._value_ = value - return obj - class ColorInAList(AutoNumberInAList): - red = () - green = () - blue = () - self.assertEqual(list(ColorInAList), [ColorInAList.red, ColorInAList.green, ColorInAList.blue]) - for enum, value in zip(ColorInAList, range(3)): - value += 1 - self.assertEqual(enum.value, [value]) - self.assertIs(ColorInAList([value]), enum) - - def test_conflicting_types_resolved_in_new(self): - class LabelledIntEnum(int, Enum): - def __new__(cls, *args): - value, label = args - obj = int.__new__(cls, value) - obj.label = label - obj._value_ = value - return obj - - class LabelledList(LabelledIntEnum): - unprocessed = (1, "Unprocessed") - payment_complete = (2, "Payment Complete") - - self.assertEqual(list(LabelledList), [LabelledList.unprocessed, LabelledList.payment_complete]) - self.assertEqual(LabelledList.unprocessed, 1) - self.assertEqual(LabelledList(1), LabelledList.unprocessed) - - def test_auto_number(self): - class Color(Enum): - red = auto() - blue = auto() - green = auto() - - self.assertEqual(list(Color), [Color.red, Color.blue, Color.green]) - self.assertEqual(Color.red.value, 1) - self.assertEqual(Color.blue.value, 2) - self.assertEqual(Color.green.value, 3) - - def test_auto_name(self): - class Color(Enum): - def _generate_next_value_(name, start, count, last): - return name - red = auto() - blue = auto() - green = auto() - - self.assertEqual(list(Color), [Color.red, Color.blue, Color.green]) - self.assertEqual(Color.red.value, 'red') - self.assertEqual(Color.blue.value, 'blue') - self.assertEqual(Color.green.value, 'green') - - def test_auto_name_inherit(self): - class AutoNameEnum(Enum): - def _generate_next_value_(name, start, count, last): - return name - class Color(AutoNameEnum): - red = auto() - blue = auto() - green = auto() - - self.assertEqual(list(Color), [Color.red, Color.blue, Color.green]) - self.assertEqual(Color.red.value, 'red') - self.assertEqual(Color.blue.value, 'blue') - self.assertEqual(Color.green.value, 'green') - - def test_auto_garbage(self): - class Color(Enum): - red = 'red' - blue = auto() - self.assertEqual(Color.blue.value, 1) - - def test_auto_garbage_corrected(self): - class Color(Enum): - red = 'red' - blue = 2 - green = auto() - - self.assertEqual(list(Color), [Color.red, Color.blue, Color.green]) - self.assertEqual(Color.red.value, 'red') - self.assertEqual(Color.blue.value, 2) - self.assertEqual(Color.green.value, 3) - - def test_auto_order(self): - with self.assertRaises(TypeError): - class Color(Enum): - red = auto() - green = auto() - blue = auto() - def _generate_next_value_(name, start, count, last): - return name + def test_nonhash_value(self): + class AutoNumberInAList(Enum): + def __new__(cls): + value = [len(cls.__members__) + 1] + obj = object.__new__(cls) + obj._value_ = value + return obj + class ColorInAList(AutoNumberInAList): + red = () + green = () + blue = () + self.assertEqual(list(ColorInAList), [ColorInAList.red, ColorInAList.green, ColorInAList.blue]) + for enum, value in zip(ColorInAList, range(3)): + value += 1 + self.assertEqual(enum.value, [value]) + self.assertIs(ColorInAList([value]), enum) - def test_auto_order_wierd(self): - weird_auto = auto() - weird_auto.value = 'pathological case' - class Color(Enum): - red = weird_auto - def _generate_next_value_(name, start, count, last): - return name - blue = auto() - self.assertEqual(list(Color), [Color.red, Color.blue]) - self.assertEqual(Color.red.value, 'pathological case') - self.assertEqual(Color.blue.value, 'blue') + def test_conflicting_types_resolved_in_new(self): + class LabelledIntEnum(int, Enum): + def __new__(cls, *args): + value, label = args + obj = int.__new__(cls, value) + obj.label = label + obj._value_ = value + return obj - def test_duplicate_auto(self): - class Dupes(Enum): - first = primero = auto() - second = auto() - third = auto() - self.assertEqual([Dupes.first, Dupes.second, Dupes.third], list(Dupes)) + class LabelledList(LabelledIntEnum): + unprocessed = (1, "Unprocessed") + payment_complete = (2, "Payment Complete") + + self.assertEqual(list(LabelledList), [LabelledList.unprocessed, LabelledList.payment_complete]) + self.assertEqual(LabelledList.unprocessed, 1) + self.assertEqual(LabelledList(1), LabelledList.unprocessed) - def test_default_missing(self): + def test_default_missing_no_chained_exception(self): class Color(Enum): RED = 1 GREEN = 2 @@ -1864,7 +2309,7 @@ class Color(Enum): else: raise Exception('Exception not raised.') - def test_missing(self): + def test_missing_override(self): class Color(Enum): red = 1 green = 2 @@ -1901,6 +2346,40 @@ def _missing_(cls, item): else: raise Exception('Exception not raised.') + def test_missing_exceptions_reset(self): + import gc + import weakref + # + class TestEnum(enum.Enum): + VAL1 = 'val1' + VAL2 = 'val2' + # + class Class1: + def __init__(self): + # Gracefully handle an exception of our own making + try: + raise ValueError() + except ValueError: + pass + # + class Class2: + def __init__(self): + # Gracefully handle an exception of Enum's making + try: + TestEnum('invalid_value') + except ValueError: + pass + # No strong refs here so these are free to die. + class_1_ref = weakref.ref(Class1()) + class_2_ref = weakref.ref(Class2()) + # + # The exception raised by Enum used to create a reference loop and thus + # Class2 instances would stick around until the next garbage collection + # cycle, unlike Class1. Verify Class2 no longer does this. + gc.collect() # For PyPy or other GCs. + self.assertIs(class_1_ref(), None) + self.assertIs(class_2_ref(), None) + def test_multiple_mixin(self): class MaxMixin: @classproperty @@ -1932,6 +2411,7 @@ class Color(MaxMixin, StrMixin, Enum): RED = auto() GREEN = auto() BLUE = auto() + __str__ = StrMixin.__str__ # needed as of 3.11 self.assertEqual(Color.RED.value, 1) self.assertEqual(Color.GREEN.value, 2) self.assertEqual(Color.BLUE.value, 3) @@ -1941,6 +2421,7 @@ class Color(StrMixin, MaxMixin, Enum): RED = auto() GREEN = auto() BLUE = auto() + __str__ = StrMixin.__str__ # needed as of 3.11 self.assertEqual(Color.RED.value, 1) self.assertEqual(Color.GREEN.value, 2) self.assertEqual(Color.BLUE.value, 3) @@ -1950,6 +2431,7 @@ class CoolColor(StrMixin, SomeEnum, Enum): RED = auto() GREEN = auto() BLUE = auto() + __str__ = StrMixin.__str__ # needed as of 3.11 self.assertEqual(CoolColor.RED.value, 1) self.assertEqual(CoolColor.GREEN.value, 2) self.assertEqual(CoolColor.BLUE.value, 3) @@ -1959,6 +2441,7 @@ class CoolerColor(StrMixin, AnotherEnum, Enum): RED = auto() GREEN = auto() BLUE = auto() + __str__ = StrMixin.__str__ # needed as of 3.11 self.assertEqual(CoolerColor.RED.value, 1) self.assertEqual(CoolerColor.GREEN.value, 2) self.assertEqual(CoolerColor.BLUE.value, 3) @@ -1969,6 +2452,7 @@ class CoolestColor(StrMixin, SomeEnum, AnotherEnum): RED = auto() GREEN = auto() BLUE = auto() + __str__ = StrMixin.__str__ # needed as of 3.11 self.assertEqual(CoolestColor.RED.value, 1) self.assertEqual(CoolestColor.GREEN.value, 2) self.assertEqual(CoolestColor.BLUE.value, 3) @@ -1979,6 +2463,7 @@ class ConfusedColor(StrMixin, AnotherEnum, SomeEnum): RED = auto() GREEN = auto() BLUE = auto() + __str__ = StrMixin.__str__ # needed as of 3.11 self.assertEqual(ConfusedColor.RED.value, 1) self.assertEqual(ConfusedColor.GREEN.value, 2) self.assertEqual(ConfusedColor.BLUE.value, 3) @@ -1989,6 +2474,7 @@ class ReformedColor(StrMixin, IntEnum, SomeEnum, AnotherEnum): RED = auto() GREEN = auto() BLUE = auto() + __str__ = StrMixin.__str__ # needed as of 3.11 self.assertEqual(ReformedColor.RED.value, 1) self.assertEqual(ReformedColor.GREEN.value, 2) self.assertEqual(ReformedColor.BLUE.value, 3) @@ -1998,13 +2484,6 @@ class ReformedColor(StrMixin, IntEnum, SomeEnum, AnotherEnum): self.assertTrue(issubclass(ReformedColor, int)) def test_multiple_inherited_mixin(self): - class StrEnum(str, Enum): - def __new__(cls, *args, **kwargs): - for a in args: - if not isinstance(a, str): - raise TypeError("Enumeration '%s' (%s) is not" - " a string" % (a, type(a).__name__)) - return str.__new__(cls, *args, **kwargs) @unique class Decision1(StrEnum): REVERT = "REVERT" @@ -2028,11 +2507,12 @@ def __repr__(self): return hex(self) class MyIntEnum(HexMixin, MyInt, enum.Enum): - pass + __repr__ = HexMixin.__repr__ class Foo(MyIntEnum): TEST = 1 self.assertTrue(isinstance(Foo.TEST, MyInt)) + self.assertEqual(Foo._member_type_, MyInt) self.assertEqual(repr(Foo.TEST), "0x1") class Fee(MyIntEnum): @@ -2044,6 +2524,50 @@ def __new__(cls, value): return member self.assertEqual(Fee.TEST, 2) + def test_multiple_mixin_with_common_data_type(self): + class CaseInsensitiveStrEnum(str, Enum): + @classmethod + def _missing_(cls, value): + for member in cls._member_map_.values(): + if member._value_.lower() == value.lower(): + return member + return super()._missing_(value) + # + class LenientStrEnum(str, Enum): + def __init__(self, *args): + self._valid = True + @classmethod + def _missing_(cls, value): + unknown = cls._member_type_.__new__(cls, value) + unknown._valid = False + unknown._name_ = value.upper() + unknown._value_ = value + cls._member_map_[value] = unknown + return unknown + @enum.property + def valid(self): + return self._valid + # + class JobStatus(CaseInsensitiveStrEnum, LenientStrEnum): + ACTIVE = "active" + PENDING = "pending" + TERMINATED = "terminated" + # + JS = JobStatus + self.assertEqual(list(JobStatus), [JS.ACTIVE, JS.PENDING, JS.TERMINATED]) + self.assertEqual(JS.ACTIVE, 'active') + self.assertEqual(JS.ACTIVE.value, 'active') + self.assertIs(JS('Active'), JS.ACTIVE) + self.assertTrue(JS.ACTIVE.valid) + missing = JS('missing') + self.assertEqual(list(JobStatus), [JS.ACTIVE, JS.PENDING, JS.TERMINATED]) + self.assertEqual(JS.ACTIVE, 'active') + self.assertEqual(JS.ACTIVE.value, 'active') + self.assertIs(JS('Active'), JS.ACTIVE) + self.assertTrue(JS.ACTIVE.valid) + self.assertTrue(isinstance(missing, JS)) + self.assertFalse(missing.valid) + def test_empty_globals(self): # bpo-35717: sys._getframe(2).f_globals['__name__'] fails with KeyError # when using compile and exec because f_globals is empty @@ -2053,8 +2577,358 @@ def test_empty_globals(self): local_ls = {} exec(code, global_ns, local_ls) + def test_strenum(self): + class GoodStrEnum(StrEnum): + one = '1' + two = '2' + three = b'3', 'ascii' + four = b'4', 'latin1', 'strict' + self.assertEqual(GoodStrEnum.one, '1') + self.assertEqual(str(GoodStrEnum.one), '1') + self.assertEqual('{}'.format(GoodStrEnum.one), '1') + self.assertEqual(GoodStrEnum.one, str(GoodStrEnum.one)) + self.assertEqual(GoodStrEnum.one, '{}'.format(GoodStrEnum.one)) + self.assertEqual(repr(GoodStrEnum.one), "") + # + class DumbMixin: + def __str__(self): + return "don't do this" + class DumbStrEnum(DumbMixin, StrEnum): + five = '5' + six = '6' + seven = '7' + __str__ = DumbMixin.__str__ # needed as of 3.11 + self.assertEqual(DumbStrEnum.seven, '7') + self.assertEqual(str(DumbStrEnum.seven), "don't do this") + # + class EnumMixin(Enum): + def hello(self): + print('hello from %s' % (self, )) + class HelloEnum(EnumMixin, StrEnum): + eight = '8' + self.assertEqual(HelloEnum.eight, '8') + self.assertEqual(HelloEnum.eight, str(HelloEnum.eight)) + # + class GoodbyeMixin: + def goodbye(self): + print('%s wishes you a fond farewell') + class GoodbyeEnum(GoodbyeMixin, EnumMixin, StrEnum): + nine = '9' + self.assertEqual(GoodbyeEnum.nine, '9') + self.assertEqual(GoodbyeEnum.nine, str(GoodbyeEnum.nine)) + # + with self.assertRaisesRegex(TypeError, '1 is not a string'): + class FirstFailedStrEnum(StrEnum): + one = 1 + two = '2' + with self.assertRaisesRegex(TypeError, "2 is not a string"): + class SecondFailedStrEnum(StrEnum): + one = '1' + two = 2, + three = '3' + with self.assertRaisesRegex(TypeError, '2 is not a string'): + class ThirdFailedStrEnum(StrEnum): + one = '1' + two = 2 + with self.assertRaisesRegex(TypeError, 'encoding must be a string, not %r' % (sys.getdefaultencoding, )): + class ThirdFailedStrEnum(StrEnum): + one = '1' + two = b'2', sys.getdefaultencoding + with self.assertRaisesRegex(TypeError, 'errors must be a string, not 9'): + class ThirdFailedStrEnum(StrEnum): + one = '1' + two = b'2', 'ascii', 9 + + # TODO: RUSTPYTHON, fails on encoding testing : TypeError: Expected type 'str' but 'builtin_function_or_method' found + @unittest.expectedFailure + def test_custom_strenum(self): + class CustomStrEnum(str, Enum): + pass + class OkayEnum(CustomStrEnum): + one = '1' + two = '2' + three = b'3', 'ascii' + four = b'4', 'latin1', 'strict' + self.assertEqual(OkayEnum.one, '1') + self.assertEqual(str(OkayEnum.one), 'OkayEnum.one') + self.assertEqual('{}'.format(OkayEnum.one), 'OkayEnum.one') + self.assertEqual(repr(OkayEnum.one), "") + # + class DumbMixin: + def __str__(self): + return "don't do this" + class DumbStrEnum(DumbMixin, CustomStrEnum): + five = '5' + six = '6' + seven = '7' + __str__ = DumbMixin.__str__ # needed as of 3.11 + self.assertEqual(DumbStrEnum.seven, '7') + self.assertEqual(str(DumbStrEnum.seven), "don't do this") + # + class EnumMixin(Enum): + def hello(self): + print('hello from %s' % (self, )) + class HelloEnum(EnumMixin, CustomStrEnum): + eight = '8' + self.assertEqual(HelloEnum.eight, '8') + self.assertEqual(str(HelloEnum.eight), 'HelloEnum.eight') + # + class GoodbyeMixin: + def goodbye(self): + print('%s wishes you a fond farewell') + class GoodbyeEnum(GoodbyeMixin, EnumMixin, CustomStrEnum): + nine = '9' + self.assertEqual(GoodbyeEnum.nine, '9') + self.assertEqual(str(GoodbyeEnum.nine), 'GoodbyeEnum.nine') + # + class FirstFailedStrEnum(CustomStrEnum): + one = 1 # this will become '1' + two = '2' + class SecondFailedStrEnum(CustomStrEnum): + one = '1' + two = 2, # this will become '2' + three = '3' + class ThirdFailedStrEnum(CustomStrEnum): + one = '1' + two = 2 # this will become '2' + with self.assertRaisesRegex(TypeError, '.encoding. must be str, not '): + class ThirdFailedStrEnum(CustomStrEnum): + one = '1' + two = b'2', sys.getdefaultencoding + with self.assertRaisesRegex(TypeError, '.errors. must be str, not '): + class ThirdFailedStrEnum(CustomStrEnum): + one = '1' + two = b'2', 'ascii', 9 + + def test_missing_value_error(self): + with self.assertRaisesRegex(TypeError, "_value_ not set in __new__"): + class Combined(str, Enum): + # + def __new__(cls, value, sequence): + enum = str.__new__(cls, value) + if '(' in value: + fis_name, segment = value.split('(', 1) + segment = segment.strip(' )') + else: + fis_name = value + segment = None + enum.fis_name = fis_name + enum.segment = segment + enum.sequence = sequence + return enum + # + def __repr__(self): + return "<%s.%s>" % (self.__class__.__name__, self._name_) + # + key_type = 'An$(1,2)', 0 + company_id = 'An$(3,2)', 1 + code = 'An$(5,1)', 2 + description = 'Bn$', 3 + + + def test_private_variable_is_normal_attribute(self): + class Private(Enum): + __corporal = 'Radar' + __major_ = 'Hoolihan' + self.assertEqual(Private._Private__corporal, 'Radar') + self.assertEqual(Private._Private__major_, 'Hoolihan') + + @unittest.skip("Accessing all values retained for performance reasons, see GH-93910") + def test_exception_for_member_from_member_access(self): + with self.assertRaisesRegex(AttributeError, " member has no attribute .NO."): + class Di(Enum): + YES = 1 + NO = 0 + nope = Di.YES.NO + + + def test_dynamic_members_with_static_methods(self): + # + foo_defines = {'FOO_CAT': 'aloof', 'BAR_DOG': 'friendly', 'FOO_HORSE': 'big'} + class Foo(Enum): + vars().update({ + k: v + for k, v in foo_defines.items() + if k.startswith('FOO_') + }) + def upper(self): + return self.value.upper() + self.assertEqual(list(Foo), [Foo.FOO_CAT, Foo.FOO_HORSE]) + self.assertEqual(Foo.FOO_CAT.value, 'aloof') + self.assertEqual(Foo.FOO_HORSE.upper(), 'BIG') + # + with self.assertRaisesRegex(TypeError, "'FOO_CAT' already defined as 'aloof'"): + class FooBar(Enum): + vars().update({ + k: v + for k, v in foo_defines.items() + if k.startswith('FOO_') + }, + **{'FOO_CAT': 'small'}, + ) + def upper(self): + return self.value.upper() + + def test_repr_with_dataclass(self): + "ensure dataclass-mixin has correct repr()" + from dataclasses import dataclass + @dataclass + class Foo: + __qualname__ = 'Foo' + a: int + class Entries(Foo, Enum): + ENTRY1 = 1 + self.assertTrue(isinstance(Entries.ENTRY1, Foo)) + self.assertTrue(Entries._member_type_ is Foo, Entries._member_type_) + self.assertTrue(Entries.ENTRY1.value == Foo(1), Entries.ENTRY1.value) + self.assertEqual(repr(Entries.ENTRY1), '') + + def test_repr_with_init_data_type_mixin(self): + # non-data_type is a mixin that doesn't define __new__ + class Foo: + def __init__(self, a): + self.a = a + def __repr__(self): + return f'Foo(a={self.a!r})' + class Entries(Foo, Enum): + ENTRY1 = 1 + # + self.assertEqual(repr(Entries.ENTRY1), 'Foo(a=1)') + + def test_repr_and_str_with_non_data_type_mixin(self): + # non-data_type is a mixin that doesn't define __new__ + class Foo: + def __repr__(self): + return 'Foo' + def __str__(self): + return 'ooF' + class Entries(Foo, Enum): + ENTRY1 = 1 + # + self.assertEqual(repr(Entries.ENTRY1), 'Foo') + self.assertEqual(str(Entries.ENTRY1), 'ooF') + + def test_value_backup_assign(self): + # check that enum will add missing values when custom __new__ does not + class Some(Enum): + def __new__(cls, val): + return object.__new__(cls) + x = 1 + y = 2 + self.assertEqual(Some.x.value, 1) + self.assertEqual(Some.y.value, 2) + + def test_custom_flag_bitwise(self): + class MyIntFlag(int, Flag): + ONE = 1 + TWO = 2 + FOUR = 4 + self.assertTrue(isinstance(MyIntFlag.ONE | MyIntFlag.TWO, MyIntFlag), MyIntFlag.ONE | MyIntFlag.TWO) + self.assertTrue(isinstance(MyIntFlag.ONE | 2, MyIntFlag)) + + def test_int_flags_copy(self): + class MyIntFlag(IntFlag): + ONE = 1 + TWO = 2 + FOUR = 4 + + flags = MyIntFlag.ONE | MyIntFlag.TWO + copied = copy.copy(flags) + deep = copy.deepcopy(flags) + self.assertEqual(copied, flags) + self.assertEqual(deep, flags) + + flags = MyIntFlag.ONE | MyIntFlag.TWO | 8 + copied = copy.copy(flags) + deep = copy.deepcopy(flags) + self.assertEqual(copied, flags) + self.assertEqual(deep, flags) + self.assertEqual(copied.value, 1 | 2 | 8) + + def test_namedtuple_as_value(self): + from collections import namedtuple + TTuple = namedtuple('TTuple', 'id a blist') + class NTEnum(Enum): + NONE = TTuple(0, 0, []) + A = TTuple(1, 2, [4]) + B = TTuple(2, 4, [0, 1, 2]) + self.assertEqual(repr(NTEnum.NONE), "") + self.assertEqual(NTEnum.NONE.value, TTuple(id=0, a=0, blist=[])) + self.assertEqual( + [x.value for x in NTEnum], + [TTuple(id=0, a=0, blist=[]), TTuple(id=1, a=2, blist=[4]), TTuple(id=2, a=4, blist=[0, 1, 2])], + ) + + def test_flag_with_custom_new(self): + class FlagFromChar(IntFlag): + def __new__(cls, c): + value = 1 << c + self = int.__new__(cls, value) + self._value_ = value + return self + # + a = ord('a') + # + self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343) + self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900672) + self.assertEqual(FlagFromChar.a, 158456325028528675187087900672) + self.assertEqual(FlagFromChar.a|1, 158456325028528675187087900673) + # + # + class FlagFromChar(Flag): + def __new__(cls, c): + value = 1 << c + self = object.__new__(cls) + self._value_ = value + return self + # + a = ord('a') + z = 1 + # + self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343) + self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900674) + self.assertEqual(FlagFromChar.a.value, 158456325028528675187087900672) + self.assertEqual((FlagFromChar.a|FlagFromChar.z).value, 158456325028528675187087900674) + # + # + class FlagFromChar(int, Flag, boundary=KEEP): + def __new__(cls, c): + value = 1 << c + self = int.__new__(cls, value) + self._value_ = value + return self + # + a = ord('a') + # + self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343) + self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900672) + self.assertEqual(FlagFromChar.a, 158456325028528675187087900672) + self.assertEqual(FlagFromChar.a|1, 158456325028528675187087900673) + + def test_init_exception(self): + class Base: + def __new__(cls, *args): + return object.__new__(cls) + def __init__(self, x): + raise ValueError("I don't like", x) + with self.assertRaises(TypeError): + class MyEnum(Base, enum.Enum): + A = 'a' + def __init__(self, y): + self.y = y + with self.assertRaises(ValueError): + class MyEnum(Base, enum.Enum): + A = 'a' + def __init__(self, y): + self.y = y + def __new__(cls, value): + member = Base.__new__(cls) + member._value_ = Base(value) + return member + class TestOrder(unittest.TestCase): + "test usage of the `_order_` attribute" def test_same_members(self): class Color(Enum): @@ -2116,7 +2990,7 @@ class Color(Enum): verde = green -class TestFlag(unittest.TestCase): +class OldTestFlag(unittest.TestCase): """Tests of the Flags.""" class Perm(Flag): @@ -2132,68 +3006,12 @@ class Open(Flag): class Color(Flag): BLACK = 0 RED = 1 + ROJO = 1 GREEN = 2 BLUE = 4 PURPLE = RED|BLUE - - def test_str(self): - Perm = self.Perm - self.assertEqual(str(Perm.R), 'Perm.R') - self.assertEqual(str(Perm.W), 'Perm.W') - self.assertEqual(str(Perm.X), 'Perm.X') - self.assertEqual(str(Perm.R | Perm.W), 'Perm.R|W') - self.assertEqual(str(Perm.R | Perm.W | Perm.X), 'Perm.R|W|X') - self.assertEqual(str(Perm(0)), 'Perm.0') - self.assertEqual(str(~Perm.R), 'Perm.W|X') - self.assertEqual(str(~Perm.W), 'Perm.R|X') - self.assertEqual(str(~Perm.X), 'Perm.R|W') - self.assertEqual(str(~(Perm.R | Perm.W)), 'Perm.X') - self.assertEqual(str(~(Perm.R | Perm.W | Perm.X)), 'Perm.0') - self.assertEqual(str(Perm(~0)), 'Perm.R|W|X') - - Open = self.Open - self.assertEqual(str(Open.RO), 'Open.RO') - self.assertEqual(str(Open.WO), 'Open.WO') - self.assertEqual(str(Open.AC), 'Open.AC') - self.assertEqual(str(Open.RO | Open.CE), 'Open.CE') - self.assertEqual(str(Open.WO | Open.CE), 'Open.CE|WO') - self.assertEqual(str(~Open.RO), 'Open.CE|AC|RW|WO') - self.assertEqual(str(~Open.WO), 'Open.CE|RW') - self.assertEqual(str(~Open.AC), 'Open.CE') - self.assertEqual(str(~(Open.RO | Open.CE)), 'Open.AC') - self.assertEqual(str(~(Open.WO | Open.CE)), 'Open.RW') - - def test_repr(self): - Perm = self.Perm - self.assertEqual(repr(Perm.R), '') - self.assertEqual(repr(Perm.W), '') - self.assertEqual(repr(Perm.X), '') - self.assertEqual(repr(Perm.R | Perm.W), '') - self.assertEqual(repr(Perm.R | Perm.W | Perm.X), '') - self.assertEqual(repr(Perm(0)), '') - self.assertEqual(repr(~Perm.R), '') - self.assertEqual(repr(~Perm.W), '') - self.assertEqual(repr(~Perm.X), '') - self.assertEqual(repr(~(Perm.R | Perm.W)), '') - self.assertEqual(repr(~(Perm.R | Perm.W | Perm.X)), '') - self.assertEqual(repr(Perm(~0)), '') - - Open = self.Open - self.assertEqual(repr(Open.RO), '') - self.assertEqual(repr(Open.WO), '') - self.assertEqual(repr(Open.AC), '') - self.assertEqual(repr(Open.RO | Open.CE), '') - self.assertEqual(repr(Open.WO | Open.CE), '') - self.assertEqual(repr(~Open.RO), '') - self.assertEqual(repr(~Open.WO), '') - self.assertEqual(repr(~Open.AC), '') - self.assertEqual(repr(~(Open.RO | Open.CE)), '') - self.assertEqual(repr(~(Open.WO | Open.CE)), '') - - def test_format(self): - Perm = self.Perm - self.assertEqual(format(Perm.R, ''), 'Perm.R') - self.assertEqual(format(Perm.R | Perm.X, ''), 'Perm.R|X') + WHITE = RED|GREEN|BLUE + BLANCO = RED|GREEN|BLUE def test_or(self): Perm = self.Perm @@ -2238,22 +3056,6 @@ def test_xor(self): self.assertIs(Open.RO ^ Open.CE, Open.CE) self.assertIs(Open.CE ^ Open.CE, Open.RO) - def test_invert(self): - Perm = self.Perm - RW = Perm.R | Perm.W - RX = Perm.R | Perm.X - WX = Perm.W | Perm.X - RWX = Perm.R | Perm.W | Perm.X - values = list(Perm) + [RW, RX, WX, RWX, Perm(0)] - for i in values: - self.assertIs(type(~i), Perm) - self.assertEqual(~~i, i) - for i in Perm: - self.assertIs(~~i, i) - Open = self.Open - self.assertIs(Open.WO & ~Open.WO, Open.RO) - self.assertIs((Open.WO|Open.CE) & ~Open.WO, Open.CE) - def test_bool(self): Perm = self.Perm for f in Perm: @@ -2262,6 +3064,74 @@ def test_bool(self): for f in Open: self.assertEqual(bool(f.value), bool(f)) + def test_boundary(self): + self.assertIs(enum.Flag._boundary_, STRICT) + class Iron(Flag, boundary=CONFORM): + ONE = 1 + TWO = 2 + EIGHT = 8 + self.assertIs(Iron._boundary_, CONFORM) + # + class Water(Flag, boundary=STRICT): + ONE = 1 + TWO = 2 + EIGHT = 8 + self.assertIs(Water._boundary_, STRICT) + # + class Space(Flag, boundary=EJECT): + ONE = 1 + TWO = 2 + EIGHT = 8 + self.assertIs(Space._boundary_, EJECT) + # + class Bizarre(Flag, boundary=KEEP): + b = 3 + c = 4 + d = 6 + # + self.assertRaisesRegex(ValueError, 'invalid value 7', Water, 7) + # + self.assertIs(Iron(7), Iron.ONE|Iron.TWO) + self.assertIs(Iron(~9), Iron.TWO) + # + self.assertEqual(Space(7), 7) + self.assertTrue(type(Space(7)) is int) + # + self.assertEqual(list(Bizarre), [Bizarre.c]) + self.assertIs(Bizarre(3), Bizarre.b) + self.assertIs(Bizarre(6), Bizarre.d) + # + class SkipFlag(enum.Flag): + A = 1 + B = 2 + C = 4 | B + # + self.assertTrue(SkipFlag.C in (SkipFlag.A|SkipFlag.C)) + self.assertRaisesRegex(ValueError, 'SkipFlag.. invalid value 42', SkipFlag, 42) + # + class SkipIntFlag(enum.IntFlag): + A = 1 + B = 2 + C = 4 | B + # + self.assertTrue(SkipIntFlag.C in (SkipIntFlag.A|SkipIntFlag.C)) + self.assertEqual(SkipIntFlag(42).value, 42) + # + class MethodHint(Flag): + HiddenText = 0x10 + DigitsOnly = 0x01 + LettersOnly = 0x02 + OnlyMask = 0x0f + # + self.assertEqual(str(MethodHint.HiddenText|MethodHint.OnlyMask), 'MethodHint.HiddenText|DigitsOnly|LettersOnly|OnlyMask') + + + def test_iter(self): + Color = self.Color + Open = self.Open + self.assertEqual(list(Color), [Color.RED, Color.GREEN, Color.BLUE]) + self.assertEqual(list(Open), [Open.WO, Open.RW, Open.CE]) + def test_programatic_function_string(self): Perm = Flag('Perm', 'R W X') lst = list(Perm) @@ -2340,22 +3210,81 @@ def test_programatic_function_from_dict(self): def test_pickle(self): if isinstance(FlagStooges, Exception): raise FlagStooges - test_pickle_dump_load(self.assertIs, FlagStooges.CURLY|FlagStooges.MOE) + test_pickle_dump_load(self.assertIs, FlagStooges.CURLY) + test_pickle_dump_load(self.assertEqual, + FlagStooges.CURLY|FlagStooges.MOE) + test_pickle_dump_load(self.assertEqual, + FlagStooges.CURLY&~FlagStooges.CURLY) test_pickle_dump_load(self.assertIs, FlagStooges) - - def test_contains(self): + test_pickle_dump_load(self.assertEqual, FlagStooges.BIG) + test_pickle_dump_load(self.assertEqual, + FlagStooges.CURLY|FlagStooges.BIG) + + test_pickle_dump_load(self.assertIs, FlagStoogesWithZero.CURLY) + test_pickle_dump_load(self.assertEqual, + FlagStoogesWithZero.CURLY|FlagStoogesWithZero.MOE) + test_pickle_dump_load(self.assertIs, FlagStoogesWithZero.NOFLAG) + test_pickle_dump_load(self.assertEqual, FlagStoogesWithZero.BIG) + test_pickle_dump_load(self.assertEqual, + FlagStoogesWithZero.CURLY|FlagStoogesWithZero.BIG) + + test_pickle_dump_load(self.assertIs, IntFlagStooges.CURLY) + test_pickle_dump_load(self.assertEqual, + IntFlagStooges.CURLY|IntFlagStooges.MOE) + test_pickle_dump_load(self.assertEqual, + IntFlagStooges.CURLY|IntFlagStooges.MOE|0x30) + test_pickle_dump_load(self.assertEqual, IntFlagStooges(0)) + test_pickle_dump_load(self.assertEqual, IntFlagStooges(0x30)) + test_pickle_dump_load(self.assertIs, IntFlagStooges) + test_pickle_dump_load(self.assertEqual, IntFlagStooges.BIG) + test_pickle_dump_load(self.assertEqual, IntFlagStooges.BIG|1) + test_pickle_dump_load(self.assertEqual, + IntFlagStooges.CURLY|IntFlagStooges.BIG) + + test_pickle_dump_load(self.assertIs, IntFlagStoogesWithZero.CURLY) + test_pickle_dump_load(self.assertEqual, + IntFlagStoogesWithZero.CURLY|IntFlagStoogesWithZero.MOE) + test_pickle_dump_load(self.assertIs, IntFlagStoogesWithZero.NOFLAG) + test_pickle_dump_load(self.assertEqual, IntFlagStoogesWithZero.BIG) + test_pickle_dump_load(self.assertEqual, IntFlagStoogesWithZero.BIG|1) + test_pickle_dump_load(self.assertEqual, + IntFlagStoogesWithZero.CURLY|IntFlagStoogesWithZero.BIG) + + @unittest.skipIf( + python_version >= (3, 12), + '__contains__ now returns True/False for all inputs', + ) + def test_contains_er(self): Open = self.Open Color = self.Color self.assertFalse(Color.BLACK in Open) self.assertFalse(Open.RO in Color) with self.assertRaises(TypeError): - 'BLACK' in Color + with self.assertWarns(DeprecationWarning): + 'BLACK' in Color with self.assertRaises(TypeError): - 'RO' in Open + with self.assertWarns(DeprecationWarning): + 'RO' in Open with self.assertRaises(TypeError): - 1 in Color + with self.assertWarns(DeprecationWarning): + 1 in Color with self.assertRaises(TypeError): - 1 in Open + with self.assertWarns(DeprecationWarning): + 1 in Open + + @unittest.skipIf( + python_version < (3, 12), + '__contains__ only works with enum memmbers before 3.12', + ) + def test_contains_tf(self): + Open = self.Open + Color = self.Color + self.assertFalse(Color.BLACK in Open) + self.assertFalse(Open.RO in Color) + self.assertFalse('BLACK' in Color) + self.assertFalse('RO' in Open) + self.assertTrue(1 in Color) + self.assertTrue(1 in Open) def test_member_contains(self): Perm = self.Perm @@ -2377,6 +3306,48 @@ def test_member_contains(self): self.assertFalse(W in RX) self.assertFalse(X in RW) + def test_member_iter(self): + Color = self.Color + self.assertEqual(list(Color.BLACK), []) + self.assertEqual(list(Color.PURPLE), [Color.RED, Color.BLUE]) + self.assertEqual(list(Color.BLUE), [Color.BLUE]) + self.assertEqual(list(Color.GREEN), [Color.GREEN]) + self.assertEqual(list(Color.WHITE), [Color.RED, Color.GREEN, Color.BLUE]) + self.assertEqual(list(Color.WHITE), [Color.RED, Color.GREEN, Color.BLUE]) + + def test_member_length(self): + self.assertEqual(self.Color.__len__(self.Color.BLACK), 0) + self.assertEqual(self.Color.__len__(self.Color.GREEN), 1) + self.assertEqual(self.Color.__len__(self.Color.PURPLE), 2) + self.assertEqual(self.Color.__len__(self.Color.BLANCO), 3) + + def test_number_reset_and_order_cleanup(self): + class Confused(Flag): + _order_ = 'ONE TWO FOUR DOS EIGHT SIXTEEN' + ONE = auto() + TWO = auto() + FOUR = auto() + DOS = 2 + EIGHT = auto() + SIXTEEN = auto() + self.assertEqual( + list(Confused), + [Confused.ONE, Confused.TWO, Confused.FOUR, Confused.EIGHT, Confused.SIXTEEN]) + self.assertIs(Confused.TWO, Confused.DOS) + self.assertEqual(Confused.DOS._value_, 2) + self.assertEqual(Confused.EIGHT._value_, 8) + self.assertEqual(Confused.SIXTEEN._value_, 16) + + def test_aliases(self): + Color = self.Color + self.assertEqual(Color(1).name, 'RED') + self.assertEqual(Color['ROJO'].name, 'RED') + self.assertEqual(Color(7).name, 'WHITE') + self.assertEqual(Color['BLANCO'].name, 'WHITE') + self.assertIs(Color.BLANCO, Color.WHITE) + Open = self.Open + self.assertIs(Open['AC'], Open.AC) + def test_auto_number(self): class Color(Flag): red = auto() @@ -2389,24 +3360,11 @@ class Color(Flag): self.assertEqual(Color.green.value, 4) def test_auto_number_garbage(self): - with self.assertRaisesRegex(TypeError, 'Invalid Flag value: .not an int.'): + with self.assertRaisesRegex(TypeError, 'invalid flag value .not an int.'): class Color(Flag): red = 'not an int' blue = auto() - def test_cascading_failure(self): - class Bizarre(Flag): - c = 3 - d = 4 - f = 6 - # Bizarre.c | Bizarre.d - self.assertRaisesRegex(ValueError, "5 is not a valid Bizarre", Bizarre, 5) - self.assertRaisesRegex(ValueError, "5 is not a valid Bizarre", Bizarre, 5) - self.assertRaisesRegex(ValueError, "2 is not a valid Bizarre", Bizarre, 2) - self.assertRaisesRegex(ValueError, "2 is not a valid Bizarre", Bizarre, 2) - self.assertRaisesRegex(ValueError, "1 is not a valid Bizarre", Bizarre, 1) - self.assertRaisesRegex(ValueError, "1 is not a valid Bizarre", Bizarre, 1) - def test_duplicate_auto(self): class Dupes(Enum): first = primero = auto() @@ -2414,13 +3372,6 @@ class Dupes(Enum): third = auto() self.assertEqual([Dupes.first, Dupes.second, Dupes.third], list(Dupes)) - def test_bizarre(self): - class Bizarre(Flag): - b = 3 - c = 4 - d = 6 - self.assertEqual(repr(Bizarre(7)), '') - def test_multiple_mixin(self): class AllMixin: @classproperty @@ -2449,6 +3400,7 @@ class Color(AllMixin, StrMixin, Flag): RED = auto() GREEN = auto() BLUE = auto() + __str__ = StrMixin.__str__ self.assertEqual(Color.RED.value, 1) self.assertEqual(Color.GREEN.value, 2) self.assertEqual(Color.BLUE.value, 4) @@ -2458,14 +3410,15 @@ class Color(StrMixin, AllMixin, Flag): RED = auto() GREEN = auto() BLUE = auto() + __str__ = StrMixin.__str__ self.assertEqual(Color.RED.value, 1) self.assertEqual(Color.GREEN.value, 2) self.assertEqual(Color.BLUE.value, 4) self.assertEqual(Color.ALL.value, 7) self.assertEqual(str(Color.BLUE), 'blue') - @unittest.skip("TODO: RUSTPYTHON, inconsistent test result on Windows due to threading") @threading_helper.reap_threads + @threading_helper.requires_working_threading() def test_unique_composite(self): # override __eq__ to be identity only class TestFlag(Flag): @@ -2495,22 +3448,59 @@ def cycle_enum(): threading.Thread(target=cycle_enum) for _ in range(8) ] - with threading_helper.start_threads(threads): - pass + with threading_helper.wait_threads_exit(): + with threading_helper.start_threads(threads): + pass # check that only 248 members were created self.assertFalse( failed, 'at least one thread failed while creating composite members') self.assertEqual(256, len(seen), 'too many composite members created') + def test_init_subclass(self): + class MyEnum(Flag): + def __init_subclass__(cls, **kwds): + super().__init_subclass__(**kwds) + self.assertFalse(cls.__dict__.get('_test', False)) + cls._test1 = 'MyEnum' + # + class TheirEnum(MyEnum): + def __init_subclass__(cls, **kwds): + super(TheirEnum, cls).__init_subclass__(**kwds) + cls._test2 = 'TheirEnum' + class WhoseEnum(TheirEnum): + def __init_subclass__(cls, **kwds): + pass + class NoEnum(WhoseEnum): + ONE = 1 + self.assertEqual(TheirEnum.__dict__['_test1'], 'MyEnum') + self.assertEqual(WhoseEnum.__dict__['_test1'], 'MyEnum') + self.assertEqual(WhoseEnum.__dict__['_test2'], 'TheirEnum') + self.assertFalse(NoEnum.__dict__.get('_test1', False)) + self.assertFalse(NoEnum.__dict__.get('_test2', False)) + # + class OurEnum(MyEnum): + def __init_subclass__(cls, **kwds): + cls._test2 = 'OurEnum' + class WhereEnum(OurEnum): + def __init_subclass__(cls, **kwds): + pass + class NeverEnum(WhereEnum): + ONE = 1 + self.assertEqual(OurEnum.__dict__['_test1'], 'MyEnum') + self.assertFalse(WhereEnum.__dict__.get('_test1', False)) + self.assertEqual(WhereEnum.__dict__['_test2'], 'OurEnum') + self.assertFalse(NeverEnum.__dict__.get('_test1', False)) + self.assertFalse(NeverEnum.__dict__.get('_test2', False)) + -class TestIntFlag(unittest.TestCase): +class OldTestIntFlag(unittest.TestCase): """Tests of the IntFlags.""" class Perm(IntFlag): - X = 1 << 0 - W = 1 << 1 R = 1 << 2 + W = 1 << 1 + X = 1 << 0 class Open(IntFlag): RO = 0 @@ -2522,9 +3512,17 @@ class Open(IntFlag): class Color(IntFlag): BLACK = 0 RED = 1 + ROJO = 1 GREEN = 2 BLUE = 4 PURPLE = RED|BLUE + WHITE = RED|GREEN|BLUE + BLANCO = RED|GREEN|BLUE + + class Skip(IntFlag): + FIRST = 1 + SECOND = 2 + EIGHTH = 8 def test_type(self): Perm = self.Perm @@ -2541,77 +3539,54 @@ def test_type(self): self.assertTrue(isinstance(Open.WO | Open.RW, Open)) self.assertEqual(Open.WO | Open.RW, 3) + def test_global_repr_keep(self): + self.assertEqual( + repr(HeadlightsK(0)), + '%s.OFF_K' % SHORT_MODULE, + ) + self.assertEqual( + repr(HeadlightsK(2**0 + 2**2 + 2**3)), + '%(m)s.LOW_BEAM_K|%(m)s.FOG_K|8' % {'m': SHORT_MODULE}, + ) + self.assertEqual( + repr(HeadlightsK(2**3)), + '%(m)s.HeadlightsK(8)' % {'m': SHORT_MODULE}, + ) - def test_str(self): - Perm = self.Perm - self.assertEqual(str(Perm.R), 'Perm.R') - self.assertEqual(str(Perm.W), 'Perm.W') - self.assertEqual(str(Perm.X), 'Perm.X') - self.assertEqual(str(Perm.R | Perm.W), 'Perm.R|W') - self.assertEqual(str(Perm.R | Perm.W | Perm.X), 'Perm.R|W|X') - self.assertEqual(str(Perm.R | 8), 'Perm.8|R') - self.assertEqual(str(Perm(0)), 'Perm.0') - self.assertEqual(str(Perm(8)), 'Perm.8') - self.assertEqual(str(~Perm.R), 'Perm.W|X') - self.assertEqual(str(~Perm.W), 'Perm.R|X') - self.assertEqual(str(~Perm.X), 'Perm.R|W') - self.assertEqual(str(~(Perm.R | Perm.W)), 'Perm.X') - self.assertEqual(str(~(Perm.R | Perm.W | Perm.X)), 'Perm.-8') - self.assertEqual(str(~(Perm.R | 8)), 'Perm.W|X') - self.assertEqual(str(Perm(~0)), 'Perm.R|W|X') - self.assertEqual(str(Perm(~8)), 'Perm.R|W|X') - - Open = self.Open - self.assertEqual(str(Open.RO), 'Open.RO') - self.assertEqual(str(Open.WO), 'Open.WO') - self.assertEqual(str(Open.AC), 'Open.AC') - self.assertEqual(str(Open.RO | Open.CE), 'Open.CE') - self.assertEqual(str(Open.WO | Open.CE), 'Open.CE|WO') - self.assertEqual(str(Open(4)), 'Open.4') - self.assertEqual(str(~Open.RO), 'Open.CE|AC|RW|WO') - self.assertEqual(str(~Open.WO), 'Open.CE|RW') - self.assertEqual(str(~Open.AC), 'Open.CE') - self.assertEqual(str(~(Open.RO | Open.CE)), 'Open.AC|RW|WO') - self.assertEqual(str(~(Open.WO | Open.CE)), 'Open.RW') - self.assertEqual(str(Open(~4)), 'Open.CE|AC|RW|WO') + def test_global_repr_conform1(self): + self.assertEqual( + repr(HeadlightsC(0)), + '%s.OFF_C' % SHORT_MODULE, + ) + self.assertEqual( + repr(HeadlightsC(2**0 + 2**2 + 2**3)), + '%(m)s.LOW_BEAM_C|%(m)s.FOG_C' % {'m': SHORT_MODULE}, + ) + self.assertEqual( + repr(HeadlightsC(2**3)), + '%(m)s.OFF_C' % {'m': SHORT_MODULE}, + ) - def test_repr(self): - Perm = self.Perm - self.assertEqual(repr(Perm.R), '') - self.assertEqual(repr(Perm.W), '') - self.assertEqual(repr(Perm.X), '') - self.assertEqual(repr(Perm.R | Perm.W), '') - self.assertEqual(repr(Perm.R | Perm.W | Perm.X), '') - self.assertEqual(repr(Perm.R | 8), '') - self.assertEqual(repr(Perm(0)), '') - self.assertEqual(repr(Perm(8)), '') - self.assertEqual(repr(~Perm.R), '') - self.assertEqual(repr(~Perm.W), '') - self.assertEqual(repr(~Perm.X), '') - self.assertEqual(repr(~(Perm.R | Perm.W)), '') - self.assertEqual(repr(~(Perm.R | Perm.W | Perm.X)), '') - self.assertEqual(repr(~(Perm.R | 8)), '') - self.assertEqual(repr(Perm(~0)), '') - self.assertEqual(repr(Perm(~8)), '') + def test_global_enum_str(self): + self.assertEqual(str(NoName.ONE & NoName.TWO), 'NoName(0)') + self.assertEqual(str(NoName(0)), 'NoName(0)') - Open = self.Open - self.assertEqual(repr(Open.RO), '') - self.assertEqual(repr(Open.WO), '') - self.assertEqual(repr(Open.AC), '') - self.assertEqual(repr(Open.RO | Open.CE), '') - self.assertEqual(repr(Open.WO | Open.CE), '') - self.assertEqual(repr(Open(4)), '') - self.assertEqual(repr(~Open.RO), '') - self.assertEqual(repr(~Open.WO), '') - self.assertEqual(repr(~Open.AC), '') - self.assertEqual(repr(~(Open.RO | Open.CE)), '') - self.assertEqual(repr(~(Open.WO | Open.CE)), '') - self.assertEqual(repr(Open(~4)), '') + # TODO: RUSTPYTHON, format(NewPerm.R) does not use __str__ + @unittest.expectedFailure def test_format(self): Perm = self.Perm self.assertEqual(format(Perm.R, ''), '4') self.assertEqual(format(Perm.R | Perm.X, ''), '5') + # + class NewPerm(IntFlag): + R = 1 << 2 + W = 1 << 1 + X = 1 << 0 + def __str__(self): + return self._name_ + self.assertEqual(format(NewPerm.R, ''), 'R') + self.assertEqual(format(NewPerm.R | Perm.X, ''), 'R|X') def test_or(self): Perm = self.Perm @@ -2689,8 +3664,7 @@ def test_invert(self): RWX = Perm.R | Perm.W | Perm.X values = list(Perm) + [RW, RX, WX, RWX, Perm(0)] for i in values: - self.assertEqual(~i, ~i.value) - self.assertEqual((~i).value, ~i.value) + self.assertEqual(~i, (~i).value) self.assertIs(type(~i), Perm) self.assertEqual(~~i, i) for i in Perm: @@ -2699,6 +3673,58 @@ def test_invert(self): self.assertIs(Open.WO & ~Open.WO, Open.RO) self.assertIs((Open.WO|Open.CE) & ~Open.WO, Open.CE) + def test_boundary(self): + self.assertIs(enum.IntFlag._boundary_, KEEP) + class Simple(IntFlag, boundary=KEEP): + SINGLE = 1 + # + class Iron(IntFlag, boundary=STRICT): + ONE = 1 + TWO = 2 + EIGHT = 8 + self.assertIs(Iron._boundary_, STRICT) + # + class Water(IntFlag, boundary=CONFORM): + ONE = 1 + TWO = 2 + EIGHT = 8 + self.assertIs(Water._boundary_, CONFORM) + # + class Space(IntFlag, boundary=EJECT): + ONE = 1 + TWO = 2 + EIGHT = 8 + self.assertIs(Space._boundary_, EJECT) + # + class Bizarre(IntFlag, boundary=KEEP): + b = 3 + c = 4 + d = 6 + # + self.assertRaisesRegex(ValueError, 'invalid value 5', Iron, 5) + # + self.assertIs(Water(7), Water.ONE|Water.TWO) + self.assertIs(Water(~9), Water.TWO) + # + self.assertEqual(Space(7), 7) + self.assertTrue(type(Space(7)) is int) + # + self.assertEqual(list(Bizarre), [Bizarre.c]) + self.assertIs(Bizarre(3), Bizarre.b) + self.assertIs(Bizarre(6), Bizarre.d) + # + simple = Simple.SINGLE | Iron.TWO + self.assertEqual(simple, 3) + self.assertIsInstance(simple, Simple) + self.assertEqual(repr(simple), ': 3>') + self.assertEqual(str(simple), '3') + + def test_iter(self): + Color = self.Color + Open = self.Open + self.assertEqual(list(Color), [Color.RED, Color.GREEN, Color.BLUE]) + self.assertEqual(list(Open), [Open.WO, Open.RW, Open.CE]) + def test_programatic_function_string(self): Perm = IntFlag('Perm', 'R W X') lst = list(Perm) @@ -2800,7 +3826,11 @@ def test_programatic_function_from_empty_tuple(self): self.assertEqual(len(lst), len(Thing)) self.assertEqual(len(Thing), 0, Thing) - def test_contains(self): + @unittest.skipIf( + python_version >= (3, 12), + '__contains__ now returns True/False for all inputs', + ) + def test_contains_er(self): Open = self.Open Color = self.Color self.assertTrue(Color.GREEN in Color) @@ -2808,13 +3838,33 @@ def test_contains(self): self.assertFalse(Color.GREEN in Open) self.assertFalse(Open.RW in Color) with self.assertRaises(TypeError): - 'GREEN' in Color + with self.assertWarns(DeprecationWarning): + 'GREEN' in Color with self.assertRaises(TypeError): - 'RW' in Open + with self.assertWarns(DeprecationWarning): + 'RW' in Open with self.assertRaises(TypeError): - 2 in Color + with self.assertWarns(DeprecationWarning): + 2 in Color with self.assertRaises(TypeError): - 2 in Open + with self.assertWarns(DeprecationWarning): + 2 in Open + + @unittest.skipIf( + python_version < (3, 12), + '__contains__ only works with enum memmbers before 3.12', + ) + def test_contains_tf(self): + Open = self.Open + Color = self.Color + self.assertTrue(Color.GREEN in Color) + self.assertTrue(Open.RW in Open) + self.assertTrue(Color.GREEN in Open) + self.assertTrue(Open.RW in Color) + self.assertFalse('GREEN' in Color) + self.assertFalse('RW' in Open) + self.assertTrue(2 in Color) + self.assertTrue(2 in Open) def test_member_contains(self): Perm = self.Perm @@ -2838,6 +3888,30 @@ def test_member_contains(self): with self.assertRaises(TypeError): self.assertFalse('test' in RW) + def test_member_iter(self): + Color = self.Color + self.assertEqual(list(Color.BLACK), []) + self.assertEqual(list(Color.PURPLE), [Color.RED, Color.BLUE]) + self.assertEqual(list(Color.BLUE), [Color.BLUE]) + self.assertEqual(list(Color.GREEN), [Color.GREEN]) + self.assertEqual(list(Color.WHITE), [Color.RED, Color.GREEN, Color.BLUE]) + + def test_member_length(self): + self.assertEqual(self.Color.__len__(self.Color.BLACK), 0) + self.assertEqual(self.Color.__len__(self.Color.GREEN), 1) + self.assertEqual(self.Color.__len__(self.Color.PURPLE), 2) + self.assertEqual(self.Color.__len__(self.Color.BLANCO), 3) + + def test_aliases(self): + Color = self.Color + self.assertEqual(Color(1).name, 'RED') + self.assertEqual(Color['ROJO'].name, 'RED') + self.assertEqual(Color(7).name, 'WHITE') + self.assertEqual(Color['BLANCO'].name, 'WHITE') + self.assertIs(Color.BLANCO, Color.WHITE) + Open = self.Open + self.assertIs(Open['AC'], Open.AC) + def test_bool(self): Perm = self.Perm for f in Perm: @@ -2846,6 +3920,7 @@ def test_bool(self): for f in Open: self.assertEqual(bool(f.value), bool(f)) + def test_multiple_mixin(self): class AllMixin: @classproperty @@ -2869,11 +3944,12 @@ class Color(AllMixin, IntFlag): self.assertEqual(Color.GREEN.value, 2) self.assertEqual(Color.BLUE.value, 4) self.assertEqual(Color.ALL.value, 7) - self.assertEqual(str(Color.BLUE), 'Color.BLUE') + self.assertEqual(str(Color.BLUE), '4') class Color(AllMixin, StrMixin, IntFlag): RED = auto() GREEN = auto() BLUE = auto() + __str__ = StrMixin.__str__ self.assertEqual(Color.RED.value, 1) self.assertEqual(Color.GREEN.value, 2) self.assertEqual(Color.BLUE.value, 4) @@ -2883,14 +3959,15 @@ class Color(StrMixin, AllMixin, IntFlag): RED = auto() GREEN = auto() BLUE = auto() + __str__ = StrMixin.__str__ self.assertEqual(Color.RED.value, 1) self.assertEqual(Color.GREEN.value, 2) self.assertEqual(Color.BLUE.value, 4) self.assertEqual(Color.ALL.value, 7) self.assertEqual(str(Color.BLUE), 'blue') - @unittest.skip("TODO: RUSTPYTHON, inconsistent test result due to threading") @threading_helper.reap_threads + @threading_helper.requires_working_threading() def test_unique_composite(self): # override __eq__ to be identity only class TestFlag(IntFlag): @@ -2920,8 +3997,9 @@ def cycle_enum(): threading.Thread(target=cycle_enum) for _ in range(8) ] - with threading_helper.start_threads(threads): - pass + with threading_helper.wait_threads_exit(): + with threading_helper.start_threads(threads): + pass # check that only 248 members were created self.assertFalse( failed, @@ -2954,6 +4032,7 @@ class Clean(Enum): one = 1 two = 'dos' tres = 4.0 + # @unique class Cleaner(IntEnum): single = 1 @@ -2979,27 +4058,400 @@ class Dirtier(IntEnum): turkey = 3 def test_unique_with_name(self): - @unique + @verify(UNIQUE) class Silly(Enum): one = 1 two = 'dos' name = 3 - @unique + # + @verify(UNIQUE) + class Sillier(IntEnum): + single = 1 + name = 2 + triple = 3 + value = 4 + +class TestVerify(unittest.TestCase): + + def test_continuous(self): + @verify(CONTINUOUS) + class Auto(Enum): + FIRST = auto() + SECOND = auto() + THIRD = auto() + FORTH = auto() + # + @verify(CONTINUOUS) + class Manual(Enum): + FIRST = 3 + SECOND = 4 + THIRD = 5 + FORTH = 6 + # + with self.assertRaisesRegex(ValueError, 'invalid enum .Missing.: missing values 5, 6, 7, 8, 9, 10, 12'): + @verify(CONTINUOUS) + class Missing(Enum): + FIRST = 3 + SECOND = 4 + THIRD = 11 + FORTH = 13 + # + with self.assertRaisesRegex(ValueError, 'invalid flag .Incomplete.: missing values 32'): + @verify(CONTINUOUS) + class Incomplete(Flag): + FIRST = 4 + SECOND = 8 + THIRD = 16 + FORTH = 64 + # + with self.assertRaisesRegex(ValueError, 'invalid flag .StillIncomplete.: missing values 16'): + @verify(CONTINUOUS) + class StillIncomplete(Flag): + FIRST = 4 + SECOND = 8 + THIRD = 11 + FORTH = 32 + + + def test_composite(self): + class Bizarre(Flag): + b = 3 + c = 4 + d = 6 + self.assertEqual(list(Bizarre), [Bizarre.c]) + self.assertEqual(Bizarre.b.value, 3) + self.assertEqual(Bizarre.c.value, 4) + self.assertEqual(Bizarre.d.value, 6) + with self.assertRaisesRegex( + ValueError, + "invalid Flag 'Bizarre': aliases b and d are missing combined values of 0x3 .use enum.show_flag_values.value. for details.", + ): + @verify(NAMED_FLAGS) + class Bizarre(Flag): + b = 3 + c = 4 + d = 6 + # + self.assertEqual(enum.show_flag_values(3), [1, 2]) + class Bizarre(IntFlag): + b = 3 + c = 4 + d = 6 + self.assertEqual(list(Bizarre), [Bizarre.c]) + self.assertEqual(Bizarre.b.value, 3) + self.assertEqual(Bizarre.c.value, 4) + self.assertEqual(Bizarre.d.value, 6) + with self.assertRaisesRegex( + ValueError, + "invalid Flag 'Bizarre': alias d is missing value 0x2 .use enum.show_flag_values.value. for details.", + ): + @verify(NAMED_FLAGS) + class Bizarre(IntFlag): + c = 4 + d = 6 + self.assertEqual(enum.show_flag_values(2), [2]) + + def test_unique_clean(self): + @verify(UNIQUE) + class Clean(Enum): + one = 1 + two = 'dos' + tres = 4.0 + # + @verify(UNIQUE) + class Cleaner(IntEnum): + single = 1 + double = 2 + triple = 3 + + def test_unique_dirty(self): + with self.assertRaisesRegex(ValueError, 'tres.*one'): + @verify(UNIQUE) + class Dirty(Enum): + one = 1 + two = 'dos' + tres = 1 + with self.assertRaisesRegex( + ValueError, + 'double.*single.*turkey.*triple', + ): + @verify(UNIQUE) + class Dirtier(IntEnum): + single = 1 + double = 1 + triple = 3 + turkey = 3 + + def test_unique_with_name(self): + @verify(UNIQUE) + class Silly(Enum): + one = 1 + two = 'dos' + name = 3 + # + @verify(UNIQUE) class Sillier(IntEnum): single = 1 name = 2 triple = 3 value = 4 + def test_negative_alias(self): + @verify(NAMED_FLAGS) + class Color(Flag): + RED = 1 + GREEN = 2 + BLUE = 4 + WHITE = -1 + # no error means success + + +class TestInternals(unittest.TestCase): + + sunder_names = '_bad_', '_good_', '_what_ho_' + dunder_names = '__mal__', '__bien__', '__que_que__' + private_names = '_MyEnum__private', '_MyEnum__still_private' + private_and_sunder_names = '_MyEnum__private_', '_MyEnum__also_private_' + random_names = 'okay', '_semi_private', '_weird__', '_MyEnum__' + + def test_sunder(self): + for name in self.sunder_names + self.private_and_sunder_names: + self.assertTrue(enum._is_sunder(name), '%r is a not sunder name?' % name) + for name in self.dunder_names + self.private_names + self.random_names: + self.assertFalse(enum._is_sunder(name), '%r is a sunder name?' % name) + + def test_dunder(self): + for name in self.dunder_names: + self.assertTrue(enum._is_dunder(name), '%r is a not dunder name?' % name) + for name in self.sunder_names + self.private_names + self.private_and_sunder_names + self.random_names: + self.assertFalse(enum._is_dunder(name), '%r is a dunder name?' % name) + + def test_is_private(self): + for name in self.private_names + self.private_and_sunder_names: + self.assertTrue(enum._is_private('MyEnum', name), '%r is a not private name?') + for name in self.sunder_names + self.dunder_names + self.random_names: + self.assertFalse(enum._is_private('MyEnum', name), '%r is a private name?') + + def test_auto_number(self): + class Color(Enum): + red = auto() + blue = auto() + green = auto() + + self.assertEqual(list(Color), [Color.red, Color.blue, Color.green]) + self.assertEqual(Color.red.value, 1) + self.assertEqual(Color.blue.value, 2) + self.assertEqual(Color.green.value, 3) + + def test_auto_name(self): + class Color(Enum): + def _generate_next_value_(name, start, count, last): + return name + red = auto() + blue = auto() + green = auto() + + self.assertEqual(list(Color), [Color.red, Color.blue, Color.green]) + self.assertEqual(Color.red.value, 'red') + self.assertEqual(Color.blue.value, 'blue') + self.assertEqual(Color.green.value, 'green') + + def test_auto_name_inherit(self): + class AutoNameEnum(Enum): + def _generate_next_value_(name, start, count, last): + return name + class Color(AutoNameEnum): + red = auto() + blue = auto() + green = auto() + + self.assertEqual(list(Color), [Color.red, Color.blue, Color.green]) + self.assertEqual(Color.red.value, 'red') + self.assertEqual(Color.blue.value, 'blue') + self.assertEqual(Color.green.value, 'green') + + @unittest.skipIf( + python_version >= (3, 13), + 'mixed types with auto() no longer supported', + ) + def test_auto_garbage_ok(self): + with self.assertWarnsRegex(DeprecationWarning, 'will require all values to be sortable'): + class Color(Enum): + red = 'red' + blue = auto() + self.assertEqual(Color.blue.value, 1) + + @unittest.skipIf( + python_version >= (3, 13), + 'mixed types with auto() no longer supported', + ) + def test_auto_garbage_corrected_ok(self): + with self.assertWarnsRegex(DeprecationWarning, 'will require all values to be sortable'): + class Color(Enum): + red = 'red' + blue = 2 + green = auto() + yellow = auto() + + self.assertEqual(list(Color), + [Color.red, Color.blue, Color.green, Color.yellow]) + self.assertEqual(Color.red.value, 'red') + self.assertEqual(Color.blue.value, 2) + self.assertEqual(Color.green.value, 3) + self.assertEqual(Color.yellow.value, 4) + + @unittest.skipIf( + python_version < (3, 13), + 'mixed types with auto() will raise in 3.13', + ) + def test_auto_garbage_fail(self): + with self.assertRaisesRegex(TypeError, 'will require all values to be sortable'): + class Color(Enum): + red = 'red' + blue = auto() + + @unittest.skipIf( + python_version < (3, 13), + 'mixed types with auto() will raise in 3.13', + ) + def test_auto_garbage_corrected_fail(self): + with self.assertRaisesRegex(TypeError, 'will require all values to be sortable'): + class Color(Enum): + red = 'red' + blue = 2 + green = auto() + + def test_auto_order(self): + with self.assertRaises(TypeError): + class Color(Enum): + red = auto() + green = auto() + blue = auto() + def _generate_next_value_(name, start, count, last): + return name + + def test_auto_order_wierd(self): + weird_auto = auto() + weird_auto.value = 'pathological case' + class Color(Enum): + red = weird_auto + def _generate_next_value_(name, start, count, last): + return name + blue = auto() + self.assertEqual(list(Color), [Color.red, Color.blue]) + self.assertEqual(Color.red.value, 'pathological case') + self.assertEqual(Color.blue.value, 'blue') + + @unittest.skipIf( + python_version < (3, 13), + 'auto() will return highest value + 1 in 3.13', + ) + def test_auto_with_aliases(self): + class Color(Enum): + red = auto() + blue = auto() + oxford = blue + crimson = red + green = auto() + self.assertIs(Color.crimson, Color.red) + self.assertIs(Color.oxford, Color.blue) + self.assertIsNot(Color.green, Color.red) + self.assertIsNot(Color.green, Color.blue) + + def test_duplicate_auto(self): + class Dupes(Enum): + first = primero = auto() + second = auto() + third = auto() + self.assertEqual([Dupes.first, Dupes.second, Dupes.third], list(Dupes)) + def test_multiple_auto_on_line(self): + class Huh(Enum): + ONE = auto() + TWO = auto(), auto() + THREE = auto(), auto(), auto() + self.assertEqual(Huh.ONE.value, 1) + self.assertEqual(Huh.TWO.value, (2, 3)) + self.assertEqual(Huh.THREE.value, (4, 5, 6)) + # + class Hah(Enum): + def __new__(cls, value, abbr=None): + member = object.__new__(cls) + member._value_ = value + member.abbr = abbr or value[:3].lower() + return member + def _generate_next_value_(name, start, count, last): + return name + # + MONDAY = auto() + TUESDAY = auto() + WEDNESDAY = auto(), 'WED' + THURSDAY = auto(), 'Thu' + FRIDAY = auto() + self.assertEqual(Hah.MONDAY.value, 'MONDAY') + self.assertEqual(Hah.MONDAY.abbr, 'mon') + self.assertEqual(Hah.TUESDAY.value, 'TUESDAY') + self.assertEqual(Hah.TUESDAY.abbr, 'tue') + self.assertEqual(Hah.WEDNESDAY.value, 'WEDNESDAY') + self.assertEqual(Hah.WEDNESDAY.abbr, 'WED') + self.assertEqual(Hah.THURSDAY.value, 'THURSDAY') + self.assertEqual(Hah.THURSDAY.abbr, 'Thu') + self.assertEqual(Hah.FRIDAY.value, 'FRIDAY') + self.assertEqual(Hah.FRIDAY.abbr, 'fri') + # + class Huh(Enum): + def _generate_next_value_(name, start, count, last): + return count+1 + ONE = auto() + TWO = auto(), auto() + THREE = auto(), auto(), auto() + self.assertEqual(Huh.ONE.value, 1) + self.assertEqual(Huh.TWO.value, (2, 2)) + self.assertEqual(Huh.THREE.value, (3, 3, 3)) + +class TestEnumTypeSubclassing(unittest.TestCase): + pass expected_help_output_with_docs = """\ Help on class Color in module %s: class Color(enum.Enum) - | Color(value, names=None, *, module=None, qualname=None, type=None, start=1) + | Create a collection of name/value pairs. + |\x20\x20 + | Example enumeration: + |\x20\x20 + | >>> class Color(Enum): + | ... RED = 1 + | ... BLUE = 2 + | ... GREEN = 3 + |\x20\x20 + | Access them by: |\x20\x20 - | An enumeration. + | - attribute access:: + |\x20\x20 + | >>> Color.RED + | + |\x20\x20 + | - value lookup: + |\x20\x20 + | >>> Color(1) + | + |\x20\x20 + | - name lookup: + |\x20\x20 + | >>> Color['RED'] + | + |\x20\x20 + | Enumerations can be iterated over, and know how many members they have: + |\x20\x20 + | >>> len(Color) + | 3 + |\x20\x20 + | >>> list(Color) + | [, , ] + |\x20\x20 + | Methods can be added to enumerations, and members can have their own + | attributes -- see the documentation for details. |\x20\x20 | Method resolution order: | Color @@ -3008,11 +4460,11 @@ class Color(enum.Enum) |\x20\x20 | Data and other attributes defined here: |\x20\x20 - | blue = + | CYAN = |\x20\x20 - | green = + | MAGENTA = |\x20\x20 - | red = + | YELLOW = |\x20\x20 | ---------------------------------------------------------------------- | Data descriptors inherited from enum.Enum: @@ -3024,13 +4476,28 @@ class Color(enum.Enum) | The value of the Enum member. |\x20\x20 | ---------------------------------------------------------------------- - | Readonly properties inherited from enum.EnumMeta: + | Methods inherited from enum.EnumType: |\x20\x20 - | __members__ - | Returns a mapping of member name->value. + | __contains__(member) from enum.EnumType + | Return True if member is a member of this enum + | raises TypeError if member is not an enum member |\x20\x20\x20\x20\x20\x20 - | This mapping lists all enum members, including aliases. Note that this - | is a read-only view of the internal mapping.""" + | note: in 3.12 TypeError will no longer be raised, and True will also be + | returned if member is the value of a member in this enum + |\x20\x20 + | __getitem__(name) from enum.EnumType + | Return the member matching `name`. + |\x20\x20 + | __iter__() from enum.EnumType + | Return members in definition order. + |\x20\x20 + | __len__() from enum.EnumType + | Return the number of members (no aliases) + |\x20\x20 + | ---------------------------------------------------------------------- + | Data descriptors inherited from enum.EnumType: + |\x20\x20 + | __members__""" expected_help_output_without_docs = """\ Help on class Color in module %s: @@ -3045,11 +4512,11 @@ class Color(enum.Enum) |\x20\x20 | Data and other attributes defined here: |\x20\x20 - | blue = + | YELLOW = |\x20\x20 - | green = + | MAGENTA = |\x20\x20 - | red = + | CYAN = |\x20\x20 | ---------------------------------------------------------------------- | Data descriptors inherited from enum.Enum: @@ -3059,7 +4526,7 @@ class Color(enum.Enum) | value |\x20\x20 | ---------------------------------------------------------------------- - | Data descriptors inherited from enum.EnumMeta: + | Data descriptors inherited from enum.EnumType: |\x20\x20 | __members__""" @@ -3068,12 +4535,10 @@ class TestStdLib(unittest.TestCase): maxDiff = None class Color(Enum): - red = 1 - green = 2 - blue = 3 + CYAN = 1 + MAGENTA = 2 + YELLOW = 3 - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_pydoc(self): # indirectly test __objclass__ if StrEnum.__doc__ is None: @@ -3084,26 +4549,34 @@ def test_pydoc(self): helper = pydoc.Helper(output=output) helper(self.Color) result = output.getvalue().strip() - self.assertEqual(result, expected_text) + self.assertEqual(result, expected_text, result) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_inspect_getmembers(self): values = dict(( - ('__class__', EnumMeta), - ('__doc__', 'An enumeration.'), + ('__class__', EnumType), + ('__doc__', '...'), ('__members__', self.Color.__members__), ('__module__', __name__), - ('blue', self.Color.blue), - ('green', self.Color.green), + ('YELLOW', self.Color.YELLOW), + ('MAGENTA', self.Color.MAGENTA), + ('CYAN', self.Color.CYAN), ('name', Enum.__dict__['name']), - ('red', self.Color.red), ('value', Enum.__dict__['value']), + ('__len__', self.Color.__len__), + ('__contains__', self.Color.__contains__), + ('__name__', 'Color'), + ('__getitem__', self.Color.__getitem__), + ('__qualname__', 'TestStdLib.Color'), + ('__init_subclass__', getattr(self.Color, '__init_subclass__')), + ('__iter__', self.Color.__iter__), )) result = dict(inspect.getmembers(self.Color)) - self.assertEqual(values.keys(), result.keys()) + self.assertEqual(set(values.keys()), set(result.keys())) failed = False for k in values.keys(): + if k == '__doc__': + # __doc__ is huge, not comparing + continue if result[k] != values[k]: print() print('\n%s\n key: %s\n result: %s\nexpected: %s\n%s\n' % @@ -3112,46 +4585,145 @@ def test_inspect_getmembers(self): if failed: self.fail("result does not equal expected, see print above") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_inspect_classify_class_attrs(self): # indirectly test __objclass__ from inspect import Attribute values = [ Attribute(name='__class__', kind='data', - defining_class=object, object=EnumMeta), + defining_class=object, object=EnumType), + Attribute(name='__contains__', kind='method', + defining_class=EnumType, object=self.Color.__contains__), Attribute(name='__doc__', kind='data', - defining_class=self.Color, object='An enumeration.'), + defining_class=self.Color, object='...'), + Attribute(name='__getitem__', kind='method', + defining_class=EnumType, object=self.Color.__getitem__), + Attribute(name='__iter__', kind='method', + defining_class=EnumType, object=self.Color.__iter__), + Attribute(name='__init_subclass__', kind='class method', + defining_class=object, object=getattr(self.Color, '__init_subclass__')), + Attribute(name='__len__', kind='method', + defining_class=EnumType, object=self.Color.__len__), Attribute(name='__members__', kind='property', - defining_class=EnumMeta, object=EnumMeta.__members__), + defining_class=EnumType, object=EnumType.__members__), Attribute(name='__module__', kind='data', defining_class=self.Color, object=__name__), - Attribute(name='blue', kind='data', - defining_class=self.Color, object=self.Color.blue), - Attribute(name='green', kind='data', - defining_class=self.Color, object=self.Color.green), - Attribute(name='red', kind='data', - defining_class=self.Color, object=self.Color.red), + Attribute(name='__name__', kind='data', + defining_class=self.Color, object='Color'), + Attribute(name='__qualname__', kind='data', + defining_class=self.Color, object='TestStdLib.Color'), + Attribute(name='YELLOW', kind='data', + defining_class=self.Color, object=self.Color.YELLOW), + Attribute(name='MAGENTA', kind='data', + defining_class=self.Color, object=self.Color.MAGENTA), + Attribute(name='CYAN', kind='data', + defining_class=self.Color, object=self.Color.CYAN), Attribute(name='name', kind='data', defining_class=Enum, object=Enum.__dict__['name']), Attribute(name='value', kind='data', defining_class=Enum, object=Enum.__dict__['value']), ] + for v in values: + try: + v.name + except AttributeError: + print(v) values.sort(key=lambda item: item.name) result = list(inspect.classify_class_attrs(self.Color)) result.sort(key=lambda item: item.name) + self.assertEqual( + len(values), len(result), + "%s != %s" % ([a.name for a in values], [a.name for a in result]) + ) failed = False for v, r in zip(values, result): - if r != v: + if r.name in ('__init_subclass__', '__doc__'): + # not sure how to make the __init_subclass_ Attributes match + # so as long as there is one, call it good + # __doc__ is too big to check exactly, so treat the same as __init_subclass__ + for name in ('name','kind','defining_class'): + if getattr(v, name) != getattr(r, name): + print('\n%s\n%s\n%s\n%s\n' % ('=' * 75, r, v, '=' * 75), sep='') + failed = True + elif r != v: print('\n%s\n%s\n%s\n%s\n' % ('=' * 75, r, v, '=' * 75), sep='') failed = True if failed: self.fail("result does not equal expected, see print above") + # TODO: RUSTPYTHON, len is often/always > 256 + @unittest.expectedFailure + def test_test_simple_enum(self): + @_simple_enum(Enum) + class SimpleColor: + CYAN = 1 + MAGENTA = 2 + YELLOW = 3 + @bltns.property + def zeroth(self): + return 'zeroed %s' % self.name + class CheckedColor(Enum): + CYAN = 1 + MAGENTA = 2 + YELLOW = 3 + @bltns.property + def zeroth(self): + return 'zeroed %s' % self.name + self.assertTrue(_test_simple_enum(CheckedColor, SimpleColor) is None) + SimpleColor.MAGENTA._value_ = 9 + self.assertRaisesRegex( + TypeError, "enum mismatch", + _test_simple_enum, CheckedColor, SimpleColor, + ) + class CheckedMissing(IntFlag, boundary=KEEP): + SIXTY_FOUR = 64 + ONE_TWENTY_EIGHT = 128 + TWENTY_FORTY_EIGHT = 2048 + ALL = 2048 + 128 + 64 + 12 + CM = CheckedMissing + self.assertEqual(list(CheckedMissing), [CM.SIXTY_FOUR, CM.ONE_TWENTY_EIGHT, CM.TWENTY_FORTY_EIGHT]) + # + @_simple_enum(IntFlag, boundary=KEEP) + class Missing: + SIXTY_FOUR = 64 + ONE_TWENTY_EIGHT = 128 + TWENTY_FORTY_EIGHT = 2048 + ALL = 2048 + 128 + 64 + 12 + M = Missing + self.assertEqual(list(CheckedMissing), [M.SIXTY_FOUR, M.ONE_TWENTY_EIGHT, M.TWENTY_FORTY_EIGHT]) + # + _test_simple_enum(CheckedMissing, Missing) + class MiscTestCase(unittest.TestCase): + def test__all__(self): - check__all__(self, enum) + support.check__all__(self, enum, not_exported={'bin', 'show_flag_values'}) + + def test_doc_1(self): + class Single(Enum): + ONE = 1 + self.assertEqual(Single.__doc__, None) + + def test_doc_2(self): + class Double(Enum): + ONE = 1 + TWO = 2 + self.assertEqual(Double.__doc__, None) + + def test_doc_3(self): + class Triple(Enum): + ONE = 1 + TWO = 2 + THREE = 3 + self.assertEqual(Triple.__doc__, None) + + def test_doc_4(self): + class Quadruple(Enum): + ONE = 1 + TWO = 2 + THREE = 3 + FOUR = 4 + self.assertEqual(Quadruple.__doc__, None) # These are unordered here on purpose to ensure that declaration order @@ -3163,21 +4735,61 @@ def test__all__(self): CONVERT_TEST_NAME_E = 5 CONVERT_TEST_NAME_F = 5 -class TestIntEnumConvert(unittest.TestCase): +CONVERT_STRING_TEST_NAME_D = 5 +CONVERT_STRING_TEST_NAME_C = 5 +CONVERT_STRING_TEST_NAME_B = 5 +CONVERT_STRING_TEST_NAME_A = 5 # This one should sort first. +CONVERT_STRING_TEST_NAME_E = 5 +CONVERT_STRING_TEST_NAME_F = 5 + +# global names for StrEnum._convert_ test +CONVERT_STR_TEST_2 = 'goodbye' +CONVERT_STR_TEST_1 = 'hello' + +# We also need values that cannot be compared: +UNCOMPARABLE_A = 5 +UNCOMPARABLE_C = (9, 1) # naming order is broken on purpose +UNCOMPARABLE_B = 'value' + +COMPLEX_C = 1j +COMPLEX_A = 2j +COMPLEX_B = 3j + +class _ModuleWrapper: + """We use this class as a namespace for swapping modules.""" + def __init__(self, module): + self.__dict__.update(module.__dict__) + +class TestConvert(unittest.TestCase): + def tearDown(self): + # Reset the module-level test variables to their original integer + # values, otherwise the already created enum values get converted + # instead. + g = globals() + for suffix in ['A', 'B', 'C', 'D', 'E', 'F']: + g['CONVERT_TEST_NAME_%s' % suffix] = 5 + g['CONVERT_STRING_TEST_NAME_%s' % suffix] = 5 + for suffix, value in (('A', 5), ('B', (9, 1)), ('C', 'value')): + g['UNCOMPARABLE_%s' % suffix] = value + for suffix, value in (('A', 2j), ('B', 3j), ('C', 1j)): + g['COMPLEX_%s' % suffix] = value + for suffix, value in (('1', 'hello'), ('2', 'goodbye')): + g['CONVERT_STR_TEST_%s' % suffix] = value + def test_convert_value_lookup_priority(self): test_type = enum.IntEnum._convert_( 'UnittestConvert', - ('test.test_enum', '__main__')[__name__=='__main__'], + MODULE, filter=lambda x: x.startswith('CONVERT_TEST_')) # We don't want the reverse lookup value to vary when there are # multiple possible names for a given value. It should always # report the first lexigraphical name in that case. self.assertEqual(test_type(5).name, 'CONVERT_TEST_NAME_A') - def test_convert(self): + def test_convert_int(self): test_type = enum.IntEnum._convert_( 'UnittestConvert', - ('test.test_enum', '__main__')[__name__=='__main__'], + MODULE, filter=lambda x: x.startswith('CONVERT_TEST_')) # Ensure that test_type has all of the desired names and values. self.assertEqual(test_type.CONVERT_TEST_NAME_F, @@ -3187,30 +4799,123 @@ def test_convert(self): self.assertEqual(test_type.CONVERT_TEST_NAME_D, 5) self.assertEqual(test_type.CONVERT_TEST_NAME_E, 5) # Ensure that test_type only picked up names matching the filter. - self.assertEqual([name for name in dir(test_type) - if name[0:2] not in ('CO', '__')], - [], msg='Names other than CONVERT_TEST_* found.') - - @unittest.skipUnless(sys.version_info[:2] == (3, 8), - '_convert was deprecated in 3.8') - def test_convert_warn(self): - with self.assertWarns(DeprecationWarning): - enum.IntEnum._convert( + int_dir = dir(int) + [ + 'CONVERT_TEST_NAME_A', 'CONVERT_TEST_NAME_B', 'CONVERT_TEST_NAME_C', + 'CONVERT_TEST_NAME_D', 'CONVERT_TEST_NAME_E', 'CONVERT_TEST_NAME_F', + 'CONVERT_TEST_SIGABRT', 'CONVERT_TEST_SIGIOT', + 'CONVERT_TEST_EIO', 'CONVERT_TEST_EBUS', + ] + extra = [name for name in dir(test_type) if name not in enum_dir(test_type)] + missing = [name for name in enum_dir(test_type) if name not in dir(test_type)] + self.assertEqual( + extra + missing, + [], + msg='extra names: %r; missing names: %r' % (extra, missing), + ) + + + def test_convert_uncomparable(self): + uncomp = enum.Enum._convert_( + 'Uncomparable', + MODULE, + filter=lambda x: x.startswith('UNCOMPARABLE_')) + # Should be ordered by `name` only: + self.assertEqual( + list(uncomp), + [uncomp.UNCOMPARABLE_A, uncomp.UNCOMPARABLE_B, uncomp.UNCOMPARABLE_C], + ) + + def test_convert_complex(self): + uncomp = enum.Enum._convert_( + 'Uncomparable', + MODULE, + filter=lambda x: x.startswith('COMPLEX_')) + # Should be ordered by `name` only: + self.assertEqual( + list(uncomp), + [uncomp.COMPLEX_A, uncomp.COMPLEX_B, uncomp.COMPLEX_C], + ) + + def test_convert_str(self): + test_type = enum.StrEnum._convert_( 'UnittestConvert', - ('test.test_enum', '__main__')[__name__=='__main__'], - filter=lambda x: x.startswith('CONVERT_TEST_')) + MODULE, + filter=lambda x: x.startswith('CONVERT_STR_'), + as_global=True) + # Ensure that test_type has all of the desired names and values. + self.assertEqual(test_type.CONVERT_STR_TEST_1, 'hello') + self.assertEqual(test_type.CONVERT_STR_TEST_2, 'goodbye') + # Ensure that test_type only picked up names matching the filter. + str_dir = dir(str) + ['CONVERT_STR_TEST_1', 'CONVERT_STR_TEST_2'] + extra = [name for name in dir(test_type) if name not in enum_dir(test_type)] + missing = [name for name in enum_dir(test_type) if name not in dir(test_type)] + self.assertEqual( + extra + missing, + [], + msg='extra names: %r; missing names: %r' % (extra, missing), + ) + self.assertEqual(repr(test_type.CONVERT_STR_TEST_1), '%s.CONVERT_STR_TEST_1' % SHORT_MODULE) + self.assertEqual(str(test_type.CONVERT_STR_TEST_2), 'goodbye') + self.assertEqual(format(test_type.CONVERT_STR_TEST_1), 'hello') - # TODO: RUSTPYTHON - @unittest.expectedFailure - @unittest.skipUnless(sys.version_info >= (3, 9), - '_convert was removed in 3.9') def test_convert_raise(self): with self.assertRaises(AttributeError): enum.IntEnum._convert( 'UnittestConvert', - ('test.test_enum', '__main__')[__name__=='__main__'], + MODULE, filter=lambda x: x.startswith('CONVERT_TEST_')) + def test_convert_repr_and_str(self): + test_type = enum.IntEnum._convert_( + 'UnittestConvert', + MODULE, + filter=lambda x: x.startswith('CONVERT_STRING_TEST_'), + as_global=True) + self.assertEqual(repr(test_type.CONVERT_STRING_TEST_NAME_A), '%s.CONVERT_STRING_TEST_NAME_A' % SHORT_MODULE) + self.assertEqual(str(test_type.CONVERT_STRING_TEST_NAME_A), '5') + self.assertEqual(format(test_type.CONVERT_STRING_TEST_NAME_A), '5') + + +# helpers + +def enum_dir(cls): + interesting = set([ + '__class__', '__contains__', '__doc__', '__getitem__', + '__iter__', '__len__', '__members__', '__module__', + '__name__', '__qualname__', + ] + + cls._member_names_ + ) + if cls._new_member_ is not object.__new__: + interesting.add('__new__') + if cls.__init_subclass__ is not object.__init_subclass__: + interesting.add('__init_subclass__') + if cls._member_type_ is object: + return sorted(interesting) + else: + # return whatever mixed-in data type has + return sorted(set(dir(cls._member_type_)) | interesting) + +def member_dir(member): + if member.__class__._member_type_ is object: + allowed = set(['__class__', '__doc__', '__eq__', '__hash__', '__module__', 'name', 'value']) + else: + allowed = set(dir(member)) + for cls in member.__class__.mro(): + for name, obj in cls.__dict__.items(): + if name[0] == '_': + continue + if isinstance(obj, enum.property): + if obj.fget is not None or name not in member._member_map_: + allowed.add(name) + else: + allowed.discard(name) + else: + allowed.add(name) + return sorted(allowed) + +missing = object() + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_re.py b/Lib/test/test_re.py index 03cb8172de..9b30b4137c 100644 --- a/Lib/test/test_re.py +++ b/Lib/test/test_re.py @@ -2304,6 +2304,8 @@ def test_long_pattern(self): self.assertEqual(r[:30], "re.compile('Very long long lon") self.assertEqual(r[-16:], ", re.IGNORECASE)") + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_flags_repr(self): self.assertEqual(repr(re.I), "re.IGNORECASE") self.assertEqual(repr(re.I|re.S|re.X), diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 72b0f19275..35f94a4e22 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -1493,6 +1493,8 @@ def test_sio_loopback_fast_path(self): raise self.assertRaises(TypeError, s.ioctl, socket.SIO_LOOPBACK_FAST_PATH, None) + # TODO: RUSTPYTHON, AssertionError: '2' != 'AddressFamily.AF_INET' + @unittest.expectedFailure def testGetaddrinfo(self): try: socket.getaddrinfo('localhost', 80) @@ -1799,6 +1801,8 @@ def test_getnameinfo_ipv6_scopeid_numeric(self): nameinfo = socket.getnameinfo(sockaddr, socket.NI_NUMERICHOST | socket.NI_NUMERICSERV) self.assertEqual(nameinfo, ('ff02::1de:c0:face:8d%' + str(ifindex), '1234')) + # TODO: RUSTPYTHON, AssertionError: '2' != 'AddressFamily.AF_INET' + @unittest.expectedFailure def test_str_for_enums(self): # Make sure that the AF_* and SOCK_* constants have enum-like string # reprs. diff --git a/Lib/test/test_unicode.py b/Lib/test/test_unicode.py index cf1b14bacd..071a2a06c1 100644 --- a/Lib/test/test_unicode.py +++ b/Lib/test/test_unicode.py @@ -1509,6 +1509,9 @@ def __int__(self): self.assertRaisesRegex(TypeError, '%x format: an integer is required, not PseudoFloat', operator.mod, '%x', pi), self.assertRaises(TypeError, operator.mod, '%c', pi), + + # TODO: RUSTPYTHON, AssertionError: '...15...' != '...Int.IDES...' + @unittest.expectedFailure def test_formatting_with_enum(self): # issue18780 import enum