From 4b55006c997582afde2ca0a8460c6466184c81e3 Mon Sep 17 00:00:00 2001 From: substack Date: Sat, 19 Mar 2022 15:53:15 -1000 Subject: [PATCH] fix deadlocks in insert post --- examples/chat.rs | 1 + src/store.rs | 94 ++++++++++++++++++++++++++---------------------- 2 files changed, 52 insertions(+), 43 deletions(-) diff --git a/examples/chat.rs b/examples/chat.rs index e483e44..6f4c4fc 100644 --- a/examples/chat.rs +++ b/examples/chat.rs @@ -36,6 +36,7 @@ fn main() -> Result<(),Error> { let channel = "default".as_bytes(); let text = line.trim_end().as_bytes(); client.post_text(channel, &text).await.unwrap(); + line.clear(); } }); } diff --git a/src/store.rs b/src/store.rs index d03c7a1..8c84a90 100644 --- a/src/store.rs +++ b/src/store.rs @@ -135,34 +135,39 @@ impl Store for MemoryStore { Ok([0;32]) } async fn insert_post(&mut self, post: &Post) -> Result<(),Error> { - println!["insert {:?}", post]; match &post.body { PostBody::Text { channel, timestamp, .. } => { - if let Some(post_map) = self.posts.write().await.get_mut(channel) { - if let Some(posts) = post_map.get_mut(timestamp) { - posts.push(post.clone()); + { + let mut posts = self.posts.write().await; + if let Some(post_map) = posts.get_mut(channel) { + if let Some(posts) = post_map.get_mut(timestamp) { + posts.push(post.clone()); + } else { + post_map.insert(*timestamp, vec![post.clone()]); + } } else { + let mut post_map = BTreeMap::new(); post_map.insert(*timestamp, vec![post.clone()]); + posts.insert(channel.to_vec(), post_map); } - } else { - let mut post_map = BTreeMap::new(); - post_map.insert(*timestamp, vec![post.clone()]); - self.posts.write().await.insert(channel.to_vec(), post_map); } - if let Some(hash_map) = self.post_hashes.write().await.get_mut(channel) { - if let Some(hashes) = hash_map.get_mut(timestamp) { - hashes.push(post.hash()?); + { + let mut post_hashes = self.post_hashes.write().await; + if let Some(hash_map) = post_hashes.get_mut(channel) { + if let Some(hashes) = hash_map.get_mut(timestamp) { + hashes.push(post.hash()?); + } else { + let hash = post.hash()?; + hash_map.insert(*timestamp, vec![hash.clone()]); + self.data.write().await.insert(hash, post.to_bytes()?); + } } else { + let mut hash_map = BTreeMap::new(); let hash = post.hash()?; hash_map.insert(*timestamp, vec![hash.clone()]); + post_hashes.insert(channel.to_vec(), hash_map); self.data.write().await.insert(hash, post.to_bytes()?); } - } else { - let mut hash_map = BTreeMap::new(); - let hash = post.hash()?; - hash_map.insert(*timestamp, vec![hash.clone()]); - self.post_hashes.write().await.insert(channel.to_vec(), hash_map); - self.data.write().await.insert(hash, post.to_bytes()?); } if let Some(senders) = self.live_streams.read().await.get(channel) { for stream in senders.read().await.iter() { @@ -185,33 +190,36 @@ impl Store for MemoryStore { Ok(Box::new(stream::from_iter(posts.into_iter()))) } async fn get_posts_live(&mut self, opts: &GetPostOptions) -> Result { - let live_stream = if let Some(live_streams) = self.live_streams.write().await.get_mut(&opts.channel) { - let live_stream = { - let mut id = self.live_stream_id.lock().await; - *id += 1; - LiveStream::new(*id, opts.clone(), live_streams.clone()) - }; - let live = live_stream.clone(); - task::block_on(async move { - live_streams.write().await.push(live); - }); - live_stream - } else { - let live_streams = Arc::new(RwLock::new(vec![])); - let live_stream_id = { - let mut id_r = self.live_stream_id.lock().await; - let id = *id_r; - *id_r += 1; - id - }; - let live_streams_c = live_streams.clone(); - let live_stream = task::block_on(async move { - let live_stream = LiveStream::new(live_stream_id, opts.clone(), live_streams_c.clone()); - live_streams_c.write().await.push(live_stream.clone()); + let live_stream = { + let mut live_streams = self.live_streams.write().await; + if let Some(streams) = live_streams.get_mut(&opts.channel) { + let live_stream = { + let mut id = self.live_stream_id.lock().await; + *id += 1; + LiveStream::new(*id, opts.clone(), streams.clone()) + }; + let live = live_stream.clone(); + task::block_on(async move { + streams.write().await.push(live); + }); + live_stream + } else { + let streams = Arc::new(RwLock::new(vec![])); + let live_stream_id = { + let mut id_r = self.live_stream_id.lock().await; + let id = *id_r; + *id_r += 1; + id + }; + let streams_c = streams.clone(); + let live_stream = task::block_on(async move { + let live_stream = LiveStream::new(live_stream_id, opts.clone(), streams_c.clone()); + streams_c.write().await.push(live_stream.clone()); + live_stream + }); + live_streams.insert(opts.channel.clone(), streams); live_stream - }); - self.live_streams.write().await.insert(opts.channel.clone(), live_streams); - live_stream + } }; let post_stream = self.get_posts(opts).await?; Ok(Box::new(post_stream.merge(live_stream)))