Skip to content

Commit eaca23e

Browse files
committed
Handle channel requests returning a single item
1 parent 91251b6 commit eaca23e

File tree

2 files changed

+62
-23
lines changed

2 files changed

+62
-23
lines changed

RSocket.Rpc.Protobuf/src/csharp_generator.cc

+61-22
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,14 @@ void GenerateDocCommentMethod(google::protobuf::io::Printer* printer,
132132
}
133133
}
134134

135+
inline string CapitalizeFirstLetter(string s) {
136+
if (s.empty()) {
137+
return s;
138+
}
139+
s[0] = ::toupper(s[0]);
140+
return s;
141+
}
142+
135143
std::string GetServiceClassName(const ServiceDescriptor* service) {
136144
return service->name();
137145
}
@@ -151,7 +159,7 @@ std::string GetServerClassName(const ServiceDescriptor* service) {
151159
std::string GetServiceFieldName() { return "__Service"; }
152160

153161
std::string GetMethodFieldName(const MethodDescriptor* method) {
154-
return "__Method_" + method->name();
162+
return "__Method_" + CapitalizeFirstLetter(method->name());
155163
}
156164

157165
// Gets vector of all messages used as input or output types.
@@ -176,8 +184,8 @@ std::vector<const Descriptor*> GetUsedMessages(
176184

177185
void GenerateStaticMethodField(Printer* out, const MethodDescriptor* method) {
178186
out->Print("public const string $methodfield$ = \"$methodname$\";\n",
179-
"methodfield", GetMethodFieldName(method), "methodname",
180-
method->name());
187+
"methodfield", GetMethodFieldName(method),
188+
"methodname", CapitalizeFirstLetter(method->name()));
181189
}
182190

183191
void GenerateServiceDescriptorProperty(Printer* out,
@@ -216,16 +224,20 @@ void GenerateInterface(Printer* out, const ServiceDescriptor* service) {
216224

217225
if (server_streaming) {
218226
out->Print("IAsyncEnumerable<$output_type$> $method_name$",
219-
"output_type", GetClassName(method->output_type()), "method_name", method->name());
227+
"output_type", GetClassName(method->output_type()),
228+
"method_name", CapitalizeFirstLetter(method->name()));
220229
} else if (client_streaming) {
221230
out->Print("Task<$output_type$> $method_name$",
222-
"output_type", GetClassName(method->output_type()), "method_name", method->name());
231+
"output_type", GetClassName(method->output_type()),
232+
"method_name", CapitalizeFirstLetter(method->name()));
223233
} else {
224234
if (options.fire_and_forget()) {
225-
out->Print("Task $method_name$", "method_name", method->name());
235+
out->Print("Task $method_name$",
236+
"method_name", CapitalizeFirstLetter(method->name()));
226237
} else {
227238
out->Print("Task<$output_type$> $method_name$",
228-
"output_type", GetClassName(method->output_type()), "method_name", method->name());
239+
"output_type", GetClassName(method->output_type()),
240+
"method_name", CapitalizeFirstLetter(method->name()));
229241
}
230242
}
231243

@@ -266,16 +278,20 @@ void GenerateClientClass(Printer* out, const ServiceDescriptor* service) {
266278

267279
if (server_streaming) {
268280
out->Print("public IAsyncEnumerable<$output_type$> $method_name$",
269-
"output_type", GetClassName(method->output_type()), "method_name", method->name());
281+
"output_type", GetClassName(method->output_type()),
282+
"method_name", CapitalizeFirstLetter(method->name()));
270283
} else if (client_streaming) {
271284
out->Print("public Task<$output_type$> $method_name$",
272-
"output_type", GetClassName(method->output_type()), "method_name", method->name());
285+
"output_type", GetClassName(method->output_type()),
286+
"method_name", CapitalizeFirstLetter(method->name()));
273287
} else {
274288
if (options.fire_and_forget()) {
275-
out->Print("public Task $method_name$", "method_name", method->name());
289+
out->Print("public Task $method_name$",
290+
"method_name", CapitalizeFirstLetter(method->name()));
276291
} else {
277292
out->Print("public Task<$output_type$> $method_name$",
278-
"output_type", GetClassName(method->output_type()), "method_name", method->name());
293+
"output_type", GetClassName(method->output_type()),
294+
"method_name", CapitalizeFirstLetter(method->name()));
279295
}
280296
}
281297

@@ -298,7 +314,11 @@ void GenerateClientClass(Printer* out, const ServiceDescriptor* service) {
298314
"servicefield", GetServiceFieldName(),
299315
"methodfield", GetMethodFieldName(method));
300316
} else {
301-
//TODO
317+
out->Print("__RequestChannel(messages, $intransform$, $outtransform$, metadata, service: $servicefield$, method: $methodfield$).SingleAsync().AsTask();\n",
318+
"intransform", "Google.Protobuf.MessageExtensions.ToByteArray",
319+
"outtransform", GetClassName(method->output_type()) + ".Parser.ParseFrom",
320+
"servicefield", GetServiceFieldName(),
321+
"methodfield", GetMethodFieldName(method));
302322
}
303323
} else if (server_streaming) {
304324
out->Print("__RequestStream(message, $intransform$, $outtransform$, metadata, service: $servicefield$, method: $methodfield$);\n",
@@ -366,16 +386,20 @@ void GenerateServerClass(Printer* out, const ServiceDescriptor* service) {
366386

367387
if (server_streaming) {
368388
out->Print("public abstract IAsyncEnumerable<$output_type$> $method_name$",
369-
"output_type", GetClassName(method->output_type()), "method_name", method->name());
389+
"output_type", GetClassName(method->output_type()),
390+
"method_name", CapitalizeFirstLetter(method->name()));
370391
} else if (client_streaming) {
371392
out->Print("public abstract Task<$output_type$> $method_name$",
372-
"output_type", GetClassName(method->output_type()), "method_name", method->name());
393+
"output_type", GetClassName(method->output_type()),
394+
"method_name", CapitalizeFirstLetter(method->name()));
373395
} else {
374396
if (options.fire_and_forget()) {
375-
out->Print("public abstract Task $method_name$", "method_name", method->name());
397+
out->Print("public abstract Task $method_name$",
398+
"method_name", CapitalizeFirstLetter(method->name()));
376399
} else {
377400
out->Print("public abstract Task<$output_type$> $method_name$",
378-
"output_type", GetClassName(method->output_type()), "method_name", method->name());
401+
"output_type", GetClassName(method->output_type()),
402+
"method_name", CapitalizeFirstLetter(method->name()));
379403
}
380404
}
381405

