Skip to content

Support progress notifications with McpClientTool and McpClientExtensions.CallToolAsync #312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions src/ModelContextProtocol/Client/McpClientExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -783,9 +783,14 @@ public static Task UnsubscribeFromResourceAsync(this IMcpClient client, Uri uri,
/// </summary>
/// <param name="client">The client instance used to communicate with the MCP server.</param>
/// <param name="toolName">The name of the tool to call on the server..</param>
/// <param name="arguments">Optional dictionary of arguments to pass to the tool. Each key represents a parameter name,
/// <param name="arguments">An optional dictionary of arguments to pass to the tool. Each key represents a parameter name,
/// and its associated value represents the argument value.
/// </param>
/// <param name="progress">
/// An optional <see cref="IProgress{T}"/> to have progress notifications reported to it. Setting this to a non-<see langword="null"/>
/// value will result in a progress token being included in the call, and any resulting progress notifications during the operation
/// routed to this instance.
/// </param>
/// <param name="serializerOptions">
/// The JSON serialization options governing argument serialization. If <see langword="null"/>, the default serialization options will be used.
/// </param>
Expand All @@ -812,6 +817,7 @@ public static Task<CallToolResponse> CallToolAsync(
this IMcpClient client,
string toolName,
IReadOnlyDictionary<string, object?>? arguments = null,
IProgress<ProgressNotificationValue>? progress = null,
JsonSerializerOptions? serializerOptions = null,
CancellationToken cancellationToken = default)
{
Expand All @@ -820,12 +826,56 @@ public static Task<CallToolResponse> CallToolAsync(
serializerOptions ??= McpJsonUtilities.DefaultOptions;
serializerOptions.MakeReadOnly();

if (progress is not null)
{
return SendRequestWithProgressAsync(client, toolName, arguments, progress, serializerOptions, cancellationToken);
}

return client.SendRequestAsync(
RequestMethods.ToolsCall,
new() { Name = toolName, Arguments = ToArgumentsDictionary(arguments, serializerOptions) },
new()
{
Name = toolName,
Arguments = ToArgumentsDictionary(arguments, serializerOptions),
},
McpJsonUtilities.JsonContext.Default.CallToolRequestParams,
McpJsonUtilities.JsonContext.Default.CallToolResponse,
cancellationToken: cancellationToken);

static async Task<CallToolResponse> SendRequestWithProgressAsync(
IMcpClient client,
string toolName,
IReadOnlyDictionary<string, object?>? arguments,
IProgress<ProgressNotificationValue> progress,
JsonSerializerOptions serializerOptions,
CancellationToken cancellationToken)
{
ProgressToken progressToken = new(Guid.NewGuid().ToString("N"));

await using var _ = client.RegisterNotificationHandler(NotificationMethods.ProgressNotification,
(notification, cancellationToken) =>
{
if (JsonSerializer.Deserialize(notification.Params, McpJsonUtilities.JsonContext.Default.ProgressNotification) is { } pn &&
pn.ProgressToken == progressToken)
{
progress.Report(pn.Progress);
}

return default;
}).ConfigureAwait(false);

return await client.SendRequestAsync(
RequestMethods.ToolsCall,
new()
{
Name = toolName,
Arguments = ToArgumentsDictionary(arguments, serializerOptions),
Meta = new() { ProgressToken = progressToken },
},
McpJsonUtilities.JsonContext.Default.CallToolRequestParams,
McpJsonUtilities.JsonContext.Default.CallToolResponse,
cancellationToken: cancellationToken).ConfigureAwait(false);
}
}

/// <summary>
Expand Down
45 changes: 39 additions & 6 deletions src/ModelContextProtocol/Client/McpClientTool.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
using Microsoft.Extensions.AI;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Utils;
using ModelContextProtocol.Utils.Json;
using Microsoft.Extensions.AI;
using System.Text.Json;
using System.Collections.ObjectModel;
using System.Text.Json;

namespace ModelContextProtocol.Client;

Expand Down Expand Up @@ -36,14 +37,22 @@ public sealed class McpClientTool : AIFunction
private readonly IMcpClient _client;
private readonly string _name;
private readonly string _description;
private readonly IProgress<ProgressNotificationValue>? _progress;

internal McpClientTool(IMcpClient client, Tool tool, JsonSerializerOptions serializerOptions, string? name = null, string? description = null)
internal McpClientTool(
IMcpClient client,
Tool tool,
JsonSerializerOptions serializerOptions,
string? name = null,
string? description = null,
IProgress<ProgressNotificationValue>? progress = null)
{
_client = client;
ProtocolTool = tool;
JsonSerializerOptions = serializerOptions;
_name = name ?? tool.Name;
_description = description ?? tool.Description ?? string.Empty;
_progress = progress;
}

/// <summary>
Expand Down Expand Up @@ -77,7 +86,7 @@ internal McpClientTool(IMcpClient client, Tool tool, JsonSerializerOptions seria
protected async override ValueTask<object?> InvokeCoreAsync(
AIFunctionArguments arguments, CancellationToken cancellationToken)
{
CallToolResponse result = await _client.CallToolAsync(ProtocolTool.Name, arguments, JsonSerializerOptions, cancellationToken: cancellationToken).ConfigureAwait(false);
CallToolResponse result = await _client.CallToolAsync(ProtocolTool.Name, arguments, _progress, JsonSerializerOptions, cancellationToken: cancellationToken).ConfigureAwait(false);
return JsonSerializer.SerializeToElement(result, McpJsonUtilities.JsonContext.Default.CallToolResponse);
}

Expand Down Expand Up @@ -107,7 +116,7 @@ internal McpClientTool(IMcpClient client, Tool tool, JsonSerializerOptions seria
/// </remarks>
public McpClientTool WithName(string name)
{
return new McpClientTool(_client, ProtocolTool, JsonSerializerOptions, name, _description);
return new McpClientTool(_client, ProtocolTool, JsonSerializerOptions, name, _description, _progress);
}

/// <summary>
Expand All @@ -133,6 +142,30 @@ public McpClientTool WithName(string name)
/// <returns>A new instance of <see cref="McpClientTool"/> with the provided description.</returns>
public McpClientTool WithDescription(string description)
{
return new McpClientTool(_client, ProtocolTool, JsonSerializerOptions, _name, description);
return new McpClientTool(_client, ProtocolTool, JsonSerializerOptions, _name, description, _progress);
}

/// <summary>
/// Creates a new instance of the tool but modified to report progress via the specified <see cref="IProgress{T}"/>.
/// </summary>
/// <param name="progress">The <see cref="IProgress{T}"/> to which progress notifications should be reported.</param>
/// <remarks>
/// <para>
/// Adding an <see cref="IProgress{T}"/> to the tool does not impact how it is reported to any AI model.
/// Rather, when the tool is invoked, the request to the MCP server will include a unique progress token,
/// and any progress notifications issued by the server with that progress token while the operation is in
/// flight will be reported to the <paramref name="progress"/> instance.
/// </para>
/// <para>
/// Only one <see cref="IProgress{T}"/> can be specified at a time. Calling <see cref="WithProgress"/> again
/// will overwrite any previously specified progress instance.
/// </para>
/// </remarks>
/// <returns>A new instance of <see cref="McpClientTool"/>, configured with the provided progress instance.</returns>
public McpClientTool WithProgress(IProgress<ProgressNotificationValue> progress)
{
Throw.IfNull(progress);

return new McpClientTool(_client, ProtocolTool, JsonSerializerOptions, _name, _description, progress);
}
}
66 changes: 55 additions & 11 deletions tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer
}
services.AddSingleton(McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)] (string i) => $"{i} Result", new() { Name = "ValuesSetViaAttr" }));
services.AddSingleton(McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)] (string i) => $"{i} Result", new() { Name = "ValuesSetViaOptions", Destructive = true, OpenWorld = false, ReadOnly = true }));

}

