diff --git a/net/ServerSocket.hpp b/net/ServerSocket.hpp index 7a47030b80d46..69cd2a2291a29 100644 --- a/net/ServerSocket.hpp +++ b/net/ServerSocket.hpp @@ -108,9 +108,11 @@ class ServerSocket : public Socket const std::string msg = "Failed to accept. (errno: "; throw std::runtime_error(msg + std::strerror(errno) + ')'); } - - LOG_TRC("Accepted client #" << clientSocket->getFD()); - _clientPoller.insertNewSocket(std::move(clientSocket)); + if( clientSocket->isOpen() ) + { + LOG_TRC("Accepted client #" << clientSocket->getFD()); + _clientPoller.insertNewSocket(std::move(clientSocket)); + } // else intentionally cancelled accepted connection (e.g. connection limiter) } } diff --git a/net/Socket.cpp b/net/Socket.cpp index f51178743e7b2..489a4463856f1 100644 --- a/net/Socket.cpp +++ b/net/Socket.cpp @@ -65,6 +65,9 @@ std::atomic Socket::InhibitThreadChecks(false); std::unique_ptr SocketPoll::PollWatchdog; +std::mutex Socket::statsMutex; +std::atomic Socket::statsConnectionCount(0); + net::DefaultValues net::Defaults = { .inactivityTimeout = std::chrono::seconds(3600), .wsPingAvgTimeout = std::chrono::seconds(12), .wsPingInterval = std::chrono::seconds(18), @@ -146,8 +149,10 @@ std::string Socket::getStatsString(const std::chrono::steady_clock::time_point & std::ostream& Socket::streamImpl(std::ostream& os) const { os << "Socket[#" << getFD() - << ", " << toString(type()) - << " @ "; + << ", " << toString(type()); + if (isCounted()) + os << ", counted"; + os << " @ "; if (Type::IPv6 == type()) { os << "[" << clientAddress() << "]:" << clientPort(); @@ -1159,47 +1164,85 @@ std::shared_ptr ServerSocket::accept() const int rc = fakeSocketAccept4(getFD()); #endif LOG_TRC("Accepted socket #" << rc << ", creating socket object."); - try + if (rc != -1) { - // Create a socket object using the factory. - if (rc != -1) - { #if !MOBILEAPP - char addrstr[INET6_ADDRSTRLEN]; + char addrstr[INET6_ADDRSTRLEN]; - Socket::Type type; - const void *inAddr; - if (clientInfo.sin6_family == AF_INET) + Socket::Type type; + const void *inAddr; + if (clientInfo.sin6_family == AF_INET) + { + struct sockaddr_in *ipv4 = (struct sockaddr_in *)&clientInfo; + inAddr = &(ipv4->sin_addr); + type = Socket::Type::IPv4; + } + else + { + struct sockaddr_in6 *ipv6 = &clientInfo; + inAddr = &(ipv6->sin6_addr); + type = Socket::Type::IPv6; + } + ::inet_ntop(clientInfo.sin6_family, inAddr, addrstr, sizeof(addrstr)); + + if( !Socket::checkAndIncrConnectionCount(rc) ) + { + // cancellation branch! + if constexpr (!Util::isMobileApp()) { - struct sockaddr_in *ipv4 = (struct sockaddr_in *)&clientInfo; - inAddr = &(ipv4->sin_addr); - type = Socket::Type::IPv4; + if (::close(rc)) + LOG_SYS("Ignored error closing socket #" << rc); + } else + fakeSocketClose(rc); + + try { + // return closed dummy socket to avoid handlePoll throw + std::shared_ptr socket = StreamSocket::create( + std::string(), /*fd=*/-1 /* closed */, type, false, HostType::Other, + std::make_shared()); + socket->setClientAddress(addrstr, clientInfo.sin6_port); + LOG_WRN("TCP Limiter: Rejecting accepted socket #" << rc << ", " << *socket); + return socket; } - else + catch (const std::exception& ex) { - struct sockaddr_in6 *ipv6 = &clientInfo; - inAddr = &(ipv6->sin6_addr); - type = Socket::Type::IPv6; + LOG_WRN("TCP Limiter: Rejecting accepted socket #" << rc); + LOG_ERR("Failed to create limited rejected client socket #" << rc << ". Error: " << ex.what()); + return nullptr; } + } +#endif + // Create a socket object using the factory. + bool hasDtor = false; + try + { +#if !MOBILEAPP + std::shared_ptr socket = createSocketFromAccept(rc, type); // may throw + socket->_isCounted = true; + hasDtor = true; + socket->setClientAddress(addrstr, clientInfo.sin6_port); - std::shared_ptr _socket = createSocketFromAccept(rc, type); - - ::inet_ntop(clientInfo.sin6_family, inAddr, addrstr, sizeof(addrstr)); - _socket->setClientAddress(addrstr, clientInfo.sin6_port); - - LOG_TRC("Accepted socket #" << _socket->getFD() << " has family " - << clientInfo.sin6_family << " address " - << _socket->clientAddress()); + LOG_TRC("Accepted socket #" << socket->getFD() << " has family " + << clientInfo.sin6_family << ", " << *socket); #else std::shared_ptr _socket = createSocketFromAccept(rc, Socket::Type::Unix); #endif - return _socket; + return socket; + } + catch (const std::exception& ex) + { + LOG_ERR("Failed to create client socket #" << rc << " (had socket << " << hasDtor << "). Error: " << ex.what()); + if (!hasDtor) + { + if constexpr (!Util::isMobileApp()) + { + if (::close(rc)) + LOG_SYS("Ignored error closing socket #" << rc); + } else + fakeSocketClose(rc); + decrConnectionCount(rc); + } } - return std::shared_ptr(nullptr); - } - catch (const std::exception& ex) - { - LOG_ERR("Failed to create client socket #" << rc << ". Error: " << ex.what()); } return nullptr; @@ -1392,8 +1435,10 @@ std::ostream& StreamSocket::stream(std::ostream& os) const { os << "StreamSocket[#" << getFD() << ", " << toStringShort(_wsState) - << ", " << Socket::toString(type()) - << " @ "; + << ", " << Socket::toString(type()); + if (isCounted()) + os << ", counted"; + os << " @ "; if (Type::IPv6 == type()) { os << "[" << clientAddress() << "]:" << clientPort(); diff --git a/net/Socket.hpp b/net/Socket.hpp index ba72c7cf5f6b9..e56a1a71f66d1 100644 --- a/net/Socket.hpp +++ b/net/Socket.hpp @@ -11,6 +11,7 @@ #pragma once +#include #include #include #include @@ -33,6 +34,7 @@ #include #include "Log.hpp" +#include "NetUtil.hpp" #include "Util.hpp" #include "Buffer.hpp" #include "SigUtil.hpp" @@ -129,6 +131,8 @@ class SocketDisposition final std::shared_ptr _socket; }; +class ServerSocket; // fwd + /// A non-blocking, streaming socket. class Socket { @@ -151,6 +155,7 @@ class Socket , _lastSeenTime(_creationTime) , _bytesSent(0) , _bytesRcvd(0) + , _isCounted(false) { init(); } @@ -162,12 +167,16 @@ class Socket // Doesn't block on sockets; no error handling needed. if constexpr (!Util::isMobileApp()) { - ::close(_fd); + if (::close(_fd)) + LOG_SYS("Ignored error closing socket #" << _fd); LOG_DBG("Closed socket " << toStringImpl()); } else - { fakeSocketClose(_fd); + if (_isCounted) + { + _isCounted = false; + decrConnectionCount(_fd); } } @@ -175,6 +184,7 @@ class Socket bool isOpen() const { return _open; } /// Returns true if this socket has been closed, i.e. rejected from polling and potentially shutdown bool isClosed() const { return !_open; } + bool isCounted() const { return _isCounted; } constexpr Type type() const { return _type; } constexpr bool isIPType() const { return Type::IPv4 == _type || Type::IPv6 == _type; } @@ -249,7 +259,7 @@ class Socket if constexpr (!Util::isMobileApp()) { const int val = 1; - if (::setsockopt(_fd, IPPROTO_TCP, TCP_NODELAY, (char*)&val, sizeof(val)) == -1) + if (isOpen() && ::setsockopt(_fd, IPPROTO_TCP, TCP_NODELAY, (char*)&val, sizeof(val)) == -1) { static std::once_flag once; std::call_once(once, @@ -409,6 +419,10 @@ class Socket LOG_TRC("Ignore further input on socket."); _ignoreInput = true; } + + /// Returns connection count, lock-free + static size_t connectionCount() { return statsConnectionCount; } + protected: /// Construct based on an existing socket fd. /// Used by accept() only. @@ -422,6 +436,7 @@ class Socket , _lastSeenTime(_creationTime) , _bytesSent(0) , _bytesRcvd(0) + , _isCounted(false) { init(); } @@ -493,6 +508,50 @@ class Socket /// We check the owner even in the release builds, needs to be always correct. std::thread::id _owner; + + bool _isCounted; // if true, must call `decrConnectionCount` in connection `shutdown` + + /// Decrements global connection count + /// @param fd the related file descriptor for logging + static void decrConnectionCount(int fd) + { + std::lock_guard lock(statsMutex); + const size_t u = statsConnectionCount; + auto logPrefix = [fd](std::ostream& os) { os << '#' << fd << ": "; }; + if (u > 0) + { + const size_t v = --statsConnectionCount; + LOG_TRC("TCP Limiter: Count decremented: " << u << " -> " << v); + } + else + LOG_WRN("TCP Limiter: Count decrement underflow: " << u); + } + + /// Increments global connection counter if not exceeding net::DefaultValue::maxTCPConnections. + /// Returns true if free connections were available, otherwise false. + /// No limitation is applied if net::DefaultValue::maxTCPConnections == 0. + /// @param fd the related file descriptor for logging + static bool checkAndIncrConnectionCount(int fd) + { + std::lock_guard lock(statsMutex); + const size_t u = statsConnectionCount; + auto logPrefix = [fd](std::ostream& os) { os << '#' << fd << ": "; }; + if (net::Defaults.maxTCPConnections == 0 || u < net::Defaults.maxTCPConnections) + { + const size_t v = ++statsConnectionCount; + LOG_TRC("TCP Limiter: Count incremented: " << u << " -> " << v); + return true; + } + else + { + LOG_WRN("TCP Limiter: Limit reached: " << u); + return false; + } + } + friend class ServerSocket; // allow `checkAndIncrConnectionCount` for `ServerSocket::accept()` + + static std::mutex statsMutex; + static std::atomic statsConnectionCount; // accepted TCP IPv4/IPv6 socket count }; inline std::ostream& operator<<(std::ostream& os, const Socket &s) { return s.stream(os); } @@ -617,6 +676,18 @@ class SimpleSocketHandler : public ProtocolHandlerInterface void getIOStats(uint64_t &, uint64_t &) override {} }; +/// A no-operation ProtocolHandlerInterface with dummy API. +class NoOpSocketHandler : public SimpleSocketHandler +{ +public: + NoOpSocketHandler() = default; + + void onConnect(const std::shared_ptr& /*socket*/) override {} + void handleIncomingMessage(SocketDisposition &) override {} + int getPollEvents(std::chrono::steady_clock::time_point /*now*/, int64_t &/*timeoutMaxMicroS*/) override { return 0; } + void performWrites(std::size_t /*capacity*/) override {} +}; + /// Interface that receives and sends incoming messages. class MessageHandlerInterface : public std::enable_shared_from_this diff --git a/test/UnitTimeoutBase.hpp b/test/UnitTimeoutBase.hpp index 06f02d09578f2..f6ce9a308dc88 100644 --- a/test/UnitTimeoutBase.hpp +++ b/test/UnitTimeoutBase.hpp @@ -204,6 +204,8 @@ inline UnitBase::TestResult UnitTimeoutBase1::testHttp(const size_t connectionLi sessions.clear(); TST_LOG("Clearing Poller: " << testname); socketPollers.clear(); + // TCP Connection Count: Just an estimation, no locking on server side + TST_LOG("TCP Connection Count: " << testname << ", " << Socket::connectionCount() << " / " << net::Defaults.maxTCPConnections); TST_LOG("Ending Test: " << testname); return TestResult::Ok; } @@ -302,6 +304,8 @@ inline UnitBase::TestResult UnitTimeoutBase1::testWSPing(const size_t connection sessions.clear(); TST_LOG("Clearing Poller: " << testname); socketPollers.clear(); + // TCP Connection Count: Just an estimation, no locking on server side + TST_LOG("TCP Connection Count: " << testname << ", " << Socket::connectionCount() << " / " << net::Defaults.maxTCPConnections); TST_LOG("Ending Test: " << testname); return TestResult::Ok; } @@ -397,6 +401,8 @@ inline UnitBase::TestResult UnitTimeoutBase1::testWSDChatPing(const size_t conne sessions.clear(); TST_LOG("Clearing Poller: " << testname); socketPollers.clear(); + // TCP Connection Count: Just an estimation, no locking on server side + TST_LOG("TCP Connection Count: " << testname << ", " << Socket::connectionCount() << " / " << net::Defaults.maxTCPConnections); TST_LOG("Ending Test: " << testname); return TestResult::Ok; }