diff --git a/Orange/widgets/data/oweditdomain.py b/Orange/widgets/data/oweditdomain.py index 567d8b5781d..a2a4b848f8c 100644 --- a/Orange/widgets/data/oweditdomain.py +++ b/Orange/widgets/data/oweditdomain.py @@ -680,7 +680,8 @@ class GroupItemsDialog(QDialog): DEFAULT_LABEL = "other" def __init__( - self, variable: Categorical, data: Union[np.ndarray, List], + self, variable: Categorical, + data: Union[np.ndarray, List, MArray], selected_attributes: List[str], dialog_settings: Dict[str, Any], parent: QWidget = None, flags: Qt.WindowFlags = Qt.Dialog, **kwargs ) -> None: @@ -814,10 +815,18 @@ def get_merge_attributes(self) -> List[str]: ------- List of attributes' to be merged names """ - counts = Counter(self.data) if self.selected_radio.isChecked(): return self.selected_attributes - elif self.n_values_radio.isChecked(): + + if isinstance(self.data, MArray): + non_nan = self.data[~self.data.mask] + elif isinstance(self.data, np.ndarray): + non_nan = self.data[~np.isnan(self.data)] + else: # list + non_nan = [x for x in self.data if x is not None] + + counts = Counter(non_nan) + if self.n_values_radio.isChecked(): keep_values = self.n_values_spin.value() values = counts.most_common()[keep_values:] indices = [i for i, _ in values] @@ -828,6 +837,8 @@ def get_merge_attributes(self) -> List[str]: n_all = sum(counts.values()) indices = [v for v, c in counts.most_common() if c / n_all * 100 < self.frequent_rel_spin.value()] + + indices = np.array(indices, dtype=int) # indices must be ints return np.array(self.variable.categories)[indices].tolist() def get_merged_value_name(self) -> str: diff --git a/Orange/widgets/data/tests/test_oweditdomain.py b/Orange/widgets/data/tests/test_oweditdomain.py index 4ef398bf342..e37cc5e2392 100644 --- a/Orange/widgets/data/tests/test_oweditdomain.py +++ b/Orange/widgets/data/tests/test_oweditdomain.py @@ -1026,6 +1026,40 @@ def test_group_keep_n(self): dialog.n_values_spin.setValue(3) self.assertListEqual(dialog.get_merge_attributes(), []) + def test_group_less_frequent_missing(self): + """ + Widget gives MaskedArray to GroupItemsDialog which can have missing + values. + gh-4599 + """ + def _test_correctness(): + dialog.frequent_abs_radio.setChecked(True) + dialog.frequent_abs_spin.setValue(3) + self.assertListEqual(dialog.get_merge_attributes(), ["b", "c"]) + + dialog.frequent_rel_radio.setChecked(True) + dialog.frequent_rel_spin.setValue(50) + self.assertListEqual(dialog.get_merge_attributes(), ["b", "c"]) + + dialog.n_values_radio.setChecked(True) + dialog.n_values_spin.setValue(1) + self.assertListEqual(dialog.get_merge_attributes(), ["b", "c"]) + + # masked array + data_masked = np.ma.array( + [0, 0, np.nan, 0, 1, 1, 2], mask=[0, 0, 1, 0, 0, 0, 0] + ) + dialog = GroupItemsDialog(self.v, data_masked, [], {}) + _test_correctness() + + data_array = np.array([0, 0, np.nan, 0, 1, 1, 2]) + dialog = GroupItemsDialog(self.v, data_array, [], {}) + _test_correctness() + + data_list = [0, 0, None, 0, 1, 1, 2] + dialog = GroupItemsDialog(self.v, data_list, [], {}) + _test_correctness() + if __name__ == '__main__': unittest.main()