-
-
Notifications
You must be signed in to change notification settings - Fork 8.8k
Description
Hi I'm getting above error, while training sparkXGBRegressor:
Please Help me:
Here is my code and sample data
xgb_aft = SparkXGBRegressor( features_col="features", label_col="label_lower", label_lower_bound_col="label_lower", label_upper_bound_col="label_upper", objective="survival:aft", eval_metric="aft-nloglik", aft_loss_distribution="normal", aft_loss_distribution_scale=1.0, num_workers=2, max_depth=6, eta=0.1, num_round=1000, # Early stopping enabled # validation_indicator_col="is_validation", # early_stopping_rounds=30 ) aft_model = xgb_aft.fit(train_df)
Sample data:
train_df:
masked_client_code | snapshot_date | time_to_event | event_observed | label_lower | label_upper | features |
---|---|---|---|---|---|---|
81e2 | 2024-06-30 | 90 | 0 | 90 | Infinity | {"vectorType":"dense","length":6,"values":[28,14,9999,2100,2,0]} |
81e2 | 2024-07-31 | 90 | 0 | 90 | Infinity | {"vectorType":"dense","length":6,"values":[21,21,9999,2200,3,0]} |
81e2 | 2024-08-31 | 90 | 0 | 90 | Infinity | {"vectorType":"dense","length":6,"values":[19,19,9999,4300,4,0]} |
81e2 | 2024-09-30 | 90 | 0 | 90 | Infinity | {"vectorType":"dense","length":6,"values":[19,19,9999,6400,3,0]} |
81e2 | 2024-10-31 | 90 | 0 | 90 | Infinity | {"vectorType":"dense","length":6,"values":[21,21,9999,8500,3,0]} |
81e2 | 2024-11-30 | 90 | 0 | 90 | Infinity | {"vectorType":"dense","length":6,"values":[18,18,9999,10600,3,0]} |
81e2 | 2024-12-31 | 90 | 0 | 90 | Infinity | {"vectorType":"dense","length":6,"values":[20,20,9999,12700,3,0]} |
81e2 | 2025-01-31 | 90 | 0 | 90 | Infinity | {"vectorType":"dense","length":6,"values":[16,16,9999,14800,3,0]} |
3fcf | 2025-01-31 | 90 | 0 | 90 | Infinity | {"vectorType":"dense","length":6,"values":[26,26,9999,200,1,0]} |
cdfa | 2024-08-31 | 87 | 1 | 87 | 87 | {"vectorType":"dense","length":6,"values":[24,2,9999,14000,7,0]} |
cdfa | 2024-09-30 | 57 | 1 | 57 | 57 | {"vectorType":"dense","length":6,"values":[25,3,9999,16000,7,0]} |
cdfa | 2024-10-31 | 26 | 1 | 26 | 26 | {"vectorType":"dense","length":6,"values":[24,2,9999,18000,7,0]} |
error stack traces:
Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
: org.apache.spark.SparkException: Job aborted due to stage failure: Could not recover from a failed barrier ResultStage. Most recent failure reason: Stage failed because barrier task ResultTask(262, 0) finished unsuccessfully.
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
File "/databricks/python/lib/python3.10/site-packages/xgboost/spark/core.py", line 836, in _train_booster
booster = worker_train(
File "/databricks/python/lib/python3.10/site-packages/xgboost/core.py", line 620, in inner_f
return func(**kwargs)
File "/databricks/python/lib/python3.10/site-packages/xgboost/training.py", line 185, in train
bst.update(dtrain, i, obj)
File "/databricks/python/lib/python3.10/site-packages/xgboost/core.py", line 1918, in update
_check_call(_LIB.XGBoosterUpdateOneIter(self.handle,
File "/databricks/python/lib/python3.10/site-packages/xgboost/core.py", line 279, in _check_call
raise XGBoostError(py_str(LIB.XGBGetLastError()))
xgboost.core.XGBoostError: [09:51:43] ../src/objective/aft_obj.cu:76: Check failed: info.labels_lower_bound.Size() == ndata (0 vs. 4208258) :
File , line 1
----> 1 aft_model = xgb_aft.fit(train_df)
2 # CORRECTED: Get validation score from training summary
3 summary = aft_model.summary
File /databricks/python_shell/lib/dbruntime/MLWorkloadsInstrumentation/_pyspark.py:30, in _create_patch_function..patched_method(self, *args, **kwargs)
28 call_succeeded = False
29 try:
---> 30 result = original_method(self, *args, **kwargs)
31 call_succeeded = True
32 return result
File /databricks/python/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py:571, in safe_patch..safe_patch_function(*args, **kwargs)
569 patch_function.call(call_original, *args, **kwargs)
570 else:
--> 571 patch_function(call_original, *args, **kwargs)
573 session.state = "succeeded"
575 try_log_autologging_event(
576 AutologgingEventLogger.get_logger().log_patch_function_success,
577 session,
(...)
581 kwargs,
582 )
File /databricks/python/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py:250, in with_managed_run..patch_with_managed_run(original, *args, **kwargs)
247 managed_run = create_managed_run()
249 try:
--> 250 result = patch_function(original, *args, **kwargs)
251 except (Exception, KeyboardInterrupt):
252 # In addition to standard Python exceptions, handle keyboard interrupts to ensure
253 # that runs are terminated if a user prematurely interrupts training execution
254 # (e.g. via sigint / ctrl-c)
255 if managed_run:
File /databricks/python/lib/python3.10/site-packages/mlflow/pyspark/ml/init.py:1139, in autolog..patched_fit(original, self, *args, **kwargs)
1137 if t.should_log():
1138 with _AUTOLOGGING_METRICS_MANAGER.disable_log_post_training_metrics():
-> 1139 fit_result = fit_mlflow(original, self, *args, **kwargs)
1140 # In some cases the fit_result
may be an iterator of spark models.
1141 if should_log_post_training_metrics and isinstance(fit_result, Model):
File /databricks/python/lib/python3.10/site-packages/mlflow/pyspark/ml/init.py:1125, in autolog..fit_mlflow(original, self, *args, **kwargs)
1123 input_training_df = args[0].persist(StorageLevel.MEMORY_AND_DISK)
1124 _log_pretraining_metadata(estimator, params, input_training_df)
-> 1125 spark_model = original(self, args, **kwargs)
1126 _log_posttraining_metadata(estimator, spark_model, params, input_training_df)
1127 input_training_df.unpersist()
File /databricks/python/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py:552, in safe_patch..safe_patch_function..call_original(og_args, **og_kwargs)
549 original_result = original(_og_args, **_og_kwargs)
550 return original_result
--> 552 return call_original_fn_with_event_logging(_original_fn, og_args, og_kwargs)
File /databricks/python/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py:487, in safe_patch..safe_patch_function..call_original_fn_with_event_logging(original_fn, og_args, og_kwargs)
478 try:
479 try_log_autologging_event(
480 AutologgingEventLogger.get_logger().log_original_function_start,
481 session,
(...)
485 og_kwargs,
486 )
--> 487 original_fn_result = original_fn(og_args, **og_kwargs)
489 try_log_autologging_event(
490 AutologgingEventLogger.get_logger().log_original_function_success,
491 session,
(...)
495 og_kwargs,
496 )
497 return original_fn_result
File /databricks/python/lib/python3.10/site-packages/mlflow/utils/autologging_utils/safety.py:549, in safe_patch..safe_patch_function..call_original.._original_fn(_og_args, **_og_kwargs)
541 # Show all non-MLflow warnings as normal (i.e. not as event logs)
542 # during original function execution, even if silent mode is enabled
543 # (silent=True
), since these warnings originate from the ML framework
544 # or one of its dependencies and are likely relevant to the caller
545 with set_non_mlflow_warnings_behavior_for_current_thread(
546 disable_warnings=False,
547 reroute_warnings=False,
548 ):
--> 549 original_result = original(_og_args, **_og_kwargs)
550 return original_result
File /databricks/spark/python/pyspark/ml/base.py:203, in Estimator.fit(self, dataset, params)
201 return self.copy(params)._fit(dataset)
202 else:
--> 203 return self._fit(dataset)
204 else:
205 raise TypeError(
206 "Params must be either a param map or a list/tuple of param maps, "
207 "but got %s." % type(params)
208 )
File /databricks/python/lib/python3.10/site-packages/xgboost/spark/core.py:864, in _SparkXGBEstimator._fit(self, dataset)
854 ret = (
855 dataset.mapInPandas(
856 _train_booster, schema="config string, booster string"
(...)
860 .collect()[0]
861 )
862 return ret[0], ret[1]
--> 864 (config, booster) = _run_job()
866 result_xgb_model = self._convert_to_sklearn_model(
867 bytearray(booster, "utf-8"), config
868 )
869 spark_model = self._create_pyspark_model(result_xgb_model)
File /databricks/python/lib/python3.10/site-packages/xgboost/spark/core.py:860, in _SparkXGBEstimator._fit.._run_job()
853 def _run_job():
854 ret = (
855 dataset.mapInPandas(
856 _train_booster, schema="config string, booster string"
857 )
858 .rdd.barrier()
859 .mapPartitions(lambda x: x)
--> 860 .collect()[0]
861 )
862 return ret[0], ret[1]
File /databricks/spark/python/pyspark/instrumentation_utils.py:47, in _wrap_function..wrapper(*args, **kwargs)
45 start = time.perf_counter()
46 try:
---> 47 res = func(*args, **kwargs)
48 logger.log_success(
49 module_name, class_name, function_name, time.perf_counter() - start, signature
50 )
51 return res
File /databricks/spark/python/pyspark/rdd.py:1856, in RDD.collect(self)
1854 with SCCallSiteSync(self.context):
1855 assert self.ctx._jvm is not None
-> 1856 sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
1857 return list(_load_from_socket(sock_info, self._jrdd_deserializer))
File /databricks/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/java_gateway.py:1355, in JavaMember.call(self, *args)
1349 command = proto.CALL_COMMAND_NAME +
1350 self.command_header +
1351 args_command +
1352 proto.END_COMMAND_PART
1354 answer = self.gateway_client.send_command(command)
-> 1355 return_value = get_return_value(
1356 answer, self.gateway_client, self.target_id, self.name)
1358 for temp_arg in temp_args:
1359 if hasattr(temp_arg, "_detach"):
File /databricks/spark/python/pyspark/errors/exceptions/captured.py:224, in capture_sql_exception..deco(*a, **kw)
222 def deco(*a: Any, **kw: Any) -> Any:
223 try:
--> 224 return f(*a, **kw)
225 except Py4JJavaError as e:
226 converted = convert_exception(e.java_exception)
File /databricks/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/protocol.py:326, in get_return_value(answer, gateway_client, target_id, name)
324 value = OUTPUT_CONVERTER[type](answer[2:], gateway_client)
325 if answer[1] == REFERENCE_TYPE:
--> 326 raise Py4JJavaError(
327 "An error occurred while calling {0}{1}{2}.\n".
328 format(target_id, ".", name), value)
329 else:
330 raise Py4JError(
331 "An error occurred while calling {0}{1}{2}. Trace:\n{3}\n".
332 format(target_id, ".", name, value))