Skip to content

Commit

Permalink
cool#9833: Implement MaxConnection Limit (2b)
Browse files Browse the repository at this point in the history
Reimplementation of commit 80246f7 (post revert).

Adding external TCP connection limit to server-side TCP IPv4/IPv6 Sockets
- Counted at StreamSocket ctor
  - Only limits TCP connections for server-side IPv4 or IPv6 TCP connections.
- Rejected at ServerSocker::handlePoll
  - If exceeding net::Defaults.maxExtConnections, socket object and hence connection is dropped

net::Defaults
  - Renamed maxTCPConnections -> maxExtConnections

TODO:
- revise net::Defaults.maxExtConnections to match actual system settings

Signed-off-by: Sven Göthel <[email protected]>
Change-Id: Ib9f0ac17c05ffe65c2370490f68f581fa76730e7
  • Loading branch information
Sven Göthel committed Oct 30, 2024
1 parent 85dd972 commit 136e95a
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 15 deletions.
4 changes: 2 additions & 2 deletions net/NetUtil.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class DefaultValues
/// WebSocketHandler ping interval in us (18s default), i.e. duration until next ping. Zero disables instrument.
std::chrono::microseconds wsPingInterval;

/// Maximum number of concurrent TCP connections. Zero disables instrument.
size_t maxTCPConnections;
/// Maximum number of concurrent external TCP connections. Zero disables instrument.
size_t maxExtConnections;
};
extern DefaultValues Defaults;

Expand Down
11 changes: 8 additions & 3 deletions net/ServerSocket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#pragma once

#include "NetUtil.hpp"
#include "memory"

#include "Socket.hpp"
Expand Down Expand Up @@ -108,9 +109,13 @@ class ServerSocket : public Socket
const std::string msg = "Failed to accept. (errno: ";
throw std::runtime_error(msg + std::strerror(errno) + ')');
}

LOG_TRC("Accepted client #" << clientSocket->getFD());
_clientPoller.insertNewSocket(std::move(clientSocket));
const size_t extConnCount = StreamSocket::getExtConnCount();
if( extConnCount <= net::Defaults.maxExtConnections )
{
LOG_TRC("Accepted client #" << clientSocket->getFD());
_clientPoller.insertNewSocket(std::move(clientSocket));
} else
LOG_TRC("Limiter rejected extConn[" << extConnCount << "/" << net::Defaults.maxExtConnections << "]: " << *clientSocket);
}
}

Expand Down
7 changes: 4 additions & 3 deletions net/Socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,12 @@ std::atomic<bool> Socket::InhibitThreadChecks(false);

std::unique_ptr<Watchdog> SocketPoll::PollWatchdog;

std::atomic<size_t> StreamSocket::extConnCount = 0;

net::DefaultValues net::Defaults = { .inactivityTimeout = std::chrono::seconds(3600),
.wsPingAvgTimeout = std::chrono::seconds(12),
.wsPingInterval = std::chrono::seconds(18),
.maxTCPConnections = 200000 /* arbitrary value to be resolved */ };
.maxExtConnections = 200000 /* arbitrary value to be resolved */ };

#define SOCKET_ABSTRACT_UNIX_NAME "0coolwsd-"

Expand Down Expand Up @@ -1189,8 +1191,7 @@ std::shared_ptr<Socket> ServerSocket::accept()
_socket->setClientAddress(addrstr, clientInfo.sin6_port);

LOG_TRC("Accepted socket #" << _socket->getFD() << " has family "
<< clientInfo.sin6_family << " address "
<< _socket->clientAddress());
<< clientInfo.sin6_family << ", " << *_socket);
#else
std::shared_ptr<Socket> _socket = createSocketFromAccept(rc, Socket::Type::Unix);
#endif
Expand Down
31 changes: 26 additions & 5 deletions net/Socket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

#include <common/StateEnum.hpp>
#include "Log.hpp"
#include "NetUtil.hpp"
#include "Util.hpp"
#include "Buffer.hpp"
#include "SigUtil.hpp"
Expand Down Expand Up @@ -409,6 +410,7 @@ class Socket
LOG_TRC("Ignore further input on socket.");
_ignoreInput = true;
}

protected:
/// Construct based on an existing socket fd.
/// Used by accept() only.
Expand Down Expand Up @@ -1045,7 +1047,7 @@ class StreamSocket : public Socket,
STATE_ENUM(ReadType, NormalRead, UseRecvmsgExpectFD);

/// Create a StreamSocket from native FD.
StreamSocket(std::string host, const int fd, Type type, bool /* isClient */,
StreamSocket(std::string host, const int fd, Type type, bool isClient,
HostType hostType, ReadType readType = ReadType::NormalRead,
std::chrono::steady_clock::time_point creationTime = std::chrono::steady_clock::now() ) :
Socket(fd, type, creationTime),
Expand All @@ -1055,7 +1057,8 @@ class StreamSocket : public Socket,
_sentHTTPContinue(false),
_shutdownSignalled(false),
_readType(readType),
_inputProcessingEnabled(true)
_inputProcessingEnabled(true),
_isExtCountedConn(evalExtCountedConn(fd, type, isClient))
{
LOG_TRC("StreamSocket ctor");
}
Expand All @@ -1078,6 +1081,8 @@ class StreamSocket : public Socket,
_shutdownSignalled = true;
StreamSocket::closeConnection();
}
if( _isExtCountedConn )
--extConnCount;
}

