-
-
Notifications
You must be signed in to change notification settings - Fork 260
Open
Description
Describe the issue:
The current dask_ml's transformer's _hstack method has different signature than the method from scikit - it lacks the n_samples argument.
Minimal Complete Verifiable Example:
from dask_ml.wrappers import Incremental
from dask_ml.feature_extraction.text import HashingVectorizer
import dask.dataframe as dd
import pandas as pd
from dask_ml.compose import ColumnTransformer
data = {
"test1": ["example", "text"],
"test2": ["lorem", "ipsum"]
}
df = pd.DataFrame(data)
df = dd.from_pandas(df).astype(str)
pipeline = ColumnTransformer([
("test1", HashingVectorizer(), "test1"),
("test2", HashingVectorizer(), "test2"),
])
pipeline.fit(df)
Anything else we need to know?:
This causes a crash:
Traceback (most recent call last):
File "/home/antoni/Documents/projects/dask/reproducers/1/main.py", line 20, in <module>
pipeline.fit(df)
~~~~~~~~~~~~^^^^
File "/home/antoni/Documents/projects/dask/reproducers/1/.venv/lib/python3.13/site-packages/sklearn/compose/_column_transformer.py", line 947, in fit
self.fit_transform(X, y=y, **params)
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
File "/home/antoni/Documents/projects/dask/reproducers/1/.venv/lib/python3.13/site-packages/sklearn/utils/_set_output.py", line 319, in wrapped
data_to_wrap = f(self, X, *args, **kwargs)
File "/home/antoni/Documents/projects/dask/reproducers/1/.venv/lib/python3.13/site-packages/sklearn/base.py", line 1389, in wrapper
return fit_method(estimator, *args, **kwargs)
File "/home/antoni/Documents/projects/dask/reproducers/1/.venv/lib/python3.13/site-packages/sklearn/compose/_column_transformer.py", line 1031, in fit_transform
return self._hstack(list(Xs), n_samples=n_samples)
~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: ColumnTransformer._hstack() got an unexpected keyword argument 'n_samples'
The fix seems simple enough it would be just adding a check similar to the one in scikit's version before returning. I can implement this just please let me know if this kind of fix seems like enough.
Environment:
- Dask version: 2025.5.0
- Dask-ml version: 2025.1.0
- scikit-learn: 1.6.1
- Python version: 3.13.3
- Operating System: Linux
- Install method (conda, pip, source): pip
Metadata
Metadata
Assignees
Labels
No labels