Skip to content

Commit

Permalink
refactor(net/virtio): migrate NetDevCfgRaw to virtio-spec
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Kröning <[email protected]>
  • Loading branch information
mkroening committed Jun 10, 2024
1 parent 5a67fe4 commit 7c6056e
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 105 deletions.
51 changes: 4 additions & 47 deletions src/drivers/net/virtio/mmio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

use alloc::rc::Rc;
use alloc::vec::Vec;
use core::ptr::read_volatile;
use core::str::FromStr;

use smoltcp::phy::ChecksumCapabilities;
Expand All @@ -16,65 +15,23 @@ use crate::drivers::virtio::error::{VirtioError, VirtioNetError};
use crate::drivers::virtio::transport::mmio::{ComCfg, IsrStatus, NotifCfg};
use crate::drivers::virtio::virtqueue::Virtq;

/// Virtio's network device configuration structure.
/// See specification v1.1. - 5.1.4
///
#[repr(C)]
pub struct NetDevCfgRaw {
// Specifies Mac address, only Valid if VIRTIO_NET_F_MAC is set
mac: [u8; 6],
// Indicates status of device. Only valid if VIRTIO_NET_F_STATUS is set
status: u16,
// Indicates number of allowed vq-pairs. Only valid if VIRTIO_NET_F_MQ is set.
max_virtqueue_pairs: u16,
// Indicates the maximum MTU driver should use. Only valid if VIRTIONET_F_MTU is set.
mtu: u16,
}

impl NetDevCfgRaw {
pub fn get_mtu(&self) -> u16 {
// see Virtio specification v1.1 - 2.4.1
unsafe { read_volatile(&self.mtu) }
}

pub fn get_mac(&self) -> [u8; 6] {
let mut mac: [u8; 6] = [0u8; 6];

// see Virtio specification v1.1 - 2.4.1
unsafe {
let mut src = self.mac.iter();
mac.fill_with(|| read_volatile(src.next().unwrap()));
mac
}
}

pub fn get_status(&self) -> u16 {
// see Virtio specification v1.1 - 2.4.1
unsafe { read_volatile(&self.status) }
}

pub fn get_max_virtqueue_pairs(&self) -> u16 {
// see Virtio specification v1.1 - 2.4.1
unsafe { read_volatile(&self.max_virtqueue_pairs) }
}
}

