Skip to content

Commit

Permalink
Reduce stop time due to preempt (#598)
Browse files Browse the repository at this point in the history
The preempt_request_ flag was only checked at the top-level task container before each compute iteration.
As a single sweep might take a while, we should check the flag before computing each stage.
  • Loading branch information
captain-yoshi authored Jul 19, 2024
1 parent 4f69a22 commit fdc06c3
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 11 deletions.
18 changes: 18 additions & 0 deletions core/include/moveit/task_constructor/stage_p.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@
namespace moveit {
namespace task_constructor {

/// exception thrown by StagePrivate::runCompute()
class PreemptStageException : public std::exception
{
public:
explicit PreemptStageException() {}
};

class ContainerBase;
class StagePrivate
{
Expand Down Expand Up @@ -146,6 +153,10 @@ class StagePrivate
bool storeFailures() const { return introspection_ != nullptr; }
void runCompute() {
ROS_DEBUG_STREAM_NAMED("Stage", fmt::format("Computing stage '{}'", name()));

if (preempted())
throw PreemptStageException();

auto compute_start_time = std::chrono::steady_clock::now();
try {
compute();
Expand All @@ -159,6 +170,11 @@ class StagePrivate
/** compute cost for solution through configured CostTerm */
void computeCost(const InterfaceState& from, const InterfaceState& to, SolutionBase& solution);

void setPreemptRequestedMember(const std::atomic<bool>* preempt_requested) {
preempt_requested_ = preempt_requested;
}
bool preempted() const { return preempt_requested_ != nullptr && *preempt_requested_; }

protected:
StagePrivate& operator=(StagePrivate&& other);

Expand Down Expand Up @@ -197,6 +213,8 @@ class StagePrivate
InterfaceWeakPtr next_starts_; // interface to be used for sendForward()

Introspection* introspection_; // task's introspection instance

const std::atomic<bool>* preempt_requested_;
};
PIMPL_FUNCTIONS(Stage)
std::ostream& operator<<(std::ostream& os, const StagePrivate& stage);
Expand Down
3 changes: 2 additions & 1 deletion core/src/stage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ StagePrivate::StagePrivate(Stage* me, const std::string& name)
, cost_term_{ std::make_unique<CostTerm>() }
, total_compute_time_{}
, parent_{ nullptr }
, introspection_{ nullptr } {}
, introspection_{ nullptr }
, preempt_requested_{ nullptr } {}

StagePrivate& StagePrivate::operator=(StagePrivate&& other) {
assert(typeid(*this) == typeid(other));
Expand Down
11 changes: 8 additions & 3 deletions core/src/task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,12 @@ void Task::init() {
// task expects its wrapped child to push to both ends, this triggers interface resolution
stages()->pimpl()->resolveInterface(InterfaceFlags({ GENERATE }));

// provide introspection instance to all stages
// provide introspection instance and preempt_requested to all stages
auto* introspection = impl->introspection_.get();
impl->traverseStages(
[introspection](Stage& stage, int /*depth*/) {
[introspection, impl](Stage& stage, int /*depth*/) {
stage.pimpl()->setIntrospection(introspection);
stage.pimpl()->setPreemptRequestedMember(&impl->preempt_requested_);
return true;
},
1, UINT_MAX);
Expand All @@ -232,7 +233,11 @@ bool Task::canCompute() const {
}

void Task::compute() {
stages()->pimpl()->runCompute();
try {
stages()->pimpl()->runCompute();
} catch (const PreemptStageException& e) {
// do nothing, needed for early stop
}
}

moveit::core::MoveItErrorCode Task::plan(size_t max_solutions) {
Expand Down
4 changes: 4 additions & 0 deletions core/test/stage_mockups.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ struct GeneratorMockup : public Generator
void init(const moveit::core::RobotModelConstPtr& robot_model) override;
bool canCompute() const override;
void compute() override;
virtual void reset() override { runs_ = 0; };
};

struct MonitoringGeneratorMockup : public MonitoringGenerator
Expand All @@ -81,6 +82,7 @@ struct MonitoringGeneratorMockup : public MonitoringGenerator
bool canCompute() const override { return false; }
void compute() override {}
void onNewSolution(const SolutionBase& s) override;
virtual void reset() override { runs_ = 0; };
};

struct ConnectMockup : public Connecting
Expand All @@ -97,6 +99,7 @@ struct ConnectMockup : public Connecting
using Connecting::compatible; // make this accessible for testing

void compute(const InterfaceState& from, const InterfaceState& to) override;
virtual void reset() override { runs_ = 0; };
};

struct PropagatorMockup : public PropagatingEitherWay
Expand All @@ -113,6 +116,7 @@ struct PropagatorMockup : public PropagatingEitherWay

void computeForward(const InterfaceState& from) override;
void computeBackward(const InterfaceState& to) override;
virtual void reset() override { runs_ = 0; };
};

struct ForwardMockup : public PropagatorMockup
Expand Down
31 changes: 24 additions & 7 deletions core/test/test_container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -674,21 +674,20 @@ TEST(Task, timeout) {
}

// https://github.com/moveit/moveit_task_constructor/pull/597
// https://github.com/moveit/moveit_task_constructor/pull/598
// start planning in another thread, then preempt it in this thread
TEST(Task, preempt) {
TEST_F(TaskTestBase, preempt) {
moveit::core::MoveItErrorCode ec;
resetMockupIds();

Task t;
t.setRobotModel(getModel());

auto timeout = std::chrono::milliseconds(10);
t.add(std::make_unique<GeneratorMockup>(PredefinedCosts::constant(0.0)));
t.add(std::make_unique<TimedForwardMockup>(timeout));
auto gen1 = add(t, new GeneratorMockup(PredefinedCosts::constant(0.0)));
auto fwd1 = add(t, new TimedForwardMockup(timeout));
auto fwd2 = add(t, new TimedForwardMockup(timeout));

// preempt before preempt_request_ is reset in plan()
{
std::thread thread{ [&ec, &t, timeout] {
std::thread thread{ [&ec, this, timeout] {
std::this_thread::sleep_for(timeout);
ec = t.plan(1);
} };
Expand All @@ -698,5 +697,23 @@ TEST(Task, preempt) {

EXPECT_EQ(ec, moveit::core::MoveItErrorCode::PREEMPTED);
EXPECT_EQ(t.solutions().size(), 0u);
EXPECT_EQ(gen1->runs_, 0u);
EXPECT_EQ(fwd1->runs_, 0u);
EXPECT_EQ(fwd2->runs_, 0u);
EXPECT_TRUE(t.plan(1)); // make sure the preempt request has been resetted on the previous call to plan()

t.reset();
{
std::thread thread{ [&ec, this] { ec = t.plan(1); } };
std::this_thread::sleep_for(timeout / 2.0);
t.preempt();
thread.join();
}

EXPECT_EQ(ec, moveit::core::MoveItErrorCode::PREEMPTED);
EXPECT_EQ(t.solutions().size(), 0u);
EXPECT_EQ(gen1->runs_, 1u);
EXPECT_EQ(fwd1->runs_, 1u);
EXPECT_EQ(fwd2->runs_, 0u);
EXPECT_TRUE(t.plan(1)); // make sure the preempt request has been resetted on the previous call to plan()
}

0 comments on commit fdc06c3

Please sign in to comment.