Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyrefly/lib/alt/class/class_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1733,6 +1733,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
self.determine_read_only_reason(name, annotation.as_ref(), &metadata, field_definition);

// Determine the final type, promoting literals when appropriate.
let original_value_ty = value_ty.clone();
// Skip literal promotion for NNModule types: their fields are captured
// constructor args that must preserve literal types for shape inference.
let ty = if matches!(value_ty, Type::NNModule(_)) {
Expand Down Expand Up @@ -1821,6 +1822,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
class,
name,
direct_annotation.as_ref(),
&original_value_ty,
&ty,
field_definition,
descriptor.is_some(),
Expand Down Expand Up @@ -2062,6 +2064,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
class: &Class,
name: &Name,
direct_annotation: Option<&Annotation>,
value_ty: &Type,
ty: &Type,
field_definition: &ClassFieldDefinition,
is_descriptor: bool,
Expand All @@ -2072,6 +2075,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
class,
name,
direct_annotation,
value_ty,
ty,
field_definition,
is_descriptor,
Expand Down
12 changes: 6 additions & 6 deletions pyrefly/lib/alt/class/enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,8 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
/// - Check whether the field is a member (which depends only on its type and name)
/// - Validate that a member should not have an annotation, and should respect any explicit annotation on `_value_`
///
/// TODO(stroxler, yangdanny): We currently operate on promoted types, which means we do not infer `Literal[...]`
/// types for the `.value` / `._value_` attributes of literals. This is permitted in the spec although not optimal
/// for most cases; we are handling it this way in part because generic enum behavior is not yet well-specified.
/// We preserve the original inferred member value type here so enum `.value` lookups on the
/// class can recover unions of member literals instead of only their promoted base classes.
///
/// We currently skip the check for `_value_` if the class defines `__new__`, since that can
/// change the value of the enum member. https://docs.python.org/3/howto/enum.html#when-to-use-new-vs-init
Expand All @@ -261,6 +260,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
class: &Class,
name: &Name,
direct_annotation: Option<&Annotation>,
value_ty: &Type,
ty: &Type,
field_definition: &ClassFieldDefinition,
is_descriptor: bool,
Expand Down Expand Up @@ -290,9 +290,9 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
&& !self
.get_class_fields(class)
.is_some_and(|f| f.contains(&dunder::NEW))
&& (!matches!(ty, Type::Ellipsis) || !self.module().path().is_interface())
&& (!matches!(value_ty, Type::Ellipsis) || !self.module().path().is_interface())
{
self.check_enum_value_annotation(ty, &enum_value_ty, name, range, errors);
self.check_enum_value_annotation(value_ty, &enum_value_ty, name, range, errors);
}
// If this field is an alias (value is a simple name referring to another field),
// look up the aliased member and return its type instead of creating a new enum literal.
Expand All @@ -305,7 +305,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
Lit::Enum(Box::new(LitEnum {
class: enum_.cls.clone(),
member: name.clone(),
ty: ty.clone(),
ty: value_ty.clone(),
}))
.to_implicit_type(),
)
Expand Down
4 changes: 2 additions & 2 deletions pyrefly/lib/test/django/enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ django_testcase!(
test_overwrite_value,
r#"
from django.db.models import Choices
from typing import Any, assert_type
from typing import Any, Literal, assert_type

class A(Choices):
X = 1
Expand All @@ -133,7 +133,7 @@ class B(Choices):
assert_type(A.X._value_, str)
assert_type(A.X.value, str)
assert_type(A.values, list[str])
assert_type(B.X._value_, int)
assert_type(B.X._value_, Literal[1])
assert_type(B.X.value, str)
assert_type(B.values, list[str])
"#,
Expand Down
49 changes: 34 additions & 15 deletions pyrefly/lib/test/enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ assert_type(MyEnum["X"], Literal[MyEnum.X])
assert_type(MyEnum.__PRIVATE, int) # E: Private attribute `__PRIVATE` cannot be accessed outside of its defining class
assert_type(MyEnum.X.name, Literal["X"])
assert_type(MyEnum.X._name_, Literal["X"])
assert_type(MyEnum.X.value, int)
assert_type(MyEnum.X._value_, int)
assert_type(MyEnum.X.value, Literal[1])
assert_type(MyEnum.X._value_, Literal[1])

MyEnum["FOO"] # E: Enum `MyEnum` does not have a member named `FOO`

Expand All @@ -63,8 +63,8 @@ def bar(member: int) -> None:

def foo(member: MyEnum) -> None:
assert_type(member.name, str)
assert_type(member.value, int)
assert_type(member._value_, int)
assert_type(member.value, Literal[1, 2])
assert_type(member._value_, Literal[1, 2])
"#,
);

