diff --git a/Cargo.toml b/Cargo.toml index 95a3860e5..ed8d01e93 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ members = [ "tests/web", "tests/service_named_result", "tests/use_arc_self", + "tests/use_generic_streaming_requests", "tests/default_stubs", "tests/deprecated_methods", "tests/skip_debug", diff --git a/tests/use_generic_streaming_requests/Cargo.toml b/tests/use_generic_streaming_requests/Cargo.toml new file mode 100644 index 000000000..97101f137 --- /dev/null +++ b/tests/use_generic_streaming_requests/Cargo.toml @@ -0,0 +1,17 @@ +[package] +authors = ["Yotam Ofek "] +edition = "2021" +license = "MIT" +name = "use_generic_streaming_requests" + +[dependencies] +tokio-stream = "0.1" +prost = "0.13" +tonic = {path = "../../tonic"} +tokio = {version = "1.0", features = ["macros"]} + +[build-dependencies] +tonic-build = {path = "../../tonic-build" } + +[package.metadata.cargo-machete] +ignored = ["prost"] diff --git a/tests/use_generic_streaming_requests/LICENSE b/tests/use_generic_streaming_requests/LICENSE new file mode 100644 index 000000000..307709840 --- /dev/null +++ b/tests/use_generic_streaming_requests/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2020 Lucio Franco + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/tests/use_generic_streaming_requests/build.rs b/tests/use_generic_streaming_requests/build.rs new file mode 100644 index 000000000..17ed5884b --- /dev/null +++ b/tests/use_generic_streaming_requests/build.rs @@ -0,0 +1,6 @@ +fn main() { + tonic_build::configure() + .use_generic_streaming_requests(true) + .compile_protos(&["proto/test.proto"], &["proto"]) + .unwrap(); +} diff --git a/tests/use_generic_streaming_requests/proto/test.proto b/tests/use_generic_streaming_requests/proto/test.proto new file mode 100644 index 000000000..0659ab452 --- /dev/null +++ b/tests/use_generic_streaming_requests/proto/test.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +package test; + +service Test { + rpc TestRequest(stream Message) returns (Message); +} + +message Message {} diff --git a/tests/use_generic_streaming_requests/src/lib.rs b/tests/use_generic_streaming_requests/src/lib.rs new file mode 100644 index 000000000..6be7f9581 --- /dev/null +++ b/tests/use_generic_streaming_requests/src/lib.rs @@ -0,0 +1,41 @@ +use tokio_stream::StreamExt; +use tonic::{Response, Status}; + +tonic::include_proto!("test"); + +#[derive(Debug, Default)] +pub struct Svc; + +#[tonic::async_trait] +impl test_server::Test for Svc { + async fn test_request( + &self, + req: tonic::Request< + impl tokio_stream::Stream> + Send + Unpin, + >, + ) -> Result, Status> { + let mut req = req.into_inner(); + while let Some(message) = req.try_next().await? { + println!("Got message: {message:?}") + } + + Ok(Response::new(Message {})) + } +} + +#[cfg(test)] +mod tests { + use tonic::Request; + + use super::test_server::Test; + use super::*; + + #[tokio::test] + async fn test_request_handler() { + let incoming_messages = tokio_stream::iter([Message {}, Message {}].map(Ok)); + let svc = Svc; + svc.test_request(Request::new(incoming_messages)) + .await + .unwrap(); + } +} diff --git a/tonic-build/src/code_gen.rs b/tonic-build/src/code_gen.rs index 0306b753a..942ff27ca 100644 --- a/tonic-build/src/code_gen.rs +++ b/tonic-build/src/code_gen.rs @@ -14,6 +14,7 @@ pub struct CodeGenBuilder { disable_comments: HashSet, use_arc_self: bool, generate_default_stubs: bool, + use_generic_streaming_requests: bool, } impl CodeGenBuilder { @@ -46,7 +47,7 @@ impl CodeGenBuilder { self } - /// Enable compiling well knonw types, this will force codegen to not + /// Enable compiling well known types, this will force codegen to not /// use the well known types from `prost-types`. pub fn compile_well_known_types(&mut self, enable: bool) -> &mut Self { self.compile_well_known_types = enable; @@ -71,6 +72,19 @@ impl CodeGenBuilder { self } + /// Enable or disable using `Request` instead of `Request>` + /// as the parameter type for generated trait methods of client-streaming functions. + /// + /// This allows calling those trait methods with a `Request` containing any object that implements + /// `Stream>`, which can be helpful for testing request handler logic. + pub fn use_generic_streaming_requests( + &mut self, + use_generic_streaming_requests: bool, + ) -> &mut Self { + self.use_generic_streaming_requests = use_generic_streaming_requests; + self + } + /// Generate client code based on `Service`. /// /// This takes some `Service` and will generate a `TokenStream` that contains @@ -101,6 +115,7 @@ impl CodeGenBuilder { &self.disable_comments, self.use_arc_self, self.generate_default_stubs, + self.use_generic_streaming_requests, ) } } @@ -115,6 +130,7 @@ impl Default for CodeGenBuilder { disable_comments: HashSet::default(), use_arc_self: false, generate_default_stubs: false, + use_generic_streaming_requests: false, } } } diff --git a/tonic-build/src/prost.rs b/tonic-build/src/prost.rs index 7cfb6ad08..0aa9aa809 100644 --- a/tonic-build/src/prost.rs +++ b/tonic-build/src/prost.rs @@ -41,6 +41,7 @@ pub fn configure() -> Builder { disable_comments: HashSet::default(), use_arc_self: false, generate_default_stubs: false, + use_generic_streaming_requests: false, compile_settings: CompileSettings::default(), skip_debug: HashSet::default(), } @@ -228,6 +229,7 @@ impl prost_build::ServiceGenerator for ServiceGenerator { .disable_comments(self.builder.disable_comments.clone()) .use_arc_self(self.builder.use_arc_self) .generate_default_stubs(self.builder.generate_default_stubs) + .use_generic_streaming_requests(self.builder.use_generic_streaming_requests) .generate_server( &TonicBuildService::new(service.clone(), self.builder.compile_settings.clone()), &self.builder.proto_path, @@ -310,6 +312,7 @@ pub struct Builder { pub(crate) disable_comments: HashSet, pub(crate) use_arc_self: bool, pub(crate) generate_default_stubs: bool, + pub(crate) use_generic_streaming_requests: bool, pub(crate) compile_settings: CompileSettings, pub(crate) skip_debug: HashSet, @@ -584,6 +587,18 @@ impl Builder { self } + /// Enable or disable using `Request` instead of `Request>` + /// as the parameter type for generated trait methods of client-streaming functions. + /// + /// This allows calling those trait methods with a `Request` containing any object that implements + /// `Stream>`, which can be helpful for testing request handler logic. + /// + /// This defaults to `false`. + pub fn use_generic_streaming_requests(mut self, use_generic_streaming_requests: bool) -> Self { + self.use_generic_streaming_requests = use_generic_streaming_requests; + self + } + /// Override the default codec. /// /// If set, writes `{codec_path}::default()` in generated code wherever a codec is created. diff --git a/tonic-build/src/server.rs b/tonic-build/src/server.rs index e2d0aacd9..b2a36aa42 100644 --- a/tonic-build/src/server.rs +++ b/tonic-build/src/server.rs @@ -19,6 +19,7 @@ pub(crate) fn generate_internal( disable_comments: &HashSet, use_arc_self: bool, generate_default_stubs: bool, + use_generic_streaming_requests: bool, ) -> TokenStream { let methods = generate_methods( service, @@ -41,6 +42,7 @@ pub(crate) fn generate_internal( disable_comments, use_arc_self, generate_default_stubs, + use_generic_streaming_requests, ); let package = if emit_package { service.package() } else { "" }; // Transport based implementations @@ -203,6 +205,7 @@ fn generate_trait( disable_comments: &HashSet, use_arc_self: bool, generate_default_stubs: bool, + use_generic_streaming_requests: bool, ) -> TokenStream { let methods = generate_trait_methods( service, @@ -212,6 +215,7 @@ fn generate_trait( disable_comments, use_arc_self, generate_default_stubs, + use_generic_streaming_requests, ); let trait_doc = generate_doc_comment(format!( " Generated trait containing gRPC methods that should be implemented for use with {}Server.", @@ -227,6 +231,7 @@ fn generate_trait( } } +#[allow(clippy::too_many_arguments)] fn generate_trait_methods( service: &T, emit_package: bool, @@ -235,6 +240,7 @@ fn generate_trait_methods( disable_comments: &HashSet, use_arc_self: bool, generate_default_stubs: bool, + use_generic_streaming_requests: bool, ) -> TokenStream { let mut stream = TokenStream::new(); @@ -257,92 +263,61 @@ fn generate_trait_methods( quote!(&self) }; - let method = match ( - method.client_streaming(), - method.server_streaming(), - generate_default_stubs, - ) { - (false, false, true) => { - quote! { - #method_doc - async fn #name(#self_param, request: tonic::Request<#req_message>) - -> std::result::Result, tonic::Status> { - Err(tonic::Status::unimplemented("Not yet implemented")) - } - } - } - (false, false, false) => { - quote! { - #method_doc - async fn #name(#self_param, request: tonic::Request<#req_message>) - -> std::result::Result, tonic::Status>; - } - } - (true, false, true) => { - quote! { - #method_doc - async fn #name(#self_param, request: tonic::Request>) - -> std::result::Result, tonic::Status> { - Err(tonic::Status::unimplemented("Not yet implemented")) - } - } - } - (true, false, false) => { - quote! { - #method_doc - async fn #name(#self_param, request: tonic::Request>) - -> std::result::Result, tonic::Status>; - } - } - (false, true, true) => { - quote! { - #method_doc - async fn #name(#self_param, request: tonic::Request<#req_message>) - -> std::result::Result>, tonic::Status> { - Err(tonic::Status::unimplemented("Not yet implemented")) - } + let result = |ok| quote!(std::result::Result<#ok, tonic::Status>); + let response_result = |message| result(quote!(tonic::Response<#message>)); + + let req_param_type = { + let inner_ty = if !method.client_streaming() { + req_message + } else if !use_generic_streaming_requests { + quote!(tonic::Streaming<#req_message>) + } else { + let message_ty = result(req_message); + quote!(impl tokio_stream::Stream + std::marker::Send + std::marker::Unpin) + }; + + quote!(tonic::Request<#inner_ty>) + }; + + let partial_sig = quote! { + #method_doc + async fn #name(#self_param, request: #req_param_type) + }; + + let body_or_semicolon = if generate_default_stubs { + quote! { + { + Err(tonic::Status::unimplemented("Not yet implemented")) } } - (false, true, false) => { - let stream = quote::format_ident!("{}Stream", method.identifier()); - let stream_doc = generate_doc_comment(format!( - " Server streaming response type for the {} method.", - method.identifier() - )); - - quote! { - #stream_doc - type #stream: tonic::codegen::tokio_stream::Stream> + std::marker::Send + 'static; - - #method_doc - async fn #name(#self_param, request: tonic::Request<#req_message>) - -> std::result::Result, tonic::Status>; - } + } else { + quote!(;) + }; + + let method = if !method.server_streaming() { + let return_ty = response_result(res_message); + quote! { + #partial_sig -> #return_ty #body_or_semicolon } - (true, true, true) => { - quote! { - #method_doc - async fn #name(#self_param, request: tonic::Request>) - -> std::result::Result>, tonic::Status> { - Err(tonic::Status::unimplemented("Not yet implemented")) - } - } + } else if generate_default_stubs { + let return_ty = response_result(quote!(BoxStream<#res_message>)); + quote! { + #partial_sig -> #return_ty #body_or_semicolon } - (true, true, false) => { - let stream = quote::format_ident!("{}Stream", method.identifier()); - let stream_doc = generate_doc_comment(format!( - " Server streaming response type for the {} method.", - method.identifier() - )); - - quote! { - #stream_doc - type #stream: tonic::codegen::tokio_stream::Stream> + std::marker::Send + 'static; - - #method_doc - async fn #name(#self_param, request: tonic::Request>) - -> std::result::Result, tonic::Status>; - } + } else { + let stream = quote::format_ident!("{}Stream", method.identifier()); + let stream_doc = generate_doc_comment(format!( + " Server streaming response type for the {} method.", + method.identifier() + )); + let stream_item_ty = result(res_message); + let stream_ty = quote!(tonic::codegen::tokio_stream::Stream + std::marker::Send + 'static); + let return_ty = response_result(quote!(Self::#stream)); + quote! { + #stream_doc + type #stream: #stream_ty; + + #partial_sig -> #return_ty #body_or_semicolon } };