Skip to content

Commit a05205a

Browse files
committed
Update generate_label_transfer_dict function by adding default_negative_value parameter.
1 parent f2c088a commit a05205a

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

spateo/alignment/methods/utils.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,8 @@ def generate_label_transfer_dict(
378378
cat2: List[str],
379379
positive_pairs: Optional[List[Dict[str, Union[List[str], float]]]] = None,
380380
negative_pairs: Optional[List[Dict[str, Union[List[str], float]]]] = None,
381-
default_positve_value: float = 10.0,
381+
default_positive_value: float = 10.0,
382+
default_negative_value: float = 1.0,
382383
) -> Dict[str, Dict[str, float]]:
383384
"""
384385
Generate a label transfer dictionary with normalized values.
@@ -394,20 +395,22 @@ def generate_label_transfer_dict(
394395
List of negative pairs with transfer values. Each dictionary should have 'left', 'right', and 'value' keys. Defaults to None.
395396
default_positive_value (float, optional):
396397
Default value for positive pairs if none are provided. Defaults to 10.0.
398+
default_negative_value (float, optional):
399+
Default value for negative pairs if none are provided. Defaults to 1.0.
397400
398401
Returns:
399402
Dict[str, Dict[str, float]]:
400403
A normalized label transfer dictionary.
401404
"""
402405

403406
# Initialize label transfer dictionary with default values
404-
# label_transfer_dict = {c2: {c1: 1.0 for c1 in cat1} for c2 in cat2}
405407
label_transfer_dict = {c1: {c2: 1.0 for c2 in cat2} for c1 in cat1}
406408

407409
# Generate default positive pairs if none provided
408410
if (positive_pairs is None) and (negative_pairs is None):
411+
label_transfer_dict = {c1: {c2: default_negative_value for c2 in cat2} for c1 in cat1}
409412
common_cat = np.union1d(cat1, cat2)
410-
positive_pairs = [{"left": [c], "right": [c], "value": default_positve_value} for c in common_cat]
413+
positive_pairs = [{"left": [c], "right": [c], "value": default_positive_value} for c in common_cat]
411414

412415
# Apply positive pairs to the dictionary
413416
if positive_pairs is not None:
@@ -431,11 +434,6 @@ def generate_label_transfer_dict(
431434
norm_c = np.array([label_transfer_dict[c1][c2] for c2 in cat2]).sum()
432435
norm_label_transfer_dict[c1] = {c2: label_transfer_dict[c1][c2] / norm_c for c2 in cat2}
433436

434-
# norm_label_transfer_dict = dict()
435-
# for c2 in cat2:
436-
# norm_c = np.array([label_transfer_dict[c2][c1] for c1 in cat1]).sum()
437-
# norm_label_transfer_dict[c2] = {c1: label_transfer_dict[c2][c1] / norm_c for c1 in cat1}
438-
439437
return norm_label_transfer_dict
440438

441439

0 commit comments

Comments
 (0)