Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alternative fix for linux race condition between onStart and onExit. #156

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions src/main/java/com/zaxxer/nuprocess/linux/LinuxProcess.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ public class LinuxProcess extends BasePosixProcess
{
private final EpollEvent epollEvent;

private boolean startCalled;

static {
LibEpoll.sigignore(LibEpoll.SIGPIPE);

Expand Down Expand Up @@ -76,12 +78,11 @@ public NuProcess start(List<String> command, String[] environment, Path cwd) {
afterStart();

registerProcess();

callStart();
}
catch (Exception e) {
// TODO remove from event processor pid map?
LOGGER.log(Level.WARNING, "Failed to start process", e);
// TODO we might has successfully spawned but failed on epoll registration, do we call destroy?
onExit(Integer.MIN_VALUE);
return null;
}
Expand All @@ -106,8 +107,6 @@ public void run(List<String> command, String[] environment, Path cwd)

myProcessor = (IEventProcessor) new ProcessEpoll(this);

callStart();

myProcessor.run();
}
catch (Exception e) {
Expand Down Expand Up @@ -229,4 +228,12 @@ private static byte[] toEnvironmentBlock(String[] environment) {

return block;
}

public void ensureOnStartHandlerMethodCalled() {
// only called via event processing thread
if (!startCalled) {
startCalled = true;
processHandler.onStart(this);
}
}
}
13 changes: 13 additions & 0 deletions src/main/java/com/zaxxer/nuprocess/linux/ProcessEpoll.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.Queue;

import com.sun.jna.Native;
import com.sun.jna.ptr.IntByReference;
Expand All @@ -39,6 +41,7 @@ class ProcessEpoll extends BaseEventProcessor<LinuxProcess>
private final int epoll;
private final EpollEvent triggeredEvent;
private final List<LinuxProcess> deadPool;
private final Queue<LinuxProcess> callStartQueue = new ConcurrentLinkedQueue<>();

private LinuxProcess process;

Expand Down Expand Up @@ -112,6 +115,8 @@ public void registerProcess(LinuxProcess process)
int errno = Native.getLastError();
throw new RuntimeException("Unable to register new events to epoll, errno: " + errno);
}
// TODO do we need to kill the process if we failed epoll registration?
callStartQueue.add(process);
}
finally {
if (stdinFd != Integer.MIN_VALUE) {
Expand Down Expand Up @@ -186,6 +191,12 @@ public void closeStdin(LinuxProcess process)
@Override
public boolean process()
{
// handle case for processes that don't trigger epoll events
LinuxProcess p;
while ((p = callStartQueue.poll()) != null) {
p.ensureOnStartHandlerMethodCalled();
}

int stdinFd = Integer.MIN_VALUE;
int stdoutFd = Integer.MIN_VALUE;
int stderrFd = Integer.MIN_VALUE;
Expand Down Expand Up @@ -219,6 +230,8 @@ public boolean process()
return true;
}

linuxProcess.ensureOnStartHandlerMethodCalled();

stdinFd = linuxProcess.getStdin().acquire();
stdoutFd = linuxProcess.getStdout().acquire();
stderrFd = linuxProcess.getStderr().acquire();
Expand Down
185 changes: 185 additions & 0 deletions src/test/java/com/zaxxer/nuprocess/linux/LinuxProcessTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
package com.zaxxer.nuprocess.linux;

import com.zaxxer.nuprocess.NuAbstractProcessHandler;
import com.zaxxer.nuprocess.NuProcess;
import com.zaxxer.nuprocess.NuProcessBuilder;
import org.junit.Assert;
import org.junit.Test;

import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;


/**
* Need to make sure there are no race conditions.
*/
public class LinuxProcessTest {

@Test
public void respectOnStartInRegardsToOnExit() throws Exception {
if (!System.getProperty("os.name").toLowerCase().contains("linux")) {
return;
}
for (int i = 0; i < 100; i++) {

final CountDownLatch latchStart = new CountDownLatch(1);
final CountDownLatch latchExit = new CountDownLatch(1);
final AtomicInteger exitCode = new AtomicInteger(-1);
final AtomicLong onStartNanoTime = new AtomicLong();
final AtomicLong onExitNanoTime = new AtomicLong();
final AtomicInteger methodsCallCount = new AtomicInteger();

new NuProcessBuilder(new NuAbstractProcessHandler() {
@Override
public void onStart(NuProcess nuProcess) {
onStartNanoTime.set(System.nanoTime());
methodsCallCount.incrementAndGet();
nuProcess.wantWrite();
latchStart.countDown();
}

@Override
public void onExit(int statusCode) {
exitCode.set(statusCode);
onExitNanoTime.set(System.nanoTime());
methodsCallCount.incrementAndGet();
latchExit.countDown();
}
}, "echo", "foo").start();

boolean await1 = latchStart.await(1, TimeUnit.SECONDS);
boolean await2 = latchExit.await(1, TimeUnit.SECONDS);
Assert.assertTrue(await1);
Assert.assertTrue(await2);
Assert.assertEquals(2, methodsCallCount.get());
Assert.assertTrue(onExitNanoTime.get() >= onStartNanoTime.get());
Assert.assertEquals(0, exitCode.get());
}
}

@Test
public void startShouldBeCalledEvenIfNoEpollEventsHappen() throws Exception {
if (!System.getProperty("os.name").toLowerCase().contains("linux")) {
return;
}
for (int i = 0; i < 100; i++) {
final CountDownLatch latchStart = new CountDownLatch(1);
final CountDownLatch latchExit = new CountDownLatch(1);
final AtomicInteger exitCode = new AtomicInteger(-1);
final AtomicLong onStartNanoTime = new AtomicLong();
final AtomicLong onExitNanoTime = new AtomicLong();
final AtomicInteger methodsCallCount = new AtomicInteger();

NuProcess process = new NuProcessBuilder(new NuAbstractProcessHandler() {
@Override
public void onStart(NuProcess nuProcess) {
onStartNanoTime.set(System.nanoTime());
methodsCallCount.incrementAndGet();
nuProcess.wantWrite();
latchStart.countDown();
}

@Override
public void onExit(int statusCode) {
exitCode.set(statusCode);
onExitNanoTime.set(System.nanoTime());
methodsCallCount.incrementAndGet();
latchExit.countDown();
}

}, "sleep", "infinity").start();

boolean await1 = latchStart.await(1, TimeUnit.SECONDS);
Assert.assertTrue(await1);

process.destroy(false);

boolean await2 = latchExit.await(1, TimeUnit.SECONDS);
Assert.assertTrue(await2);

Assert.assertEquals(2, methodsCallCount.get());
Assert.assertTrue(onExitNanoTime.get() >= onStartNanoTime.get());
Assert.assertTrue(exitCode.get() != 0);
}
}

@Test
public void testStdinHandling() throws Exception {
if (!System.getProperty("os.name").toLowerCase().contains("linux")) {
return;
}
final byte[] testStringBytes = "TEST STRING\n".getBytes();
for (int i = 0; i < 100; i++) {
final CountDownLatch latchStart = new CountDownLatch(1);
final CountDownLatch latchExit = new CountDownLatch(1);
final AtomicInteger exitCode = new AtomicInteger(-1);
final AtomicLong onStartNanoTime = new AtomicLong();
final AtomicLong onExitNanoTime = new AtomicLong();
final AtomicInteger methodsCallCount = new AtomicInteger();
final AtomicInteger problems = new AtomicInteger();
final AtomicInteger readCount = new AtomicInteger();
final AtomicInteger wroteCount = new AtomicInteger();

NuProcess process = new NuProcessBuilder(new NuAbstractProcessHandler() {
private NuProcess nuProcess;

@Override
public void onStart(NuProcess nuProcess) {
onStartNanoTime.set(System.nanoTime());
methodsCallCount.incrementAndGet();
nuProcess.wantWrite();
this.nuProcess = nuProcess;
latchStart.countDown();
}

@Override
public void onExit(int statusCode) {
exitCode.set(statusCode);
onExitNanoTime.set(System.nanoTime());
methodsCallCount.incrementAndGet();
latchExit.countDown();
}

@Override
public void onStdout(ByteBuffer buffer, boolean closed) {
if (buffer.remaining() > 0) {
byte[] dst = new byte[buffer.remaining()];
ByteBuffer put = buffer.get(dst);
if (!Arrays.equals(dst, testStringBytes)) {
problems.incrementAndGet();
}
readCount.incrementAndGet();
// should produce an infinite loop, we exit via calling destroy
nuProcess.wantWrite();
}
}

@Override
public boolean onStdinReady(ByteBuffer buffer) {
buffer.put(testStringBytes);
buffer.flip();
wroteCount.incrementAndGet();
return false;
}
}, "cat").start();

boolean await1 = latchStart.await(1, TimeUnit.SECONDS);
Assert.assertTrue(await1);
Thread.sleep(100);
process.destroy(false);
boolean await2 = latchExit.await(1, TimeUnit.SECONDS);
Assert.assertTrue(await2);
Assert.assertEquals(2, methodsCallCount.get());
Assert.assertTrue(onExitNanoTime.get() >= onStartNanoTime.get());
Assert.assertEquals(0, problems.get());
Assert.assertTrue(readCount.get() > 10);
// it is OK to diff by 1
Assert.assertEquals(readCount.get(), wroteCount.get(), 1);
Assert.assertTrue(exitCode.get() != 0);
}
}
}