diff --git a/src/facade/conn_context.cc b/src/facade/conn_context.cc index fa10135246ea..fb68f18c3022 100644 --- a/src/facade/conn_context.cc +++ b/src/facade/conn_context.cc @@ -25,7 +25,7 @@ ConnectionContext::ConnectionContext(Connection* owner) : owner_(owner) { journal_emulated = false; paused = false; blocked = false; - + subscriber = false; subscriptions = 0; } diff --git a/src/facade/conn_context.h b/src/facade/conn_context.h index 0817a7291efa..3f157bb8bac2 100644 --- a/src/facade/conn_context.h +++ b/src/facade/conn_context.h @@ -47,6 +47,7 @@ class ConnectionContext { bool async_dispatch : 1; // whether this connection is amid an async dispatch bool sync_dispatch : 1; // whether this connection is amid a sync dispatch bool journal_emulated : 1; // whether it is used to dispatch journal commands + bool subscriber : 1; // whether this connection is a subscriber to pub/sub channels bool paused = false; // whether this connection is paused due to CLIENT PAUSE // whether it's blocked on blocking commands like BLPOP, needs to be addressable diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index d8ca3e455340..880c3ce81b80 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -506,17 +506,22 @@ void Connection::AsyncOperations::operator()(const AclUpdateMessage& msg) { void Connection::AsyncOperations::operator()(const PubMessage& pub_msg) { RedisReplyBuilder* rbuilder = (RedisReplyBuilder*)builder; + facade::ConnectionContext* cntx = self->cntx(); if (pub_msg.should_unsubscribe) { rbuilder->StartCollection(3, RedisReplyBuilder::CollectionType::PUSH); rbuilder->SendBulkString("unsubscribe"); rbuilder->SendBulkString(pub_msg.channel); rbuilder->SendLong(0); - auto* cntx = self->cntx(); + cntx->Unsubscribe(pub_msg.channel); return; } + if (!cntx->subscriber) { + LOG(DFATAL) << "PubMessage received on non-subscriber connection: " << self->DebugInfo(); + } + unsigned i = 0; array arr; if (pub_msg.pattern.empty()) { diff --git a/src/facade/reply_builder.cc b/src/facade/reply_builder.cc index 04f3e6926dbc..b15d841fb4f0 100644 --- a/src/facade/reply_builder.cc +++ b/src/facade/reply_builder.cc @@ -14,6 +14,7 @@ #include "core/heap_size.h" #include "facade/error.h" #include "util/fibers/proactor_base.h" +#include "util/fibers/stacktrace.h" #ifdef __APPLE__ #ifndef UIO_MAXIOV @@ -99,6 +100,7 @@ void SinkReplyBuilder::CloseConnection() { } template void SinkReplyBuilder::WritePieces(Ts&&... pieces) { + CHECK_EQ(0u, send_time_ns_); if (size_t required = (piece_size(pieces) + ...); buffer_.AppendLen() <= required) Flush(required); @@ -115,11 +117,12 @@ template void SinkReplyBuilder::WritePieces(Ts&&... pieces) { vecs_.push_back(iovec{dest, 0}); } - DCHECK(iovec_end(vecs_.back()) == dest); + CHECK(iovec_end(vecs_.back()) == dest); char* ptr = dest; ([&]() { ptr = write_piece(pieces, ptr); }(), ...); size_t written = ptr - dest; + CHECK_LE(written, buffer_.AppendLen()); buffer_.CommitWrite(written); vecs_.back().iov_len += written; total_size_ += written; @@ -133,6 +136,7 @@ void SinkReplyBuilder::WriteRef(std::string_view str) { } void SinkReplyBuilder::Flush(size_t expected_buffer_cap) { + CHECK_EQ(0u, send_time_ns_); if (!vecs_.empty()) Send(); @@ -167,6 +171,25 @@ void SinkReplyBuilder::Send() { reply_stats.io_write_cnt++; reply_stats.io_write_bytes += total_size_; + + // char needle[32] = {0}; + size_t total = 0; + for (unsigned j = 0; j < vecs_.size(); j++) { + auto& v = vecs_[j]; + total += v.iov_len; +#if 0 + void* found = memmem(v.iov_base, v.iov_len, needle, sizeof(needle)); + if (found) { + size_t offset = reinterpret_cast(found) - reinterpret_cast(v.iov_base); + LOG(ERROR) << "Found zero in iovec " << j << " of size " << v.iov_len << " at offset " + << offset << ":\n " << util::fb2::GetStacktrace() << "\n:" + << absl::CHexEscape( + {reinterpret_cast(v.iov_base), offset + sizeof(needle)}); + } +#endif + } + CHECK_EQ(total, total_size_); + DVLOG(2) << "Writing " << total_size_ << " bytes"; if (auto ec = sink_->Write(vecs_.data(), vecs_.size()); ec) ec_ = ec; diff --git a/src/server/conn_context.cc b/src/server/conn_context.cc index 4cd9b0e6e694..cab2f362bf16 100644 --- a/src/server/conn_context.cc +++ b/src/server/conn_context.cc @@ -259,6 +259,7 @@ void ConnectionContext::Unsubscribe(std::string_view channel) { conn_state.subscribe_info.reset(); DCHECK_GE(subscriptions, 1u); --subscriptions; + subscriber = false; // If we have no subscriptions, we are not a subscriber. } } @@ -273,10 +274,11 @@ vector ConnectionContext::ChangeSubscriptions(CmdArgList channels, boo DCHECK(to_add); conn_state.subscribe_info.reset(new ConnectionState::SubscribeInfo); + subscriber = true; subscriptions++; } - auto& sinfo = *conn_state.subscribe_info.get(); + auto& sinfo = *conn_state.subscribe_info; auto& local_store = pattern ? sinfo.patterns : sinfo.channels; int32_t tid = util::ProactorBase::me()->GetPoolIndex(); diff --git a/tests/dragonfly/connection_test.py b/tests/dragonfly/connection_test.py index fabfdece92f2..95ada880ba60 100755 --- a/tests/dragonfly/connection_test.py +++ b/tests/dragonfly/connection_test.py @@ -1279,3 +1279,46 @@ async def test_client_detached_crash(df_factory): async_client = server.client() await async_client.client_pause(2, all=False) server.stop() + + +async def subscriber_task(pubsub, stop_flag: asyncio.Event): + """Handles receiving messages from the subscribed channel.""" + while not stop_flag.is_set(): + message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1) + if message: + logging.info(f"Received message: {message}") + await asyncio.sleep(0.1) # Small delay to prevent busy-waiting + + +async def ping_task(client: aioredis.Redis, stop_flag: asyncio.Event): + while not stop_flag.is_set(): + pipe = client.pipeline() + pipe.ping() + pipe.ping() + pipe.ping() + await pipe.execute() + await asyncio.sleep(0.1) + + +async def publish_task(client: aioredis.Redis, stop_flag: asyncio.Event): + while not stop_flag.is_set(): + pipe = client.pipeline() + for i in range(10): + pipe.publish("channel1", "x" * 575 + f"{i}") + await pipe.execute() + await asyncio.sleep(0.1) + + +async def test_client_pubsub(df_server: DflyInstance, async_client: aioredis.Redis): + pubsub = async_client.pubsub() + + stop_flag = asyncio.Event() + await pubsub.subscribe("channel1") + task1 = asyncio.create_task(ping_task(async_client, stop_flag)) + task2 = asyncio.create_task(subscriber_task(pubsub, stop_flag)) + publish_client = df_server.client() + task3 = asyncio.create_task(publish_task(publish_client, stop_flag)) + await asyncio.sleep(5) # Let the tasks run for a while + stop_flag.set() # Signal the tasks to stop + await asyncio.gather(task1, task2, task3) + await pubsub.unsubscribe("channel1")