diff --git a/core/include/moveit/task_constructor/stage_p.h b/core/include/moveit/task_constructor/stage_p.h index 5d359b5c6..b965ac198 100644 --- a/core/include/moveit/task_constructor/stage_p.h +++ b/core/include/moveit/task_constructor/stage_p.h @@ -57,6 +57,13 @@ namespace moveit { namespace task_constructor { +/// exception thrown by StagePrivate::runCompute() +class PreemptStageException : public std::exception +{ +public: + explicit PreemptStageException() {} +}; + class ContainerBase; class StagePrivate { @@ -146,6 +153,10 @@ class StagePrivate bool storeFailures() const { return introspection_ != nullptr; } void runCompute() { ROS_DEBUG_STREAM_NAMED("Stage", fmt::format("Computing stage '{}'", name())); + + if (preempted()) + throw PreemptStageException(); + auto compute_start_time = std::chrono::steady_clock::now(); try { compute(); @@ -159,6 +170,11 @@ class StagePrivate /** compute cost for solution through configured CostTerm */ void computeCost(const InterfaceState& from, const InterfaceState& to, SolutionBase& solution); + void setPreemptRequestedMember(const std::atomic* preempt_requested) { + preempt_requested_ = preempt_requested; + } + bool preempted() const { return preempt_requested_ != nullptr && *preempt_requested_; } + protected: StagePrivate& operator=(StagePrivate&& other); @@ -197,6 +213,8 @@ class StagePrivate InterfaceWeakPtr next_starts_; // interface to be used for sendForward() Introspection* introspection_; // task's introspection instance + + const std::atomic* preempt_requested_; }; PIMPL_FUNCTIONS(Stage) std::ostream& operator<<(std::ostream& os, const StagePrivate& stage); diff --git a/core/src/stage.cpp b/core/src/stage.cpp index 97c99826b..9fbf4d27e 100644 --- a/core/src/stage.cpp +++ b/core/src/stage.cpp @@ -102,7 +102,8 @@ StagePrivate::StagePrivate(Stage* me, const std::string& name) , cost_term_{ std::make_unique() } , total_compute_time_{} , parent_{ nullptr } - , introspection_{ nullptr } {} + , introspection_{ nullptr } + , preempt_requested_{ nullptr } {} StagePrivate& StagePrivate::operator=(StagePrivate&& other) { assert(typeid(*this) == typeid(other)); diff --git a/core/src/task.cpp b/core/src/task.cpp index dbba73890..ab375d2c7 100644 --- a/core/src/task.cpp +++ b/core/src/task.cpp @@ -213,11 +213,12 @@ void Task::init() { // task expects its wrapped child to push to both ends, this triggers interface resolution stages()->pimpl()->resolveInterface(InterfaceFlags({ GENERATE })); - // provide introspection instance to all stages + // provide introspection instance and preempt_requested to all stages auto* introspection = impl->introspection_.get(); impl->traverseStages( - [introspection](Stage& stage, int /*depth*/) { + [introspection, impl](Stage& stage, int /*depth*/) { stage.pimpl()->setIntrospection(introspection); + stage.pimpl()->setPreemptRequestedMember(&impl->preempt_requested_); return true; }, 1, UINT_MAX); @@ -232,7 +233,11 @@ bool Task::canCompute() const { } void Task::compute() { - stages()->pimpl()->runCompute(); + try { + stages()->pimpl()->runCompute(); + } catch (const PreemptStageException& e) { + // do nothing, needed for early stop + } } moveit::core::MoveItErrorCode Task::plan(size_t max_solutions) { diff --git a/core/test/stage_mockups.h b/core/test/stage_mockups.h index fc75c633e..35d01e659 100644 --- a/core/test/stage_mockups.h +++ b/core/test/stage_mockups.h @@ -65,6 +65,7 @@ struct GeneratorMockup : public Generator void init(const moveit::core::RobotModelConstPtr& robot_model) override; bool canCompute() const override; void compute() override; + virtual void reset() override { runs_ = 0; }; }; struct MonitoringGeneratorMockup : public MonitoringGenerator @@ -81,6 +82,7 @@ struct MonitoringGeneratorMockup : public MonitoringGenerator bool canCompute() const override { return false; } void compute() override {} void onNewSolution(const SolutionBase& s) override; + virtual void reset() override { runs_ = 0; }; }; struct ConnectMockup : public Connecting @@ -97,6 +99,7 @@ struct ConnectMockup : public Connecting using Connecting::compatible; // make this accessible for testing void compute(const InterfaceState& from, const InterfaceState& to) override; + virtual void reset() override { runs_ = 0; }; }; struct PropagatorMockup : public PropagatingEitherWay @@ -113,6 +116,7 @@ struct PropagatorMockup : public PropagatingEitherWay void computeForward(const InterfaceState& from) override; void computeBackward(const InterfaceState& to) override; + virtual void reset() override { runs_ = 0; }; }; struct ForwardMockup : public PropagatorMockup diff --git a/core/test/test_container.cpp b/core/test/test_container.cpp index 3746c2390..5d7b722c9 100644 --- a/core/test/test_container.cpp +++ b/core/test/test_container.cpp @@ -674,21 +674,20 @@ TEST(Task, timeout) { } // https://github.com/moveit/moveit_task_constructor/pull/597 +// https://github.com/moveit/moveit_task_constructor/pull/598 // start planning in another thread, then preempt it in this thread -TEST(Task, preempt) { +TEST_F(TaskTestBase, preempt) { moveit::core::MoveItErrorCode ec; resetMockupIds(); - Task t; - t.setRobotModel(getModel()); - auto timeout = std::chrono::milliseconds(10); - t.add(std::make_unique(PredefinedCosts::constant(0.0))); - t.add(std::make_unique(timeout)); + auto gen1 = add(t, new GeneratorMockup(PredefinedCosts::constant(0.0))); + auto fwd1 = add(t, new TimedForwardMockup(timeout)); + auto fwd2 = add(t, new TimedForwardMockup(timeout)); // preempt before preempt_request_ is reset in plan() { - std::thread thread{ [&ec, &t, timeout] { + std::thread thread{ [&ec, this, timeout] { std::this_thread::sleep_for(timeout); ec = t.plan(1); } }; @@ -698,5 +697,23 @@ TEST(Task, preempt) { EXPECT_EQ(ec, moveit::core::MoveItErrorCode::PREEMPTED); EXPECT_EQ(t.solutions().size(), 0u); + EXPECT_EQ(gen1->runs_, 0u); + EXPECT_EQ(fwd1->runs_, 0u); + EXPECT_EQ(fwd2->runs_, 0u); + EXPECT_TRUE(t.plan(1)); // make sure the preempt request has been resetted on the previous call to plan() + + t.reset(); + { + std::thread thread{ [&ec, this] { ec = t.plan(1); } }; + std::this_thread::sleep_for(timeout / 2.0); + t.preempt(); + thread.join(); + } + + EXPECT_EQ(ec, moveit::core::MoveItErrorCode::PREEMPTED); + EXPECT_EQ(t.solutions().size(), 0u); + EXPECT_EQ(gen1->runs_, 1u); + EXPECT_EQ(fwd1->runs_, 1u); + EXPECT_EQ(fwd2->runs_, 0u); EXPECT_TRUE(t.plan(1)); // make sure the preempt request has been resetted on the previous call to plan() }