diff --git a/src/engine/Server.cpp b/src/engine/Server.cpp index ffa25ac664..8994d1c5a2 100644 --- a/src/engine/Server.cpp +++ b/src/engine/Server.cpp @@ -760,16 +760,15 @@ MediaType Server::determineMediaType( } // ____________________________________________________________________________ -std::pair, - ad_utility::websocket::MessageSender> -Server::createMessageSender(auto& queryHub_, const auto& request, - const string& operation) { - auto queryHub = queryHub_.lock(); - AD_CORRECTNESS_CHECK(queryHub); +ad_utility::websocket::MessageSender Server::createMessageSender( + const std::weak_ptr& queryHub, + const ad_utility::httpUtils::HttpRequest auto& request, + const string& operation) { + auto queryHubLock = queryHub.lock(); + AD_CORRECTNESS_CHECK(queryHubLock); ad_utility::websocket::MessageSender messageSender{ - getQueryId(request, operation), *queryHub}; - // TODO is it required to keep the queryHub alive? - return std::make_pair(std::move(queryHub), std::move(messageSender)); + getQueryId(request, operation), *queryHubLock}; + return messageSender; } // ____________________________________________________________________________ @@ -782,10 +781,8 @@ Awaitable Server::processQuery( LOG(INFO) << "Requested media type of result is \"" << ad_utility::toString(mediaType) << "\"" << std::endl; - auto queryHub = queryHub_.lock(); - AD_CORRECTNESS_CHECK(queryHub); - ad_utility::websocket::MessageSender messageSender{getQueryId(request, query), - *queryHub}; + ad_utility::websocket::MessageSender messageSender = + createMessageSender(queryHub_, request, query); auto [cancellationHandle, cancelTimeoutOnDestruction] = setupCancellationHandle(messageSender.getQueryId(), timeLimit); @@ -862,8 +859,7 @@ Awaitable Server::processUpdate( ad_utility::Timer& requestTimer, const ad_utility::httpUtils::HttpRequest auto& request, auto&& send, TimeLimit timeLimit) { - auto [queryHub, messageSender] = - createMessageSender(queryHub_, request, update); + auto messageSender = createMessageSender(queryHub_, request, update); auto [cancellationHandle, cancelTimeoutOnDestruction] = setupCancellationHandle(messageSender.getQueryId(), timeLimit); @@ -893,10 +889,11 @@ Awaitable Server::processUpdate( // work fine when a new update is sent only after the previous one has // finished. auto& deltaTriplesManager = index_.deltaTriplesManager(); - deltaTriplesManager.modify([&](auto& deltaTriples) { - ExecuteUpdate::executeUpdate(index_, plannedQuery.parsedQuery_, qet, - deltaTriples, cancellationHandle); - }); + deltaTriplesManager.modify( + [this, &plannedQuery, &qet, &cancellationHandle](auto& deltaTriples) { + ExecuteUpdate::executeUpdate(index_, plannedQuery.parsedQuery_, qet, + deltaTriples, cancellationHandle); + }); LOG(INFO) << "Done processing update" << ", total time was " << requestTimer.msecs().count() << " ms" diff --git a/src/engine/Server.h b/src/engine/Server.h index e2f70c08d4..a0fcfbd063 100644 --- a/src/engine/Server.h +++ b/src/engine/Server.h @@ -32,6 +32,8 @@ using std::vector; //! The HTTP Server used. class Server { FRIEND_TEST(ServerTest, parseHttpRequest); + FRIEND_TEST(ServerTest, getQueryId); + FRIEND_TEST(ServerTest, createMessageSender); public: explicit Server(unsigned short port, size_t numThreads, @@ -172,10 +174,11 @@ class Server { SharedCancellationHandle handle, TimeLimit timeLimit, const ad_utility::Timer& requestTimer); - std::pair, - ad_utility::websocket::MessageSender> - createMessageSender(auto& queryHub_, const auto& request, - const string& operation); + // Creates a `MessageSender` for the given operation. + ad_utility::websocket::MessageSender createMessageSender( + const std::weak_ptr& queryHub, + const ad_utility::httpUtils::HttpRequest auto& request, + const string& operation); static json composeErrorResponseJson( const string& query, const std::string& errorMsg, diff --git a/test/ServerTest.cpp b/test/ServerTest.cpp index 5ed589fc42..292d77f12b 100644 --- a/test/ServerTest.cpp +++ b/test/ServerTest.cpp @@ -27,27 +27,27 @@ auto ParsedRequestIs = [](const std::string& path, AD_FIELD(ad_utility::url_parser::ParsedRequest, operation_, testing::Eq(operation))); }; +auto MakeBasicRequest = [](http::verb method, const std::string& target) { + // version 11 stands for HTTP/1.1 + return http::request{method, target, 11}; +}; +auto MakeGetRequest = [](const std::string& target) { + return MakeBasicRequest(http::verb::get, target); +}; +auto MakePostRequest = [](const std::string& target, + const std::string& contentType, + const std::string& body) { + auto req = MakeBasicRequest(http::verb::post, target); + req.set(http::field::content_type, contentType); + req.body() = body; + req.prepare_payload(); + return req; +}; } // namespace TEST(ServerTest, parseHttpRequest) { namespace http = boost::beast::http; - auto MakeBasicRequest = [](http::verb method, const std::string& target) { - // version 11 stands for HTTP/1.1 - return http::request{method, target, 11}; - }; - auto MakeGetRequest = [&MakeBasicRequest](const std::string& target) { - return MakeBasicRequest(http::verb::get, target); - }; - auto MakePostRequest = [&MakeBasicRequest](const std::string& target, - const std::string& contentType, - const std::string& body) { - auto req = MakeBasicRequest(http::verb::post, target); - req.set(http::field::content_type, contentType); - req.body() = body; - req.prepare_payload(); - return req; - }; auto parse = [](const ad_utility::httpUtils::HttpRequest auto& request) { return Server::parseHttpRequest(request); }; @@ -222,3 +222,65 @@ TEST(ServerTest, determineMediaType) { EXPECT_THAT(Server::determineMediaType({}, MakeRequest("")), testing::Eq(ad_utility::MediaType::sparqlJson)); } + +TEST(ServerTest, getQueryId) { + using namespace ad_utility::websocket; + Server server{9999, 1, ad_utility::MemorySize::megabytes(1), "accessToken"}; + auto reqWithExplicitQueryId = MakeGetRequest("/"); + reqWithExplicitQueryId.set("Query-Id", "100"); + const auto req = MakeGetRequest("/"); + { + // A request with a custom query id. + auto queryId1 = server.getQueryId(reqWithExplicitQueryId, + "SELECT * WHERE { ?a ?b ?c }"); + // Another request with the same custom query id. This throws an error, + // because query id cannot be used for multiple queries at the same time. + AD_EXPECT_THROW_WITH_MESSAGE( + server.getQueryId(reqWithExplicitQueryId, + "SELECT * WHERE { ?a ?b ?c }"), + testing::HasSubstr("Query id '100' is already in use!")); + } + // The custom query id can be reused, once the query is finished. + auto queryId1 = + server.getQueryId(reqWithExplicitQueryId, "SELECT * WHERE { ?a ?b ?c }"); + // Without custom query ids, unique ids are generated. + auto queryId2 = server.getQueryId(req, "SELECT * WHERE { ?a ?b ?c }"); + auto queryId3 = server.getQueryId(req, "SELECT * WHERE { ?a ?b ?c }"); +} + +TEST(ServerTest, createMessageSender) { + Server server{9999, 1, ad_utility::MemorySize::megabytes(1), "accessToken"}; + auto reqWithExplicitQueryId = MakeGetRequest("/"); + std::string customQueryId = "100"; + reqWithExplicitQueryId.set("Query-Id", customQueryId); + const auto req = MakeGetRequest("/"); + // The query hub is only valid once, the server has been started. + AD_EXPECT_THROW_WITH_MESSAGE( + server.createMessageSender(server.queryHub_, req, + "SELECT * WHERE { ?a ?b ?c }"), + testing::HasSubstr("Assertion `queryHubLock` failed.")); + { + // Set a dummy query hub. + boost::asio::io_context io_context; + auto queryHub = + std::make_shared(io_context); + server.queryHub_ = queryHub; + // MessageSenders are created normally. + server.createMessageSender(server.queryHub_, req, + "SELECT * WHERE { ?a ?b ?c }"); + server.createMessageSender(server.queryHub_, req, + "INSERT DATA { }"); + EXPECT_THAT( + server.createMessageSender(server.queryHub_, reqWithExplicitQueryId, + "INSERT DATA { }"), + AD_PROPERTY(ad_utility::websocket::MessageSender, getQueryId, + testing::Eq(ad_utility::websocket::QueryId::idFromString( + customQueryId)))); + } + // Once the query hub expires (e.g. because the io context dies), message + // senders can no longer be created. + AD_EXPECT_THROW_WITH_MESSAGE( + server.createMessageSender(server.queryHub_, req, + "SELECT * WHERE { ?a ?b ?c }"), + testing::HasSubstr("Assertion `queryHubLock` failed.")); +}