@@ -406,18 +430,33 @@ void GenerateServerClass(Printer* out, const ServiceDescriptor* service) {
406430
bool server_streaming = method->server_streaming();
407431

408432
if (client_streaming) {
409-
out->Print("case $methodfield$: return from result in service.$method_name$(from message in messages select $input_type$.Parser.ParseFrom(data.ToArray()), metadata) select Google.Protobuf.MessageExtensions.ToByteArray(result);\n",
410-
"methodfield", GetMethodFieldName(method), "method_name", method->name(), "input_type", GetClassName(method->input_type()));
433+
if (server_streaming) {
434+
out->Print("case $methodfield$: return from result in service.$method_name$(from message in messages select $input_type$.Parser.ParseFrom(data.ToArray()), metadata) select Google.Protobuf.MessageExtensions.ToByteArray(result);\n",
435+
"methodfield", GetMethodFieldName(method),
436+
"method_name", CapitalizeFirstLetter(method->name()),
437+
"input_type", GetClassName(method->input_type()));
438+
} else {
439+
out->Print("case $methodfield$: return from result in service.$method_name$(from message in messages select $input_type$.Parser.ParseFrom(data.ToArray()), metadata).ToAsyncEnumerable() select Google.Protobuf.MessageExtensions.ToByteArray(result);\n",
440+
"methodfield", GetMethodFieldName(method),
441+
"method_name", CapitalizeFirstLetter(method->name()),
442+
"input_type", GetClassName(method->input_type()));
443+
}
411444
} else if (server_streaming) {
412445
out->Print("case $methodfield$: return from result in service.$method_name$($input_type$.Parser.ParseFrom(data.ToArray()), metadata) select Google.Protobuf.MessageExtensions.ToByteArray(result);\n",
413-
"methodfield", GetMethodFieldName(method), "method_name", method->name(), "input_type", GetClassName(method->input_type()));
446+
"methodfield", GetMethodFieldName(method),
447+
"method_name", CapitalizeFirstLetter(method->name()),
448+
"input_type", GetClassName(method->input_type()));
414449
} else {
415450
if (options.fire_and_forget()) {
416-
out->Print("case $methodfield$: return AsyncEnumerable.Empty<byte[]>();\n",
417-
"methodfield", GetMethodFieldName(method));
451+
out->Print("case $methodfield$: service.$method_name$($input_type$.Parser.ParseFrom(data.ToArray()), metadata); return AsyncEnumerable.Empty<byte[]>();\n",
452+
"methodfield", GetMethodFieldName(method),
453+
"method_name", CapitalizeFirstLetter(method->name()),
454+
"input_type", GetClassName(method->input_type()));
418455
} else {
419456
out->Print("case $methodfield$: return from result in service.$method_name$($input_type$.Parser.ParseFrom(data.ToArray()), metadata).ToAsyncEnumerable() select Google.Protobuf.MessageExtensions.ToByteArray(result);\n",
420-
"methodfield", GetMethodFieldName(method), "method_name", method->name(), "input_type", GetClassName(method->input_type()));
457+
"methodfield", GetMethodFieldName(method),
458+
"method_name", CapitalizeFirstLetter(method->name()),
459+
"input_type", GetClassName(method->input_type()));
421460
}
422461
}
423462
}

RSocket.Rpc.Sample/EchoService.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ static IAsyncEnumerable<byte[]> Dispatch(IEchoService service, ReadOnlySequence<
113113
{
114114
switch (method)
115115
{
116-
case Method_fireAndForget: return AsyncEnumerable.Empty<byte[]>();
116+
case Method_fireAndForget: service.FireAndForget(Google.Protobuf.WellKnownTypes.Value.Parser.ParseFrom(data.ToArray()), metadata); return AsyncEnumerable.Empty<byte[]>();
117117
case Method_requestResponse: return from result in service.RequestResponse(Google.Protobuf.WellKnownTypes.Value.Parser.ParseFrom(data.ToArray()), metadata).ToAsyncEnumerable() select Google.Protobuf.MessageExtensions.ToByteArray(result);
118118
case Method_requestStream: return from result in service.RequestStream(Google.Protobuf.WellKnownTypes.Value.Parser.ParseFrom(data.ToArray()), metadata) select Google.Protobuf.MessageExtensions.ToByteArray(result);
119119
case Method_requestChannel: return from result in service.RequestChannel(from message in messages select Google.Protobuf.WellKnownTypes.Value.Parser.ParseFrom(message), metadata) select Google.Protobuf.MessageExtensions.ToByteArray(result);

0 commit comments

Comments
 (0)