Skip to content

Commit 69a780c

Browse files
feat: specify column names in constructor of table transformers (#795)
### Summary of Changes Specify the names of the columns that a table transformer should be applied to in its constructor instead of its `fit` method. This allows easier composition. --------- Co-authored-by: megalinter-bot <[email protected]>
1 parent f07bc5a commit 69a780c

25 files changed

+242
-171
lines changed

docs/tutorials/classification.ipynb

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
"source": [
7676
"from safeds.data.tabular.transformation import OneHotEncoder\n",
7777
"\n",
78-
"encoder = OneHotEncoder().fit(train_table, [\"sex\"])"
78+
"encoder = OneHotEncoder(column_names=\"sex\").fit(train_table)"
7979
],
8080
"metadata": {
8181
"collapsed": false
@@ -155,7 +155,6 @@
155155
{
156156
"cell_type": "code",
157157
"source": [
158-
"encoder = OneHotEncoder().fit(test_table, [\"sex\"])\n",
159158
"transformed_test_table = encoder.transform(test_table)\n",
160159
"\n",
161160
"prediction = fitted_model.predict(\n",
@@ -182,7 +181,6 @@
182181
{
183182
"cell_type": "code",
184183
"source": [
185-
"encoder = OneHotEncoder().fit(test_table, [\"sex\"])\n",
186184
"testing_table = encoder.transform(testing_table)\n",
187185
"\n",
188186
"test_tabular_dataset = testing_table.to_tabular_dataset(\"survived\", extra_names=extra_names)\n",

docs/tutorials/data_processing.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@
183183
"source": [
184184
"from safeds.data.tabular.transformation import SimpleImputer\n",
185185
"\n",
186-
"imputer = SimpleImputer(SimpleImputer.Strategy.constant(0)).fit(titanic, [\"age\", \"fare\", \"cabin\", \"port_embarked\"])\n",
186+
"imputer = SimpleImputer(SimpleImputer.Strategy.constant(0), column_names=[\"age\", \"fare\", \"cabin\", \"port_embarked\"]).fit(titanic)\n",
187187
"imputer.transform(titanic_slice)"
188188
],
189189
"metadata": {
@@ -206,7 +206,7 @@
206206
"source": [
207207
"from safeds.data.tabular.transformation import LabelEncoder\n",
208208
"\n",
209-
"encoder = LabelEncoder().fit(titanic, [\"sex\", \"port_embarked\"])\n",
209+
"encoder = LabelEncoder(column_names=[\"sex\", \"port_embarked\"]).fit(titanic)\n",
210210
"encoder.transform(titanic_slice)"
211211
],
212212
"metadata": {
@@ -229,7 +229,7 @@
229229
"source": [
230230
"from safeds.data.tabular.transformation import OneHotEncoder\n",
231231
"\n",
232-
"encoder = OneHotEncoder().fit(titanic, [\"sex\", \"port_embarked\"])\n",
232+
"encoder = OneHotEncoder(column_names=[\"sex\", \"port_embarked\"]).fit(titanic)\n",
233233
"encoder.transform(titanic_slice)"
234234
],
235235
"metadata": {
@@ -252,7 +252,7 @@
252252
"source": [
253253
"from safeds.data.tabular.transformation import RangeScaler\n",
254254
"\n",
255-
"scaler = RangeScaler(0.0, 1.0).fit(titanic, [\"age\"])\n",
255+
"scaler = RangeScaler(0.0, 1.0, column_names=\"age\").fit(titanic)\n",
256256
"scaler.transform(titanic_slice)"
257257
],
258258
"metadata": {
@@ -275,7 +275,7 @@
275275
"source": [
276276
"from safeds.data.tabular.transformation import StandardScaler\n",
277277
"\n",
278-
"scaler = StandardScaler().fit(titanic, [\"age\", \"travel_class\"])\n",
278+
"scaler = StandardScaler(column_names=[\"age\", \"travel_class\"]).fit(titanic)\n",
279279
"scaler.transform(titanic_slice)"
280280
],
281281
"metadata": {

src/safeds/data/labeled/containers/_image_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def __init__(self, column: Column) -> None:
374374
)
375375
# TODO: should not one-hot-encode the target. label encoding without order is sufficient. should also not
376376
# be done automatically?
377-
self._one_hot_encoder = OneHotEncoder().fit(column_as_table, [self._column_name])
377+
self._one_hot_encoder = OneHotEncoder(column_names=self._column_name).fit(column_as_table)
378378
self._tensor = torch.Tensor(
379379
self._one_hot_encoder.transform(column_as_table)._data_frame.to_torch(dtype=pl.Float32),
380380
).to(_get_device())

src/safeds/data/tabular/containers/_table.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1688,7 +1688,7 @@ def inverse_transform_table(self, fitted_transformer: InvertibleTableTransformer
16881688
>>> from safeds.data.tabular.containers import Table
16891689
>>> from safeds.data.tabular.transformation import RangeScaler
16901690
>>> table = Table({"a": [1, 2, 3]})
1691-
>>> transformer, transformed_table = RangeScaler(min_=0, max_=1).fit_and_transform(table, ["a"])
1691+
>>> transformer, transformed_table = RangeScaler(min_=0, max_=1, column_names="a").fit_and_transform(table)
16921692
>>> transformed_table.inverse_transform_table(transformer)
16931693
+---------+
16941694
| a |
@@ -1726,7 +1726,7 @@ def transform_table(self, fitted_transformer: TableTransformer) -> Table:
17261726
>>> from safeds.data.tabular.containers import Table
17271727
>>> from safeds.data.tabular.transformation import RangeScaler
17281728
>>> table = Table({"a": [1, 2, 3]})
1729-
>>> transformer = RangeScaler(min_=0, max_=1).fit(table, ["a"])
1729+
>>> transformer = RangeScaler(min_=0, max_=1, column_names="a").fit(table)
17301730
>>> table.transform_table(transformer)
17311731
+---------+
17321732
| a |

src/safeds/data/tabular/transformation/_discretizer.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from safeds._utils import _structural_hash
66
from safeds._validation import _check_bounds, _check_columns_exist, _ClosedBound
7+
from safeds._validation._check_columns_are_numeric import _check_columns_are_numeric
78
from safeds.data.tabular.containers import Table
89
from safeds.exceptions import (
910
NonNumericColumnError,
@@ -24,6 +25,8 @@ class Discretizer(TableTransformer):
2425
----------
2526
bin_count:
2627
The number of bins to be created.
28+
column_names:
29+
The list of columns used to fit the transformer. If `None`, all numeric columns are used.
2730
2831
Raises
2932
------
@@ -35,8 +38,13 @@ class Discretizer(TableTransformer):
3538
# Dunder methods
3639
# ------------------------------------------------------------------------------------------------------------------
3740

38-
def __init__(self, bin_count: int = 5) -> None:
39-
TableTransformer.__init__(self)
41+
def __init__(
42+
self,
43+
bin_count: int = 5,
44+
*,
45+
column_names: str | list[str] | None = None,
46+
) -> None:
47+
TableTransformer.__init__(self, column_names)
4048

4149
_check_bounds("bin_count", bin_count, lower_bound=_ClosedBound(2))
4250

@@ -53,6 +61,10 @@ def __hash__(self) -> int:
5361
# Properties
5462
# ------------------------------------------------------------------------------------------------------------------
5563

64+
@property
65+
def is_fitted(self) -> bool:
66+
return self._wrapped_transformer is not None
67+
5668
@property
5769
def bin_count(self) -> int:
5870
return self._bin_count
@@ -61,7 +73,7 @@ def bin_count(self) -> int:
6173
# Learning and transformation
6274
# ------------------------------------------------------------------------------------------------------------------
6375

64-
def fit(self, table: Table, column_names: list[str] | None) -> Discretizer:
76+
def fit(self, table: Table) -> Discretizer:
6577
"""
6678
Learn a transformation for a set of columns in a table.
6779
@@ -71,8 +83,6 @@ def fit(self, table: Table, column_names: list[str] | None) -> Discretizer:
7183
----------
7284
table:
7385
The table used to fit the transformer.
74-
column_names:
75-
The list of columns from the table used to fit the transformer. If `None`, all columns are used.
7686
7787
Returns
7888
-------
@@ -93,24 +103,21 @@ def fit(self, table: Table, column_names: list[str] | None) -> Discretizer:
93103
if table.row_count == 0:
94104
raise ValueError("The Discretizer cannot be fitted because the table contains 0 rows")
95105

96-
if column_names is None:
97-
column_names = table.column_names
106+
if self._column_names is None:
107+
column_names = [name for name in table.column_names if table.get_column_type(name).is_numeric]
98108
else:
109+
column_names = self._column_names
99110
_check_columns_exist(table, column_names)
100-
101-
for column in column_names:
102-
if not table.get_column(column).type.is_numeric:
103-
raise NonNumericColumnError(f"{column} is of type {table.get_column(column).type}.")
111+
_check_columns_are_numeric(table, column_names, operation="fit a Discretizer")
104112

105113
wrapped_transformer = sk_KBinsDiscretizer(n_bins=self._bin_count, encode="ordinal")
106114
wrapped_transformer.set_output(transform="polars")
107115
wrapped_transformer.fit(
108116
table.remove_columns_except(column_names)._data_frame,
109117
)
110118

111-
result = Discretizer(self._bin_count)
119+
result = Discretizer(self._bin_count, column_names=column_names)
112120
result._wrapped_transformer = wrapped_transformer
113-
result._column_names = column_names
114121

115122
return result
116123

src/safeds/data/tabular/transformation/_label_encoder.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ class LabelEncoder(InvertibleTableTransformer):
1818
1919
Parameters
2020
----------
21+
column_names:
22+
The list of columns used to fit the transformer. If `None`, all non-numeric columns are used.
2123
partial_order:
2224
The partial order of the labels. The labels are encoded in the order of the given list. Additional values are
2325
assigned labels in the order they are encountered during fitting.
@@ -27,8 +29,13 @@ class LabelEncoder(InvertibleTableTransformer):
2729
# Dunder methods
2830
# ------------------------------------------------------------------------------------------------------------------
2931

30-
def __init__(self, *, partial_order: list[Any] | None = None) -> None:
31-
super().__init__()
32+
def __init__(
33+
self,
34+
*,
35+
column_names: str | list[str] | None = None,
36+
partial_order: list[Any] | None = None,
37+
) -> None:
38+
super().__init__(column_names)
3239

3340
if partial_order is None:
3441
partial_order = []
@@ -51,6 +58,10 @@ def __hash__(self) -> int:
5158
# Properties
5259
# ------------------------------------------------------------------------------------------------------------------
5360

61+
@property
62+
def is_fitted(self) -> bool:
63+
return self._mapping is not None and self._inverse_mapping is not None
64+
5465
@property
5566
def partial_order(self) -> list[Any]:
5667
"""The partial order of the labels."""
@@ -60,7 +71,7 @@ def partial_order(self) -> list[Any]:
6071
# Learning and transformation
6172
# ------------------------------------------------------------------------------------------------------------------
6273

63-
def fit(self, table: Table, column_names: list[str] | None) -> LabelEncoder:
74+
def fit(self, table: Table) -> LabelEncoder:
6475
"""
6576
Learn a transformation for a set of columns in a table.
6677
@@ -70,8 +81,6 @@ def fit(self, table: Table, column_names: list[str] | None) -> LabelEncoder:
7081
----------
7182
table:
7283
The table used to fit the transformer.
73-
column_names:
74-
The list of columns from the table used to fit the transformer. If `None`, all non-numeric columns are used.
7584
7685
Returns
7786
-------
@@ -85,9 +94,10 @@ def fit(self, table: Table, column_names: list[str] | None) -> LabelEncoder:
8594
ValueError
8695
If the table contains 0 rows.
8796
"""
88-
if column_names is None:
97+
if self._column_names is None:
8998
column_names = [name for name in table.column_names if not table.get_column_type(name).is_numeric]
9099
else:
100+
column_names = self._column_names
91101
_check_columns_exist(table, column_names)
92102
_warn_if_columns_are_numeric(table, column_names)
93103

@@ -111,8 +121,7 @@ def fit(self, table: Table, column_names: list[str] | None) -> LabelEncoder:
111121
reverse_mapping[name][label] = value
112122

113123
# Create a copy with the learned transformation
114-
result = LabelEncoder(partial_order=self._partial_order)
115-
result._column_names = column_names
124+
result = LabelEncoder(column_names=column_names, partial_order=self._partial_order)
116125
result._mapping = mapping
117126
result._inverse_mapping = reverse_mapping
118127

src/safeds/data/tabular/transformation/_one_hot_encoder.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class OneHotEncoder(InvertibleTableTransformer):
4343
4444
Parameters
4545
----------
46+
column_names:
47+
The list of columns used to fit the transformer. If `None`, all non-numeric columns are used.
4648
separator:
4749
The separator used to separate the original column name from the value in the new column names.
4850
@@ -52,7 +54,7 @@ class OneHotEncoder(InvertibleTableTransformer):
5254
>>> from safeds.data.tabular.transformation import OneHotEncoder
5355
>>> table = Table({"col1": ["a", "b", "c", "a"]})
5456
>>> transformer = OneHotEncoder()
55-
>>> transformer.fit_and_transform(table, ["col1"])[1]
57+
>>> transformer.fit_and_transform(table)[1]
5658
+---------+---------+---------+
5759
| col1__a | col1__b | col1__c |
5860
| --- | --- | --- |
@@ -72,9 +74,10 @@ class OneHotEncoder(InvertibleTableTransformer):
7274
def __init__(
7375
self,
7476
*,
77+
column_names: str | list[str] | None = None,
7578
separator: str = "__",
7679
) -> None:
77-
super().__init__()
80+
super().__init__(column_names)
7881

7982
# Parameters
8083
self._separator = separator
@@ -103,6 +106,10 @@ def __hash__(self) -> int:
103106
# Properties
104107
# ------------------------------------------------------------------------------------------------------------------
105108

109+
@property
110+
def is_fitted(self) -> bool:
111+
return self._mapping is not None
112+
106113
@property
107114
def separator(self) -> str:
108115
"""The separator used to separate the original column name from the value in the new column names."""
@@ -112,7 +119,7 @@ def separator(self) -> str:
112119
# Learning and transformation
113120
# ------------------------------------------------------------------------------------------------------------------
114121

115-
def fit(self, table: Table, column_names: list[str] | None) -> OneHotEncoder:
122+
def fit(self, table: Table) -> OneHotEncoder:
116123
"""
117124
Learn a transformation for a set of columns in a table.
118125
@@ -122,8 +129,6 @@ def fit(self, table: Table, column_names: list[str] | None) -> OneHotEncoder:
122129
----------
123130
table:
124131
The table used to fit the transformer.
125-
column_names:
126-
The list of columns from the table used to fit the transformer. If `None`, all columns are used.
127132
128133
Returns
129134
-------
@@ -137,9 +142,10 @@ def fit(self, table: Table, column_names: list[str] | None) -> OneHotEncoder:
137142
ValueError
138143
If the table contains 0 rows.
139144
"""
140-
if column_names is None:
145+
if self._column_names is None:
141146
column_names = [name for name in table.column_names if not table.get_column_type(name).is_numeric]
142147
else:
148+
column_names = self._column_names
143149
_check_columns_exist(table, column_names)
144150
_warn_if_columns_are_numeric(table, column_names)
145151

@@ -169,8 +175,7 @@ def fit(self, table: Table, column_names: list[str] | None) -> OneHotEncoder:
169175
mapping[name].append((new_name, value))
170176

171177
# Create a copy with the learned transformation
172-
result = OneHotEncoder()
173-
result._column_names = column_names
178+
result = OneHotEncoder(column_names=column_names, separator=self._separator)
174179
result._new_column_names = new_column_names
175180
result._mapping = mapping
176181

0 commit comments

Comments
 (0)