Skip to content
Closed
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ members = [
"tonic-health",
"tonic-types",
"tonic-reflection",
"tonic-web", # Non-published crates
"tonic-web", # Non-published crates
"examples",
"codegen",
"interop", # Tests
"interop", # Tests
"tests/disable_comments",
"tests/included_service",
"tests/same_name",
Expand All @@ -22,6 +22,7 @@ members = [
"tests/stream_conflict",
"tests/root-crate-path",
"tests/compression",
"tests/various_compression_formats",
"tests/web",
"tests/service_named_result",
"tests/use_arc_self",
Expand Down
13 changes: 13 additions & 0 deletions tests/various_compression_formats/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[package]
name = "various_compression_formats"
version = "0.1.0"
edition = "2021"
publish = false

[dependencies]
prost = "0.13"
tonic = { path = "../../tonic", features = ["gzip","zstd"]}
tokio = { version = "1.36.2", features = ["macros", "rt-multi-thread"] }

[build-dependencies]
tonic-build = { path = "../../tonic-build" }
4 changes: 4 additions & 0 deletions tests/various_compression_formats/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
fn main() -> Result<(), Box<dyn std::error::Error>> {
tonic_build::compile_protos("proto/proto_box.proto")?;
Ok(())
}
15 changes: 15 additions & 0 deletions tests/various_compression_formats/proto/proto_box.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
syntax = "proto3";

package proto_box;

service ProtoService {
rpc Rpc(Input) returns (Output);
}

message Input {
string data = 1;
}

message Output {
string data = 1;
}
3 changes: 3 additions & 0 deletions tests/various_compression_formats/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pub mod proto_box {
tonic::include_proto!("proto_box");
}
206 changes: 206 additions & 0 deletions tests/various_compression_formats/tests/auto_encoding.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
use std::error::Error;

use tokio::net::TcpListener;
use tokio::sync::oneshot;

use tonic::codegen::CompressionEncoding;
use tonic::transport::{server::TcpIncoming, Channel, Server};
use tonic::{Request, Response, Status};

use various_compression_formats::proto_box::{
proto_service_client::ProtoServiceClient,
proto_service_server::{ProtoService, ProtoServiceServer},
Input, Output,
};

const LOCALHOST: &str = "127.0.0.1:0";

#[derive(Default)]
pub struct ServerTest;

#[tonic::async_trait]
impl ProtoService for ServerTest {
async fn rpc(&self, request: Request<Input>) -> Result<Response<Output>, Status> {
println!("Server received request: {:?}", request);

Ok(Response::new(Output {
data: format!("Received: {}", request.into_inner().data),
}))
}
}

struct ClientWrapper {
client: ProtoServiceClient<Channel>,
}

impl ClientWrapper {
async fn new(
address: &str,
accept: Option<CompressionEncoding>,
) -> Result<Self, Box<dyn Error + Send + Sync>> {
let channel = Channel::from_shared(address.to_string())?.connect().await?;
let mut client = ProtoServiceClient::new(channel);

if let Some(encoding) = accept {
client = client.accept_compressed(encoding);
}

Ok(Self { client })
}

async fn send_request(
&mut self,
data: String,
) -> Result<Response<Output>, Box<dyn Error + Send + Sync>> {
let request = Request::new(Input { data });

println!("Client sending request: {:?}", request);

let response = self.client.rpc(request).await?;

println!("Client response headers: {:?}", response.metadata());

Ok(response)
}
}

async fn start_server(
listener: TcpListener,
send: Option<CompressionEncoding>,
auto: bool,
) -> oneshot::Sender<()> {
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
let srv = ServerTest::default();
let mut service = ProtoServiceServer::new(srv);

if let Some(encoding) = send {
service = service.send_compressed(encoding);
}

if auto {
service = service.auto_encoding();
}

let server = Server::builder()
.add_service(service)
.serve_with_incoming_shutdown(
TcpIncoming::from_listener(listener, true, None).unwrap(),
async {
shutdown_rx.await.ok();
},
);

tokio::spawn(async move {
server.await.expect("Server crashed");
});

shutdown_tx
}

async fn run_client_test(
address: &str,
client_accept: Option<CompressionEncoding>,
expected_encoding: Option<&str>,
data: &str,
) -> Result<(), Box<dyn Error + Send + Sync>> {
let mut client = ClientWrapper::new(address, client_accept).await?;
let response = client.send_request(data.to_string()).await?;

match expected_encoding {
Some(encoding) => {
let grpc_encoding = response
.metadata()
.get("grpc-encoding")
.expect("Missing 'grpc-encoding' header");
assert_eq!(grpc_encoding, encoding);
}
None => {
assert!(
!response.metadata().contains_key("grpc-encoding"),
"Expected no 'grpc-encoding' header"
);
}
}

Ok(())
}

#[tokio::test]
async fn test_compression_behavior() -> Result<(), Box<dyn Error + Send + Sync>> {
let listener = TcpListener::bind(LOCALHOST).await?;
let address = format!("http://{}", listener.local_addr().unwrap());

// The server is not specified to send data with any compression
let shutdown_tx = start_server(listener, None, false).await;

tokio::time::sleep(std::time::Duration::from_secs(1)).await;

tokio::try_join!(
// Client 1 can only accept gzip encoding or uncompressed,
// so all data must be returned uncompressed
run_client_test(&address, Some(CompressionEncoding::Gzip), None, "Client 1"),
// Client 2 can only accept non-compressed data,
// so all data must be returned uncompressed
run_client_test(&address, None, None, "Client 2")
)?;

shutdown_tx.send(()).unwrap();

let listener = TcpListener::bind(LOCALHOST).await?;
let address = format!("http://{}", listener.local_addr().unwrap());

// The server is specified to send data with zstd compression
let shutdown_tx = start_server(listener, Some(CompressionEncoding::Zstd), false).await;

tokio::time::sleep(std::time::Duration::from_secs(1)).await;

tokio::try_join!(
// Client 3 can only accept zstd encoding or uncompressed,
// so all data must be returned compressed with zstd
run_client_test(
&address,
Some(CompressionEncoding::Zstd),
Some("zstd"),
"Client 3"
),
// Client 4 can only accept Gzip encoding or uncompressed,
// so all data must be returned uncompressed
run_client_test(&address, Some(CompressionEncoding::Gzip), None, "Client 4")
)?;

shutdown_tx.send(()).unwrap();

Ok(())
}

#[tokio::test]
async fn test_auto_encoding_behavior() -> Result<(), Box<dyn Error + Send + Sync>> {
let listener = TcpListener::bind(LOCALHOST).await?;
let address = format!("http://{}", listener.local_addr().unwrap());

// The server returns in the compression format that the client prefers
let shutdown_tx = start_server(listener, Some(CompressionEncoding::Gzip), true).await;

tokio::time::sleep(std::time::Duration::from_secs(1)).await;

tokio::try_join!(
// Client 5 can accept gzip encoding or uncompressed, so all data must be returned compressed with gzip
run_client_test(
&address,
Some(CompressionEncoding::Gzip),
Some("gzip"),
"Client 5"
),
// Client 6 can accept zstd encoding or uncompressed, so all data must be returned compressed with zstd
run_client_test(
&address,
Some(CompressionEncoding::Zstd),
Some("zstd"),
"Client 6"
)
)?;

shutdown_tx.send(()).unwrap();

Ok(())
}
31 changes: 31 additions & 0 deletions tonic-build/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ pub(crate) fn generate_internal<T: Service>(
self.send_compression_encodings.enable(encoding);
self
}

/// Automatically determine the encoding to use based on the request headers.
#[must_use]
pub fn auto_encoding(mut self) -> Self {
self.auto_encoding = true;
self
}
};

let configure_max_message_size_methods = quote! {
Expand Down Expand Up @@ -117,6 +124,7 @@ pub(crate) fn generate_internal<T: Service>(
send_compression_encodings: EnabledCompressionEncodings,
max_decoding_message_size: Option<usize>,
max_encoding_message_size: Option<usize>,
auto_encoding: bool,
}

impl<T> #server_service<T> {
Expand All @@ -131,6 +139,7 @@ pub(crate) fn generate_internal<T: Service>(
send_compression_encodings: Default::default(),
max_decoding_message_size: None,
max_encoding_message_size: None,
auto_encoding: false,
}
}

Expand Down Expand Up @@ -184,6 +193,7 @@ pub(crate) fn generate_internal<T: Service>(
send_compression_encodings: self.send_compression_encodings,
max_decoding_message_size: self.max_decoding_message_size,
max_encoding_message_size: self.max_encoding_message_size,
auto_encoding: self.auto_encoding,
}
}
}
Expand Down Expand Up @@ -473,6 +483,7 @@ fn generate_unary<T: Method>(
let send_compression_encodings = self.send_compression_encodings;
let max_decoding_message_size = self.max_decoding_message_size;
let max_encoding_message_size = self.max_encoding_message_size;
let auto_encoding = self.auto_encoding;
let inner = self.inner.clone();
let fut = async move {
let method = #service_ident(inner);
Expand All @@ -482,6 +493,10 @@ fn generate_unary<T: Method>(
.apply_compression_config(accept_compression_encodings, send_compression_encodings)
.apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size);

if auto_encoding {
grpc = grpc.auto_encoding();
}

let res = grpc.unary(method, req).await;
Ok(res)
};
Expand Down Expand Up @@ -540,6 +555,7 @@ fn generate_server_streaming<T: Method>(
let send_compression_encodings = self.send_compression_encodings;
let max_decoding_message_size = self.max_decoding_message_size;
let max_encoding_message_size = self.max_encoding_message_size;
let auto_encoding = self.auto_encoding;
let inner = self.inner.clone();
let fut = async move {
let method = #service_ident(inner);
Expand All @@ -549,6 +565,10 @@ fn generate_server_streaming<T: Method>(
.apply_compression_config(accept_compression_encodings, send_compression_encodings)
.apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size);

if auto_encoding {
grpc = grpc.auto_encoding();
}

let res = grpc.server_streaming(method, req).await;
Ok(res)
};
Expand Down Expand Up @@ -598,6 +618,7 @@ fn generate_client_streaming<T: Method>(
let send_compression_encodings = self.send_compression_encodings;
let max_decoding_message_size = self.max_decoding_message_size;
let max_encoding_message_size = self.max_encoding_message_size;
let auto_encoding = self.auto_encoding;
let inner = self.inner.clone();
let fut = async move {
let method = #service_ident(inner);
Expand All @@ -607,6 +628,10 @@ fn generate_client_streaming<T: Method>(
.apply_compression_config(accept_compression_encodings, send_compression_encodings)
.apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size);

if auto_encoding {
grpc = grpc.auto_encoding();
}

let res = grpc.client_streaming(method, req).await;
Ok(res)
};
Expand Down Expand Up @@ -666,7 +691,9 @@ fn generate_streaming<T: Method>(
let send_compression_encodings = self.send_compression_encodings;
let max_decoding_message_size = self.max_decoding_message_size;
let max_encoding_message_size = self.max_encoding_message_size;
let auto_encoding = self.auto_encoding;
let inner = self.inner.clone();

let fut = async move {
let method = #service_ident(inner);
let codec = #codec_name::default();
Expand All @@ -675,6 +702,10 @@ fn generate_streaming<T: Method>(
.apply_compression_config(accept_compression_encodings, send_compression_encodings)
.apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size);

if auto_encoding {
grpc = grpc.auto_encoding();
}

let res = grpc.streaming(method, req).await;
Ok(res)
};
Expand Down
Loading