bool isWebSocket() const { return _wsState == WSState::WS; }
Expand Down Expand Up @@ -1322,11 +1327,12 @@ class StreamSocket : public Socket,
_socketHandler.reset();
}

/// Create a socket of type TSocket given an FD and a handler.
/// Create a socket of type TSocket derived from StreamSocket given an FD and a handler.
/// We need this helper since the handler needs a shared_ptr to the socket
/// but we can't have a shared_ptr in the ctor.
template <typename TSocket>
static std::shared_ptr<TSocket> create(std::string hostname, const int fd, Type type,
template <typename TSocket,
std::enable_if_t<std::is_base_of_v<StreamSocket, TSocket>, bool> = true>
static std::shared_ptr<TSocket> create(std::string hostname, int fd, Type type,
bool isClient, HostType hostType,
std::shared_ptr<ProtocolHandlerInterface> handler,
ReadType readType = ReadType::NormalRead,
Expand Down Expand Up @@ -1623,6 +1629,8 @@ class StreamSocket : public Socket,

void dumpState(std::ostream& os) override;

static size_t getExtConnCount() { return extConnCount; }

protected:
void handshakeFail()
{
Expand Down Expand Up @@ -1753,6 +1761,19 @@ class StreamSocket : public Socket,
std::vector<int> _incomingFDs;
ReadType _readType;
std::atomic_bool _inputProcessingEnabled;

const bool _isExtCountedConn;
static bool evalExtCountedConn(int fd, Type type, bool isClient)
{
if(!Util::isMobileApp() && !isClient &&
fd >= 0 && (type == Type::IPv4 || type == Type::IPv6))
{
++extConnCount;
return true;
}
return false;
}
static std::atomic<size_t> extConnCount; // accepted external TCP IPv4/IPv6 socket count
};

enum class WSOpCode : unsigned char {
Expand Down
6 changes: 6 additions & 0 deletions test/UnitTimeoutBase.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ inline UnitBase::TestResult UnitTimeoutBase1::testHttp(const size_t connectionLi
sessions.clear();
TST_LOG("Clearing Poller: " << testname);
socketPollers.clear();
// TCP Connection Count: Just an estimation, no locking on server side
TST_LOG("TCP Connection Count: " << testname << ", " << StreamSocket::getExtConnCount() << " / " << net::Defaults.maxExtConnections);
TST_LOG("Ending Test: " << testname);
return TestResult::Ok;
}
Expand Down Expand Up @@ -302,6 +304,8 @@ inline UnitBase::TestResult UnitTimeoutBase1::testWSPing(const size_t connection
sessions.clear();
TST_LOG("Clearing Poller: " << testname);
socketPollers.clear();
// TCP Connection Count: Just an estimation, no locking on server side
TST_LOG("TCP Connection Count: " << testname << ", " << StreamSocket::getExtConnCount() << " / " << net::Defaults.maxExtConnections);
TST_LOG("Ending Test: " << testname);
return TestResult::Ok;
}
Expand Down Expand Up @@ -397,6 +401,8 @@ inline UnitBase::TestResult UnitTimeoutBase1::testWSDChatPing(const size_t conne
sessions.clear();
TST_LOG("Clearing Poller: " << testname);
socketPollers.clear();
// TCP Connection Count: Just an estimation, no locking on server side
TST_LOG("TCP Connection Count: " << testname << ", " << StreamSocket::getExtConnCount() << " / " << net::Defaults.maxExtConnections);
TST_LOG("Ending Test: " << testname);
return TestResult::Ok;
}
Expand Down
2 changes: 1 addition & 1 deletion test/UnitTimeoutConnections.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class UnitTimeoutConnections : public UnitTimeoutBase1
{
void configure(Poco::Util::LayeredConfiguration& /* config */) override
{
net::Defaults.maxTCPConnections = ConnectionLimit;
net::Defaults.maxExtConnections = ConnectionLimit;
}

public:
Expand Down
2 changes: 1 addition & 1 deletion wsd/COOLWSD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2762,7 +2762,7 @@ void COOLWSD::innerInitialize(Poco::Util::Application& self)
LOG_DBG("net::Defaults: WSPing[timeout "
<< net::Defaults.wsPingAvgTimeout << ", interval " << net::Defaults.wsPingInterval
<< "], Socket[inactivityTimeout " << net::Defaults.inactivityTimeout
<< ", maxTCPConnections " << net::Defaults.maxTCPConnections << "]");
<< ", maxExtConnections " << net::Defaults.maxExtConnections << "]");
}

#if !MOBILEAPP
Expand Down

0 comments on commit 136e95a

Please sign in to comment.