From 3af79e3f84dc762e05f07d488eaf8ca19864aa86 Mon Sep 17 00:00:00 2001 From: Julien Viet Date: Tue, 17 Oct 2023 14:18:16 +0200 Subject: [PATCH] The protocol of the TaskQueue unschedule method is racy, since sometimes the task resume might happen before the thread is suspended. The protocol should be modified to implement a race free interaction. --- .../java/io/vertx/core/impl/TaskQueue.java | 54 ++++++--- .../io/vertx/core/impl/WorkerExecutor.java | 46 ++++++- .../java/io/vertx/core/TaskQueueTest.java | 113 ++++++++++++------ 3 files changed, 157 insertions(+), 56 deletions(-) diff --git a/src/main/java/io/vertx/core/impl/TaskQueue.java b/src/main/java/io/vertx/core/impl/TaskQueue.java index 7898a66d78f..fd4632d44c5 100644 --- a/src/main/java/io/vertx/core/impl/TaskQueue.java +++ b/src/main/java/io/vertx/core/impl/TaskQueue.java @@ -15,9 +15,10 @@ import io.vertx.core.impl.logging.LoggerFactory; import java.util.LinkedList; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; import java.util.concurrent.RejectedExecutionException; -import java.util.function.Consumer; +import java.util.concurrent.TimeUnit; /** * A task queue that always run all tasks in order. The executor to run the tasks is passed @@ -82,17 +83,12 @@ private void run() { } /** - * Unschedule the current task from execution, the next task in the queue will be executed - * when there is one. + * Return a controller for the current task. * - *

When the current task wants to be resumed, it should call the returned consumer with a command - * to unpark the thread (e.g most likely yielding a latch), this task will be executed immediately if there - * is no tasks being executed, otherwise it will be added first in the queue. - * - * @return a mean to signal to resume the thread when it shall be resumed + * @return the controller * @throws IllegalStateException if the current thread is not currently being executed by the queue */ - public Consumer unschedule() { + public WorkerExecutor.TaskController current() { Thread thread; Executor executor; synchronized (tasks) { @@ -101,20 +97,42 @@ public Consumer unschedule() { } thread = currentThread; executor = currentExecutor; - currentThread = null; } - executor.execute(runner); - return r -> { - synchronized (tasks) { - if (currentExecutor != null) { - tasks.addFirst(new ResumeTask(r, executor, thread)); - return; - } else { + return new WorkerExecutor.TaskController() { + + final CountDownLatch latch = new CountDownLatch(1); + + @Override + public void resume(Runnable callback) { + Runnable task = () -> { + callback.run(); + latch.countDown(); + }; + synchronized (tasks) { + if (currentExecutor != null) { + tasks.addFirst(new ResumeTask(task, executor, thread)); + return; + } currentExecutor = executor; currentThread = thread; } + task.run(); + } + + @Override + public CountDownLatch suspend() { + if (Thread.currentThread() != thread) { + throw new IllegalStateException(); + } + synchronized (tasks) { + if (currentThread == null || currentThread != Thread.currentThread()) { + throw new IllegalStateException(); + } + currentThread = null; + } + executor.execute(runner); + return latch; } - r.run(); }; } diff --git a/src/main/java/io/vertx/core/impl/WorkerExecutor.java b/src/main/java/io/vertx/core/impl/WorkerExecutor.java index 8853c6d4ca4..33e2bc28a69 100644 --- a/src/main/java/io/vertx/core/impl/WorkerExecutor.java +++ b/src/main/java/io/vertx/core/impl/WorkerExecutor.java @@ -12,7 +12,8 @@ import io.vertx.core.spi.metrics.PoolMetrics; -import java.util.function.Consumer; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; /** * Execute events on a worker pool. @@ -60,9 +61,46 @@ public void execute(Runnable command) { } /** - * See {@link TaskQueue#unschedule()}. + * See {@link TaskQueue#current()}. */ - public Consumer unschedule() { - return orderedTasks.unschedule(); + public TaskController current() { + return orderedTasks.current(); + } + + public interface TaskController { + + /** + * Resume the task, the {@code callback} will be executed when the task is resumed, before the task thread + * is unparked. + * + * @param callback called when the task is resumed + */ + void resume(Runnable callback); + + /** + * Like {@link #resume(Runnable)}. + */ + default void resume() { + resume(() -> {}); + } + + /** + * Suspend the task execution and park the current thread until the task is resumed. + * The next task in the queue will be executed, when there is one. + * + *

When the task wants to be resumed, it should call {@link #resume}, this will be executed immediately if there + * is no other tasks being executed, otherwise it will be added first in the queue. + */ + default void suspendAndAwaitResume() throws InterruptedException { + suspend().await(); + } + + /** + * Like {@link #suspendAndAwaitResume()} but does not await the task to be resumed. + * + * @return the latch to await + */ + CountDownLatch suspend(); + } } diff --git a/src/test/java/io/vertx/core/TaskQueueTest.java b/src/test/java/io/vertx/core/TaskQueueTest.java index 55f06b97181..7a4b21fdccc 100644 --- a/src/test/java/io/vertx/core/TaskQueueTest.java +++ b/src/test/java/io/vertx/core/TaskQueueTest.java @@ -1,6 +1,7 @@ package io.vertx.core; import io.vertx.core.impl.TaskQueue; +import io.vertx.core.impl.WorkerExecutor; import io.vertx.test.core.AsyncTestBase; import org.junit.Test; @@ -8,9 +9,9 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; public class TaskQueueTest extends AsyncTestBase { @@ -40,6 +41,14 @@ protected void tearDown() throws Exception { super.tearDown(); } + private void suspendAndAwaitResume(WorkerExecutor.TaskController controller) { + try { + controller.suspendAndAwaitResume(); + } catch (InterruptedException e) { + fail(e); + } + } + @Test public void testCreateThread() throws Exception { AtomicReference thread = new AtomicReference<>(); @@ -57,18 +66,14 @@ public void testCreateThread() throws Exception { @Test public void testAwaitSchedulesOnNewThread() { - CountDownLatch latch = new CountDownLatch(1); taskQueue.execute(() -> { Thread current = Thread.currentThread(); taskQueue.execute(() -> { assertNotSame(current, Thread.currentThread()); testComplete(); }, executor); - taskQueue.unschedule(); - try { - latch.await(); - } catch (InterruptedException ignore) { - } + WorkerExecutor.TaskController cont = taskQueue.current(); + suspendAndAwaitResume(cont); }, executor); await(); } @@ -76,22 +81,15 @@ public void testAwaitSchedulesOnNewThread() { @Test public void testResumeFromAnotherThread() { taskQueue.execute(() -> { - CountDownLatch latch = new CountDownLatch(1); - Consumer detach = taskQueue.unschedule(); + WorkerExecutor.TaskController continuation = taskQueue.current(); new Thread(() -> { try { Thread.sleep(100); } catch (InterruptedException e) { throw new RuntimeException(e); } - detach.accept(() -> { - latch.countDown(); - }); - try { - latch.await(10, TimeUnit.SECONDS); - } catch (InterruptedException e) { - fail(e); - } + continuation.resume(); + suspendAndAwaitResume(continuation); }).start(); testComplete(); }, executor); @@ -101,8 +99,7 @@ public void testResumeFromAnotherThread() { @Test public void testResumeFromContextThread() { taskQueue.execute(() -> { - CountDownLatch latch = new CountDownLatch(1); - Consumer detach = taskQueue.unschedule(); + WorkerExecutor.TaskController continuation = taskQueue.current(); taskQueue.execute(() -> { // Make sure the awaiting thread will block on the internal future before resolving it (could use thread status) try { @@ -110,15 +107,9 @@ public void testResumeFromContextThread() { } catch (InterruptedException e) { throw new RuntimeException(e); } - detach.accept(() -> { - latch.countDown(); - }); + continuation.resume(); }, executor); - try { - latch.await(10, TimeUnit.SECONDS); - } catch (InterruptedException e) { - fail(e); - } + suspendAndAwaitResume(continuation); testComplete(); }, executor); await(); @@ -127,8 +118,7 @@ public void testResumeFromContextThread() { @Test public void testResumeWhenIdle() { taskQueue.execute(() -> { - CountDownLatch latch = new CountDownLatch(1); - Consumer l = taskQueue.unschedule(); + WorkerExecutor.TaskController cont = taskQueue.current(); AtomicReference ref = new AtomicReference<>(); new Thread(() -> { Thread th; @@ -143,15 +133,70 @@ public void testResumeWhenIdle() { } catch (InterruptedException ignore) { ignore.printStackTrace(System.out); } - l.accept(latch::countDown); + cont.resume(); }).start(); taskQueue.execute(() -> ref.set(Thread.currentThread()), executor); - try { - latch.await(); - } catch (InterruptedException ignore) { - } + suspendAndAwaitResume(cont); testComplete(); }, executor); await(); } + + @Test + public void testRaceResumeBeforeSuspend() { + AtomicInteger seq = new AtomicInteger(); + taskQueue.execute(() -> { + taskQueue.execute(() -> { + WorkerExecutor.TaskController cont = taskQueue.current(); + cont.resume(() -> { + assertEquals(1, seq.getAndIncrement()); + }); + assertEquals(0, seq.getAndIncrement()); + suspendAndAwaitResume(cont); + assertEquals(2, seq.getAndIncrement()); + }, executor); + taskQueue.execute(() -> { + assertEquals(3, seq.getAndIncrement()); + testComplete(); + }, executor); + }, executor); + await(); + } + + // Need to do unschedule when nested test! + + @Test + public void testUnscheduleRace2() { + AtomicInteger seq = new AtomicInteger(); + taskQueue.execute(() -> { + CompletableFuture cf = new CompletableFuture<>(); + taskQueue.execute(() -> { + assertEquals("vert.x-0", Thread.currentThread().getName()); + assertEquals(0, seq.getAndIncrement()); + WorkerExecutor.TaskController cont = taskQueue.current(); + cf.whenComplete((v, e) -> cont.resume(() -> { + assertEquals("vert.x-1", Thread.currentThread().getName()); + assertEquals(2, seq.getAndIncrement()); + })); + suspendAndAwaitResume(cont); + }, executor); + AtomicBoolean enqueued = new AtomicBoolean(); + taskQueue.execute(() -> { + assertEquals("vert.x-1", Thread.currentThread().getName()); + assertEquals(1, seq.getAndIncrement()); + while (!enqueued.get()) { + // Wait until next task is enqueued + } + cf.complete(null); + }, executor); + taskQueue.execute(() -> { + assertEquals("vert.x-0", Thread.currentThread().getName()); + assertEquals(3, seq.getAndIncrement()); + testComplete(); + }, executor); + enqueued.set(true); + }, executor); + + await(); + } }