[Theory]
Expand Down Expand Up @@ -209,7 +208,7 @@ public async Task CreateSamplingHandler_ShouldHandleResourceMessages()
[Fact]
public async Task ListToolsAsync_AllToolsReturned()
{
IMcpClient client = await CreateMcpClientForServer();
await using IMcpClient client = await CreateMcpClientForServer();

var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);
Assert.Equal(12, tools.Count);
Expand All @@ -235,7 +234,7 @@ public async Task ListToolsAsync_AllToolsReturned()
[Fact]
public async Task EnumerateToolsAsync_AllToolsReturned()
{
IMcpClient client = await CreateMcpClientForServer();
await using IMcpClient client = await CreateMcpClientForServer();

await foreach (var tool in client.EnumerateToolsAsync(cancellationToken: TestContext.Current.CancellationToken))
{
Expand All @@ -254,7 +253,7 @@ public async Task EnumerateToolsAsync_AllToolsReturned()
public async Task EnumerateToolsAsync_FlowsJsonSerializerOptions()
{
JsonSerializerOptions options = new(JsonSerializerOptions.Default);
IMcpClient client = await CreateMcpClientForServer();
await using IMcpClient client = await CreateMcpClientForServer();
bool hasTools = false;

await foreach (var tool in client.EnumerateToolsAsync(options, TestContext.Current.CancellationToken))
Expand All @@ -275,7 +274,7 @@ public async Task EnumerateToolsAsync_FlowsJsonSerializerOptions()
public async Task EnumerateToolsAsync_HonorsJsonSerializerOptions()
{
JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() };
IMcpClient client = await CreateMcpClientForServer();
await using IMcpClient client = await CreateMcpClientForServer();

var tool = (await client.ListToolsAsync(emptyOptions, TestContext.Current.CancellationToken)).First();
await Assert.ThrowsAsync<NotSupportedException>(async () => await tool.InvokeAsync(new() { ["i"] = 42 }, TestContext.Current.CancellationToken));
Expand All @@ -285,7 +284,7 @@ public async Task EnumerateToolsAsync_HonorsJsonSerializerOptions()
public async Task SendRequestAsync_HonorsJsonSerializerOptions()
{
JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() };
IMcpClient client = await CreateMcpClientForServer();
await using IMcpClient client = await CreateMcpClientForServer();

await Assert.ThrowsAsync<NotSupportedException>(() => client.SendRequestAsync<CallToolRequestParams, CallToolResponse>("Method4", new() { Name = "tool" }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken));
}
Expand All @@ -294,7 +293,7 @@ public async Task SendRequestAsync_HonorsJsonSerializerOptions()
public async Task SendNotificationAsync_HonorsJsonSerializerOptions()
{
JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() };
IMcpClient client = await CreateMcpClientForServer();
await using IMcpClient client = await CreateMcpClientForServer();

await Assert.ThrowsAsync<NotSupportedException>(() => client.SendNotificationAsync("Method4", new { Value = 42 }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken));
}
Expand All @@ -303,7 +302,7 @@ public async Task SendNotificationAsync_HonorsJsonSerializerOptions()
public async Task GetPromptsAsync_HonorsJsonSerializerOptions()
{
JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() };
IMcpClient client = await CreateMcpClientForServer();
await using IMcpClient client = await CreateMcpClientForServer();

await Assert.ThrowsAsync<NotSupportedException>(() => client.GetPromptAsync("Prompt", new Dictionary<string, object?> { ["i"] = 42 }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken));
}
Expand All @@ -312,7 +311,7 @@ public async Task GetPromptsAsync_HonorsJsonSerializerOptions()
public async Task WithName_ChangesToolName()
{
JsonSerializerOptions options = new(JsonSerializerOptions.Default);
IMcpClient client = await CreateMcpClientForServer();
await using IMcpClient client = await CreateMcpClientForServer();

var tool = (await client.ListToolsAsync(options, TestContext.Current.CancellationToken)).First();
var originalName = tool.Name;
Expand All @@ -327,7 +326,7 @@ public async Task WithName_ChangesToolName()
public async Task WithDescription_ChangesToolDescription()
{
JsonSerializerOptions options = new(JsonSerializerOptions.Default);
IMcpClient client = await CreateMcpClientForServer();
await using IMcpClient client = await CreateMcpClientForServer();
var tool = (await client.ListToolsAsync(options, TestContext.Current.CancellationToken)).FirstOrDefault();
var originalDescription = tool?.Description;
var redescribedTool = tool?.WithDescription("ToolWithNewDescription");
Expand All @@ -336,10 +335,55 @@ public async Task WithDescription_ChangesToolDescription()
Assert.Equal(originalDescription, tool?.Description);
}

[Fact]
public async Task WithProgress_ProgressReported()
{
const int TotalNotifications = 3;
int remainingProgress = TotalNotifications;
TaskCompletionSource<bool> allProgressReceived = new(TaskCreationOptions.RunContinuationsAsynchronously);

Server.ServerOptions.Capabilities?.Tools?.ToolCollection?.Add(McpServerTool.Create(async (IProgress<ProgressNotificationValue> progress) =>
{
for (int i = 0; i < TotalNotifications; i++)
{
progress.Report(new ProgressNotificationValue { Progress = i * 10, Message = "making progress" });
await Task.Delay(1);
}

await allProgressReceived.Task;

return 42;
}, new() { Name = "ProgressReporter" }));

await using IMcpClient client = await CreateMcpClientForServer();

var tool = (await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken)).First(t => t.Name == "ProgressReporter");

IProgress<ProgressNotificationValue> progress = new SynchronousProgress(value =>
{
Assert.True(value.Progress >= 0 && value.Progress <= 100);
Assert.Equal("making progress", value.Message);
if (Interlocked.Decrement(ref remainingProgress) == 0)
{
allProgressReceived.SetResult(true);
}
});

Assert.Throws<ArgumentNullException>("progress", () => tool.WithProgress(null!));

var result = await tool.WithProgress(progress).InvokeAsync(cancellationToken: TestContext.Current.CancellationToken);
Assert.Contains("42", result?.ToString());
}

