Skip to content

Commit

Permalink
cool#9833: Implement MaxConnection Limit (2)
Browse files Browse the repository at this point in the history
Reimplementation of commit 80246f7 (post revert).

Adding TCP connection limit to outside facing TCP IPv4/IPv6 Sockets
- validated and registered at ServerSocket::accept() (IPv4/IPv6 only)
- excluded from LocalServerSocket::accept() for local UDS connections
- unregistered at Socket dtor post socket closing.

TODO:
- revise net::Defaults.maxTCPConnections to match actual system settings

Signed-off-by: Sven Göthel <[email protected]>
Change-Id: Ib9f0ac17c05ffe65c2370490f68f581fa76730e7
  • Loading branch information
Sven Göthel committed Oct 30, 2024
1 parent abac9fa commit 7031ce0
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 39 deletions.
8 changes: 5 additions & 3 deletions net/ServerSocket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,11 @@ class ServerSocket : public Socket
const std::string msg = "Failed to accept. (errno: ";
throw std::runtime_error(msg + std::strerror(errno) + ')');
}

LOG_TRC("Accepted client #" << clientSocket->getFD());
_clientPoller.insertNewSocket(std::move(clientSocket));
if( clientSocket->isOpen() )
{
LOG_TRC("Accepted client #" << clientSocket->getFD());
_clientPoller.insertNewSocket(std::move(clientSocket));
} // else intentionally cancelled accepted connection (e.g. connection limiter)
}
}

Expand Down
111 changes: 78 additions & 33 deletions net/Socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ std::atomic<bool> Socket::InhibitThreadChecks(false);

std::unique_ptr<Watchdog> SocketPoll::PollWatchdog;

std::mutex Socket::statsMutex;
std::atomic<size_t> Socket::statsConnectionCount(0);

