diff --git a/examples/chat.rs b/examples/chat.rs index 6f4c4fc..0e3bb11 100644 --- a/examples/chat.rs +++ b/examples/chat.rs @@ -1,64 +1,69 @@ -use async_std::{prelude::*,io,task,net}; -use cable::{Cable,MemoryStore,ChannelOptions}; +use async_std::{io, net, prelude::*, task}; +use cable::{Cable, ChannelOptions, MemoryStore}; -type Error = Box; +type Error = Box; -fn main() -> Result<(),Error> { - let (args,argv) = argmap::parse(std::env::args()); +fn main() -> Result<(), Error> { + let (args, argv) = argmap::parse(std::env::args()); - task::block_on(async move { - let store = MemoryStore::default(); - let cable = Cable::new(store); - { - let opts = ChannelOptions { - channel: "default".as_bytes().to_vec(), - time_start: 0, - //time_end: now(), - time_end: 0, - limit: 20, - }; - let mut client = cable.clone(); - task::spawn(async move { - let mut msg_stream = client.open_channel(&opts).await.unwrap(); - while let Some(msg) = msg_stream.next().await { - println!["msg={:?}", msg]; + task::block_on(async move { + let store = MemoryStore::default(); + let cable = Cable::new(store); + { + let opts = ChannelOptions { + channel: "default".as_bytes().to_vec(), + time_start: 0, + //time_end: now(), + time_end: 0, + limit: 20, + }; + let mut client = cable.clone(); + task::spawn(async move { + let mut msg_stream = client.open_channel(&opts).await.unwrap(); + while let Some(msg) = msg_stream.next().await { + println!["msg={:?}", msg]; + } + }); } - }); - } - { - let mut client = cable.clone(); - task::spawn(async move { - let stdin = io::stdin(); - let mut line = String::new(); - loop { - stdin.read_line(&mut line).await.unwrap(); - if line.is_empty() { break } - let channel = "default".as_bytes(); - let text = line.trim_end().as_bytes(); - client.post_text(channel, &text).await.unwrap(); - line.clear(); + { + let mut client = cable.clone(); + task::spawn(async move { + let stdin = io::stdin(); + let mut line = String::new(); + loop { + stdin.read_line(&mut line).await.unwrap(); + if line.is_empty() { + break; + } + let channel = "default".as_bytes(); + let text = line.trim_end().as_bytes(); + client.post_text(channel, &text).await.unwrap(); + line.clear(); + } + }); } - }); - } - if let Some(port) = argv.get("l").and_then(|x| x.first()) { - let listener = net::TcpListener::bind(format!["0.0.0.0:{}",port]).await?; - let mut incoming = listener.incoming(); - while let Some(rstream) = incoming.next().await { - let stream = rstream.unwrap(); - let client = cable.clone(); - task::spawn(async move { - client.listen(stream).await.unwrap(); - }); - } - } else if let Some(addr) = args.get(1) { - let stream = net::TcpStream::connect(addr).await?; - cable.listen(stream).await?; - } - Ok(()) - }) + if let Some(port) = argv.get("l").and_then(|x| x.first()) { + let listener = net::TcpListener::bind(format!["0.0.0.0:{}", port]).await?; + let mut incoming = listener.incoming(); + while let Some(rstream) = incoming.next().await { + let stream = rstream.unwrap(); + let client = cable.clone(); + task::spawn(async move { + client.listen(stream).await.unwrap(); + }); + } + } else if let Some(addr) = args.get(1) { + let stream = net::TcpStream::connect(addr).await?; + cable.listen(stream).await?; + } + Ok(()) + }) } fn _now() -> u64 { - std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs() + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() } diff --git a/src/error.rs b/src/error.rs index 7740ed8..f40171e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -3,65 +3,80 @@ use std::backtrace::Backtrace; #[derive(Debug)] pub struct CableError { - kind: CableErrorKind, - backtrace: Backtrace, + kind: CableErrorKind, + backtrace: Backtrace, } #[derive(Debug)] pub enum CableErrorKind { - DstTooSmall { provided: usize, required: usize }, - MessageEmpty {}, - MessageWriteUnrecognizedType { msg_type: u64 }, - MessageHashResponseEnd {}, - MessageDataResponseEnd {}, - MessageHashRequestEnd {}, - MessageCancelRequestEnd {}, - MessageChannelTimeRangeRequestEnd {}, - MessageChannelStateRequestEnd {}, - MessageChannelListRequestEnd {}, - PostWriteUnrecognizedType { post_type: u64 }, + DstTooSmall { provided: usize, required: usize }, + MessageEmpty {}, + MessageWriteUnrecognizedType { msg_type: u64 }, + MessageHashResponseEnd {}, + MessageDataResponseEnd {}, + MessageHashRequestEnd {}, + MessageCancelRequestEnd {}, + MessageChannelTimeRangeRequestEnd {}, + MessageChannelStateRequestEnd {}, + MessageChannelListRequestEnd {}, + PostWriteUnrecognizedType { post_type: u64 }, } impl CableErrorKind { - pub fn raise(self) -> Result { - Err(Box::new(CableError { - kind: self, - backtrace: Backtrace::capture(), - })) - } + pub fn raise(self) -> Result { + Err(Box::new(CableError { + kind: self, + backtrace: Backtrace::capture(), + })) + } } impl std::error::Error for CableError { - fn backtrace<'a>(&'a self) -> Option<&'a Backtrace> { - Some(&self.backtrace) - } + fn backtrace<'a>(&'a self) -> Option<&'a Backtrace> { + Some(&self.backtrace) + } } impl std::fmt::Display for CableError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match &self.kind { - CableErrorKind::MessageEmpty {} => { - write![f, "empty message"] - }, - CableErrorKind::MessageWriteUnrecognizedType { msg_type } => { - write![f, "cannot write unrecognized msg_type={}", msg_type] - }, - CableErrorKind::DstTooSmall { provided, required } => { - write![f, "destination buffer too small. {} bytes required, {} provided", - required, provided] - }, - CableErrorKind::MessageHashResponseEnd {} => { write![f, "unexpected end of HashResponse"] }, - CableErrorKind::MessageDataResponseEnd {} => { write![f, "unexpected end of DataResponse"] }, - CableErrorKind::MessageHashRequestEnd {} => { write![f, "unexpected end of HashRequest"] }, - CableErrorKind::MessageCancelRequestEnd {} => { write![f, "unexpected end of CancelRequest"] }, - CableErrorKind::MessageChannelTimeRangeRequestEnd {} => { - write![f, "unexpected end of ChannelTimeRangeRequest"] - }, - CableErrorKind::MessageChannelStateRequestEnd {} => { write![f, "unexpected end of ChannelStateRequest"] }, - CableErrorKind::MessageChannelListRequestEnd {} => { write![f, "unexpected end of ChannelListRequest"] }, - CableErrorKind::PostWriteUnrecognizedType { post_type } => { - write![f, "cannot write unrecognized post_type={}", post_type] - }, + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.kind { + CableErrorKind::MessageEmpty {} => { + write![f, "empty message"] + } + CableErrorKind::MessageWriteUnrecognizedType { msg_type } => { + write![f, "cannot write unrecognized msg_type={}", msg_type] + } + CableErrorKind::DstTooSmall { provided, required } => { + write![ + f, + "destination buffer too small. {} bytes required, {} provided", + required, provided + ] + } + CableErrorKind::MessageHashResponseEnd {} => { + write![f, "unexpected end of HashResponse"] + } + CableErrorKind::MessageDataResponseEnd {} => { + write![f, "unexpected end of DataResponse"] + } + CableErrorKind::MessageHashRequestEnd {} => { + write![f, "unexpected end of HashRequest"] + } + CableErrorKind::MessageCancelRequestEnd {} => { + write![f, "unexpected end of CancelRequest"] + } + CableErrorKind::MessageChannelTimeRangeRequestEnd {} => { + write![f, "unexpected end of ChannelTimeRangeRequest"] + } + CableErrorKind::MessageChannelStateRequestEnd {} => { + write![f, "unexpected end of ChannelStateRequest"] + } + CableErrorKind::MessageChannelListRequestEnd {} => { + write![f, "unexpected end of ChannelListRequest"] + } + CableErrorKind::PostWriteUnrecognizedType { post_type } => { + write![f, "cannot write unrecognized post_type={}", post_type] + } + } } - } } diff --git a/src/lib.rs b/src/lib.rs index 7c6ae7f..e4b46a6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,18 +1,23 @@ -#![feature(backtrace,async_closure,drain_filter)] +#![feature(backtrace, async_closure, drain_filter)] -use async_std::{prelude::*,sync::{Arc,RwLock},channel,task}; -use std::collections::{HashMap,HashSet}; -use futures::{io::{AsyncRead,AsyncWrite}}; -use desert::{ToBytes,FromBytes}; +use async_std::{ + channel, + prelude::*, + sync::{Arc, RwLock}, + task, +}; +use desert::{FromBytes, ToBytes}; +use futures::io::{AsyncRead, AsyncWrite}; +use std::collections::{HashMap, HashSet}; use std::convert::TryInto; -pub type ReqId = [u8;4]; -pub type ReplyId = [u8;4]; +pub type ReqId = [u8; 4]; +pub type ReplyId = [u8; 4]; pub type PeerId = usize; -pub type Hash = [u8;32]; +pub type Hash = [u8; 32]; pub type Payload = Vec; pub type Channel = Vec; -pub type Error = Box; +pub type Error = Box; mod message; pub use message::*; @@ -23,248 +28,292 @@ pub use store::*; mod error; pub use error::*; mod stream; +use length_prefixed_stream::{decode_with_options, DecodeOptions}; pub use stream::*; -use length_prefixed_stream::{decode_with_options,DecodeOptions}; -#[derive(Clone,Debug,PartialEq)] +#[derive(Clone, Debug, PartialEq)] pub struct ChannelOptions { - pub channel: Vec, - pub time_start: u64, - pub time_end: u64, - pub limit: usize, + pub channel: Vec, + pub time_start: u64, + pub time_end: u64, + pub limit: usize, } #[derive(Clone)] pub struct Cable { - pub store: S, - peers: Arc>>>, - next_peer_id: Arc>, - next_req_id: Arc>, - listening: Arc>>>, - requested: Arc>>, - open_requests: Arc>>, + pub store: S, + peers: Arc>>>, + next_peer_id: Arc>, + next_req_id: Arc>, + listening: Arc>>>, + requested: Arc>>, + open_requests: Arc>>, } -impl Cable where S: Store { - pub fn new(store: S) -> Self { - Self { - store, - peers: Arc::new(RwLock::new(HashMap::new())), - next_peer_id: Arc::new(RwLock::new(0)), - next_req_id: Arc::new(RwLock::new(0)), - listening: Arc::new(RwLock::new(HashMap::new())), - requested: Arc::new(RwLock::new(HashSet::new())), - open_requests: Arc::new(RwLock::new(HashMap::new())), +impl Cable +where + S: Store, +{ + pub fn new(store: S) -> Self { + Self { + store, + peers: Arc::new(RwLock::new(HashMap::new())), + next_peer_id: Arc::new(RwLock::new(0)), + next_req_id: Arc::new(RwLock::new(0)), + listening: Arc::new(RwLock::new(HashMap::new())), + requested: Arc::new(RwLock::new(HashSet::new())), + open_requests: Arc::new(RwLock::new(HashMap::new())), + } } - } } -impl Cable where S: Store { - pub async fn post_text(&mut self, channel: &[u8], text: &[u8]) -> Result<(),Error> { - let timestamp = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH)?.as_secs(); - let post = Post { - header: PostHeader { - public_key: self.get_public_key().await?, - signature: [0;64], - link: self.get_link(channel).await?, - }, - body: PostBody::Text { - channel: channel.to_vec(), - timestamp, - text: text.to_vec(), - } - }; - self.post(post).await - } - pub async fn post(&mut self, mut post: Post) -> Result<(),Error> { - if !post.is_signed() { - post.sign(&self.get_secret_key().await?)?; +impl Cable +where + S: Store, +{ + pub async fn post_text(&mut self, channel: &[u8], text: &[u8]) -> Result<(), Error> { + let timestamp = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH)? + .as_secs(); + let post = Post { + header: PostHeader { + public_key: self.get_public_key().await?, + signature: [0; 64], + link: self.get_link(channel).await?, + }, + body: PostBody::Text { + channel: channel.to_vec(), + timestamp, + text: text.to_vec(), + }, + }; + self.post(post).await } - self.store.insert_post(&post).await?; - for (peer_id,reqs) in self.listening.read().await.iter() { - for (req_id,opts) in reqs { - let n_limit = opts.limit.min(4096); - let mut hashes = vec![]; - { - let mut stream = self.store.get_post_hashes(&opts).await?; - while let Some(result) = stream.next().await { - hashes.push(result?); - if hashes.len() >= n_limit { break } - } + pub async fn post(&mut self, mut post: Post) -> Result<(), Error> { + if !post.is_signed() { + post.sign(&self.get_secret_key().await?)?; } - let response = Message::HashResponse { req_id: req_id.clone(), hashes }; - self.send(*peer_id, &response).await?; - } - } - Ok(()) - } - pub async fn broadcast(&self, message: &Message) -> Result<(),Error> { - for ch in self.peers.read().await.values() { - ch.send(message.clone()).await?; - } - Ok(()) - } - pub async fn send(&self, peer_id: usize, msg: &Message) -> Result<(),Error> { - if let Some(ch) = self.peers.read().await.get(&peer_id) { - ch.send(msg.clone()).await?; + self.store.insert_post(&post).await?; + for (peer_id, reqs) in self.listening.read().await.iter() { + for (req_id, opts) in reqs { + let n_limit = opts.limit.min(4096); + let mut hashes = vec![]; + { + let mut stream = self.store.get_post_hashes(&opts).await?; + while let Some(result) = stream.next().await { + hashes.push(result?); + if hashes.len() >= n_limit { + break; + } + } + } + let response = Message::HashResponse { + req_id: req_id.clone(), + hashes, + }; + self.send(*peer_id, &response).await?; + } + } + Ok(()) } - Ok(()) - } - pub async fn handle(&mut self, peer_id: usize, msg: &Message) -> Result<(),Error> { - // todo: forward requests - match msg { - Message::ChannelTimeRangeRequest { req_id, channel, time_start, time_end, limit, .. } => { - let opts = GetPostOptions { - channel: channel.to_vec(), - time_start: *time_start, - time_end: *time_end, - limit: *limit, - }; - let n_limit = (*limit).min(4096); - let mut hashes = vec![]; - { - let mut stream = self.store.get_post_hashes(&opts).await?; - while let Some(result) = stream.next().await { - hashes.push(result?); - if hashes.len() >= n_limit { break } - } + pub async fn broadcast(&self, message: &Message) -> Result<(), Error> { + for ch in self.peers.read().await.values() { + ch.send(message.clone()).await?; } - let response = Message::HashResponse { req_id: *req_id, hashes }; - { - let mut w = self.listening.write().await; - if let Some(listeners) = w.get_mut(&peer_id) { - listeners.push((req_id.clone(),opts)); - } else { - w.insert(peer_id, vec![(req_id.clone(),opts)]); - } + Ok(()) + } + pub async fn send(&self, peer_id: usize, msg: &Message) -> Result<(), Error> { + if let Some(ch) = self.peers.read().await.get(&peer_id) { + ch.send(msg.clone()).await?; } - self.send(peer_id, &response).await?; - }, - Message::HashResponse { req_id, hashes } => { - let want = self.store.want(hashes).await?; - if !want.is_empty() { - { - let mut mreq = self.requested.write().await; - for hash in &want { - mreq.insert(hash.clone()); + Ok(()) + } + pub async fn handle(&mut self, peer_id: usize, msg: &Message) -> Result<(), Error> { + // todo: forward requests + match msg { + Message::ChannelTimeRangeRequest { + req_id, + channel, + time_start, + time_end, + limit, + .. + } => { + let opts = GetPostOptions { + channel: channel.to_vec(), + time_start: *time_start, + time_end: *time_end, + limit: *limit, + }; + let n_limit = (*limit).min(4096); + let mut hashes = vec![]; + { + let mut stream = self.store.get_post_hashes(&opts).await?; + while let Some(result) = stream.next().await { + hashes.push(result?); + if hashes.len() >= n_limit { + break; + } + } + } + let response = Message::HashResponse { + req_id: *req_id, + hashes, + }; + { + let mut w = self.listening.write().await; + if let Some(listeners) = w.get_mut(&peer_id) { + listeners.push((req_id.clone(), opts)); + } else { + w.insert(peer_id, vec![(req_id.clone(), opts)]); + } + } + self.send(peer_id, &response).await?; + } + Message::HashResponse { req_id, hashes } => { + let want = self.store.want(hashes).await?; + if !want.is_empty() { + { + let mut mreq = self.requested.write().await; + for hash in &want { + mreq.insert(hash.clone()); + } + } + let hreq = Message::HashRequest { + req_id: *req_id, + ttl: 1, + hashes: want, + }; + self.send(peer_id, &hreq).await?; + } + } + Message::HashRequest { + req_id, + ttl: _, + hashes, + } => { + let response = Message::DataResponse { + req_id: *req_id, + data: self.store.get_data(hashes).await?, + }; + self.send(peer_id, &response).await? + } + Message::DataResponse { req_id: _, data } => { + for buf in data { + if !Post::verify(&buf) { + continue; + } + let (s, post) = Post::from_bytes(&buf)?; + if s != buf.len() { + continue; + } + let h = post.hash()?; + { + let mut mreq = self.requested.write().await; + if !mreq.contains(&h) { + continue; + } // didn't request this response + mreq.remove(&h); + } + self.store.insert_post(&post).await?; + } + } + _ => { + //println!["other message type: todo"]; } - } - let hreq = Message::HashRequest { - req_id: *req_id, - ttl: 1, - hashes: want, - }; - self.send(peer_id, &hreq).await?; } - }, - Message::HashRequest { req_id, ttl: _, hashes } => { - let response = Message::DataResponse { - req_id: *req_id, - data: self.store.get_data(hashes).await?, + Ok(()) + } + async fn req_id(&self) -> (u32, ReqId) { + let mut n = self.next_req_id.write().await; + let r = *n; + *n = if *n == u32::MAX { 0 } else { *n + 1 }; + (r, r.to_bytes().unwrap().try_into().unwrap()) + } + pub async fn open_channel( + &mut self, + options: &ChannelOptions, + ) -> Result, Error> { + let (req_id, req_id_bytes) = self.req_id().await; + let m = Message::ChannelTimeRangeRequest { + req_id: req_id_bytes, + ttl: 1, + channel: options.channel.to_vec(), + time_start: options.time_start, + time_end: options.time_end, + limit: options.limit, }; - self.send(peer_id, &response).await? - }, - Message::DataResponse { req_id: _, data } => { - for buf in data { - if !Post::verify(&buf) { continue } - let (s,post) = Post::from_bytes(&buf)?; - if s != buf.len() { continue } - let h = post.hash()?; - { - let mut mreq = self.requested.write().await; - if !mreq.contains(&h) { continue } // didn't request this response - mreq.remove(&h); - } - self.store.insert_post(&post).await?; - } - }, - _ => { - //println!["other message type: todo"]; - }, + self.open_requests.write().await.insert(req_id, m.clone()); + self.broadcast(&m).await?; + Ok(self.store.get_posts_live(options).await?) } - Ok(()) - } - async fn req_id(&self) -> (u32,ReqId) { - let mut n = self.next_req_id.write().await; - let r = *n; - *n = if *n == u32::MAX { 0 } else { *n + 1 }; - (r,r.to_bytes().unwrap().try_into().unwrap()) - } - pub async fn open_channel(&mut self, options: &ChannelOptions) -> Result,Error> { - let (req_id,req_id_bytes) = self.req_id().await; - let m = Message::ChannelTimeRangeRequest { - req_id: req_id_bytes, - ttl: 1, - channel: options.channel.to_vec(), - time_start: options.time_start, - time_end: options.time_end, - limit: options.limit, - }; - self.open_requests.write().await.insert(req_id, m.clone()); - self.broadcast(&m).await?; - Ok(self.store.get_posts_live(options).await?) - } - pub async fn close_channel(&self, _channel: &[u8]) { - unimplemented![] - } - pub async fn get_peer_ids(&self) -> Vec { - self.peers.read().await.keys().copied().collect::>() - } - pub async fn get_link(&mut self, channel: &[u8]) -> Result<[u8;32],Error> { - let link = self.store.get_latest_hash(channel).await?; - Ok(link) - } - pub async fn get_public_key(&mut self) -> Result<[u8;32],Error> { - let (pk,_sk) = self.store.get_or_create_keypair().await?; - Ok(pk) - } - pub async fn get_secret_key(&mut self) -> Result<[u8;64],Error> { - let (_pk,sk) = self.store.get_or_create_keypair().await?; - Ok(sk) - } - pub async fn listen(&self, mut stream: T) -> Result<(),Error> - where T: AsyncRead+AsyncWrite+Clone+Unpin+Send+Sync+'static { - let peer_id = { - let mut n = self.next_peer_id.write().await; - let peer_id = *n; - *n += 1; - peer_id - }; - let (send,recv) = channel::bounded(100); - self.peers.write().await.insert(peer_id, send); - - for msg in self.open_requests.read().await.values() { - stream.write_all(&msg.to_bytes()?).await?; + pub async fn close_channel(&self, _channel: &[u8]) { + unimplemented![] + } + pub async fn get_peer_ids(&self) -> Vec { + self.peers + .read() + .await + .keys() + .copied() + .collect::>() + } + pub async fn get_link(&mut self, channel: &[u8]) -> Result<[u8; 32], Error> { + let link = self.store.get_latest_hash(channel).await?; + Ok(link) } + pub async fn get_public_key(&mut self) -> Result<[u8; 32], Error> { + let (pk, _sk) = self.store.get_or_create_keypair().await?; + Ok(pk) + } + pub async fn get_secret_key(&mut self) -> Result<[u8; 64], Error> { + let (_pk, sk) = self.store.get_or_create_keypair().await?; + Ok(sk) + } + pub async fn listen(&self, mut stream: T) -> Result<(), Error> + where + T: AsyncRead + AsyncWrite + Clone + Unpin + Send + Sync + 'static, + { + let peer_id = { + let mut n = self.next_peer_id.write().await; + let peer_id = *n; + *n += 1; + peer_id + }; + let (send, recv) = channel::bounded(100); + self.peers.write().await.insert(peer_id, send); - let w = { - let mut cstream = stream.clone(); - task::spawn(async move { - while let Ok(msg) = recv.recv().await { - cstream.write_all(&msg.to_bytes().unwrap()).await.unwrap(); + for msg in self.open_requests.read().await.values() { + stream.write_all(&msg.to_bytes()?).await?; } - let res: Result<(),Error> = Ok(()); - res - }) - }; - let mut options = DecodeOptions::default(); - options.include_len = true; - let mut lps = decode_with_options(stream, options); - while let Some(rbuf) = lps.next().await { - let buf = rbuf?; - let (_,msg) = Message::from_bytes(&buf)?; - let mut this = self.clone(); - task::spawn(async move { - if let Err(_e) = this.handle(peer_id, &msg).await { - //eprintln!["{}", e]; + let w = { + let mut cstream = stream.clone(); + task::spawn(async move { + while let Ok(msg) = recv.recv().await { + cstream.write_all(&msg.to_bytes().unwrap()).await.unwrap(); + } + let res: Result<(), Error> = Ok(()); + res + }) + }; + + let mut options = DecodeOptions::default(); + options.include_len = true; + let mut lps = decode_with_options(stream, options); + while let Some(rbuf) = lps.next().await { + let buf = rbuf?; + let (_, msg) = Message::from_bytes(&buf)?; + let mut this = self.clone(); + task::spawn(async move { + if let Err(_e) = this.handle(peer_id, &msg).await { + //eprintln!["{}", e]; + } + }); } - }); - } - w.await?; - self.peers.write().await.remove(&peer_id); - Ok(()) - } + w.await?; + self.peers.write().await.remove(&peer_id); + Ok(()) + } } diff --git a/src/message.rs b/src/message.rs index b1a754d..86a7596 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,336 +1,440 @@ -use desert::{FromBytes,ToBytes,CountBytes,varint}; -use crate::{ReqId,Hash,Payload,Channel,Error,error::CableErrorKind as E}; +use crate::{error::CableErrorKind as E, Channel, Error, Hash, Payload, ReqId}; +use desert::{varint, CountBytes, FromBytes, ToBytes}; -#[derive(Clone,Debug)] +#[derive(Clone, Debug)] pub enum Message { - HashResponse { - req_id: ReqId, - //reply_id: ReplyId, - hashes: Vec, - }, - DataResponse { - req_id: ReqId, - //reply_id: ReplyId, - data: Vec, - }, - HashRequest { - req_id: ReqId, - //reply_id: ReplyId, - ttl: usize, - hashes: Vec, - }, - CancelRequest { - req_id: ReqId, - }, - ChannelTimeRangeRequest { - req_id: ReqId, - //reply_id: ReplyId, - ttl: usize, - channel: Channel, - time_start: u64, - time_end: u64, - limit: usize, - }, - ChannelStateRequest { - req_id: ReqId, - //reply_id: ReplyId, - ttl: usize, - channel: Channel, - limit: usize, - updates: usize, - }, - ChannelListRequest { - req_id: ReqId, - //reply_id: ReplyId, - ttl: usize, - limit: usize, - }, - Unrecognized { - msg_type: u64, - } + HashResponse { + req_id: ReqId, + //reply_id: ReplyId, + hashes: Vec, + }, + DataResponse { + req_id: ReqId, + //reply_id: ReplyId, + data: Vec, + }, + HashRequest { + req_id: ReqId, + //reply_id: ReplyId, + ttl: usize, + hashes: Vec, + }, + CancelRequest { + req_id: ReqId, + }, + ChannelTimeRangeRequest { + req_id: ReqId, + //reply_id: ReplyId, + ttl: usize, + channel: Channel, + time_start: u64, + time_end: u64, + limit: usize, + }, + ChannelStateRequest { + req_id: ReqId, + //reply_id: ReplyId, + ttl: usize, + channel: Channel, + limit: usize, + updates: usize, + }, + ChannelListRequest { + req_id: ReqId, + //reply_id: ReplyId, + ttl: usize, + limit: usize, + }, + Unrecognized { + msg_type: u64, + }, } impl CountBytes for Message { - fn count_bytes(&self) -> usize { - let size = match self { - Self::HashResponse { hashes, .. } => { - varint::length(0) + 4 + varint::length(hashes.len() as u64) + hashes.len()*32 - }, - Self::DataResponse { data, .. } => { - varint::length(1) + 4 - + data.iter().fold(0, |sum,d| sum + varint::length(d.len() as u64) + d.len()) - + varint::length(0) - }, - Self::HashRequest { ttl, hashes, .. } => { - varint::length(2) + 4 + varint::length(*ttl as u64) - + varint::length(hashes.len() as u64) + hashes.len()*32 - }, - Self::CancelRequest { .. } => { - varint::length(3) + 4 - }, - Self::ChannelTimeRangeRequest { ttl, channel, time_start, time_end, limit, .. } => { - varint::length(4) + 4 + varint::length(*ttl as u64) - + varint::length(channel.len() as u64) + channel.len() - + varint::length(*time_start) + varint::length(*time_end) - + varint::length(*limit as u64) - }, - Self::ChannelStateRequest { ttl, channel, limit, updates, .. } => { - varint::length(5) + 4 + varint::length(*ttl as u64) - + varint::length(channel.len() as u64) + channel.len() - + varint::length(*limit as u64) + varint::length(*updates as u64) - }, - Self::ChannelListRequest { ttl, limit, .. } => { - varint::length(6) + 4 + varint::length(*ttl as u64) + varint::length(*limit as u64) - }, - Self::Unrecognized { .. } => 0, - }; - varint::length(size as u64) + size - } - fn count_from_bytes(buf: &[u8]) -> Result { - if buf.is_empty() { return E::MessageEmpty {}.raise() } - let (s,msg_len) = varint::decode(buf)?; - Ok(s + (msg_len as usize)) - } + fn count_bytes(&self) -> usize { + let size = match self { + Self::HashResponse { hashes, .. } => { + varint::length(0) + 4 + varint::length(hashes.len() as u64) + hashes.len() * 32 + } + Self::DataResponse { data, .. } => { + varint::length(1) + + 4 + + data + .iter() + .fold(0, |sum, d| sum + varint::length(d.len() as u64) + d.len()) + + varint::length(0) + } + Self::HashRequest { ttl, hashes, .. } => { + varint::length(2) + + 4 + + varint::length(*ttl as u64) + + varint::length(hashes.len() as u64) + + hashes.len() * 32 + } + Self::CancelRequest { .. } => varint::length(3) + 4, + Self::ChannelTimeRangeRequest { + ttl, + channel, + time_start, + time_end, + limit, + .. + } => { + varint::length(4) + + 4 + + varint::length(*ttl as u64) + + varint::length(channel.len() as u64) + + channel.len() + + varint::length(*time_start) + + varint::length(*time_end) + + varint::length(*limit as u64) + } + Self::ChannelStateRequest { + ttl, + channel, + limit, + updates, + .. + } => { + varint::length(5) + + 4 + + varint::length(*ttl as u64) + + varint::length(channel.len() as u64) + + channel.len() + + varint::length(*limit as u64) + + varint::length(*updates as u64) + } + Self::ChannelListRequest { ttl, limit, .. } => { + varint::length(6) + 4 + varint::length(*ttl as u64) + varint::length(*limit as u64) + } + Self::Unrecognized { .. } => 0, + }; + varint::length(size as u64) + size + } + fn count_from_bytes(buf: &[u8]) -> Result { + if buf.is_empty() { + return E::MessageEmpty {}.raise(); + } + let (s, msg_len) = varint::decode(buf)?; + Ok(s + (msg_len as usize)) + } } impl ToBytes for Message { - fn to_bytes(&self) -> Result,Error> { - let mut buf = vec![0;self.count_bytes()]; - self.write_bytes(&mut buf)?; - Ok(buf) - } - fn write_bytes(&self, buf: &mut [u8]) -> Result { - let mut offset = 0; - let mut msg_len = self.count_bytes(); - msg_len -= varint::length(msg_len as u64); - offset += varint::encode(msg_len as u64, &mut buf[offset..])?; - let msg_type = match self { - Self::HashResponse { .. } => 0, - Self::DataResponse { .. } => 1, - Self::HashRequest { .. } => 2, - Self::CancelRequest { .. } => 3, - Self::ChannelTimeRangeRequest { .. } => 4, - Self::ChannelStateRequest { .. } => 5, - Self::ChannelListRequest { .. } => 6, - Self::Unrecognized { msg_type } => { - return E::MessageWriteUnrecognizedType { msg_type: *msg_type }.raise() - }, - }; - offset += varint::encode(msg_type, &mut buf[offset..])?; - Ok(match self { - Self::HashResponse { req_id, hashes } => { - offset += req_id.write_bytes(&mut buf[offset..])?; - offset += varint::encode(hashes.len() as u64, &mut buf[offset..])?; - for hash in hashes.iter() { - if offset+hash.len() > buf.len() { - return E::DstTooSmall { - required: offset+hash.len(), - provided: buf.len(), - }.raise(); - } - buf[offset..offset+hash.len()].copy_from_slice(hash); - offset += hash.len(); - } - offset - }, - Self::DataResponse { req_id, data } => { - offset += req_id.write_bytes(&mut buf[offset..])?; - for d in data.iter() { - offset += varint::encode(d.len() as u64, &mut buf[offset..])?; - if offset+d.len() > buf.len() { - return E::DstTooSmall { - required: offset+d.len(), - provided: buf.len(), - }.raise(); - } - buf[offset..offset+d.len()].copy_from_slice(d); - offset += d.len(); - } - offset += varint::encode(0, &mut buf[offset..])?; - offset - }, - Self::HashRequest { req_id, ttl, hashes } => { - offset += req_id.write_bytes(&mut buf[offset..])?; - offset += varint::encode(*ttl as u64, &mut buf[offset..])?; - offset += varint::encode(hashes.len() as u64, &mut buf[offset..])?; - for hash in hashes.iter() { - if offset+hash.len() > buf.len() { - return E::DstTooSmall { - required: offset+hash.len(), - provided: buf.len(), - }.raise(); - } - buf[offset..offset+hash.len()].copy_from_slice(hash); - offset += hash.len(); - } - offset - }, - Self::CancelRequest { req_id } => { - offset += req_id.write_bytes(&mut buf[offset..])?; - offset - }, - Self::ChannelTimeRangeRequest { req_id, ttl, channel, time_start, time_end, limit } => { - offset += req_id.write_bytes(&mut buf[offset..])?; - offset += varint::encode(*ttl as u64, &mut buf[offset..])?; - if offset+channel.len() > buf.len() { - return E::DstTooSmall { - required: offset+channel.len(), - provided: buf.len(), - }.raise(); - } - offset += channel.write_bytes(&mut buf[offset..])?; - offset += varint::encode(*time_start, &mut buf[offset..])?; - offset += varint::encode(*time_end, &mut buf[offset..])?; - offset += varint::encode(*limit as u64, &mut buf[offset..])?; - offset - }, - Self::ChannelStateRequest { req_id, ttl, channel, limit, updates } => { - offset += req_id.write_bytes(&mut buf[offset..])?; - offset += varint::encode(*ttl as u64, &mut buf[offset..])?; - if offset+channel.len() > buf.len() { - return E::DstTooSmall { - required: offset+channel.len(), - provided: buf.len(), - }.raise(); - } - offset += channel.write_bytes(&mut buf[offset..])?; - offset += varint::encode(*limit as u64, &mut buf[offset..])?; - offset += varint::encode(*updates as u64, &mut buf[offset..])?; - offset - }, - Self::ChannelListRequest { req_id, ttl, limit } => { - offset += req_id.write_bytes(&mut buf[offset..])?; - offset += varint::encode(*ttl as u64, &mut buf[offset..])?; - offset += varint::encode(*limit as u64, &mut buf[offset..])?; - offset - }, - Self::Unrecognized { msg_type } => { - return E::MessageWriteUnrecognizedType { msg_type: *msg_type }.raise(); - } - }) - } + fn to_bytes(&self) -> Result, Error> { + let mut buf = vec![0; self.count_bytes()]; + self.write_bytes(&mut buf)?; + Ok(buf) + } + fn write_bytes(&self, buf: &mut [u8]) -> Result { + let mut offset = 0; + let mut msg_len = self.count_bytes(); + msg_len -= varint::length(msg_len as u64); + offset += varint::encode(msg_len as u64, &mut buf[offset..])?; + let msg_type = match self { + Self::HashResponse { .. } => 0, + Self::DataResponse { .. } => 1, + Self::HashRequest { .. } => 2, + Self::CancelRequest { .. } => 3, + Self::ChannelTimeRangeRequest { .. } => 4, + Self::ChannelStateRequest { .. } => 5, + Self::ChannelListRequest { .. } => 6, + Self::Unrecognized { msg_type } => { + return E::MessageWriteUnrecognizedType { + msg_type: *msg_type, + } + .raise() + } + }; + offset += varint::encode(msg_type, &mut buf[offset..])?; + Ok(match self { + Self::HashResponse { req_id, hashes } => { + offset += req_id.write_bytes(&mut buf[offset..])?; + offset += varint::encode(hashes.len() as u64, &mut buf[offset..])?; + for hash in hashes.iter() { + if offset + hash.len() > buf.len() { + return E::DstTooSmall { + required: offset + hash.len(), + provided: buf.len(), + } + .raise(); + } + buf[offset..offset + hash.len()].copy_from_slice(hash); + offset += hash.len(); + } + offset + } + Self::DataResponse { req_id, data } => { + offset += req_id.write_bytes(&mut buf[offset..])?; + for d in data.iter() { + offset += varint::encode(d.len() as u64, &mut buf[offset..])?; + if offset + d.len() > buf.len() { + return E::DstTooSmall { + required: offset + d.len(), + provided: buf.len(), + } + .raise(); + } + buf[offset..offset + d.len()].copy_from_slice(d); + offset += d.len(); + } + offset += varint::encode(0, &mut buf[offset..])?; + offset + } + Self::HashRequest { + req_id, + ttl, + hashes, + } => { + offset += req_id.write_bytes(&mut buf[offset..])?; + offset += varint::encode(*ttl as u64, &mut buf[offset..])?; + offset += varint::encode(hashes.len() as u64, &mut buf[offset..])?; + for hash in hashes.iter() { + if offset + hash.len() > buf.len() { + return E::DstTooSmall { + required: offset + hash.len(), + provided: buf.len(), + } + .raise(); + } + buf[offset..offset + hash.len()].copy_from_slice(hash); + offset += hash.len(); + } + offset + } + Self::CancelRequest { req_id } => { + offset += req_id.write_bytes(&mut buf[offset..])?; + offset + } + Self::ChannelTimeRangeRequest { + req_id, + ttl, + channel, + time_start, + time_end, + limit, + } => { + offset += req_id.write_bytes(&mut buf[offset..])?; + offset += varint::encode(*ttl as u64, &mut buf[offset..])?; + if offset + channel.len() > buf.len() { + return E::DstTooSmall { + required: offset + channel.len(), + provided: buf.len(), + } + .raise(); + } + offset += channel.write_bytes(&mut buf[offset..])?; + offset += varint::encode(*time_start, &mut buf[offset..])?; + offset += varint::encode(*time_end, &mut buf[offset..])?; + offset += varint::encode(*limit as u64, &mut buf[offset..])?; + offset + } + Self::ChannelStateRequest { + req_id, + ttl, + channel, + limit, + updates, + } => { + offset += req_id.write_bytes(&mut buf[offset..])?; + offset += varint::encode(*ttl as u64, &mut buf[offset..])?; + if offset + channel.len() > buf.len() { + return E::DstTooSmall { + required: offset + channel.len(), + provided: buf.len(), + } + .raise(); + } + offset += channel.write_bytes(&mut buf[offset..])?; + offset += varint::encode(*limit as u64, &mut buf[offset..])?; + offset += varint::encode(*updates as u64, &mut buf[offset..])?; + offset + } + Self::ChannelListRequest { req_id, ttl, limit } => { + offset += req_id.write_bytes(&mut buf[offset..])?; + offset += varint::encode(*ttl as u64, &mut buf[offset..])?; + offset += varint::encode(*limit as u64, &mut buf[offset..])?; + offset + } + Self::Unrecognized { msg_type } => { + return E::MessageWriteUnrecognizedType { + msg_type: *msg_type, + } + .raise(); + } + }) + } } impl FromBytes for Message { - fn from_bytes(buf: &[u8]) -> Result<(usize,Self),Error> { - if buf.is_empty() { return E::MessageEmpty {}.raise() } - let mut offset = 0; - let (s,nbytes) = varint::decode(&buf[offset..])?; - offset += s; - let msg_len = (nbytes as usize) + s; - let (s,msg_type) = varint::decode(&buf[offset..])?; - offset += s; - Ok(match msg_type { - 0 => { - if offset+4 > buf.len() { return E::MessageHashResponseEnd {}.raise() } - let mut req_id = [0;4]; - req_id.copy_from_slice(&buf[offset..offset+4]); - offset += 4; - let (s,hash_count) = varint::decode(&buf[offset..])?; - offset += s; - let mut hashes = Vec::with_capacity(hash_count as usize); - for _ in 0..hash_count { - if offset+32 > buf.len() { return E::MessageHashResponseEnd {}.raise() } - let mut hash = [0;32]; - hash.copy_from_slice(&buf[offset..offset+32]); - offset += 32; - hashes.push(hash); - } - (msg_len, Self::HashResponse { req_id, hashes }) - }, - 1 => { - if offset+4 > buf.len() { return E::MessageDataResponseEnd {}.raise() } - let mut req_id = [0;4]; - req_id.copy_from_slice(&buf[offset..offset+4]); - offset += 4; - let mut data = vec![]; - loop { - let (s,data_len) = varint::decode(&buf[offset..])?; - offset += s; - if data_len == 0 { break } - data.push(buf[offset..offset+(data_len as usize)].to_vec()); - offset += data_len as usize; - } - (msg_len, Self::DataResponse { req_id, data }) - }, - 2 => { - if offset+4 > buf.len() { return E::MessageHashRequestEnd {}.raise() } - let mut req_id = [0;4]; - req_id.copy_from_slice(&buf[offset..offset+4]); - offset += 4; - let (s,ttl) = varint::decode(&buf[offset..])?; - offset += s; - let (s,hash_count) = varint::decode(&buf[offset..])?; - offset += s; - let mut hashes = Vec::with_capacity(hash_count as usize); - for _ in 0..hash_count { - if offset+32 > buf.len() { return E::MessageHashRequestEnd {}.raise() } - let mut hash = [0;32]; - hash.copy_from_slice(&buf[offset..offset+32]); - offset += 32; - hashes.push(hash); + fn from_bytes(buf: &[u8]) -> Result<(usize, Self), Error> { + if buf.is_empty() { + return E::MessageEmpty {}.raise(); } - (msg_len, Self::HashRequest { req_id, ttl: ttl as usize, hashes }) - }, - 3 => { - if offset+4 > buf.len() { return E::MessageCancelRequestEnd {}.raise() } - let mut req_id = [0;4]; - req_id.copy_from_slice(&buf[offset..offset+4]); - (msg_len, Self::CancelRequest { req_id }) - }, - 4 => { - if offset+4 > buf.len() { return E::MessageChannelTimeRangeRequestEnd {}.raise() } - let mut req_id = [0;4]; - req_id.copy_from_slice(&buf[offset..offset+4]); - offset += 4; - let (s,ttl) = varint::decode(&buf[offset..])?; - offset += s; - let (s,channel_len) = varint::decode(&buf[offset..])?; - offset += s; - let channel = buf[offset..offset+channel_len as usize].to_vec(); - offset += channel_len as usize; - let (s,time_start) = varint::decode(&buf[offset..])?; - offset += s; - let (s,time_end) = varint::decode(&buf[offset..])?; + let mut offset = 0; + let (s, nbytes) = varint::decode(&buf[offset..])?; offset += s; - let (_,limit) = varint::decode(&buf[offset..])?; - //offset += s; - (msg_len, Self::ChannelTimeRangeRequest { - req_id, ttl: ttl as usize, channel, time_start, time_end, limit: limit as usize - }) - }, - 5 => { - if offset+4 > buf.len() { return E::MessageChannelStateRequestEnd {}.raise() } - let mut req_id = [0;4]; - req_id.copy_from_slice(&buf[offset..offset+4]); - offset += 4; - let (s,ttl) = varint::decode(&buf[offset..])?; - offset += s; - let (s,channel_len) = varint::decode(&buf[offset..])?; - offset += s; - let channel = buf[offset..offset+channel_len as usize].to_vec(); - offset += channel_len as usize; - let (s,limit) = varint::decode(&buf[offset..])?; - offset += s; - let (_,updates) = varint::decode(&buf[offset..])?; - (msg_len, Self::ChannelStateRequest { - req_id, ttl: ttl as usize, channel, limit: limit as usize, updates: updates as usize - }) - }, - 6 => { - if offset+4 > buf.len() { return E::MessageChannelListRequestEnd {}.raise() } - let mut req_id = [0;4]; - req_id.copy_from_slice(&buf[offset..offset+4]); - offset += 4; - let (s,ttl) = varint::decode(&buf[offset..])?; + let msg_len = (nbytes as usize) + s; + let (s, msg_type) = varint::decode(&buf[offset..])?; offset += s; - let (_,limit) = varint::decode(&buf[offset..])?; - //offset += s; - (msg_len, Self::ChannelListRequest { - req_id, ttl: ttl as usize, limit: limit as usize + Ok(match msg_type { + 0 => { + if offset + 4 > buf.len() { + return E::MessageHashResponseEnd {}.raise(); + } + let mut req_id = [0; 4]; + req_id.copy_from_slice(&buf[offset..offset + 4]); + offset += 4; + let (s, hash_count) = varint::decode(&buf[offset..])?; + offset += s; + let mut hashes = Vec::with_capacity(hash_count as usize); + for _ in 0..hash_count { + if offset + 32 > buf.len() { + return E::MessageHashResponseEnd {}.raise(); + } + let mut hash = [0; 32]; + hash.copy_from_slice(&buf[offset..offset + 32]); + offset += 32; + hashes.push(hash); + } + (msg_len, Self::HashResponse { req_id, hashes }) + } + 1 => { + if offset + 4 > buf.len() { + return E::MessageDataResponseEnd {}.raise(); + } + let mut req_id = [0; 4]; + req_id.copy_from_slice(&buf[offset..offset + 4]); + offset += 4; + let mut data = vec![]; + loop { + let (s, data_len) = varint::decode(&buf[offset..])?; + offset += s; + if data_len == 0 { + break; + } + data.push(buf[offset..offset + (data_len as usize)].to_vec()); + offset += data_len as usize; + } + (msg_len, Self::DataResponse { req_id, data }) + } + 2 => { + if offset + 4 > buf.len() { + return E::MessageHashRequestEnd {}.raise(); + } + let mut req_id = [0; 4]; + req_id.copy_from_slice(&buf[offset..offset + 4]); + offset += 4; + let (s, ttl) = varint::decode(&buf[offset..])?; + offset += s; + let (s, hash_count) = varint::decode(&buf[offset..])?; + offset += s; + let mut hashes = Vec::with_capacity(hash_count as usize); + for _ in 0..hash_count { + if offset + 32 > buf.len() { + return E::MessageHashRequestEnd {}.raise(); + } + let mut hash = [0; 32]; + hash.copy_from_slice(&buf[offset..offset + 32]); + offset += 32; + hashes.push(hash); + } + ( + msg_len, + Self::HashRequest { + req_id, + ttl: ttl as usize, + hashes, + }, + ) + } + 3 => { + if offset + 4 > buf.len() { + return E::MessageCancelRequestEnd {}.raise(); + } + let mut req_id = [0; 4]; + req_id.copy_from_slice(&buf[offset..offset + 4]); + (msg_len, Self::CancelRequest { req_id }) + } + 4 => { + if offset + 4 > buf.len() { + return E::MessageChannelTimeRangeRequestEnd {}.raise(); + } + let mut req_id = [0; 4]; + req_id.copy_from_slice(&buf[offset..offset + 4]); + offset += 4; + let (s, ttl) = varint::decode(&buf[offset..])?; + offset += s; + let (s, channel_len) = varint::decode(&buf[offset..])?; + offset += s; + let channel = buf[offset..offset + channel_len as usize].to_vec(); + offset += channel_len as usize; + let (s, time_start) = varint::decode(&buf[offset..])?; + offset += s; + let (s, time_end) = varint::decode(&buf[offset..])?; + offset += s; + let (_, limit) = varint::decode(&buf[offset..])?; + //offset += s; + ( + msg_len, + Self::ChannelTimeRangeRequest { + req_id, + ttl: ttl as usize, + channel, + time_start, + time_end, + limit: limit as usize, + }, + ) + } + 5 => { + if offset + 4 > buf.len() { + return E::MessageChannelStateRequestEnd {}.raise(); + } + let mut req_id = [0; 4]; + req_id.copy_from_slice(&buf[offset..offset + 4]); + offset += 4; + let (s, ttl) = varint::decode(&buf[offset..])?; + offset += s; + let (s, channel_len) = varint::decode(&buf[offset..])?; + offset += s; + let channel = buf[offset..offset + channel_len as usize].to_vec(); + offset += channel_len as usize; + let (s, limit) = varint::decode(&buf[offset..])?; + offset += s; + let (_, updates) = varint::decode(&buf[offset..])?; + ( + msg_len, + Self::ChannelStateRequest { + req_id, + ttl: ttl as usize, + channel, + limit: limit as usize, + updates: updates as usize, + }, + ) + } + 6 => { + if offset + 4 > buf.len() { + return E::MessageChannelListRequestEnd {}.raise(); + } + let mut req_id = [0; 4]; + req_id.copy_from_slice(&buf[offset..offset + 4]); + offset += 4; + let (s, ttl) = varint::decode(&buf[offset..])?; + offset += s; + let (_, limit) = varint::decode(&buf[offset..])?; + //offset += s; + ( + msg_len, + Self::ChannelListRequest { + req_id, + ttl: ttl as usize, + limit: limit as usize, + }, + ) + } + msg_type => (msg_len, Self::Unrecognized { msg_type }), }) - }, - msg_type => (msg_len, Self::Unrecognized { msg_type }) - }) - } + } } diff --git a/src/post.rs b/src/post.rs index a478a06..5fc7639 100644 --- a/src/post.rs +++ b/src/post.rs @@ -1,316 +1,368 @@ -use desert::{FromBytes,ToBytes,CountBytes,varint}; -use crate::{Error,Hash,Channel,error::CableErrorKind as E}; +use crate::{error::CableErrorKind as E, Channel, Error, Hash}; +use desert::{varint, CountBytes, FromBytes, ToBytes}; use sodiumoxide::crypto; use std::convert::TryInto; -#[derive(Clone,Debug)] +#[derive(Clone, Debug)] pub struct Post { - pub header: PostHeader, - pub body: PostBody, + pub header: PostHeader, + pub body: PostBody, } -#[derive(Clone,Debug)] +#[derive(Clone, Debug)] pub struct PostHeader { - pub public_key: [u8;32], - pub signature: [u8;64], - pub link: [u8;32], + pub public_key: [u8; 32], + pub signature: [u8; 64], + pub link: [u8; 32], } -#[derive(Clone,Debug)] +#[derive(Clone, Debug)] pub enum PostBody { - Text { - channel: Channel, - timestamp: u64, - text: Vec, - }, - Delete { - timestamp: u64, - hash: Hash, - }, - Info { - timestamp: u64, - key: Vec, - value: Vec, - }, - Topic { - channel: Channel, - timestamp: u64, - topic: Vec, - }, - Join { - channel: Channel, - timestamp: u64, - }, - Leave { - channel: Channel, - timestamp: u64, - }, - Unrecognized { - post_type: u64 - }, + Text { + channel: Channel, + timestamp: u64, + text: Vec, + }, + Delete { + timestamp: u64, + hash: Hash, + }, + Info { + timestamp: u64, + key: Vec, + value: Vec, + }, + Topic { + channel: Channel, + timestamp: u64, + topic: Vec, + }, + Join { + channel: Channel, + timestamp: u64, + }, + Leave { + channel: Channel, + timestamp: u64, + }, + Unrecognized { + post_type: u64, + }, } impl Post { - pub fn post_type(&self) -> u64 { - match &self.body { - PostBody::Text { .. } => 0, - PostBody::Delete { .. } => 1, - PostBody::Info { .. } => 2, - PostBody::Topic { .. } => 3, - PostBody::Join { .. } => 4, - PostBody::Leave { .. } => 5, - PostBody::Unrecognized { post_type } => *post_type, + pub fn post_type(&self) -> u64 { + match &self.body { + PostBody::Text { .. } => 0, + PostBody::Delete { .. } => 1, + PostBody::Info { .. } => 2, + PostBody::Topic { .. } => 3, + PostBody::Join { .. } => 4, + PostBody::Leave { .. } => 5, + PostBody::Unrecognized { post_type } => *post_type, + } } - } - pub fn verify(buf: &[u8]) -> bool { - if buf.len() < 32+64 { return false } - let o_pk = crypto::sign::PublicKey::from_slice(&buf[0..32]); - let o_sig = crypto::sign::Signature::from_bytes(&buf[32..32+64]); - match (o_pk,o_sig) { - (Some(pk),Ok(sig)) => crypto::sign::verify_detached(&sig, &buf[32+64..], &pk), - _ => false, + pub fn verify(buf: &[u8]) -> bool { + if buf.len() < 32 + 64 { + return false; + } + let o_pk = crypto::sign::PublicKey::from_slice(&buf[0..32]); + let o_sig = crypto::sign::Signature::from_bytes(&buf[32..32 + 64]); + match (o_pk, o_sig) { + (Some(pk), Ok(sig)) => crypto::sign::verify_detached(&sig, &buf[32 + 64..], &pk), + _ => false, + } } - } - pub fn sign(&mut self, secret_key: &[u8;64]) -> Result<(),Error> { - let buf = self.to_bytes()?; - let sk = crypto::sign::SecretKey::from_slice(secret_key).unwrap(); - // todo: return NoneError - self.header.signature = crypto::sign::sign_detached(&buf[32+64..], &sk).to_bytes(); - Ok(()) - } - pub fn is_signed(&self) -> bool { - for i in 0..self.header.signature.len() { - if self.header.signature[i] != 0 { - return true; - } + pub fn sign(&mut self, secret_key: &[u8; 64]) -> Result<(), Error> { + let buf = self.to_bytes()?; + let sk = crypto::sign::SecretKey::from_slice(secret_key).unwrap(); + // todo: return NoneError + self.header.signature = crypto::sign::sign_detached(&buf[32 + 64..], &sk).to_bytes(); + Ok(()) } - return false; - } - pub fn hash(&self) -> Result { - let buf = self.to_bytes()?; - let digest = crypto::generichash::hash(&buf, Some(32), None).unwrap(); - Ok(digest.as_ref().try_into()?) - } - pub fn get_timestamp(&self) -> Option { - match &self.body { - PostBody::Text { timestamp, .. } => Some(*timestamp), - PostBody::Delete { timestamp, .. } => Some(*timestamp), - PostBody::Info { timestamp, .. } => Some(*timestamp), - PostBody::Topic { timestamp, .. } => Some(*timestamp), - PostBody::Join { timestamp, .. } => Some(*timestamp), - PostBody::Leave { timestamp, .. } => Some(*timestamp), - PostBody::Unrecognized { .. } => None, + pub fn is_signed(&self) -> bool { + for i in 0..self.header.signature.len() { + if self.header.signature[i] != 0 { + return true; + } + } + return false; } - } - pub fn get_channel<'a>(&'a self) -> Option<&'a Channel> { - match &self.body { - PostBody::Text { channel, .. } => Some(channel), - PostBody::Delete { .. } => None, - PostBody::Info { .. } => None, - PostBody::Topic { channel, .. } => Some(channel), - PostBody::Join { channel, .. } => Some(channel), - PostBody::Leave { channel, .. } => Some(channel), - PostBody::Unrecognized { .. } => None, + pub fn hash(&self) -> Result { + let buf = self.to_bytes()?; + let digest = crypto::generichash::hash(&buf, Some(32), None).unwrap(); + Ok(digest.as_ref().try_into()?) + } + pub fn get_timestamp(&self) -> Option { + match &self.body { + PostBody::Text { timestamp, .. } => Some(*timestamp), + PostBody::Delete { timestamp, .. } => Some(*timestamp), + PostBody::Info { timestamp, .. } => Some(*timestamp), + PostBody::Topic { timestamp, .. } => Some(*timestamp), + PostBody::Join { timestamp, .. } => Some(*timestamp), + PostBody::Leave { timestamp, .. } => Some(*timestamp), + PostBody::Unrecognized { .. } => None, + } + } + pub fn get_channel<'a>(&'a self) -> Option<&'a Channel> { + match &self.body { + PostBody::Text { channel, .. } => Some(channel), + PostBody::Delete { .. } => None, + PostBody::Info { .. } => None, + PostBody::Topic { channel, .. } => Some(channel), + PostBody::Join { channel, .. } => Some(channel), + PostBody::Leave { channel, .. } => Some(channel), + PostBody::Unrecognized { .. } => None, + } } - } } impl CountBytes for Post { - fn count_bytes(&self) -> usize { - let post_type = self.post_type(); - let header_size = 32 + 64 + 32; - let body_size = varint::length(post_type) + match &self.body { - PostBody::Text { channel, timestamp, text } => { - varint::length(channel.len() as u64) + channel.len() - + varint::length(*timestamp) - + varint::length(text.len() as u64) + text.len() - }, - PostBody::Delete { timestamp, hash } => { - varint::length(*timestamp) + hash.len() - }, - PostBody::Info { timestamp, key, value } => { - varint::length(*timestamp) - + varint::length(key.len() as u64) + key.len() - + varint::length(value.len() as u64) + value.len() - }, - PostBody::Topic { channel, timestamp, topic } => { - varint::length(channel.len() as u64) + channel.len() - + varint::length(*timestamp) - + varint::length(topic.len() as u64) + topic.len() - }, - PostBody::Join { channel, timestamp } => { - varint::length(channel.len() as u64) + channel.len() + varint::length(*timestamp) - }, - PostBody::Leave { channel, timestamp } => { - varint::length(channel.len() as u64) + channel.len() + varint::length(*timestamp) - }, - PostBody::Unrecognized { .. } => 0, - }; - header_size + body_size - } - fn count_from_bytes(_buf: &[u8]) -> Result { - unimplemented![] - } + fn count_bytes(&self) -> usize { + let post_type = self.post_type(); + let header_size = 32 + 64 + 32; + let body_size = varint::length(post_type) + + match &self.body { + PostBody::Text { + channel, + timestamp, + text, + } => { + varint::length(channel.len() as u64) + + channel.len() + + varint::length(*timestamp) + + varint::length(text.len() as u64) + + text.len() + } + PostBody::Delete { timestamp, hash } => varint::length(*timestamp) + hash.len(), + PostBody::Info { + timestamp, + key, + value, + } => { + varint::length(*timestamp) + + varint::length(key.len() as u64) + + key.len() + + varint::length(value.len() as u64) + + value.len() + } + PostBody::Topic { + channel, + timestamp, + topic, + } => { + varint::length(channel.len() as u64) + + channel.len() + + varint::length(*timestamp) + + varint::length(topic.len() as u64) + + topic.len() + } + PostBody::Join { channel, timestamp } => { + varint::length(channel.len() as u64) + + channel.len() + + varint::length(*timestamp) + } + PostBody::Leave { channel, timestamp } => { + varint::length(channel.len() as u64) + + channel.len() + + varint::length(*timestamp) + } + PostBody::Unrecognized { .. } => 0, + }; + header_size + body_size + } + fn count_from_bytes(_buf: &[u8]) -> Result { + unimplemented![] + } } impl ToBytes for Post { - fn to_bytes(&self) -> Result,Error> { - let mut buf = vec![0;self.count_bytes()]; - self.write_bytes(&mut buf)?; - Ok(buf) - } - fn write_bytes(&self, buf: &mut [u8]) -> Result { - let mut offset = 0; - assert_eq![self.header.public_key.len(), 32]; - assert_eq![self.header.signature.len(), 64]; - assert_eq![self.header.link.len(), 32]; - buf[offset..offset+32].copy_from_slice(&self.header.public_key); - offset += self.header.public_key.len(); - buf[offset..offset+64].copy_from_slice(&self.header.signature); - offset += self.header.signature.len(); - buf[offset..offset+32].copy_from_slice(&self.header.link); - offset += self.header.link.len(); - offset += varint::encode(self.post_type(), &mut buf[offset..])?; - match &self.body { - PostBody::Text { channel, timestamp, text } => { - offset += varint::encode(channel.len() as u64, &mut buf[offset..])?; - buf[offset..offset+channel.len()].copy_from_slice(channel); - offset += channel.len(); - offset += varint::encode(*timestamp, &mut buf[offset..])?; - offset += varint::encode(text.len() as u64, &mut buf[offset..])?; - buf[offset..offset+text.len()].copy_from_slice(text); - offset += text.len(); - }, - PostBody::Delete { timestamp, hash } => { - offset += varint::encode(*timestamp, &mut buf[offset..])?; - buf[offset..offset+hash.len()].copy_from_slice(hash); - offset += hash.len(); - }, - PostBody::Info { timestamp, key, value } => { - offset += varint::encode(*timestamp, &mut buf[offset..])?; - offset += varint::encode(key.len() as u64, &mut buf[offset..])?; - buf[offset..offset+key.len()].copy_from_slice(key); - offset += key.len(); - offset += varint::encode(value.len() as u64, &mut buf[offset..])?; - buf[offset..offset+value.len()].copy_from_slice(value); - offset += value.len(); - }, - PostBody::Topic { channel, timestamp, topic } => { - offset += varint::encode(channel.len() as u64, &mut buf[offset..])?; - buf[offset..offset+channel.len()].copy_from_slice(channel); - offset += channel.len(); - offset += varint::encode(*timestamp, &mut buf[offset..])?; - offset += varint::encode(topic.len() as u64, &mut buf[offset..])?; - buf[offset..offset+topic.len()].copy_from_slice(topic); - offset += topic.len(); - }, - PostBody::Join { channel, timestamp } => { - offset += varint::encode(channel.len() as u64, &mut buf[offset..])?; - buf[offset..offset+channel.len()].copy_from_slice(channel); - offset += channel.len(); - offset += varint::encode(*timestamp, &mut buf[offset..])?; - }, - PostBody::Leave { channel, timestamp } => { - offset += varint::encode(channel.len() as u64, &mut buf[offset..])?; - buf[offset..offset+channel.len()].copy_from_slice(channel); - offset += channel.len(); - offset += varint::encode(*timestamp, &mut buf[offset..])?; - }, - PostBody::Unrecognized { post_type } => { - return E::PostWriteUnrecognizedType { post_type: *post_type }.raise(); - }, + fn to_bytes(&self) -> Result, Error> { + let mut buf = vec![0; self.count_bytes()]; + self.write_bytes(&mut buf)?; + Ok(buf) + } + fn write_bytes(&self, buf: &mut [u8]) -> Result { + let mut offset = 0; + assert_eq![self.header.public_key.len(), 32]; + assert_eq![self.header.signature.len(), 64]; + assert_eq![self.header.link.len(), 32]; + buf[offset..offset + 32].copy_from_slice(&self.header.public_key); + offset += self.header.public_key.len(); + buf[offset..offset + 64].copy_from_slice(&self.header.signature); + offset += self.header.signature.len(); + buf[offset..offset + 32].copy_from_slice(&self.header.link); + offset += self.header.link.len(); + offset += varint::encode(self.post_type(), &mut buf[offset..])?; + match &self.body { + PostBody::Text { + channel, + timestamp, + text, + } => { + offset += varint::encode(channel.len() as u64, &mut buf[offset..])?; + buf[offset..offset + channel.len()].copy_from_slice(channel); + offset += channel.len(); + offset += varint::encode(*timestamp, &mut buf[offset..])?; + offset += varint::encode(text.len() as u64, &mut buf[offset..])?; + buf[offset..offset + text.len()].copy_from_slice(text); + offset += text.len(); + } + PostBody::Delete { timestamp, hash } => { + offset += varint::encode(*timestamp, &mut buf[offset..])?; + buf[offset..offset + hash.len()].copy_from_slice(hash); + offset += hash.len(); + } + PostBody::Info { + timestamp, + key, + value, + } => { + offset += varint::encode(*timestamp, &mut buf[offset..])?; + offset += varint::encode(key.len() as u64, &mut buf[offset..])?; + buf[offset..offset + key.len()].copy_from_slice(key); + offset += key.len(); + offset += varint::encode(value.len() as u64, &mut buf[offset..])?; + buf[offset..offset + value.len()].copy_from_slice(value); + offset += value.len(); + } + PostBody::Topic { + channel, + timestamp, + topic, + } => { + offset += varint::encode(channel.len() as u64, &mut buf[offset..])?; + buf[offset..offset + channel.len()].copy_from_slice(channel); + offset += channel.len(); + offset += varint::encode(*timestamp, &mut buf[offset..])?; + offset += varint::encode(topic.len() as u64, &mut buf[offset..])?; + buf[offset..offset + topic.len()].copy_from_slice(topic); + offset += topic.len(); + } + PostBody::Join { channel, timestamp } => { + offset += varint::encode(channel.len() as u64, &mut buf[offset..])?; + buf[offset..offset + channel.len()].copy_from_slice(channel); + offset += channel.len(); + offset += varint::encode(*timestamp, &mut buf[offset..])?; + } + PostBody::Leave { channel, timestamp } => { + offset += varint::encode(channel.len() as u64, &mut buf[offset..])?; + buf[offset..offset + channel.len()].copy_from_slice(channel); + offset += channel.len(); + offset += varint::encode(*timestamp, &mut buf[offset..])?; + } + PostBody::Unrecognized { post_type } => { + return E::PostWriteUnrecognizedType { + post_type: *post_type, + } + .raise(); + } + } + Ok(offset) } - Ok(offset) - } } impl FromBytes for Post { - fn from_bytes(buf: &[u8]) -> Result<(usize,Self),Error> { - let mut offset = 0; - let header = { - let mut public_key = [0;32]; - public_key.copy_from_slice(&buf[offset..offset+32]); - offset += 32; - let mut signature = [0;64]; - signature.copy_from_slice(&buf[offset..offset+64]); - offset += 64; - let mut link = [0;32]; - link.copy_from_slice(&buf[offset..offset+32]); - offset += 32; - PostHeader { public_key, signature, link } - }; - let (s,post_type) = varint::decode(&buf[offset..])?; - offset += s; - let body = match post_type { - 0 => { - let (s,channel_len) = varint::decode(&buf[offset..])?; - offset += s; - let channel = buf[offset..offset+channel_len as usize].to_vec(); - offset += channel_len as usize; - let (s,timestamp) = varint::decode(&buf[offset..])?; - offset += s; - let (s,text_len) = varint::decode(&buf[offset..])?; - offset += s; - let text = buf[offset..offset+text_len as usize].to_vec(); - offset += text_len as usize; - PostBody::Text { channel, timestamp, text } - }, - 1 => { - let (s,timestamp) = varint::decode(&buf[offset..])?; - offset += s; - let mut hash = [0;32]; - hash.copy_from_slice(&buf[offset..offset+32]); - offset += 32; - PostBody::Delete { timestamp, hash } - }, - 2 => { - let (s,timestamp) = varint::decode(&buf[offset..])?; - offset += s; - let (s,key_len) = varint::decode(&buf[offset..])?; + fn from_bytes(buf: &[u8]) -> Result<(usize, Self), Error> { + let mut offset = 0; + let header = { + let mut public_key = [0; 32]; + public_key.copy_from_slice(&buf[offset..offset + 32]); + offset += 32; + let mut signature = [0; 64]; + signature.copy_from_slice(&buf[offset..offset + 64]); + offset += 64; + let mut link = [0; 32]; + link.copy_from_slice(&buf[offset..offset + 32]); + offset += 32; + PostHeader { + public_key, + signature, + link, + } + }; + let (s, post_type) = varint::decode(&buf[offset..])?; offset += s; - let key = buf[offset..offset+key_len as usize].to_vec(); - offset += key_len as usize; - let (s,value_len) = varint::decode(&buf[offset..])?; - offset += s; - let value = buf[offset..offset+value_len as usize].to_vec(); - offset += value_len as usize; - PostBody::Info { timestamp, key, value } - }, - 3 => { - let (s,channel_len) = varint::decode(&buf[offset..])?; - offset += s; - let channel = buf[offset..offset+channel_len as usize].to_vec(); - offset += channel_len as usize; - let (s,timestamp) = varint::decode(&buf[offset..])?; - offset += s; - let (s,topic_len) = varint::decode(&buf[offset..])?; - offset += s; - let topic = buf[offset..offset+topic_len as usize].to_vec(); - offset += topic_len as usize; - PostBody::Topic { channel, timestamp, topic } - }, - 4 => { - let (s,channel_len) = varint::decode(&buf[offset..])?; - offset += s; - let channel = buf[offset..offset+channel_len as usize].to_vec(); - offset += channel_len as usize; - let (s,timestamp) = varint::decode(&buf[offset..])?; - offset += s; - PostBody::Join { channel, timestamp } - }, - 5 => { - let (s,channel_len) = varint::decode(&buf[offset..])?; - offset += s; - let channel = buf[offset..offset+channel_len as usize].to_vec(); - offset += channel_len as usize; - let (s,timestamp) = varint::decode(&buf[offset..])?; - offset += s; - PostBody::Leave { channel, timestamp } - }, - post_type => { - PostBody::Unrecognized { post_type } - }, - }; - Ok((offset, Post { header, body })) - } + let body = match post_type { + 0 => { + let (s, channel_len) = varint::decode(&buf[offset..])?; + offset += s; + let channel = buf[offset..offset + channel_len as usize].to_vec(); + offset += channel_len as usize; + let (s, timestamp) = varint::decode(&buf[offset..])?; + offset += s; + let (s, text_len) = varint::decode(&buf[offset..])?; + offset += s; + let text = buf[offset..offset + text_len as usize].to_vec(); + offset += text_len as usize; + PostBody::Text { + channel, + timestamp, + text, + } + } + 1 => { + let (s, timestamp) = varint::decode(&buf[offset..])?; + offset += s; + let mut hash = [0; 32]; + hash.copy_from_slice(&buf[offset..offset + 32]); + offset += 32; + PostBody::Delete { timestamp, hash } + } + 2 => { + let (s, timestamp) = varint::decode(&buf[offset..])?; + offset += s; + let (s, key_len) = varint::decode(&buf[offset..])?; + offset += s; + let key = buf[offset..offset + key_len as usize].to_vec(); + offset += key_len as usize; + let (s, value_len) = varint::decode(&buf[offset..])?; + offset += s; + let value = buf[offset..offset + value_len as usize].to_vec(); + offset += value_len as usize; + PostBody::Info { + timestamp, + key, + value, + } + } + 3 => { + let (s, channel_len) = varint::decode(&buf[offset..])?; + offset += s; + let channel = buf[offset..offset + channel_len as usize].to_vec(); + offset += channel_len as usize; + let (s, timestamp) = varint::decode(&buf[offset..])?; + offset += s; + let (s, topic_len) = varint::decode(&buf[offset..])?; + offset += s; + let topic = buf[offset..offset + topic_len as usize].to_vec(); + offset += topic_len as usize; + PostBody::Topic { + channel, + timestamp, + topic, + } + } + 4 => { + let (s, channel_len) = varint::decode(&buf[offset..])?; + offset += s; + let channel = buf[offset..offset + channel_len as usize].to_vec(); + offset += channel_len as usize; + let (s, timestamp) = varint::decode(&buf[offset..])?; + offset += s; + PostBody::Join { channel, timestamp } + } + 5 => { + let (s, channel_len) = varint::decode(&buf[offset..])?; + offset += s; + let channel = buf[offset..offset + channel_len as usize].to_vec(); + offset += channel_len as usize; + let (s, timestamp) = varint::decode(&buf[offset..])?; + offset += s; + PostBody::Leave { channel, timestamp } + } + post_type => PostBody::Unrecognized { post_type }, + }; + Ok((offset, Post { header, body })) + } } diff --git a/src/store.rs b/src/store.rs index bcf30fb..0f5181b 100644 --- a/src/store.rs +++ b/src/store.rs @@ -1,199 +1,221 @@ use crate::{ - Error,Post,PostBody,Channel,Hash,Payload,ChannelOptions, - stream::{LiveStream,PostStream,HashStream}, + stream::{HashStream, LiveStream, PostStream}, + Channel, ChannelOptions, Error, Hash, Payload, Post, PostBody, }; +use async_std::{ + prelude::*, + stream, + sync::{Arc, Mutex, RwLock}, + task, +}; +use desert::ToBytes; use sodiumoxide::crypto; +use std::collections::{BTreeMap, HashMap}; use std::convert::TryInto; -use std::collections::{HashMap,BTreeMap}; -use async_std::{prelude::*,task,stream,sync::{Arc,RwLock,Mutex}}; -use desert::ToBytes; -pub type Keypair = ([u8;32],[u8;64]); +pub type Keypair = ([u8; 32], [u8; 64]); pub type GetPostOptions = ChannelOptions; #[async_trait::async_trait] -pub trait Store: Clone+Send+Sync+Unpin+'static { - async fn get_keypair(&mut self) -> Result,Error>; - async fn set_keypair(&mut self, keypair: Keypair) -> Result<(),Error>; - async fn get_or_create_keypair(&mut self) -> Result { - if let Some(kp) = self.get_keypair().await? { - Ok(kp) - } else { - let (pk,sk) = crypto::sign::gen_keypair(); - let kp = ( - pk.as_ref().try_into().unwrap(), - sk.as_ref().try_into().unwrap() - ); - self.set_keypair(kp.clone()).await?; - Ok(kp) +pub trait Store: Clone + Send + Sync + Unpin + 'static { + async fn get_keypair(&mut self) -> Result, Error>; + async fn set_keypair(&mut self, keypair: Keypair) -> Result<(), Error>; + async fn get_or_create_keypair(&mut self) -> Result { + if let Some(kp) = self.get_keypair().await? { + Ok(kp) + } else { + let (pk, sk) = crypto::sign::gen_keypair(); + let kp = ( + pk.as_ref().try_into().unwrap(), + sk.as_ref().try_into().unwrap(), + ); + self.set_keypair(kp.clone()).await?; + Ok(kp) + } } - } - async fn get_latest_hash(&mut self, channel: &[u8]) -> Result<[u8;32],Error>; - async fn insert_post(&mut self, post: &Post) -> Result<(),Error>; - async fn get_posts<'a>(&'a mut self, opts: &GetPostOptions) -> Result; - async fn get_posts_live<'a>(&'a mut self, opts: &GetPostOptions) -> Result; - async fn get_post_hashes<'a>(&'a mut self, opts: &GetPostOptions) -> Result; - async fn want(&mut self, hashes: &[Hash]) -> Result,Error>; - async fn get_data(&mut self, hashes: &[Hash]) -> Result,Error>; + async fn get_latest_hash(&mut self, channel: &[u8]) -> Result<[u8; 32], Error>; + async fn insert_post(&mut self, post: &Post) -> Result<(), Error>; + async fn get_posts<'a>(&'a mut self, opts: &GetPostOptions) -> Result; + async fn get_posts_live<'a>(&'a mut self, opts: &GetPostOptions) -> Result; + async fn get_post_hashes<'a>(&'a mut self, opts: &GetPostOptions) -> Result; + async fn want(&mut self, hashes: &[Hash]) -> Result, Error>; + async fn get_data(&mut self, hashes: &[Hash]) -> Result, Error>; } #[derive(Clone)] pub struct MemoryStore { - keypair: Keypair, - posts: Arc>>>>, - post_hashes: Arc>>>>, - data: Arc>>, - empty_post_bt: BTreeMap>, - empty_hash_bt: BTreeMap>, - live_streams: Arc>>>>>, - live_stream_id: Arc>, + keypair: Keypair, + posts: Arc>>>>, + post_hashes: Arc>>>>, + data: Arc>>, + empty_post_bt: BTreeMap>, + empty_hash_bt: BTreeMap>, + live_streams: Arc>>>>>, + live_stream_id: Arc>, } impl Default for MemoryStore { - fn default() -> Self { - let (pk,sk) = crypto::sign::gen_keypair(); - Self { - keypair: ( - pk.as_ref().try_into().unwrap(), - sk.as_ref().try_into().unwrap() - ), - posts: Arc::new(RwLock::new(HashMap::new())), - post_hashes: Arc::new(RwLock::new(HashMap::new())), - data: Arc::new(RwLock::new(HashMap::new())), - empty_post_bt: BTreeMap::new(), - empty_hash_bt: BTreeMap::new(), - live_streams: Arc::new(RwLock::new(HashMap::new())), - live_stream_id: Arc::new(Mutex::new(0)), + fn default() -> Self { + let (pk, sk) = crypto::sign::gen_keypair(); + Self { + keypair: ( + pk.as_ref().try_into().unwrap(), + sk.as_ref().try_into().unwrap(), + ), + posts: Arc::new(RwLock::new(HashMap::new())), + post_hashes: Arc::new(RwLock::new(HashMap::new())), + data: Arc::new(RwLock::new(HashMap::new())), + empty_post_bt: BTreeMap::new(), + empty_hash_bt: BTreeMap::new(), + live_streams: Arc::new(RwLock::new(HashMap::new())), + live_stream_id: Arc::new(Mutex::new(0)), + } } - } } #[async_trait::async_trait] impl Store for MemoryStore { - async fn get_keypair(&mut self) -> Result,Error> { - Ok(Some(self.keypair.clone())) - } - async fn set_keypair(&mut self, keypair: Keypair) -> Result<(),Error> { - self.keypair = keypair; - Ok(()) - } - async fn get_latest_hash(&mut self, _channel: &[u8]) -> Result<[u8;32],Error> { - // todo: actually use latest message if available instead of zeros - Ok([0;32]) - } - async fn insert_post(&mut self, post: &Post) -> Result<(),Error> { - match &post.body { - PostBody::Text { channel, timestamp, .. } => { - { - let mut posts = self.posts.write().await; - if let Some(post_map) = posts.get_mut(channel) { - if let Some(posts) = post_map.get_mut(timestamp) { - posts.push(post.clone()); - } else { - post_map.insert(*timestamp, vec![post.clone()]); - } - } else { - let mut post_map = BTreeMap::new(); - post_map.insert(*timestamp, vec![post.clone()]); - posts.insert(channel.to_vec(), post_map); - } - } - { - let mut post_hashes = self.post_hashes.write().await; - if let Some(hash_map) = post_hashes.get_mut(channel) { - if let Some(hashes) = hash_map.get_mut(timestamp) { - hashes.push(post.hash()?); - } else { - let hash = post.hash()?; - hash_map.insert(*timestamp, vec![hash.clone()]); - self.data.write().await.insert(hash, post.to_bytes()?); - } - } else { - let mut hash_map = BTreeMap::new(); - let hash = post.hash()?; - hash_map.insert(*timestamp, vec![hash.clone()]); - post_hashes.insert(channel.to_vec(), hash_map); - self.data.write().await.insert(hash, post.to_bytes()?); - } - } - if let Some(senders) = self.live_streams.read().await.get(channel) { - for stream in senders.write().await.iter_mut() { - if stream.matches(&post) { - stream.send(post.clone()).await; + async fn get_keypair(&mut self) -> Result, Error> { + Ok(Some(self.keypair.clone())) + } + async fn set_keypair(&mut self, keypair: Keypair) -> Result<(), Error> { + self.keypair = keypair; + Ok(()) + } + async fn get_latest_hash(&mut self, _channel: &[u8]) -> Result<[u8; 32], Error> { + // todo: actually use latest message if available instead of zeros + Ok([0; 32]) + } + async fn insert_post(&mut self, post: &Post) -> Result<(), Error> { + match &post.body { + PostBody::Text { + channel, timestamp, .. + } => { + { + let mut posts = self.posts.write().await; + if let Some(post_map) = posts.get_mut(channel) { + if let Some(posts) = post_map.get_mut(timestamp) { + posts.push(post.clone()); + } else { + post_map.insert(*timestamp, vec![post.clone()]); + } + } else { + let mut post_map = BTreeMap::new(); + post_map.insert(*timestamp, vec![post.clone()]); + posts.insert(channel.to_vec(), post_map); + } + } + { + let mut post_hashes = self.post_hashes.write().await; + if let Some(hash_map) = post_hashes.get_mut(channel) { + if let Some(hashes) = hash_map.get_mut(timestamp) { + hashes.push(post.hash()?); + } else { + let hash = post.hash()?; + hash_map.insert(*timestamp, vec![hash.clone()]); + self.data.write().await.insert(hash, post.to_bytes()?); + } + } else { + let mut hash_map = BTreeMap::new(); + let hash = post.hash()?; + hash_map.insert(*timestamp, vec![hash.clone()]); + post_hashes.insert(channel.to_vec(), hash_map); + self.data.write().await.insert(hash, post.to_bytes()?); + } + } + if let Some(senders) = self.live_streams.read().await.get(channel) { + for stream in senders.write().await.iter_mut() { + if stream.matches(&post) { + stream.send(post.clone()).await; + } + } + } } - } + _ => {} } - }, - _ => {}, + Ok(()) + } + async fn get_posts(&mut self, opts: &GetPostOptions) -> Result { + let posts = self + .posts + .write() + .await + .get(&opts.channel) + .unwrap_or(&self.empty_post_bt) + .range(opts.time_start..opts.time_end) + .flat_map(|(_time, posts)| posts.iter().map(|post| Ok(post.clone()))) + .collect::>>(); + Ok(Box::new(stream::from_iter(posts.into_iter()))) } - Ok(()) - } - async fn get_posts(&mut self, opts: &GetPostOptions) -> Result { - let posts = self.posts.write().await.get(&opts.channel) - .unwrap_or(&self.empty_post_bt) - .range(opts.time_start..opts.time_end) - .flat_map(|(_time,posts)| posts.iter().map(|post| Ok(post.clone()))) - .collect::>>(); - Ok(Box::new(stream::from_iter(posts.into_iter()))) - } - async fn get_posts_live(&mut self, opts: &GetPostOptions) -> Result { - let live_stream = { - let mut live_streams = self.live_streams.write().await; - if let Some(streams) = live_streams.get_mut(&opts.channel) { + async fn get_posts_live(&mut self, opts: &GetPostOptions) -> Result { let live_stream = { - let mut id = self.live_stream_id.lock().await; - *id += 1; - LiveStream::new(*id, opts.clone(), streams.clone()) - }; - let live = live_stream.clone(); - task::block_on(async move { - streams.write().await.push(live); - }); - live_stream - } else { - let streams = Arc::new(RwLock::new(vec![])); - let live_stream_id = { - let mut id_r = self.live_stream_id.lock().await; - let id = *id_r; - *id_r += 1; - id + let mut live_streams = self.live_streams.write().await; + if let Some(streams) = live_streams.get_mut(&opts.channel) { + let live_stream = { + let mut id = self.live_stream_id.lock().await; + *id += 1; + LiveStream::new(*id, opts.clone(), streams.clone()) + }; + let live = live_stream.clone(); + task::block_on(async move { + streams.write().await.push(live); + }); + live_stream + } else { + let streams = Arc::new(RwLock::new(vec![])); + let live_stream_id = { + let mut id_r = self.live_stream_id.lock().await; + let id = *id_r; + *id_r += 1; + id + }; + let streams_c = streams.clone(); + let live_stream = task::block_on(async move { + let live_stream = + LiveStream::new(live_stream_id, opts.clone(), streams_c.clone()); + streams_c.write().await.push(live_stream.clone()); + live_stream + }); + live_streams.insert(opts.channel.clone(), streams); + live_stream + } }; - let streams_c = streams.clone(); - let live_stream = task::block_on(async move { - let live_stream = LiveStream::new(live_stream_id, opts.clone(), streams_c.clone()); - streams_c.write().await.push(live_stream.clone()); - live_stream - }); - live_streams.insert(opts.channel.clone(), streams); - live_stream - } - }; - let post_stream = self.get_posts(opts).await?; - Ok(Box::new(post_stream.merge(live_stream))) - } - async fn get_post_hashes(&mut self, opts: &GetPostOptions) -> Result { - let start = opts.time_start; - let end = opts.time_end; - let empty = self.empty_hash_bt.range(..); - let hashes = self.post_hashes.read().await.get(&opts.channel) - .map(|x| { - match (start,end) { - (0,0) => x.range(..), - (0,end) => x.range(..end), - (start,0) => x.range(start..), - _ => x.range(start..end), - } - }) - .unwrap_or(empty) - .flat_map(|(_time,hashes)| hashes.iter().map(|hash| Ok(*hash))) - .collect::>>(); - Ok(Box::new(stream::from_iter(hashes.into_iter()))) - } - async fn want(&mut self, hashes: &[Hash]) -> Result,Error> { - let data = self.data.read().await; - Ok(hashes.iter().filter(|hash| !data.contains_key(hash.clone())).cloned().collect()) - } - async fn get_data(&mut self, hashes: &[Hash]) -> Result,Error> { - let data = self.data.read().await; - Ok(hashes.iter().filter_map(|hash| data.get(hash)).cloned().collect()) - } + let post_stream = self.get_posts(opts).await?; + Ok(Box::new(post_stream.merge(live_stream))) + } + async fn get_post_hashes(&mut self, opts: &GetPostOptions) -> Result { + let start = opts.time_start; + let end = opts.time_end; + let empty = self.empty_hash_bt.range(..); + let hashes = self + .post_hashes + .read() + .await + .get(&opts.channel) + .map(|x| match (start, end) { + (0, 0) => x.range(..), + (0, end) => x.range(..end), + (start, 0) => x.range(start..), + _ => x.range(start..end), + }) + .unwrap_or(empty) + .flat_map(|(_time, hashes)| hashes.iter().map(|hash| Ok(*hash))) + .collect::>>(); + Ok(Box::new(stream::from_iter(hashes.into_iter()))) + } + async fn want(&mut self, hashes: &[Hash]) -> Result, Error> { + let data = self.data.read().await; + Ok(hashes + .iter() + .filter(|hash| !data.contains_key(hash.clone())) + .cloned() + .collect()) + } + async fn get_data(&mut self, hashes: &[Hash]) -> Result, Error> { + let data = self.data.read().await; + Ok(hashes + .iter() + .filter_map(|hash| data.get(hash)) + .cloned() + .collect()) + } } diff --git a/src/stream.rs b/src/stream.rs index a1c6d3c..9277210 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,80 +1,98 @@ -use crate::{Error,ChannelOptions,Post,Hash}; +use crate::{ChannelOptions, Error, Hash, Post}; use async_std::{ - prelude::*, - task,channel,stream::Stream,pin::Pin, - task::{Waker,Context,Poll},sync::{Arc,RwLock,Mutex}, + channel, + pin::Pin, + prelude::*, + stream::Stream, + sync::{Arc, Mutex, RwLock}, + task, + task::{Context, Poll, Waker}, }; -pub type PostStream<'a> = Box>+Unpin+Send+'a>; -pub type HashStream<'a> = Box>+Unpin+Send+'a>; +pub type PostStream<'a> = Box> + Unpin + Send + 'a>; +pub type HashStream<'a> = Box> + Unpin + Send + 'a>; #[derive(Clone)] pub struct LiveStream { - id: usize, - options: ChannelOptions, - sender: channel::Sender, - receiver: channel::Receiver, - live_streams: Arc>>, - waker: Arc>>, -} - -impl LiveStream { - pub fn new( id: usize, options: ChannelOptions, + sender: channel::Sender, + receiver: channel::Receiver, live_streams: Arc>>, - ) -> Self { - let (sender,receiver) = channel::bounded(options.limit); - Self { id, options, sender, receiver, live_streams, waker: Arc::new(Mutex::new(None)) } - } - pub async fn send(&mut self, post: Post) { - if let Err(_) = self.sender.try_send(post) {} - if let Some(waker) = self.waker.lock().await.as_ref() { - waker.wake_by_ref(); + waker: Arc>>, +} + +impl LiveStream { + pub fn new(id: usize, options: ChannelOptions, live_streams: Arc>>) -> Self { + let (sender, receiver) = channel::bounded(options.limit); + Self { + id, + options, + sender, + receiver, + live_streams, + waker: Arc::new(Mutex::new(None)), + } } - } - pub fn matches(&self, post: &Post) -> bool { - if Some(&self.options.channel) != post.get_channel() { return false } - match (self.options.time_start, self.options.time_end) { - (0,0) => true, - (0,end) => post.get_timestamp().map(|t| t <= end).unwrap_or(false), - (start,0) => post.get_timestamp().map(|t| start <= t).unwrap_or(false), - (start,end) => post.get_timestamp().map(|t| start <= t && t <= end).unwrap_or(false), + pub async fn send(&mut self, post: Post) { + if let Err(_) = self.sender.try_send(post) {} + if let Some(waker) = self.waker.lock().await.as_ref() { + waker.wake_by_ref(); + } + } + pub fn matches(&self, post: &Post) -> bool { + if Some(&self.options.channel) != post.get_channel() { + return false; + } + match (self.options.time_start, self.options.time_end) { + (0, 0) => true, + (0, end) => post.get_timestamp().map(|t| t <= end).unwrap_or(false), + (start, 0) => post.get_timestamp().map(|t| start <= t).unwrap_or(false), + (start, end) => post + .get_timestamp() + .map(|t| start <= t && t <= end) + .unwrap_or(false), + } } - } } impl Stream for LiveStream { - type Item = Result; - fn poll_next(self: Pin<&mut Self>, ctx: &mut Context) -> Poll> { - let r = Pin::new(&mut self.receiver.recv()).poll(ctx); - match r { - Poll::Ready(Ok(x)) => { - let m_waker = self.waker.clone(); - task::block_on(async move { *m_waker.lock().await = None; }); - Poll::Ready(Some(Ok(x))) - }, - Poll::Ready(Err(x)) => { - let m_waker = self.waker.clone(); - task::block_on(async move { *m_waker.lock().await = None; }); - Poll::Ready(Some(Err(x.into()))) - }, - Poll::Pending => { - let m_waker = self.waker.clone(); - let waker = ctx.waker().clone(); - task::block_on(async move { *m_waker.lock().await = Some(waker); }); - Poll::Pending - }, + type Item = Result; + fn poll_next(self: Pin<&mut Self>, ctx: &mut Context) -> Poll> { + let r = Pin::new(&mut self.receiver.recv()).poll(ctx); + match r { + Poll::Ready(Ok(x)) => { + let m_waker = self.waker.clone(); + task::block_on(async move { + *m_waker.lock().await = None; + }); + Poll::Ready(Some(Ok(x))) + } + Poll::Ready(Err(x)) => { + let m_waker = self.waker.clone(); + task::block_on(async move { + *m_waker.lock().await = None; + }); + Poll::Ready(Some(Err(x.into()))) + } + Poll::Pending => { + let m_waker = self.waker.clone(); + let waker = ctx.waker().clone(); + task::block_on(async move { + *m_waker.lock().await = Some(waker); + }); + Poll::Pending + } + } } - } } impl Drop for LiveStream { - fn drop(&mut self) { - let live_streams = self.live_streams.clone(); - let id = self.id; - task::block_on(async move { - live_streams.write().await.drain_filter(|s| s.id == id); - }); - } + fn drop(&mut self) { + let live_streams = self.live_streams.clone(); + let id = self.id; + task::block_on(async move { + live_streams.write().await.drain_filter(|s| s.id == id); + }); + } }