Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 169 additions & 45 deletions pyrefly/lib/alt/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ use crate::types::callable::PropertyRole;
use crate::types::callable::Required;
use crate::types::class::ClassKind;
use crate::types::keywords::DataclassTransformMetadata;
use crate::types::types::AnyStyle;
use crate::types::types::CalleeKind;
use crate::types::types::Forall;
use crate::types::types::Forallable;
Expand All @@ -95,6 +96,50 @@ fn is_class_property_decorator_type(ty: &Type) -> bool {
}
}

fn is_unannotated_passthrough_callable(ty: &Type) -> bool {
let signature = match ty {
Type::Function(f) => &f.signature,
Type::Callable(c) => c,
_ => return false,
};
if !matches!(signature.ret, Type::Any(AnyStyle::Implicit)) {
return false;
}
match &signature.params {
Params::List(params) => {
let items = params.items();
// Require a purely variadic wrapper shape. A zero-arg callable or one that names
// positional parameters could be narrower than the original decoratee and should
// not be treated as a transparent passthrough.
!items.is_empty()
&& items.iter().all(|param| match param {
Param::VarArg(_, ty) | Param::Kwargs(_, ty) => {
matches!(ty, Type::Any(AnyStyle::Implicit))
}
_ => false,
})
}
_ => false,
}
}

enum DecoratorValueNormalization {
Unchanged(Type),
FactoryBranchesOnly(Type),
}

impl DecoratorValueNormalization {
fn into_type(self) -> Type {
match self {
Self::Unchanged(ty) | Self::FactoryBranchesOnly(ty) => ty,
}
}

fn narrowed_to_factory_branches(&self) -> bool {
matches!(self, Self::FactoryBranchesOnly(_))
}
}

/// Result of resolving a function parameter's type and requiredness.
struct ParamTypeResult {
/// The resolved type of the parameter.
Expand Down Expand Up @@ -1255,6 +1300,74 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
})
}

/// When a decorator application returns a union, drop members that still look like
/// decorator factory branches so the decorated symbol keeps only final callable values.
fn normalize_decorator_return_union(&self, returned_ty: Type, decoratee: &Type) -> Type {
let Type::Union(box Union { members, .. }) = &returned_ty else {
return returned_ty;
};

let mut kept = Vec::with_capacity(members.len());
let mut pruned = false;
for member in members {
if self.is_decorator_factory_branch(member, decoratee) {
pruned = true;
} else {
kept.push(member.clone());
}
}

if pruned && !kept.is_empty() {
if kept.len() == 1 && is_unannotated_passthrough_callable(&kept[0]) {
decoratee.clone()
} else {
self.unions(kept)
}
} else {
// If nothing was pruned, or every member looked factory-like, keep the original
// result rather than guessing which branch should represent the decorated value.
returned_ty
}
}

/// When the decorator value itself is a union, keep only the branches that still behave
/// like decorator factories for this decoratee so the subsequent call produces the wrapped
/// function type. This is the opposite direction from return-side normalization above.
fn normalize_decorator_value_union(
&self,
decorator: Type,
decoratee: &Type,
) -> DecoratorValueNormalization {
let Type::Union(box Union { members, .. }) = &decorator else {
return DecoratorValueNormalization::Unchanged(decorator);
};

let kept: Vec<_> = members
.iter()
.filter(|member| self.is_decorator_factory_branch(member, decoratee))
.cloned()
.collect();

if !kept.is_empty() && kept.len() < members.len() {
DecoratorValueNormalization::FactoryBranchesOnly(self.unions(kept))
} else {
DecoratorValueNormalization::Unchanged(decorator)
}
}

/// Returns whether this callable branch still behaves like a decorator factory for the
/// current decoratee: it accepts the decoratee as its first parameter and returns a callable.
fn is_decorator_factory_branch(&self, ty: &Type, decoratee: &Type) -> bool {
let Some(first_param) = ty.callable_first_param(self.heap) else {
return false;
};
let Some(ret) = ty.callable_return_type(self.heap) else {
return false;
};

ret.is_toplevel_callable() && self.is_subset_eq(decoratee, &first_param)
}

