diff --git a/include/brynet/base/Timer.hpp b/include/brynet/base/Timer.hpp index 4876794..7c2041a 100644 --- a/include/brynet/base/Timer.hpp +++ b/include/brynet/base/Timer.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -75,7 +76,26 @@ class Timer final friend class TimerMgr; }; -class TimerMgr final +class RepeatTimer +{ +public: + using Ptr = std::shared_ptr; + + void cancel() + { + mCancel.store(true); + } + + bool isCancel() const + { + return mCancel.load(); + } + +private: + std::atomic_bool mCancel = {false}; +}; + +class TimerMgr final : public std::enable_shared_from_this { public: using Ptr = std::shared_ptr; @@ -95,6 +115,22 @@ class TimerMgr final return timer; } + template + RepeatTimer::Ptr addIntervalTimer( + std::chrono::nanoseconds interval, + F&& callback, + TArgs&&... args) + { + auto sharedThis = shared_from_this(); + auto repeatTimer = std::make_shared(); + auto wrapperCallback = std::bind(std::forward(callback), std::forward(args)...); + addTimer(interval, [sharedThis, interval, wrapperCallback, repeatTimer]() { + stubRepeatTimerCallback(sharedThis, interval, wrapperCallback, repeatTimer); + }); + + return repeatTimer; + } + void addTimer(const Timer::Ptr& timer) { mTimers.push(timer); @@ -145,6 +181,22 @@ class TimerMgr final } } +private: + static void stubRepeatTimerCallback(TimerMgr::Ptr timerMgr, + std::chrono::nanoseconds interval, + std::function callback, + RepeatTimer::Ptr repeatTimer) + { + if (repeatTimer->isCancel()) + { + return; + } + callback(); + timerMgr->addTimer(interval, [timerMgr, interval, callback, repeatTimer]() { + stubRepeatTimerCallback(timerMgr, interval, callback, repeatTimer); + }); + } + private: class CompareTimer { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index db2345a..877dcc9 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -3,6 +3,12 @@ enable_testing() include_directories("${PROJECT_SOURCE_DIR}/include/") add_executable(test_timer test_timer.cpp) +if(WIN32) + target_link_libraries(test_timer ws2_32) +elseif(UNIX) + find_package(Threads REQUIRED) + target_link_libraries(test_timer pthread) +endif() add_test(TestTimer test_timer) add_executable(test_wait_group test_wait_group.cpp) diff --git a/tests/test_timer.cpp b/tests/test_timer.cpp index a7f2d0b..4c527df 100644 --- a/tests/test_timer.cpp +++ b/tests/test_timer.cpp @@ -1,7 +1,9 @@ #define CATCH_CONFIG_MAIN// This tells Catch to provide a main() - only do this in one cpp file #include +#include #include #include +#include #include "catch.hpp" @@ -50,3 +52,33 @@ TEST_CASE("Timer are computed", "[timer]") REQUIRE(upvalue == 2); } + +TEST_CASE("repeat timer are computed", "[repeat timer]") +{ + auto timerMgr = std::make_shared(); + auto wg = brynet::base::WaitGroup::Create(); + wg->add(1); + + std::atomic_int value = 0; + auto timer = timerMgr->addIntervalTimer(std::chrono::milliseconds(100), [&]() { + if (value.load() < 10) + { + value.fetch_add(1); + } + else + { + wg->done(); + } + }); + + std::thread t([&]() { + wg->wait(); + timer->cancel(); + }); + while (!timerMgr->isEmpty()) + { + timerMgr->schedule(); + } + t.join(); + REQUIRE(value.load() == 10); +}