Skip to content

Commit 9782f45

Browse files
refactor(handler): extract RowSet formatting from scheduler to handler (#5802)
* lift to_pg_rows * cleanup `format: bool` * no enum wrapping local & distributed * enum stream without boxing for local or distributed * comments Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent 527b656 commit 9782f45

File tree

7 files changed

+93
-61
lines changed

7 files changed

+93
-61
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/frontend/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ parking_lot = "0.12"
2828
parse-display = "0.6"
2929
paste = "1"
3030
pgwire = { path = "../utils/pgwire" }
31+
pin-project-lite = "0.2"
3132
prometheus = { version = "0.13", features = ["process"] }
3233
prost = "0.11"
3334
rand = "0.8"

src/frontend/src/handler/mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use pgwire::types::Row;
2525
use risingwave_common::error::{ErrorCode, Result};
2626
use risingwave_sqlparser::ast::{DropStatement, ObjectType, Statement};
2727

28+
use self::util::DataChunkToRowSetAdapter;
2829
use crate::scheduler::{DistributedQueryStream, LocalQueryStream};
2930
use crate::session::{OptimizerContext, SessionImpl};
3031
use crate::utils::WithOptions;
@@ -60,8 +61,8 @@ pub mod variable;
6061
pub type RwPgResponse = PgResponse<PgResponseStream>;
6162

6263
pub enum PgResponseStream {
63-
LocalQuery(LocalQueryStream),
64-
DistributedQuery(DistributedQueryStream),
64+
LocalQuery(DataChunkToRowSetAdapter<LocalQueryStream>),
65+
DistributedQuery(DataChunkToRowSetAdapter<DistributedQueryStream>),
6566
Rows(BoxStream<'static, RowSetResult>),
6667
}
6768

src/frontend/src/handler/query.rs

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,16 @@ use risingwave_sqlparser::ast::Statement;
2525
use super::{PgResponseStream, RwPgResponse};
2626
use crate::binder::{Binder, BoundSetExpr, BoundStatement};
2727
use crate::handler::privilege::{check_privileges, resolve_privileges};
28-
use crate::handler::util::to_pg_field;
28+
use crate::handler::util::{to_pg_field, DataChunkToRowSetAdapter};
2929
use crate::planner::Planner;
3030
use crate::scheduler::plan_fragmenter::Query;
3131
use crate::scheduler::{
32-
BatchPlanFragmenter, ExecutionContext, ExecutionContextRef, LocalQueryExecution,
32+
BatchPlanFragmenter, DistributedQueryStream, ExecutionContext, ExecutionContextRef,
33+
LocalQueryExecution, LocalQueryStream,
3334
};
3435
use crate::session::{OptimizerContext, OptimizerContextRef, SessionImpl};
3536
use crate::PlanRef;
3637

37-
pub type QueryResultSet = PgResponseStream;
38-
3938
pub fn gen_batch_query_plan(
4039
session: &SessionImpl,
4140
context: OptimizerContextRef,
@@ -116,9 +115,17 @@ pub async fn handle_query(
116115
tracing::trace!("Generated query after plan fragmenter: {:?}", &query);
117116

118117
let mut row_stream = match query_mode {
119-
QueryMode::Local => local_execute(session.clone(), query, format).await?,
118+
QueryMode::Local => PgResponseStream::LocalQuery(DataChunkToRowSetAdapter::new(
119+
local_execute(session.clone(), query).await?,
120+
format,
121+
)),
120122
// Local mode do not support cancel tasks.
121-
QueryMode::Distributed => distribute_execute(session.clone(), query, format).await?,
123+
QueryMode::Distributed => {
124+
PgResponseStream::DistributedQuery(DataChunkToRowSetAdapter::new(
125+
distribute_execute(session.clone(), query).await?,
126+
format,
127+
))
128+
}
122129
};
123130

124131
let rows_count = match stmt_type {
@@ -183,21 +190,16 @@ fn to_statement_type(stmt: &Statement) -> StatementType {
183190
pub async fn distribute_execute(
184191
session: Arc<SessionImpl>,
185192
query: Query,
186-
format: bool,
187-
) -> Result<QueryResultSet> {
193+
) -> Result<DistributedQueryStream> {
188194
let execution_context: ExecutionContextRef = ExecutionContext::new(session.clone()).into();
189195
let query_manager = execution_context.session().env().query_manager().clone();
190196
query_manager
191-
.schedule(execution_context, query, format)
197+
.schedule(execution_context, query)
192198
.await
193199
.map_err(|err| err.into())
194200
}
195201

196-
async fn local_execute(
197-
session: Arc<SessionImpl>,
198-
query: Query,
199-
format: bool,
200-
) -> Result<QueryResultSet> {
202+
async fn local_execute(session: Arc<SessionImpl>, query: Query) -> Result<LocalQueryStream> {
201203
let front_env = session.env();
202204

203205
// Acquire hummock snapshot for local execution.
@@ -211,7 +213,7 @@ async fn local_execute(
211213
// TODO: Passing sql here
212214
let execution =
213215
LocalQueryExecution::new(query, front_env.clone(), "", epoch, session.auth_context());
214-
let rsp = Ok(execution.stream_rows(format));
216+
let rsp = Ok(execution.stream_rows());
215217

216218
// Release hummock snapshot for local execution.
217219
hummock_snapshot_manager.release(epoch, &query_id).await;

src/frontend/src/handler/util.rs

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,70 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
use std::pin::Pin;
16+
use std::task::{Context, Poll};
17+
1518
use bytes::Bytes;
19+
use futures::Stream;
1620
use itertools::Itertools;
1721
use pgwire::pg_field_descriptor::{PgFieldDescriptor, TypeOid};
22+
use pgwire::pg_response::RowSetResult;
23+
use pgwire::pg_server::BoxedError;
1824
use pgwire::types::Row;
25+
use pin_project_lite::pin_project;
1926
use risingwave_common::array::DataChunk;
2027
use risingwave_common::catalog::{ColumnDesc, Field};
2128
use risingwave_common::types::{DataType, ScalarRefImpl};
2229

30+
pin_project! {
31+
/// Wrapper struct that converts a stream of DataChunk to a stream of RowSet based on formatting
32+
/// parameters.
33+
///
34+
/// This is essentially `StreamExt::map(self, move |res| res.map(|chunk| to_pg_rows(chunk,
35+
/// format)))` but we need a nameable type as part of [`super::PgResponseStream`], but we cannot
36+
/// name the type of a closure.
37+
pub struct DataChunkToRowSetAdapter<VS>
38+
where
39+
VS: Stream<Item = Result<DataChunk, BoxedError>>,
40+
{
41+
#[pin]
42+
chunk_stream: VS,
43+
format: bool,
44+
}
45+
}
46+
impl<VS> DataChunkToRowSetAdapter<VS>
47+
where
48+
VS: Stream<Item = Result<DataChunk, BoxedError>>,
49+
{
50+
pub fn new(chunk_stream: VS, format: bool) -> Self {
51+
Self {
52+
chunk_stream,
53+
format,
54+
}
55+
}
56+
}
57+
58+
impl<VS> Stream for DataChunkToRowSetAdapter<VS>
59+
where
60+
VS: Stream<Item = Result<DataChunk, BoxedError>>,
61+
{
62+
type Item = RowSetResult;
63+
64+
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
65+
let mut this = self.project();
66+
match this.chunk_stream.as_mut().poll_next(cx) {
67+
Poll::Pending => Poll::Pending,
68+
Poll::Ready(chunk) => match chunk {
69+
Some(chunk_result) => match chunk_result {
70+
Ok(chunk) => Poll::Ready(Some(Ok(to_pg_rows(chunk, *this.format)))),
71+
Err(err) => Poll::Ready(Some(Err(err))),
72+
},
73+
None => Poll::Ready(None),
74+
},
75+
}
76+
}
77+
}
78+
2379
/// Format scalars according to postgres convention.
2480
fn pg_value_format(d: ScalarRefImpl<'_>, format: bool) -> Bytes {
2581
// format == false means TEXT format
@@ -34,7 +90,7 @@ fn pg_value_format(d: ScalarRefImpl<'_>, format: bool) -> Bytes {
3490
}
3591
}
3692

37-
pub fn to_pg_rows(chunk: DataChunk, format: bool) -> Vec<Row> {
93+
fn to_pg_rows(chunk: DataChunk, format: bool) -> Vec<Row> {
3894
chunk
3995
.rows()
4096
.map(|r| {

src/frontend/src/scheduler/distributed/query_manager.rs

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ use std::task::{Context, Poll};
2020

2121
use futures::{Stream, StreamExt};
2222
use futures_async_stream::try_stream;
23-
use pgwire::pg_response::RowSetResult;
24-
use pgwire::pg_server::{Session, SessionId};
23+
use pgwire::pg_server::{BoxedError, Session, SessionId};
2524
use risingwave_batch::executor::BoxedDataChunkStream;
2625
use risingwave_common::array::DataChunk;
2726
use risingwave_common::error::RwError;
@@ -32,36 +31,23 @@ use tracing::debug;
3231

3332
use super::QueryExecution;
3433
use crate::catalog::catalog_service::CatalogReader;
35-
use crate::handler::query::QueryResultSet;
36-
use crate::handler::util::to_pg_rows;
37-
use crate::handler::PgResponseStream;
3834
use crate::scheduler::plan_fragmenter::{Query, QueryId};
3935
use crate::scheduler::worker_node_manager::WorkerNodeManagerRef;
4036
use crate::scheduler::{ExecutionContextRef, HummockSnapshotManagerRef, SchedulerResult};
4137

4238
pub struct DistributedQueryStream {
4339
chunk_rx: tokio::sync::mpsc::Receiver<SchedulerResult<DataChunk>>,
44-
format: bool,
45-
}
46-
47-
impl DistributedQueryStream {
48-
pub fn new(
49-
chunk_rx: tokio::sync::mpsc::Receiver<SchedulerResult<DataChunk>>,
50-
format: bool,
51-
) -> Self {
52-
Self { chunk_rx, format }
53-
}
5440
}
5541

5642
impl Stream for DistributedQueryStream {
57-
type Item = RowSetResult;
43+
type Item = Result<DataChunk, BoxedError>;
5844

5945
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
6046
match self.chunk_rx.poll_recv(cx) {
6147
Poll::Pending => Poll::Pending,
6248
Poll::Ready(chunk) => match chunk {
6349
Some(chunk_result) => match chunk_result {
64-
Ok(chunk) => Poll::Ready(Some(Ok(to_pg_rows(chunk, self.format)))),
50+
Ok(chunk) => Poll::Ready(Some(Ok(chunk))),
6551
Err(err) => Poll::Ready(Some(Err(Box::new(err)))),
6652
},
6753
None => Poll::Ready(None),
@@ -117,8 +103,7 @@ impl QueryManager {
117103
&self,
118104
context: ExecutionContextRef,
119105
query: Query,
120-
format: bool,
121-
) -> SchedulerResult<QueryResultSet> {
106+
) -> SchedulerResult<DistributedQueryStream> {
122107
let query_id = query.query_id().clone();
123108
let epoch = self
124109
.hummock_snapshot_manager
@@ -157,7 +142,7 @@ impl QueryManager {
157142

158143
// TODO: Clean up queries status when ends. This should be done lazily.
159144

160-
Ok(query_result_fetcher.stream_from_channel(format))
145+
Ok(query_result_fetcher.stream_from_channel())
161146
}
162147

163148
pub fn cancel_queries_in_session(&self, session_id: SessionId) {
@@ -225,11 +210,10 @@ impl QueryResultFetcher {
225210
Box::pin(self.run_inner())
226211
}
227212

228-
fn stream_from_channel(self, format: bool) -> QueryResultSet {
229-
PgResponseStream::DistributedQuery(DistributedQueryStream {
213+
fn stream_from_channel(self) -> DistributedQueryStream {
214+
DistributedQueryStream {
230215
chunk_rx: self.chunk_rx,
231-
format,
232-
})
216+
}
233217
}
234218
}
235219

src/frontend/src/scheduler/local.rs

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use std::task::{Context, Poll};
2121
use futures::Stream;
2222
use futures_async_stream::try_stream;
2323
use itertools::Itertools;
24-
use pgwire::pg_response::RowSetResult;
24+
use pgwire::pg_server::BoxedError;
2525
use risingwave_batch::executor::{BoxedDataChunkStream, ExecutorBuilder};
2626
use risingwave_batch::task::TaskId;
2727
use risingwave_common::array::DataChunk;
@@ -38,8 +38,6 @@ use tracing::debug;
3838
use uuid::Uuid;
3939

4040
use super::plan_fragmenter::{PartitionInfo, QueryStageRef};
41-
use crate::handler::query::QueryResultSet;
42-
use crate::handler::util::to_pg_rows;
4341
use crate::optimizer::plan_node::PlanNodeType;
4442
use crate::scheduler::plan_fragmenter::{ExecutionPlanNode, Query, StageId};
4543
use crate::scheduler::task_context::FrontendBatchTaskContext;
@@ -48,27 +46,17 @@ use crate::session::{AuthContext, FrontendEnv};
4846

4947
pub struct LocalQueryStream {
5048
data_stream: BoxedDataChunkStream,
51-
format: bool,
52-
}
53-
54-
impl LocalQueryStream {
55-
pub fn new(data_stream: BoxedDataChunkStream, format: bool) -> Self {
56-
Self {
57-
data_stream,
58-
format,
59-
}
60-
}
6149
}
6250

6351
impl Stream for LocalQueryStream {
64-
type Item = RowSetResult;
52+
type Item = Result<DataChunk, BoxedError>;
6553

6654
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
6755
match self.data_stream.as_mut().poll_next(cx) {
6856
Poll::Pending => Poll::Pending,
6957
Poll::Ready(chunk) => match chunk {
7058
Some(chunk_result) => match chunk_result {
71-
Ok(chunk) => Poll::Ready(Some(Ok(to_pg_rows(chunk, self.format)))),
59+
Ok(chunk) => Poll::Ready(Some(Ok(chunk))),
7260
Err(err) => Poll::Ready(Some(Err(Box::new(err)))),
7361
},
7462
None => Poll::Ready(None),
@@ -134,11 +122,10 @@ impl LocalQueryExecution {
134122
Box::pin(self.run_inner())
135123
}
136124

137-
pub fn stream_rows(self, format: bool) -> QueryResultSet {
138-
QueryResultSet::LocalQuery(LocalQueryStream {
125+
pub fn stream_rows(self) -> LocalQueryStream {
126+
LocalQueryStream {
139127
data_stream: self.run(),
140-
format,
141-
})
128+
}
142129
}
143130

144131
/// Convert query to plan fragment.

0 commit comments

Comments
 (0)