diff --git a/folktexts/evaluation.py b/folktexts/evaluation.py index baa953b..1abb088 100644 --- a/folktexts/evaluation.py +++ b/folktexts/evaluation.py @@ -81,7 +81,8 @@ def evaluate_binary_predictions_fairness( y_true: np.ndarray, y_pred: np.ndarray, sensitive_attribute: np.ndarray, - return_groupwise_metrics: Optional[bool] = False, + return_groupwise_metrics: bool = False, + min_group_size: float = 0.04, ) -> dict: """Evaluates fairness of the given predictions. @@ -96,9 +97,14 @@ def evaluate_binary_predictions_fairness( The discretized predictions. sensitive_attribute : np.ndarray The sensitive attribute (protected group membership). - return_groupwise_metrics : Optional[bool], optional + return_groupwise_metrics : bool, optional Whether to return group-wise performance metrics (bool: True) or only the ratios between these metrics (bool: False), by default False. + min_group_size : float, optional + The minimum fraction of samples (as a fraction of the total number of + samples) that a group must have to be considered for fairness + evaluation, by default 0.04. This is meant to avoid evaluating metrics + on very small groups which leads to noisy and inconsistent results. Returns ------- @@ -124,6 +130,10 @@ def group_metric_name(metric_name, group_name): # Indices of samples that belong to the current group group_indices = np.argwhere(sensitive_attribute == s_value).flatten() + if len(group_indices) < min_group_size * len(y_true): + logging.warning(f"Skipping group {s_value} with {len(group_indices)} samples") + continue + # Filter labels and predictions for samples of the current group group_labels = y_true[group_indices] group_preds = y_pred[group_indices] @@ -144,8 +154,9 @@ def group_metric_name(metric_name, group_name): # Compute ratios and absolute diffs for metric_name in unique_metrics: curr_metric_results = [ - groupwise_metrics[group_metric_name(metric_name, group_name)] + groupwise_metrics[curr_group_metric_name] for group_name in unique_groups + if (curr_group_metric_name := group_metric_name(metric_name, group_name)) in groupwise_metrics ] # Metrics' ratio