Skip to content

Commit 669468c

Browse files
committed
fix(treeshake): default arg values are visited
RustPython/Parser#133
1 parent b647bfe commit 669468c

File tree

5 files changed

+98
-2
lines changed

5 files changed

+98
-2
lines changed

rust/src/common/ast/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
pub mod full_name;
22
pub mod providers;
3-
3+
pub mod visitor_patch;
44
use pyo3::{
55
exceptions::{PyImportError, PyValueError},
66
PyResult,

rust/src/common/ast/visitor_patch.rs

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
use rustpython_ast::{Arg, ArgWithDefault, Arguments, Visitor};
2+
// patch for missing visits in RustPython visitor
3+
// https://github.com/RustPython/Parser/issues/133
4+
pub trait VisitorPatch: Visitor {
5+
fn generic_visit_arg_patch(&mut self, arg: Arg) {
6+
if let Some(annotation) = arg.annotation {
7+
self.visit_expr(*annotation);
8+
}
9+
}
10+
11+
fn generic_visit_arguments_patch(&mut self, args: Arguments) {
12+
for arg in args.args {
13+
self.visit_arg_with_default(arg);
14+
}
15+
16+
for posonly in args.posonlyargs {
17+
self.visit_arg_with_default(posonly);
18+
}
19+
20+
for kwonly in args.kwonlyargs {
21+
self.visit_arg_with_default(kwonly);
22+
}
23+
if let Some(vararg) = args.vararg {
24+
self.visit_arg(*vararg);
25+
}
26+
if let Some(kwarg) = args.kwarg {
27+
self.visit_arg(*kwarg);
28+
}
29+
}
30+
31+
fn visit_arg_with_default(&mut self, arg: ArgWithDefault) {
32+
self.generic_visit_arg_with_default(arg);
33+
}
34+
35+
fn generic_visit_arg_with_default(&mut self, arg: ArgWithDefault) {
36+
self.visit_arg(arg.def);
37+
if let Some(default) = arg.default {
38+
self.visit_expr(*default);
39+
}
40+
}
41+
}

rust/src/treeshake/references_counter.rs

+21-1
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,17 @@ use std::{
55
};
66

77
use pyo3::{pyclass, pymethods};
8-
use rustpython_ast::{Expr, ExprCompare, Stmt, StmtImport, StmtImportFrom, Suite, Visitor};
8+
use rustpython_ast::{
9+
Arg, Arguments, Expr, ExprCompare, Stmt, StmtImport, StmtImportFrom, Suite, Visitor,
10+
};
911
use rustpython_parser::Parse;
1012

1113
use crate::common::{
1214
ast::{
1315
full_name::{get_full_name_for_expr, get_full_name_for_stmt},
1416
get_import_from_absolute_module_spec,
1517
providers::fully_qualified_name_provider::FullyQualifiedNameProvider,
18+
visitor_patch::VisitorPatch,
1619
},
1720
module_spec::get_parent_package,
1821
};
@@ -372,6 +375,13 @@ impl Visitor for ReferencesCounter {
372375

373376
match expr {
374377
Expr::Call(_) => {
378+
if let Some(full_name) = get_full_name_for_expr(&expr) {
379+
println!(
380+
"Default {} {}",
381+
full_name,
382+
self.is_global_scope() && self.module_spec_has_references()
383+
)
384+
}
375385
if self.is_global_scope() && self.module_spec_has_references() {
376386
self.maybe_increase_expr(&expr);
377387
self.always_bump_context = true;
@@ -395,4 +405,14 @@ impl Visitor for ReferencesCounter {
395405
self.names_provider.visit_import_from(&node);
396406
self.generic_visit_stmt_import_from(node);
397407
}
408+
409+
fn generic_visit_arg(&mut self, node: Arg) {
410+
self.generic_visit_arg_patch(node);
411+
}
412+
413+
fn generic_visit_arguments(&mut self, node: Arguments) {
414+
self.generic_visit_arguments_patch(node);
415+
}
398416
}
417+
418+
impl VisitorPatch for ReferencesCounter {}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import typing as t
2+
import dataclasses
3+
4+
@dataclasses.dataclass
5+
class DefaultPlaceholder:
6+
value: t.Any
7+
8+
def Default(value: t.Any) -> DefaultPlaceholder:
9+
return DefaultPlaceholder(value=value)
10+
11+
12+
class DefaultUser:
13+
def __init__(self, value: t.Any = Default(42)) -> None:
14+
self.value = value
15+
16+
17+
def main() -> None:
18+
user = DefaultUser()
19+
print(user.value)
20+
21+
if __name__ == "__main__":
22+
main()

tests/test_treeshake/test_treeshake_package.py

+13
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,16 @@ def test_treeshake_package_re_exports(
114114
)
115115

116116
assert "def useless_func() -> None:" not in inner_hello_world_init_content
117+
118+
119+
def test_treeshake_package_param_default_value(
120+
run_treeshake_package: RunTreeshakePackageT,
121+
) -> None:
122+
source_path = TEST_PACKAGES_DIR / "param_default_value"
123+
result_path = run_treeshake_package(source_path)
124+
init_file = result_path / "__init__.py"
125+
init_file_content = init_file.read_text()
126+
assert (
127+
"def Default(value: t.Any) -> DefaultPlaceholder:\n return DefaultPlaceholder(value=value)"
128+
in init_file_content
129+
)

0 commit comments

Comments
 (0)