private sealed class SynchronousProgress(Action<ProgressNotificationValue> callback) : IProgress<ProgressNotificationValue>
{
public void Report(ProgressNotificationValue value) => callback(value);
}

[Fact]
public async Task AsClientLoggerProvider_MessagesSentToClient()
{
IMcpClient client = await CreateMcpClientForServer();
await using IMcpClient client = await CreateMcpClientForServer();

ILoggerProvider loggerProvider = Server.AsClientLoggerProvider();
Assert.Throws<ArgumentNullException>("categoryName", () => loggerProvider.CreateLogger(null!));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public void Adds_Prompts_To_Server()
[Fact]
public async Task Can_List_And_Call_Registered_Prompts()
{
IMcpClient client = await CreateMcpClientForServer();
await using IMcpClient client = await CreateMcpClientForServer();

var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken);
Assert.Equal(6, prompts.Count);
Expand Down Expand Up @@ -125,7 +125,7 @@ public async Task Can_List_And_Call_Registered_Prompts()
[Fact]
public async Task Can_Be_Notified_Of_Prompt_Changes()
{
IMcpClient client = await CreateMcpClientForServer();
await using IMcpClient client = await CreateMcpClientForServer();

var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken);
Assert.Equal(6, prompts.Count);
Expand Down Expand Up @@ -166,7 +166,7 @@ public async Task Can_Be_Notified_Of_Prompt_Changes()
[Fact]
public async Task Throws_When_Prompt_Fails()
{
IMcpClient client = await CreateMcpClientForServer();
await using IMcpClient client = await CreateMcpClientForServer();

await Assert.ThrowsAsync<McpException>(async () => await client.GetPromptAsync(
nameof(SimplePrompts.ThrowsException),
Expand All @@ -176,7 +176,7 @@ await Assert.ThrowsAsync<McpException>(async () => await client.GetPromptAsync(
[Fact]
public async Task Throws_Exception_On_Unknown_Prompt()
{
IMcpClient client = await CreateMcpClientForServer();
await using IMcpClient client = await CreateMcpClientForServer();

var e = await Assert.ThrowsAsync<McpException>(async () => await client.GetPromptAsync(
"NotRegisteredPrompt",
Expand All @@ -188,7 +188,7 @@ public async Task Throws_Exception_On_Unknown_Prompt()
[Fact]
public async Task Throws_Exception_Missing_Parameter()
{
IMcpClient client = await CreateMcpClientForServer();
await using IMcpClient client = await CreateMcpClientForServer();

var e = await Assert.ThrowsAsync<McpException>(async () => await client.GetPromptAsync(
nameof(SimplePrompts.ReturnsChatMessages),
Expand Down
Loading