diff --git a/src/common/transport.h b/src/common/transport.h new file mode 100644 index 0000000000000..3a7987b102416 --- /dev/null +++ b/src/common/transport.h @@ -0,0 +1,189 @@ +// Copyright (c) 2024 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#ifndef BITCOIN_COMMON_TRANSPORT_H +#define BITCOIN_COMMON_TRANSPORT_H + +#include +#include +#include +#include +#include +#include + +/** Transport layer version */ +enum class TransportProtocolType : uint8_t { + DETECTING, //!< Peer could be v1 or v2 + V1, //!< Unencrypted, plaintext protocol + V2, //!< BIP324 protocol +}; + +/** Convert TransportProtocolType enum to a string value */ +std::string TransportTypeAsString(TransportProtocolType transport_type); + +/** Transport protocol agnostic message container. + * Ideally it should only contain receive time, payload, + * type and size. + */ +class CNetMessage +{ +public: + DataStream m_recv; //!< received message data + std::chrono::microseconds m_time{0}; //!< time of message receipt + uint32_t m_message_size{0}; //!< size of the payload + uint32_t m_raw_message_size{0}; //!< used wire size of the message (including header/checksum) + std::string m_type; + + explicit CNetMessage(DataStream&& recv_in) : m_recv(std::move(recv_in)) {} + // Only one CNetMessage object will exist for the same message on either + // the receive or processing queue. For performance reasons we therefore + // delete the copy constructor and assignment operator to avoid the + // possibility of copying CNetMessage objects. + CNetMessage(CNetMessage&&) = default; + CNetMessage(const CNetMessage&) = delete; + CNetMessage& operator=(CNetMessage&&) = default; + CNetMessage& operator=(const CNetMessage&) = delete; + + /** Compute total memory usage of this object (own memory + any dynamic memory). */ + size_t GetMemoryUsage() const noexcept; +}; + +struct CSerializedNetMsg { + CSerializedNetMsg() = default; + CSerializedNetMsg(CSerializedNetMsg&&) = default; + CSerializedNetMsg& operator=(CSerializedNetMsg&&) = default; + // No implicit copying, only moves. + CSerializedNetMsg(const CSerializedNetMsg& msg) = delete; + CSerializedNetMsg& operator=(const CSerializedNetMsg&) = delete; + + CSerializedNetMsg Copy() const + { + CSerializedNetMsg copy; + copy.data = data; + copy.m_type = m_type; + return copy; + } + + std::vector data; + std::string m_type; + + /** Compute total memory usage of this object (own memory + any dynamic memory). */ + size_t GetMemoryUsage() const noexcept; +}; + +/** The Transport converts one connection's sent messages to wire bytes, and received bytes back. */ +class Transport { +public: + virtual ~Transport() = default; + + struct Info + { + TransportProtocolType transport_type; + std::optional session_id; + }; + + /** Retrieve information about this transport. */ + virtual Info GetInfo() const noexcept = 0; + + // 1. Receiver side functions, for decoding bytes received on the wire into transport protocol + // agnostic CNetMessage (message type & payload) objects. + + /** Returns true if the current message is complete (so GetReceivedMessage can be called). */ + virtual bool ReceivedMessageComplete() const = 0; + + /** Feed wire bytes to the transport. + * + * @return false if some bytes were invalid, in which case the transport can't be used anymore. + * + * Consumed bytes are chopped off the front of msg_bytes. + */ + virtual bool ReceivedBytes(Span& msg_bytes) = 0; + + /** Retrieve a completed message from transport. + * + * This can only be called when ReceivedMessageComplete() is true. + * + * If reject_message=true is returned the message itself is invalid, but (other than false + * returned by ReceivedBytes) the transport is not in an inconsistent state. + */ + virtual CNetMessage GetReceivedMessage(std::chrono::microseconds time, bool& reject_message) = 0; + + // 2. Sending side functions, for converting messages into bytes to be sent over the wire. + + /** Set the next message to send. + * + * If no message can currently be set (perhaps because the previous one is not yet done being + * sent), returns false, and msg will be unmodified. Otherwise msg is enqueued (and + * possibly moved-from) and true is returned. + */ + virtual bool SetMessageToSend(CSerializedNetMsg& msg) noexcept = 0; + + /** Return type for GetBytesToSend, consisting of: + * - Span to_send: span of bytes to be sent over the wire (possibly empty). + * - bool more: whether there will be more bytes to be sent after the ones in to_send are + * all sent (as signaled by MarkBytesSent()). + * - const std::string& m_type: message type on behalf of which this is being sent + * ("" for bytes that are not on behalf of any message). + */ + using BytesToSend = std::tuple< + Span /*to_send*/, + bool /*more*/, + const std::string& /*m_type*/ + >; + + /** Get bytes to send on the wire, if any, along with other information about it. + * + * As a const function, it does not modify the transport's observable state, and is thus safe + * to be called multiple times. + * + * @param[in] have_next_message If true, the "more" return value reports whether more will + * be sendable after a SetMessageToSend call. It is set by the caller when they know + * they have another message ready to send, and only care about what happens + * after that. The have_next_message argument only affects this "more" return value + * and nothing else. + * + * Effectively, there are three possible outcomes about whether there are more bytes + * to send: + * - Yes: the transport itself has more bytes to send later. For example, for + * V1Transport this happens during the sending of the header of a + * message, when there is a non-empty payload that follows. + * - No: the transport itself has no more bytes to send, but will have bytes to + * send if handed a message through SetMessageToSend. In V1Transport this + * happens when sending the payload of a message. + * - Blocked: the transport itself has no more bytes to send, and is also incapable + * of sending anything more at all now, if it were handed another + * message to send. This occurs in V2Transport before the handshake is + * complete, as the encryption ciphers are not set up for sending + * messages before that point. + * + * The boolean 'more' is true for Yes, false for Blocked, and have_next_message + * controls what is returned for No. + * + * @return a BytesToSend object. The to_send member returned acts as a stream which is only + * ever appended to. This means that with the exception of MarkBytesSent (which pops + * bytes off the front of later to_sends), operations on the transport can only append + * to what is being returned. Also note that m_type and to_send refer to data that is + * internal to the transport, and calling any non-const function on this object may + * invalidate them. + */ + virtual BytesToSend GetBytesToSend(bool have_next_message) const noexcept = 0; + + /** Report how many bytes returned by the last GetBytesToSend() have been sent. + * + * bytes_sent cannot exceed to_send.size() of the last GetBytesToSend() result. + * + * If bytes_sent=0, this call has no effect. + */ + virtual void MarkBytesSent(size_t bytes_sent) noexcept = 0; + + /** Return the memory usage of this transport attributable to buffered data to send. */ + virtual size_t GetSendMemoryUsage() const noexcept = 0; + + // 3. Miscellaneous functions. + + /** Whether upon disconnections, a reconnect with V1 is warranted. */ + virtual bool ShouldReconnectV1() const noexcept = 0; +}; + +#endif // BITCOIN_COMMON_TRANSPORT_H diff --git a/src/net.h b/src/net.h index fc096ff7b8680..a009bb683ecf5 100644 --- a/src/net.h +++ b/src/net.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -111,29 +112,6 @@ struct AddedNodeInfo { class CNodeStats; class CClientUIInterface; -struct CSerializedNetMsg { - CSerializedNetMsg() = default; - CSerializedNetMsg(CSerializedNetMsg&&) = default; - CSerializedNetMsg& operator=(CSerializedNetMsg&&) = default; - // No implicit copying, only moves. - CSerializedNetMsg(const CSerializedNetMsg& msg) = delete; - CSerializedNetMsg& operator=(const CSerializedNetMsg&) = delete; - - CSerializedNetMsg Copy() const - { - CSerializedNetMsg copy; - copy.data = data; - copy.m_type = m_type; - return copy; - } - - std::vector data; - std::string m_type; - - /** Compute total memory usage of this object (own memory + any dynamic memory). */ - size_t GetMemoryUsage() const noexcept; -}; - /** * Look up IP addresses from all interfaces on the machine and add them to the * list of local addresses to self-advertise. @@ -222,148 +200,6 @@ class CNodeStats std::string m_session_id; }; - -/** Transport protocol agnostic message container. - * Ideally it should only contain receive time, payload, - * type and size. - */ -class CNetMessage -{ -public: - DataStream m_recv; //!< received message data - std::chrono::microseconds m_time{0}; //!< time of message receipt - uint32_t m_message_size{0}; //!< size of the payload - uint32_t m_raw_message_size{0}; //!< used wire size of the message (including header/checksum) - std::string m_type; - - explicit CNetMessage(DataStream&& recv_in) : m_recv(std::move(recv_in)) {} - // Only one CNetMessage object will exist for the same message on either - // the receive or processing queue. For performance reasons we therefore - // delete the copy constructor and assignment operator to avoid the - // possibility of copying CNetMessage objects. - CNetMessage(CNetMessage&&) = default; - CNetMessage(const CNetMessage&) = delete; - CNetMessage& operator=(CNetMessage&&) = default; - CNetMessage& operator=(const CNetMessage&) = delete; - - /** Compute total memory usage of this object (own memory + any dynamic memory). */ - size_t GetMemoryUsage() const noexcept; -}; - -/** The Transport converts one connection's sent messages to wire bytes, and received bytes back. */ -class Transport { -public: - virtual ~Transport() = default; - - struct Info - { - TransportProtocolType transport_type; - std::optional session_id; - }; - - /** Retrieve information about this transport. */ - virtual Info GetInfo() const noexcept = 0; - - // 1. Receiver side functions, for decoding bytes received on the wire into transport protocol - // agnostic CNetMessage (message type & payload) objects. - - /** Returns true if the current message is complete (so GetReceivedMessage can be called). */ - virtual bool ReceivedMessageComplete() const = 0; - - /** Feed wire bytes to the transport. - * - * @return false if some bytes were invalid, in which case the transport can't be used anymore. - * - * Consumed bytes are chopped off the front of msg_bytes. - */ - virtual bool ReceivedBytes(Span& msg_bytes) = 0; - - /** Retrieve a completed message from transport. - * - * This can only be called when ReceivedMessageComplete() is true. - * - * If reject_message=true is returned the message itself is invalid, but (other than false - * returned by ReceivedBytes) the transport is not in an inconsistent state. - */ - virtual CNetMessage GetReceivedMessage(std::chrono::microseconds time, bool& reject_message) = 0; - - // 2. Sending side functions, for converting messages into bytes to be sent over the wire. - - /** Set the next message to send. - * - * If no message can currently be set (perhaps because the previous one is not yet done being - * sent), returns false, and msg will be unmodified. Otherwise msg is enqueued (and - * possibly moved-from) and true is returned. - */ - virtual bool SetMessageToSend(CSerializedNetMsg& msg) noexcept = 0; - - /** Return type for GetBytesToSend, consisting of: - * - Span to_send: span of bytes to be sent over the wire (possibly empty). - * - bool more: whether there will be more bytes to be sent after the ones in to_send are - * all sent (as signaled by MarkBytesSent()). - * - const std::string& m_type: message type on behalf of which this is being sent - * ("" for bytes that are not on behalf of any message). - */ - using BytesToSend = std::tuple< - Span /*to_send*/, - bool /*more*/, - const std::string& /*m_type*/ - >; - - /** Get bytes to send on the wire, if any, along with other information about it. - * - * As a const function, it does not modify the transport's observable state, and is thus safe - * to be called multiple times. - * - * @param[in] have_next_message If true, the "more" return value reports whether more will - * be sendable after a SetMessageToSend call. It is set by the caller when they know - * they have another message ready to send, and only care about what happens - * after that. The have_next_message argument only affects this "more" return value - * and nothing else. - * - * Effectively, there are three possible outcomes about whether there are more bytes - * to send: - * - Yes: the transport itself has more bytes to send later. For example, for - * V1Transport this happens during the sending of the header of a - * message, when there is a non-empty payload that follows. - * - No: the transport itself has no more bytes to send, but will have bytes to - * send if handed a message through SetMessageToSend. In V1Transport this - * happens when sending the payload of a message. - * - Blocked: the transport itself has no more bytes to send, and is also incapable - * of sending anything more at all now, if it were handed another - * message to send. This occurs in V2Transport before the handshake is - * complete, as the encryption ciphers are not set up for sending - * messages before that point. - * - * The boolean 'more' is true for Yes, false for Blocked, and have_next_message - * controls what is returned for No. - * - * @return a BytesToSend object. The to_send member returned acts as a stream which is only - * ever appended to. This means that with the exception of MarkBytesSent (which pops - * bytes off the front of later to_sends), operations on the transport can only append - * to what is being returned. Also note that m_type and to_send refer to data that is - * internal to the transport, and calling any non-const function on this object may - * invalidate them. - */ - virtual BytesToSend GetBytesToSend(bool have_next_message) const noexcept = 0; - - /** Report how many bytes returned by the last GetBytesToSend() have been sent. - * - * bytes_sent cannot exceed to_send.size() of the last GetBytesToSend() result. - * - * If bytes_sent=0, this call has no effect. - */ - virtual void MarkBytesSent(size_t bytes_sent) noexcept = 0; - - /** Return the memory usage of this transport attributable to buffered data to send. */ - virtual size_t GetSendMemoryUsage() const noexcept = 0; - - // 3. Miscellaneous functions. - - /** Whether upon disconnections, a reconnect with V1 is warranted. */ - virtual bool ShouldReconnectV1() const noexcept = 0; -}; - class V1Transport final : public Transport { private: diff --git a/src/node/connection_types.cpp b/src/node/connection_types.cpp index 5e4dc5bf2ef94..2d8dbec2f131c 100644 --- a/src/node/connection_types.cpp +++ b/src/node/connection_types.cpp @@ -2,6 +2,7 @@ // Distributed under the MIT software license, see the accompanying // file COPYING or http://www.opensource.org/licenses/mit-license.php. +#include #include #include diff --git a/src/node/connection_types.h b/src/node/connection_types.h index a911b95f7e917..5e1abcace67d1 100644 --- a/src/node/connection_types.h +++ b/src/node/connection_types.h @@ -6,7 +6,6 @@ #define BITCOIN_NODE_CONNECTION_TYPES_H #include -#include /** Different types of connections to a peer. This enum encapsulates the * information we have available at the time of opening or accepting the @@ -80,14 +79,4 @@ enum class ConnectionType { /** Convert ConnectionType enum to a string value */ std::string ConnectionTypeAsString(ConnectionType conn_type); -/** Transport layer version */ -enum class TransportProtocolType : uint8_t { - DETECTING, //!< Peer could be v1 or v2 - V1, //!< Unencrypted, plaintext protocol - V2, //!< BIP324 protocol -}; - -/** Convert TransportProtocolType enum to a string value */ -std::string TransportTypeAsString(TransportProtocolType transport_type); - #endif // BITCOIN_NODE_CONNECTION_TYPES_H diff --git a/src/sv2/CMakeLists.txt b/src/sv2/CMakeLists.txt index d6e44842e8c87..e61f2f3560834 100644 --- a/src/sv2/CMakeLists.txt +++ b/src/sv2/CMakeLists.txt @@ -4,6 +4,7 @@ add_library(bitcoin_sv2 STATIC EXCLUDE_FROM_ALL noise.cpp + transport.cpp ) target_link_libraries(bitcoin_sv2 diff --git a/src/sv2/messages.h b/src/sv2/messages.h new file mode 100644 index 0000000000000..fbbe68d63d6cf --- /dev/null +++ b/src/sv2/messages.h @@ -0,0 +1,209 @@ +// Copyright (c) 2023-present The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#ifndef BITCOIN_SV2_MESSAGES_H +#define BITCOIN_SV2_MESSAGES_H + +#include // for CSerializedNetMsg and CNetMessage +#include +#include +#include +#include + +namespace node { +/** + * A type used as the message length field in stratum v2 messages. + */ +using u24_t = uint8_t[3]; + +/** + * All the stratum v2 message types handled by the template provider. + */ +enum class Sv2MsgType : uint8_t { + COINBASE_OUTPUT_DATA_SIZE = 0x70, +}; + +/** + * Set the coinbase outputs data len for the outputs that the client wants to add to the coinbase. + * The template provider MUST NOT provide NewWork messages which would represent consensus-invalid blocks once this + * additional size — along with a maximally-sized (100 byte) coinbase field — is added. + */ +struct Sv2CoinbaseOutputDataSizeMsg +{ + /** + * The default message type value for this Stratum V2 message. + */ + static constexpr auto m_msg_type = Sv2MsgType::COINBASE_OUTPUT_DATA_SIZE; + + /** + * The maximum additional serialized bytes which the pool will add in coinbase transaction outputs. + */ + uint32_t m_coinbase_output_max_additional_size; + + template + void Serialize(Stream& s) const + { + s << m_coinbase_output_max_additional_size; + }; + + + template + void Unserialize(Stream& s) + { + s >> m_coinbase_output_max_additional_size; + } +}; + +/** + * Header for all stratum v2 messages. Each header must contain the message type, + * the length of the serialized message and a 2 byte extension field currently + * not utilised by the template provider. + */ +class Sv2NetHeader +{ +public: + /** + * Unique identifier of the message. + */ + Sv2MsgType m_msg_type; + + /** + * Serialized length of the message. + */ + uint32_t m_msg_len; + + Sv2NetHeader() = default; + explicit Sv2NetHeader(Sv2MsgType msg_type, uint32_t msg_len) : m_msg_type{msg_type}, m_msg_len{msg_len} {}; + + template + void Serialize(Stream& s) const + { + // The template provider currently does not use the extension_type field, + // but the field is still required for all headers. + uint16_t extension_type = 0; + + u24_t msg_len; + msg_len[2] = (m_msg_len >> 16) & 0xff; + msg_len[1] = (m_msg_len >> 8) & 0xff; + msg_len[0] = m_msg_len & 0xff; + + s << extension_type + << static_cast(m_msg_type) + << msg_len; + }; + + template + void Unserialize(Stream& s) + { + // Ignore the first 2 bytes (extension type) as the template provider currently doesn't + // interpret this field. + s.ignore(2); + + uint8_t msg_type; + s >> msg_type; + m_msg_type = static_cast(msg_type); + + u24_t msg_len_bytes; + for (unsigned int i = 0; i < sizeof(u24_t); ++i) { + s >> msg_len_bytes[i]; + } + + m_msg_len = msg_len_bytes[2]; + m_msg_len = m_msg_len << 8 | msg_len_bytes[1]; + m_msg_len = m_msg_len << 8 | msg_len_bytes[0]; + } +}; + +/** + * The networked form for all stratum v2 messages, contains a header and a serialized + * payload from a referenced stratum v2 message. + */ +class Sv2NetMsg +{ +public: + Sv2MsgType m_msg_type; + std::vector m_msg; + + explicit Sv2NetMsg(const Sv2MsgType msg_type, const std::vector&& msg) : m_msg_type{msg_type}, m_msg{msg} {}; + + // Unwrap CSerializedNetMsg + Sv2NetMsg(CSerializedNetMsg&& net_msg) + { + Assume(net_msg.m_type == ""); + DataStream ss(MakeByteSpan(net_msg.data)); + Unserialize(ss); + }; + + // Unwrap CNetMsg + Sv2NetMsg(CNetMessage net_msg) + { + Unserialize(net_msg.m_recv); + }; + + operator CSerializedNetMsg() + { + CSerializedNetMsg net_msg; + net_msg.m_type = ""; + DataStream ser; + Serialize(ser); + net_msg.data.resize(ser.size()); + std::transform(ser.begin(), ser.end(), net_msg.data.begin(), + [](std::byte b) { return static_cast(b); }); + return net_msg; + } + + operator CNetMessage() + { + DataStream msg; + Serialize(msg); + CNetMessage ret{std::move(msg)}; + return ret; + } + + /** + * Serializes the message M and sets an Sv2 network header. + * @throws std::ios_base or std::out_of_range errors. + */ + template + explicit Sv2NetMsg(const M& msg) + { + m_msg_type = msg.m_msg_type; + + // Serialize the sv2 message. + VectorWriter{m_msg, 0, msg}; + } + + unsigned char* data() { return m_msg.data(); } + size_t size() { return m_msg.size(); } + + operator Sv2NetHeader() + { + Sv2NetHeader hdr; + hdr.m_msg_type = m_msg_type; + hdr.m_msg_len = static_cast(m_msg.size()); + return hdr; + } + + template + void Unserialize(Stream& s) + { + uint8_t msg_type; + s >> msg_type; + m_msg_type = static_cast(msg_type); + m_msg.resize(s.size()); + s.read(MakeWritableByteSpan(m_msg)); + } + + template + void Serialize(Stream& s) const + { + s << static_cast(m_msg_type); + s.write(MakeByteSpan(m_msg)); + } + +}; + +} + +#endif // BITCOIN_SV2_MESSAGES_H diff --git a/src/sv2/transport.cpp b/src/sv2/transport.cpp new file mode 100644 index 0000000000000..37a6e36ba19e0 --- /dev/null +++ b/src/sv2/transport.cpp @@ -0,0 +1,494 @@ +// Copyright (c) 2023-present The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +Sv2Transport::Sv2Transport(CKey static_key, Sv2SignatureNoiseMessage certificate) noexcept + : m_cipher{Sv2Cipher(std::move(static_key), std::move(certificate))}, m_initiating{false}, + m_recv_state{RecvState::HANDSHAKE_STEP_1}, + m_send_state{SendState::HANDSHAKE_STEP_2}, + m_message{Sv2NetMsg(Sv2NetHeader{})} +{ + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Noise session receive state -> %s\n", + RecvStateAsString(m_recv_state)); +} + +Sv2Transport::Sv2Transport(CKey static_key, XOnlyPubKey responder_authority_key) noexcept + : m_cipher{Sv2Cipher(std::move(static_key), responder_authority_key)}, m_initiating{true}, + m_recv_state{RecvState::HANDSHAKE_STEP_2}, + m_send_state{SendState::HANDSHAKE_STEP_1}, + m_message{Sv2NetMsg(Sv2NetHeader{})} +{ + /** Start sending immediately since we're the initiator of the connection. + This only happens in test code. + */ + LOCK(m_send_mutex); + StartSendingHandshake(); + +} + +void Sv2Transport::SetReceiveState(RecvState recv_state) noexcept +{ + AssertLockHeld(m_recv_mutex); + // Enforce allowed state transitions. + switch (m_recv_state) { + case RecvState::HANDSHAKE_STEP_1: + Assume(recv_state == RecvState::HANDSHAKE_STEP_2); + break; + case RecvState::HANDSHAKE_STEP_2: + Assume(recv_state == RecvState::APP); + break; + case RecvState::APP: + Assume(recv_state == RecvState::APP_READY); + break; + case RecvState::APP_READY: + Assume(recv_state == RecvState::APP); + break; + } + // Change state. + m_recv_state = recv_state; + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Noise session receive state -> %s\n", + RecvStateAsString(m_recv_state)); + +} + +void Sv2Transport::SetSendState(SendState send_state) noexcept +{ + AssertLockHeld(m_send_mutex); + // Enforce allowed state transitions. + switch (m_send_state) { + case SendState::HANDSHAKE_STEP_1: + Assume(send_state == SendState::HANDSHAKE_STEP_2); + break; + case SendState::HANDSHAKE_STEP_2: + Assume(send_state == SendState::READY); + break; + case SendState::READY: + Assume(false); // Final state + break; + } + // Change state. + m_send_state = send_state; + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Noise session send state -> %s\n", + SendStateAsString(m_send_state)); +} + +void Sv2Transport::StartSendingHandshake() noexcept +{ + AssertLockHeld(m_send_mutex); + AssertLockNotHeld(m_recv_mutex); + Assume(m_send_state == SendState::HANDSHAKE_STEP_1); + Assume(m_send_buffer.empty()); + + m_send_buffer.resize(Sv2HandshakeState::ELLSWIFT_PUB_KEY_SIZE); + m_cipher.GetHandshakeState().WriteMsgEphemeralPK(MakeWritableByteSpan(m_send_buffer)); + + m_send_state = SendState::HANDSHAKE_STEP_2; +} + +void Sv2Transport::SendHandshakeReply() noexcept +{ + AssertLockHeld(m_send_mutex); + AssertLockHeld(m_recv_mutex); + Assume(m_send_state == SendState::HANDSHAKE_STEP_2); + + Assume(m_send_buffer.empty()); + m_send_buffer.resize(Sv2HandshakeState::HANDSHAKE_STEP2_SIZE); + m_cipher.GetHandshakeState().WriteMsgES(MakeWritableByteSpan(m_send_buffer)); + + m_cipher.FinishHandshake(); + + // We can send and receive stuff now, unless the other side hangs up + SetSendState(SendState::READY); + Assume(m_recv_state == RecvState::HANDSHAKE_STEP_2); + SetReceiveState(RecvState::APP); +} + +Transport::BytesToSend Sv2Transport::GetBytesToSend(bool have_next_message) const noexcept +{ + AssertLockNotHeld(m_send_mutex); + LOCK(m_send_mutex); + + const std::string dummy_m_type; // m_type is set to "" when wrapping Sv2NetMsg + + Assume(m_send_pos <= m_send_buffer.size()); + return { + Span{m_send_buffer}.subspan(m_send_pos), + // We only have more to send after the current m_send_buffer if there is a (next) + // message to be sent, and we're capable of sending packets. */ + have_next_message && m_send_state == SendState::READY, + dummy_m_type + }; +} + +void Sv2Transport::MarkBytesSent(size_t bytes_sent) noexcept +{ + AssertLockNotHeld(m_send_mutex); + LOCK(m_send_mutex); + + // if (m_send_state == SendState::AWAITING_KEY && m_send_pos == 0 && bytes_sent > 0) { + // LogPrint(BCLog::NET, "start sending v2 handshake to peer=%d\n", m_nodeid); + // } + + m_send_pos += bytes_sent; + Assume(m_send_pos <= m_send_buffer.size()); + // Wipe the buffer when everything is sent. + if (m_send_pos == m_send_buffer.size()) { + m_send_pos = 0; + ClearShrink(m_send_buffer); + } +} + +bool Sv2Transport::SetMessageToSend(CSerializedNetMsg& msg) noexcept +{ + AssertLockNotHeld(m_send_mutex); + LOCK(m_send_mutex); + + // We only allow adding a new message to be sent when in the READY state (so the packet cipher + // is available) and the send buffer is empty. This limits the number of messages in the send + // buffer to just one, and leaves the responsibility for queueing them up to the caller. + if (m_send_state != SendState::READY) { + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "SendState is not READY\n"); + return false; + } + + if (!m_send_buffer.empty()) { + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Send buffer is not empty\n"); + return false; + } + + // The Sv2NetMsg is wrapped inside a dummy CSerializedNetMsg, extract it: + Sv2NetMsg sv2_msg(std::move(msg)); + // Reconstruct the header: + Sv2NetHeader hdr(sv2_msg.m_msg_type, sv2_msg.size()); + + // Construct ciphertext in send buffer. + const size_t encrypted_msg_size = Sv2Cipher::EncryptedMessageSize(sv2_msg.size()); + m_send_buffer.resize(SV2_HEADER_ENCRYPTED_SIZE + encrypted_msg_size); + Span buffer_span{MakeWritableByteSpan(m_send_buffer)}; + + // Header + DataStream ss_header_plain{}; + ss_header_plain << hdr; + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Header: %s\n", HexStr(ss_header_plain)); + Span header_encrypted{buffer_span.subspan(0, SV2_HEADER_ENCRYPTED_SIZE)}; + if (!m_cipher.EncryptMessage(ss_header_plain, header_encrypted)) { + return false; + } + + // Payload + Span payload_plain = MakeByteSpan(sv2_msg); + // TODO: truncate very long messages, about 100 bytes at the start and end + // is probably enough for most debugging. + // LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Payload: %s\n", HexStr(payload_plain)); + Span payload_encrypted{buffer_span.subspan(SV2_HEADER_ENCRYPTED_SIZE, encrypted_msg_size)}; + if (!m_cipher.EncryptMessage(payload_plain, payload_encrypted)) { + return false; + } + + // Release memory (not needed with std::move above) + // ClearShrink(msg.data); + + return true; +} + +size_t Sv2Transport::GetSendMemoryUsage() const noexcept +{ + AssertLockNotHeld(m_send_mutex); + LOCK(m_send_mutex); + + return sizeof(m_send_buffer) + memusage::DynamicUsage(m_send_buffer); +} + +bool Sv2Transport::ReceivedBytes(Span& msg_bytes) noexcept +{ + AssertLockNotHeld(m_send_mutex); + AssertLockNotHeld(m_recv_mutex); + /** How many bytes to allocate in the receive buffer at most above what is received so far. */ + static constexpr size_t MAX_RESERVE_AHEAD = 256 * 1024; // TODO: reduce to NOISE_MAX_CHUNK_SIZE? + + LOCK(m_recv_mutex); + // Process the provided bytes in msg_bytes in a loop. In each iteration a nonzero number of + // bytes (decided by GetMaxBytesToProcess) are taken from the beginning om msg_bytes, and + // appended to m_recv_buffer. Then, depending on the receiver state, one of the + // ProcessReceived*Bytes functions is called to process the bytes in that buffer. + while (!msg_bytes.empty()) { + // Decide how many bytes to copy from msg_bytes to m_recv_buffer. + size_t max_read = GetMaxBytesToProcess(); + + // Reserve space in the buffer if there is not enough. + if (m_recv_buffer.size() + std::min(msg_bytes.size(), max_read) > m_recv_buffer.capacity()) { + switch (m_recv_state) { + case RecvState::HANDSHAKE_STEP_1: + m_recv_buffer.reserve(Sv2HandshakeState::ELLSWIFT_PUB_KEY_SIZE); + break; + case RecvState::HANDSHAKE_STEP_2: + m_recv_buffer.reserve(Sv2HandshakeState::HANDSHAKE_STEP2_SIZE); + break; + case RecvState::APP: { + // During states where a packet is being received, as much as is expected but never + // more than MAX_RESERVE_AHEAD bytes in addition to what is received so far. + // This means attackers that want to cause us to waste allocated memory are limited + // to MAX_RESERVE_AHEAD above the largest allowed message contents size, and to + // MAX_RESERVE_AHEAD more than they've actually sent us. + size_t alloc_add = std::min(max_read, msg_bytes.size() + MAX_RESERVE_AHEAD); + m_recv_buffer.reserve(m_recv_buffer.size() + alloc_add); + break; + } + case RecvState::APP_READY: + // The buffer is empty in this state. + Assume(m_recv_buffer.empty()); + break; + } + } + + // Can't read more than provided input. + max_read = std::min(msg_bytes.size(), max_read); + // Copy data to buffer. + m_recv_buffer.insert(m_recv_buffer.end(), UCharCast(msg_bytes.data()), UCharCast(msg_bytes.data() + max_read)); + msg_bytes = msg_bytes.subspan(max_read); + + // Process data in the buffer. + switch (m_recv_state) { + + case RecvState::HANDSHAKE_STEP_1: + if (!ProcessReceivedEphemeralKeyBytes()) return false; + break; + + case RecvState::HANDSHAKE_STEP_2: + if (!ProcessReceivedHandshakeReplyBytes()) return false; + break; + + case RecvState::APP: + if (!ProcessReceivedPacketBytes()) return false; + break; + + case RecvState::APP_READY: + return true; + + } + // Make sure we have made progress before continuing. + Assume(max_read > 0); + } + + return true; +} + +bool Sv2Transport::ProcessReceivedEphemeralKeyBytes() noexcept +{ + AssertLockHeld(m_recv_mutex); + AssertLockNotHeld(m_send_mutex); + Assume(m_recv_state == RecvState::HANDSHAKE_STEP_1); + Assume(m_recv_buffer.size() <= Sv2HandshakeState::ELLSWIFT_PUB_KEY_SIZE); + + if (m_recv_buffer.size() == Sv2HandshakeState::ELLSWIFT_PUB_KEY_SIZE) { + // Other side's key has been fully received, and can now be Diffie-Hellman + // combined with our key. This is act 1 of the Noise Protocol handshake. + // TODO handle failure + // TODO: MakeByteSpan instead of MakeWritableByteSpan + m_cipher.GetHandshakeState().ReadMsgEphemeralPK(MakeWritableByteSpan(m_recv_buffer)); + m_recv_buffer.clear(); + SetReceiveState(RecvState::HANDSHAKE_STEP_2); + + LOCK(m_send_mutex); + Assume(m_send_buffer.size() == 0); + + // Send our act 2 handshake + SendHandshakeReply(); + } else { + // We still have to receive more key bytes. + } + return true; +} + +bool Sv2Transport::ProcessReceivedHandshakeReplyBytes() noexcept +{ + AssertLockHeld(m_recv_mutex); + AssertLockNotHeld(m_send_mutex); + Assume(m_recv_state == RecvState::HANDSHAKE_STEP_2); + Assume(m_recv_buffer.size() <= Sv2HandshakeState::HANDSHAKE_STEP2_SIZE); + + if (m_recv_buffer.size() == Sv2HandshakeState::HANDSHAKE_STEP2_SIZE) { + // TODO handle failure + // TODO: MakeByteSpan instead of MakeWritableByteSpan + bool res = m_cipher.GetHandshakeState().ReadMsgES(MakeWritableByteSpan(m_recv_buffer)); + if (!res) return false; + m_recv_buffer.clear(); + m_cipher.FinishHandshake(); + SetReceiveState(RecvState::APP); + + LOCK(m_send_mutex); + Assume(m_send_buffer.size() == 0); + + SetSendState(SendState::READY); + } else { + // We still have to receive more key bytes. + } + return true; +} + +size_t Sv2Transport::GetMaxBytesToProcess() noexcept +{ + AssertLockHeld(m_recv_mutex); + switch (m_recv_state) { + case RecvState::HANDSHAKE_STEP_1: + // In this state, we only allow the 64-byte key into the receive buffer. + Assume(m_recv_buffer.size() <= Sv2HandshakeState::ELLSWIFT_PUB_KEY_SIZE); + return Sv2HandshakeState::ELLSWIFT_PUB_KEY_SIZE - m_recv_buffer.size(); + case RecvState::HANDSHAKE_STEP_2: + // In this state, we only allow the handshake reply into the receive buffer. + Assume(m_recv_buffer.size() <= Sv2HandshakeState::HANDSHAKE_STEP2_SIZE); + return Sv2HandshakeState::HANDSHAKE_STEP2_SIZE - m_recv_buffer.size(); + case RecvState::APP: + // Decode a packet. Process the header first, + // so that we know where the current packet ends (and we don't process bytes from the next + // packet yet). Then, process the ciphertext bytes of the current packet. + if (m_recv_buffer.size() < SV2_HEADER_ENCRYPTED_SIZE) { + return SV2_HEADER_ENCRYPTED_SIZE - m_recv_buffer.size(); + } else { + // When transitioning from receiving the packet length to receiving its ciphertext, + // the encrypted header is left in the receive buffer. + size_t expanded_size_with_header = SV2_HEADER_ENCRYPTED_SIZE + Sv2Cipher::EncryptedMessageSize(m_header.m_msg_len); + return expanded_size_with_header - m_recv_buffer.size(); + } + case RecvState::APP_READY: + // No bytes can be processed until GetMessage() is called. + return 0; + } + Assume(false); // unreachable + return 0; +} + +bool Sv2Transport::ProcessReceivedPacketBytes() noexcept +{ + AssertLockHeld(m_recv_mutex); + Assume(m_recv_state == RecvState::APP); + + // The maximum permitted decrypted payload size for a packet + static constexpr size_t MAX_CONTENTS_LEN = 16777215; // 24 bit unsigned; + + Assume(m_recv_buffer.size() <= SV2_HEADER_ENCRYPTED_SIZE || m_header.m_msg_len > 0); + + if (m_recv_buffer.size() == SV2_HEADER_ENCRYPTED_SIZE) { + // Header received, decrypt it. + std::array header_plain; + if (!m_cipher.DecryptMessage(MakeWritableByteSpan(m_recv_buffer), header_plain)) { + LogPrintLevel(BCLog::SV2, BCLog::Level::Debug, "Failed to decrypt header\n"); + return false; + } + + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Header: %s\n", HexStr(header_plain)); + + // Decode header + DataStream ss_header{header_plain}; + node::Sv2NetHeader header; + ss_header >> header; + m_header = std::move(header); + + // TODO: 16 MB is pretty large, maybe set lower limits for most or all message types? + if (m_header.m_msg_len > MAX_CONTENTS_LEN) { + LogTrace(BCLog::SV2, "Packet too large (%u bytes)\n", m_header.m_msg_len); + return false; + } + + // Disconnect for empty messages (TODO: check the spec) + if (m_header.m_msg_len == 0) { + LogTrace(BCLog::SV2, "Empty message\n"); + return false; + } + LogTrace(BCLog::SV2, "Expecting %d bytes payload (plain)\n", m_header.m_msg_len); + } else if (m_recv_buffer.size() > SV2_HEADER_ENCRYPTED_SIZE && + m_recv_buffer.size() == SV2_HEADER_ENCRYPTED_SIZE + Sv2Cipher::EncryptedMessageSize(m_header.m_msg_len)) { + /** Ciphertext received: decrypt into decode_buffer and deserialize into m_message. + * + * Note that it is impossible to reach this branch without hitting the + * branch above first, as GetMaxBytesToProcess only allows up to + * SV2_HEADER_ENCRYPTED_SIZE into the buffer before that point. */ + std::vector payload; + payload.resize(m_header.m_msg_len); + + Span recv_span{MakeWritableByteSpan(m_recv_buffer).subspan(SV2_HEADER_ENCRYPTED_SIZE)}; + if (!m_cipher.DecryptMessage(recv_span, MakeWritableByteSpan(payload))) { + LogPrintLevel(BCLog::SV2, BCLog::Level::Debug, "Failed to decrypt message payload\n"); + return false; + } + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Payload: %s\n", HexStr(payload)); + + // Wipe the receive buffer where the next packet will be received into. + ClearShrink(m_recv_buffer); + + Sv2NetMsg message{m_header.m_msg_type, std::move(payload)}; + m_message = std::move(message); + + // At this point we have a valid message decrypted into m_message. + SetReceiveState(RecvState::APP_READY); + } else { + // We either have less than 22 bytes, so we don't know the packet's length yet, or more + // than 22 bytes but less than the packet's full ciphertext. Wait until those arrive. + LogTrace(BCLog::SV2, "Waiting for more bytes\n"); + } + return true; +} + +bool Sv2Transport::ReceivedMessageComplete() const noexcept +{ + AssertLockNotHeld(m_recv_mutex); + LOCK(m_recv_mutex); + + return m_recv_state == RecvState::APP_READY; +} + +CNetMessage Sv2Transport::GetReceivedMessage(std::chrono::microseconds time, bool& reject_message) noexcept +{ + AssertLockNotHeld(m_recv_mutex); + LOCK(m_recv_mutex); + Assume(m_recv_state == RecvState::APP_READY); + + SetReceiveState(RecvState::APP); + return m_message; // Sv2NetMsg is wrapped in a CNetMessage +} + +Transport::Info Sv2Transport::GetInfo() const noexcept +{ + return {.transport_type = TransportProtocolType::V1, .session_id = {}}; +} + +std::string RecvStateAsString(Sv2Transport::RecvState state) +{ + switch (state) { + case Sv2Transport::RecvState::HANDSHAKE_STEP_1: + return "HANDSHAKE_STEP_1"; + case Sv2Transport::RecvState::HANDSHAKE_STEP_2: + return "HANDSHAKE_STEP_2"; + case Sv2Transport::RecvState::APP: + return "APP"; + case Sv2Transport::RecvState::APP_READY: + return "APP_READY"; + } // no default case, so the compiler can warn about missing cases + + assert(false); +} + +std::string SendStateAsString(Sv2Transport::SendState state) +{ + switch (state) { + case Sv2Transport::SendState::HANDSHAKE_STEP_1: + return "HANDSHAKE_STEP_1"; + case Sv2Transport::SendState::HANDSHAKE_STEP_2: + return "HANDSHAKE_STEP_2"; + case Sv2Transport::SendState::READY: + return "READY"; + } // no default case, so the compiler can warn about missing cases + + assert(false); +} diff --git a/src/sv2/transport.h b/src/sv2/transport.h new file mode 100644 index 0000000000000..b8948e71a9f42 --- /dev/null +++ b/src/sv2/transport.h @@ -0,0 +1,194 @@ +// Copyright (c) 2023-present The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#ifndef BITCOIN_SV2_TRANSPORT_H +#define BITCOIN_SV2_TRANSPORT_H + +#include +#include +#include +#include + +static constexpr size_t SV2_HEADER_PLAIN_SIZE{6}; +static constexpr size_t SV2_HEADER_ENCRYPTED_SIZE{SV2_HEADER_PLAIN_SIZE + Poly1305::TAGLEN}; + +using node::Sv2NetHeader; +using node::Sv2NetMsg; + +class Sv2Transport final : public Transport +{ +public: + + // The sender side and receiver side of Sv2Transport are state machines that are transitioned + // through, based on what has been received. The receive state corresponds to the contents of, + // and bytes received to, the receive buffer. The send state controls what can be appended to + // the send buffer and what can be sent from it. + + /** State type that defines the current contents of the receive buffer and/or how the next + * received bytes added to it will be interpreted. + * + * Diagram: + * + * start(responder) + * | start(initiator) + * | | /---------\ + * | | | | + * v v v | + * HANDSHAKE_STEP_1 -> HANDSHAKE_STEP_2 -> APP -> APP_READY + */ + enum class RecvState : uint8_t { + /** Handshake Act 1: -> E */ + HANDSHAKE_STEP_1, + + /** Handshake Act 2: <- e, ee, s, es, SIGNATURE_NOISE_MESSAGE */ + HANDSHAKE_STEP_2, + + /** Application packet. + * + * A packet is received, and decrypted/verified. If that succeeds, the + * state becomes APP_READY and the decrypted message is kept in m_message + * until it is retrieved by GetMessage(). */ + APP, + + /** Nothing (an application packet is available for GetMessage()). + * + * Nothing can be received in this state. When the message is retrieved + * by GetMessage(), the state becomes APP again. */ + APP_READY, + }; + + /** State type that controls the sender side. + * + * Diagram: + * + * start(initiator) + * | start(responder) + * | | + * | | + * v v + * HANDSHAKE_STEP_1 -> HANDSHAKE_STEP_2 -> READY + */ + enum class SendState : uint8_t { + /** Handshake Act 1: -> E */ + HANDSHAKE_STEP_1, + + /** Handshake Act 2: <- e, ee, s, es, SIGNATURE_NOISE_MESSAGE */ + HANDSHAKE_STEP_2, + + /** Normal sending state. + * + * In this state, the ciphers are initialized, so packets can be sent. + * In this state a message can be provided if the send buffer is empty. */ + READY, + }; + +private: + + /** Cipher state. */ + Sv2Cipher m_cipher; + + /** Whether we are the initiator side. */ + const bool m_initiating; + + /** Lock for receiver-side fields. */ + mutable Mutex m_recv_mutex ACQUIRED_BEFORE(m_send_mutex); + /** Receive buffer; meaning is determined by m_recv_state. */ + std::vector m_recv_buffer GUARDED_BY(m_recv_mutex); + /** AAD expected in next received packet (currently used only for garbage). */ + std::vector m_recv_aad GUARDED_BY(m_recv_mutex); + /** Current receiver state. */ + RecvState m_recv_state GUARDED_BY(m_recv_mutex); + + /** Lock for sending-side fields. If both sending and receiving fields are accessed, + * m_recv_mutex must be acquired before m_send_mutex. */ + mutable Mutex m_send_mutex ACQUIRED_AFTER(m_recv_mutex); + /** The send buffer; meaning is determined by m_send_state. */ + std::vector m_send_buffer GUARDED_BY(m_send_mutex); + /** How many bytes from the send buffer have been sent so far. */ + uint32_t m_send_pos GUARDED_BY(m_send_mutex) {0}; + /** The garbage sent, or to be sent (MAYBE_V1 and AWAITING_KEY state only). */ + std::vector m_send_garbage GUARDED_BY(m_send_mutex); + /** Type of the message being sent. */ + std::string m_send_type GUARDED_BY(m_send_mutex); + /** Current sender state. */ + SendState m_send_state GUARDED_BY(m_send_mutex); + + /** Change the receive state. */ + void SetReceiveState(RecvState recv_state) noexcept EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex); + /** Change the send state. */ + void SetSendState(SendState send_state) noexcept EXCLUSIVE_LOCKS_REQUIRED(m_send_mutex); + /** Given a packet's contents, find the message type (if valid), and strip it from contents. */ + static std::optional GetMessageType(Span& contents) noexcept; + /** Determine how many received bytes can be processed in one go (not allowed in V1 state). */ + size_t GetMaxBytesToProcess() noexcept EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex); + /** Put our ephemeral public key in the send buffer. */ + void StartSendingHandshake() noexcept EXCLUSIVE_LOCKS_REQUIRED(m_send_mutex, !m_recv_mutex); + /** Put second part of the handshake in the send buffer. */ + void SendHandshakeReply() noexcept EXCLUSIVE_LOCKS_REQUIRED(m_send_mutex, m_recv_mutex); + /** Process bytes in m_recv_buffer, while in HANDSHAKE_STEP_1 state. */ + bool ProcessReceivedEphemeralKeyBytes() noexcept EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex, !m_send_mutex); + /** Process bytes in m_recv_buffer, while in HANDSHAKE_STEP_2 state. */ + bool ProcessReceivedHandshakeReplyBytes() noexcept EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex, !m_send_mutex); + + /** Process bytes in m_recv_buffer, while in VERSION/APP state. */ + bool ProcessReceivedPacketBytes() noexcept EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex); + + /** In APP, the decrypted header, if m_recv_buffer.size() >= + * SV2_HEADER_ENCRYPTED_SIZE. Unspecified otherwise. */ + Sv2NetHeader m_header GUARDED_BY(m_recv_mutex); + /* In APP_READY the last retrieved message. Unspecified otherwise */ + Sv2NetMsg m_message GUARDED_BY(m_recv_mutex); + +public: + /** Construct a Stratum v2 transport as the initiator + * + * @param[in] static_key a securely generated key + + */ + Sv2Transport(CKey static_key, XOnlyPubKey responder_authority_key) noexcept; + + /** Construct a Stratum v2 transport as the responder + * + * @param[in] static_key a securely generated key + + */ + Sv2Transport(CKey static_key, Sv2SignatureNoiseMessage certificate) noexcept; + + // Receive side functions. + bool ReceivedMessageComplete() const noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex); + bool ReceivedBytes(Span& msg_bytes) noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex, !m_send_mutex); + + CNetMessage GetReceivedMessage(std::chrono::microseconds time, bool& reject_message) noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex); + + // Send side functions. + bool SetMessageToSend(CSerializedNetMsg& msg) noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex); + + BytesToSend GetBytesToSend(bool have_next_message) const noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex); + + void MarkBytesSent(size_t bytes_sent) noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex); + size_t GetSendMemoryUsage() const noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex); + + // Miscellaneous functions. + bool ShouldReconnectV1() const noexcept override { return false; }; + Info GetInfo() const noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex); + + // Test only + uint256 NoiseHash() const { return m_cipher.GetHash(); }; + RecvState GetRecvState() EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex) { + AssertLockNotHeld(m_recv_mutex); + LOCK(m_recv_mutex); + return m_recv_state; + }; + SendState GetSendState() EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex) { + AssertLockNotHeld(m_send_mutex); + LOCK(m_send_mutex); + return m_send_state; + }; +}; + +/** Convert TransportProtocolType enum to a string value */ +std::string RecvStateAsString(Sv2Transport::RecvState state); +std::string SendStateAsString(Sv2Transport::SendState state); + +#endif // BITCOIN_SV2_TRANSPORT_H diff --git a/src/test/CMakeLists.txt b/src/test/CMakeLists.txt index dbbe5228267ae..0c2fc409453f3 100644 --- a/src/test/CMakeLists.txt +++ b/src/test/CMakeLists.txt @@ -179,6 +179,7 @@ if(WITH_SV2) target_sources(test_bitcoin PRIVATE sv2_noise_tests.cpp + sv2_transport_tests.cpp ) target_link_libraries(test_bitcoin bitcoin_sv2) endif() diff --git a/src/test/sv2_transport_tests.cpp b/src/test/sv2_transport_tests.cpp new file mode 100644 index 0000000000000..9a353d231ad3e --- /dev/null +++ b/src/test/sv2_transport_tests.cpp @@ -0,0 +1,389 @@ +// Copyright (c) 2023-present The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +using namespace std::literals; +using node::Sv2NetMsg; +using node::Sv2CoinbaseOutputDataSizeMsg; +using node::Sv2MsgType; + +BOOST_FIXTURE_TEST_SUITE(sv2_transport_tests, RegTestingSetup) + +namespace { + +/** A class for scenario-based tests of Sv2Transport + * + * Each Sv2TransportTester encapsulates a Sv2Transport (the one being tested), + * and can be told to interact with it. To do so, it also encapsulates a Sv2Cipher + * to act as the other side. A second Sv2Transport is not used, as doing so would + * not permit scenarios that involve sending invalid data. + */ +class Sv2TransportTester +{ + FastRandomContext& m_rng; + std::unique_ptr m_transport; //!< Sv2Transport being tested + std::unique_ptr m_peer_cipher; //!< Cipher to help with the other side + bool m_test_initiator; //!< Whether m_transport is the initiator (true) or responder (false) + + std::vector m_to_send; //!< Bytes we have queued up to send to m_transport-> + std::vector m_received; //!< Bytes we have received from m_transport-> + std::deque m_msg_to_send; //!< Messages to be sent *by* m_transport to us. + +public: + /** Construct a tester object. test_initiator: whether the tested transport is initiator. */ + + explicit Sv2TransportTester(FastRandomContext& rng, bool test_initiator) : m_rng{rng}, m_test_initiator(test_initiator) + { + auto initiator_static_key{GenerateRandomKey()}; + auto responder_static_key{GenerateRandomKey()}; + auto responder_authority_key{GenerateRandomKey()}; + + // Create certificates + auto epoch_now = std::chrono::system_clock::now().time_since_epoch(); + uint16_t version = 0; + uint32_t valid_from = static_cast(std::chrono::duration_cast(epoch_now).count()); + uint32_t valid_to = std::numeric_limits::max(); + + auto responder_certificate = Sv2SignatureNoiseMessage(version, valid_from, valid_to, + XOnlyPubKey(responder_static_key.GetPubKey()), responder_authority_key); + + if (test_initiator) { + m_transport = std::make_unique(initiator_static_key, XOnlyPubKey(responder_authority_key.GetPubKey())); + m_peer_cipher = std::make_unique(std::move(responder_static_key), std::move(responder_certificate)); + } else { + m_transport = std::make_unique(responder_static_key, responder_certificate); + m_peer_cipher = std::make_unique(std::move(initiator_static_key), XOnlyPubKey(responder_authority_key.GetPubKey())); + } + } + + /** Data type returned by Interact: + * + * - std::nullopt: transport error occurred + * - otherwise: a vector of + * - std::nullopt: invalid message received + * - otherwise: a Sv2NetMsg retrieved + */ + using InteractResult = std::optional>>; + + void LogProgress(bool should_progress, bool progress, bool pretend_no_progress) { + if (!should_progress) { + BOOST_TEST_MESSAGE("[Interact] !should_progress"); + } else if (!progress) { + BOOST_TEST_MESSAGE("[Interact] should_progress && !progress"); + } else if (pretend_no_progress) { + BOOST_TEST_MESSAGE("[Interact] pretend !progress"); + } + } + + /** Send/receive scheduled/available bytes and messages. + * + * This is the only function that interacts with the transport being tested; everything else is + * scheduling things done by Interact(), or processing things learned by it. + */ + InteractResult Interact() + { + std::vector> ret; + while (true) { + bool progress{false}; + // Send bytes from m_to_send to the transport. + if (!m_to_send.empty()) { + size_t n_bytes_to_send = 1 + m_rng.randrange(m_to_send.size()); + BOOST_TEST_MESSAGE(strprintf("[Interact] send %d of %d bytes", n_bytes_to_send, m_to_send.size())); + Span to_send = Span{m_to_send}.first(n_bytes_to_send); + size_t old_len = to_send.size(); + if (!m_transport->ReceivedBytes(to_send)) { + BOOST_TEST_MESSAGE("[Interact] transport error"); + return std::nullopt; + } + if (old_len != to_send.size()) { + progress = true; + m_to_send.erase(m_to_send.begin(), m_to_send.begin() + (old_len - to_send.size())); + } + } + // Retrieve messages received by the transport. + bool should_progress = m_transport->ReceivedMessageComplete(); + bool pretend_no_progress = m_rng.randbool(); + LogProgress(should_progress, progress, pretend_no_progress); + if (should_progress && (!progress || pretend_no_progress)) { + bool dummy_reject_message = false; + CNetMessage net_msg = m_transport->GetReceivedMessage(std::chrono::microseconds(0), dummy_reject_message); + Sv2NetMsg msg(std::move(net_msg)); + ret.emplace_back(std::move(msg)); + progress = true; + } + // Enqueue a message to be sent by the transport to us. + should_progress = !m_msg_to_send.empty(); + pretend_no_progress = m_rng.randbool(); + LogProgress(should_progress, progress, pretend_no_progress); + if (should_progress && (!progress || pretend_no_progress)) { + BOOST_TEST_MESSAGE("Shoehorn into CSerializedNetMsg"); + CSerializedNetMsg msg{m_msg_to_send.front()}; + BOOST_TEST_MESSAGE("Call SetMessageToSend"); + if (m_transport->SetMessageToSend(msg)) { + BOOST_TEST_MESSAGE("Finished SetMessageToSend"); + m_msg_to_send.pop_front(); + progress = true; + } + } + // Receive bytes from the transport. + const auto& [recv_bytes, _more, _m_type] = m_transport->GetBytesToSend(!m_msg_to_send.empty()); + should_progress = !recv_bytes.empty(); + pretend_no_progress = m_rng.randbool(); + LogProgress(should_progress, progress, pretend_no_progress); + if (should_progress && (!progress || pretend_no_progress)) { + size_t to_receive = 1 + m_rng.randrange(recv_bytes.size()); + BOOST_TEST_MESSAGE(strprintf("[Interact] receive %d of %d bytes", to_receive, recv_bytes.size())); + m_received.insert(m_received.end(), recv_bytes.begin(), recv_bytes.begin() + to_receive); + progress = true; + m_transport->MarkBytesSent(to_receive); + } + if (!progress) break; + } + return ret; + } + + /** Schedule bytes to be sent to the transport. */ + void Send(Span data) + { + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Send: %s\n", HexStr(data)); + m_to_send.insert(m_to_send.end(), data.begin(), data.end()); + } + + /** Schedule bytes to be sent to the transport. */ + void Send(Span data) { Send(MakeUCharSpan(data)); } + + /** Schedule a message to be sent to us by the transport. */ + void AddMessage(Sv2NetMsg msg) + { + m_msg_to_send.push_back(std::move(msg)); + } + + /** + * If we are the initiator, the send buffer should contain our ephemeral public + * key. Pass this to the peer cipher and clear the buffer. + * + * If we are the responder, put the peer ephemeral public key on our receive buffer. + */ + void ProcessHandshake1() { + if (m_test_initiator) { + BOOST_REQUIRE(m_received.size() == Sv2HandshakeState::ELLSWIFT_PUB_KEY_SIZE); + m_peer_cipher->GetHandshakeState().ReadMsgEphemeralPK(MakeWritableByteSpan(m_received)); + m_received.clear(); + } else { + BOOST_REQUIRE(m_to_send.empty()); + m_to_send.resize(Sv2HandshakeState::ELLSWIFT_PUB_KEY_SIZE); + m_peer_cipher->GetHandshakeState().WriteMsgEphemeralPK(MakeWritableByteSpan(m_to_send)); + } + + } + + /** Expect key to have been received from transport and process it. + * + * Many other Sv2TransportTester functions cannot be called until after + * ProcessHandshake2() has been called, as no encryption keys are set up before that point. + */ + void ProcessHandshake2() + { + if (m_test_initiator) { + BOOST_REQUIRE(m_to_send.empty()); + + // Have the peer cypher write the second part of the handshake into our receive buffer + m_to_send.resize(Sv2HandshakeState::HANDSHAKE_STEP2_SIZE); + m_peer_cipher->GetHandshakeState().WriteMsgES(MakeWritableByteSpan(m_to_send)); + + // At this point the peer is done with the handshake: + m_peer_cipher->FinishHandshake(); + } else { + BOOST_REQUIRE(m_received.size() == Sv2HandshakeState::HANDSHAKE_STEP2_SIZE); + BOOST_REQUIRE(m_peer_cipher->GetHandshakeState().ReadMsgES(MakeWritableByteSpan(m_received))); + m_received.clear(); + + m_peer_cipher->FinishHandshake(); + } + } + + /** Schedule an encrypted packet with specified content to be sent to transport + * (only after ReceiveKey). */ + void SendPacket(Sv2NetMsg msg) + { + // TODO: randomly break stuff + + std::vector ciphertext; + const size_t encrypted_payload_size = Sv2Cipher::EncryptedMessageSize(msg.size()); + ciphertext.resize(SV2_HEADER_ENCRYPTED_SIZE + encrypted_payload_size); + Span buffer_span{MakeWritableByteSpan(ciphertext)}; + + // Header + DataStream ss_header_plain{}; + ss_header_plain << Sv2NetHeader(msg); + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Header: %s\n", HexStr(ss_header_plain)); + Span header_encrypted{buffer_span.subspan(0, SV2_HEADER_ENCRYPTED_SIZE)}; + BOOST_REQUIRE(m_peer_cipher->EncryptMessage(ss_header_plain, header_encrypted)); + + // Payload + Span payload_plain = MakeByteSpan(msg); + // TODO: truncate very long messages, about 100 bytes at the start and end + // is probably enough for most debugging. + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Payload: %s\n", HexStr(payload_plain)); + Span payload_encrypted{buffer_span.subspan(SV2_HEADER_ENCRYPTED_SIZE, encrypted_payload_size)}; + BOOST_REQUIRE(m_peer_cipher->EncryptMessage(payload_plain, payload_encrypted)); + + // Schedule it for sending. + Send(ciphertext); + } + + /** Expect application packet to have been received, with specified message type and payload. + * (only after ReceiveKey). */ + void ReceiveMessage(Sv2NetMsg expected_msg) + { + // When processing a packet, at least enough bytes for its length descriptor must be received. + BOOST_REQUIRE(m_received.size() >= SV2_HEADER_ENCRYPTED_SIZE); + + auto header_encrypted{MakeWritableByteSpan(m_received).subspan(0, SV2_HEADER_ENCRYPTED_SIZE)}; + std::array header_plain; + BOOST_REQUIRE(m_peer_cipher->DecryptMessage(header_encrypted, header_plain)); + + // Decode header + DataStream ss_header{header_plain}; + node::Sv2NetHeader header; + ss_header >> header; + + BOOST_CHECK(header.m_msg_type == expected_msg.m_msg_type); + + size_t expanded_size = Sv2Cipher::EncryptedMessageSize(header.m_msg_len); + BOOST_REQUIRE(m_received.size() >= SV2_HEADER_ENCRYPTED_SIZE + expanded_size); + + Span encrypted_payload{MakeWritableByteSpan(m_received).subspan(SV2_HEADER_ENCRYPTED_SIZE, expanded_size)}; + Span payload = encrypted_payload.subspan(0, header.m_msg_len); + + BOOST_REQUIRE(m_peer_cipher->DecryptMessage(encrypted_payload, payload)); + + std::vector decode_buffer; + decode_buffer.resize(header.m_msg_len); + + std::transform(payload.begin(), payload.end(), decode_buffer.begin(), + [](std::byte b) { return static_cast(b); }); + + // TODO: clear the m_received we used + + Sv2NetMsg message{header.m_msg_type, std::move(decode_buffer)}; + + // TODO: compare payload + } + + /** Test whether the transport's m_hash matches the other side. */ + void CompareHash() const + { + BOOST_REQUIRE(m_transport); + BOOST_CHECK(m_transport->NoiseHash() == m_peer_cipher->GetHash()); + } + + void CheckRecvState(Sv2Transport::RecvState state) { + BOOST_REQUIRE(m_transport); + BOOST_CHECK_EQUAL(RecvStateAsString(m_transport->GetRecvState()), RecvStateAsString(state)); + } + + void CheckSendState(Sv2Transport::SendState state) { + BOOST_REQUIRE(m_transport); + BOOST_CHECK_EQUAL(SendStateAsString(m_transport->GetSendState()), SendStateAsString(state)); + } + + /** Introduce a bit error in the data scheduled to be sent. */ + // void Damage() + // { + // BOOST_TEST_MESSAGE("[Interact] introduce a bit error"); + // m_to_send[m_rng.randrange(m_to_send.size())] ^= (uint8_t{1} << m_rng.randrange(8)); + // } +}; + +} // namespace + +BOOST_AUTO_TEST_CASE(sv2_transport_initiator_test) +{ + // A mostly normal scenario, testing a transport in initiator mode. + // Interact() introduces randomness, so run multiple times + for (int i = 0; i < 10; ++i) { + BOOST_TEST_MESSAGE(strprintf("\nIteration %d (initiator)", i)); + Sv2TransportTester tester(m_rng, true); + // As the initiator, our ephemeral public key is immedidately put + // onto the buffer. + tester.CheckSendState(Sv2Transport::SendState::HANDSHAKE_STEP_2); + tester.CheckRecvState(Sv2Transport::RecvState::HANDSHAKE_STEP_2); + auto ret = tester.Interact(); + BOOST_REQUIRE(ret && ret->empty()); + tester.ProcessHandshake1(); + ret = tester.Interact(); + BOOST_REQUIRE(ret && ret->empty()); + tester.ProcessHandshake2(); + ret = tester.Interact(); + BOOST_REQUIRE(ret && ret->empty()); + tester.CheckSendState(Sv2Transport::SendState::READY); + tester.CheckRecvState(Sv2Transport::RecvState::APP); + tester.CompareHash(); + } +} + +BOOST_AUTO_TEST_CASE(sv2_transport_responder_test) +{ + // Normal scenario, with a transport in responder node. + for (int i = 0; i < 10; ++i) { + BOOST_TEST_MESSAGE(strprintf("\nIteration %d (responder)", i)); + Sv2TransportTester tester(m_rng, false); + tester.CheckSendState(Sv2Transport::SendState::HANDSHAKE_STEP_2); + tester.CheckRecvState(Sv2Transport::RecvState::HANDSHAKE_STEP_1); + tester.ProcessHandshake1(); + auto ret = tester.Interact(); + BOOST_REQUIRE(ret && ret->empty()); + tester.CheckSendState(Sv2Transport::SendState::READY); + tester.CheckRecvState(Sv2Transport::RecvState::APP); + + // Have the test cypher process our handshake reply + tester.ProcessHandshake2(); + tester.CompareHash(); + + // Handshake complete, have the initiator send us a message: + Sv2CoinbaseOutputDataSizeMsg body{4000}; + Sv2NetMsg msg{body}; + BOOST_REQUIRE(msg.m_msg_type == Sv2MsgType::COINBASE_OUTPUT_DATA_SIZE); + + tester.SendPacket(msg); + ret = tester.Interact(); + BOOST_REQUIRE(ret && ret->size() == 1); + BOOST_CHECK((*ret)[0] && + (*ret)[0]->m_msg_type == Sv2MsgType::COINBASE_OUTPUT_DATA_SIZE); + + tester.CompareHash(); + + // Send a message back to the initiator + tester.AddMessage(msg); + ret = tester.Interact(); + BOOST_REQUIRE(ret && ret->size() == 0); + tester.ReceiveMessage(msg); + + // TODO: send / receive message larger than the chunk size + } +} + + +BOOST_AUTO_TEST_SUITE_END()