@@ -37,13 +37,14 @@ def __init__(self, *, partial_order: list[Any] | None = None) -> None:
3737 self ._partial_order = partial_order
3838
3939 # Internal state
40- self ._mapping : dict [str , dict [Any , int ]] | None = None
41- self ._inverse_mapping : dict [str , dict [int , Any ]] | None = None
40+ self ._mapping : dict [str , dict [Any , int ]] | None = None # Column name -> value -> label
41+ self ._inverse_mapping : dict [str , dict [int , Any ]] | None = None # Column name -> label -> value
4242
4343 def __hash__ (self ) -> int :
4444 return _structural_hash (
4545 super ().__hash__ (),
4646 self ._partial_order ,
47+ # Leave out the internal state for faster hashing
4748 )
4849
4950 # ------------------------------------------------------------------------------------------------------------------
@@ -61,7 +62,7 @@ def fit(self, table: Table, column_names: list[str] | None) -> LabelEncoder:
6162 table:
6263 The table used to fit the transformer.
6364 column_names:
64- The list of columns from the table used to fit the transformer. If `None`, all columns are used.
65+ The list of columns from the table used to fit the transformer. If `None`, all non-numeric columns are used.
6566
6667 Returns
6768 -------
@@ -76,14 +77,13 @@ def fit(self, table: Table, column_names: list[str] | None) -> LabelEncoder:
7677 If the table contains 0 rows.
7778 """
7879 if column_names is None :
79- column_names = table .column_names
80+ column_names = [ name for name in table .column_names if not table . get_column_type ( name ). is_numeric ]
8081 else :
8182 _check_columns_exist (table , column_names )
83+ _warn_if_columns_are_numeric (table , column_names )
8284
8385 if table .number_of_rows == 0 :
84- raise ValueError ("The LabelEncoder cannot transform the table because it contains 0 rows" )
85-
86- _warn_if_columns_are_numeric (table , column_names )
86+ raise ValueError ("The LabelEncoder cannot be fitted because the table contains 0 rows" )
8787
8888 # Learn the transformation
8989 mapping = {}
@@ -142,7 +142,10 @@ def transform(self, table: Table) -> Table:
142142
143143 _check_columns_exist (table , self ._column_names )
144144
145- columns = [pl .col (name ).replace (self ._mapping [name ], return_dtype = pl .UInt32 ) for name in self ._column_names ]
145+ columns = [
146+ pl .col (name ).replace (self ._mapping [name ], default = None , return_dtype = pl .UInt32 )
147+ for name in self ._column_names
148+ ]
146149
147150 return Table ._from_polars_lazy_frame (
148151 table ._lazy_frame .with_columns (columns ),
@@ -186,7 +189,7 @@ def inverse_transform(self, transformed_table: Table) -> Table:
186189 operation = "inverse-transform with a LabelEncoder" ,
187190 )
188191
189- columns = [pl .col (name ).replace (self ._inverse_mapping [name ]) for name in self ._column_names ]
192+ columns = [pl .col (name ).replace (self ._inverse_mapping [name ], default = None ) for name in self ._column_names ]
190193
191194 return Table ._from_polars_lazy_frame (
192195 transformed_table ._lazy_frame .with_columns (columns ),
0 commit comments