From d90e53e6b35320426da2664155dbeb3c7e0c57d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alen=20Vre=C4=8Dko?= <332217+avrecko@users.noreply.github.com> Date: Wed, 1 May 2024 18:58:51 +0200 Subject: [PATCH] Fixing race condition between event processing thread doing onExit and caller thread doing onStart. We moved onStart processing to event processing thread. This way we make sure onStart is called first before any other EPoll event processing is done. --- .../zaxxer/nuprocess/linux/LinuxProcess.java | 15 +- .../zaxxer/nuprocess/linux/ProcessEpoll.java | 13 ++ .../nuprocess/linux/LinuxProcessTest.java | 185 ++++++++++++++++++ 3 files changed, 209 insertions(+), 4 deletions(-) create mode 100644 src/test/java/com/zaxxer/nuprocess/linux/LinuxProcessTest.java diff --git a/src/main/java/com/zaxxer/nuprocess/linux/LinuxProcess.java b/src/main/java/com/zaxxer/nuprocess/linux/LinuxProcess.java index d4c6bca8..58e26196 100644 --- a/src/main/java/com/zaxxer/nuprocess/linux/LinuxProcess.java +++ b/src/main/java/com/zaxxer/nuprocess/linux/LinuxProcess.java @@ -37,6 +37,8 @@ public class LinuxProcess extends BasePosixProcess { private final EpollEvent epollEvent; + private boolean startCalled; + static { LibEpoll.sigignore(LibEpoll.SIGPIPE); @@ -76,12 +78,11 @@ public NuProcess start(List command, String[] environment, Path cwd) { afterStart(); registerProcess(); - - callStart(); } catch (Exception e) { // TODO remove from event processor pid map? LOGGER.log(Level.WARNING, "Failed to start process", e); + // TODO we might has successfully spawned but failed on epoll registration, do we call destroy? onExit(Integer.MIN_VALUE); return null; } @@ -106,8 +107,6 @@ public void run(List command, String[] environment, Path cwd) myProcessor = (IEventProcessor) new ProcessEpoll(this); - callStart(); - myProcessor.run(); } catch (Exception e) { @@ -229,4 +228,12 @@ private static byte[] toEnvironmentBlock(String[] environment) { return block; } + + public void ensureOnStartHandlerMethodCalled() { + // only called via event processing thread + if (!startCalled) { + startCalled = true; + processHandler.onStart(this); + } + } } diff --git a/src/main/java/com/zaxxer/nuprocess/linux/ProcessEpoll.java b/src/main/java/com/zaxxer/nuprocess/linux/ProcessEpoll.java index 54c09fc0..a8236c4b 100644 --- a/src/main/java/com/zaxxer/nuprocess/linux/ProcessEpoll.java +++ b/src/main/java/com/zaxxer/nuprocess/linux/ProcessEpoll.java @@ -19,6 +19,8 @@ import java.util.Iterator; import java.util.LinkedList; import java.util.List; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.Queue; import com.sun.jna.Native; import com.sun.jna.ptr.IntByReference; @@ -39,6 +41,7 @@ class ProcessEpoll extends BaseEventProcessor private final int epoll; private final EpollEvent triggeredEvent; private final List deadPool; + private final Queue callStartQueue = new ConcurrentLinkedQueue<>(); private LinuxProcess process; @@ -112,6 +115,8 @@ public void registerProcess(LinuxProcess process) int errno = Native.getLastError(); throw new RuntimeException("Unable to register new events to epoll, errno: " + errno); } + // TODO do we need to kill the process if we failed epoll registration? + callStartQueue.add(process); } finally { if (stdinFd != Integer.MIN_VALUE) { @@ -186,6 +191,12 @@ public void closeStdin(LinuxProcess process) @Override public boolean process() { + // handle case for processes that don't trigger epoll events + LinuxProcess p; + while ((p = callStartQueue.poll()) != null) { + p.ensureOnStartHandlerMethodCalled(); + } + int stdinFd = Integer.MIN_VALUE; int stdoutFd = Integer.MIN_VALUE; int stderrFd = Integer.MIN_VALUE; @@ -219,6 +230,8 @@ public boolean process() return true; } + linuxProcess.ensureOnStartHandlerMethodCalled(); + stdinFd = linuxProcess.getStdin().acquire(); stdoutFd = linuxProcess.getStdout().acquire(); stderrFd = linuxProcess.getStderr().acquire(); diff --git a/src/test/java/com/zaxxer/nuprocess/linux/LinuxProcessTest.java b/src/test/java/com/zaxxer/nuprocess/linux/LinuxProcessTest.java new file mode 100644 index 00000000..a32d18de --- /dev/null +++ b/src/test/java/com/zaxxer/nuprocess/linux/LinuxProcessTest.java @@ -0,0 +1,185 @@ +package com.zaxxer.nuprocess.linux; + +import com.zaxxer.nuprocess.NuAbstractProcessHandler; +import com.zaxxer.nuprocess.NuProcess; +import com.zaxxer.nuprocess.NuProcessBuilder; +import org.junit.Assert; +import org.junit.Test; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; + + +/** + * Need to make sure there are no race conditions. + */ +public class LinuxProcessTest { + + @Test + public void respectOnStartInRegardsToOnExit() throws Exception { + if (!System.getProperty("os.name").toLowerCase().contains("linux")) { + return; + } + for (int i = 0; i < 100; i++) { + + final CountDownLatch latchStart = new CountDownLatch(1); + final CountDownLatch latchExit = new CountDownLatch(1); + final AtomicInteger exitCode = new AtomicInteger(-1); + final AtomicLong onStartNanoTime = new AtomicLong(); + final AtomicLong onExitNanoTime = new AtomicLong(); + final AtomicInteger methodsCallCount = new AtomicInteger(); + + new NuProcessBuilder(new NuAbstractProcessHandler() { + @Override + public void onStart(NuProcess nuProcess) { + onStartNanoTime.set(System.nanoTime()); + methodsCallCount.incrementAndGet(); + nuProcess.wantWrite(); + latchStart.countDown(); + } + + @Override + public void onExit(int statusCode) { + exitCode.set(statusCode); + onExitNanoTime.set(System.nanoTime()); + methodsCallCount.incrementAndGet(); + latchExit.countDown(); + } + }, "echo", "foo").start(); + + boolean await1 = latchStart.await(1, TimeUnit.SECONDS); + boolean await2 = latchExit.await(1, TimeUnit.SECONDS); + Assert.assertTrue(await1); + Assert.assertTrue(await2); + Assert.assertEquals(2, methodsCallCount.get()); + Assert.assertTrue(onExitNanoTime.get() >= onStartNanoTime.get()); + Assert.assertEquals(0, exitCode.get()); + } + } + + @Test + public void startShouldBeCalledEvenIfNoEpollEventsHappen() throws Exception { + if (!System.getProperty("os.name").toLowerCase().contains("linux")) { + return; + } + for (int i = 0; i < 100; i++) { + final CountDownLatch latchStart = new CountDownLatch(1); + final CountDownLatch latchExit = new CountDownLatch(1); + final AtomicInteger exitCode = new AtomicInteger(-1); + final AtomicLong onStartNanoTime = new AtomicLong(); + final AtomicLong onExitNanoTime = new AtomicLong(); + final AtomicInteger methodsCallCount = new AtomicInteger(); + + NuProcess process = new NuProcessBuilder(new NuAbstractProcessHandler() { + @Override + public void onStart(NuProcess nuProcess) { + onStartNanoTime.set(System.nanoTime()); + methodsCallCount.incrementAndGet(); + nuProcess.wantWrite(); + latchStart.countDown(); + } + + @Override + public void onExit(int statusCode) { + exitCode.set(statusCode); + onExitNanoTime.set(System.nanoTime()); + methodsCallCount.incrementAndGet(); + latchExit.countDown(); + } + + }, "sleep", "infinity").start(); + + boolean await1 = latchStart.await(1, TimeUnit.SECONDS); + Assert.assertTrue(await1); + + process.destroy(false); + + boolean await2 = latchExit.await(1, TimeUnit.SECONDS); + Assert.assertTrue(await2); + + Assert.assertEquals(2, methodsCallCount.get()); + Assert.assertTrue(onExitNanoTime.get() >= onStartNanoTime.get()); + Assert.assertTrue(exitCode.get() != 0); + } + } + + @Test + public void testStdinHandling() throws Exception { + if (!System.getProperty("os.name").toLowerCase().contains("linux")) { + return; + } + final byte[] testStringBytes = "TEST STRING\n".getBytes(); + for (int i = 0; i < 100; i++) { + final CountDownLatch latchStart = new CountDownLatch(1); + final CountDownLatch latchExit = new CountDownLatch(1); + final AtomicInteger exitCode = new AtomicInteger(-1); + final AtomicLong onStartNanoTime = new AtomicLong(); + final AtomicLong onExitNanoTime = new AtomicLong(); + final AtomicInteger methodsCallCount = new AtomicInteger(); + final AtomicInteger problems = new AtomicInteger(); + final AtomicInteger readCount = new AtomicInteger(); + final AtomicInteger wroteCount = new AtomicInteger(); + + NuProcess process = new NuProcessBuilder(new NuAbstractProcessHandler() { + private NuProcess nuProcess; + + @Override + public void onStart(NuProcess nuProcess) { + onStartNanoTime.set(System.nanoTime()); + methodsCallCount.incrementAndGet(); + nuProcess.wantWrite(); + this.nuProcess = nuProcess; + latchStart.countDown(); + } + + @Override + public void onExit(int statusCode) { + exitCode.set(statusCode); + onExitNanoTime.set(System.nanoTime()); + methodsCallCount.incrementAndGet(); + latchExit.countDown(); + } + + @Override + public void onStdout(ByteBuffer buffer, boolean closed) { + if (buffer.remaining() > 0) { + byte[] dst = new byte[buffer.remaining()]; + ByteBuffer put = buffer.get(dst); + if (!Arrays.equals(dst, testStringBytes)) { + problems.incrementAndGet(); + } + readCount.incrementAndGet(); + // should produce an infinite loop, we exit via calling destroy + nuProcess.wantWrite(); + } + } + + @Override + public boolean onStdinReady(ByteBuffer buffer) { + buffer.put(testStringBytes); + buffer.flip(); + wroteCount.incrementAndGet(); + return false; + } + }, "cat").start(); + + boolean await1 = latchStart.await(1, TimeUnit.SECONDS); + Assert.assertTrue(await1); + Thread.sleep(100); + process.destroy(false); + boolean await2 = latchExit.await(1, TimeUnit.SECONDS); + Assert.assertTrue(await2); + Assert.assertEquals(2, methodsCallCount.get()); + Assert.assertTrue(onExitNanoTime.get() >= onStartNanoTime.get()); + Assert.assertEquals(0, problems.get()); + Assert.assertTrue(readCount.get() > 10); + // it is OK to diff by 1 + Assert.assertEquals(readCount.get(), wroteCount.get(), 1); + Assert.assertTrue(exitCode.get() != 0); + } + } +} \ No newline at end of file