Skip to content

Commit

Permalink
feat: add ConnectionAddress type and Connectable traits
Browse files Browse the repository at this point in the history
  • Loading branch information
pv42 authored and patrickelectric committed Jan 9, 2025
1 parent a973af3 commit a44abd2
Show file tree
Hide file tree
Showing 13 changed files with 394 additions and 283 deletions.
62 changes: 29 additions & 33 deletions mavlink-core/src/async_connection/direct_serial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
use core::ops::DerefMut;
use std::io;

use async_trait::async_trait;
use tokio::sync::Mutex;
use tokio_serial::{SerialPort, SerialPortBuilderExt, SerialStream};

use crate::{async_peek_reader::AsyncPeekReader, MavHeader, MavlinkVersion, Message};
use super::AsyncConnectable;
use crate::{
async_peek_reader::AsyncPeekReader, connectable::SerialConnectable, MavHeader, MavlinkVersion,
Message,
};

#[cfg(not(feature = "signing"))]
use crate::{read_versioned_msg_async, write_versioned_msg_async};
Expand All @@ -17,38 +22,6 @@ use crate::{

use super::AsyncMavConnection;

pub fn open(settings: &str) -> io::Result<AsyncSerialConnection> {
let settings_toks: Vec<&str> = settings.split(':').collect();
if settings_toks.len() < 2 {
return Err(io::Error::new(
io::ErrorKind::AddrNotAvailable,
"Incomplete port settings",
));
}

let Ok(baud) = settings_toks[1].parse::<u32>() else {
return Err(io::Error::new(
io::ErrorKind::AddrNotAvailable,
"Invalid baud rate",
));
};

let port_name = settings_toks[0];
let mut port = tokio_serial::new(port_name, baud).open_native_async()?;
port.set_data_bits(tokio_serial::DataBits::Eight)?;
port.set_parity(tokio_serial::Parity::None)?;
port.set_stop_bits(tokio_serial::StopBits::One)?;
port.set_flow_control(tokio_serial::FlowControl::None)?;

Ok(AsyncSerialConnection {
port: Mutex::new(AsyncPeekReader::new(port)),
sequence: Mutex::new(0),
protocol_version: MavlinkVersion::V2,
#[cfg(feature = "signing")]
signing_data: None,
})
}

pub struct AsyncSerialConnection {
port: Mutex<AsyncPeekReader<SerialStream>>,
sequence: Mutex<u8>,
Expand Down Expand Up @@ -118,3 +91,26 @@ impl<M: Message + Sync + Send> AsyncMavConnection<M> for AsyncSerialConnection {
self.signing_data = signing_data.map(SigningData::from_config)
}
}

#[async_trait]
impl AsyncConnectable for SerialConnectable {
async fn connect_async<M>(&self) -> io::Result<Box<dyn AsyncMavConnection<M> + Sync + Send>>
where
M: Message + Sync + Send,
{
let mut port =
tokio_serial::new(&self.port_name, self.baud_rate as u32).open_native_async()?;
port.set_data_bits(tokio_serial::DataBits::Eight)?;
port.set_parity(tokio_serial::Parity::None)?;
port.set_stop_bits(tokio_serial::StopBits::One)?;
port.set_flow_control(tokio_serial::FlowControl::None)?;

Ok(Box::new(AsyncSerialConnection {
port: Mutex::new(AsyncPeekReader::new(port)),
sequence: Mutex::new(0),
protocol_version: MavlinkVersion::V2,
#[cfg(feature = "signing")]
signing_data: None,
}))
}
}
14 changes: 13 additions & 1 deletion mavlink-core/src/async_connection/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
use core::ops::DerefMut;

use super::AsyncMavConnection;
use super::{AsyncConnectable, AsyncMavConnection};
use crate::connectable::FileConnectable;
use crate::error::{MessageReadError, MessageWriteError};

use crate::{async_peek_reader::AsyncPeekReader, MavHeader, MavlinkVersion, Message};

use async_trait::async_trait;
use tokio::fs::File;
use tokio::io;
use tokio::sync::Mutex;
Expand Down Expand Up @@ -81,3 +83,13 @@ impl<M: Message + Sync + Send> AsyncMavConnection<M> for AsyncFileConnection {
self.signing_data = signing_data.map(SigningData::from_config)
}
}