Expand Down Expand Up @@ -147,13 +147,14 @@ testcase!(
test_value_annotation,
r#"
from enum import Enum, member, auto
from typing import Literal

class MyEnum(Enum):
_value_: int
V = member(1)
W = auto()
X = 1
Y = "FOO" # E: Enum member `Y` has type `str`, must match the `_value_` attribute annotation of `int`
Y = "FOO" # E: Enum member `Y` has type `Literal['FOO']`, must match the `_value_` attribute annotation of `int`
Z = member("FOO") # E: Enum member `Z` has type `str`, must match the `_value_` attribute annotation of `int`

def get_value(self) -> int:
Expand All @@ -168,14 +169,14 @@ testcase!(
test_infer_value,
r#"
from enum import Enum
from typing import assert_type
from typing import assert_type, Literal

class MyEnum(Enum):
X = 1
Y = "foo"
def test(e: MyEnum):
# the inferred type use promoted types, for performance reasons
assert_type(e.value, int | str)
# the inferred type of `e.value` is the Literal union of all member values
assert_type(e.value, Literal['foo', 1])
"#,
);

Expand All @@ -191,7 +192,7 @@ class MyEnumUnannotated(Enum):
def mutate(ea: MyEnumAnnotated, eu: MyEnumUnannotated) -> None:
ea._value_ = 2 # Allowed for now, because it must be permitted in `__init__`
ea.value = 2 # E: Cannot set field `value`
eu._value_ = 2 # Allowed for now, because it must be permitted in `__init__`
eu._value_ = 2 # E: `Literal[2]` is not assignable to attribute `_value_` with type `Literal[1]`
eu.value = 2 # E: Cannot set field `value`
"#,
);
Expand Down Expand Up @@ -257,6 +258,22 @@ def f(e: Literal[E.X, E.Y]) -> int:
"#,
);

testcase!(
test_value_of_enum_is_union_of_member_literals,
r#"
from enum import Enum
from typing import Literal

class Priority(Enum):
LOW = 1
MEDIUM = 2
HIGH = 3

def get_priority_level(p: Priority) -> Literal[1, 2, 3]:
return p.value
"#,
);

testcase!(
test_enum_union_simplification,
r#"
Expand Down Expand Up @@ -320,7 +337,7 @@ def foo(f: MyFlag) -> None:
testcase!(
test_enum_instance_only_attr,
r#"
from typing import assert_type, Any
from typing import assert_type, Any, Literal
from enum import Enum

class MyEnum(Enum):
Expand All @@ -331,7 +348,7 @@ class MyEnum(Enum):
assert_type(MyEnum.Y, int)

for x in MyEnum:
assert_type(x.value, str) # Y is not an enum member
assert_type(x.value, Literal['bar', 'foo']) # Y is not an enum member
"#,
);

Expand Down Expand Up @@ -516,18 +533,20 @@ fn env_enum_dots() -> TestEnv {
let mut env = TestEnv::new();
env.add_with_path("py", "py.py", r#"
from enum import IntEnum
from typing import Literal

class Color(IntEnum):
RED = ... # E: Enum member `RED` has type `Ellipsis`, must match the `_value_` attribute annotation of `int`
GREEN = "wrong" # E: Enum member `GREEN` has type `str`, must match the `_value_` attribute annotation of `int`
GREEN = "wrong" # E: Enum member `GREEN` has type `Literal['wrong']`, must match the `_value_` attribute annotation of `int`
"#
);
env.add_with_path("pyi", "pyi.pyi", r#"
from enum import IntEnum
from typing import Literal

class Color(IntEnum):
RED = ...
GREEN = "wrong" # E: Enum member `GREEN` has type `str`, must match the `_value_` attribute annotation of `int`
GREEN = "wrong" # E: Enum member `GREEN` has type `Literal['wrong']`, must match the `_value_` attribute annotation of `int`
"#
);
env
Expand Down Expand Up @@ -655,12 +674,12 @@ testcase!(
test_override_value_prop,
r#"
from enum import Enum
from typing import assert_type
from typing import assert_type, Literal
class E(Enum):
X = 1
@property
def value(self) -> str: ...
assert_type(E.X._value_, int)
assert_type(E.X._value_, Literal[1])
assert_type(E.X.value, str)
"#,
);
Expand Down
Loading