Skip to content

Commit d4033fa

Browse files
Updated callback mechanism, termination callback working
1 parent 670c922 commit d4033fa

File tree

11 files changed

+479
-44
lines changed

11 files changed

+479
-44
lines changed

src/Enums.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,12 @@ enum class E_DualSolutionSource
5151

5252
enum class E_EventType
5353
{
54+
ExternalDualBound,
55+
ExternalHyperplaneSelection,
56+
ExternalPrimalSolution,
5457
NewPrimalSolution,
58+
PrimalSolutionCandidateSelection,
5559
UserTerminationCheck,
56-
ExternalHyperplaneSelection,
5760
};
5861

5962
enum class E_HyperplaneSource

src/EventHandler.h

Lines changed: 195 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,47 +11,230 @@
1111
#pragma once
1212
#include "Environment.h"
1313
#include "Enums.h"
14+
#include "Output.h"
1415

1516
#include <any>
1617
#include <functional>
1718
#include <map>
1819
#include <vector>
1920
#include <utility>
21+
#include <optional>
22+
#include <type_traits>
23+
#include <stdexcept>
2024

2125
namespace SHOT
2226
{
2327

28+
/**
29+
* @brief EventHandler that supports both notification callbacks and data providers
30+
*
31+
* This class provides a single registration interface that automatically detects whether
32+
* a callback is a notification callback (returns void) or a data provider (returns a value).
33+
*
34+
* Usage examples:
35+
*
36+
* // Data provider for user termination check - returns bool
37+
* eventHandler.registerCallback(E_EventType::UserTerminationCheck, []() {
38+
* return shouldTerminate(); // Returns bool -> data provider for termination check
39+
* });
40+
*
41+
* // Data provider for dual bound - returns double
42+
* eventHandler.registerCallback(E_EventType::ExternalDualBound, []() {
43+
* return computeDualBound(); // Returns double -> data provider
44+
* });
45+
*
46+
* // Data provider for primal solution - returns std::vector<VectorDouble>
47+
* eventHandler.registerCallback(E_EventType::ExternalPrimalSolution, []() {
48+
* return getSolution(); // Returns std::vector<VectorDouble> -> data provider
49+
* });
50+
*
51+
* // Notification with parameters - returns void
52+
* eventHandler.registerCallback(E_EventType::NewPrimalSolution, [](std::any solution) {
53+
* processSolution(solution); // Returns void -> notification callback
54+
* });
55+
*/
2456
class EventHandler
2557
{
2658
public:
2759
inline EventHandler(EnvironmentPtr envPtr) : env(envPtr) {};
2860

29-
// Register a callback for a specific event type
61+
/**
62+
* @brief Unified callback registration method
63+
*
64+
* This method automatically detects the callback type based on its signature:
65+
* - If callback returns void: Registered as notification callback
66+
* - If callback returns a value: Registered as data provider
67+
* - If callback takes no arguments: Compatible with both types
68+
* - If callback takes std::any argument: Notification callback with data
69+
*
70+
* @tparam Callback The callback function type (lambda, function pointer, etc.)
71+
* @param event The event type to register for
72+
* @param callback The callback function to register
73+
*/
3074
template <typename Callback> void registerCallback(const E_EventType& event, Callback&& callback)
3175
{
32-
registeredCallbacks[event].push_back([callback](std::any args) { callback(args); });
76+
77+
using CallbackType = std::decay_t<Callback>;
78+
79+
// Check if callback can be called with no arguments
80+
if constexpr(std::is_invocable_v<CallbackType>)
81+
{
82+
using ReturnType = std::invoke_result_t<CallbackType>;
83+
84+
if constexpr(std::is_void_v<ReturnType>)
85+
{
86+
// Notification callback - returns void
87+
notificationCallbacks[event].push_back([callback](std::any) { callback(); });
88+
89+
env->output->outputCritical("Registering callback for event: " + std::to_string(static_cast<int>(event))
90+
+ " (no args, no return)");
91+
}
92+
else
93+
{
94+
// Data provider - returns a value
95+
dataProviders[event] = [callback]() -> std::any { return std::any(callback()); };
96+
97+
env->output->outputCritical("Registering callback for event: " + std::to_string(static_cast<int>(event))
98+
+ " (no args, returns value)");
99+
}
100+
}
101+
// Check if callback can be called with std::any argument
102+
else if constexpr(std::is_invocable_v<CallbackType, std::any>)
103+
{
104+
using ReturnType = std::invoke_result_t<CallbackType, std::any>;
105+
106+
if constexpr(std::is_void_v<ReturnType>)
107+
{
108+
// Notification callback with arguments
109+
notificationCallbacks[event].push_back([callback](std::any args) { callback(args); });
110+
111+
env->output->outputCritical("Registering callback for event: " + std::to_string(static_cast<int>(event))
112+
+ " (args, no return)");
113+
}
114+
else
115+
{
116+
// Data provider with arguments
117+
parameterizedDataProviders[event]
118+
= [callback](std::any args) -> std::any { return std::any(callback(args)); };
119+
120+
env->output->outputCritical("Registering callback for event: " + std::to_string(static_cast<int>(event))
121+
+ " (args, returns value)");
122+
}
123+
}
124+
else
125+
{
126+
static_assert(std::is_invocable_v<CallbackType> || std::is_invocable_v<CallbackType, std::any>,
127+
"Callback must be invocable with either no arguments or std::any argument");
128+
}
33129
}
34130

35-
// Notify all callbacks registered for a specific event type
131+
/**
132+
* @brief Notify all callbacks registered for a specific event type
133+
*
134+
* @param event The event type to notify
135+
* @param args Arguments to pass to the callbacks (wrapped in std::any)
136+
*/
36137
void notify(const E_EventType& event, std::any args) const
37138
{
38-
if(registeredCallbacks.empty())
39-
return;
139+
env->output->outputTrace(
140+
"Notifying callbacks for event: " + std::to_string(static_cast<int>(event)) + " (args)");
141+
auto it = notificationCallbacks.find(event);
142+
if(it != notificationCallbacks.end())
143+
{
144+
for(const auto& callback : it->second)
145+
{
146+
callback(args);
147+
}
148+
}
149+
}
150+
151+
/**
152+
* @brief Notify callbacks with no arguments
153+
*
154+
* @param event The event type to notify
155+
*/
156+
void notify(const E_EventType& event) const
157+
{
158+
env->output->outputTrace(
159+
"Notifying callbacks for event: " + std::to_string(static_cast<int>(event)) + " (no args)");
160+
notify(event, std::any());
161+
}
40162

41-
auto it = registeredCallbacks.find(event);
42-
if(it == registeredCallbacks.end())
43-
return;
163+
/**
164+
* @brief Request data from a registered data provider (no arguments)
165+
*
166+
* @tparam ReturnType The expected return type
167+
* @param event The event type to request data for
168+
* @return std::optional<ReturnType> The data if available, std::nullopt otherwise
169+
*/
170+
template <typename ReturnType> std::optional<ReturnType> requestData(const E_EventType& event) const
171+
{
172+
auto it = dataProviders.find(event);
173+
if(it == dataProviders.end())
174+
return std::nullopt;
44175

45-
for(const auto& callback : it->second)
176+
try
177+
{
178+
std::any result = it->second();
179+
return std::any_cast<ReturnType>(result);
180+
}
181+
catch(const std::bad_any_cast&)
46182
{
47-
callback(args);
183+
return std::nullopt;
48184
}
49185
}
50186

187+
/**
188+
* @brief Request data from a registered data provider (with arguments)
189+
*
190+
* @tparam ReturnType The expected return type
191+
* @tparam ArgType The argument type to pass to the provider
192+
* @param event The event type to request data for
193+
* @param arg The argument to pass to the data provider
194+
* @return std::optional<ReturnType> The data if available, std::nullopt otherwise
195+
*/
196+
template <typename ReturnType, typename ArgType>
197+
std::optional<ReturnType> requestData(const E_EventType& event, const ArgType& arg) const
198+
{
199+
auto it = parameterizedDataProviders.find(event);
200+
if(it == parameterizedDataProviders.end())
201+
return std::nullopt;
202+
203+
try
204+
{
205+
std::any result = it->second(std::any(arg));
206+
return std::any_cast<ReturnType>(result);
207+
}
208+
catch(const std::bad_any_cast&)
209+
{
210+
return std::nullopt;
211+
}
212+
}
213+
214+
/**
215+
* @brief Check if a data provider is registered for an event
216+
*
217+
* @param event The event type to check
218+
* @return true if a data provider is registered, false otherwise
219+
*/
220+
bool hasDataProvider(const E_EventType& event) const
221+
{
222+
return dataProviders.find(event) != dataProviders.end()
223+
|| parameterizedDataProviders.find(event) != parameterizedDataProviders.end();
224+
}
225+
51226
private:
52-
// Map of event types to their registered callbacks
53-
std::map<E_EventType, std::vector<std::function<void(std::any)>>> registeredCallbacks;
227+
/// Map of event types to their registered notification callbacks
228+
std::map<E_EventType, std::vector<std::function<void(std::any)>>> notificationCallbacks;
229+
230+
/// Map of event types to data provider callbacks (no parameters)
231+
std::map<E_EventType, std::function<std::any()>> dataProviders;
54232

233+
/// Map of event types to parameterized data provider callbacks
234+
std::map<E_EventType, std::function<std::any(std::any)>> parameterizedDataProviders;
235+
236+
/// Pointer to the environment
55237
EnvironmentPtr env;
56238
};
239+
57240
} // namespace SHOT

