Skip to content

Commit

Permalink
updated group metrics calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Sep 1, 2024
1 parent 27191b2 commit e78de48
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions folktexts/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def evaluate_binary_predictions_fairness(
y_true: np.ndarray,
y_pred: np.ndarray,
sensitive_attribute: np.ndarray,
return_groupwise_metrics: Optional[bool] = False,
return_groupwise_metrics: bool = False,
min_group_size: float = 0.04,
) -> dict:
"""Evaluates fairness of the given predictions.
Expand All @@ -96,9 +97,14 @@ def evaluate_binary_predictions_fairness(
The discretized predictions.
sensitive_attribute : np.ndarray
The sensitive attribute (protected group membership).
return_groupwise_metrics : Optional[bool], optional
return_groupwise_metrics : bool, optional
Whether to return group-wise performance metrics (bool: True) or only
the ratios between these metrics (bool: False), by default False.
min_group_size : float, optional
The minimum fraction of samples (as a fraction of the total number of
samples) that a group must have to be considered for fairness
evaluation, by default 0.04. This is meant to avoid evaluating metrics
on very small groups which leads to noisy and inconsistent results.
Returns
-------
Expand All @@ -124,6 +130,10 @@ def group_metric_name(metric_name, group_name):
# Indices of samples that belong to the current group
group_indices = np.argwhere(sensitive_attribute == s_value).flatten()

if len(group_indices) < min_group_size * len(y_true):
logging.warning(f"Skipping group {s_value} with {len(group_indices)} samples")
continue

# Filter labels and predictions for samples of the current group
group_labels = y_true[group_indices]
group_preds = y_pred[group_indices]
Expand All @@ -144,8 +154,9 @@ def group_metric_name(metric_name, group_name):
# Compute ratios and absolute diffs
for metric_name in unique_metrics:
curr_metric_results = [
groupwise_metrics[group_metric_name(metric_name, group_name)]
groupwise_metrics[curr_group_metric_name]
for group_name in unique_groups
if (curr_group_metric_name := group_metric_name(metric_name, group_name)) in groupwise_metrics
]

# Metrics' ratio
Expand Down

0 comments on commit e78de48

Please sign in to comment.