Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/pyrefly_python/src/dunder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ pub const CLASS: Name = Name::new_static("__class__");
pub const CLASS_GETITEM: Name = Name::new_static("__class_getitem__");
pub const CONTAINS: Name = Name::new_static("__contains__");
pub const DATACLASS_FIELDS: Name = Name::new_static("__dataclass_fields__");
pub const DEFAULTS: Name = Name::new_static("__defaults__");
pub const DELATTR: Name = Name::new_static("__delattr__");
pub const DELITEM: Name = Name::new_static("__delitem__");
pub const DICT: Name = Name::new_static("__dict__");
Expand Down
43 changes: 42 additions & 1 deletion pyrefly/lib/binding/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use pyrefly_graph::index::Idx;
use pyrefly_graph::index::Index;
use pyrefly_graph::index_map::IndexMap;
use pyrefly_python::ast::Ast;
use pyrefly_python::dunder;
use pyrefly_python::module_name::ModuleName;
use pyrefly_python::nesting_context::NestingContext;
use pyrefly_python::short_identifier::ShortIdentifier;
Expand Down Expand Up @@ -263,6 +264,9 @@ pub struct BindingsBuilder<'a> {
next_lambda_param_id: u32,
/// See `BindingsInner::subsequently_initialized`.
subsequently_initialized: SmallSet<Idx<KeyAnnotation>>,
/// Defaults extracted from an adjacent `__new__.__defaults__` assignment,
/// set by `stmts()` and consumed by namedtuple synthesis in `stmt()`.
pub adjacent_namedtuple_defaults: Option<Vec<Expr>>,
}

/// An enum tracking whether we are in a generator expression
Expand Down Expand Up @@ -533,6 +537,7 @@ impl Bindings {
lambda_yield_keys: Vec::new(),
next_lambda_param_id: 0,
subsequently_initialized: SmallSet::new(),
adjacent_namedtuple_defaults: None,
};
builder.init_static_scope(&x.body, true);
if module_info.name() != ModuleName::builtins() {
Expand Down Expand Up @@ -760,6 +765,25 @@ impl CurrentIdx {
}
}

fn extract_new_defaults(stmt: &Stmt, name: &str) -> Option<Vec<Expr>> {
if let Stmt::Assign(assign) = stmt
&& let [Expr::Attribute(outer)] = assign.targets.as_slice()
&& outer.attr.id == dunder::DEFAULTS
&& let Expr::Attribute(inner) = outer.value.as_ref()
&& inner.attr.id == dunder::NEW
&& let Expr::Name(target_name) = inner.value.as_ref()
&& target_name.id == name
{
match assign.value.as_ref() {
Expr::Tuple(tuple) => Some(tuple.elts.clone()),
Expr::NoneLiteral(_) => Some(vec![]),
_ => None,
}
} else {
None
}
}

