Skip to content

Commit c4a01b7

Browse files
authored
[BUG] Fixed LPL Refactorer (#521)
closes #387
1 parent b7c5853 commit c4a01b7

File tree

2 files changed

+973
-46
lines changed

2 files changed

+973
-46
lines changed

src/ecooptimizer/refactorers/concrete/long_parameter_list.py

Lines changed: 286 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -260,53 +260,195 @@ def update_parameter_usages(
260260
):
261261
"""
262262
Updates the function body to use encapsulated parameter objects.
263+
This method transforms parameter references in the function body to use new data_params
264+
and config_params objects.
265+
266+
Args:
267+
function_node: CST node of the function to transform
268+
classified_params: Dictionary mapping parameter groups ('data_params' or 'config_params')
269+
to lists of parameter names in each group
270+
271+
Returns:
272+
The transformed function node with updated parameter usages
263273
"""
274+
# Create a module with just the function to get metadata
275+
module = cst.Module(body=[function_node])
276+
wrapper = MetadataWrapper(module)
264277

265278
class ParameterUsageTransformer(cst.CSTTransformer):
266-
def __init__(self, classified_params: dict[str, list[str]]):
267-
self.param_to_group = {}
279+
"""
280+
A CST transformer that updates parameter references to use the new parameter objects.
281+
"""
268282

283+
METADATA_DEPENDENCIES = (ParentNodeProvider,)
284+
285+
def __init__(
286+
self, classified_params: dict[str, list[str]], metadata_wrapper: MetadataWrapper
287+
):
288+
super().__init__()
289+
# map each parameter to its group (data_params or config_params)
290+
self.param_to_group = {}
291+
self.parent_provider = metadata_wrapper.resolve(ParentNodeProvider)
269292
# flatten classified_params to map each param to its group (dataParams or configParams)
270293
for group, params in classified_params.items():
271294
for param in params:
272295
self.param_to_group[param] = group
273296

274-
def leave_Assign(
275-
self,
276-
original_node: cst.Assign, # noqa: ARG002
277-
updated_node: cst.Assign,
278-
) -> cst.Assign:
297+
def is_in_assignment_target(self, node: cst.CSTNode) -> bool:
279298
"""
280-
Transform only right-hand side references to parameters that need to be updated.
281-
Ensure left-hand side (self attributes) remain unchanged.
299+
Check if a node is part of an assignment target (left side of =).
300+
301+
Args:
302+
node: The CST node to check
303+
304+
Returns:
305+
True if the node is part of an assignment target that should not be transformed,
306+
False otherwise
307+
"""
308+
current = node
309+
while current:
310+
parent = self.parent_provider.get(current)
311+
312+
# if we're at an AssignTarget, check if it's a simple Name assignment
313+
if isinstance(parent, cst.AssignTarget):
314+
if isinstance(current, cst.Name):
315+
# allow transformation for simple parameter assignments
316+
return False
317+
return True
318+
319+
if isinstance(parent, cst.Assign):
320+
# if we reach an Assign node, check if we came from the targets
321+
for target in parent.targets:
322+
if target.target.deep_equals(current):
323+
if isinstance(current, cst.Name):
324+
# allow transformation for simple parameter assignments
325+
return False
326+
return True
327+
return False
328+
329+
if isinstance(parent, cst.Module):
330+
return False
331+
332+
current = parent
333+
return False
334+
335+
def leave_Name(
336+
self, original_node: cst.Name, updated_node: cst.Name
337+
) -> cst.BaseExpression:
282338
"""
283-
if not isinstance(updated_node.value, cst.Name):
339+
Transform standalone parameter references.
340+
341+
Skip transformation if:
342+
1. The name is part of an attribute access (eg: self.param)
343+
2. The name is part of a complex assignment target (eg: self.x = y)
344+
345+
Transform if:
346+
1. The name is a simple parameter being assigned (eg: param1 = value)
347+
2. The name is used as a value (eg: x = param1)
348+
349+
Args:
350+
original_node: The original Name node
351+
updated_node: The current state of the Name node
352+
353+
Returns:
354+
The transformed node or the original if no transformation is needed
355+
"""
356+
# dont't transform if this is part of a complex assignment target
357+
if self.is_in_assignment_target(original_node):
284358
return updated_node
285359

286-
var_name = updated_node.value.value
360+
# dont't transform if this is part of an attribute access (e.g., self.param)
361+
parent = self.parent_provider.get(original_node)
362+
if isinstance(parent, cst.Attribute) and original_node is parent.attr:
363+
return updated_node
287364

288-
if var_name in self.param_to_group:
289-
new_value = cst.Attribute(
290-
value=cst.Name(self.param_to_group[var_name]), attr=cst.Name(var_name)
365+
name_value = updated_node.value
366+
if name_value in self.param_to_group:
367+
# transform the name into an attribute access on the appropriate parameter object
368+
return cst.Attribute(
369+
value=cst.Name(self.param_to_group[name_value]), attr=cst.Name(name_value)
370+
)
371+
return updated_node
372+
373+
def leave_Attribute(
374+
self, original_node: cst.Attribute, updated_node: cst.Attribute
375+
) -> cst.BaseExpression:
376+
"""
377+
Handle method calls and attribute access on parameters.
378+
This method handles several cases:
379+
380+
1. Assignment targets (eg: self.x = y)
381+
2. Simple attribute access (eg: self.x or report.x)
382+
3. Nested attribute access (eg: data_params.user_id)
383+
4. Subscript access (eg: self.settings["timezone"])
384+
5. Parameter attribute access (eg: username.strip())
385+
386+
Args:
387+
original_node: The original Attribute node
388+
updated_node: The current state of the Attribute node
389+
390+
Returns:
391+
The transformed node or the original if no transformation is needed
392+
"""
393+
# don't transform if this is part of an assignment target
394+
if self.is_in_assignment_target(original_node):
395+
# if this is a simple attribute access (eg: self.x or report.x), don't transform it
396+
if isinstance(updated_node.value, cst.Name) and updated_node.value.value in {
397+
"self",
398+
"report",
399+
}:
400+
return original_node
401+
return updated_node
402+
403+
# if this is a nested attribute access (eg: data_params.user_id), don't transform it further
404+
if (
405+
isinstance(updated_node.value, cst.Attribute)
406+
and isinstance(updated_node.value.value, cst.Name)
407+
and updated_node.value.value.value in {"data_params", "config_params"}
408+
):
409+
return updated_node
410+
411+
# if this is a simple attribute access (eg: self.x or report.x), don't transform it
412+
if isinstance(updated_node.value, cst.Name) and updated_node.value.value in {
413+
"self",
414+
"report",
415+
}:
416+
# check if this is part of a subscript target (eg: self.settings["timezone"])
417+
parent = self.parent_provider.get(original_node)
418+
if isinstance(parent, cst.Subscript):
419+
return original_node
420+
# check if this is part of a subscript value
421+
if isinstance(parent, cst.SubscriptElement):
422+
return original_node
423+
return original_node
424+
425+
# if the attribute's value is a parameter name, update it to use the encapsulated parameter object
426+
if (
427+
isinstance(updated_node.value, cst.Name)
428+
and updated_node.value.value in self.param_to_group
429+
):
430+
param_name = updated_node.value.value
431+
return cst.Attribute(
432+
value=cst.Name(self.param_to_group[param_name]), attr=updated_node.attr
291433
)
292-
return updated_node.with_changes(value=new_value)
293434

294435
return updated_node
295436

296-
# wrap CST node in a MetadataWrapper to enable metadata analysis
297-
transformer = ParameterUsageTransformer(classified_params)
298-
return function_node.visit(transformer)
437+
# create transformer with metadata wrapper
438+
transformer = ParameterUsageTransformer(classified_params, wrapper)
439+
# transform the function body
440+
updated_module = module.visit(transformer)
441+
# return the transformed function
442+
return updated_module.body[0]
299443

300444
@staticmethod
301445
def get_enclosing_class_name(
302-
tree: cst.Module, # noqa: ARG004
303446
init_node: cst.FunctionDef,
304447
parent_metadata: Mapping[cst.CSTNode, cst.CSTNode],
305448
) -> Optional[str]:
306449
"""
307450
Finds the class name enclosing the given __init__ function node.
308451
"""
309-
# wrapper = MetadataWrapper(tree)
310452
current_node = init_node
311453
while current_node in parent_metadata:
312454
parent = parent_metadata[current_node]
@@ -324,15 +466,7 @@ def update_function_calls(
324466
classified_param_names: tuple[str, str],
325467
enclosing_class_name: str,
326468
) -> cst.Module:
327-
"""
328-
Updates all calls to a given function in the provided CST tree to reflect new encapsulated parameters
329-
:param tree: CST tree of the code.
330-
:param function_node: CST node of the function to update calls for.
331-
:param params: A dictionary containing 'data' and 'config' parameters.
332-
:return: The updated CST tree
333-
"""
334469
param_to_group = {}
335-
336470
for group_name, params in zip(classified_param_names, classified_params.values()):
337471
for param in params:
338472
param_to_group[param] = group_name
@@ -341,6 +475,15 @@ def update_function_calls(
341475
if function_name == "__init__":
342476
function_name = enclosing_class_name
343477

478+
# Get all parameter names from the function definition
479+
all_param_names = [p.name.value for p in function_node.params.params]
480+
# Find where variadic args start (if any)
481+
variadic_start = len(all_param_names)
482+
for i, param in enumerate(function_node.params.params):
483+
if param.star == "*" or param.star == "**":
484+
variadic_start = i
485+
break
486+
344487
class FunctionCallTransformer(cst.CSTTransformer):
345488
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: # noqa: ARG002
346489
"""Transforms function calls to use grouped parameters."""
@@ -361,13 +504,27 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Cal
361504

362505
positional_args = []
363506
keyword_args = {}
364-
365-
# Separate positional and keyword arguments
366-
for arg in updated_node.args:
367-
if arg.keyword is None:
368-
positional_args.append(arg.value)
369-
else:
370-
keyword_args[arg.keyword.value] = arg.value
507+
variadic_args = []
508+
variadic_kwargs = {}
509+
510+
# Separate positional, keyword, and variadic arguments
511+
for i, arg in enumerate(updated_node.args):
512+
if isinstance(arg, cst.Arg):
513+
if arg.keyword is None:
514+
# If this is a positional argument beyond the number of parameters,
515+
# it's a variadic arg
516+
if i >= variadic_start:
517+
variadic_args.append(arg.value)
518+
elif i < len(used_params):
519+
positional_args.append(arg.value)
520+
else:
521+
# If this is a keyword argument for a used parameter, keep it
522+
if arg.keyword.value in param_to_group:
523+
keyword_args[arg.keyword.value] = arg.value
524+
# If this is a keyword argument not in the original parameters,
525+
# it's a variadic kwarg
526+
elif arg.keyword.value not in all_param_names:
527+
variadic_kwargs[arg.keyword.value] = arg.value
371528

372529
# Group arguments based on classified_params
373530
grouped_args = {group: [] for group in classified_param_names}
@@ -397,6 +554,94 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Cal
397554
if grouped_args[group_name] # Skip empty groups
398555
]
399556

557+
# Add variadic positional arguments
558+
new_args.extend([cst.Arg(value=arg) for arg in variadic_args])
559+
560+
# Add variadic keyword arguments
561+
new_args.extend(
562+
[
563+
cst.Arg(keyword=cst.Name(key), value=value)
564+
for key, value in variadic_kwargs.items()
565+
]
566+
)
567+
568+
return updated_node.with_changes(args=new_args)
569+
570+
transformer = FunctionCallTransformer()
571+
return tree.visit(transformer)
572+
573+
@staticmethod
574+
def update_function_calls_unclassified(
575+
tree: cst.Module,
576+
function_node: cst.FunctionDef,
577+
used_params: list[str],
578+
enclosing_class_name: str,
579+
) -> cst.Module:
580+
"""
581+
Updates all calls to a given function to only include used parameters.
582+
This is used when parameters are removed without being classified into objects.
583+
584+
Args:
585+
tree: CST tree of the code
586+
function_node: CST node of the function to update calls for
587+
used_params: List of parameter names that are actually used in the function
588+
enclosing_class_name: Name of the enclosing class if this is a method
589+
590+
Returns:
591+
Updated CST tree with modified function calls
592+
"""
593+
function_name = function_node.name.value
594+
if function_name == "__init__":
595+
function_name = enclosing_class_name
596+
597+
class FunctionCallTransformer(cst.CSTTransformer):
598+
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: # noqa: ARG002
599+
"""Transforms function calls to only include used parameters."""
600+
# handle both standalone function calls and instance method calls
601+
if not isinstance(updated_node.func, (cst.Name, cst.Attribute)):
602+
return updated_node
603+
604+
# extract the function/method name
605+
func_name = (
606+
updated_node.func.attr.value
607+
if isinstance(updated_node.func, cst.Attribute)
608+
else updated_node.func.value
609+
)
610+
611+
# if not the target function, leave unchanged
612+
if func_name != function_name:
613+
return updated_node
614+
615+
# map original parameters to their positions
616+
param_positions = {
617+
param.name.value: i for i, param in enumerate(function_node.params.params)
618+
}
619+
620+
# keep track of which positions in the argument list correspond to used parameters
621+
used_positions = {i for param, i in param_positions.items() if param in used_params}
622+
623+
new_args = []
624+
pos_arg_count = 0
625+
626+
# process all arguments
627+
for arg in updated_node.args:
628+
if arg.keyword is None:
629+
# handle positional arguments
630+
if pos_arg_count in used_positions:
631+
new_args.append(arg)
632+
pos_arg_count += 1
633+
else:
634+
# handle keyword arguments
635+
if arg.keyword.value in used_params:
636+
# keep keyword arguments for used parameters
637+
new_args.append(arg)
638+
639+
# ensure the last argument does not have a trailing comma
640+
if new_args:
641+
final_args = new_args[:-1]
642+
final_args.append(new_args[-1].with_changes(comma=cst.MaybeSentinel.DEFAULT))
643+
new_args = final_args
644+
400645
return updated_node.with_changes(args=new_args)
401646

402647
transformer = FunctionCallTransformer()
@@ -499,7 +744,7 @@ def refactor(
499744
self.is_constructor = self.function_node.name.value == "__init__"
500745
if self.is_constructor:
501746
self.enclosing_class_name = FunctionCallUpdater.get_enclosing_class_name(
502-
tree, self.function_node, parent_metadata
747+
self.function_node, parent_metadata
503748
)
504749
param_names = [
505750
param.name.value
@@ -562,6 +807,11 @@ def refactor(
562807
self.function_node, self.used_params, default_value_params
563808
)
564809

810+
# update all calls to match the new signature
811+
tree = self.function_updater.update_function_calls_unclassified(
812+
tree, self.function_node, self.used_params, self.enclosing_class_name
813+
)
814+
565815
class FunctionReplacer(cst.CSTTransformer):
566816
def __init__(
567817
self, original_function: cst.FunctionDef, updated_function: cst.FunctionDef

0 commit comments

Comments
 (0)