Skip to content

Commit

Permalink
feat(preimage): Decouple from kona-common (#817)
Browse files Browse the repository at this point in the history
  • Loading branch information
clabby authored Nov 15, 2024
1 parent 433d4d9 commit 1953014
Show file tree
Hide file tree
Showing 13 changed files with 266 additions and 161 deletions.
2 changes: 0 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions bin/client/src/fault/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
//! Contains FPVM-specific constructs for the `kona-client` program.
use kona_client::PipeHandle;
use kona_common::FileDescriptor;
use kona_preimage::{HintWriter, OracleReader, PipeHandle};
use kona_preimage::{HintWriter, OracleReader};

mod handler;
pub(crate) use handler::fpvm_handle_register;
Expand All @@ -15,7 +16,7 @@ static HINT_WRITER_PIPE: PipeHandle =
PipeHandle::new(FileDescriptor::HintRead, FileDescriptor::HintWrite);

/// The global preimage oracle reader.
pub(crate) static ORACLE_READER: OracleReader = OracleReader::new(ORACLE_READER_PIPE);
pub(crate) static ORACLE_READER: OracleReader<PipeHandle> = OracleReader::new(ORACLE_READER_PIPE);

/// The global hint writer.
pub(crate) static HINT_WRITER: HintWriter = HintWriter::new(HINT_WRITER_PIPE);
pub(crate) static HINT_WRITER: HintWriter<PipeHandle> = HintWriter::new(HINT_WRITER_PIPE);
3 changes: 3 additions & 0 deletions bin/client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,8 @@ pub use hint::HintType;
pub mod boot;
pub use boot::BootInfo;

mod pipe;
pub use pipe::PipeHandle;

mod caching_oracle;
pub use caching_oracle::{CachingOracle, FlushableCache};
21 changes: 21 additions & 0 deletions crates/proof-sdk/preimage/src/pipe.rs → bin/client/src/pipe.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//! This module contains a rudamentary pipe between two file descriptors, using [kona_common::io]
//! for reading and writing from the file descriptors.
use alloc::boxed::Box;
use async_trait::async_trait;
use core::{
cell::RefCell,
cmp::Ordering,
Expand All @@ -9,6 +11,10 @@ use core::{
task::{Context, Poll},
};
use kona_common::{errors::IOResult, io, FileDescriptor};
use kona_preimage::{
errors::{ChannelError, ChannelResult},
Channel,
};

/// [PipeHandle] is a handle for one end of a bidirectional pipe.
#[derive(Debug, Clone, Copy)]
Expand Down Expand Up @@ -51,6 +57,21 @@ impl PipeHandle {
}
}

#[async_trait]
impl Channel for PipeHandle {
async fn read(&self, buf: &mut [u8]) -> ChannelResult<usize> {
self.read(buf).map_err(|_| ChannelError::Closed)
}

async fn read_exact(&self, buf: &mut [u8]) -> ChannelResult<usize> {
self.read_exact(buf).await.map_err(|_| ChannelError::Closed)
}

async fn write(&self, buf: &[u8]) -> ChannelResult<usize> {
self.write(buf).await.map_err(|_| ChannelError::Closed)
}
}

/// A future that reads from a pipe, returning [Poll::Ready] when the buffer is full.
struct ReadFuture<'a> {
/// The pipe handle to read from
Expand Down
3 changes: 2 additions & 1 deletion bin/host/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ use server::PreimageServer;
use anyhow::{anyhow, bail, Result};
use command_fds::{CommandFdExt, FdMapping};
use futures::FutureExt;
use kona_client::PipeHandle;
use kona_common::FileDescriptor;
use kona_preimage::{HintReader, OracleServer, PipeHandle};
use kona_preimage::{HintReader, OracleServer};
use kv::KeyValueStore;
use std::{
io::{stderr, stdin, stdout},
Expand Down
6 changes: 3 additions & 3 deletions crates/proof-sdk/preimage/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ thiserror.workspace = true
async-trait.workspace = true
alloy-primitives.workspace = true

# Workspace
kona-common.workspace = true
# `std` feature dependencies
tokio = { workspace = true, features = ["full"], optional = true }

# `rkyv` feature dependencies
rkyv = { workspace = true, optional = true }
Expand All @@ -28,10 +28,10 @@ rkyv = { workspace = true, optional = true }
serde = { workspace = true, optional = true, features = ["derive"] }

[dev-dependencies]
os_pipe.workspace = true
tokio = { workspace = true, features = ["full"] }

[features]
default = []
std = ["dep:tokio"]
rkyv = ["dep:rkyv"]
serde = ["dep:serde"]
19 changes: 17 additions & 2 deletions crates/proof-sdk/preimage/src/errors.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
//! Errors for the `kona-preimage` crate.
use alloc::string::String;
use kona_common::errors::IOError;
use thiserror::Error;

/// A [PreimageOracleError] is an enum that differentiates pipe-related errors from other errors
Expand All @@ -13,7 +12,7 @@ use thiserror::Error;
pub enum PreimageOracleError {
/// The pipe has been broken.
#[error(transparent)]
IOError(#[from] IOError),
IOError(#[from] ChannelError),
/// The preimage key is invalid.
#[error("Invalid preimage key.")]
InvalidPreimageKey,
Expand All @@ -30,3 +29,19 @@ pub enum PreimageOracleError {

/// A [Result] type for the [PreimageOracleError] enum.
pub type PreimageOracleResult<T> = Result<T, PreimageOracleError>;

/// A [ChannelError] is an enum that describes the error cases of a [Channel] trait implementation.
///
/// [Channel]: crate::Channel
#[derive(Error, Debug)]
pub enum ChannelError {
/// The channel is closed.
#[error("Channel is closed.")]
Closed,
/// Unexpected EOF.
#[error("Unexpected EOF in channel read operation.")]
UnexpectedEOF,
}

/// A [Result] type for the [ChannelError] enum.
pub type ChannelResult<T> = Result<T, ChannelError>;
111 changes: 48 additions & 63 deletions crates/proof-sdk/preimage/src/hint.rs
Original file line number Diff line number Diff line change
@@ -1,47 +1,46 @@
use crate::{
errors::{PreimageOracleError, PreimageOracleResult},
traits::{HintRouter, HintWriterClient},
HintReaderServer, PipeHandle,
Channel, HintReaderServer,
};
use alloc::{boxed::Box, format, string::String, vec};
use async_trait::async_trait;
use tracing::{error, trace};

/// A [HintWriter] is a high-level interface to the hint pipe. It provides a way to write hints to
/// the host.
/// A [HintWriter] is a high-level interface to the hint channel. It provides a way to write hints
/// to the host.
#[derive(Debug, Clone, Copy)]
pub struct HintWriter {
pipe_handle: PipeHandle,
pub struct HintWriter<C> {
channel: C,
}

impl HintWriter {
/// Create a new [HintWriter] from a [PipeHandle].
pub const fn new(pipe_handle: PipeHandle) -> Self {
Self { pipe_handle }
impl<C> HintWriter<C> {
/// Create a new [HintWriter] from a [Channel].
pub const fn new(channel: C) -> Self {
Self { channel }
}
}

#[async_trait]
impl HintWriterClient for HintWriter {
/// Write a hint to the host. This will overwrite any existing hint in the pipe, and block until
/// all data has been written.
impl<C> HintWriterClient for HintWriter<C>
where
C: Channel + Send + Sync,
{
/// Write a hint to the host. This will overwrite any existing hint in the channel, and block
/// until all data has been written.
async fn write(&self, hint: &str) -> PreimageOracleResult<()> {
// Form the hint into a byte buffer. The format is a 4-byte big-endian length prefix
// followed by the hint string.
let mut hint_bytes = vec![0u8; hint.len() + 4];
hint_bytes[0..4].copy_from_slice(u32::to_be_bytes(hint.len() as u32).as_ref());
hint_bytes[4..].copy_from_slice(hint.as_bytes());

trace!(target: "hint_writer", "Writing hint \"{hint}\"");

// Write the hint to the host.
self.pipe_handle.write(&hint_bytes).await?;
// Form the hint into a byte buffer. The format is a 4-byte big-endian length prefix
// followed by the hint string.
self.channel.write(u32::to_be_bytes(hint.len() as u32).as_ref()).await?;
self.channel.write(hint.as_bytes()).await?;

trace!(target: "hint_writer", "Successfully wrote hint");

// Read the hint acknowledgement from the host.
let mut hint_ack = [0u8; 1];
self.pipe_handle.read_exact(&mut hint_ack).await?;
self.channel.read_exact(&mut hint_ack).await?;

trace!(target: "hint_writer", "Received hint acknowledgement");

Expand All @@ -52,36 +51,42 @@ impl HintWriterClient for HintWriter {
/// A [HintReader] is a router for hints sent by the [HintWriter] from the client program. It
/// provides a way for the host to prepare preimages for reading.
#[derive(Debug, Clone, Copy)]
pub struct HintReader {
pipe_handle: PipeHandle,
pub struct HintReader<C> {
channel: C,
}

impl HintReader {
/// Create a new [HintReader] from a [PipeHandle].
pub const fn new(pipe_handle: PipeHandle) -> Self {
Self { pipe_handle }
impl<C> HintReader<C>
where
C: Channel,
{
/// Create a new [HintReader] from a [Channel].
pub const fn new(channel: C) -> Self {
Self { channel }
}
}

#[async_trait]
impl HintReaderServer for HintReader {
impl<C> HintReaderServer for HintReader<C>
where
C: Channel + Send + Sync,
{
async fn next_hint<R>(&self, hint_router: &R) -> PreimageOracleResult<()>
where
R: HintRouter + Send + Sync,
{
// Read the length of the raw hint payload.
let mut len_buf = [0u8; 4];
self.pipe_handle.read_exact(&mut len_buf).await?;
self.channel.read_exact(&mut len_buf).await?;
let len = u32::from_be_bytes(len_buf);

// Read the raw hint payload.
let mut raw_payload = vec![0u8; len as usize];
self.pipe_handle.read_exact(raw_payload.as_mut_slice()).await?;
self.channel.read_exact(raw_payload.as_mut_slice()).await?;
let payload = match String::from_utf8(raw_payload) {
Ok(p) => p,
Err(e) => {
// Write back on error to prevent blocking the client.
self.pipe_handle.write(&[0x00]).await?;
self.channel.write(&[0x00]).await?;

return Err(PreimageOracleError::Other(format!(
"Failed to decode hint payload: {e}"
Expand All @@ -94,14 +99,14 @@ impl HintReaderServer for HintReader {
// Route the hint
if let Err(e) = hint_router.route_hint(payload).await {
// Write back on error to prevent blocking the client.
self.pipe_handle.write(&[0x00]).await?;
self.channel.write(&[0x00]).await?;

error!("Failed to route hint: {e}");
return Err(e);
}

// Write back an acknowledgement to the client to unblock their process.
self.pipe_handle.write(&[0x00]).await?;
self.channel.write(&[0x00]).await?;

trace!(target: "hint_reader", "Successfully routed and acknowledged hint");

Expand All @@ -112,10 +117,8 @@ impl HintReaderServer for HintReader {
#[cfg(test)]
mod test {
use super::*;
use crate::test_utils::bidirectional_pipe;
use crate::native_channel::BidirectionalChannel;
use alloc::{sync::Arc, vec::Vec};
use kona_common::FileDescriptor;
use std::os::unix::io::AsRawFd;
use tokio::sync::Mutex;

struct TestRouter {
Expand Down Expand Up @@ -143,24 +146,18 @@ mod test {
async fn test_unblock_on_bad_utf8() {
let mock_data = [0xf0, 0x90, 0x28, 0xbc];

let hint_pipe = bidirectional_pipe().unwrap();
let hint_channel = BidirectionalChannel::new::<2>().unwrap();

let client = tokio::task::spawn(async move {
let hint_writer = HintWriter::new(PipeHandle::new(
FileDescriptor::Wildcard(hint_pipe.client.read.as_raw_fd() as usize),
FileDescriptor::Wildcard(hint_pipe.client.write.as_raw_fd() as usize),
));
let hint_writer = HintWriter::new(hint_channel.client);

#[allow(invalid_from_utf8_unchecked)]
hint_writer.write(unsafe { alloc::str::from_utf8_unchecked(&mock_data) }).await
});
let host = tokio::task::spawn(async move {
let router = TestRouter { incoming_hints: Default::default() };

let hint_reader = HintReader::new(PipeHandle::new(
FileDescriptor::Wildcard(hint_pipe.host.read.as_raw_fd() as usize),
FileDescriptor::Wildcard(hint_pipe.host.write.as_raw_fd() as usize),
));
let hint_reader = HintReader::new(hint_channel.host);
hint_reader.next_hint(&router).await
});

Expand All @@ -178,21 +175,15 @@ mod test {
async fn test_unblock_on_fetch_failure() {
const MOCK_DATA: &str = "test-hint 0xfacade";

let hint_pipe = bidirectional_pipe().unwrap();
let hint_channel = BidirectionalChannel::new::<2>().unwrap();

let client = tokio::task::spawn(async move {
let hint_writer = HintWriter::new(PipeHandle::new(
FileDescriptor::Wildcard(hint_pipe.client.read.as_raw_fd() as usize),
FileDescriptor::Wildcard(hint_pipe.client.write.as_raw_fd() as usize),
));
let hint_writer = HintWriter::new(hint_channel.client);

hint_writer.write(MOCK_DATA).await
});
let host = tokio::task::spawn(async move {
let hint_reader = HintReader::new(PipeHandle::new(
FileDescriptor::Wildcard(hint_pipe.host.read.as_raw_fd() as usize),
FileDescriptor::Wildcard(hint_pipe.host.write.as_raw_fd() as usize),
));
let hint_reader = HintReader::new(hint_channel.host);
hint_reader.next_hint(&TestFailRouter).await
});

Expand All @@ -206,13 +197,10 @@ mod test {
const MOCK_DATA: &str = "test-hint 0xfacade";

let incoming_hints = Arc::new(Mutex::new(Vec::new()));
let hint_pipe = bidirectional_pipe().unwrap();
let hint_channel = BidirectionalChannel::new::<2>().unwrap();

let client = tokio::task::spawn(async move {
let hint_writer = HintWriter::new(PipeHandle::new(
FileDescriptor::Wildcard(hint_pipe.client.read.as_raw_fd() as usize),
FileDescriptor::Wildcard(hint_pipe.client.write.as_raw_fd() as usize),
));
let hint_writer = HintWriter::new(hint_channel.client);

hint_writer.write(MOCK_DATA).await
});
Expand All @@ -221,10 +209,7 @@ mod test {
async move {
let router = TestRouter { incoming_hints: incoming_hints_ref };

let hint_reader = HintReader::new(PipeHandle::new(
FileDescriptor::Wildcard(hint_pipe.host.read.as_raw_fd() as usize),
FileDescriptor::Wildcard(hint_pipe.host.write.as_raw_fd() as usize),
));
let hint_reader = HintReader::new(hint_channel.host);
hint_reader.next_hint(&router).await.unwrap();
}
});
Expand Down
Loading

0 comments on commit 1953014

Please sign in to comment.