Skip to content

Commit 6fbe537

Browse files
feat: specify partial order in label encoder (#763)
Closes #639 ### Summary of Changes * Optionally specify a partial order of labels in the label encoder * Performance: Implement RangeScaler, StandardScaler, LabelEncoder with polars --------- Co-authored-by: megalinter-bot <[email protected]>
1 parent 74cc701 commit 6fbe537

File tree

60 files changed

+672
-1242
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+672
-1242
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ omit = [
7878
]
7979

8080
[tool.pytest.ini_options]
81-
addopts = "--snapshot-warn-unused"
81+
addopts = "--snapshot-warn-unused --tb=short"
8282
filterwarnings = [
8383
"ignore:Deprecated call to `pkg_resources.declare_namespace",
8484
"ignore:Jupyter is migrating its paths to use standard platformdirs"
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from safeds.exceptions import ColumnTypeError
6+
7+
if TYPE_CHECKING:
8+
from collections.abc import Container
9+
10+
from safeds.data.tabular.containers import Table
11+
from safeds.data.tabular.typing import Schema
12+
13+
14+
def _check_columns_are_numeric(
15+
table_or_schema: Table | Schema,
16+
column_names: str | list[str],
17+
*,
18+
operation: str = "do a numeric operation",
19+
) -> None:
20+
"""
21+
Check if the columns with the specified names are numeric and raise an error if they are not.
22+
23+
Missing columns are ignored. Use `_check_columns_exist` to check for missing columns.
24+
25+
Parameters
26+
----------
27+
table_or_schema:
28+
The table or schema to check.
29+
column_names:
30+
The column names to check.
31+
operation:
32+
The operation that is performed on the columns. This is used in the error message.
33+
34+
Raises
35+
------
36+
ColumnTypeError
37+
If a column exists but is not numeric.
38+
"""
39+
from safeds.data.tabular.containers import Table # circular import
40+
41+
if isinstance(table_or_schema, Table):
42+
table_or_schema = table_or_schema.schema
43+
if isinstance(column_names, str):
44+
column_names = [column_names]
45+
46+
if len(column_names) > 1:
47+
# Create a set for faster containment checks
48+
known_names: Container = set(table_or_schema.column_names)
49+
else:
50+
known_names = table_or_schema.column_names
51+
52+
non_numeric_names = [
53+
name for name in column_names if name in known_names and not table_or_schema.get_column_type(name).is_numeric
54+
]
55+
if non_numeric_names:
56+
message = _build_error_message(non_numeric_names, operation)
57+
raise ColumnTypeError(message)
58+
59+
60+
def _build_error_message(non_numeric_names: list[str], operation: str) -> str:
61+
return f"Tried to {operation} on non-numeric columns {non_numeric_names}."

src/safeds/data/image/containers/_image.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import io
4-
import os.path
54
import sys
65
import warnings
76
from pathlib import Path
@@ -79,9 +78,12 @@ def from_file(path: str | Path) -> Image:
7978
"""
8079
from torchvision.io import read_image
8180

81+
if isinstance(path, str):
82+
path = Path(path)
83+
8284
_init_default_device()
8385

84-
if not os.path.isfile(path):
86+
if not path.is_file():
8587
raise FileNotFoundError(f"No such file or directory: '{path}'")
8688

8789
return Image(image_tensor=read_image(str(path)).to(_get_device()))

src/safeds/data/image/containers/_image_list.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,6 @@ def from_files(
283283
return image_list
284284

285285
class _FromFileThreadPackage:
286-
287286
def __init__(
288287
self,
289288
im_files: list[str],
@@ -323,7 +322,6 @@ def __len__(self) -> int:
323322
return len(self._im_files)
324323

325324
class _FromImageThread(Thread):
326-
327325
def __init__(self, packages: list[ImageList._FromFileThreadPackage]) -> None:
328326
super().__init__()
329327
self._packages = packages

src/safeds/data/image/containers/_multi_size_image_list.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def _create_from_single_sized_image_lists(single_size_image_lists: list[_SingleS
6666
single_size_image_list._indices_to_tensor_positions.keys(),
6767
[image_size] * len(single_size_image_list),
6868
strict=False,
69-
)
69+
),
7070
)
7171
if max_channel is None:
7272
max_channel = single_size_image_list.channel
@@ -80,7 +80,7 @@ def _create_from_single_sized_image_lists(single_size_image_lists: list[_SingleS
8080
for size in image_list._image_list_dict:
8181
if max_channel is not None and image_list._image_list_dict[size].channel != max_channel:
8282
image_list._image_list_dict[size] = image_list._image_list_dict[size].change_channel(
83-
int(max_channel)
83+
int(max_channel),
8484
)
8585
return image_list
8686

src/safeds/data/image/containers/_single_size_image_list.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pathlib import Path
77
from typing import TYPE_CHECKING
88

9-
from safeds._config import _init_default_device, _get_device
9+
from safeds._config import _get_device, _init_default_device
1010
from safeds._utils import _structural_hash
1111
from safeds.data.image._utils._image_transformation_error_and_warning_checks import (
1212
_check_add_noise_errors,
@@ -82,7 +82,12 @@ def _create_image_list_from_files(
8282
image_list = _SingleSizeImageList()
8383

8484
images_tensor = torch.empty(
85-
number_of_images, max_channel, height, width, dtype=torch.uint8, device=_get_device()
85+
number_of_images,
86+
max_channel,
87+
height,
88+
width,
89+
dtype=torch.uint8,
90+
device=_get_device(),
8691
)
8792

8893
thread_packages: list[ImageList._FromFileThreadPackage] = []

src/safeds/data/labeled/containers/_image_dataset.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(self, input_data: ImageList, output_data: T, batch_size: int = 1, s
8989
_output_size: int | ImageSize = output_data.number_of_columns
9090
elif isinstance(output_data, Column):
9191
_column_as_tensor = _ColumnAsTensor(output_data)
92-
_output_size = len(_column_as_tensor._one_hot_encoder.get_names_of_added_columns())
92+
_output_size = len(_column_as_tensor._one_hot_encoder._get_names_of_added_columns())
9393
_output = _column_as_tensor
9494
elif isinstance(output_data, _SingleSizeImageList):
9595
_output = output_data._clone()._as_single_size_image_list()
@@ -289,7 +289,6 @@ def shuffle(self) -> ImageDataset[T]:
289289

290290

291291
class _TableAsTensor:
292-
293292
def __init__(self, table: Table) -> None:
294293
import torch
295294

@@ -345,7 +344,6 @@ def _to_table(self) -> Table:
345344

346345

347346
class _ColumnAsTensor:
348-
349347
def __init__(self, column: Column) -> None:
350348
import torch
351349

@@ -359,6 +357,8 @@ def __init__(self, column: Column) -> None:
359357
message=rf"The columns \['{self._column_name}'\] contain numerical data. The OneHotEncoder is designed to encode non-numerical values into numerical values",
360358
category=UserWarning,
361359
)
360+
# TODO: should not one-hot-encode the target. label encoding without order is sufficient. should also not
361+
# be done automatically?
362362
self._one_hot_encoder = OneHotEncoder().fit(column_as_table, [self._column_name])
363363
self._tensor = torch.Tensor(self._one_hot_encoder.transform(column_as_table)._data_frame.to_torch()).to(
364364
_get_device(),
@@ -394,9 +394,9 @@ def _from_tensor(tensor: Tensor, column_name: str, one_hot_encoder: OneHotEncode
394394
raise ValueError(f"Tensor has an invalid amount of dimensions. Needed 2 dimensions but got {tensor.dim()}.")
395395
if not one_hot_encoder.is_fitted:
396396
raise TransformerNotFittedError
397-
if tensor.size(dim=1) != len(one_hot_encoder.get_names_of_added_columns()):
397+
if tensor.size(dim=1) != len(one_hot_encoder._get_names_of_added_columns()):
398398
raise ValueError(
399-
f"Tensor and one_hot_encoder have different amounts of classes ({tensor.size(dim=1)}!={len(one_hot_encoder.get_names_of_added_columns())}).",
399+
f"Tensor and one_hot_encoder have different amounts of classes ({tensor.size(dim=1)}!={len(one_hot_encoder._get_names_of_added_columns())}).",
400400
)
401401
table_as_tensor = _ColumnAsTensor.__new__(_ColumnAsTensor)
402402
table_as_tensor._tensor = tensor
@@ -406,6 +406,6 @@ def _from_tensor(tensor: Tensor, column_name: str, one_hot_encoder: OneHotEncode
406406

407407
def _to_column(self) -> Column:
408408
table = Table(
409-
dict(zip(self._one_hot_encoder.get_names_of_added_columns(), self._tensor.T.tolist(), strict=False)),
409+
dict(zip(self._one_hot_encoder._get_names_of_added_columns(), self._tensor.T.tolist(), strict=False)),
410410
)
411411
return self._one_hot_encoder.inverse_transform(table).get_column(self._column_name)

src/safeds/data/tabular/containers/_table.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def __init__(self, data: Mapping[str, Sequence[Any]] | None = None) -> None:
323323

324324
# Implementation
325325
self._lazy_frame: pl.LazyFrame = pl.LazyFrame(data)
326-
self.__data_frame_cache: pl.DataFrame | None = None
326+
self.__data_frame_cache: pl.DataFrame | None = None # Scramble the name to prevent access from outside
327327

328328
def __eq__(self, other: object) -> bool:
329329
if not isinstance(other, Table):
@@ -1033,6 +1033,9 @@ def remove_duplicate_rows(self) -> Table:
10331033
| 2 | 5 |
10341034
+-----+-----+
10351035
"""
1036+
if self.number_of_columns == 0:
1037+
return self # Workaround for https://github.com/pola-rs/polars/issues/16207
1038+
10361039
return Table._from_polars_lazy_frame(
10371040
self._lazy_frame.unique(maintain_order=True),
10381041
)
@@ -1221,6 +1224,8 @@ def remove_rows_with_outliers(
12211224
| null | 8 |
12221225
+------+-----+
12231226
"""
1227+
if self.number_of_rows == 0:
1228+
return self # polars raises a ComputeError for tables without rows
12241229
if column_names is None:
12251230
column_names = self.column_names
12261231

@@ -1440,7 +1445,10 @@ def split_rows(
14401445
The first table contains a percentage of the rows specified by `percentage_in_first`, and the second table
14411446
contains the remaining rows.
14421447
1443-
**Note:** The original table is not modified.
1448+
**Notes:**
1449+
1450+
- The original table is not modified.
1451+
- By default, the rows are shuffled before splitting. You can disable this by setting `shuffle` to False.
14441452
14451453
Parameters
14461454
----------

src/safeds/data/tabular/transformation/_discretizer.py

Lines changed: 26 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import TYPE_CHECKING
44

5+
from safeds._utils import _structural_hash
56
from safeds._validation import _check_bounds, _check_columns_exist, _ClosedBound
67
from safeds.data.tabular.containers import Table
78
from safeds.exceptions import (
@@ -30,13 +31,36 @@ class Discretizer(TableTransformer):
3031
If the given number_of_bins is less than 2.
3132
"""
3233

33-
def __init__(self, number_of_bins: int = 5):
34+
# ------------------------------------------------------------------------------------------------------------------
35+
# Dunder methods
36+
# ------------------------------------------------------------------------------------------------------------------
37+
38+
def __init__(self, number_of_bins: int = 5) -> None:
39+
TableTransformer.__init__(self)
40+
3441
_check_bounds("number_of_bins", number_of_bins, lower_bound=_ClosedBound(2))
3542

36-
self._column_names: list[str] | None = None
3743
self._wrapped_transformer: sk_KBinsDiscretizer | None = None
3844
self._number_of_bins = number_of_bins
3945

46+
def __hash__(self) -> int:
47+
return _structural_hash(
48+
TableTransformer.__hash__(self),
49+
self._number_of_bins,
50+
)
51+
52+
# ------------------------------------------------------------------------------------------------------------------
53+
# Properties
54+
# ------------------------------------------------------------------------------------------------------------------
55+
56+
@property
57+
def number_of_bins(self) -> int:
58+
return self._number_of_bins
59+
60+
# ------------------------------------------------------------------------------------------------------------------
61+
# Learning and transformation
62+
# ------------------------------------------------------------------------------------------------------------------
63+
4064
def fit(self, table: Table, column_names: list[str] | None) -> Discretizer:
4165
"""
4266
Learn a transformation for a set of columns in a table.
@@ -137,62 +161,3 @@ def transform(self, table: Table) -> Table:
137161
return Table._from_polars_lazy_frame(
138162
table._lazy_frame.update(new_data.lazy()),
139163
)
140-
141-
@property
142-
def is_fitted(self) -> bool:
143-
"""Whether the transformer is fitted."""
144-
return self._wrapped_transformer is not None
145-
146-
def get_names_of_added_columns(self) -> list[str]:
147-
"""
148-
Get the names of all new columns that have been added by the Discretizer.
149-
150-
Returns
151-
-------
152-
added_columns:
153-
A list of names of the added columns, ordered as they will appear in the table.
154-
155-
Raises
156-
------
157-
TransformerNotFittedError
158-
If the transformer has not been fitted yet.
159-
"""
160-
if not self.is_fitted:
161-
raise TransformerNotFittedError
162-
return []
163-
164-
def get_names_of_changed_columns(self) -> list[str]:
165-
"""
166-
Get the names of all columns that may have been changed by the Discretizer.
167-
168-
Returns
169-
-------
170-
changed_columns:
171-
The list of (potentially) changed column names, as passed to fit.
172-
173-
Raises
174-
------
175-
TransformerNotFittedError
176-
If the transformer has not been fitted yet.
177-
"""
178-
if self._column_names is None:
179-
raise TransformerNotFittedError
180-
return self._column_names
181-
182-
def get_names_of_removed_columns(self) -> list[str]:
183-
"""
184-
Get the names of all columns that have been removed by the Discretizer.
185-
186-
Returns
187-
-------
188-
removed_columns:
189-
A list of names of the removed columns, ordered as they appear in the table the Discretizer was fitted on.
190-
191-
Raises
192-
------
193-
TransformerNotFittedError
194-
If the transformer has not been fitted yet.
195-
"""
196-
if not self.is_fitted:
197-
raise TransformerNotFittedError
198-
return []

0 commit comments

Comments
 (0)