diff --git a/connector/src/main/scala/com/datastax/spark/connector/writer/AsyncExecutor.scala b/connector/src/main/scala/com/datastax/spark/connector/writer/AsyncExecutor.scala index eef2cd128..a52da60da 100644 --- a/connector/src/main/scala/com/datastax/spark/connector/writer/AsyncExecutor.scala +++ b/connector/src/main/scala/com/datastax/spark/connector/writer/AsyncExecutor.scala @@ -1,8 +1,7 @@ package com.datastax.spark.connector.writer -import java.util.concurrent.{CompletionStage, Semaphore} +import java.util.concurrent.{CompletableFuture, CompletionStage, Semaphore} import java.util.function.BiConsumer - import com.datastax.spark.connector.util.Logging import scala.jdk.CollectionConverters._ @@ -41,9 +40,14 @@ class AsyncExecutor[T, R](asyncAction: T => CompletionStage[R], maxConcurrentTas val executionTimestamp = System.nanoTime() def tryFuture(): Future[R] = { - val value = asyncAction(task) - - value.whenComplete(new BiConsumer[R, Throwable] { + val value = Try(asyncAction(task)) recover { + case e => + val future = new CompletableFuture[R]() + future.completeExceptionally(e) + future + } + + value.get.whenComplete(new BiConsumer[R, Throwable] { private def release() { semaphore.release() pendingFutures.remove(promise.future) diff --git a/connector/src/test/scala/com/datastax/spark/connector/writer/AsyncExecutorTest.scala b/connector/src/test/scala/com/datastax/spark/connector/writer/AsyncExecutorTest.scala index 9ab34696f..5f4c0adeb 100644 --- a/connector/src/test/scala/com/datastax/spark/connector/writer/AsyncExecutorTest.scala +++ b/connector/src/test/scala/com/datastax/spark/connector/writer/AsyncExecutorTest.scala @@ -1,5 +1,7 @@ package com.datastax.spark.connector.writer +import com.datastax.oss.driver.api.core.cql.{AsyncResultSet, SimpleStatement, Statement} + import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.{Callable, CompletableFuture, CompletionStage} @@ -55,4 +57,20 @@ class AsyncExecutorTest { totalFinishedExecutionsCounter.get() shouldBe taskCount asyncExecutor.getLatestException() shouldBe None } + + @Test + def testGracefullyHandleCqlSessionExecuteExceptions() { + val executor = new AsyncExecutor[Statement[_], AsyncResultSet]( + _ => { + // simulate exception returned by session.executeAsync() (not future) + throw new IllegalStateException("something bad happened") + }, 10, None, None + ) + val stmt = SimpleStatement.newInstance("INSERT INTO table1 (key, value) VALUES (1, '100')"); + val future = executor.executeAsync(stmt) + assertTrue(future.isCompleted) + val value = future.value.get + assertTrue(value.isInstanceOf[Failure[_]]) + assertTrue(value.asInstanceOf[Failure[_]].exception.isInstanceOf[IllegalStateException]) + } }