Skip to content

Unexpected behavior in train_test_split with shuffle=False #992

@divir94

Description

@divir94

When using train_test_split with shuffle=False and a Dask dataframe, I notice 2 issues - 1) The index is actually shuffled and 2) the train/test size seems incorrect. The behavior doesn't match sklearn or when you pass a raw DataFrame.

Minimal Complete Verifiable Example:
Setup

import pandas as pd
import numpy as np
import dask.dataframe as dd

from sklearn.model_selection import train_test_split as sk_train_test_split
from dask_ml.model_selection import train_test_split as dd_train_test_split

df = pd.DataFrame(np.random.rand(10, 3), columns=["y", "x1", "x2"])
ddf = dd.from_pandas(df, 5)

With sklearn.model_selection, order is maintained (i.e. no shuffle)

y = df["y"]
X = df[["x1", "x2"]]

X_train, X_valid, y_train, y_test = sk_train_test_split(X, y, test_size=0.5, shuffle=False)
y_train, y_test
Output:
(0    0.166713
 1    0.961016
 2    0.483907
 3    0.979503
 4    0.553724
 Name: y, dtype: float64,
 5    0.158432
 6    0.078795
 7    0.440427
 8    0.673160
 9    0.657797
 Name: y, dtype: float64)

With dask_ml.model_selection using Pandas Dataframe, order is maintained (i.e. no shuffle)

y = df["y"]
X = df[["x1", "x2"]]

X_train, X_valid, y_train, y_test = dd_train_test_split(X, y, test_size=0.5, shuffle=False)
y_train, y_test
(0    0.166713
 1    0.961016
 2    0.483907
 3    0.979503
 4    0.553724
 Name: y, dtype: float64,
 5    0.158432
 6    0.078795
 7    0.440427
 8    0.673160
 9    0.657797
 Name: y, dtype: float64)

With dask_ml.model_selection using Dask Dataframe, , order is NOT maintained and train/test size is incorrect.

y = ddf["y"]
X = ddf[["x1", "x2"]]

X_train, X_valid, y_train, y_test = dd_train_test_split(X, y, test_size=0.5, shuffle=False)
y_train.compute(), y_test.compute()
(0    0.166713
 1    0.961016
 2    0.483907
 3    0.979503
 8    0.673160
 9    0.657797
 Name: y, dtype: float64,
 4    0.553724
 5    0.158432
 6    0.078795
 7    0.440427
 Name: y, dtype: float64)

Environment:

  • Dask version: 2023.11.0
  • Python version: 3.11.8
  • Operating System: MacOS
  • Install method (conda, pip, source): micromamba

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions