Skip to content

Commit e993c17

Browse files
perf: implement one hot encoder and imputer using polars (#768)
### Summary of Changes The one hot encoder and imputer are now also implemented using polars, providing better performance. Tests should pass again now. We'll maximize coverage over the coming days. --------- Co-authored-by: megalinter-bot <[email protected]>
1 parent 6fbe537 commit e993c17

25 files changed

+428
-579
lines changed

.mega-linter.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ JSON_PRETTIER_FILE_EXTENSIONS:
1515
- .html
1616
# - .md
1717

18+
PYTHON_RUFF_CONFIG_FILE: pyproject.toml
19+
1820
# Commands
1921
PRE_COMMANDS:
2022
- command: npm i @lars-reimann/prettier-config

poetry.lock

Lines changed: 53 additions & 53 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/safeds/data/labeled/containers/_image_dataset.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ def shuffle(self) -> ImageDataset[T]:
290290

291291
class _TableAsTensor:
292292
def __init__(self, table: Table) -> None:
293+
import polars as pl
293294
import torch
294295

295296
_init_default_device()
@@ -298,7 +299,7 @@ def __init__(self, table: Table) -> None:
298299
if table.number_of_rows == 0:
299300
self._tensor = torch.empty((0, table.number_of_columns), dtype=torch.float32).to(_get_device())
300301
else:
301-
self._tensor = table._data_frame.to_torch().to(_get_device())
302+
self._tensor = table._data_frame.to_torch(dtype=pl.Float32).to(_get_device())
302303

303304
if not torch.all(self._tensor.sum(dim=1) == torch.ones(self._tensor.size(dim=0))):
304305
raise ValueError(
@@ -345,6 +346,7 @@ def _to_table(self) -> Table:
345346

346347
class _ColumnAsTensor:
347348
def __init__(self, column: Column) -> None:
349+
import polars as pl
348350
import torch
349351

350352
_init_default_device()
@@ -360,9 +362,9 @@ def __init__(self, column: Column) -> None:
360362
# TODO: should not one-hot-encode the target. label encoding without order is sufficient. should also not
361363
# be done automatically?
362364
self._one_hot_encoder = OneHotEncoder().fit(column_as_table, [self._column_name])
363-
self._tensor = torch.Tensor(self._one_hot_encoder.transform(column_as_table)._data_frame.to_torch()).to(
364-
_get_device(),
365-
)
365+
self._tensor = torch.Tensor(
366+
self._one_hot_encoder.transform(column_as_table)._data_frame.to_torch(dtype=pl.Float32),
367+
).to(_get_device())
366368

367369
def __eq__(self, other: object) -> bool:
368370
import torch

src/safeds/data/labeled/containers/_tabular_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class TabularDataset(Dataset):
5353
5454
Examples
5555
--------
56-
>>> from safeds.data.labeled.containers import TabularDataset
56+
>>> from safeds.data.tabular.containers import Table
5757
>>> table = Table(
5858
... {
5959
... "id": [1, 2, 3],

src/safeds/data/tabular/containers/_column.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,14 +1009,7 @@ def mode(
10091009
>>> from safeds.data.tabular.containers import Column
10101010
>>> column = Column("test", [3, 1, 2, 1, 3])
10111011
>>> column.mode()
1012-
+------+
1013-
| test |
1014-
| --- |
1015-
| i64 |
1016-
+======+
1017-
| 1 |
1018-
| 3 |
1019-
+------+
1012+
[1, 3]
10201013
"""
10211014
import polars as pl
10221015

src/safeds/data/tabular/containers/_table.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def __eq__(self, other: object) -> bool:
331331
if self is other:
332332
return True
333333

334-
return self._data_frame.frame_equal(other._data_frame)
334+
return self._data_frame.equals(other._data_frame)
335335

336336
def __hash__(self) -> int:
337337
return _structural_hash(self.schema, self.number_of_rows)
@@ -859,7 +859,7 @@ def rename_column(self, old_name: str, new_name: str) -> Table:
859859
def replace_column(
860860
self,
861861
old_name: str,
862-
new_columns: Column | list[Column],
862+
new_columns: Column | list[Column] | Table,
863863
) -> Table:
864864
"""
865865
Return a new table with a column replaced by zero or more columns.
@@ -871,7 +871,7 @@ def replace_column(
871871
old_name:
872872
The name of the column to replace.
873873
new_columns:
874-
The new column or columns.
874+
The new columns.
875875
876876
Returns
877877
-------
@@ -922,11 +922,13 @@ def replace_column(
922922
| 9 | 12 | 6 |
923923
+-----+-----+-----+
924924
"""
925-
_check_columns_exist(self, old_name)
926-
_check_columns_dont_exist(self, [column.name for column in new_columns], old_name=old_name)
927-
928925
if isinstance(new_columns, Column):
929926
new_columns = [new_columns]
927+
elif isinstance(new_columns, Table):
928+
new_columns = new_columns.to_columns()
929+
930+
_check_columns_exist(self, old_name)
931+
_check_columns_dont_exist(self, [column.name for column in new_columns], old_name=old_name)
930932

931933
if len(new_columns) == 0:
932934
return self.remove_columns(old_name)
@@ -1033,9 +1035,6 @@ def remove_duplicate_rows(self) -> Table:
10331035
| 2 | 5 |
10341036
+-----+-----+
10351037
"""
1036-
if self.number_of_columns == 0:
1037-
return self # Workaround for https://github.com/pola-rs/polars/issues/16207
1038-
10391038
return Table._from_polars_lazy_frame(
10401039
self._lazy_frame.unique(maintain_order=True),
10411040
)

src/safeds/data/tabular/plotting/_table_plotter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class TablePlotter:
2424
Examples
2525
--------
2626
>>> from safeds.data.tabular.containers import Table
27-
>>> table = Table("test", [1, 2, 3])
27+
>>> table = Table({"test": [1, 2, 3]})
2828
>>> plotter = table.plot
2929
"""
3030

src/safeds/data/tabular/transformation/_invertible_table_transformer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@ class InvertibleTableTransformer(TableTransformer):
1515
@abstractmethod
1616
def inverse_transform(self, transformed_table: Table) -> Table:
1717
"""
18-
Undo the learned transformation.
18+
Undo the learned transformation as well as possible.
1919
20-
The table is not modified.
20+
Column order and types may differ from the original table. Likewise, some values might not be restored.
21+
22+
**Note:** The given table is not modified.
2123
2224
Parameters
2325
----------

src/safeds/data/tabular/transformation/_label_encoder.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,14 @@ def __init__(self, *, partial_order: list[Any] | None = None) -> None:
3737
self._partial_order = partial_order
3838

3939
# Internal state
40-
self._mapping: dict[str, dict[Any, int]] | None = None
41-
self._inverse_mapping: dict[str, dict[int, Any]] | None = None
40+
self._mapping: dict[str, dict[Any, int]] | None = None # Column name -> value -> label
41+
self._inverse_mapping: dict[str, dict[int, Any]] | None = None # Column name -> label -> value
4242

4343
def __hash__(self) -> int:
4444
return _structural_hash(
4545
super().__hash__(),
4646
self._partial_order,
47+
# Leave out the internal state for faster hashing
4748
)
4849

4950
# ------------------------------------------------------------------------------------------------------------------
@@ -61,7 +62,7 @@ def fit(self, table: Table, column_names: list[str] | None) -> LabelEncoder:
6162
table:
6263
The table used to fit the transformer.
6364
column_names:
64-
The list of columns from the table used to fit the transformer. If `None`, all columns are used.
65+
The list of columns from the table used to fit the transformer. If `None`, all non-numeric columns are used.
6566
6667
Returns
6768
-------
@@ -76,14 +77,13 @@ def fit(self, table: Table, column_names: list[str] | None) -> LabelEncoder:
7677
If the table contains 0 rows.
7778
"""
7879
if column_names is None:
79-
column_names = table.column_names
80+
column_names = [name for name in table.column_names if not table.get_column_type(name).is_numeric]
8081
else:
8182
_check_columns_exist(table, column_names)
83+
_warn_if_columns_are_numeric(table, column_names)
8284

8385
if table.number_of_rows == 0:
84-
raise ValueError("The LabelEncoder cannot transform the table because it contains 0 rows")
85-
86-
_warn_if_columns_are_numeric(table, column_names)
86+
raise ValueError("The LabelEncoder cannot be fitted because the table contains 0 rows")
8787

8888
# Learn the transformation
8989
mapping = {}
@@ -142,7 +142,10 @@ def transform(self, table: Table) -> Table:
142142

143143
_check_columns_exist(table, self._column_names)
144144

145-
columns = [pl.col(name).replace(self._mapping[name], return_dtype=pl.UInt32) for name in self._column_names]
145+
columns = [
146+
pl.col(name).replace(self._mapping[name], default=None, return_dtype=pl.UInt32)
147+
for name in self._column_names
148+
]
146149

147150
return Table._from_polars_lazy_frame(
148151
table._lazy_frame.with_columns(columns),
@@ -186,7 +189,7 @@ def inverse_transform(self, transformed_table: Table) -> Table:
186189
operation="inverse-transform with a LabelEncoder",
187190
)
188191

189-
columns = [pl.col(name).replace(self._inverse_mapping[name]) for name in self._column_names]
192+
columns = [pl.col(name).replace(self._inverse_mapping[name], default=None) for name in self._column_names]
190193

191194
return Table._from_polars_lazy_frame(
192195
transformed_table._lazy_frame.with_columns(columns),

0 commit comments

Comments
 (0)