src/ModelingSystem/EntryPointsGAMS.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,11 @@ extern "C"
191191
env->timing->stopTimer("ProblemInitialization");
192192

193193
solver.registerCallback(
194-
E_EventType::UserTerminationCheck, [&env, gev = (gevHandle_t)gmoEnvironment(gs->gmo)](std::any args) {
194+
E_EventType::UserTerminationCheck, [gev = (gevHandle_t)gmoEnvironment(gs->gmo)]() -> bool {
195195
if(gevTerminateGet(gev))
196-
env->tasks->terminate();
196+
return (true);
197+
198+
return (false);
197199
});
198200

199201
if(!solver.setProblem(problem, modelingSystem))

src/Solver.h

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,37 @@ class DllExport Solver
7575

7676
void finalizeSolution();
7777

78+
/**
79+
* @brief Callback registration method
80+
*
81+
* This method automatically detects whether the callback is a notification callback
82+
* or a data provider based on its return type:
83+
* - Returns void: Notification callback
84+
* - Returns a value: Data provider
85+
*
86+
* Examples:
87+
* // Data provider for dual bound
88+
* solver.registerCallback(E_EventType::ExternalDualBound, []() {
89+
* return computeDualBound(); // Returns double -> data provider
90+
* });
91+
*
92+
* // User termination check
93+
* solver.registerCallback(E_EventType::UserTerminationCheck, []() {
94+
* return shouldTerminate(); // Returns bool -> data provider
95+
* });
96+
*
97+
* // Notification callback
98+
* solver.registerCallback(E_EventType::NewPrimalSolution, [](std::any solution) {
99+
* processSolution(solution); // Returns void -> notification
100+
* });
101+
*
102+
* @tparam Callback The callback function type
103+
* @param event The event type to register for
104+
* @param callback The callback function
105+
*/
78106
template <typename Callback> inline void registerCallback(const E_EventType& event, Callback&& callback)
79107
{
80-
env->events->registerCallback(event, callback);
108+
env->events->registerCallback(event, std::forward<Callback>(callback));
81109
}
82110

83111
std::string getOptionsOSoL();

src/Tasks/TaskCheckUserTermination.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "TaskCheckUserTermination.h"
1212

1313
#include "../EventHandler.h"
14+
#include "../Output.h"
1415
#include "../Results.h"
1516
#include "../TaskHandler.h"
1617

@@ -28,8 +29,21 @@ TaskCheckUserTermination::~TaskCheckUserTermination() = default;
2829

2930
void TaskCheckUserTermination::run()
3031
{
32+
if(env->tasks->isTerminated())
33+
return;
34+
35+
// Notify callback
3136
env->events->notify(E_EventType::UserTerminationCheck, std::any());
3237

38+
// Check if user termination was requested
39+
if(env->events->hasDataProvider(E_EventType::UserTerminationCheck))
40+
{
41+
auto shouldTerminate = env->events->requestData<bool>(E_EventType::UserTerminationCheck);
42+
43+
if(shouldTerminate.has_value() && *shouldTerminate)
44+
env->tasks->terminate();
45+
}
46+
3347
if(env->tasks->isTerminated()
3448
|| env->results->getCurrentIteration()->solutionStatus == E_ProblemSolutionStatus::Abort)
3549
{

test/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ set(Solver_parts
5757
3
5858
4
5959
5
60-
6)
60+
6
61+
7)
6162
set(cpptests ${cpptests} Solver)
6263

6364
if(HAS_IPOPT)

test/CbcTest.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,14 @@ bool CbcTerminationCallbackTest(std::string filename)
125125
return (false);
126126
}
127127

128-
// Registers a callback that terminates in the third iteration
129-
solver->registerCallback(E_EventType::UserTerminationCheck, [&env](std::any args) {
128+
// Registers a callback that terminates after the third iteration
129+
solver->registerCallback(E_EventType::UserTerminationCheck, [&env]() -> bool {
130130
std::cout << "Callback activated. Terminating.\n";
131131

132-
if(env->results->getNumberOfIterations() == 3)
133-
env->tasks->terminate();
132+
if(env->results->getNumberOfIterations() > 3)
133+
return (true);
134+
135+
return (false);
134136
});
135137

136138
// Solving the problem
@@ -140,7 +142,7 @@ bool CbcTerminationCallbackTest(std::string filename)
140142
return (false);
141143
}
142144

143-
if(env->results->getNumberOfIterations() != 3)
145+
if(env->results->terminationReason != E_TerminationReason::UserAbort)
144146
{
145147
std::cout << "Termination callback did not seem to work as expected\n";
146148
return (false);

0 commit comments

Comments
 (0)