Skip to content

Commit 01f69ce

Browse files
committed
Continue DataActor in Rust
1 parent 4992cf3 commit 01f69ce

File tree

5 files changed

+276
-18
lines changed

5 files changed

+276
-18
lines changed

crates/adapters/demo/src/data_client.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,13 @@ impl MockDataClient {
125125
.parse::<i32>()
126126
.unwrap();
127127
println!("Received positive value: {value}");
128+
128129
let response = DataResponse::Data(CustomDataResponse::new(
129130
req.request_id,
130131
req.client_id,
131132
Venue::new("http positive stream"),
132133
DataType::new("positive_stream", None),
133-
value,
134+
Arc::new(value),
134135
UnixNanos::new(0),
135136
None,
136137
));
@@ -167,7 +168,7 @@ impl MockDataClient {
167168
req.client_id,
168169
Venue::new("http positive stream"),
169170
DataType::new("positive_stream", None),
170-
value,
171+
Arc::new(value),
171172
UnixNanos::new(0),
172173
None,
173174
));

crates/common/src/actor/data_actor.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -519,12 +519,12 @@ pub trait DataActor: Actor {
519519
fn handle_data_response(&mut self, response: &CustomDataResponse) {
520520
log_received(&response);
521521

522-
if let Some(data) = response.data.downcast_ref::<Vec<&dyn Any>>() {
523-
for d in data {
524-
self.handle_historical_data(d);
522+
if let Some(list) = response.data.downcast_ref::<Vec<&dyn Any>>() {
523+
for item in list {
524+
self.handle_historical_data(item);
525525
}
526-
} else if let Some(data) = response.data.downcast_ref::<&dyn Any>() {
527-
self.handle_historical_data(data);
526+
} else {
527+
self.handle_historical_data(response.data.as_ref());
528528
}
529529
}
530530

crates/common/src/actor/registry.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,10 @@ pub fn get_actor_unchecked<T: Actor>(id: &Ustr) -> &mut T {
8080
let actor = get_actor(id).unwrap_or_else(|| panic!("Actor for {id} not found"));
8181
unsafe { &mut *(actor.get() as *mut _ as *mut T) }
8282
}
83+
84+
// Clears the global actor registry (for test isolation).
85+
#[cfg(test)]
86+
pub fn clear_actor_registry() {
87+
// SAFETY: Clearing registry actors; tests should run single-threaded for actor registry
88+
get_actor_registry().actors.borrow_mut().clear();
89+
}

crates/common/src/actor/tests.rs

Lines changed: 256 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use std::{
2323
sync::Arc,
2424
};
2525

26+
use bytes::Bytes;
2627
use log::LevelFilter;
2728
use nautilus_core::UnixNanos;
2829
use nautilus_model::{
@@ -52,10 +53,11 @@ use crate::{
5253
enums::ComponentState,
5354
logging::{logger::LogGuard, logging_is_initialized},
5455
messages::data::{
55-
BarsResponse, BookResponse, InstrumentsResponse, QuotesResponse, TradesResponse,
56+
BarsResponse, BookResponse, CustomDataResponse, InstrumentsResponse, QuotesResponse,
57+
TradesResponse,
5658
},
5759
msgbus::{
58-
self,
60+
self, MessageBus, get_message_bus,
5961
switchboard::{
6062
MessagingSwitchboard, get_bars_topic, get_book_deltas_topic, get_book_snapshots_topic,
6163
get_custom_topic, get_index_price_topic, get_instrument_close_topic,
@@ -182,15 +184,7 @@ impl DataActor for TestDataActor {
182184
}
183185

184186
fn on_historical_data(&mut self, data: &dyn Any) -> anyhow::Result<()> {
185-
// Capture raw historical data items
186-
// Attempt to downcast to &str or String
187-
if let Some(s) = data.downcast_ref::<&str>() {
188-
self.received_data.push(s.to_string());
189-
} else if let Some(s) = data.downcast_ref::<String>() {
190-
self.received_data.push(s.clone());
191-
} else {
192-
self.received_data.push(format!("{:?}", data));
193-
}
187+
self.received_data.push(format!("{:?}", data));
194188
Ok(())
195189
}
196190

@@ -302,6 +296,9 @@ fn register_data_actor(
302296
trader_id: TraderId,
303297
) -> Ustr {
304298
let config = DataActorConfig::default();
299+
// Ensure clean message bus state for this actor's subscriptions
300+
let bus = get_message_bus();
301+
*bus.borrow_mut() = MessageBus::default();
305302
let mut actor = TestDataActor::new(config, cache, clock);
306303
let actor_id = actor.actor_id;
307304
actor.set_trader_id(trader_id);
@@ -1033,6 +1030,214 @@ fn test_subscribe_and_receive_instrument_close(
10331030
assert_eq!(actor.received_closes[0], stub_instrument_close);
10341031
}
10351032

1033+
// Unsubscribe tests for various data types
1034+
#[rstest]
1035+
fn test_unsubscribe_instruments(
1036+
clock: Rc<RefCell<TestClock>>,
1037+
cache: Rc<RefCell<Cache>>,
1038+
trader_id: TraderId,
1039+
audusd_sim: CurrencyPair,
1040+
gbpusd_sim: CurrencyPair,
1041+
) {
1042+
let actor_id = register_data_actor(clock.clone(), cache.clone(), trader_id);
1043+
let actor = get_actor_unchecked::<TestDataActor>(&actor_id);
1044+
actor.start().unwrap();
1045+
1046+
let venue = Venue::from("SIM");
1047+
actor.subscribe_instruments::<TestDataActor>(venue, None, None);
1048+
1049+
let topic = get_instruments_topic(venue);
1050+
let inst1 = InstrumentAny::CurrencyPair(audusd_sim.clone());
1051+
msgbus::publish(&topic, &inst1);
1052+
let inst2 = InstrumentAny::CurrencyPair(gbpusd_sim.clone());
1053+
msgbus::publish(&topic, &inst2);
1054+
1055+
assert_eq!(actor.received_instruments.len(), 2);
1056+
1057+
actor.unsubscribe_instruments::<TestDataActor>(venue, None, None);
1058+
1059+
let inst3 = InstrumentAny::CurrencyPair(audusd_sim.clone());
1060+
msgbus::publish(&topic, &inst3);
1061+
let inst4 = InstrumentAny::CurrencyPair(gbpusd_sim.clone());
1062+
msgbus::publish(&topic, &inst4);
1063+
1064+
assert_eq!(actor.received_instruments.len(), 2);
1065+
}
1066+
1067+
#[rstest]
1068+
fn test_unsubscribe_instrument(
1069+
clock: Rc<RefCell<TestClock>>,
1070+
cache: Rc<RefCell<Cache>>,
1071+
trader_id: TraderId,
1072+
audusd_sim: CurrencyPair,
1073+
gbpusd_sim: CurrencyPair,
1074+
) {
1075+
let actor_id = register_data_actor(clock.clone(), cache.clone(), trader_id);
1076+
let actor = get_actor_unchecked::<TestDataActor>(&actor_id);
1077+
actor.start().unwrap();
1078+
1079+
actor.subscribe_instrument::<TestDataActor>(audusd_sim.id, None, None);
1080+
1081+
let topic = get_instrument_topic(audusd_sim.id);
1082+
let inst1 = InstrumentAny::CurrencyPair(audusd_sim.clone());
1083+
msgbus::publish(&topic, &inst1);
1084+
let inst2 = InstrumentAny::CurrencyPair(gbpusd_sim.clone());
1085+
msgbus::publish(&topic, &inst2);
1086+
1087+
assert_eq!(actor.received_instruments.len(), 2);
1088+
1089+
actor.unsubscribe_instrument::<TestDataActor>(audusd_sim.id, None, None);
1090+
1091+
let inst3 = InstrumentAny::CurrencyPair(audusd_sim.clone());
1092+
msgbus::publish(&topic, &inst3);
1093+
let inst4 = InstrumentAny::CurrencyPair(gbpusd_sim.clone());
1094+
msgbus::publish(&topic, &inst4);
1095+
1096+
assert_eq!(actor.received_instruments.len(), 2);
1097+
}
1098+
1099+
#[rstest]
1100+
fn test_unsubscribe_mark_prices(
1101+
clock: Rc<RefCell<TestClock>>,
1102+
cache: Rc<RefCell<Cache>>,
1103+
trader_id: TraderId,
1104+
audusd_sim: CurrencyPair,
1105+
) {
1106+
let actor_id = register_data_actor(clock.clone(), cache.clone(), trader_id);
1107+
let actor = get_actor_unchecked::<TestDataActor>(&actor_id);
1108+
actor.start().unwrap();
1109+
1110+
actor.subscribe_mark_prices::<TestDataActor>(audusd_sim.id, None, None);
1111+
1112+
let topic = get_mark_price_topic(audusd_sim.id);
1113+
let mp1 = MarkPriceUpdate::new(
1114+
audusd_sim.id,
1115+
Price::from("1.00000"),
1116+
UnixNanos::from(1),
1117+
UnixNanos::from(2),
1118+
);
1119+
msgbus::publish(&topic, &mp1);
1120+
let mp2 = MarkPriceUpdate::new(
1121+
audusd_sim.id,
1122+
Price::from("1.00010"),
1123+
UnixNanos::from(3),
1124+
UnixNanos::from(4),
1125+
);
1126+
msgbus::publish(&topic, &mp2);
1127+
1128+
assert_eq!(actor.received_mark_prices.len(), 2);
1129+
1130+
actor.unsubscribe_mark_prices::<TestDataActor>(audusd_sim.id, None, None);
1131+
1132+
let mp3 = MarkPriceUpdate::new(
1133+
audusd_sim.id,
1134+
Price::from("1.00020"),
1135+
UnixNanos::from(5),
1136+
UnixNanos::from(6),
1137+
);
1138+
msgbus::publish(&topic, &mp3);
1139+
let mp4 = MarkPriceUpdate::new(
1140+
audusd_sim.id,
1141+
Price::from("1.00030"),
1142+
UnixNanos::from(7),
1143+
UnixNanos::from(8),
1144+
);
1145+
msgbus::publish(&topic, &mp4);
1146+
1147+
assert_eq!(actor.received_mark_prices.len(), 2);
1148+
}
1149+
1150+
#[rstest]
1151+
fn test_unsubscribe_index_prices(
1152+
clock: Rc<RefCell<TestClock>>,
1153+
cache: Rc<RefCell<Cache>>,
1154+
trader_id: TraderId,
1155+
audusd_sim: CurrencyPair,
1156+
) {
1157+
let actor_id = register_data_actor(clock.clone(), cache.clone(), trader_id);
1158+
let actor = get_actor_unchecked::<TestDataActor>(&actor_id);
1159+
actor.start().unwrap();
1160+
1161+
actor.subscribe_index_prices::<TestDataActor>(audusd_sim.id, None, None);
1162+
1163+
let topic = get_index_price_topic(audusd_sim.id);
1164+
let ip1 = IndexPriceUpdate::new(
1165+
audusd_sim.id,
1166+
Price::from("1.00000"),
1167+
UnixNanos::from(1),
1168+
UnixNanos::from(2),
1169+
);
1170+
msgbus::publish(&topic, &ip1);
1171+
1172+
assert_eq!(actor.received_index_prices.len(), 1);
1173+
1174+
actor.unsubscribe_index_prices::<TestDataActor>(audusd_sim.id, None, None);
1175+
1176+
let ip2 = IndexPriceUpdate::new(
1177+
audusd_sim.id,
1178+
Price::from("1.00010"),
1179+
UnixNanos::from(3),
1180+
UnixNanos::from(4),
1181+
);
1182+
msgbus::publish(&topic, &ip2);
1183+
1184+
assert_eq!(actor.received_index_prices.len(), 1);
1185+
}
1186+
1187+
#[rstest]
1188+
fn test_unsubscribe_instrument_status(
1189+
clock: Rc<RefCell<TestClock>>,
1190+
cache: Rc<RefCell<Cache>>,
1191+
trader_id: TraderId,
1192+
stub_instrument_status: InstrumentStatus,
1193+
) {
1194+
let actor_id = register_data_actor(clock.clone(), cache.clone(), trader_id);
1195+
let actor = get_actor_unchecked::<TestDataActor>(&actor_id);
1196+
actor.start().unwrap();
1197+
1198+
let instrument_id = stub_instrument_status.instrument_id;
1199+
actor.subscribe_instrument_status::<TestDataActor>(instrument_id, None, None);
1200+
1201+
let topic = get_instrument_status_topic(instrument_id);
1202+
msgbus::publish(&topic, &stub_instrument_status);
1203+
1204+
assert_eq!(actor.received_status.len(), 1);
1205+
1206+
actor.unsubscribe_instrument_status::<TestDataActor>(instrument_id, None, None);
1207+
1208+
let stub2 = stub_instrument_status.clone();
1209+
msgbus::publish(&topic, &stub2);
1210+
1211+
assert_eq!(actor.received_status.len(), 1);
1212+
}
1213+
1214+
#[rstest]
1215+
fn test_unsubscribe_instrument_close(
1216+
clock: Rc<RefCell<TestClock>>,
1217+
cache: Rc<RefCell<Cache>>,
1218+
trader_id: TraderId,
1219+
stub_instrument_close: InstrumentClose,
1220+
) {
1221+
let actor_id = register_data_actor(clock.clone(), cache.clone(), trader_id);
1222+
let actor = get_actor_unchecked::<TestDataActor>(&actor_id);
1223+
actor.start().unwrap();
1224+
1225+
let instrument_id = stub_instrument_close.instrument_id;
1226+
actor.subscribe_instrument_close::<TestDataActor>(instrument_id, None, None);
1227+
1228+
let topic = get_instrument_close_topic(instrument_id);
1229+
msgbus::publish(&topic, &stub_instrument_close);
1230+
1231+
assert_eq!(actor.received_closes.len(), 1);
1232+
1233+
actor.unsubscribe_instrument_close::<TestDataActor>(instrument_id, None, None);
1234+
1235+
let stub2 = stub_instrument_close.clone();
1236+
msgbus::publish(&topic, &stub2);
1237+
1238+
assert_eq!(actor.received_closes.len(), 1);
1239+
}
1240+
10361241
#[rstest]
10371242
fn test_request_book_snapshot(
10381243
clock: Rc<RefCell<TestClock>>,
@@ -1068,3 +1273,43 @@ fn test_request_book_snapshot(
10681273
assert_eq!(actor.received_books.len(), 1);
10691274
assert_eq!(actor.received_books[0], book);
10701275
}
1276+
1277+
#[rstest]
1278+
fn test_request_data(
1279+
clock: Rc<RefCell<TestClock>>,
1280+
cache: Rc<RefCell<Cache>>,
1281+
trader_id: TraderId,
1282+
) {
1283+
test_logging();
1284+
1285+
let actor_id = register_data_actor(clock.clone(), cache.clone(), trader_id);
1286+
let actor = get_actor_unchecked::<TestDataActor>(&actor_id);
1287+
actor.start().unwrap();
1288+
1289+
// Request custom data
1290+
let data_type = DataType::new("TestData", None);
1291+
let client_id = ClientId::new("TestClient");
1292+
let request_id = actor
1293+
.request_data::<TestDataActor>(data_type.clone(), client_id.clone(), None, None, None, None)
1294+
.unwrap();
1295+
1296+
// Build a response payload containing a String
1297+
let payload = Arc::new(Bytes::from("Data-001"));
1298+
let ts_init = UnixNanos::default();
1299+
// Create response with payload type String
1300+
let response = CustomDataResponse::new(
1301+
request_id,
1302+
client_id.clone(),
1303+
Venue::from("SIM"),
1304+
data_type.clone(),
1305+
payload,
1306+
ts_init,
1307+
None,
1308+
);
1309+
// Publish the response
1310+
msgbus::response(&request_id, response.as_any());
1311+
1312+
// Actor should receive the custom data
1313+
assert_eq!(actor.received_data.len(), 1);
1314+
assert_eq!(actor.received_data[0], "Any { .. }");
1315+
}

crates/common/src/messages/data.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,6 +1338,11 @@ impl CustomDataResponse {
13381338
params,
13391339
}
13401340
}
1341+
1342+
/// Converts the response to a dyn Any trait object for messaging.
1343+
pub fn as_any(&self) -> &dyn Any {
1344+
self
1345+
}
13411346
}
13421347

13431348
#[derive(Clone, Debug)]

0 commit comments

Comments
 (0)