Skip to content

Fit modes give different results #631

@oscarkey

Description

@oscarkey

Describe the bug

Different settings of the fit_mode option give different outputs. The fit mode shouldn't affect the output, up to float precision issues.

Steps/Code to Reproduce

import random
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
import numpy as np
import torch

from tabpfn import TabPFNRegressor

X, y = load_diabetes(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.33, random_state=42
)


def _set_seeds() -> None:
    torch.manual_seed(0)
    np.random.seed(0)
    random.seed(0)


_set_seeds()
reg = TabPFNRegressor(fit_mode="low_memory", inference_precision=torch.float64)
reg.fit(X_train, y_train)
preds_no_cache = reg.predict(X_test)

reg = TabPFNRegressor(fit_mode="low_memory", inference_precision=torch.float64)
reg.fit(X_train, y_train)
preds_no_cache_repeat = reg.predict(X_test)

_set_seeds()
reg = TabPFNRegressor(fit_mode="fit_preprocessors", inference_precision=torch.float64)
reg.fit(X_train, y_train)
preds_cache_preproc = reg.predict(X_test)

_set_seeds()
reg = TabPFNRegressor(fit_mode="fit_with_cache", inference_precision=torch.float64)
reg.fit(X_train, y_train)
preds_kv_cache = reg.predict(X_test)


def _max_diff(a: np.ndarray, b: np.ndarray) -> float:
    return np.max(np.abs(a - b) / np.abs(a))


print("max relative diffs")
print("no_cache vs no_cache_repeat:", _max_diff(preds_no_cache, preds_no_cache_repeat))
print("no_cache vs cache_preproc:", _max_diff(preds_no_cache, preds_cache_preproc))
print("no_cache vs kv_cache:", _max_diff(preds_no_cache, preds_kv_cache))

I set TABPFN_EXCLUDE_DEVICES=mps to get float64 support on MacOS.

Expected Results

no_cache == no_cache_repeat: 0.0
no_cache == cache_preproc: ~0.0
no_cache == kv_cache: ~0.0

Actual Results

no_cache == no_cache_repeat: True
no_cache == cache_preproc: 0.044
no_cache == kv_cache: 1.423

Versions

PyTorch version: 2.9.0
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 26.1 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.4.4.1)
CMake version: version 4.1.2
Libc version: N/A

Python version: 3.10.17 (main, Apr  9 2025, 03:47:39) [Clang 20.1.0 ] (64-bit runtime)
Python platform: macOS-26.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M4

Dependency Versions:
--------------------
tabpfn: 6.0.6
torch: 2.9.0
numpy: 2.2.6
scipy: 1.15.3
pandas: 2.3.3
scikit-learn: 1.7.2
typing_extensions: 4.15.0
einops: 0.8.1
huggingface-hub: 1.1.2

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions