TableShift is a benchmarking library for machine learning with tabular data under distribution shift.
You can read more about TableShift at tableshift.org or read the full paper (published in NeurIPS 2023 Datasets & Benchmarks Track) on arxiv. If you use the benchmark in your research, please cite the paper:
@article{gardner2023tableshift,
title={Benchmarking Distribution Shift in Tabular Data with TableShift},
author={Gardner, Josh and Popovic, Zoran and Schmidt, Ludwig},
journal={Advances in Neural Information Processing Systems},
year={2023}
}
If you find an issue, please file a GitHub issue.
Environment setup: We recommend the use of docker with TableShift. Our dataset construction and model pipelines have a diverse set of dependencies that included non-Python files required to make some libraries work. As a result, we recommend you use the provided Docker image for using the benchmark, and suggest forking this Docker image for your own development.
# fetch the docker image
docker pull ghcr.io/jpgard/tableshift:latest
# run it to test your setup; this automatically launches examples/run_expt.py
docker run ghcr.io/jpgard/tableshift:latest --model xgb
# optionally, use the container interactively
docker run -it --entrypoint=/bin/bash ghcr.io/jpgard/tableshift:latest
Conda: We recommend using Docker with TableShift when running training or using any of the pretrained modeling code, as the libraries used for training contain a complex and subtle set of dependencies that can be difficult to configure outside Docker. However, Conda might provide a more lightweight environment for basic development and exploration with TableShift, so we describe how to set up Conda here.
To create a conda environment, simply clone this repo, enter the root directory, and run the following commands to create and test a local execution environment:
# set up the environment
conda env create -f environment.yml
conda activate tableshift
# test the install by running the training script
python examples/run_expt.py
The final line above will print some detailed logging output as the script executes. When you see training completed! test accuracy: 0.6221
your environment is ready to go! (Accuracy may vary slightly due to randomness.)
Accessing datasets: If you simply want to load and use a standard version of one of the public TableShift datasets, it's as simple as:
from tableshift import get_dataset
dataset_name = "diabetes_readmission"
dset = get_dataset(dataset_name)
The full list of identifiers for all available datasets is below; simply swap any of these for dataset_name
to access the relevant data.
If you would like to use a dataset without a domain split, replace get_dataset()
with get_iid_dataset()
.
The call to get_dataset()
returns a TabularDataset
that you can use to
easily load tabular data in several formats, including Pandas DataFrame and
PyTorch DataLoaders:
# Fetch a pandas DataFrame of the training set
X_tr, y_tr, _, _ = dset.get_pandas("train")
# Fetch and use a pytorch DataLoader
train_loader = dset.get_dataloader("train", batch_size=1024)
for X, y, _, _ in train_loader:
...
For all TableShift datasets, the following splits are
available: train
, validation
, id_test
, ood_validation
, ood_test
.
For IID datasets (those without a domain split) these splits are available: train
, validation
, test
.
There is a complete example of a training script in examples/run_expt.py
.
tl;dr: if you want to get started exploring ASAP, use datasets marked as " public" below.
All of the datasets used in the TableShift benchmark are either publicly available or provide open credentialized access. The datasets with open credentialized access require signing a data use agreement; as a result, some datasets must be manually fetched and stored locally. TableShift makes this process as simple as possible.
A list of datasets, their names in TableShift, and the corresponding access
levels are below. The string identifier is the value that should be passed as the experiment
parameter
to get_dataset()
or the --experiment
flag of run_expt.py
and other training scripts.
Dataset | String Identifier | Availability | Source |
---|---|---|---|
Voting | anes |
Public Credentialized Access (source) | American National Election Studies (ANES) |
ASSISTments | assistments |
Public | Kaggle |
Childhood Lead | nhanes_lead |
Public | National Health and Nutrition Examination Survey (NHANES) |
College Scorecard | college_scorecard |
Public | College Scorecard |
Diabetes | brfss_diabetes |
Public | Behavioral Risk Factor Surveillance System (BRFSS) |
Food Stamps | acsfoodstamps |
Public | American Community Survey (via folktables |
HELOC | heloc |
Public Credentialized Access (source) | FICO |
Hospital Readmission | diabetes_readmission |
Public | UCI |
Hypertension | brfss_blood_pressure |
Public | Behavioral Risk Factor Surveillance System (BRFSS) |
ICU Length of Stay | mimic_extract_los_3 |
Public Credentialized Access (source) | MIMIC-iii via MIMIC-Extract |
ICU Mortality | mimic_extract_mort_hosp |
Public Credentialized Access (source) | MIMIC-iii via MIMIC-Extract |
Income | acsincome |
Public | American Community Survey (via folktables |
Public Health Insurance | acspubcov |
Public | American Community Survey (via folktables |
Sepsis | physionet |
Public | Physionet |
Unemployment | acsunemployment |
Public | American Community Survey (via folktables |
Note that details on the data source, which files to load, and the feature
codings are provided in the TableShift source code for each dataset and data
source (see data_sources.py
and the tableshift.datasets
module).
For additional, non-benchmark datasets (possibly with only IID splits, not a distribution shift),
see tableshift.configs.non_benchmark.configs.py
More information about the tasks, datasets, splitting variables, data sources, and motivation are available in the TableShift paper; we provide a summary below.
Task | Target | Shift | Domain | Baseline | Total Observations |
---|---|---|---|---|---|
ASSISTments | Next Answer Correct | School | ✓ | -34.5% | 2,667,776 |
College Scorecard | Low Degree Completion Rate | Institution Type | ✓ | -11.2% | 124,699 |
ICU Mortality | ICU patient expires in hospital during current visit | Insurance Type | ✓ | -6.3% | 23,944 |
Hospital Readmission | 30-day readmission of diabetic hospital patients | Admission source | ✓ | -5.9% | 99,493 |
Diabetes | Diabetes diagnosis | Race | ✓ | -4.5% | 1,444,176 |
ICU Length of Stay | Length of stay >= 3 hrs in ICU | Insurance Type | ✓ | -3.4% | 23,944 |
Voting | Voted in U.S. presidential election | Geographic Region | ✓ | -2.6% | 8280 |
Food Stamps | Food stamp recipiency in past year for households with child | Geographic Region | ✓ | -2.4% | 840,582 |
Unemployment | Unemployment for non-social security-eligible adults | Education Level | ✓ | -1.3% | 1,795,434 |
Income | Income >= 56k for employed adults | Geographic Region | ✓ | -1.3% | 1,664,500 |
HELOC | Repayment of Home Equity Line of Credit loan | Est. third-party risk level | -22.6% | 10,459 | |
Public Health Insurance | Coverage of non-Medicare eligible low-income individuals | Disability Status | -14.5% | 5,916,565 | |
Sepsis | Sepsis onset within next 6hrs for hospital patients | Length of Stay | -6.0% | 1,552,210 | |
Childhood Lead | Blood lead levels above CDC Blood Level Reference Value | Poverty level | -5.1% | 27,499 | |
Hypertension | Hypertension diagnosis for high-risk age (50+) | BMI Category | -4.4% | 846,761 |
A sample training script is located at examples/run_expt.py
. However, training a scikit-learn model is as simple as:
from tableshift import get_dataset
from sklearn.ensemble import GradientBoostingClassifier
dset = get_dataset("diabetes_readmission")
X_train, y_train, _, _ = dset.get_pandas("train")
# Train
estimator = GradientBoostingClassifier()
trained_estimator = estimator.fit(X_train, y_train)
# Test
for split in ('id_test', 'ood_test'):
X, y, _, _ = dset.get_pandas(split)
preds = estimator.predict(X)
acc = (preds == y).mean()
print(f'accuracy on split {split} is: {acc:.3f}')
The code should output the following:
accuracy on split id_test is: 0.655
accuracy on split ood_test is: 0.619
Now, please close that domain gap!
We also have several tabular datasets available in TableShift which are not part of the official TableShift benchmark, but which still may be useful for tabular data research. We are continuously adding datasets to the package. These datasets support all of the same functionality provided for the TableShift benchmark datasets, but we did not include these as an official part of the TableShift benchmark -- they are not an official part of the TableShift package and are mostly intended for convenience and for our own internal use.
For a list of the non-benchmark datasets, see the file tableshift.configs.non_benchmark_configs.py
.