impl<'a> BindingsBuilder<'a> {
/// Whether to infer empty container types and unsolved type variables based on first use.
pub fn infer_with_first_use(&self) -> bool {
Expand Down Expand Up @@ -962,8 +986,25 @@ impl<'a> BindingsBuilder<'a> {
}

pub fn stmts(&mut self, xs: Vec<Stmt>, parent: &NestingContext) {
for x in xs {
let mut iter = xs.into_iter().peekable();
while let Some(x) = iter.next() {
if let Stmt::Assign(assign) = &x
&& let [Expr::Name(name)] = assign.targets.as_slice()
&& let Expr::Call(call) = assign.value.as_ref()
&& let Some(defaults) = iter
.peek()
.and_then(|next| extract_new_defaults(next, &name.id))
&& let Some(special) = self.as_special_export(&call.func)
&& matches!(
special,
SpecialExport::TypingNamedTuple | SpecialExport::CollectionsNamedTuple
)
{
iter.next();
self.adjacent_namedtuple_defaults = Some(defaults);
}
self.stmt(x, parent);
self.adjacent_namedtuple_defaults = None;
}
}

Expand Down
109 changes: 73 additions & 36 deletions pyrefly/lib/binding/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,18 @@ enum SynthesizedClassKind {
NewType,
}

/// Right-align `default_elts` into `defaults`: a slice of N elements makes the last N fields
/// optional. An empty slice clears all defaults.
fn apply_adjacent_defaults(default_elts: &[Expr], n_members: usize, defaults: &mut [Option<Expr>]) {
defaults.iter_mut().for_each(|d| *d = None);
let n_defaults = default_elts.len().min(n_members);
// Right-align: skip leading elements if more defaults than fields
let start = default_elts.len() - n_defaults;
for (i, elt) in default_elts[start..].iter().enumerate() {
defaults[n_members - n_defaults + i] = Some(elt.clone());
}
}

impl<'a> BindingsBuilder<'a> {
fn def_index(&mut self) -> ClassDefIndex {
self.metadata.push_class()
Expand Down Expand Up @@ -179,6 +191,7 @@ impl<'a> BindingsBuilder<'a> {
members,
&mut call.arguments.keywords,
false,
None,
))
} else {
None
Expand All @@ -194,6 +207,7 @@ impl<'a> BindingsBuilder<'a> {
&mut call.func,
members,
false,
None,
))
} else {
None
Expand Down Expand Up @@ -764,7 +778,7 @@ impl<'a> BindingsBuilder<'a> {
/// and field definition bindings.
fn insert_synthesized_fields(
&mut self,
member_definitions: Vec<(String, TextRange, Option<Expr>, Option<Expr>)>,
member_definitions: Vec<(String, TextRange, Option<Expr>, Option<ExprOrBinding>)>,
fields: &mut SmallMap<Name, ClassFieldProperties>,
class_indices: &ClassIndices,
illegal_identifier_handling: IllegalIdentifierHandling,
Expand Down Expand Up @@ -841,7 +855,7 @@ impl<'a> BindingsBuilder<'a> {
});
let definition = match (member_value, force_class_initialization) {
(Some(value), _) => ClassFieldDefinition::AssignedInBody {
value: Box::new(ExprOrBinding::Expr(value)),
value: Box::new(value),
annotation,
alias_of: None,
},
Expand Down Expand Up @@ -879,7 +893,7 @@ impl<'a> BindingsBuilder<'a> {
base: Option<Expr>,
keywords: Box<[(Name, Expr)]>,
// name, position, annotation, value
member_definitions: Vec<(String, TextRange, Option<Expr>, Option<Expr>)>,
member_definitions: Vec<(String, TextRange, Option<Expr>, Option<ExprOrBinding>)>,
illegal_identifier_handling: IllegalIdentifierHandling,
force_class_initialization: bool,
class_kind: SynthesizedClassKind,
Expand Down Expand Up @@ -993,7 +1007,7 @@ impl<'a> BindingsBuilder<'a> {
for arg in &mut *members {
self.ensure_expr(arg, class_object.usage());
}
let member_definitions: Vec<(String, TextRange, Option<Expr>, Option<Expr>)> =
let member_definitions: Vec<(String, TextRange, Option<Expr>, Option<ExprOrBinding>)> =
match members {
// Enum('Color', 'RED, GREEN, BLUE')
// Enum('Color', 'RED GREEN BLUE')
Expand Down Expand Up @@ -1077,7 +1091,7 @@ impl<'a> BindingsBuilder<'a> {
}
}
.into_iter()
.map(|(name, range, value)| (name, range, None, value))
.map(|(name, range, value)| (name, range, None, value.map(ExprOrBinding::Expr)))
.collect();
self.synthesize_class_def(
class_name,
Expand Down Expand Up @@ -1105,6 +1119,7 @@ impl<'a> BindingsBuilder<'a> {
members: &mut [Expr],
keywords: &mut [Keyword],
bind_to_name: bool,
adjacent_defaults: Option<Vec<Expr>>,
) -> Idx<KeyClass> {
let (mut class_object, class_indices) = if bind_to_name {
self.class_object_and_indices(&class_name)
Expand Down Expand Up @@ -1158,12 +1173,25 @@ impl<'a> BindingsBuilder<'a> {
);
}
}
let member_definitions_with_defaults: Vec<(String, TextRange, Option<Expr>, Option<Expr>)> =
member_definitions
.into_iter()
.zip(defaults)
.map(|((name, range, annotation), default)| (name, range, annotation, default))
.collect();
if let Some(ref default_elts) = adjacent_defaults {
apply_adjacent_defaults(default_elts, n_members, &mut defaults);
}
let member_definitions_with_defaults: Vec<(
String,
TextRange,
Option<Expr>,
Option<ExprOrBinding>,
)> = member_definitions
.into_iter()
.zip(defaults)
.map(|((name, range, annotation), default)| {
// collections.namedtuple fields are untyped: defaults only
// mark optionality, not the field type.
let value =
default.map(|_| ExprOrBinding::Binding(Binding::Any(AnyStyle::Implicit)));
(name, range, annotation, value)
})
.collect();
let range = class_name.range();
self.synthesize_class_def(
class_name,
Expand All @@ -1190,26 +1218,34 @@ impl<'a> BindingsBuilder<'a> {
func: &mut Expr,
members: &[Expr],
bind_to_name: bool,
adjacent_defaults: Option<Vec<Expr>>,
) -> Idx<KeyClass> {
let (mut class_object, class_indices) = if bind_to_name {
self.class_object_and_indices(&class_name)
} else {
self.anon_class_object_and_indices(&class_name)
};
self.ensure_expr(func, class_object.usage());
let member_definitions: Vec<(String, TextRange, Option<Expr>, Option<Expr>)> = self
.parse_typing_namedtuple_fields(members, class_name.range)
.0
.into_iter()
.map(|(name, range, annotation)| {
if let Some(mut ann) = annotation {
self.ensure_type(&mut ann, &mut None);
(name, range, Some(ann), None)
} else {
(name, range, None, None)
}
})
.collect();
let (parsed_fields, _has_dynamic) =
self.parse_typing_namedtuple_fields(members, class_name.range);
let n_members = parsed_fields.len();
let mut defaults: Vec<Option<Expr>> = vec![None; n_members];
if let Some(ref default_elts) = adjacent_defaults {
apply_adjacent_defaults(default_elts, n_members, &mut defaults);
}
let member_definitions: Vec<(String, TextRange, Option<Expr>, Option<ExprOrBinding>)> =
parsed_fields
.into_iter()
.zip(defaults)
.map(|((name, range, annotation), default)| {
if let Some(mut ann) = annotation {
self.ensure_type(&mut ann, &mut None);
(name, range, Some(ann), default.map(ExprOrBinding::Expr))
} else {
(name, range, None, default.map(ExprOrBinding::Expr))
}
})
.collect();
self.synthesize_class_def(
class_name,
class_object,
Expand Down Expand Up @@ -1296,19 +1332,20 @@ impl<'a> BindingsBuilder<'a> {
);
}
}
let member_definitions: Vec<(String, TextRange, Option<Expr>, Option<Expr>)> = match args {
// Movie = TypedDict('Movie', {'name': str, 'year': int})
[Expr::Dict(ExprDict { items, .. })] => items
.iter_mut()
.filter_map(|item| {
if let Some(key) = &mut item.key {
self.ensure_expr(key, class_object.usage());
}
self.ensure_type(&mut item.value, &mut None);
match (&item.key, &item.value) {
(Some(Expr::StringLiteral(k)), v) => {
Some((k.value.to_string(), k.range(), Some(v.clone()), None))
let member_definitions: Vec<(String, TextRange, Option<Expr>, Option<ExprOrBinding>)> =
match args {
// Movie = TypedDict('Movie', {'name': str, 'year': int})
[Expr::Dict(ExprDict { items, .. })] => items
.iter_mut()
.filter_map(|item| {
if let Some(key) = &mut item.key {
self.ensure_expr(key, class_object.usage());
}
self.ensure_type(&mut item.value, &mut None);
match (&item.key, &item.value) {
(Some(Expr::StringLiteral(k)), v) => {
Some((k.value.to_string(), k.range(), Some(v.clone()), None))
}
(Some(k), _) => {
self.error(
k.range(),
Expand Down
6 changes: 6 additions & 0 deletions pyrefly/lib/binding/stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -607,12 +607,15 @@ impl<'a> BindingsBuilder<'a> {
call.arguments.args.split_first_mut()
{
self.check_functional_definition_name(&name.id, arg_name);
let adjacent_defaults =
self.adjacent_namedtuple_defaults.take();
self.synthesize_typing_named_tuple_def(
Ast::expr_name_identifier(name.clone()),
parent,
&mut call.func,
members,
true,
adjacent_defaults,
);
return;
}
Expand All @@ -622,13 +625,16 @@ impl<'a> BindingsBuilder<'a> {
call.arguments.args.split_first_mut()
{
self.check_functional_definition_name(&name.id, arg_name);
let adjacent_defaults =
self.adjacent_namedtuple_defaults.take();
self.synthesize_collections_named_tuple_def(
Ast::expr_name_identifier(name.clone()),
parent,
&mut call.func,
members,
&mut call.arguments.keywords,
true,
adjacent_defaults,
);
return;
}
Expand Down
Loading
Loading