Skip to content

Commit b647bfe

Browse files
authored
fix(bundle): transform annotation literals (#40)
1 parent be49f7e commit b647bfe

File tree

5 files changed

+74
-7
lines changed

5 files changed

+74
-7
lines changed

Cargo.lock

+2-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ crate-type = ["cdylib"]
1212
pyo3 = { version = "0.23.4", features = ["extension-module"] }
1313
rustpython-ast = { version = "0.4.0", features = ["visitor"] }
1414
rustpython-parser = { version = "0.4.0" }
15-
rustpython-unparser = { version = "0.2.0", features = ["transformer"] }
15+
rustpython-unparser = { version = "0.2.1", features = ["transformer"] }

rust/src/bundle/imports_transformer.rs

+35-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use crate::common::module_spec::{get_top_level_package, is_in_std_lib};
22
use pyo3::pyfunction;
33
use rustpython_ast::{
4-
text_size::TextRange, Alias, Expr, ExprAttribute, ExprName, Identifier, StmtImport,
5-
StmtImportFrom, Suite,
4+
text_size::TextRange, Alias, Constant, Expr, ExprAttribute, ExprConstant, ExprName, Identifier,
5+
StmtImport, StmtImportFrom, Suite,
66
};
77
use rustpython_parser::Parse;
88
use rustpython_unparser::{transformer::Transformer, Unparser};
@@ -12,6 +12,7 @@ struct ImportsTransformer {
1212
top_level_package: String,
1313
vendor_module_name: String,
1414
affected_names: HashSet<String>,
15+
is_in_annotation: bool,
1516
}
1617

1718
impl ImportsTransformer {
@@ -20,6 +21,7 @@ impl ImportsTransformer {
2021
top_level_package,
2122
vendor_module_name,
2223
affected_names: HashSet::new(),
24+
is_in_annotation: false,
2325
}
2426
}
2527

@@ -51,7 +53,37 @@ impl ImportsTransformer {
5153
}
5254

5355
impl Transformer for ImportsTransformer {
54-
fn generic_visit_stmt_import(&mut self, stmt: StmtImport) -> Option<StmtImport> {
56+
fn on_enter_annotation(&mut self, _: &Expr) {
57+
self.is_in_annotation = true;
58+
}
59+
60+
fn on_exit_annotation(&mut self, _: &Option<Expr>) {
61+
self.is_in_annotation = false;
62+
}
63+
64+
fn visit_expr_constant(&mut self, mut expr: ExprConstant) -> Option<ExprConstant> {
65+
if self.is_in_annotation {
66+
match &expr.value {
67+
Constant::Str(str_value) => {
68+
if str_value.contains(".") {
69+
let name_parts: Vec<&str> = str_value.splitn(2, ".").collect();
70+
let module_part = name_parts[0];
71+
if self.affected_names.contains(module_part) {
72+
expr.value =
73+
Constant::Str(format!("{}.{}", self.get_vendor_string(), str_value))
74+
}
75+
}
76+
77+
Some(expr)
78+
}
79+
_ => self.generic_visit_expr_constant(expr),
80+
}
81+
} else {
82+
self.generic_visit_expr_constant(expr)
83+
}
84+
}
85+
86+
fn visit_stmt_import(&mut self, stmt: StmtImport) -> Option<StmtImport> {
5587
Some(StmtImport {
5688
names: stmt
5789
.names
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import typer
2+
import typer as typerino
3+
4+
random_literal = "typer.Typer"
5+
6+
def modify_app(app: "typer.Typer") -> "typer.Typer":
7+
return app
8+
9+
def modify_app2(app2: "typerino.Typer") -> "typerino.Typer":
10+
return app2

tests/test_bundle/test_bundle_package.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ def test_bundle_relative_imports(run_bundle_package: RunBundlePackageT) -> None:
121121
assert (result_path / "module" / "hello_world.py").exists()
122122

123123

124-
def test_correct_transformer_recursion(run_bundle_package: RunBundlePackageT) -> None:
124+
def test_bundle_correct_transformer_recursion(
125+
run_bundle_package: RunBundlePackageT,
126+
) -> None:
125127
_, result_path = run_bundle_package(
126128
"correct_transformer_recursion", "correct_transformer_recursion"
127129
)
@@ -133,3 +135,26 @@ def test_correct_transformer_recursion(run_bundle_package: RunBundlePackageT) ->
133135
"correct_transformer_recursion._vendor.typer.style('Hello World!', fg=correct_transformer_recursion._vendor.typer.colors.BRIGHT_MAGENTA)"
134136
in init_file_content
135137
)
138+
139+
140+
def test_bundle_annotation_string_literals(
141+
run_bundle_package: RunBundlePackageT,
142+
) -> None:
143+
_, result_path = run_bundle_package(
144+
"annotation_string_literals", "annotation_string_literals"
145+
)
146+
147+
init_file = result_path / "__init__.py"
148+
init_file_content = init_file.read_text()
149+
150+
assert "random_literal = 'typer.Typer'" in init_file_content
151+
152+
assert (
153+
"def modify_app(app: 'annotation_string_literals._vendor.typer.Typer') -> 'annotation_string_literals._vendor.typer.Typer':\n return app"
154+
in init_file_content
155+
)
156+
157+
assert (
158+
"def modify_app2(app2: 'typerino.Typer') -> 'typerino.Typer':\n return app2"
159+
in init_file_content
160+
)

0 commit comments

Comments
 (0)