diff --git a/network/CMakeLists.txt b/network/CMakeLists.txt index 71e5e5e12..b7602f8ea 100644 --- a/network/CMakeLists.txt +++ b/network/CMakeLists.txt @@ -18,9 +18,9 @@ cmake_minimum_required(VERSION 3.10 FATAL_ERROR) project(network) -set(NETWORK_SRCS socket_connection.cc) +set(NETWORK_SRCS socket_connection.cc messages.cc) -set(NETWORK_HDRS socket_connection.h) +set(NETWORK_HDRS socket_connection.h serializable.h messages.h) add_library(network SHARED ${NETWORK_SRCS} ${NETWORK_HDRS}) diff --git a/network/messages.cc b/network/messages.cc new file mode 100644 index 000000000..60a7879fa --- /dev/null +++ b/network/messages.cc @@ -0,0 +1,336 @@ +/* +Copyright 2025 Google Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "messages.h" + +#ifdef WIN32 +# include +#else +# include +#endif + +constexpr uint32_t kMaxPayloadSize = 16 * 1024 * 1024; + +namespace Network +{ + +void WriteUint32ToBuffer(uint32_t value, Buffer& dest) +{ + uint32_t net_val = htonl(value); + const uint8_t* p_val = reinterpret_cast(&net_val); + dest.insert(dest.end(), p_val, p_val + sizeof(uint32_t)); +} + +void WriteStringToBuffer(const std::string& str, Buffer& dest) +{ + WriteUint32ToBuffer(static_cast(str.length()), dest); + dest.insert(dest.end(), str.begin(), str.end()); +} + +absl::StatusOr ReadUint32FromBuffer(const Buffer& src, size_t& offset) +{ + if (src.size() < offset + sizeof(uint32_t)) + { + return absl::InvalidArgumentError("Buffer too small to read a uint32_t."); + } + uint32_t net_val; + std::memcpy(&net_val, src.data() + offset, sizeof(uint32_t)); + offset += sizeof(uint32_t); + return ntohl(net_val); +} + +absl::StatusOr ReadStringFromBuffer(const Buffer& src, size_t& offset) +{ + absl::StatusOr len_or = ReadUint32FromBuffer(src, offset); + if (!len_or.ok()) + { + return len_or.status(); + } + + uint32_t len = len_or.value(); + if (src.size() < offset + len) + { + return absl::InvalidArgumentError("Buffer too small for declared string length."); + } + std::string result(reinterpret_cast(src.data() + offset), len); + offset += len; + return result; +} + +absl::Status HandShakeMessage::Serialize(Buffer& dest) const +{ + dest.clear(); + WriteUint32ToBuffer(m_major_version, dest); + WriteUint32ToBuffer(m_minor_version, dest); + return absl::OkStatus(); +} + +absl::Status HandShakeMessage::Deserialize(const Buffer& src) +{ + size_t offset = 0; + absl::StatusOr major_or = ReadUint32FromBuffer(src, offset); + if (!major_or.ok()) + { + return major_or.status(); + } + m_major_version = major_or.value(); + + absl::StatusOr minor_or = ReadUint32FromBuffer(src, offset); + if (!minor_or.ok()) + { + return minor_or.status(); + } + m_minor_version = minor_or.value(); + + if (offset != src.size()) + { + return absl::InvalidArgumentError("Handshake message has unexpected trailing data."); + } + return absl::OkStatus(); +} + +absl::Status StringMessage::Serialize(Buffer& dest) const +{ + dest.clear(); + WriteStringToBuffer(m_str, dest); + return absl::OkStatus(); +} + +absl::Status StringMessage::Deserialize(const Buffer& src) +{ + size_t offset = 0; + absl::StatusOr str_or = ReadStringFromBuffer(src, offset); + if (!str_or.ok()) + { + return str_or.status(); + } + + m_str = std::move(str_or.value()); + if (offset != src.size()) + { + return absl::InvalidArgumentError("String message has unexpected trailing data."); + } + return absl::OkStatus(); +} + +absl::Status DownloadFileResponse::Serialize(Buffer& dest) const +{ + dest.push_back(static_cast(m_found)); + WriteStringToBuffer(m_error_reason, dest); + WriteStringToBuffer(m_file_path, dest); + WriteStringToBuffer(m_file_size_str, dest); + + return absl::OkStatus(); +} + +absl::Status DownloadFileResponse::Deserialize(const Buffer& src) +{ + size_t offset = 0; + + // Deserialize the 'found' boolean. + if (src.size() < offset + sizeof(uint8_t)) + { + return absl::InvalidArgumentError("Buffer too small for 'found' field."); + } + m_found = (src[offset] != 0); + offset += sizeof(uint8_t); + + // Deserialize the strings using the StatusOr-returning helper + absl::StatusOr error_reason_or = ReadStringFromBuffer(src, offset); + if (!error_reason_or.ok()) + { + return error_reason_or.status(); // Forward the error + } + m_error_reason = std::move(error_reason_or.value()); + + absl::StatusOr file_path_or = ReadStringFromBuffer(src, offset); + if (!file_path_or.ok()) + { + return file_path_or.status(); + } + m_file_path = std::move(file_path_or.value()); + + absl::StatusOr file_size_or = ReadStringFromBuffer(src, offset); + if (!file_size_or.ok()) + { + return file_size_or.status(); + } + m_file_size_str = std::move(file_size_or.value()); + + // Final check for trailing data. + if (offset != src.size()) + { + return absl::InvalidArgumentError("Message has unexpected trailing data."); + } + + return absl::OkStatus(); +} + +absl::Status ReceiveBuffer(SocketConnection* conn, uint8_t* buffer, size_t size) +{ + if (!conn) + { + return absl::InvalidArgumentError("Provided SocketConnection is null."); + } + size_t total_received = 0; + while (total_received < size) + { + absl::StatusOr received_or = conn->Recv(buffer + total_received, + size - total_received); + if (!received_or.ok()) + { + return received_or.status(); + } + total_received += received_or.value(); + } + return absl::OkStatus(); +} + +absl::Status SendBuffer(SocketConnection* conn, const uint8_t* buffer, size_t size) +{ + if (!conn) + { + return absl::InvalidArgumentError("Provided SocketConnection is null."); + } + return conn->Send(buffer, size); +} + +absl::StatusOr> ReceiveMessage(SocketConnection* conn) +{ + if (!conn) + { + return absl::InvalidArgumentError("Provided SocketConnection is null."); + } + + const size_t header_size = sizeof(uint32_t) * 2; + uint8_t header_buffer[header_size]; + + // Receive the message header. + absl::Status status = ReceiveBuffer(conn, header_buffer, header_size); + if (!status.ok()) + { + return status; + } + + // Parse header. + uint32_t net_type, net_length; + std::memcpy(&net_type, header_buffer, sizeof(uint32_t)); + std::memcpy(&net_length, header_buffer + sizeof(uint32_t), sizeof(uint32_t)); + uint32_t type = ntohl(net_type); + uint32_t payload_length = ntohl(net_length); + + if (payload_length > kMaxPayloadSize) + { + conn->Close(); + return absl::InvalidArgumentError( + absl::StrCat("Payload size ", payload_length, " exceeds limit.")); + } + + // Receive the message payload. + Buffer payload_buffer(payload_length); + status = ReceiveBuffer(conn, payload_buffer.data(), payload_length); + if (!status.ok()) + { + return status; + } + + // Create and deserialize the message object. + std::unique_ptr message; + switch (static_cast(type)) + { + case MessageType::HANDSHAKE_REQUEST: + message = std::make_unique(); + break; + case MessageType::HANDSHAKE_RESPONSE: + message = std::make_unique(); + break; + case MessageType::PING_MESSAGE: + message = std::make_unique(); + break; + case MessageType::PONG_MESSAGE: + message = std::make_unique(); + break; + case MessageType::PM4_CAPTURE_REQUEST: + message = std::make_unique(); + break; + case MessageType::PM4_CAPTURE_RESPONSE: + message = std::make_unique(); + break; + case MessageType::DOWNLOAD_FILE_REQUEST: + message = std::make_unique(); + break; + case MessageType::DOWNLOAD_FILE_RESPONSE: + message = std::make_unique(); + break; + default: + conn->Close(); + return absl::InvalidArgumentError(absl::StrCat("Unknown message type: ", type)); + } + + status = message->Deserialize(payload_buffer); + if (!status.ok()) + { + conn->Close(); + return status; + } + + return message; +} + +absl::Status SendMessage(SocketConnection* conn, const ISerializable& message) +{ + if (!conn) + { + return absl::InvalidArgumentError("Provided SocketConnection is null."); + } + + // Serialize the message payload. + Buffer payload_buffer; + absl::Status status = message.Serialize(payload_buffer); + if (!status.ok()) + { + return status; + } + if (payload_buffer.size() > kMaxPayloadSize) + { + return absl::InvalidArgumentError("Serialized payload size exceeds limit."); + } + + // Construct and send the header. + uint32_t net_type = htonl(message.GetMessageType()); + uint32_t net_payload_length = htonl(static_cast(payload_buffer.size())); + const size_t header_size = sizeof(net_type) + sizeof(net_payload_length); + uint8_t header_buffer[header_size]; + std::memcpy(header_buffer, &net_type, sizeof(uint32_t)); + std::memcpy(header_buffer + sizeof(uint32_t), &net_payload_length, sizeof(uint32_t)); + + status = SendBuffer(conn, header_buffer, header_size); + if (!status.ok()) + { + return status; + } + + // Send the payload. + status = SendBuffer(conn, payload_buffer.data(), payload_buffer.size()); + if (!status.ok()) + { + return status; + } + + return absl::OkStatus(); +} + +} // namespace Network diff --git a/network/messages.h b/network/messages.h new file mode 100644 index 000000000..adf6b6e0b --- /dev/null +++ b/network/messages.h @@ -0,0 +1,196 @@ +/* +Copyright 2025 Google Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#pragma once + +#include "serializable.h" +#include "socket_connection.h" + +namespace Network +{ + +// Helper to write a uint32_t to a buffer. +void WriteUint32ToBuffer(uint32_t value, Buffer& dest); + +// Helper to write a string (length + data) to the buffer. +void WriteStringToBuffer(const std::string& str, Buffer& dest); + +// Helper to read a uint32_t from a buffer. +absl::StatusOr ReadUint32FromBuffer(const Buffer& src, size_t& offset); + +// Helper to read a string (length + data) from the buffer. +absl::StatusOr ReadStringFromBuffer(const Buffer& src, size_t& offset); + +enum class MessageType : uint32_t +{ + HANDSHAKE_REQUEST = 1, + HANDSHAKE_RESPONSE = 2, + PING_MESSAGE = 3, + PONG_MESSAGE = 4, + PM4_CAPTURE_REQUEST = 5, + PM4_CAPTURE_RESPONSE = 6, + DOWNLOAD_FILE_REQUEST = 7, + DOWNLOAD_FILE_RESPONSE = 8 +}; + +class HandShakeMessage : public ISerializable +{ +public: + absl::Status Serialize(Buffer& dest) const override; + absl::Status Deserialize(const Buffer& src) override; + + uint32_t GetMajorVersion() const { return m_major_version; } + uint32_t GetMinorVersion() const { return m_minor_version; } + void SetMajorVersion(uint32_t major) { m_major_version = major; } + void SetMinorVersion(uint32_t minor) { m_minor_version = minor; } + +private: + uint32_t m_major_version; + uint32_t m_minor_version; +}; + +class HandShakeRequest : public HandShakeMessage +{ +public: + uint32_t GetMessageType() const override { return m_type; } + +private: + const uint32_t m_type = static_cast(MessageType::HANDSHAKE_REQUEST); +}; + +class HandShakeResponse : public HandShakeMessage +{ +public: + uint32_t GetMessageType() const override { return m_type; } + +private: + const uint32_t m_type = static_cast(MessageType::HANDSHAKE_RESPONSE); +}; + +class EmptyMessage : public ISerializable +{ +public: + absl::Status Serialize(Buffer& dest) const override { return absl::OkStatus(); } + absl::Status Deserialize(const Buffer& src) override { return absl::OkStatus(); } +}; + +class Pm4CaptureRequest : public EmptyMessage +{ +public: + uint32_t GetMessageType() const override { return m_type; } + +private: + const uint32_t m_type = static_cast(MessageType::PM4_CAPTURE_REQUEST); +}; + +class StringMessage : public ISerializable +{ +public: + absl::Status Serialize(Buffer& dest) const override; + absl::Status Deserialize(const Buffer& src) override; + + const std::string& GetString() const { return m_str; } + void SetString(std::string str) { m_str = std::move(str); } + +private: + std::string m_str; +}; + +class Pm4CaptureResponse : public StringMessage +{ +public: + uint32_t GetMessageType() const override { return m_type; } + +private: + const uint32_t m_type = static_cast(MessageType::PM4_CAPTURE_RESPONSE); +}; + +class PingMessage : public StringMessage +{ +public: + uint32_t GetMessageType() const override { return m_type; } + +private: + const uint32_t m_type = static_cast(MessageType::PING_MESSAGE); +}; + +class PongMessage : public StringMessage +{ +public: + uint32_t GetMessageType() const override { return m_type; } + +private: + const uint32_t m_type = static_cast(MessageType::PONG_MESSAGE); +}; + +class DownloadFileRequest : public StringMessage +{ +public: + uint32_t GetMessageType() const override { return m_type; } + +private: + const uint32_t m_type = static_cast(MessageType::DOWNLOAD_FILE_REQUEST); +}; + +class DownloadFileResponse : public ISerializable +{ +public: + uint32_t GetMessageType() const override { return m_type; } + absl::Status Serialize(Buffer& dest) const override; + absl::Status Deserialize(const Buffer& src) override; + + bool GetFound() const { return m_found; } + void SetFound(bool found) { m_found = found; } + + const std::string& GetErrorReason() const { return m_error_reason; } + void SetErrorReason(std::string error_reason) { m_error_reason = std::move(error_reason); } + + const std::string& GetFilePath() const { return m_file_path; } + void SetFilePath(std::string file_path) { m_file_path = std::move(file_path); } + + const std::string& GetFileSizeStr() const { return m_file_size_str; } + void SetFileSizeStr(std::string file_size_str) { m_file_size_str = std::move(file_size_str); } + +private: + // Flag indicating whether the requested file was found on the server. + bool m_found; + // A description of the error if the download failed. Empty if successful. + std::string m_error_reason; + // The local path where the downloaded file has been saved on the server. + // It can be the same as the requested file path from client. + std::string m_file_path; + // A string representation of the downloaded file's size. + // It avoids to use uint64_t which requires custom implementation for htonll/ntohll. + std::string m_file_size_str; + + const uint32_t m_type = static_cast(MessageType::DOWNLOAD_FILE_RESPONSE); +}; + +// Message Helper Functions (TLV Framing). + +// Helper to receive an exact number of bytes. +absl::Status ReceiveBuffer(SocketConnection* conn, uint8_t* buffer, size_t size); + +// Helper to send an exact number of bytes. +absl::Status SendBuffer(SocketConnection* conn, const uint8_t* buffer, size_t size); + +// Returns a fully-formed message or an error status. +absl::StatusOr> ReceiveMessage(SocketConnection* conn); + +// Sends a full message (header + payload). +absl::Status SendMessage(SocketConnection* conn, const ISerializable& message); + +} // namespace Network \ No newline at end of file diff --git a/network/serializable.h b/network/serializable.h new file mode 100644 index 000000000..d799f1dc1 --- /dev/null +++ b/network/serializable.h @@ -0,0 +1,46 @@ +/* +Copyright 2025 Google Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#pragma once + +#include +#include + +#include "absl/status/statusor.h" + +namespace Network +{ + +using Buffer = std::vector; + +class ISerializable +{ +public: + virtual ~ISerializable() = default; + + // Returns the specific type identifier for this message. + virtual uint32_t GetMessageType() const = 0; + + // Serializes the object's payload into the destination buffer. + // Returns absl::OkStatus() on success, or an error status on failure. + virtual absl::Status Serialize(Buffer& dest) const = 0; + + // Deserializes the object's state from the source buffer. + // Returns absl::OkStatus() on success, or an error status on failure. + virtual absl::Status Deserialize(const Buffer& src) = 0; +}; + +} // namespace Network \ No newline at end of file diff --git a/network/socket_connection.cc b/network/socket_connection.cc index 325e8f4d1..8ce0903b0 100644 --- a/network/socket_connection.cc +++ b/network/socket_connection.cc @@ -24,7 +24,7 @@ limitations under the License. # define NOMINMAX # include # pragma comment(lib, "Ws2_32.lib") -typedef SSIZE_T ssize_t; +using ssize_t = SSIZE_T; #else # include # include diff --git a/network/socket_connection.h b/network/socket_connection.h index 3211d87d2..7dcf43b7d 100644 --- a/network/socket_connection.h +++ b/network/socket_connection.h @@ -23,12 +23,12 @@ limitations under the License. #ifdef WIN32 // On Windows, a socket is a pointer-sized handle to ensure 32/64-bit compatibility. -typedef uintptr_t SocketType; +using SocketType = uintptr_t; // The value for an invalid socket on Windows is INVALID_SOCKET (~0). constexpr SocketType kInvalidSocketValue = ~static_cast(0); #else // On POSIX systems, a socket is a file descriptor (`int`). -typedef int SocketType; +using SocketType = int; // Functions that return a file descriptor use -1 to indicate an error. constexpr SocketType kInvalidSocketValue = -1; #endif