You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We are currently using xgboost 1.6.2 and are trying to upgrade to 2.1.1. On the way through the versions, we observed the following prediction time averages:
1.6.2: 15ms
1.7.6: 17ms
2.0.3: 43ms
2.1.1: 110ms
As you can see, there is a big jump from 1.7 to 2.0, and then an even bigger jump from 2.0 to 2.1. It's not easy for me to share the model unfortunately, but I found this related bug report & updated the scripts to my use case: #8865
import time
import numpy as np
import pandas as pd
import xgboost
from sklearn.datasets import load_iris, load_digits
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier
MODEL_NAME = '/tmp/model.model'
def train_model():
data = load_digits()
X_train, X_test, y_train, y_test = train_test_split(data['data'], data['target'], test_size=.2)
dtrain = xgboost.DMatrix(X_train, label=y_train)
params = {'max_depth':3, 'eta':1, 'objective':'reg:linear', 'eval_metric':'rmse'}
bst = xgboost.train(params, dtrain, 10, [(dtrain, 'train')])
bst.save_model(MODEL_NAME)
def predict_np_array():
bst = xgboost.Booster()
bst.set_param({"nthread": 1})
bst.load_model(fname=MODEL_NAME)
times = []
np.random.seed(7)
iterations = 10000
for _ in range(iterations):
sample = np.random.uniform(-1, 10, size=(1, 64))
start = time.time()
bst.inplace_predict(sample)
times.append(time.time() - start)
iter_time = sum(times[iterations // 2:]) / iterations / 2
print("np.array iter_time: ", iter_time * 1000, "ms")
def predict_sklearn():
xgb = XGBClassifier()
xgb.set_params(n_jobs=1, nthread=1)
xgb.load_model(fname=MODEL_NAME)
times = []
np.random.seed(7)
iterations = 500
attrs = {f"{i}" for i in range(64)}
for _ in range(iterations):
sample = pd.DataFrame({ind: [np.random.uniform(-1, 10)] for ind in attrs})
start = time.time()
xgb.predict_proba(sample)
times.append(time.time() - start)
iter_time = sum(times[iterations // 2:]) / iterations / 2
print("DataFrame iter_time: ", iter_time * 1000, "ms")
if __name__ == "__main__":
train_model()
for i in range(10):
predict_np_array()
predict_sklearn()
I get the following times when they stabilize:
1.7.6:
np.array iter_time: 0.012594342231750488 ms
DataFrame iter_time: 0.3071410655975342 ms
2.1.1
np.array iter_time: 0.03231525421142578 ms
DataFrame iter_time: 1.8953888416290283 ms
While not as severe for this artificial model, it still looks like a significant performance degradation. I see now that using pd.DataFrame is a lot worse than np.array, so I think I can work around my issue. But it is still surprising to me that the performance regressed that significantly.
Additional context
Our production model has the following attributes (extracted from the model.json, in case that is helpful):
Thank you for opening the issue. Yeah, we have added some more inspection for pd.DataFrame due to support for its extension. But the performance degradation looks bad, will do some profiling.
We are currently using xgboost 1.6.2 and are trying to upgrade to 2.1.1. On the way through the versions, we observed the following prediction time averages:
1.6.2: 15ms
1.7.6: 17ms
2.0.3: 43ms
2.1.1: 110ms
As you can see, there is a big jump from 1.7 to 2.0, and then an even bigger jump from 2.0 to 2.1. It's not easy for me to share the model unfortunately, but I found this related bug report & updated the scripts to my use case: #8865
I get the following times when they stabilize:
While not as severe for this artificial model, it still looks like a significant performance degradation. I see now that using
pd.DataFrame
is a lot worse thannp.array
, so I think I can work around my issue. But it is still surprising to me that the performance regressed that significantly.Additional context
Our production model has the following attributes (extracted from the model.json, in case that is helpful):
The model was trained on xgboost 1.5.2 but then re-saved on 2.1.1.
The
requirements.lock
fileI used these version locks when measuring the above numbers. The only change to the file being
xgboost==1.7.6
when testing for that version.All tested on
Ubuntu 24.04.1
,11th Gen Intel(R) Core(TM) i7-11800H
requirements_dev.zip
The text was updated successfully, but these errors were encountered: