From 961caa74f0d452861e91470a68d9a16c31ecce69 Mon Sep 17 00:00:00 2001 From: jbesraa Date: Fri, 18 Oct 2024 12:45:42 +0300 Subject: [PATCH] msg interrupter --- .../lib/template_receiver/setup_connection.rs | 4 +- roles/tests-integration/tests/common/mod.rs | 9 +- .../tests-integration/tests/common/sniffer.rs | 90 +++++++++++++++++-- .../tests/pool_integration.rs | 37 +++++++- 4 files changed, 127 insertions(+), 13 deletions(-) diff --git a/roles/pool/src/lib/template_receiver/setup_connection.rs b/roles/pool/src/lib/template_receiver/setup_connection.rs index d88064ce8..47931892b 100644 --- a/roles/pool/src/lib/template_receiver/setup_connection.rs +++ b/roles/pool/src/lib/template_receiver/setup_connection.rs @@ -85,8 +85,8 @@ impl ParseUpstreamCommonMessages for SetupConnectionHandler { // let error_code = m.error_code.clone(); let message = SetupConnectionError { flags, - // this error code is currently a hack because there is a lifetime problem with - // `error_code`. + // this error code is currently a hack because there is a lifetime problem with + // `error_code`. error_code: "unsupported-feature-flags" .to_string() .into_bytes() diff --git a/roles/tests-integration/tests/common/mod.rs b/roles/tests-integration/tests/common/mod.rs index 191906984..49ecd1bf2 100644 --- a/roles/tests-integration/tests/common/mod.rs +++ b/roles/tests-integration/tests/common/mod.rs @@ -6,6 +6,7 @@ use key_utils::{Secp256k1PublicKey, Secp256k1SecretKey}; use once_cell::sync::Lazy; use pool_sv2::PoolSv2; use sniffer::Sniffer; +pub use sniffer::{InterruptMessage, MessageDirection}; use std::{ collections::HashSet, convert::TryFrom, @@ -193,8 +194,12 @@ pub fn get_available_address() -> SocketAddr { SocketAddr::from(([127, 0, 0, 1], port)) } -pub async fn start_sniffer(listening_address: SocketAddr, upstream: SocketAddr) -> Sniffer { - let sniffer = Sniffer::new(listening_address, upstream).await; +pub async fn start_sniffer( + listening_address: SocketAddr, + upstream: SocketAddr, + interrupt_messages: Option>, +) -> Sniffer { + let sniffer = Sniffer::new(listening_address, upstream, interrupt_messages).await; let sniffer_clone = sniffer.clone(); tokio::spawn(async move { sniffer_clone.start().await; diff --git a/roles/tests-integration/tests/common/sniffer.rs b/roles/tests-integration/tests/common/sniffer.rs index 46e885e4e..01b78277c 100644 --- a/roles/tests-integration/tests/common/sniffer.rs +++ b/roles/tests-integration/tests/common/sniffer.rs @@ -1,6 +1,6 @@ use async_channel::{Receiver, Sender}; use codec_sv2::{ - framing_sv2::framing::Frame, HandshakeRole, Initiator, Responder, StandardEitherFrame, + framing_sv2::framing::Frame, HandshakeRole, Initiator, Responder, StandardEitherFrame, Sv2Frame, }; use key_utils::{Secp256k1PublicKey, Secp256k1SecretKey}; use network_helpers_sv2::noise_connection_tokio::Connection; @@ -13,8 +13,8 @@ use roles_logic_sv2::{ IdentifyTransactionsSuccess, ProvideMissingTransactions, ProvideMissingTransactionsSuccess, SubmitSolution, }, - TemplateDistribution, - TemplateDistribution::CoinbaseOutputDataSize, + PoolMessages, + TemplateDistribution::{self, CoinbaseOutputDataSize}, }, utils::Mutex, }; @@ -22,6 +22,7 @@ use std::{collections::VecDeque, convert::TryInto, net::SocketAddr, sync::Arc}; use tokio::{ net::{TcpListener, TcpStream}, select, + time::sleep, }; type MessageFrame = StandardEitherFrame>; type MsgType = u8; @@ -30,6 +31,7 @@ type MsgType = u8; enum SnifferError { DownstreamClosed, UpstreamClosed, + MessageInterrupted, } /// Allows to intercept messages sent between two roles. @@ -50,17 +52,56 @@ pub struct Sniffer { upstream_address: SocketAddr, downstream_messages: MessagesAggregator, upstream_messages: MessagesAggregator, + interrupt_messages: Vec, +} + +#[derive(Debug, Clone)] +pub struct InterruptMessage { + direction: MessageDirection, + expected_message_type: MsgType, + response_message: PoolMessages<'static>, + response_message_type: MsgType, + break_on: bool, +} + +impl InterruptMessage { + pub fn new( + direction: MessageDirection, + expected_message_type: MsgType, + response_message: PoolMessages<'static>, + response_message_type: MsgType, + break_on: bool, + ) -> Self { + Self { + direction, + expected_message_type, + response_message, + response_message_type, + break_on, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum MessageDirection { + ToDownstream, + ToUpstream, } impl Sniffer { /// Creates a new sniffer that listens on the given listening address and connects to the given /// upstream address. - pub async fn new(listening_address: SocketAddr, upstream_address: SocketAddr) -> Self { + pub async fn new( + listening_address: SocketAddr, + upstream_address: SocketAddr, + interrupt_messages: Option>, + ) -> Self { Self { listening_address, upstream_address, downstream_messages: MessagesAggregator::new(), upstream_messages: MessagesAggregator::new(), + interrupt_messages: interrupt_messages.unwrap_or_default(), } } @@ -82,10 +123,18 @@ impl Sniffer { .expect("Failed to create upstream"); let downstream_messages = self.downstream_messages.clone(); let upstream_messages = self.upstream_messages.clone(); + let interrupt_messages = self.interrupt_messages.clone(); let _ = select! { - r = Self::recv_from_down_send_to_up(downstream_receiver, upstream_sender, downstream_messages) => r, - r = Self::recv_from_up_send_to_down(upstream_receiver, downstream_sender, upstream_messages) => r, + r = Self::recv_from_down_send_to_up(downstream_receiver, upstream_sender, downstream_messages, interrupt_messages.clone()) => r, + r = Self::recv_from_up_send_to_down(upstream_receiver, downstream_sender, upstream_messages, interrupt_messages) => r, }; + // wait a bit so we dont drop the sniffer before the test has finished + sleep(std::time::Duration::from_secs(1)).await; + } + + pub fn list_all_messages(&self) { + println!("Downstream messages: {:?}", self.downstream_messages); + println!("Upstream messages: {:?}", self.upstream_messages); } /// Returns the oldest message sent by downstream. @@ -160,6 +209,7 @@ impl Sniffer { recv: Receiver, send: Sender, downstream_messages: MessagesAggregator, + _interrupt_messages: Vec, ) -> Result<(), SnifferError> { while let Ok(mut frame) = recv.recv().await { let (msg_type, msg) = Self::message_from_frame(&mut frame); @@ -175,13 +225,39 @@ impl Sniffer { recv: Receiver, send: Sender, upstream_messages: MessagesAggregator, + interrupt_messages: Vec, ) -> Result<(), SnifferError> { while let Ok(mut frame) = recv.recv().await { let (msg_type, msg) = Self::message_from_frame(&mut frame); - upstream_messages.add_message(msg_type, msg); + for interrupt_message in interrupt_messages.iter() { + if interrupt_message.direction == MessageDirection::ToDownstream + && interrupt_message.expected_message_type == msg_type + { + let extension_type = 0; + let channel_msg = false; + let frame = StandardEitherFrame::>::Sv2( + Sv2Frame::from_message( + interrupt_message.response_message.clone(), + interrupt_message.response_message_type, + extension_type, + channel_msg, + ) + .expect("Failed to create the frame"), + ); + upstream_messages + .add_message(msg_type, interrupt_message.response_message.clone()); + let _ = send.send(frame).await; + if interrupt_message.break_on { + return Err(SnifferError::MessageInterrupted); + } else { + continue; + } + } + } if send.send(frame).await.is_err() { return Err(SnifferError::DownstreamClosed); }; + upstream_messages.add_message(msg_type, msg); } Err(SnifferError::UpstreamClosed) } diff --git a/roles/tests-integration/tests/pool_integration.rs b/roles/tests-integration/tests/pool_integration.rs index 9754ed1d9..9174d629b 100644 --- a/roles/tests-integration/tests/pool_integration.rs +++ b/roles/tests-integration/tests/pool_integration.rs @@ -1,7 +1,11 @@ mod common; +use std::convert::TryInto; + +use common::{InterruptMessage, MessageDirection}; +use const_sv2::MESSAGE_TYPE_SETUP_CONNECTION_ERROR; use roles_logic_sv2::{ - common_messages_sv2::{Protocol, SetupConnection}, + common_messages_sv2::{Protocol, SetupConnection, SetupConnectionError}, parsers::{CommonMessages, PoolMessages, TemplateDistribution}, }; @@ -15,7 +19,7 @@ async fn success_pool_template_provider_connection() { let tp_addr = common::get_available_address(); let pool_addr = common::get_available_address(); let _tp = common::start_template_provider(tp_addr.port()).await; - let sniffer = common::start_sniffer(sniffer_addr, tp_addr).await; + let sniffer = common::start_sniffer(sniffer_addr, tp_addr, None).await; let _ = common::start_pool(Some(pool_addr), Some(sniffer_addr)).await; // here we assert that the downstream(pool in this case) have sent `SetupConnection` message // with the correct parameters, protocol, flags, min_version and max_version. Note that the @@ -38,3 +42,32 @@ async fn success_pool_template_provider_connection() { assert_tp_message!(&sniffer.next_upstream_message(), NewTemplate); assert_tp_message!(sniffer.next_upstream_message(), SetNewPrevHash); } + +#[tokio::test] +async fn test_sniffer_interrupter() { + let sniffer_addr = common::get_available_address(); + let tp_addr = common::get_available_address(); + let pool_addr = common::get_available_address(); + let _tp = common::start_template_provider(tp_addr.port()).await; + use const_sv2::MESSAGE_TYPE_SETUP_CONNECTION_SUCCESS; + let message = + PoolMessages::Common(CommonMessages::SetupConnectionError(SetupConnectionError { + flags: 0, + error_code: "unsupported-feature-flags" + .to_string() + .into_bytes() + .try_into() + .unwrap(), + })); + let interrupt_msgs = InterruptMessage::new( + MessageDirection::ToDownstream, + MESSAGE_TYPE_SETUP_CONNECTION_SUCCESS, + message, + MESSAGE_TYPE_SETUP_CONNECTION_ERROR, + true, + ); + let sniffer = common::start_sniffer(sniffer_addr, tp_addr, Some(vec![interrupt_msgs])).await; + let _ = common::start_pool(Some(pool_addr), Some(sniffer_addr)).await; + assert_common_message!(&sniffer.next_downstream_message(), SetupConnection); + assert_common_message!(&sniffer.next_upstream_message(), SetupConnectionError); +}