Skip to content

Commit be49f7e

Browse files
authored
fix(treeshake): make re-exports work (#39)
1 parent 18f6a64 commit be49f7e

File tree

8 files changed

+119
-11
lines changed

8 files changed

+119
-11
lines changed

rust/src/common/ast/providers/fully_qualified_name_provider.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ impl FullyQualifiedNameProvider {
3131
}
3232
}
3333

34-
fn maybe_format_with_name_context(&self, name: &str) -> String {
34+
pub fn resolve_qualified_name(&self, name: &str) -> String {
3535
if self.name_context.len() > 0 {
3636
format!("{}.{}", self.name_context, name)
3737
} else {
@@ -41,7 +41,7 @@ impl FullyQualifiedNameProvider {
4141

4242
fn get_expr_qualified_name(&self, expr: &Expr) -> Option<String> {
4343
get_full_name_for_expr(expr).map(|name| match expr {
44-
Expr::NamedExpr(_) => self.maybe_format_with_name_context(&name),
44+
Expr::NamedExpr(_) => self.resolve_qualified_name(&name),
4545
_ => name,
4646
})
4747
}
@@ -58,14 +58,14 @@ impl FullyQualifiedNameProvider {
5858
| Stmt::ClassDef(_)
5959
| Stmt::FunctionDef(_)
6060
| Stmt::AsyncFunctionDef(_) => {
61-
self.maybe_format_with_name_context(name)
61+
self.resolve_qualified_name(name)
6262
}
6363
_ => name.to_string(),
6464
})
6565
.collect()
6666
}
6767

68-
fn resolve_fully_qualified_name(&self, qualified_name: &str) -> Vec<String> {
68+
pub fn resolve_fully_qualified_name(&self, qualified_name: &str) -> Vec<String> {
6969
let mut result: Vec<String> = Vec::new();
7070

7171
for (key, value) in &self.imports_provider.active_imports {

rust/src/treeshake/references_counter.rs

+56-7
Original file line numberDiff line numberDiff line change
@@ -153,21 +153,32 @@ impl ReferencesCounter {
153153
self.names_provider.name_context.len() == 0
154154
}
155155

156-
fn maybe_increase_stmt(&mut self, stmt: &Stmt) {
156+
fn maybe_increase_stmt_selective<F>(&mut self, stmt: &Stmt, predicate: F)
157+
where
158+
F: Fn(&str) -> bool,
159+
{
157160
for fqn in self.names_provider.get_stmt_fully_qualified_name(stmt) {
158-
self.increase(&fqn);
161+
if predicate(&fqn) {
162+
self.increase(&fqn);
163+
}
159164
}
160165

161166
// bump for this node because it is a global name that could be imported somewhere else via star import
162167
if self.import_star_module_specs.len() > 0 {
163168
for module_spec in self.import_star_module_specs.to_owned() {
164169
for full_name in get_full_name_for_stmt(stmt, &self.get_parent_package()) {
165-
self.increase(&format!("{}.{}", module_spec, full_name));
170+
if predicate(&full_name) {
171+
self.increase(&format!("{}.{}", module_spec, full_name));
172+
}
166173
}
167174
}
168175
}
169176
}
170177

178+
fn maybe_increase_stmt(&mut self, stmt: &Stmt) {
179+
self.maybe_increase_stmt_selective(stmt, |_| true);
180+
}
181+
171182
fn maybe_increase_expr(&mut self, expr: &Expr) {
172183
for fqn in self.names_provider.get_expr_fully_qualified_name(expr) {
173184
self.increase(&fqn);
@@ -283,13 +294,51 @@ impl Visitor for ReferencesCounter {
283294
self.always_bump_context = true;
284295
}
285296
}
286-
Stmt::ImportFrom(import_from) => {
287-
if import_from.names.len() == 1
288-
&& import_from.names[0].name.as_str() == "*"
297+
Stmt::Import(stmt_import) => {
298+
// check if one of the names defined by this import was imported somewhere else
299+
// if yes, bump reference of this import
300+
for alias in &stmt_import.names {
301+
let defined_name = if let Some(alias_value) = &alias.asname {
302+
alias_value
303+
} else {
304+
&alias.name
305+
};
306+
for fqn in self.names_provider.resolve_fully_qualified_name(
307+
&self.names_provider.resolve_qualified_name(&defined_name),
308+
) {
309+
if self.has_references_for_str(&fqn) {
310+
self.maybe_increase_stmt_selective(&stmt, |n| n == alias.name.as_str());
311+
}
312+
}
313+
}
314+
}
315+
316+
Stmt::ImportFrom(stmt_import_from) => {
317+
// check if one of the names defined by this import was imported somewhere else
318+
// if yes, bump reference of this import
319+
for alias in &stmt_import_from.names {
320+
let defined_name = if let Some(alias_value) = &alias.asname {
321+
alias_value
322+
} else {
323+
&alias.name
324+
};
325+
for fqn in self.names_provider.resolve_fully_qualified_name(
326+
&self.names_provider.resolve_qualified_name(&defined_name),
327+
) {
328+
if self.has_references_for_str(&fqn) {
329+
self.maybe_increase_stmt_selective(&stmt, |n| {
330+
n.ends_with(alias.name.as_str())
331+
});
332+
}
333+
}
334+
}
335+
336+
if stmt_import_from.names.len() == 1
337+
&& stmt_import_from.names[0].name.as_str() == "*"
289338
&& self.module_spec_has_references()
290339
{
291340
if let Ok(module_specs) = get_import_from_absolute_module_spec(
292-
&import_from,
341+
&stmt_import_from,
293342
&self.get_parent_package(),
294343
) {
295344
self.import_star_module_specs.extend(module_specs);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .hello_world import hello_world
2+
from .hello_world import inner_hello_world_alias
3+
4+
def main() -> None:
5+
hello_world()
6+
inner_hello_world_alias.hello_world()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from . import main
2+
3+
4+
if __name__ == "__main__":
5+
main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .inner_hello_world import hello_world, useless_func
2+
from .moin_world import moin_world
3+
import re_exports.hello_world.inner_hello_world as inner_hello_world_alias, re_exports.hello_world.moin_world as moin_world_alias # type: ignore[import-not-found]
4+
5+
__all__ = ["hello_world", "inner_hello_world_alias", "moin_world", "moin_world_alias", "useless_func"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
def hello_world() -> None:
2+
print("Hello World!")
3+
4+
def useless_func() -> None:
5+
pass
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
def moin_world() -> None:
2+
print("Moin Welt!")

tests/test_treeshake/test_treeshake_package.py

+36
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,39 @@ def test_treeshake_package_preserve_with_decorators(
7878
init_file_content = init_file.read_text()
7979
assert "@dataclass\nclass MyClass:\n pass" in init_file_content
8080
assert "@contextmanager\ndef my_context_manager() ->" in init_file_content
81+
82+
83+
def test_treeshake_package_re_exports(
84+
run_treeshake_package: RunTreeshakePackageT,
85+
) -> None:
86+
source_path = TEST_PACKAGES_DIR / "re_exports"
87+
result_path = run_treeshake_package(source_path)
88+
89+
hello_world_init = result_path / "hello_world" / "__init__.py"
90+
assert hello_world_init.exists()
91+
hello_world_init_content = hello_world_init.read_text()
92+
assert "from .inner_hello_world import hello_world" in hello_world_init_content
93+
assert (
94+
"import re_exports.hello_world.inner_hello_world as inner_hello_world_alias"
95+
in hello_world_init_content
96+
)
97+
assert (
98+
"re_exports.hello_world.moin_world as moin_world_alias"
99+
not in hello_world_init_content
100+
)
101+
assert "useless_func" not in hello_world_init_content
102+
assert "from .moin_world import moin_world" not in hello_world_init_content
103+
104+
assert not (result_path / "hello_world" / "moin_world").exists()
105+
106+
inner_hello_world_init = hello_world_init = (
107+
result_path / "hello_world" / "inner_hello_world" / "__init__.py"
108+
)
109+
assert inner_hello_world_init.exists()
110+
inner_hello_world_init_content = inner_hello_world_init.read_text()
111+
assert (
112+
"def hello_world() -> None:\n print('Hello World!')"
113+
in inner_hello_world_init_content
114+
)
115+
116+
assert "def useless_func() -> None:" not in inner_hello_world_init_content

0 commit comments

Comments
 (0)