@@ -81,7 +81,8 @@ def evaluate_binary_predictions_fairness(
81
81
y_true : np .ndarray ,
82
82
y_pred : np .ndarray ,
83
83
sensitive_attribute : np .ndarray ,
84
- return_groupwise_metrics : Optional [bool ] = False ,
84
+ return_groupwise_metrics : bool = False ,
85
+ min_group_size : float = 0.04 ,
85
86
) -> dict :
86
87
"""Evaluates fairness of the given predictions.
87
88
@@ -96,9 +97,14 @@ def evaluate_binary_predictions_fairness(
96
97
The discretized predictions.
97
98
sensitive_attribute : np.ndarray
98
99
The sensitive attribute (protected group membership).
99
- return_groupwise_metrics : Optional[ bool] , optional
100
+ return_groupwise_metrics : bool, optional
100
101
Whether to return group-wise performance metrics (bool: True) or only
101
102
the ratios between these metrics (bool: False), by default False.
103
+ min_group_size : float, optional
104
+ The minimum fraction of samples (as a fraction of the total number of
105
+ samples) that a group must have to be considered for fairness
106
+ evaluation, by default 0.04. This is meant to avoid evaluating metrics
107
+ on very small groups which leads to noisy and inconsistent results.
102
108
103
109
Returns
104
110
-------
@@ -124,6 +130,10 @@ def group_metric_name(metric_name, group_name):
124
130
# Indices of samples that belong to the current group
125
131
group_indices = np .argwhere (sensitive_attribute == s_value ).flatten ()
126
132
133
+ if len (group_indices ) < min_group_size * len (y_true ):
134
+ logging .warning (f"Skipping group { s_value } with { len (group_indices )} samples" )
135
+ continue
136
+
127
137
# Filter labels and predictions for samples of the current group
128
138
group_labels = y_true [group_indices ]
129
139
group_preds = y_pred [group_indices ]
@@ -144,8 +154,9 @@ def group_metric_name(metric_name, group_name):
144
154
# Compute ratios and absolute diffs
145
155
for metric_name in unique_metrics :
146
156
curr_metric_results = [
147
- groupwise_metrics [group_metric_name ( metric_name , group_name ) ]
157
+ groupwise_metrics [curr_group_metric_name ]
148
158
for group_name in unique_groups
159
+ if (curr_group_metric_name := group_metric_name (metric_name , group_name )) in groupwise_metrics
149
160
]
150
161
151
162
# Metrics' ratio
0 commit comments