Skip to content
This repository was archived by the owner on May 27, 2025. It is now read-only.

Commit 04d600a

Browse files
committed
Support Multiple Upstream WebSocket Connections for HA Setups
1 parent 1460721 commit 04d600a

File tree

3 files changed

+112
-20
lines changed

3 files changed

+112
-20
lines changed

src/main.rs

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ struct Args {
3636
listen_addr: SocketAddr,
3737

3838
#[arg(long, env, help = "WebSocket URI of the upstream server to connect to")]
39-
upstream_ws: Uri,
39+
upstream_ws: Vec<Uri>,
4040

4141
#[arg(
4242
long,
@@ -144,6 +144,14 @@ async fn main() {
144144
.expect("failed to setup Prometheus endpoint")
145145
}
146146

147+
// Validate that we have at least one upstream URI
148+
if args.upstream_ws.is_empty() {
149+
error!(message = "no upstream URIs provided");
150+
panic!("No upstream URIs provided");
151+
}
152+
153+
info!(message = "using upstream URIs", uris = ?args.upstream_ws);
154+
147155
let metrics = Arc::new(Metrics::default());
148156
let metrics_clone = metrics.clone();
149157

@@ -165,14 +173,29 @@ async fn main() {
165173
};
166174

167175
let token = CancellationToken::new();
176+
let mut subscriber_tasks = Vec::new();
177+
178+
// Start a subscriber for each upstream URI
179+
for (index, uri) in args.upstream_ws.iter().enumerate() {
180+
let uri_clone = uri.clone();
181+
let listener_clone = listener.clone();
182+
let token_clone = token.clone();
183+
let metrics_clone = metrics.clone();
184+
185+
let mut subscriber = WebsocketSubscriber::new(
186+
uri_clone,
187+
listener_clone,
188+
args.subscriber_max_interval,
189+
metrics_clone,
190+
);
191+
192+
let task = tokio::spawn(async move {
193+
info!(message = "starting subscriber", index = index, uri = uri_clone.to_string());
194+
subscriber.run(token_clone).await;
195+
});
168196

169-
let mut subscriber = WebsocketSubscriber::new(
170-
args.upstream_ws,
171-
listener,
172-
args.subscriber_max_interval,
173-
metrics.clone(),
174-
);
175-
let subscriber_task = subscriber.run(token.clone());
197+
subscriber_tasks.push(task);
198+
}
176199

177200
let registry = Registry::new(sender, metrics.clone());
178201

@@ -194,8 +217,8 @@ async fn main() {
194217
let mut terminate = signal(SignalKind::terminate()).unwrap();
195218

196219
tokio::select! {
197-
_ = subscriber_task => {
198-
info!("subscriber task terminated");
220+
_ = futures::future::join_all(subscriber_tasks) => {
221+
info!("all subscriber tasks terminated");
199222
token.cancel();
200223
},
201224
_ = server_task => {

src/metrics.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,17 @@ pub struct Metrics {
2929

3030
#[metric(describe = "Count of messages received from the upstream source")]
3131
pub upstream_messages: Gauge,
32+
33+
// New metrics for multiple upstream connections
34+
#[metric(describe = "Number of active upstream connections")]
35+
pub upstream_connections: Gauge,
36+
37+
#[metric(describe = "Number of upstream connection attempts")]
38+
pub upstream_connection_attempts: Counter,
39+
40+
#[metric(describe = "Number of successful upstream connections")]
41+
pub upstream_connection_successes: Counter,
42+
43+
#[metric(describe = "Number of failed upstream connection attempts")]
44+
pub upstream_connection_failures: Counter,
3245
}

src/subscriber.rs

Lines changed: 66 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,27 +40,52 @@ where
4040
}
4141

4242
pub async fn run(&mut self, token: CancellationToken) {
43-
info!("starting upstream subscription");
43+
// Added the URI to the log message for better identification
44+
info!(
45+
message = "starting upstream subscription",
46+
uri = self.uri.to_string()
47+
);
4448
loop {
4549
select! {
4650
_ = token.cancelled() => {
47-
info!("cancelled upstream subscription");
51+
info!(
52+
message = "cancelled upstream subscription",
53+
uri = self.uri.to_string()
54+
);
4855
return;
4956
}
5057
result = self.connect_and_listen() => {
5158
match result {
5259
Ok(()) => {
53-
info!(message="upstream connection closed");
60+
info!(
61+
message = "upstream connection closed",
62+
uri = self.uri.to_string()
63+
);
5464
}
5565
Err(e) => {
56-
error!(message="upstream websocket error", error=e.to_string());
66+
// Added URI to the error log for better debugging
67+
error!(
68+
message = "upstream websocket error",
69+
uri = self.uri.to_string(),
70+
error = e.to_string()
71+
);
5772
self.metrics.upstream_errors.increment(1);
73+
// Decrement the active connections count when connection fails
74+
self.metrics.upstream_connections.decrement(1);
5875

5976
if let Some(duration) = self.backoff.next_backoff() {
60-
warn!(message="recconecting", seconds=duration.as_secs());
77+
// Added URI to the warning message
78+
warn!(
79+
message = "reconnecting",
80+
uri = self.uri.to_string(),
81+
seconds = duration.as_secs()
82+
);
6183
select! {
6284
_ = token.cancelled() => {
63-
info!(message="cancelled subscriber during backoff");
85+
info!(
86+
message = "cancelled subscriber during backoff",
87+
uri = self.uri.to_string()
88+
);
6489
return
6590
}
6691
_ = tokio::time::sleep(duration) => {}
@@ -79,8 +104,31 @@ where
79104
uri = self.uri.to_string()
80105
);
81106

82-
let (ws_stream, _) = connect_async(&self.uri).await?;
83-
info!(message = "websocket connection established");
107+
// Increment connection attempts counter for metrics
108+
self.metrics.upstream_connection_attempts.increment(1);
109+
110+
// Modified connection with success/failure metrics tracking
111+
let (ws_stream, _) = match connect_async(&self.uri).await {
112+
Ok(connection) => {
113+
// Track successful connections
114+
self.metrics.upstream_connection_successes.increment(1);
115+
connection
116+
},
117+
Err(e) => {
118+
// Track failed connections
119+
self.metrics.upstream_connection_failures.increment(1);
120+
return Err(e);
121+
}
122+
};
123+
124+
info!(
125+
message = "websocket connection established",
126+
uri = self.uri.to_string()
127+
);
128+
129+
// Increment active connections counter
130+
self.metrics.upstream_connections.increment(1);
131+
// Reset backoff timer on successful connection
84132
self.backoff.reset();
85133

86134
let (_, mut read) = ws_stream.split();
@@ -89,12 +137,20 @@ where
89137
match message {
90138
Ok(msg) => {
91139
let text = msg.to_text()?;
92-
trace!(message = "received message", payload = text);
140+
trace!(
141+
message = "received message",
142+
uri = self.uri.to_string(),
143+
payload = text
144+
);
93145
self.metrics.upstream_messages.increment(1);
94146
(self.handler)(text.into());
95147
}
96148
Err(e) => {
97-
error!(message = "error receiving message", error = e.to_string());
149+
error!(
150+
message = "error receiving message",
151+
uri = self.uri.to_string(),
152+
error = e.to_string()
153+
);
98154
return Err(e);
99155
}
100156
}

0 commit comments

Comments
 (0)