diff --git a/graphene/types/enum.py b/graphene/types/enum.py index 70e8ee8e0..2bba903e8 100644 --- a/graphene/types/enum.py +++ b/graphene/types/enum.py @@ -1,4 +1,5 @@ from enum import Enum as PyEnum +from typing import Iterable, Type, Optional from graphene.utils.subclass_with_meta import SubclassWithMeta_Meta @@ -18,11 +19,18 @@ def eq_enum(self, other): class EnumOptions(BaseOptions): enum = None # type: Enum deprecation_reason = None + enums = () # type: Iterable[Type[Enum]] class EnumMeta(SubclassWithMeta_Meta): def __new__(cls, name_, bases, classdict, **options): - enum_members = dict(classdict, __eq__=eq_enum) + enum_members = dict(__eq__=eq_enum) + meta = classdict.get("Meta", None) # type: Optional[EnumOptions] + if meta and hasattr(meta, 'enums'): + for enum in meta.enums: + enum_members.update(enum.as_dict()) + enum_members.update(classdict) + # We remove the Meta attribute from the class to not collide # with the enum values. enum_members.pop("Meta", None) @@ -106,3 +114,10 @@ def get_type(cls): is mounted (as a Field, InputField or Argument) """ return cls + + @classmethod + def as_dict(cls): + return { + enum_meta.name: enum_meta.value + for _, enum_meta in cls._meta.enum.__members__.items() + } diff --git a/graphene/types/tests/test_enum.py b/graphene/types/tests/test_enum.py index 6e204aa9c..b5448cd6c 100644 --- a/graphene/types/tests/test_enum.py +++ b/graphene/types/tests/test_enum.py @@ -471,3 +471,77 @@ class Query(ObjectType): assert result.data == {"createPaint": {"color": "RED"}} assert color_input_value == RGB.RED + + +def test_enum_inheritance(): + class ParentRGB(Enum): + RED = 1 + + class ChildRGB(Enum): + BLUE = 2 + + class Meta: + enums = (ParentRGB,) + + class Query(ObjectType): + color = ChildRGB(required=True) + + schema = Schema(query=Query) + assert str(schema) == dedent( + '''\ + type Query { + color: ChildRGB! + } + + enum ChildRGB { + RED + BLUE + } + ''' + ) + + +def test_multiple_enum_inheritance(): + class Parent1RGB(Enum): + RED = 1 + + class Parent2RGB(Enum): + BLUE = 2 + + class ChildRGB(Enum): + GREEN = 3 + + class Meta: + enums = (Parent1RGB, Parent2RGB,) + + class Query(ObjectType): + color = ChildRGB(required=True) + + schema = Schema(query=Query) + assert str(schema) == dedent( + '''\ + type Query { + color: ChildRGB! + } + + enum ChildRGB { + RED + BLUE + GREEN + } + ''' + ) + + +def test_override_enum_inheritance(): + class ParentRGB(Enum): + RED = 1 + BLUE = 2 + + class ChildRGB(Enum): + BLUE = 3 + + class Meta: + enums = (ParentRGB,) + + assert ChildRGB.get(3) != ParentRGB.BLUE