Skip to content

Commit

Permalink
msg interrupter
Browse files Browse the repository at this point in the history
  • Loading branch information
jbesraa committed Oct 18, 2024
1 parent e53b419 commit 961caa7
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 13 deletions.
4 changes: 2 additions & 2 deletions roles/pool/src/lib/template_receiver/setup_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ impl ParseUpstreamCommonMessages<NoRouting> for SetupConnectionHandler {
// let error_code = m.error_code.clone();
let message = SetupConnectionError {
flags,
// this error code is currently a hack because there is a lifetime problem with
// `error_code`.
// this error code is currently a hack because there is a lifetime problem with
// `error_code`.
error_code: "unsupported-feature-flags"
.to_string()
.into_bytes()
Expand Down
9 changes: 7 additions & 2 deletions roles/tests-integration/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use key_utils::{Secp256k1PublicKey, Secp256k1SecretKey};
use once_cell::sync::Lazy;
use pool_sv2::PoolSv2;
use sniffer::Sniffer;
pub use sniffer::{InterruptMessage, MessageDirection};
use std::{
collections::HashSet,
convert::TryFrom,
Expand Down Expand Up @@ -193,8 +194,12 @@ pub fn get_available_address() -> SocketAddr {
SocketAddr::from(([127, 0, 0, 1], port))
}

pub async fn start_sniffer(listening_address: SocketAddr, upstream: SocketAddr) -> Sniffer {
let sniffer = Sniffer::new(listening_address, upstream).await;
pub async fn start_sniffer(
listening_address: SocketAddr,
upstream: SocketAddr,
interrupt_messages: Option<Vec<InterruptMessage>>,
) -> Sniffer {
let sniffer = Sniffer::new(listening_address, upstream, interrupt_messages).await;
let sniffer_clone = sniffer.clone();
tokio::spawn(async move {
sniffer_clone.start().await;
Expand Down
90 changes: 83 additions & 7 deletions roles/tests-integration/tests/common/sniffer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use async_channel::{Receiver, Sender};
use codec_sv2::{
framing_sv2::framing::Frame, HandshakeRole, Initiator, Responder, StandardEitherFrame,
framing_sv2::framing::Frame, HandshakeRole, Initiator, Responder, StandardEitherFrame, Sv2Frame,
};
use key_utils::{Secp256k1PublicKey, Secp256k1SecretKey};
use network_helpers_sv2::noise_connection_tokio::Connection;
Expand All @@ -13,15 +13,16 @@ use roles_logic_sv2::{
IdentifyTransactionsSuccess, ProvideMissingTransactions,
ProvideMissingTransactionsSuccess, SubmitSolution,
},
TemplateDistribution,
TemplateDistribution::CoinbaseOutputDataSize,
PoolMessages,
TemplateDistribution::{self, CoinbaseOutputDataSize},
},
utils::Mutex,
};
use std::{collections::VecDeque, convert::TryInto, net::SocketAddr, sync::Arc};
use tokio::{
net::{TcpListener, TcpStream},
select,
time::sleep,
};
type MessageFrame = StandardEitherFrame<AnyMessage<'static>>;
type MsgType = u8;
Expand All @@ -30,6 +31,7 @@ type MsgType = u8;
enum SnifferError {
DownstreamClosed,
UpstreamClosed,
MessageInterrupted,
}

/// Allows to intercept messages sent between two roles.
Expand All @@ -50,17 +52,56 @@ pub struct Sniffer {
upstream_address: SocketAddr,
downstream_messages: MessagesAggregator,
upstream_messages: MessagesAggregator,
interrupt_messages: Vec<InterruptMessage>,
}

#[derive(Debug, Clone)]
pub struct InterruptMessage {
direction: MessageDirection,
expected_message_type: MsgType,
response_message: PoolMessages<'static>,
response_message_type: MsgType,
break_on: bool,
}

impl InterruptMessage {
pub fn new(
direction: MessageDirection,
expected_message_type: MsgType,
response_message: PoolMessages<'static>,
response_message_type: MsgType,
break_on: bool,
) -> Self {
Self {
direction,
expected_message_type,
response_message,
response_message_type,
break_on,
}
}
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MessageDirection {
ToDownstream,
ToUpstream,
}

impl Sniffer {
/// Creates a new sniffer that listens on the given listening address and connects to the given
/// upstream address.
pub async fn new(listening_address: SocketAddr, upstream_address: SocketAddr) -> Self {
pub async fn new(
listening_address: SocketAddr,
upstream_address: SocketAddr,
interrupt_messages: Option<Vec<InterruptMessage>>,
) -> Self {
Self {
listening_address,
upstream_address,
downstream_messages: MessagesAggregator::new(),
upstream_messages: MessagesAggregator::new(),
interrupt_messages: interrupt_messages.unwrap_or_default(),
}
}

Expand All @@ -82,10 +123,18 @@ impl Sniffer {
.expect("Failed to create upstream");
let downstream_messages = self.downstream_messages.clone();
let upstream_messages = self.upstream_messages.clone();
let interrupt_messages = self.interrupt_messages.clone();
let _ = select! {
r = Self::recv_from_down_send_to_up(downstream_receiver, upstream_sender, downstream_messages) => r,
r = Self::recv_from_up_send_to_down(upstream_receiver, downstream_sender, upstream_messages) => r,
r = Self::recv_from_down_send_to_up(downstream_receiver, upstream_sender, downstream_messages, interrupt_messages.clone()) => r,
r = Self::recv_from_up_send_to_down(upstream_receiver, downstream_sender, upstream_messages, interrupt_messages) => r,
};
// wait a bit so we dont drop the sniffer before the test has finished
sleep(std::time::Duration::from_secs(1)).await;
}

pub fn list_all_messages(&self) {
println!("Downstream messages: {:?}", self.downstream_messages);
println!("Upstream messages: {:?}", self.upstream_messages);
}

/// Returns the oldest message sent by downstream.
Expand Down Expand Up @@ -160,6 +209,7 @@ impl Sniffer {
recv: Receiver<MessageFrame>,
send: Sender<MessageFrame>,
downstream_messages: MessagesAggregator,
_interrupt_messages: Vec<InterruptMessage>,
) -> Result<(), SnifferError> {
while let Ok(mut frame) = recv.recv().await {
let (msg_type, msg) = Self::message_from_frame(&mut frame);
Expand All @@ -175,13 +225,39 @@ impl Sniffer {
recv: Receiver<MessageFrame>,
send: Sender<MessageFrame>,
upstream_messages: MessagesAggregator,
interrupt_messages: Vec<InterruptMessage>,
) -> Result<(), SnifferError> {
while let Ok(mut frame) = recv.recv().await {
let (msg_type, msg) = Self::message_from_frame(&mut frame);
upstream_messages.add_message(msg_type, msg);
for interrupt_message in interrupt_messages.iter() {
if interrupt_message.direction == MessageDirection::ToDownstream
&& interrupt_message.expected_message_type == msg_type
{
let extension_type = 0;
let channel_msg = false;
let frame = StandardEitherFrame::<AnyMessage<'_>>::Sv2(
Sv2Frame::from_message(
interrupt_message.response_message.clone(),
interrupt_message.response_message_type,
extension_type,
channel_msg,
)
.expect("Failed to create the frame"),
);
upstream_messages
.add_message(msg_type, interrupt_message.response_message.clone());
let _ = send.send(frame).await;
if interrupt_message.break_on {
return Err(SnifferError::MessageInterrupted);
} else {
continue;
}
}
}
if send.send(frame).await.is_err() {
return Err(SnifferError::DownstreamClosed);
};
upstream_messages.add_message(msg_type, msg);
}
Err(SnifferError::UpstreamClosed)
}
Expand Down
37 changes: 35 additions & 2 deletions roles/tests-integration/tests/pool_integration.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
mod common;

use std::convert::TryInto;

use common::{InterruptMessage, MessageDirection};
use const_sv2::MESSAGE_TYPE_SETUP_CONNECTION_ERROR;
use roles_logic_sv2::{
common_messages_sv2::{Protocol, SetupConnection},
common_messages_sv2::{Protocol, SetupConnection, SetupConnectionError},
parsers::{CommonMessages, PoolMessages, TemplateDistribution},
};

Expand All @@ -15,7 +19,7 @@ async fn success_pool_template_provider_connection() {
let tp_addr = common::get_available_address();
let pool_addr = common::get_available_address();
let _tp = common::start_template_provider(tp_addr.port()).await;
let sniffer = common::start_sniffer(sniffer_addr, tp_addr).await;
let sniffer = common::start_sniffer(sniffer_addr, tp_addr, None).await;
let _ = common::start_pool(Some(pool_addr), Some(sniffer_addr)).await;
// here we assert that the downstream(pool in this case) have sent `SetupConnection` message
// with the correct parameters, protocol, flags, min_version and max_version. Note that the
Expand All @@ -38,3 +42,32 @@ async fn success_pool_template_provider_connection() {
assert_tp_message!(&sniffer.next_upstream_message(), NewTemplate);
assert_tp_message!(sniffer.next_upstream_message(), SetNewPrevHash);
}

#[tokio::test]
async fn test_sniffer_interrupter() {
let sniffer_addr = common::get_available_address();
let tp_addr = common::get_available_address();
let pool_addr = common::get_available_address();
let _tp = common::start_template_provider(tp_addr.port()).await;
use const_sv2::MESSAGE_TYPE_SETUP_CONNECTION_SUCCESS;
let message =
PoolMessages::Common(CommonMessages::SetupConnectionError(SetupConnectionError {
flags: 0,
error_code: "unsupported-feature-flags"
.to_string()
.into_bytes()
.try_into()
.unwrap(),
}));
let interrupt_msgs = InterruptMessage::new(
MessageDirection::ToDownstream,
MESSAGE_TYPE_SETUP_CONNECTION_SUCCESS,
message,
MESSAGE_TYPE_SETUP_CONNECTION_ERROR,
true,
);
let sniffer = common::start_sniffer(sniffer_addr, tp_addr, Some(vec![interrupt_msgs])).await;
let _ = common::start_pool(Some(pool_addr), Some(sniffer_addr)).await;
assert_common_message!(&sniffer.next_downstream_message(), SetupConnection);
assert_common_message!(&sniffer.next_upstream_message(), SetupConnectionError);
}

0 comments on commit 961caa7

Please sign in to comment.