Skip to content

Commit 365cbe3

Browse files
committed
Fix #2101
1 parent 4a7aae5 commit 365cbe3

File tree

3 files changed

+104
-5
lines changed

3 files changed

+104
-5
lines changed

README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,22 @@ svr.set_post_routing_handler([](const auto& req, auto& res) {
285285
});
286286
```
287287

288+
### Pre request handler
289+
290+
```cpp
291+
svr.set_pre_request_handler([](const auto& req, auto& res) {
292+
if (req.matched_route == "/user/:user") {
293+
auto user = req.path_params.at("user");
294+
if (user != "john") {
295+
res.status = StatusCode::Forbidden_403;
296+
res.set_content("error", "text/html");
297+
return Server::HandlerResponse::Handled;
298+
}
299+
}
300+
return Server::HandlerResponse::Unhandled;
301+
});
302+
```
303+
288304
### 'multipart/form-data' POST data
289305

290306
```cpp

httplib.h

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,7 @@ using Ranges = std::vector<Range>;
636636
struct Request {
637637
std::string method;
638638
std::string path;
639+
std::string matched_route;
639640
Params params;
640641
Headers headers;
641642
std::string body;
@@ -887,10 +888,16 @@ namespace detail {
887888

888889
class MatcherBase {
889890
public:
891+
MatcherBase(std::string pattern) : pattern_(pattern) {}
890892
virtual ~MatcherBase() = default;
891893

894+
const std::string &pattern() const { return pattern_; }
895+
892896
// Match request path and populate its matches and
893897
virtual bool match(Request &request) const = 0;
898+
899+
private:
900+
std::string pattern_;
894901
};
895902

896903
/**
@@ -942,7 +949,8 @@ class PathParamsMatcher final : public MatcherBase {
942949
*/
943950
class RegexMatcher final : public MatcherBase {
944951
public:
945-
RegexMatcher(const std::string &pattern) : regex_(pattern) {}
952+
RegexMatcher(const std::string &pattern)
953+
: MatcherBase(pattern), regex_(pattern) {}
946954

947955
bool match(Request &request) const override;
948956

@@ -1009,9 +1017,12 @@ class Server {
10091017
}
10101018

10111019
Server &set_exception_handler(ExceptionHandler handler);
1020+
10121021
Server &set_pre_routing_handler(HandlerWithResponse handler);
10131022
Server &set_post_routing_handler(Handler handler);
10141023

1024+
Server &set_pre_request_handler(HandlerWithResponse handler);
1025+
10151026
Server &set_expect_100_continue_handler(Expect100ContinueHandler handler);
10161027
Server &set_logger(Logger logger);
10171028

@@ -1153,6 +1164,7 @@ class Server {
11531164
ExceptionHandler exception_handler_;
11541165
HandlerWithResponse pre_routing_handler_;
11551166
Handler post_routing_handler_;
1167+
HandlerWithResponse pre_request_handler_;
11561168
Expect100ContinueHandler expect_100_continue_handler_;
11571169

11581170
Logger logger_;
@@ -6224,7 +6236,8 @@ inline time_t BufferStream::duration() const { return 0; }
62246236

62256237
inline const std::string &BufferStream::get_buffer() const { return buffer; }
62266238

6227-
inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern) {
6239+
inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern)
6240+
: MatcherBase(pattern) {
62286241
constexpr const char marker[] = "/:";
62296242

62306243
// One past the last ending position of a path param substring
@@ -6475,6 +6488,11 @@ inline Server &Server::set_post_routing_handler(Handler handler) {
64756488
return *this;
64766489
}
64776490

6491+
inline Server &Server::set_pre_request_handler(HandlerWithResponse handler) {
6492+
pre_request_handler_ = std::move(handler);
6493+
return *this;
6494+
}
6495+
64786496
inline Server &Server::set_logger(Logger logger) {
64796497
logger_ = std::move(logger);
64806498
return *this;
@@ -7129,7 +7147,11 @@ inline bool Server::dispatch_request(Request &req, Response &res,
71297147
const auto &handler = x.second;
71307148

71317149
if (matcher->match(req)) {
7132-
handler(req, res);
7150+
req.matched_route = matcher->pattern();
7151+
if (!pre_request_handler_ ||
7152+
pre_request_handler_(req, res) != HandlerResponse::Handled) {
7153+
handler(req, res);
7154+
}
71337155
return true;
71347156
}
71357157
}
@@ -7256,7 +7278,11 @@ inline bool Server::dispatch_request_for_content_reader(
72567278
const auto &handler = x.second;
72577279

72587280
if (matcher->match(req)) {
7259-
handler(req, res, content_reader);
7281+
req.matched_route = matcher->pattern();
7282+
if (!pre_request_handler_ ||
7283+
pre_request_handler_(req, res) != HandlerResponse::Handled) {
7284+
handler(req, res, content_reader);
7285+
}
72607286
return true;
72617287
}
72627288
}

test/test.cc

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2263,7 +2263,7 @@ TEST(NoContentTest, ContentLength) {
22632263
}
22642264
}
22652265

2266-
TEST(RoutingHandlerTest, PreRoutingHandler) {
2266+
TEST(RoutingHandlerTest, PreAndPostRoutingHandlers) {
22672267
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
22682268
SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE);
22692269
ASSERT_TRUE(svr.is_valid());
@@ -2354,6 +2354,63 @@ TEST(RoutingHandlerTest, PreRoutingHandler) {
23542354
}
23552355
}
23562356

2357+
TEST(RequestHandlerTest, PreRequestHandler) {
2358+
auto route_path = "/user/:user";
2359+
2360+
Server svr;
2361+
2362+
svr.Get("/hi", [](const Request &, Response &res) {
2363+
res.set_content("hi", "text/plain");
2364+
});
2365+
2366+
svr.Get(route_path, [](const Request &req, Response &res) {
2367+
res.set_content(req.path_params.at("user"), "text/plain");
2368+
});
2369+
2370+
svr.set_pre_request_handler([&](const Request &req, Response &res) {
2371+
if (req.matched_route == route_path) {
2372+
auto user = req.path_params.at("user");
2373+
if (user != "john") {
2374+
res.status = StatusCode::Forbidden_403;
2375+
res.set_content("error", "text/html");
2376+
return Server::HandlerResponse::Handled;
2377+
}
2378+
}
2379+
return Server::HandlerResponse::Unhandled;
2380+
});
2381+
2382+
auto thread = std::thread([&]() { svr.listen(HOST, PORT); });
2383+
auto se = detail::scope_exit([&] {
2384+
svr.stop();
2385+
thread.join();
2386+
ASSERT_FALSE(svr.is_running());
2387+
});
2388+
2389+
svr.wait_until_ready();
2390+
2391+
Client cli(HOST, PORT);
2392+
{
2393+
auto res = cli.Get("/hi");
2394+
ASSERT_TRUE(res);
2395+
EXPECT_EQ(StatusCode::OK_200, res->status);
2396+
EXPECT_EQ("hi", res->body);
2397+
}
2398+
2399+
{
2400+
auto res = cli.Get("/user/john");
2401+
ASSERT_TRUE(res);
2402+
EXPECT_EQ(StatusCode::OK_200, res->status);
2403+
EXPECT_EQ("john", res->body);
2404+
}
2405+
2406+
{
2407+
auto res = cli.Get("/user/invalid-user");
2408+
ASSERT_TRUE(res);
2409+
EXPECT_EQ(StatusCode::Forbidden_403, res->status);
2410+
EXPECT_EQ("error", res->body);
2411+
}
2412+
}
2413+
23572414
TEST(InvalidFormatTest, StatusCode) {
23582415
Server svr;
23592416

0 commit comments

Comments
 (0)