#[async_trait]
impl AsyncConnectable for FileConnectable {
async fn connect_async<M>(&self) -> io::Result<Box<dyn AsyncMavConnection<M> + Sync + Send>>
where
M: Message + Sync + Send,
{
Ok(Box::new(open(&self.address).await?))
}
}
65 changes: 27 additions & 38 deletions mavlink-core/src/async_connection/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use async_trait::async_trait;
use tokio::io;

use crate::{MavFrame, MavHeader, MavlinkVersion, Message};
use crate::{connectable::ConnectionAddress, MavFrame, MavHeader, MavlinkVersion, Message};

#[cfg(feature = "tcp")]
mod tcp;
Expand Down Expand Up @@ -81,43 +82,9 @@ pub trait AsyncMavConnection<M: Message + Sync + Send> {
pub async fn connect_async<M: Message + Sync + Send>(
address: &str,
) -> io::Result<Box<dyn AsyncMavConnection<M> + Sync + Send>> {
let protocol_err = Err(io::Error::new(
io::ErrorKind::AddrNotAvailable,
"Protocol unsupported",
));

if cfg!(feature = "tcp") && address.starts_with("tcp") {
#[cfg(feature = "tcp")]
{
tcp::select_protocol(address).await
}
#[cfg(not(feature = "tcp"))]
{
protocol_err
}
} else if cfg!(feature = "udp") && address.starts_with("udp") {
#[cfg(feature = "udp")]
{
udp::select_protocol(address).await
}
#[cfg(not(feature = "udp"))]
{
protocol_err
}
} else if cfg!(feature = "direct-serial") && address.starts_with("serial") {
#[cfg(feature = "direct-serial")]
{
Ok(Box::new(direct_serial::open(&address["serial:".len()..])?))
}
#[cfg(not(feature = "direct-serial"))]
{
protocol_err
}
} else if address.starts_with("file") {
Ok(Box::new(file::open(&address["file:".len()..]).await?))
} else {
protocol_err
}
ConnectionAddress::parse_address(address)?
.connect_async::<M>()
.await
}

/// Returns the socket address for the given address.
Expand All @@ -135,3 +102,25 @@ pub(crate) fn get_socket_addr<T: std::net::ToSocketAddrs>(
};
Ok(addr)
}

#[async_trait]
pub trait AsyncConnectable {
async fn connect_async<M>(&self) -> io::Result<Box<dyn AsyncMavConnection<M> + Sync + Send>>
where
M: Message + Sync + Send;
}

#[async_trait]
impl AsyncConnectable for ConnectionAddress {
async fn connect_async<M>(&self) -> io::Result<Box<dyn AsyncMavConnection<M> + Sync + Send>>
where
M: Message + Sync + Send,
{
match self {
Self::Tcp(connectable) => connectable.connect_async::<M>().await,
Self::Udp(connectable) => connectable.connect_async::<M>().await,
Self::Serial(connectable) => connectable.connect_async::<M>().await,
Self::File(connectable) => connectable.connect_async::<M>().await,
}
}
}
36 changes: 18 additions & 18 deletions mavlink-core/src/async_connection/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
//! Async TCP MAVLink connection
use super::{get_socket_addr, AsyncMavConnection};
use super::{get_socket_addr, AsyncConnectable, AsyncMavConnection};
use crate::async_peek_reader::AsyncPeekReader;
use crate::connectable::TcpConnectable;
use crate::{MavHeader, MavlinkVersion, Message};

use async_trait::async_trait;
use core::ops::DerefMut;
use tokio::io;
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
Expand All @@ -17,23 +19,6 @@ use crate::{
read_versioned_msg_async_signed, write_versioned_msg_async_signed, SigningConfig, SigningData,
};

pub async fn select_protocol<M: Message + Sync + Send>(
address: &str,
) -> io::Result<Box<dyn AsyncMavConnection<M> + Sync + Send>> {
let connection = if let Some(address) = address.strip_prefix("tcpout:") {
tcpout(address).await
} else if let Some(address) = address.strip_prefix("tcpin:") {
tcpin(address).await
} else {
Err(io::Error::new(
io::ErrorKind::AddrNotAvailable,
"Protocol unsupported",
))
};

Ok(Box::new(connection?))
}