fn apply_function_decorator(
&self,
decorator: Type,
Expand All @@ -1269,63 +1382,74 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
{
return decoratee;
}
// Preserve function metadata, so things like method binding still work.
let call_target =
self.as_call_target_or_error(decorator, CallStyle::FreeForm, range, errors, None);
// If the decoratee is generic, unwrap the `Forall` so that `call_infer` can treat the
// type parameters as concrete in the raw inferred result; this avoids us replacing the
// type vars with partial types.
let (tparams_opt, decoratee_arg) = match &decoratee {
Type::Forall(forall) => (Some(forall.tparams.clone()), forall.body.clone().as_type()),
_ => (None, decoratee.clone()),
};
let normalized_decorator = self.normalize_decorator_value_union(decorator, &decoratee_arg);
let normalized_decorator_union = normalized_decorator.narrowed_to_factory_branches();
let decorator = normalized_decorator.into_type();
// Preserve function metadata, so things like method binding still work.
let call_target =
self.as_call_target_or_error(decorator, CallStyle::FreeForm, range, errors, None);
let arg = CallArg::ty(&decoratee_arg, range);
// Compute the raw return type - this may need tweaks to handle Forall well.
let inferred_ty =
match self.call_infer(call_target, &[arg], &[], range, errors, None, None, None) {
Type::Callable(c) => self.heap.mk_function(Function {
signature: *c,
metadata: metadata.clone(),
}),
Type::Forall(box Forall {
tparams,
body: Forallable::Callable(c),
}) => Forallable::Function(Function {
signature: c,
metadata: metadata.clone(),
})
.forall(tparams),
// Callback protocol. We convert it to a function so we can add function metadata.
Type::ClassType(cls)
if self
.get_metadata_for_class(cls.class_object())
.is_protocol() =>
{
let call_attr = self.instance_as_dunder_call(&cls).and_then(|call_attr| {
if let Type::BoundMethod(m) = call_attr {
Some(
self.bind_boundmethod(&m, &mut |a, b| self.is_subset_eq(a, b))
.unwrap_or(m.func.as_type()),
)
} else {
None
}
});
if let Some(mut call_attr) = call_attr {
call_attr.transform_toplevel_func_metadata(|m| {
*m = FuncMetadata {
kind: FunctionKind::CallbackProtocol(Box::new(cls.clone())),
flags: metadata.flags.clone(),
};
});
call_attr
let returned_ty =
self.call_infer(call_target, &[arg], &[], range, errors, None, None, None);
let returned_ty = self.normalize_decorator_return_union(returned_ty, &decoratee_arg);
let returned_ty =
if normalized_decorator_union && is_unannotated_passthrough_callable(&returned_ty) {
decoratee.clone()
} else {
returned_ty
};
let inferred_ty = match returned_ty {
Type::Callable(c) => self.heap.mk_function(Function {
signature: *c,
metadata: metadata.clone(),
}),
Type::Forall(box Forall {
tparams,
body: Forallable::Callable(c),
}) => Forallable::Function(Function {
signature: c,
metadata: metadata.clone(),
})
.forall(tparams),
// Callback protocol. We convert it to a function so we can add function metadata.
Type::ClassType(cls)
if self
.get_metadata_for_class(cls.class_object())
.is_protocol() =>
{
let call_attr = self.instance_as_dunder_call(&cls).and_then(|call_attr| {
if let Type::BoundMethod(m) = call_attr {
Some(
self.bind_boundmethod(&m, &mut |a, b| self.is_subset_eq(a, b))
.unwrap_or(m.func.as_type()),
)
} else {
self.heap.mk_class_type(cls)
None
}
});
if let Some(mut call_attr) = call_attr {
call_attr.transform_toplevel_func_metadata(|m| {
*m = FuncMetadata {
kind: FunctionKind::CallbackProtocol(Box::new(cls.clone())),
flags: metadata.flags.clone(),
};
});
call_attr
} else {
self.heap.mk_class_type(cls)
}
Type::ClassType(cls) if cls.has_qname("functools", "_Wrapped") => decoratee.clone(),
returned_ty => returned_ty,
};
}
Type::ClassType(cls) if cls.has_qname("functools", "_Wrapped") => decoratee.clone(),
returned_ty => returned_ty,
};

// Given the raw `inferred_ty`, which may include `Type::Quantified` type variables coming from a
// `Forall` in the original decoratee, we need to create the proper output type:
Expand Down
70 changes: 70 additions & 0 deletions pyrefly/lib/test/decorators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -717,3 +717,73 @@ assert_type(test1(1, 2), int)
assert_type(test2(1, 2), int)
"#,
);

// Dual-use decorator: can be used as @decorator or @decorator(flag).
// The decorator function returns a union of the wrapper and the decorator factory,
// but since the wrapper is an unannotated passthrough, we should preserve the
// original function's type.
testcase!(
test_dual_use_decorator,
r#"
from functools import wraps
from typing import assert_type

def optional_debug(func_or_flag=None):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
if callable(func_or_flag):
return decorator(func_or_flag)
return decorator

@optional_debug
def compute(x: int, y: int, z: int) -> int:
return x + y + z

assert_type(compute(1, 2, 3), int)
"#,
);

testcase!(
test_dual_use_decorator_factory_call,
r#"
from functools import wraps
from typing import assert_type

def optional_debug(func_or_flag=None):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
if callable(func_or_flag):
return decorator(func_or_flag)
return decorator

@optional_debug(True)
def compute(x: int, y: int, z: int) -> int:
return x + y + z

assert_type(compute(1, 2, 3), int)
"#,
);

testcase!(
test_dual_use_decorator_typed_wrapper_branch,
r#"
from typing import Callable, assert_type

def stringify_decorator(
func: Callable[[int], int],
) -> Callable[[int], str] | Callable[[Callable[[int], int]], Callable[[int], str]]:
...

@stringify_decorator
def compute(x: int) -> int:
return x + 1

assert_type(compute(1), str)
"#,
);
Loading