@@ -378,7 +378,8 @@ def generate_label_transfer_dict(
378378 cat2 : List [str ],
379379 positive_pairs : Optional [List [Dict [str , Union [List [str ], float ]]]] = None ,
380380 negative_pairs : Optional [List [Dict [str , Union [List [str ], float ]]]] = None ,
381- default_positve_value : float = 10.0 ,
381+ default_positive_value : float = 10.0 ,
382+ default_negative_value : float = 1.0 ,
382383) -> Dict [str , Dict [str , float ]]:
383384 """
384385 Generate a label transfer dictionary with normalized values.
@@ -394,20 +395,22 @@ def generate_label_transfer_dict(
394395 List of negative pairs with transfer values. Each dictionary should have 'left', 'right', and 'value' keys. Defaults to None.
395396 default_positive_value (float, optional):
396397 Default value for positive pairs if none are provided. Defaults to 10.0.
398+ default_negative_value (float, optional):
399+ Default value for negative pairs if none are provided. Defaults to 1.0.
397400
398401 Returns:
399402 Dict[str, Dict[str, float]]:
400403 A normalized label transfer dictionary.
401404 """
402405
403406 # Initialize label transfer dictionary with default values
404- # label_transfer_dict = {c2: {c1: 1.0 for c1 in cat1} for c2 in cat2}
405407 label_transfer_dict = {c1 : {c2 : 1.0 for c2 in cat2 } for c1 in cat1 }
406408
407409 # Generate default positive pairs if none provided
408410 if (positive_pairs is None ) and (negative_pairs is None ):
411+ label_transfer_dict = {c1 : {c2 : default_negative_value for c2 in cat2 } for c1 in cat1 }
409412 common_cat = np .union1d (cat1 , cat2 )
410- positive_pairs = [{"left" : [c ], "right" : [c ], "value" : default_positve_value } for c in common_cat ]
413+ positive_pairs = [{"left" : [c ], "right" : [c ], "value" : default_positive_value } for c in common_cat ]
411414
412415 # Apply positive pairs to the dictionary
413416 if positive_pairs is not None :
@@ -431,11 +434,6 @@ def generate_label_transfer_dict(
431434 norm_c = np .array ([label_transfer_dict [c1 ][c2 ] for c2 in cat2 ]).sum ()
432435 norm_label_transfer_dict [c1 ] = {c2 : label_transfer_dict [c1 ][c2 ] / norm_c for c2 in cat2 }
433436
434- # norm_label_transfer_dict = dict()
435- # for c2 in cat2:
436- # norm_c = np.array([label_transfer_dict[c2][c1] for c1 in cat1]).sum()
437- # norm_label_transfer_dict[c2] = {c1: label_transfer_dict[c2][c1] / norm_c for c1 in cat1}
438-
439437 return norm_label_transfer_dict
440438
441439
0 commit comments