pub async fn tcpout<T: std::net::ToSocketAddrs>(address: T) -> io::Result<AsyncTcpConnection> {
let addr = get_socket_addr(address)?;

Expand Down Expand Up @@ -154,3 +139,18 @@ impl<M: Message + Sync + Send> AsyncMavConnection<M> for AsyncTcpConnection {
self.signing_data = signing_data.map(SigningData::from_config)
}
}

#[async_trait]
impl AsyncConnectable for TcpConnectable {
async fn connect_async<M>(&self) -> io::Result<Box<dyn AsyncMavConnection<M> + Sync + Send>>
where
M: Message + Sync + Send,
{
let conn = if self.is_out {
tcpout(&self.address).await
} else {
tcpin(&self.address).await
};
Ok(Box::new(conn?))
}
}
71 changes: 25 additions & 46 deletions mavlink-core/src/async_connection/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,20 @@
use core::{ops::DerefMut, task::Poll};
use std::{collections::VecDeque, io::Read, sync::Arc};

use async_trait::async_trait;
use tokio::{
io::{self, AsyncRead, ReadBuf},
net::UdpSocket,
sync::Mutex,
};

use crate::{async_peek_reader::AsyncPeekReader, MavHeader, MavlinkVersion, Message};
use crate::{
async_peek_reader::AsyncPeekReader,
connectable::{UdpConnectable, UdpMode},
MavHeader, MavlinkVersion, Message,
};

use super::{get_socket_addr, AsyncMavConnection};
use super::{get_socket_addr, AsyncConnectable, AsyncMavConnection};

#[cfg(not(feature = "signing"))]
use crate::{read_versioned_msg_async, write_versioned_msg_async};
Expand All @@ -20,50 +25,6 @@ use crate::{
read_versioned_msg_async_signed, write_versioned_msg_signed, SigningConfig, SigningData,
};

pub async fn select_protocol<M: Message + Sync + Send>(
address: &str,
) -> io::Result<Box<dyn AsyncMavConnection<M> + Sync + Send>> {
let connection = if let Some(address) = address.strip_prefix("udpin:") {
udpin(address).await
} else if let Some(address) = address.strip_prefix("udpout:") {
udpout(address).await
} else if let Some(address) = address.strip_prefix("udpbcast:") {
udpbcast(address).await
} else {
Err(io::Error::new(
io::ErrorKind::AddrNotAvailable,
"Protocol unsupported",
))
};

Ok(Box::new(connection?))
}

pub async fn udpbcast<T: std::net::ToSocketAddrs>(address: T) -> io::Result<AsyncUdpConnection> {
let addr = get_socket_addr(address)?;
let socket = UdpSocket::bind("0.0.0.0:0").await?;
socket
.set_broadcast(true)
.expect("Couldn't bind to broadcast address.");
AsyncUdpConnection::new(socket, false, Some(addr))
}

pub async fn udpout<T: std::net::ToSocketAddrs>(address: T) -> io::Result<AsyncUdpConnection> {
let addr = get_socket_addr(address)?;
let socket = UdpSocket::bind("0.0.0.0:0").await?;
AsyncUdpConnection::new(socket, false, Some(addr))
}

pub async fn udpin<T: std::net::ToSocketAddrs>(address: T) -> io::Result<AsyncUdpConnection> {
let addr = address
.to_socket_addrs()
.unwrap()
.next()
.expect("Invalid address");
let socket = UdpSocket::bind(addr).await?;
AsyncUdpConnection::new(socket, true, None)
}

struct UdpRead {
socket: Arc<UdpSocket>,
buffer: VecDeque<u8>,
Expand Down Expand Up @@ -235,6 +196,24 @@ impl<M: Message + Sync + Send> AsyncMavConnection<M> for AsyncUdpConnection {
}
}

#[async_trait]
impl AsyncConnectable for UdpConnectable {
async fn connect_async<M>(&self) -> io::Result<Box<dyn AsyncMavConnection<M> + Sync + Send>>
where
M: Message + Sync + Send,
{
let (addr, server, dest): (&str, _, _) = match self.mode {
UdpMode::Udpin => (&self.address, true, None),
_ => ("0.0.0.0:0", false, Some(get_socket_addr(&self.address)?)),
};
let socket = UdpSocket::bind(addr).await?;
if matches!(self.mode, UdpMode::Udpcast) {
socket.set_broadcast(true)?;
}
Ok(Box::new(AsyncUdpConnection::new(socket, server, dest)?))
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading

0 comments on commit a44abd2

Please sign in to comment.