named_enum.py 4.93 KB
import enum
from functools import lru_cache


class NamedEnum(enum.Enum):
    """
    >>> @enum.unique
    ... class Status(NamedEnum):
    ...     CREATED = (3, '初始化')
    ...     CANCELED = (2, '损坏')
    ...     FINISHED = (1, '完成')
    ...     REJECTED = (4, '拒绝')
    ...     TEST = (6, '拒绝')
    >>> assert Status.CREATED != 3
    >>> assert Status.CREATED.value == 3
    >>> assert Status.CREATED.name == 'CREATED'
    >>> assert Status.CREATED.verbose_name == '初始化'
    >>> assert Status(3) == Status.CREATED
    >>> assert Status['CREATED'] == Status.CREATED
    >>> assert Status.get_verbose_name_or_raise(3) == '初始化'
    >>> assert Status.get_mappings() == {
    ...     3: '初始化', 2: '损坏', 1: '完成', 4: '拒绝', 6: '拒绝'}
    >>> assert Status.get_mapping_lst() == [
    ...     {'id': 3, 'name': '初始化'},
    ...     {'id': 2, 'name': '损坏'},
    ...     {'id': 1, 'name': '完成'},
    ...     {'id': 4, 'name': '拒绝'},
    ...     {'id': 6, 'name': '拒绝'}
    ... ]
    >>> assert Status.get_verbose_name(5, '默认') == '默认'
    >>> assert Status.get_value_or_raise('完成') == 1
    >>> assert Status.get_value('完成完成', 100) == 100
    >>> assert Status.get_value_or_raise('拒绝') == 4
    >>> # 测试扩展功能
    >>> extend_values = {
    ...     'EXTENDED': (6, '扩展')
    ... }
    >>> ExtendedStatus = extend(Status, 'ExtendedStatus', extend_values)
    >>> assert ExtendedStatus.get_verbose_name(5, '默认') == '默认'
    >>> assert ExtendedStatus.get_value_or_raise('完成') == 1
    >>> assert ExtendedStatus.get_value('完成完成', 100) == 100
    >>> assert ExtendedStatus.get_value_or_raise('拒绝') == 4
    >>> assert ExtendedStatus.EXTENDED.value == 6
    >>> ExtendedStatus = extend(
    ...     Status, 'ExtendedStatus', [('EXTENDED', (6, '扩展'))])
    >>> assert ExtendedStatus.EXTENDED.value == 6
    >>> try:
    ...     ExtendedStatus = extend(
    ...         Status, 'ExtendedStatus', extend_values, unique=True)
    ... except ValueError as err:
    ...     assert 'duplicate' in str(err)
    ... else:
    ...     raise Exception('except ValueError')
    """

    def __new__(cls, *args):
        value, verbose_name = args
        res = object.__new__(cls)
        res._value_ = value
        return res

    def __init__(self, *args, **kwargs):
        self._value_ = args[0]
        self.verbose_name = args[1]

    @classmethod
    def get_verbose_name_or_raise(cls, value):
        return cls.get_verbose_name(value, raise_on_missing=True)

    @classmethod
    def get_verbose_name(cls, value, default=None, raise_on_missing=False):
        """尝试根据value获取verbose_name,如果失败,返回默认值
        """
        try:
            return cls(value).verbose_name
        except ValueError as err:
            if raise_on_missing:
                raise err
            else:
                return default

    @classmethod
    def get_value_or_raise(cls, verbose_name):
        return cls.get_value(verbose_name, raise_on_missing=True)

    @classmethod
    def get_value(cls, verbose_name, default=None, raise_on_missing=False):
        try:
            return cls.from_verbose_name(verbose_name).value
        except ValueError as err:
            if raise_on_missing:
                raise err
            else:
                return default

    @classmethod
    def from_verbose_name(cls, verbose_name):
        """根据verbose_name获取NamedEnum,
        如果verbose_name重复出现,会返回第一个定义的记录
        """
        for member in cls._value2member_map_.values():
            if member.verbose_name == verbose_name:
                return member
        raise ValueError('%s is not a valid "%s"' % (
            verbose_name, cls.__name__))

    @classmethod
    @lru_cache()
    def get_mappings(cls):
        return {
            item.value: item.verbose_name
            for _, item in cls._member_map_.items()
        }

    @classmethod
    @lru_cache()
    def get_mapping_lst(cls):
        return list([
            {'id': item.value, 'name': item.verbose_name}
            for item in cls
        ])

    @classmethod
    @lru_cache()
    def get_value_lst(cls):
        return list([
            item.value
            for item in cls
        ])

    @property
    def raw_value(self):
        return (self.value, self.verbose_name)

    @classmethod
    @lru_cache()
    def get_choices_lst(cls):
        return [
            (item.value, item.verbose_name)
            for _, item in cls._member_map_.items()
        ]


def extend(cls, sub_cls_name, names, unique=False):
    assert issubclass(cls, NamedEnum)
    target_names = {
        item.name: item.raw_value for item in cls
    }
    if not isinstance(names, dict):
        names = {item[0]: item[1] for item in names}
    target_names.update(names)
    sub_cls = NamedEnum(sub_cls_name, names=target_names)
    if unique:
        sub_cls = enum.unique(sub_cls)
    return sub_cls