Skip to content

Commit 227f440

Browse files
authored
Support progress notifications with McpClientTool and McpClientExtensions.CallToolAsync (#312)
1 parent 90748b1 commit 227f440

File tree

8 files changed

+174
-47
lines changed

8 files changed

+174
-47
lines changed

src/ModelContextProtocol/Client/McpClientExtensions.cs

+52-2
Original file line numberDiff line numberDiff line change
@@ -783,9 +783,14 @@ public static Task UnsubscribeFromResourceAsync(this IMcpClient client, Uri uri,
783783
/// </summary>
784784
/// <param name="client">The client instance used to communicate with the MCP server.</param>
785785
/// <param name="toolName">The name of the tool to call on the server..</param>
786-
/// <param name="arguments">Optional dictionary of arguments to pass to the tool. Each key represents a parameter name,
786+
/// <param name="arguments">An optional dictionary of arguments to pass to the tool. Each key represents a parameter name,
787787
/// and its associated value represents the argument value.
788788
/// </param>
789+
/// <param name="progress">
790+
/// An optional <see cref="IProgress{T}"/> to have progress notifications reported to it. Setting this to a non-<see langword="null"/>
791+
/// value will result in a progress token being included in the call, and any resulting progress notifications during the operation
792+
/// routed to this instance.
793+
/// </param>
789794
/// <param name="serializerOptions">
790795
/// The JSON serialization options governing argument serialization. If <see langword="null"/>, the default serialization options will be used.
791796
/// </param>
@@ -812,6 +817,7 @@ public static Task<CallToolResponse> CallToolAsync(
812817
this IMcpClient client,
813818
string toolName,
814819
IReadOnlyDictionary<string, object?>? arguments = null,
820+
IProgress<ProgressNotificationValue>? progress = null,
815821
JsonSerializerOptions? serializerOptions = null,
816822
CancellationToken cancellationToken = default)
817823
{
@@ -820,12 +826,56 @@ public static Task<CallToolResponse> CallToolAsync(
820826
serializerOptions ??= McpJsonUtilities.DefaultOptions;
821827
serializerOptions.MakeReadOnly();
822828

829+
if (progress is not null)
830+
{
831+
return SendRequestWithProgressAsync(client, toolName, arguments, progress, serializerOptions, cancellationToken);
832+
}
833+
823834
return client.SendRequestAsync(
824835
RequestMethods.ToolsCall,
825-
new() { Name = toolName, Arguments = ToArgumentsDictionary(arguments, serializerOptions) },
836+
new()
837+
{
838+
Name = toolName,
839+
Arguments = ToArgumentsDictionary(arguments, serializerOptions),
840+
},
826841
McpJsonUtilities.JsonContext.Default.CallToolRequestParams,
827842
McpJsonUtilities.JsonContext.Default.CallToolResponse,
828843
cancellationToken: cancellationToken);
844+
845+
static async Task<CallToolResponse> SendRequestWithProgressAsync(
846+
IMcpClient client,
847+
string toolName,
848+
IReadOnlyDictionary<string, object?>? arguments,
849+
IProgress<ProgressNotificationValue> progress,
850+
JsonSerializerOptions serializerOptions,
851+
CancellationToken cancellationToken)
852+
{
853+
ProgressToken progressToken = new(Guid.NewGuid().ToString("N"));
854+
855+
await using var _ = client.RegisterNotificationHandler(NotificationMethods.ProgressNotification,
856+
(notification, cancellationToken) =>
857+
{
858+
if (JsonSerializer.Deserialize(notification.Params, McpJsonUtilities.JsonContext.Default.ProgressNotification) is { } pn &&
859+
pn.ProgressToken == progressToken)
860+
{
861+
progress.Report(pn.Progress);
862+
}
863+
864+
return default;
865+
}).ConfigureAwait(false);
866+
867+
return await client.SendRequestAsync(
868+
RequestMethods.ToolsCall,
869+
new()
870+
{
871+
Name = toolName,
872+
Arguments = ToArgumentsDictionary(arguments, serializerOptions),
873+
Meta = new() { ProgressToken = progressToken },
874+
},
875+
McpJsonUtilities.JsonContext.Default.CallToolRequestParams,
876+
McpJsonUtilities.JsonContext.Default.CallToolResponse,
877+
cancellationToken: cancellationToken).ConfigureAwait(false);
878+
}
829879
}
830880

831881
/// <summary>

src/ModelContextProtocol/Client/McpClientTool.cs

+39-6
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
using Microsoft.Extensions.AI;
12
using ModelContextProtocol.Protocol.Types;
3+
using ModelContextProtocol.Utils;
24
using ModelContextProtocol.Utils.Json;
3-
using Microsoft.Extensions.AI;
4-
using System.Text.Json;
55
using System.Collections.ObjectModel;
6+
using System.Text.Json;
67

78
namespace ModelContextProtocol.Client;
89

@@ -36,14 +37,22 @@ public sealed class McpClientTool : AIFunction
3637
private readonly IMcpClient _client;
3738
private readonly string _name;
3839
private readonly string _description;
40+
private readonly IProgress<ProgressNotificationValue>? _progress;
3941

40-
internal McpClientTool(IMcpClient client, Tool tool, JsonSerializerOptions serializerOptions, string? name = null, string? description = null)
42+
internal McpClientTool(
43+
IMcpClient client,
44+
Tool tool,
45+
JsonSerializerOptions serializerOptions,
46+
string? name = null,
47+
string? description = null,
48+
IProgress<ProgressNotificationValue>? progress = null)
4149
{
4250
_client = client;
4351
ProtocolTool = tool;
4452
JsonSerializerOptions = serializerOptions;
4553
_name = name ?? tool.Name;
4654
_description = description ?? tool.Description ?? string.Empty;
55+
_progress = progress;
4756
}
4857

4958
/// <summary>
@@ -77,7 +86,7 @@ internal McpClientTool(IMcpClient client, Tool tool, JsonSerializerOptions seria
7786
protected async override ValueTask<object?> InvokeCoreAsync(
7887
AIFunctionArguments arguments, CancellationToken cancellationToken)
7988
{
80-
CallToolResponse result = await _client.CallToolAsync(ProtocolTool.Name, arguments, JsonSerializerOptions, cancellationToken: cancellationToken).ConfigureAwait(false);
89+
CallToolResponse result = await _client.CallToolAsync(ProtocolTool.Name, arguments, _progress, JsonSerializerOptions, cancellationToken: cancellationToken).ConfigureAwait(false);
8190
return JsonSerializer.SerializeToElement(result, McpJsonUtilities.JsonContext.Default.CallToolResponse);
8291
}
8392

@@ -107,7 +116,7 @@ internal McpClientTool(IMcpClient client, Tool tool, JsonSerializerOptions seria
107116
/// </remarks>
108117
public McpClientTool WithName(string name)
109118
{
110-
return new McpClientTool(_client, ProtocolTool, JsonSerializerOptions, name, _description);
119+
return new McpClientTool(_client, ProtocolTool, JsonSerializerOptions, name, _description, _progress);
111120
}
112121

113122
/// <summary>
@@ -133,6 +142,30 @@ public McpClientTool WithName(string name)
133142
/// <returns>A new instance of <see cref="McpClientTool"/> with the provided description.</returns>
134143
public McpClientTool WithDescription(string description)
135144
{
136-
return new McpClientTool(_client, ProtocolTool, JsonSerializerOptions, _name, description);
145+
return new McpClientTool(_client, ProtocolTool, JsonSerializerOptions, _name, description, _progress);
146+
}
147+
148+
/// <summary>
149+
/// Creates a new instance of the tool but modified to report progress via the specified <see cref="IProgress{T}"/>.
150+
/// </summary>
151+
/// <param name="progress">The <see cref="IProgress{T}"/> to which progress notifications should be reported.</param>
152+
/// <remarks>
153+
/// <para>
154+
/// Adding an <see cref="IProgress{T}"/> to the tool does not impact how it is reported to any AI model.
155+
/// Rather, when the tool is invoked, the request to the MCP server will include a unique progress token,
156+
/// and any progress notifications issued by the server with that progress token while the operation is in
157+
/// flight will be reported to the <paramref name="progress"/> instance.
158+
/// </para>
159+
/// <para>
160+
/// Only one <see cref="IProgress{T}"/> can be specified at a time. Calling <see cref="WithProgress"/> again
161+
/// will overwrite any previously specified progress instance.
162+
/// </para>
163+
/// </remarks>
164+
/// <returns>A new instance of <see cref="McpClientTool"/>, configured with the provided progress instance.</returns>
165+
public McpClientTool WithProgress(IProgress<ProgressNotificationValue> progress)
166+
{
167+
Throw.IfNull(progress);
168+
169+
return new McpClientTool(_client, ProtocolTool, JsonSerializerOptions, _name, _description, progress);
137170
}
138171
}

tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs

+55-11
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer
2929
}
3030
services.AddSingleton(McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)] (string i) => $"{i} Result", new() { Name = "ValuesSetViaAttr" }));
3131
services.AddSingleton(McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)] (string i) => $"{i} Result", new() { Name = "ValuesSetViaOptions", Destructive = true, OpenWorld = false, ReadOnly = true }));
32-
3332
}
3433

