Skip to content

ColumnTransformer _hstack incompatible with scikit's version #1019

@avalanche-pwn

Description

@avalanche-pwn

Describe the issue:
The current dask_ml's transformer's _hstack method has different signature than the method from scikit - it lacks the n_samples argument.

Minimal Complete Verifiable Example:

from dask_ml.wrappers import Incremental
from dask_ml.feature_extraction.text import HashingVectorizer
import dask.dataframe as dd
import pandas as pd
from dask_ml.compose import ColumnTransformer

data = {
    "test1": ["example", "text"],
    "test2": ["lorem", "ipsum"]
}

df = pd.DataFrame(data)
df = dd.from_pandas(df).astype(str)

pipeline = ColumnTransformer([
    ("test1", HashingVectorizer(), "test1"),
    ("test2", HashingVectorizer(), "test2"),
    ])

pipeline.fit(df)

Anything else we need to know?:
This causes a crash:

Traceback (most recent call last):
  File "/home/antoni/Documents/projects/dask/reproducers/1/main.py", line 20, in <module>
    pipeline.fit(df)
    ~~~~~~~~~~~~^^^^
  File "/home/antoni/Documents/projects/dask/reproducers/1/.venv/lib/python3.13/site-packages/sklearn/compose/_column_transformer.py", line 947, in fit
    self.fit_transform(X, y=y, **params)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
  File "/home/antoni/Documents/projects/dask/reproducers/1/.venv/lib/python3.13/site-packages/sklearn/utils/_set_output.py", line 319, in wrapped
    data_to_wrap = f(self, X, *args, **kwargs)
  File "/home/antoni/Documents/projects/dask/reproducers/1/.venv/lib/python3.13/site-packages/sklearn/base.py", line 1389, in wrapper
    return fit_method(estimator, *args, **kwargs)
  File "/home/antoni/Documents/projects/dask/reproducers/1/.venv/lib/python3.13/site-packages/sklearn/compose/_column_transformer.py", line 1031, in fit_transform
    return self._hstack(list(Xs), n_samples=n_samples)
           ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: ColumnTransformer._hstack() got an unexpected keyword argument 'n_samples'

The fix seems simple enough it would be just adding a check similar to the one in scikit's version before returning. I can implement this just please let me know if this kind of fix seems like enough.

Environment:

  • Dask version: 2025.5.0
  • Dask-ml version: 2025.1.0
  • scikit-learn: 1.6.1
  • Python version: 3.13.3
  • Operating System: Linux
  • Install method (conda, pip, source): pip

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions