Skip to content

Commit

Permalink
cool#9833: Implement MaxConnection Limit (2b)
Browse files Browse the repository at this point in the history
Reimplementation of commit 80246f7 (post revert).

Adding external TCP connection limit to server-side TCP IPv4/IPv6 Sockets
- Counted at StreamSocket ctor
  - Only limits TCP connections for server-side IPv4 or IPv6 TCP connections.
- Rejected at ServerSocker::handlePoll
  - If exceeding net::Defaults.maxExtConnections, socket object and hence connection is dropped

net::Defaults
  - Renamed maxTCPConnections -> maxExtConnections

TODO:
- revise net::Defaults.maxExtConnections to match actual system settings

Signed-off-by: Sven Göthel <[email protected]>
Change-Id: Ib9f0ac17c05ffe65c2370490f68f581fa76730e7
  • Loading branch information
Sven Göthel authored and caolanm committed Nov 1, 2024
1 parent 8a24d0e commit 5cd0c39
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 16 deletions.
4 changes: 2 additions & 2 deletions net/NetUtil.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class DefaultValues
/// WebSocketHandler ping interval in us (18s default), i.e. duration until next ping. Zero disables instrument.
std::chrono::microseconds wsPingInterval;

/// Maximum number of concurrent TCP connections. Zero disables instrument.
size_t maxTCPConnections;
/// Maximum number of concurrent external TCP connections. Zero disables instrument.
size_t maxExtConnections;
};
extern DefaultValues Defaults;

Expand Down
11 changes: 8 additions & 3 deletions net/ServerSocket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#pragma once

#include "NetUtil.hpp"
#include "memory"

#include "Socket.hpp"
Expand Down Expand Up @@ -108,9 +109,13 @@ class ServerSocket : public Socket
const std::string msg = "Failed to accept. (errno: ";
throw std::runtime_error(msg + std::strerror(errno) + ')');
}

LOG_TRC("Accepted client #" << clientSocket->getFD());
_clientPoller.insertNewSocket(std::move(clientSocket));
const size_t extConnCount = StreamSocket::getExternalConnectionCount();
if( 0 == net::Defaults.maxExtConnections || extConnCount <= net::Defaults.maxExtConnections )
{
LOG_TRC("Accepted client #" << clientSocket->getFD());
_clientPoller.insertNewSocket(std::move(clientSocket));
} else
LOG_WRN("Limiter rejected extConn[" << extConnCount << "/" << net::Defaults.maxExtConnections << "]: " << *clientSocket);
}
}

Expand Down
7 changes: 4 additions & 3 deletions net/Socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,12 @@ std::atomic<bool> Socket::InhibitThreadChecks(false);

std::unique_ptr<Watchdog> SocketPoll::PollWatchdog;

std::atomic<size_t> StreamSocket::ExternalConnectionCount = 0;

net::DefaultValues net::Defaults = { .inactivityTimeout = std::chrono::seconds(3600),
.wsPingAvgTimeout = std::chrono::seconds(12),
.wsPingInterval = std::chrono::seconds(18),
.maxTCPConnections = 200000 /* arbitrary value to be resolved */ };
.maxExtConnections = 200000 /* arbitrary value to be resolved */ };

#define SOCKET_ABSTRACT_UNIX_NAME "0coolwsd-"

Expand Down Expand Up @@ -1189,8 +1191,7 @@ std::shared_ptr<Socket> ServerSocket::accept()
_socket->setClientAddress(addrstr, clientInfo.sin6_port);

LOG_TRC("Accepted socket #" << _socket->getFD() << " has family "
<< clientInfo.sin6_family << " address "
<< _socket->clientAddress());
<< clientInfo.sin6_family << ", " << *_socket);
#else
std::shared_ptr<Socket> _socket = createSocketFromAccept(rc, Socket::Type::Unix);
#endif
Expand Down
28 changes: 22 additions & 6 deletions net/Socket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

#include <common/StateEnum.hpp>
#include "Log.hpp"
#include "NetUtil.hpp"
#include "Util.hpp"
#include "Buffer.hpp"
#include "SigUtil.hpp"
Expand Down Expand Up @@ -409,6 +410,7 @@ class Socket
LOG_TRC("Ignore further input on socket.");
_ignoreInput = true;
}

protected:
/// Construct based on an existing socket fd.
/// Used by accept() only.
Expand Down Expand Up @@ -1045,19 +1047,22 @@ class StreamSocket : public Socket,
STATE_ENUM(ReadType, NormalRead, UseRecvmsgExpectFD);

/// Create a StreamSocket from native FD.
StreamSocket(std::string host, const int fd, Type type, bool /* isClient */,
StreamSocket(std::string host, const int fd, Type type, bool isClient,
HostType hostType, ReadType readType = ReadType::NormalRead,
std::chrono::steady_clock::time_point creationTime = std::chrono::steady_clock::now() ) :
Socket(fd, type, creationTime),
_hostname(std::move(host)),
_wsState(WSState::HTTP),
_isClient(isClient),
_isLocalHost(hostType == LocalHost),
_sentHTTPContinue(false),
_shutdownSignalled(false),
_readType(readType),
_inputProcessingEnabled(true)
{
LOG_TRC("StreamSocket ctor");
if (isExternalCountedConnection())
++ExternalConnectionCount;
}

~StreamSocket() override
Expand All @@ -1078,6 +1083,8 @@ class StreamSocket : public Socket,
_shutdownSignalled = true;
StreamSocket::closeConnection();
}
if (isExternalCountedConnection())
--ExternalConnectionCount;
}