3534
[Theory]
@@ -209,7 +208,7 @@ public async Task CreateSamplingHandler_ShouldHandleResourceMessages()
209208
[Fact]
210209
public async Task ListToolsAsync_AllToolsReturned()
211210
{
212-
IMcpClient client = await CreateMcpClientForServer();
211+
await using IMcpClient client = await CreateMcpClientForServer();
213212

214213
var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);
215214
Assert.Equal(12, tools.Count);
@@ -235,7 +234,7 @@ public async Task ListToolsAsync_AllToolsReturned()
235234
[Fact]
236235
public async Task EnumerateToolsAsync_AllToolsReturned()
237236
{
238-
IMcpClient client = await CreateMcpClientForServer();
237+
await using IMcpClient client = await CreateMcpClientForServer();
239238

240239
await foreach (var tool in client.EnumerateToolsAsync(cancellationToken: TestContext.Current.CancellationToken))
241240
{
@@ -254,7 +253,7 @@ public async Task EnumerateToolsAsync_AllToolsReturned()
254253
public async Task EnumerateToolsAsync_FlowsJsonSerializerOptions()
255254
{
256255
JsonSerializerOptions options = new(JsonSerializerOptions.Default);
257-
IMcpClient client = await CreateMcpClientForServer();
256+
await using IMcpClient client = await CreateMcpClientForServer();
258257
bool hasTools = false;
259258

260259
await foreach (var tool in client.EnumerateToolsAsync(options, TestContext.Current.CancellationToken))
@@ -275,7 +274,7 @@ public async Task EnumerateToolsAsync_FlowsJsonSerializerOptions()
275274
public async Task EnumerateToolsAsync_HonorsJsonSerializerOptions()
276275
{
277276
JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() };
278-
IMcpClient client = await CreateMcpClientForServer();
277+
await using IMcpClient client = await CreateMcpClientForServer();
279278

280279
var tool = (await client.ListToolsAsync(emptyOptions, TestContext.Current.CancellationToken)).First();
281280
await Assert.ThrowsAsync<NotSupportedException>(async () => await tool.InvokeAsync(new() { ["i"] = 42 }, TestContext.Current.CancellationToken));
@@ -285,7 +284,7 @@ public async Task EnumerateToolsAsync_HonorsJsonSerializerOptions()
285284
public async Task SendRequestAsync_HonorsJsonSerializerOptions()
286285
{
287286
JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() };
288-
IMcpClient client = await CreateMcpClientForServer();
287+
await using IMcpClient client = await CreateMcpClientForServer();
289288

290289
await Assert.ThrowsAsync<NotSupportedException>(() => client.SendRequestAsync<CallToolRequestParams, CallToolResponse>("Method4", new() { Name = "tool" }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken));
291290
}
@@ -294,7 +293,7 @@ public async Task SendRequestAsync_HonorsJsonSerializerOptions()
294293
public async Task SendNotificationAsync_HonorsJsonSerializerOptions()
295294
{
296295
JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() };
297-
IMcpClient client = await CreateMcpClientForServer();
296+
await using IMcpClient client = await CreateMcpClientForServer();
298297

299298
await Assert.ThrowsAsync<NotSupportedException>(() => client.SendNotificationAsync("Method4", new { Value = 42 }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken));
300299
}
@@ -303,7 +302,7 @@ public async Task SendNotificationAsync_HonorsJsonSerializerOptions()
303302
public async Task GetPromptsAsync_HonorsJsonSerializerOptions()
304303
{
305304
JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() };
306-
IMcpClient client = await CreateMcpClientForServer();
305+
await using IMcpClient client = await CreateMcpClientForServer();
307306

308307
await Assert.ThrowsAsync<NotSupportedException>(() => client.GetPromptAsync("Prompt", new Dictionary<string, object?> { ["i"] = 42 }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken));
309308
}
@@ -312,7 +311,7 @@ public async Task GetPromptsAsync_HonorsJsonSerializerOptions()
312311
public async Task WithName_ChangesToolName()
313312
{
314313
JsonSerializerOptions options = new(JsonSerializerOptions.Default);
315-
IMcpClient client = await CreateMcpClientForServer();
314+
await using IMcpClient client = await CreateMcpClientForServer();
316315

317316
var tool = (await client.ListToolsAsync(options, TestContext.Current.CancellationToken)).First();
318317
var originalName = tool.Name;
@@ -327,7 +326,7 @@ public async Task WithName_ChangesToolName()
327326
public async Task WithDescription_ChangesToolDescription()
328327
{
329328
JsonSerializerOptions options = new(JsonSerializerOptions.Default);
330-
IMcpClient client = await CreateMcpClientForServer();
329+
await using IMcpClient client = await CreateMcpClientForServer();
331330
var tool = (await client.ListToolsAsync(options, TestContext.Current.CancellationToken)).FirstOrDefault();
332331
var originalDescription = tool?.Description;
333332
var redescribedTool = tool?.WithDescription("ToolWithNewDescription");
@@ -336,10 +335,55 @@ public async Task WithDescription_ChangesToolDescription()
336335
Assert.Equal(originalDescription, tool?.Description);
337336
}
338337

338+
[Fact]
339+
public async Task WithProgress_ProgressReported()
340+
{
341+
const int TotalNotifications = 3;
342+
int remainingProgress = TotalNotifications;
343+
TaskCompletionSource<bool> allProgressReceived = new(TaskCreationOptions.RunContinuationsAsynchronously);
344+
345+
Server.ServerOptions.Capabilities?.Tools?.ToolCollection?.Add(McpServerTool.Create(async (IProgress<ProgressNotificationValue> progress) =>
346+
{
347+
for (int i = 0; i < TotalNotifications; i++)
348+
{
349+
progress.Report(new ProgressNotificationValue { Progress = i * 10, Message = "making progress" });
350+
await Task.Delay(1);
351+
}
352+
353+
await allProgressReceived.Task;
354+
355+
return 42;
356+
}, new() { Name = "ProgressReporter" }));
357+
358+
await using IMcpClient client = await CreateMcpClientForServer();
359+
360+
var tool = (await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken)).First(t => t.Name == "ProgressReporter");
361+
362+
IProgress<ProgressNotificationValue> progress = new SynchronousProgress(value =>
363+
{
364+
Assert.True(value.Progress >= 0 && value.Progress <= 100);
365+
Assert.Equal("making progress", value.Message);
366+
if (Interlocked.Decrement(ref remainingProgress) == 0)
367+
{
368+
allProgressReceived.SetResult(true);
369+
}
370+
});
371+
372+
Assert.Throws<ArgumentNullException>("progress", () => tool.WithProgress(null!));
373+
374+
var result = await tool.WithProgress(progress).InvokeAsync(cancellationToken: TestContext.Current.CancellationToken);
375+
Assert.Contains("42", result?.ToString());
376+
}
377+
378+
private sealed class SynchronousProgress(Action<ProgressNotificationValue> callback) : IProgress<ProgressNotificationValue>
379+
{
380+
public void Report(ProgressNotificationValue value) => callback(value);
381+
}
382+
339383
[Fact]
340384
public async Task AsClientLoggerProvider_MessagesSentToClient()
341385
{
342-
IMcpClient client = await CreateMcpClientForServer();
386+
await using IMcpClient client = await CreateMcpClientForServer();
343387

344388
ILoggerProvider loggerProvider = Server.AsClientLoggerProvider();
345389
Assert.Throws<ArgumentNullException>("categoryName", () => loggerProvider.CreateLogger(null!));

tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs

+5-5
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ public void Adds_Prompts_To_Server()
9696
[Fact]
9797
public async Task Can_List_And_Call_Registered_Prompts()
9898
{
99-
IMcpClient client = await CreateMcpClientForServer();
99+
await using IMcpClient client = await CreateMcpClientForServer();
100100

101101
var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken);
102102
Assert.Equal(6, prompts.Count);
@@ -125,7 +125,7 @@ public async Task Can_List_And_Call_Registered_Prompts()
125125
[Fact]
126126
public async Task Can_Be_Notified_Of_Prompt_Changes()
127127
{
128-
IMcpClient client = await CreateMcpClientForServer();
128+
await using IMcpClient client = await CreateMcpClientForServer();
129129

130130
var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken);
131131
Assert.Equal(6, prompts.Count);
@@ -166,7 +166,7 @@ public async Task Can_Be_Notified_Of_Prompt_Changes()
166166
[Fact]
167167
public async Task Throws_When_Prompt_Fails()
168168
{
169-
IMcpClient client = await CreateMcpClientForServer();
169+
await using IMcpClient client = await CreateMcpClientForServer();
170170

171171
await Assert.ThrowsAsync<McpException>(async () => await client.GetPromptAsync(
172172
nameof(SimplePrompts.ThrowsException),
@@ -176,7 +176,7 @@ await Assert.ThrowsAsync<McpException>(async () => await client.GetPromptAsync(
176176
[Fact]
177177
public async Task Throws_Exception_On_Unknown_Prompt()
178178
{
179-
IMcpClient client = await CreateMcpClientForServer();
179+
await using IMcpClient client = await CreateMcpClientForServer();
180180

181181
var e = await Assert.ThrowsAsync<McpException>(async () => await client.GetPromptAsync(
182182
"NotRegisteredPrompt",
@@ -188,7 +188,7 @@ public async Task Throws_Exception_On_Unknown_Prompt()
188188
[Fact]
189189
public async Task Throws_Exception_Missing_Parameter()
190190
{
191-
IMcpClient client = await CreateMcpClientForServer();
191+
await using IMcpClient client = await CreateMcpClientForServer();
192192

193193
var e = await Assert.ThrowsAsync<McpException>(async () => await client.GetPromptAsync(
194194
nameof(SimplePrompts.ReturnsChatMessages),

0 commit comments

Comments
 (0)