net::DefaultValues net::Defaults = { .inactivityTimeout = std::chrono::seconds(3600),
.wsPingAvgTimeout = std::chrono::seconds(12),
.wsPingInterval = std::chrono::seconds(18),
Expand Down Expand Up @@ -146,8 +149,10 @@ std::string Socket::getStatsString(const std::chrono::steady_clock::time_point &
std::ostream& Socket::streamImpl(std::ostream& os) const
{
os << "Socket[#" << getFD()
<< ", " << toString(type())
<< " @ ";
<< ", " << toString(type());
if (isCounted())
os << ", counted";
os << " @ ";
if (Type::IPv6 == type())
{
os << "[" << clientAddress() << "]:" << clientPort();
Expand Down Expand Up @@ -1159,47 +1164,85 @@ std::shared_ptr<Socket> ServerSocket::accept()
const int rc = fakeSocketAccept4(getFD());
#endif
LOG_TRC("Accepted socket #" << rc << ", creating socket object.");
try
if (rc != -1)
{
// Create a socket object using the factory.
if (rc != -1)
{
#if !MOBILEAPP
char addrstr[INET6_ADDRSTRLEN];
char addrstr[INET6_ADDRSTRLEN];

Socket::Type type;
const void *inAddr;
if (clientInfo.sin6_family == AF_INET)
Socket::Type type;
const void *inAddr;
if (clientInfo.sin6_family == AF_INET)
{
struct sockaddr_in *ipv4 = (struct sockaddr_in *)&clientInfo;
inAddr = &(ipv4->sin_addr);
type = Socket::Type::IPv4;
}
else
{
struct sockaddr_in6 *ipv6 = &clientInfo;
inAddr = &(ipv6->sin6_addr);
type = Socket::Type::IPv6;
}
::inet_ntop(clientInfo.sin6_family, inAddr, addrstr, sizeof(addrstr));

if( !Socket::checkAndIncrConnectionCount(rc) )
{
// cancellation branch!
if constexpr (!Util::isMobileApp())
{
struct sockaddr_in *ipv4 = (struct sockaddr_in *)&clientInfo;
inAddr = &(ipv4->sin_addr);
type = Socket::Type::IPv4;
if (::close(rc))
LOG_SYS("Ignored error closing socket #" << rc);
} else
fakeSocketClose(rc);

try {
// return closed dummy socket to avoid handlePoll throw
std::shared_ptr<Socket> socket = StreamSocket::create<StreamSocket>(
std::string(), /*fd=*/-1 /* closed */, type, false, HostType::Other,
std::make_shared<NoOpSocketHandler>());
socket->setClientAddress(addrstr, clientInfo.sin6_port);
LOG_WRN("TCP Limiter: Rejecting accepted socket #" << rc << ", " << *socket);
return socket;
}
else
catch (const std::exception& ex)
{
struct sockaddr_in6 *ipv6 = &clientInfo;
inAddr = &(ipv6->sin6_addr);
type = Socket::Type::IPv6;
LOG_WRN("TCP Limiter: Rejecting accepted socket #" << rc);
LOG_ERR("Failed to create limited rejected client socket #" << rc << ". Error: " << ex.what());
return nullptr;
}
}
#endif
// Create a socket object using the factory.
bool hasDtor = false;
try
{
#if !MOBILEAPP
std::shared_ptr<Socket> socket = createSocketFromAccept(rc, type); // may throw
socket->_isCounted = true;
hasDtor = true;
socket->setClientAddress(addrstr, clientInfo.sin6_port);

std::shared_ptr<Socket> _socket = createSocketFromAccept(rc, type);

::inet_ntop(clientInfo.sin6_family, inAddr, addrstr, sizeof(addrstr));
_socket->setClientAddress(addrstr, clientInfo.sin6_port);

LOG_TRC("Accepted socket #" << _socket->getFD() << " has family "
<< clientInfo.sin6_family << " address "
<< _socket->clientAddress());
LOG_TRC("Accepted socket #" << socket->getFD() << " has family "
<< clientInfo.sin6_family << ", " << *socket);
#else
std::shared_ptr<Socket> _socket = createSocketFromAccept(rc, Socket::Type::Unix);
#endif
return _socket;
return socket;
}
catch (const std::exception& ex)
{
LOG_ERR("Failed to create client socket #" << rc << " (had socket << " << hasDtor << "). Error: " << ex.what());
if (!hasDtor)
{
if constexpr (!Util::isMobileApp())
{
if (::close(rc))
LOG_SYS("Ignored error closing socket #" << rc);
} else
fakeSocketClose(rc);
decrConnectionCount(rc);
}
}
return std::shared_ptr<Socket>(nullptr);
}
catch (const std::exception& ex)
{
LOG_ERR("Failed to create client socket #" << rc << ". Error: " << ex.what());
}

return nullptr;
Expand Down Expand Up @@ -1392,8 +1435,10 @@ std::ostream& StreamSocket::stream(std::ostream& os) const
{
os << "StreamSocket[#" << getFD()
<< ", " << toStringShort(_wsState)
<< ", " << Socket::toString(type())
<< " @ ";
<< ", " << Socket::toString(type());
if (isCounted())
os << ", counted";
os << " @ ";
if (Type::IPv6 == type())
{
os << "[" << clientAddress() << "]:" << clientPort();
Expand Down
77 changes: 74 additions & 3 deletions net/Socket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#pragma once

#include <limits>
#include <poll.h>
#include <unistd.h>
#include <sys/types.h>
Expand All @@ -33,6 +34,7 @@

#include <common/StateEnum.hpp>
#include "Log.hpp"
#include "NetUtil.hpp"
#include "Util.hpp"
#include "Buffer.hpp"
#include "SigUtil.hpp"
Expand Down Expand Up @@ -129,6 +131,8 @@ class SocketDisposition final
std::shared_ptr<Socket> _socket;
};

class ServerSocket; // fwd

/// A non-blocking, streaming socket.
class Socket
{
Expand All @@ -151,6 +155,7 @@ class Socket
, _lastSeenTime(_creationTime)
, _bytesSent(0)
, _bytesRcvd(0)
, _isCounted(false)
{
init();
}
Expand All @@ -162,19 +167,24 @@ class Socket
// Doesn't block on sockets; no error handling needed.
if constexpr (!Util::isMobileApp())
{
::close(_fd);
if (::close(_fd))
LOG_SYS("Ignored error closing socket #" << _fd);
LOG_DBG("Closed socket " << toStringImpl());
}
else
{
fakeSocketClose(_fd);
if (_isCounted)
{
_isCounted = false;
decrConnectionCount(_fd);
}
}

/// Returns true if this socket is open, i.e. allowed to be polled and not shutdown
bool isOpen() const { return _open; }
/// Returns true if this socket has been closed, i.e. rejected from polling and potentially shutdown
bool isClosed() const { return !_open; }
bool isCounted() const { return _isCounted; }

constexpr Type type() const { return _type; }
constexpr bool isIPType() const { return Type::IPv4 == _type || Type::IPv6 == _type; }
Expand Down Expand Up @@ -249,7 +259,7 @@ class Socket
if constexpr (!Util::isMobileApp())
{
const int val = 1;
if (::setsockopt(_fd, IPPROTO_TCP, TCP_NODELAY, (char*)&val, sizeof(val)) == -1)
if (isOpen() && ::setsockopt(_fd, IPPROTO_TCP, TCP_NODELAY, (char*)&val, sizeof(val)) == -1)
{
static std::once_flag once;
std::call_once(once,
Expand Down Expand Up @@ -409,6 +419,10 @@ class Socket
LOG_TRC("Ignore further input on socket.");
_ignoreInput = true;
}

/// Returns connection count, lock-free
static size_t connectionCount() { return statsConnectionCount; }

protected:
/// Construct based on an existing socket fd.
/// Used by accept() only.
Expand All @@ -422,6 +436,7 @@ class Socket
, _lastSeenTime(_creationTime)
, _bytesSent(0)
, _bytesRcvd(0)
, _isCounted(false)
{
init();
}
Expand Down Expand Up @@ -493,6 +508,50 @@ class Socket

/// We check the owner even in the release builds, needs to be always correct.
std::thread::id _owner;

bool _isCounted; // if true, must call `decrConnectionCount` in connection `shutdown`

/// Decrements global connection count
/// @param fd the related file descriptor for logging
static void decrConnectionCount(int fd)
{
std::lock_guard<std::mutex> lock(statsMutex);
const size_t u = statsConnectionCount;
auto logPrefix = [fd](std::ostream& os) { os << '#' << fd << ": "; };
if (u > 0)
{
const size_t v = --statsConnectionCount;
LOG_TRC("TCP Limiter: Count decremented: " << u << " -> " << v);
}
else
LOG_WRN("TCP Limiter: Count decrement underflow: " << u);
}

/// Increments global connection counter if not exceeding net::DefaultValue::maxTCPConnections.
/// Returns true if free connections were available, otherwise false.
/// No limitation is applied if net::DefaultValue::maxTCPConnections == 0.
/// @param fd the related file descriptor for logging
static bool checkAndIncrConnectionCount(int fd)
{
std::lock_guard<std::mutex> lock(statsMutex);
const size_t u = statsConnectionCount;
auto logPrefix = [fd](std::ostream& os) { os << '#' << fd << ": "; };
if (net::Defaults.maxTCPConnections == 0 || u < net::Defaults.maxTCPConnections)
{
const size_t v = ++statsConnectionCount;
LOG_TRC("TCP Limiter: Count incremented: " << u << " -> " << v);
return true;
}
else
{
LOG_WRN("TCP Limiter: Limit reached: " << u);
return false;
}
}
friend class ServerSocket; // allow `checkAndIncrConnectionCount` for `ServerSocket::accept()`

static std::mutex statsMutex;
static std::atomic<size_t> statsConnectionCount; // accepted TCP IPv4/IPv6 socket count
};

inline std::ostream& operator<<(std::ostream& os, const Socket &s) { return s.stream(os); }
Expand Down Expand Up @@ -617,6 +676,18 @@ class SimpleSocketHandler : public ProtocolHandlerInterface
void getIOStats(uint64_t &, uint64_t &) override {}
};

/// A no-operation ProtocolHandlerInterface with dummy API.
class NoOpSocketHandler : public SimpleSocketHandler
{
public:
NoOpSocketHandler() = default;

void onConnect(const std::shared_ptr<StreamSocket>& /*socket*/) override {}
void handleIncomingMessage(SocketDisposition &) override {}
int getPollEvents(std::chrono::steady_clock::time_point /*now*/, int64_t &/*timeoutMaxMicroS*/) override { return 0; }
void performWrites(std::size_t /*capacity*/) override {}
};

/// Interface that receives and sends incoming messages.
class MessageHandlerInterface :
public std::enable_shared_from_this<MessageHandlerInterface>
Expand Down
6 changes: 6 additions & 0 deletions test/UnitTimeoutBase.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ inline UnitBase::TestResult UnitTimeoutBase1::testHttp(const size_t connectionLi
sessions.clear();
TST_LOG("Clearing Poller: " << testname);
socketPollers.clear();
// TCP Connection Count: Just an estimation, no locking on server side
TST_LOG("TCP Connection Count: " << testname << ", " << Socket::connectionCount() << " / " << net::Defaults.maxTCPConnections);
TST_LOG("Ending Test: " << testname);
return TestResult::Ok;
}
Expand Down Expand Up @@ -302,6 +304,8 @@ inline UnitBase::TestResult UnitTimeoutBase1::testWSPing(const size_t connection
sessions.clear();
TST_LOG("Clearing Poller: " << testname);
socketPollers.clear();
// TCP Connection Count: Just an estimation, no locking on server side
TST_LOG("TCP Connection Count: " << testname << ", " << Socket::connectionCount() << " / " << net::Defaults.maxTCPConnections);
TST_LOG("Ending Test: " << testname);
return TestResult::Ok;
}
Expand Down Expand Up @@ -397,6 +401,8 @@ inline UnitBase::TestResult UnitTimeoutBase1::testWSDChatPing(const size_t conne
sessions.clear();
TST_LOG("Clearing Poller: " << testname);
socketPollers.clear();
// TCP Connection Count: Just an estimation, no locking on server side
TST_LOG("TCP Connection Count: " << testname << ", " << Socket::connectionCount() << " / " << net::Defaults.maxTCPConnections);
TST_LOG("Ending Test: " << testname);
return TestResult::Ok;
}
Expand Down

0 comments on commit 7031ce0

Please sign in to comment.