44
55from safeds ._utils import _structural_hash
66from safeds ._validation import _check_bounds , _check_columns_exist , _ClosedBound
7+ from safeds ._validation ._check_columns_are_numeric import _check_columns_are_numeric
78from safeds .data .tabular .containers import Table
89from safeds .exceptions import (
910 NonNumericColumnError ,
@@ -24,6 +25,8 @@ class Discretizer(TableTransformer):
2425 ----------
2526 bin_count:
2627 The number of bins to be created.
28+ column_names:
29+ The list of columns used to fit the transformer. If `None`, all numeric columns are used.
2730
2831 Raises
2932 ------
@@ -35,8 +38,13 @@ class Discretizer(TableTransformer):
3538 # Dunder methods
3639 # ------------------------------------------------------------------------------------------------------------------
3740
38- def __init__ (self , bin_count : int = 5 ) -> None :
39- TableTransformer .__init__ (self )
41+ def __init__ (
42+ self ,
43+ bin_count : int = 5 ,
44+ * ,
45+ column_names : str | list [str ] | None = None ,
46+ ) -> None :
47+ TableTransformer .__init__ (self , column_names )
4048
4149 _check_bounds ("bin_count" , bin_count , lower_bound = _ClosedBound (2 ))
4250
@@ -53,6 +61,10 @@ def __hash__(self) -> int:
5361 # Properties
5462 # ------------------------------------------------------------------------------------------------------------------
5563
64+ @property
65+ def is_fitted (self ) -> bool :
66+ return self ._wrapped_transformer is not None
67+
5668 @property
5769 def bin_count (self ) -> int :
5870 return self ._bin_count
@@ -61,7 +73,7 @@ def bin_count(self) -> int:
6173 # Learning and transformation
6274 # ------------------------------------------------------------------------------------------------------------------
6375
64- def fit (self , table : Table , column_names : list [ str ] | None ) -> Discretizer :
76+ def fit (self , table : Table ) -> Discretizer :
6577 """
6678 Learn a transformation for a set of columns in a table.
6779
@@ -71,8 +83,6 @@ def fit(self, table: Table, column_names: list[str] | None) -> Discretizer:
7183 ----------
7284 table:
7385 The table used to fit the transformer.
74- column_names:
75- The list of columns from the table used to fit the transformer. If `None`, all columns are used.
7686
7787 Returns
7888 -------
@@ -93,24 +103,21 @@ def fit(self, table: Table, column_names: list[str] | None) -> Discretizer:
93103 if table .row_count == 0 :
94104 raise ValueError ("The Discretizer cannot be fitted because the table contains 0 rows" )
95105
96- if column_names is None :
97- column_names = table .column_names
106+ if self . _column_names is None :
107+ column_names = [ name for name in table .column_names if table . get_column_type ( name ). is_numeric ]
98108 else :
109+ column_names = self ._column_names
99110 _check_columns_exist (table , column_names )
100-
101- for column in column_names :
102- if not table .get_column (column ).type .is_numeric :
103- raise NonNumericColumnError (f"{ column } is of type { table .get_column (column ).type } ." )
111+ _check_columns_are_numeric (table , column_names , operation = "fit a Discretizer" )
104112
105113 wrapped_transformer = sk_KBinsDiscretizer (n_bins = self ._bin_count , encode = "ordinal" )
106114 wrapped_transformer .set_output (transform = "polars" )
107115 wrapped_transformer .fit (
108116 table .remove_columns_except (column_names )._data_frame ,
109117 )
110118
111- result = Discretizer (self ._bin_count )
119+ result = Discretizer (self ._bin_count , column_names = column_names )
112120 result ._wrapped_transformer = wrapped_transformer
113- result ._column_names = column_names
114121
115122 return result
116123
0 commit comments