// Backend-dependent interface for Virtio network driver
impl VirtioNetDriver {
pub fn new(
dev_id: u16,
mut registers: VolatileRef<'static, DeviceRegisters>,
irq: u8,
) -> Result<Self, VirtioNetError> {
let dev_cfg_raw: &'static NetDevCfgRaw = unsafe {
let dev_cfg_raw: &'static virtio_spec::net::Config = unsafe {
&*registers
.borrow_mut()
.as_mut_ptr()
.config()
.as_raw_ptr()
.cast::<NetDevCfgRaw>()
.cast::<virtio_spec::net::Config>()
.as_ptr()
};
let dev_cfg_raw = VolatileRef::from_ref(dev_cfg_raw);
let dev_cfg = NetDevCfg {
raw: dev_cfg_raw,
dev_id,
Expand Down Expand Up @@ -107,7 +64,7 @@ impl VirtioNetDriver {

pub fn print_information(&mut self) {
self.com_cfg.print_information();
if self.dev_status() == virtio_spec::net::S::LINK_UP.bits().to_ne() {
if self.dev_status() == virtio_spec::net::S::LINK_UP {
info!("The link of the network device is up!");
}
}
Expand Down
68 changes: 47 additions & 21 deletions src/drivers/net/virtio/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,8 @@
cfg_if::cfg_if! {
if #[cfg(feature = "pci")] {
mod pci;

use self::pci::NetDevCfgRaw;
} else {
mod mmio;

use self::mmio::NetDevCfgRaw;
}
}

Expand All @@ -24,8 +20,10 @@ use align_address::Align;
use pci_types::InterruptLine;
use smoltcp::phy::{Checksum, ChecksumCapabilities};
use smoltcp::wire::{EthernetFrame, Ipv4Packet, Ipv6Packet, ETHERNET_HEADER_LEN};
use virtio_spec::net::{Hdr, HdrF};
use virtio_spec::net::{ConfigVolatileFieldAccess, Hdr, HdrF};
use virtio_spec::FeatureBits;
use volatile::access::ReadOnly;
use volatile::VolatileRef;

use self::constants::MAX_NUM_VQ;
use self::error::VirtioNetError;
Expand All @@ -46,7 +44,7 @@ use crate::executor::device::{RxToken, TxToken};
/// Handling the right access to fields, as some are read-only
/// for the driver.
pub(crate) struct NetDevCfg {
pub raw: &'static NetDevCfgRaw,
pub raw: VolatileRef<'static, virtio_spec::net::Config, ReadOnly>,
pub dev_id: u16,
pub features: virtio_spec::net::F,
}
Expand Down Expand Up @@ -162,7 +160,7 @@ impl RxQueues {
(1514 + mem::size_of::<Hdr>())
.align_up(core::mem::size_of::<crossbeam_utils::CachePadded<u8>>())
} else {
dev_cfg.raw.get_mtu() as usize + mem::size_of::<Hdr>()
dev_cfg.raw.as_ptr().mtu().read().to_ne() as usize + mem::size_of::<Hdr>()
};

// See Virtio specification v1.1 - 5.1.6.3.1
Expand Down Expand Up @@ -326,8 +324,10 @@ impl TxQueues {
// Header and data are added as ONE output descriptor to the transmitvq.
// Hence we are interpreting this, as the fact, that send packets must be inside a single descriptor.
// As usize is currently safe as the minimal usize is defined as 16bit in rust.
let buff_def =
Bytes::new(mem::size_of::<Hdr>() + dev_cfg.raw.get_mtu() as usize).unwrap();
let buff_def = Bytes::new(
mem::size_of::<Hdr>() + dev_cfg.raw.as_ptr().mtu().read().to_ne() as usize,
)
.unwrap();
let spec = BuffSpec::Single(buff_def);

let num_buff: u16 = vq.size().into();
Expand Down Expand Up @@ -431,7 +431,7 @@ impl NetworkDriver for VirtioNetDriver {
if self.dev_cfg.features.contains(virtio_spec::net::F::MAC) {
loop {
let before = self.com_cfg.config_generation();
let mac = self.dev_cfg.raw.get_mac();
let mac = self.dev_cfg.raw.as_ptr().mac().read();
let after = self.com_cfg.config_generation();
if before == after {
break mac;
Expand Down Expand Up @@ -652,11 +652,11 @@ impl VirtioNetDriver {
/// Returns the current status of the device, if VIRTIO_NET_F_STATUS
/// has been negotiated. Otherwise assumes an active device.
#[cfg(not(feature = "pci"))]
pub fn dev_status(&self) -> u16 {
pub fn dev_status(&self) -> virtio_spec::net::S {
if self.dev_cfg.features.contains(virtio_spec::net::F::STATUS) {
self.dev_cfg.raw.get_status()
self.dev_cfg.raw.as_ptr().status().read()
} else {
virtio_spec::net::S::LINK_UP.bits().to_ne()
virtio_spec::net::S::LINK_UP
}
}

Expand All @@ -665,8 +665,12 @@ impl VirtioNetDriver {
#[cfg(feature = "pci")]
pub fn is_link_up(&self) -> bool {
if self.dev_cfg.features.contains(virtio_spec::net::F::STATUS) {
self.dev_cfg.raw.get_status() & virtio_spec::net::S::LINK_UP.bits().to_ne()
== virtio_spec::net::S::LINK_UP.bits().to_ne()
self.dev_cfg
.raw
.as_ptr()
.status()
.read()
.contains(virtio_spec::net::S::LINK_UP)
} else {
true
}
Expand All @@ -675,8 +679,12 @@ impl VirtioNetDriver {
#[allow(dead_code)]
pub fn is_announce(&self) -> bool {
if self.dev_cfg.features.contains(virtio_spec::net::F::STATUS) {
self.dev_cfg.raw.get_status() & virtio_spec::net::S::ANNOUNCE.bits().to_ne()
== virtio_spec::net::S::ANNOUNCE.bits().to_ne()
self.dev_cfg
.raw
.as_ptr()
.status()
.read()
.contains(virtio_spec::net::S::ANNOUNCE)
} else {
false
}
Expand All @@ -690,7 +698,12 @@ impl VirtioNetDriver {
#[allow(dead_code)]
pub fn get_max_vq_pairs(&self) -> u16 {
if self.dev_cfg.features.contains(virtio_spec::net::F::MQ) {
self.dev_cfg.raw.get_max_virtqueue_pairs()
self.dev_cfg
.raw
.as_ptr()
.max_virtqueue_pairs()
.read()
.to_ne()
} else {
1
}
Expand Down Expand Up @@ -853,7 +866,7 @@ impl VirtioNetDriver {
debug!("{:?}", self.checksums);

if self.dev_cfg.features.contains(virtio_spec::net::F::MTU) {
self.mtu = self.dev_cfg.raw.get_mtu();
self.mtu = self.dev_cfg.raw.as_ptr().mtu().read().to_ne();
}

Ok(())
Expand Down Expand Up @@ -939,10 +952,23 @@ impl VirtioNetDriver {
// - the num_queues is found in the ComCfg struct of the device and defines the maximal number
// of supported queues.
if self.dev_cfg.features.contains(virtio_spec::net::F::MQ) {
if self.dev_cfg.raw.get_max_virtqueue_pairs() * 2 >= MAX_NUM_VQ {
if self
.dev_cfg
.raw
.as_ptr()
.max_virtqueue_pairs()
.read()
.to_ne() * 2 >= MAX_NUM_VQ
{
self.num_vqs = MAX_NUM_VQ;
} else {
self.num_vqs = self.dev_cfg.raw.get_max_virtqueue_pairs() * 2;
self.num_vqs = self
.dev_cfg
.raw
.as_ptr()
.max_virtqueue_pairs()
.read()
.to_ne() * 2;
}
} else {
// Minimal number of virtqueues defined in the standard v1.1. - 5.1.5 Step 1
Expand Down
45 changes: 8 additions & 37 deletions src/drivers/net/virtio/pci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use core::str::FromStr;

use pci_types::CommandRegister;
use smoltcp::phy::ChecksumCapabilities;
use volatile::VolatileRef;

use crate::arch::pci::PciConfigRegion;
use crate::drivers::net::virtio::{CtrlQueue, NetDevCfg, RxQueues, TxQueues, VirtioNetDriver};
Expand All @@ -15,46 +16,16 @@ use crate::drivers::virtio::error::{self, VirtioError};
use crate::drivers::virtio::transport::pci;
use crate::drivers::virtio::transport::pci::{PciCap, UniCapsColl};

/// Virtio's network device configuration structure.
/// See specification v1.1. - 5.1.4
///
#[repr(C)]
pub(crate) struct NetDevCfgRaw {
// Specifies Mac address, only Valid if VIRTIO_NET_F_MAC is set
mac: [u8; 6],
// Indicates status of device. Only valid if VIRTIO_NET_F_STATUS is set
status: u16,
// Indicates number of allowed vq-pairs. Only valid if VIRTIO_NET_F_MQ is set.
max_virtqueue_pairs: u16,
// Indicates the maximum MTU driver should use. Only valid if VIRTIONET_F_MTU is set.
mtu: u16,
}

impl NetDevCfgRaw {
pub fn get_mtu(&self) -> u16 {
self.mtu
}

pub fn get_mac(&self) -> [u8; 6] {
self.mac
}

pub fn get_status(&self) -> u16 {
self.status
}

pub fn get_max_virtqueue_pairs(&self) -> u16 {
self.max_virtqueue_pairs
}
}

// Backend-dependent interface for Virtio network driver
impl VirtioNetDriver {
fn map_cfg(cap: &PciCap) -> Option<NetDevCfg> {
let dev_cfg: &'static NetDevCfgRaw = match pci::map_dev_cfg::<NetDevCfgRaw>(cap) {
Some(cfg) => cfg,
None => return None,
};
let dev_cfg: &'static virtio_spec::net::Config =
match pci::map_dev_cfg::<virtio_spec::net::Config>(cap) {
Some(cfg) => cfg,
None => return None,
};

let dev_cfg = VolatileRef::from_ref(dev_cfg);

Some(NetDevCfg {
raw: dev_cfg,
Expand Down

0 comments on commit 7c6056e

Please sign in to comment.