Skip to content

Commit

Permalink
The protocol of the TaskQueue unschedule method is racy, since someti…
Browse files Browse the repository at this point in the history
…mes the task resume might happen before the thread is suspended. The protocol should be modified to implement a race free interaction.
  • Loading branch information
vietj committed Oct 18, 2023
1 parent 6dcac21 commit 3af79e3
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 56 deletions.
54 changes: 36 additions & 18 deletions src/main/java/io/vertx/core/impl/TaskQueue.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
import io.vertx.core.impl.logging.LoggerFactory;

import java.util.LinkedList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.RejectedExecutionException;
import java.util.function.Consumer;
import java.util.concurrent.TimeUnit;

/**
* A task queue that always run all tasks in order. The executor to run the tasks is passed
Expand Down Expand Up @@ -82,17 +83,12 @@ private void run() {
}

/**
* Unschedule the current task from execution, the next task in the queue will be executed
* when there is one.
* Return a controller for the current task.
*
* <p>When the current task wants to be resumed, it should call the returned consumer with a command
* to unpark the thread (e.g most likely yielding a latch), this task will be executed immediately if there
* is no tasks being executed, otherwise it will be added first in the queue.
*
* @return a mean to signal to resume the thread when it shall be resumed
* @return the controller
* @throws IllegalStateException if the current thread is not currently being executed by the queue
*/
public Consumer<Runnable> unschedule() {
public WorkerExecutor.TaskController current() {
Thread thread;
Executor executor;
synchronized (tasks) {
Expand All @@ -101,20 +97,42 @@ public Consumer<Runnable> unschedule() {
}
thread = currentThread;
executor = currentExecutor;
currentThread = null;
}
executor.execute(runner);
return r -> {
synchronized (tasks) {
if (currentExecutor != null) {
tasks.addFirst(new ResumeTask(r, executor, thread));
return;
} else {
return new WorkerExecutor.TaskController() {

final CountDownLatch latch = new CountDownLatch(1);

@Override
public void resume(Runnable callback) {
Runnable task = () -> {
callback.run();
latch.countDown();
};
synchronized (tasks) {
if (currentExecutor != null) {
tasks.addFirst(new ResumeTask(task, executor, thread));
return;
}
currentExecutor = executor;
currentThread = thread;
}
task.run();
}

@Override
public CountDownLatch suspend() {
if (Thread.currentThread() != thread) {
throw new IllegalStateException();
}
synchronized (tasks) {
if (currentThread == null || currentThread != Thread.currentThread()) {
throw new IllegalStateException();
}
currentThread = null;
}
executor.execute(runner);
return latch;
}
r.run();
};
}

Expand Down
46 changes: 42 additions & 4 deletions src/main/java/io/vertx/core/impl/WorkerExecutor.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

import io.vertx.core.spi.metrics.PoolMetrics;

import java.util.function.Consumer;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

/**
* Execute events on a worker pool.
Expand Down Expand Up @@ -60,9 +61,46 @@ public void execute(Runnable command) {
}

/**
* See {@link TaskQueue#unschedule()}.
* See {@link TaskQueue#current()}.
*/
public Consumer<Runnable> unschedule() {
return orderedTasks.unschedule();
public TaskController current() {
return orderedTasks.current();
}

public interface TaskController {

/**
* Resume the task, the {@code callback} will be executed when the task is resumed, before the task thread
* is unparked.
*
* @param callback called when the task is resumed
*/
void resume(Runnable callback);

/**
* Like {@link #resume(Runnable)}.
*/
default void resume() {
resume(() -> {});
}

/**
* Suspend the task execution and park the current thread until the task is resumed.
* The next task in the queue will be executed, when there is one.
*
* <p>When the task wants to be resumed, it should call {@link #resume}, this will be executed immediately if there
* is no other tasks being executed, otherwise it will be added first in the queue.
*/
default void suspendAndAwaitResume() throws InterruptedException {
suspend().await();
}

/**
* Like {@link #suspendAndAwaitResume()} but does not await the task to be resumed.
*
* @return the latch to await
*/
CountDownLatch suspend();

}
}
113 changes: 79 additions & 34 deletions src/test/java/io/vertx/core/TaskQueueTest.java
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
package io.vertx.core;

import io.vertx.core.impl.TaskQueue;
import io.vertx.core.impl.WorkerExecutor;
import io.vertx.test.core.AsyncTestBase;
import org.junit.Test;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;

public class TaskQueueTest extends AsyncTestBase {

Expand Down Expand Up @@ -40,6 +41,14 @@ protected void tearDown() throws Exception {
super.tearDown();
}

private void suspendAndAwaitResume(WorkerExecutor.TaskController controller) {
try {
controller.suspendAndAwaitResume();
} catch (InterruptedException e) {
fail(e);
}
}

@Test
public void testCreateThread() throws Exception {
AtomicReference<Thread> thread = new AtomicReference<>();
Expand All @@ -57,41 +66,30 @@ public void testCreateThread() throws Exception {

@Test
public void testAwaitSchedulesOnNewThread() {
CountDownLatch latch = new CountDownLatch(1);
taskQueue.execute(() -> {
Thread current = Thread.currentThread();
taskQueue.execute(() -> {
assertNotSame(current, Thread.currentThread());
testComplete();
}, executor);
taskQueue.unschedule();
try {
latch.await();
} catch (InterruptedException ignore) {
}
WorkerExecutor.TaskController cont = taskQueue.current();
suspendAndAwaitResume(cont);
}, executor);
await();
}

@Test
public void testResumeFromAnotherThread() {
taskQueue.execute(() -> {
CountDownLatch latch = new CountDownLatch(1);
Consumer<Runnable> detach = taskQueue.unschedule();
WorkerExecutor.TaskController continuation = taskQueue.current();
new Thread(() -> {
try {
Thread.sleep(100);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
detach.accept(() -> {
latch.countDown();
});
try {
latch.await(10, TimeUnit.SECONDS);
} catch (InterruptedException e) {
fail(e);
}
continuation.resume();
suspendAndAwaitResume(continuation);
}).start();
testComplete();
}, executor);
Expand All @@ -101,24 +99,17 @@ public void testResumeFromAnotherThread() {
@Test
public void testResumeFromContextThread() {
taskQueue.execute(() -> {
CountDownLatch latch = new CountDownLatch(1);
Consumer<Runnable> detach = taskQueue.unschedule();
WorkerExecutor.TaskController continuation = taskQueue.current();
taskQueue.execute(() -> {
// Make sure the awaiting thread will block on the internal future before resolving it (could use thread status)
try {
Thread.sleep(100);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
detach.accept(() -> {
latch.countDown();
});
continuation.resume();
}, executor);
try {
latch.await(10, TimeUnit.SECONDS);
} catch (InterruptedException e) {
fail(e);
}
suspendAndAwaitResume(continuation);
testComplete();
}, executor);
await();
Expand All @@ -127,8 +118,7 @@ public void testResumeFromContextThread() {
@Test
public void testResumeWhenIdle() {
taskQueue.execute(() -> {
CountDownLatch latch = new CountDownLatch(1);
Consumer<Runnable> l = taskQueue.unschedule();
WorkerExecutor.TaskController cont = taskQueue.current();
AtomicReference<Thread> ref = new AtomicReference<>();
new Thread(() -> {
Thread th;
Expand All @@ -143,15 +133,70 @@ public void testResumeWhenIdle() {
} catch (InterruptedException ignore) {
ignore.printStackTrace(System.out);
}
l.accept(latch::countDown);
cont.resume();
}).start();
taskQueue.execute(() -> ref.set(Thread.currentThread()), executor);
try {
latch.await();
} catch (InterruptedException ignore) {
}
suspendAndAwaitResume(cont);
testComplete();
}, executor);
await();
}

@Test
public void testRaceResumeBeforeSuspend() {
AtomicInteger seq = new AtomicInteger();
taskQueue.execute(() -> {
taskQueue.execute(() -> {
WorkerExecutor.TaskController cont = taskQueue.current();
cont.resume(() -> {
assertEquals(1, seq.getAndIncrement());
});
assertEquals(0, seq.getAndIncrement());
suspendAndAwaitResume(cont);
assertEquals(2, seq.getAndIncrement());
}, executor);
taskQueue.execute(() -> {
assertEquals(3, seq.getAndIncrement());
testComplete();
}, executor);
}, executor);
await();
}

// Need to do unschedule when nested test!

@Test
public void testUnscheduleRace2() {
AtomicInteger seq = new AtomicInteger();
taskQueue.execute(() -> {
CompletableFuture<Void> cf = new CompletableFuture<>();
taskQueue.execute(() -> {
assertEquals("vert.x-0", Thread.currentThread().getName());
assertEquals(0, seq.getAndIncrement());
WorkerExecutor.TaskController cont = taskQueue.current();
cf.whenComplete((v, e) -> cont.resume(() -> {
assertEquals("vert.x-1", Thread.currentThread().getName());
assertEquals(2, seq.getAndIncrement());
}));
suspendAndAwaitResume(cont);
}, executor);
AtomicBoolean enqueued = new AtomicBoolean();
taskQueue.execute(() -> {
assertEquals("vert.x-1", Thread.currentThread().getName());
assertEquals(1, seq.getAndIncrement());
while (!enqueued.get()) {
// Wait until next task is enqueued
}
cf.complete(null);
}, executor);
taskQueue.execute(() -> {
assertEquals("vert.x-0", Thread.currentThread().getName());
assertEquals(3, seq.getAndIncrement());
testComplete();
}, executor);
enqueued.set(true);
}, executor);

await();
}
}

0 comments on commit 3af79e3

Please sign in to comment.