@@ -260,53 +260,195 @@ def update_parameter_usages(
260260 ):
261261 """
262262 Updates the function body to use encapsulated parameter objects.
263+ This method transforms parameter references in the function body to use new data_params
264+ and config_params objects.
265+
266+ Args:
267+ function_node: CST node of the function to transform
268+ classified_params: Dictionary mapping parameter groups ('data_params' or 'config_params')
269+ to lists of parameter names in each group
270+
271+ Returns:
272+ The transformed function node with updated parameter usages
263273 """
274+ # Create a module with just the function to get metadata
275+ module = cst .Module (body = [function_node ])
276+ wrapper = MetadataWrapper (module )
264277
265278 class ParameterUsageTransformer (cst .CSTTransformer ):
266- def __init__ (self , classified_params : dict [str , list [str ]]):
267- self .param_to_group = {}
279+ """
280+ A CST transformer that updates parameter references to use the new parameter objects.
281+ """
268282
283+ METADATA_DEPENDENCIES = (ParentNodeProvider ,)
284+
285+ def __init__ (
286+ self , classified_params : dict [str , list [str ]], metadata_wrapper : MetadataWrapper
287+ ):
288+ super ().__init__ ()
289+ # map each parameter to its group (data_params or config_params)
290+ self .param_to_group = {}
291+ self .parent_provider = metadata_wrapper .resolve (ParentNodeProvider )
269292 # flatten classified_params to map each param to its group (dataParams or configParams)
270293 for group , params in classified_params .items ():
271294 for param in params :
272295 self .param_to_group [param ] = group
273296
274- def leave_Assign (
275- self ,
276- original_node : cst .Assign , # noqa: ARG002
277- updated_node : cst .Assign ,
278- ) -> cst .Assign :
297+ def is_in_assignment_target (self , node : cst .CSTNode ) -> bool :
279298 """
280- Transform only right-hand side references to parameters that need to be updated.
281- Ensure left-hand side (self attributes) remain unchanged.
299+ Check if a node is part of an assignment target (left side of =).
300+
301+ Args:
302+ node: The CST node to check
303+
304+ Returns:
305+ True if the node is part of an assignment target that should not be transformed,
306+ False otherwise
307+ """
308+ current = node
309+ while current :
310+ parent = self .parent_provider .get (current )
311+
312+ # if we're at an AssignTarget, check if it's a simple Name assignment
313+ if isinstance (parent , cst .AssignTarget ):
314+ if isinstance (current , cst .Name ):
315+ # allow transformation for simple parameter assignments
316+ return False
317+ return True
318+
319+ if isinstance (parent , cst .Assign ):
320+ # if we reach an Assign node, check if we came from the targets
321+ for target in parent .targets :
322+ if target .target .deep_equals (current ):
323+ if isinstance (current , cst .Name ):
324+ # allow transformation for simple parameter assignments
325+ return False
326+ return True
327+ return False
328+
329+ if isinstance (parent , cst .Module ):
330+ return False
331+
332+ current = parent
333+ return False
334+
335+ def leave_Name (
336+ self , original_node : cst .Name , updated_node : cst .Name
337+ ) -> cst .BaseExpression :
282338 """
283- if not isinstance (updated_node .value , cst .Name ):
339+ Transform standalone parameter references.
340+
341+ Skip transformation if:
342+ 1. The name is part of an attribute access (eg: self.param)
343+ 2. The name is part of a complex assignment target (eg: self.x = y)
344+
345+ Transform if:
346+ 1. The name is a simple parameter being assigned (eg: param1 = value)
347+ 2. The name is used as a value (eg: x = param1)
348+
349+ Args:
350+ original_node: The original Name node
351+ updated_node: The current state of the Name node
352+
353+ Returns:
354+ The transformed node or the original if no transformation is needed
355+ """
356+ # dont't transform if this is part of a complex assignment target
357+ if self .is_in_assignment_target (original_node ):
284358 return updated_node
285359
286- var_name = updated_node .value .value
360+ # dont't transform if this is part of an attribute access (e.g., self.param)
361+ parent = self .parent_provider .get (original_node )
362+ if isinstance (parent , cst .Attribute ) and original_node is parent .attr :
363+ return updated_node
287364
288- if var_name in self .param_to_group :
289- new_value = cst .Attribute (
290- value = cst .Name (self .param_to_group [var_name ]), attr = cst .Name (var_name )
365+ name_value = updated_node .value
366+ if name_value in self .param_to_group :
367+ # transform the name into an attribute access on the appropriate parameter object
368+ return cst .Attribute (
369+ value = cst .Name (self .param_to_group [name_value ]), attr = cst .Name (name_value )
370+ )
371+ return updated_node
372+
373+ def leave_Attribute (
374+ self , original_node : cst .Attribute , updated_node : cst .Attribute
375+ ) -> cst .BaseExpression :
376+ """
377+ Handle method calls and attribute access on parameters.
378+ This method handles several cases:
379+
380+ 1. Assignment targets (eg: self.x = y)
381+ 2. Simple attribute access (eg: self.x or report.x)
382+ 3. Nested attribute access (eg: data_params.user_id)
383+ 4. Subscript access (eg: self.settings["timezone"])
384+ 5. Parameter attribute access (eg: username.strip())
385+
386+ Args:
387+ original_node: The original Attribute node
388+ updated_node: The current state of the Attribute node
389+
390+ Returns:
391+ The transformed node or the original if no transformation is needed
392+ """
393+ # don't transform if this is part of an assignment target
394+ if self .is_in_assignment_target (original_node ):
395+ # if this is a simple attribute access (eg: self.x or report.x), don't transform it
396+ if isinstance (updated_node .value , cst .Name ) and updated_node .value .value in {
397+ "self" ,
398+ "report" ,
399+ }:
400+ return original_node
401+ return updated_node
402+
403+ # if this is a nested attribute access (eg: data_params.user_id), don't transform it further
404+ if (
405+ isinstance (updated_node .value , cst .Attribute )
406+ and isinstance (updated_node .value .value , cst .Name )
407+ and updated_node .value .value .value in {"data_params" , "config_params" }
408+ ):
409+ return updated_node
410+
411+ # if this is a simple attribute access (eg: self.x or report.x), don't transform it
412+ if isinstance (updated_node .value , cst .Name ) and updated_node .value .value in {
413+ "self" ,
414+ "report" ,
415+ }:
416+ # check if this is part of a subscript target (eg: self.settings["timezone"])
417+ parent = self .parent_provider .get (original_node )
418+ if isinstance (parent , cst .Subscript ):
419+ return original_node
420+ # check if this is part of a subscript value
421+ if isinstance (parent , cst .SubscriptElement ):
422+ return original_node
423+ return original_node
424+
425+ # if the attribute's value is a parameter name, update it to use the encapsulated parameter object
426+ if (
427+ isinstance (updated_node .value , cst .Name )
428+ and updated_node .value .value in self .param_to_group
429+ ):
430+ param_name = updated_node .value .value
431+ return cst .Attribute (
432+ value = cst .Name (self .param_to_group [param_name ]), attr = updated_node .attr
291433 )
292- return updated_node .with_changes (value = new_value )
293434
294435 return updated_node
295436
296- # wrap CST node in a MetadataWrapper to enable metadata analysis
297- transformer = ParameterUsageTransformer (classified_params )
298- return function_node .visit (transformer )
437+ # create transformer with metadata wrapper
438+ transformer = ParameterUsageTransformer (classified_params , wrapper )
439+ # transform the function body
440+ updated_module = module .visit (transformer )
441+ # return the transformed function
442+ return updated_module .body [0 ]
299443
300444 @staticmethod
301445 def get_enclosing_class_name (
302- tree : cst .Module , # noqa: ARG004
303446 init_node : cst .FunctionDef ,
304447 parent_metadata : Mapping [cst .CSTNode , cst .CSTNode ],
305448 ) -> Optional [str ]:
306449 """
307450 Finds the class name enclosing the given __init__ function node.
308451 """
309- # wrapper = MetadataWrapper(tree)
310452 current_node = init_node
311453 while current_node in parent_metadata :
312454 parent = parent_metadata [current_node ]
@@ -324,15 +466,7 @@ def update_function_calls(
324466 classified_param_names : tuple [str , str ],
325467 enclosing_class_name : str ,
326468 ) -> cst .Module :
327- """
328- Updates all calls to a given function in the provided CST tree to reflect new encapsulated parameters
329- :param tree: CST tree of the code.
330- :param function_node: CST node of the function to update calls for.
331- :param params: A dictionary containing 'data' and 'config' parameters.
332- :return: The updated CST tree
333- """
334469 param_to_group = {}
335-
336470 for group_name , params in zip (classified_param_names , classified_params .values ()):
337471 for param in params :
338472 param_to_group [param ] = group_name
@@ -341,6 +475,15 @@ def update_function_calls(
341475 if function_name == "__init__" :
342476 function_name = enclosing_class_name
343477
478+ # Get all parameter names from the function definition
479+ all_param_names = [p .name .value for p in function_node .params .params ]
480+ # Find where variadic args start (if any)
481+ variadic_start = len (all_param_names )
482+ for i , param in enumerate (function_node .params .params ):
483+ if param .star == "*" or param .star == "**" :
484+ variadic_start = i
485+ break
486+
344487 class FunctionCallTransformer (cst .CSTTransformer ):
345488 def leave_Call (self , original_node : cst .Call , updated_node : cst .Call ) -> cst .Call : # noqa: ARG002
346489 """Transforms function calls to use grouped parameters."""
@@ -361,13 +504,27 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Cal
361504
362505 positional_args = []
363506 keyword_args = {}
364-
365- # Separate positional and keyword arguments
366- for arg in updated_node .args :
367- if arg .keyword is None :
368- positional_args .append (arg .value )
369- else :
370- keyword_args [arg .keyword .value ] = arg .value
507+ variadic_args = []
508+ variadic_kwargs = {}
509+
510+ # Separate positional, keyword, and variadic arguments
511+ for i , arg in enumerate (updated_node .args ):
512+ if isinstance (arg , cst .Arg ):
513+ if arg .keyword is None :
514+ # If this is a positional argument beyond the number of parameters,
515+ # it's a variadic arg
516+ if i >= variadic_start :
517+ variadic_args .append (arg .value )
518+ elif i < len (used_params ):
519+ positional_args .append (arg .value )
520+ else :
521+ # If this is a keyword argument for a used parameter, keep it
522+ if arg .keyword .value in param_to_group :
523+ keyword_args [arg .keyword .value ] = arg .value
524+ # If this is a keyword argument not in the original parameters,
525+ # it's a variadic kwarg
526+ elif arg .keyword .value not in all_param_names :
527+ variadic_kwargs [arg .keyword .value ] = arg .value
371528
372529 # Group arguments based on classified_params
373530 grouped_args = {group : [] for group in classified_param_names }
@@ -397,6 +554,94 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Cal
397554 if grouped_args [group_name ] # Skip empty groups
398555 ]
399556
557+ # Add variadic positional arguments
558+ new_args .extend ([cst .Arg (value = arg ) for arg in variadic_args ])
559+
560+ # Add variadic keyword arguments
561+ new_args .extend (
562+ [
563+ cst .Arg (keyword = cst .Name (key ), value = value )
564+ for key , value in variadic_kwargs .items ()
565+ ]
566+ )
567+
568+ return updated_node .with_changes (args = new_args )
569+
570+ transformer = FunctionCallTransformer ()
571+ return tree .visit (transformer )
572+
573+ @staticmethod
574+ def update_function_calls_unclassified (
575+ tree : cst .Module ,
576+ function_node : cst .FunctionDef ,
577+ used_params : list [str ],
578+ enclosing_class_name : str ,
579+ ) -> cst .Module :
580+ """
581+ Updates all calls to a given function to only include used parameters.
582+ This is used when parameters are removed without being classified into objects.
583+
584+ Args:
585+ tree: CST tree of the code
586+ function_node: CST node of the function to update calls for
587+ used_params: List of parameter names that are actually used in the function
588+ enclosing_class_name: Name of the enclosing class if this is a method
589+
590+ Returns:
591+ Updated CST tree with modified function calls
592+ """
593+ function_name = function_node .name .value
594+ if function_name == "__init__" :
595+ function_name = enclosing_class_name
596+
597+ class FunctionCallTransformer (cst .CSTTransformer ):
598+ def leave_Call (self , original_node : cst .Call , updated_node : cst .Call ) -> cst .Call : # noqa: ARG002
599+ """Transforms function calls to only include used parameters."""
600+ # handle both standalone function calls and instance method calls
601+ if not isinstance (updated_node .func , (cst .Name , cst .Attribute )):
602+ return updated_node
603+
604+ # extract the function/method name
605+ func_name = (
606+ updated_node .func .attr .value
607+ if isinstance (updated_node .func , cst .Attribute )
608+ else updated_node .func .value
609+ )
610+
611+ # if not the target function, leave unchanged
612+ if func_name != function_name :
613+ return updated_node
614+
615+ # map original parameters to their positions
616+ param_positions = {
617+ param .name .value : i for i , param in enumerate (function_node .params .params )
618+ }
619+
620+ # keep track of which positions in the argument list correspond to used parameters
621+ used_positions = {i for param , i in param_positions .items () if param in used_params }
622+
623+ new_args = []
624+ pos_arg_count = 0
625+
626+ # process all arguments
627+ for arg in updated_node .args :
628+ if arg .keyword is None :
629+ # handle positional arguments
630+ if pos_arg_count in used_positions :
631+ new_args .append (arg )
632+ pos_arg_count += 1
633+ else :
634+ # handle keyword arguments
635+ if arg .keyword .value in used_params :
636+ # keep keyword arguments for used parameters
637+ new_args .append (arg )
638+
639+ # ensure the last argument does not have a trailing comma
640+ if new_args :
641+ final_args = new_args [:- 1 ]
642+ final_args .append (new_args [- 1 ].with_changes (comma = cst .MaybeSentinel .DEFAULT ))
643+ new_args = final_args
644+
400645 return updated_node .with_changes (args = new_args )
401646
402647 transformer = FunctionCallTransformer ()
@@ -499,7 +744,7 @@ def refactor(
499744 self .is_constructor = self .function_node .name .value == "__init__"
500745 if self .is_constructor :
501746 self .enclosing_class_name = FunctionCallUpdater .get_enclosing_class_name (
502- tree , self .function_node , parent_metadata
747+ self .function_node , parent_metadata
503748 )
504749 param_names = [
505750 param .name .value
@@ -562,6 +807,11 @@ def refactor(
562807 self .function_node , self .used_params , default_value_params
563808 )
564809
810+ # update all calls to match the new signature
811+ tree = self .function_updater .update_function_calls_unclassified (
812+ tree , self .function_node , self .used_params , self .enclosing_class_name
813+ )
814+
565815 class FunctionReplacer (cst .CSTTransformer ):
566816 def __init__ (
567817 self , original_function : cst .FunctionDef , updated_function : cst .FunctionDef
0 commit comments