bool isWebSocket() const { return _wsState == WSState::WS; }
Expand Down Expand Up @@ -1322,11 +1329,12 @@ class StreamSocket : public Socket,
_socketHandler.reset();
}

/// Create a socket of type TSocket given an FD and a handler.
/// Create a socket of type TSocket derived from StreamSocket given an FD and a handler.
/// We need this helper since the handler needs a shared_ptr to the socket
/// but we can't have a shared_ptr in the ctor.
template <typename TSocket>
static std::shared_ptr<TSocket> create(std::string hostname, const int fd, Type type,
template <typename TSocket,
std::enable_if_t<std::is_base_of_v<StreamSocket, TSocket>, bool> = true>
static std::shared_ptr<TSocket> create(std::string hostname, int fd, Type type,
bool isClient, HostType hostType,
std::shared_ptr<ProtocolHandlerInterface> handler,
ReadType readType = ReadType::NormalRead,
Expand Down Expand Up @@ -1623,6 +1631,8 @@ class StreamSocket : public Socket,

void dumpState(std::ostream& os) override;

static size_t getExternalConnectionCount() { return ExternalConnectionCount; }

protected:
void handshakeFail()
{
Expand Down Expand Up @@ -1741,18 +1751,24 @@ class StreamSocket : public Socket,
STATE_ENUM(WSState, HTTP, WS);
WSState _wsState;

/// True if owner is in client role, otherwise false (server)
bool _isClient:1;

/// True if host is localhost
bool _isLocalHost;
bool _isLocalHost:1;

/// True if we've received a Continue in response to an Expect: 100-continue
bool _sentHTTPContinue;
bool _sentHTTPContinue:1;

/// True when shutdown was requested via shutdown().
/// It's accessed from different threads.
std::atomic_bool _shutdownSignalled;
std::vector<int> _incomingFDs;
ReadType _readType;
std::atomic_bool _inputProcessingEnabled;

bool isExternalCountedConnection() const { return !_isClient && isIPType(); }
static std::atomic<size_t> ExternalConnectionCount; // accepted external TCP IPv4/IPv6 socket count
};

enum class WSOpCode : unsigned char {
Expand Down
6 changes: 6 additions & 0 deletions test/UnitTimeoutBase.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ inline UnitBase::TestResult UnitTimeoutBase1::testHttp(const size_t connectionLi
sessions.clear();
TST_LOG("Clearing Poller: " << testname);
socketPollers.clear();
// TCP Connection Count: Just an estimation, no locking on server side
TST_LOG("TCP Connection Count: " << testname << ", " << StreamSocket::getExternalConnectionCount() << " / " << net::Defaults.maxExtConnections);
TST_LOG("Ending Test: " << testname);
return TestResult::Ok;
}
Expand Down Expand Up @@ -302,6 +304,8 @@ inline UnitBase::TestResult UnitTimeoutBase1::testWSPing(const size_t connection
sessions.clear();
TST_LOG("Clearing Poller: " << testname);
socketPollers.clear();
// TCP Connection Count: Just an estimation, no locking on server side
TST_LOG("TCP Connection Count: " << testname << ", " << StreamSocket::getExternalConnectionCount() << " / " << net::Defaults.maxExtConnections);
TST_LOG("Ending Test: " << testname);
return TestResult::Ok;
}
Expand Down Expand Up @@ -397,6 +401,8 @@ inline UnitBase::TestResult UnitTimeoutBase1::testWSDChatPing(const size_t conne
sessions.clear();
TST_LOG("Clearing Poller: " << testname);
socketPollers.clear();
// TCP Connection Count: Just an estimation, no locking on server side
TST_LOG("TCP Connection Count: " << testname << ", " << StreamSocket::getExternalConnectionCount() << " / " << net::Defaults.maxExtConnections);
TST_LOG("Ending Test: " << testname);
return TestResult::Ok;
}
Expand Down
2 changes: 1 addition & 1 deletion test/UnitTimeoutConnections.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class UnitTimeoutConnections : public UnitTimeoutBase1
{
void configure(Poco::Util::LayeredConfiguration& /* config */) override
{
net::Defaults.maxTCPConnections = ConnectionLimit;
net::Defaults.maxExtConnections = ConnectionLimit;
}

public:
Expand Down
2 changes: 1 addition & 1 deletion wsd/COOLWSD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2762,7 +2762,7 @@ void COOLWSD::innerInitialize(Poco::Util::Application& self)
LOG_DBG("net::Defaults: WSPing[timeout "
<< net::Defaults.wsPingAvgTimeout << ", interval " << net::Defaults.wsPingInterval
<< "], Socket[inactivityTimeout " << net::Defaults.inactivityTimeout
<< ", maxTCPConnections " << net::Defaults.maxTCPConnections << "]");
<< ", maxExtConnections " << net::Defaults.maxExtConnections << "]");
}

#if !MOBILEAPP
Expand Down

0 comments on commit 5cd0c39

Please sign in to comment.