Skip to content

Commit e78de48

Browse files
committed
updated group metrics calculation
1 parent 27191b2 commit e78de48

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

folktexts/evaluation.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ def evaluate_binary_predictions_fairness(
8181
y_true: np.ndarray,
8282
y_pred: np.ndarray,
8383
sensitive_attribute: np.ndarray,
84-
return_groupwise_metrics: Optional[bool] = False,
84+
return_groupwise_metrics: bool = False,
85+
min_group_size: float = 0.04,
8586
) -> dict:
8687
"""Evaluates fairness of the given predictions.
8788
@@ -96,9 +97,14 @@ def evaluate_binary_predictions_fairness(
9697
The discretized predictions.
9798
sensitive_attribute : np.ndarray
9899
The sensitive attribute (protected group membership).
99-
return_groupwise_metrics : Optional[bool], optional
100+
return_groupwise_metrics : bool, optional
100101
Whether to return group-wise performance metrics (bool: True) or only
101102
the ratios between these metrics (bool: False), by default False.
103+
min_group_size : float, optional
104+
The minimum fraction of samples (as a fraction of the total number of
105+
samples) that a group must have to be considered for fairness
106+
evaluation, by default 0.04. This is meant to avoid evaluating metrics
107+
on very small groups which leads to noisy and inconsistent results.
102108
103109
Returns
104110
-------
@@ -124,6 +130,10 @@ def group_metric_name(metric_name, group_name):
124130
# Indices of samples that belong to the current group
125131
group_indices = np.argwhere(sensitive_attribute == s_value).flatten()
126132

133+
if len(group_indices) < min_group_size * len(y_true):
134+
logging.warning(f"Skipping group {s_value} with {len(group_indices)} samples")
135+
continue
136+
127137
# Filter labels and predictions for samples of the current group
128138
group_labels = y_true[group_indices]
129139
group_preds = y_pred[group_indices]
@@ -144,8 +154,9 @@ def group_metric_name(metric_name, group_name):
144154
# Compute ratios and absolute diffs
145155
for metric_name in unique_metrics:
146156
curr_metric_results = [
147-
groupwise_metrics[group_metric_name(metric_name, group_name)]
157+
groupwise_metrics[curr_group_metric_name]
148158
for group_name in unique_groups
159+
if (curr_group_metric_name := group_metric_name(metric_name, group_name)) in groupwise_metrics
149160
]
150161

151162
# Metrics' ratio

0 commit comments

Comments
 (0)