diff --git a/src/ModelContextProtocol/Client/McpClientExtensions.cs b/src/ModelContextProtocol/Client/McpClientExtensions.cs index a86f7e60..513fc8c0 100644 --- a/src/ModelContextProtocol/Client/McpClientExtensions.cs +++ b/src/ModelContextProtocol/Client/McpClientExtensions.cs @@ -783,9 +783,14 @@ public static Task UnsubscribeFromResourceAsync(this IMcpClient client, Uri uri, /// /// The client instance used to communicate with the MCP server. /// The name of the tool to call on the server.. - /// Optional dictionary of arguments to pass to the tool. Each key represents a parameter name, + /// An optional dictionary of arguments to pass to the tool. Each key represents a parameter name, /// and its associated value represents the argument value. /// + /// + /// An optional to have progress notifications reported to it. Setting this to a non- + /// value will result in a progress token being included in the call, and any resulting progress notifications during the operation + /// routed to this instance. + /// /// /// The JSON serialization options governing argument serialization. If , the default serialization options will be used. /// @@ -812,6 +817,7 @@ public static Task CallToolAsync( this IMcpClient client, string toolName, IReadOnlyDictionary? arguments = null, + IProgress? progress = null, JsonSerializerOptions? serializerOptions = null, CancellationToken cancellationToken = default) { @@ -820,12 +826,56 @@ public static Task CallToolAsync( serializerOptions ??= McpJsonUtilities.DefaultOptions; serializerOptions.MakeReadOnly(); + if (progress is not null) + { + return SendRequestWithProgressAsync(client, toolName, arguments, progress, serializerOptions, cancellationToken); + } + return client.SendRequestAsync( RequestMethods.ToolsCall, - new() { Name = toolName, Arguments = ToArgumentsDictionary(arguments, serializerOptions) }, + new() + { + Name = toolName, + Arguments = ToArgumentsDictionary(arguments, serializerOptions), + }, McpJsonUtilities.JsonContext.Default.CallToolRequestParams, McpJsonUtilities.JsonContext.Default.CallToolResponse, cancellationToken: cancellationToken); + + static async Task SendRequestWithProgressAsync( + IMcpClient client, + string toolName, + IReadOnlyDictionary? arguments, + IProgress progress, + JsonSerializerOptions serializerOptions, + CancellationToken cancellationToken) + { + ProgressToken progressToken = new(Guid.NewGuid().ToString("N")); + + await using var _ = client.RegisterNotificationHandler(NotificationMethods.ProgressNotification, + (notification, cancellationToken) => + { + if (JsonSerializer.Deserialize(notification.Params, McpJsonUtilities.JsonContext.Default.ProgressNotification) is { } pn && + pn.ProgressToken == progressToken) + { + progress.Report(pn.Progress); + } + + return default; + }).ConfigureAwait(false); + + return await client.SendRequestAsync( + RequestMethods.ToolsCall, + new() + { + Name = toolName, + Arguments = ToArgumentsDictionary(arguments, serializerOptions), + Meta = new() { ProgressToken = progressToken }, + }, + McpJsonUtilities.JsonContext.Default.CallToolRequestParams, + McpJsonUtilities.JsonContext.Default.CallToolResponse, + cancellationToken: cancellationToken).ConfigureAwait(false); + } } /// diff --git a/src/ModelContextProtocol/Client/McpClientTool.cs b/src/ModelContextProtocol/Client/McpClientTool.cs index 7917f0af..759b9f9c 100644 --- a/src/ModelContextProtocol/Client/McpClientTool.cs +++ b/src/ModelContextProtocol/Client/McpClientTool.cs @@ -1,8 +1,9 @@ +using Microsoft.Extensions.AI; using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Utils; using ModelContextProtocol.Utils.Json; -using Microsoft.Extensions.AI; -using System.Text.Json; using System.Collections.ObjectModel; +using System.Text.Json; namespace ModelContextProtocol.Client; @@ -36,14 +37,22 @@ public sealed class McpClientTool : AIFunction private readonly IMcpClient _client; private readonly string _name; private readonly string _description; + private readonly IProgress? _progress; - internal McpClientTool(IMcpClient client, Tool tool, JsonSerializerOptions serializerOptions, string? name = null, string? description = null) + internal McpClientTool( + IMcpClient client, + Tool tool, + JsonSerializerOptions serializerOptions, + string? name = null, + string? description = null, + IProgress? progress = null) { _client = client; ProtocolTool = tool; JsonSerializerOptions = serializerOptions; _name = name ?? tool.Name; _description = description ?? tool.Description ?? string.Empty; + _progress = progress; } /// @@ -77,7 +86,7 @@ internal McpClientTool(IMcpClient client, Tool tool, JsonSerializerOptions seria protected async override ValueTask InvokeCoreAsync( AIFunctionArguments arguments, CancellationToken cancellationToken) { - CallToolResponse result = await _client.CallToolAsync(ProtocolTool.Name, arguments, JsonSerializerOptions, cancellationToken: cancellationToken).ConfigureAwait(false); + CallToolResponse result = await _client.CallToolAsync(ProtocolTool.Name, arguments, _progress, JsonSerializerOptions, cancellationToken: cancellationToken).ConfigureAwait(false); return JsonSerializer.SerializeToElement(result, McpJsonUtilities.JsonContext.Default.CallToolResponse); } @@ -107,7 +116,7 @@ internal McpClientTool(IMcpClient client, Tool tool, JsonSerializerOptions seria /// public McpClientTool WithName(string name) { - return new McpClientTool(_client, ProtocolTool, JsonSerializerOptions, name, _description); + return new McpClientTool(_client, ProtocolTool, JsonSerializerOptions, name, _description, _progress); } /// @@ -133,6 +142,30 @@ public McpClientTool WithName(string name) /// A new instance of with the provided description. public McpClientTool WithDescription(string description) { - return new McpClientTool(_client, ProtocolTool, JsonSerializerOptions, _name, description); + return new McpClientTool(_client, ProtocolTool, JsonSerializerOptions, _name, description, _progress); + } + + /// + /// Creates a new instance of the tool but modified to report progress via the specified . + /// + /// The to which progress notifications should be reported. + /// + /// + /// Adding an to the tool does not impact how it is reported to any AI model. + /// Rather, when the tool is invoked, the request to the MCP server will include a unique progress token, + /// and any progress notifications issued by the server with that progress token while the operation is in + /// flight will be reported to the instance. + /// + /// + /// Only one can be specified at a time. Calling again + /// will overwrite any previously specified progress instance. + /// + /// + /// A new instance of , configured with the provided progress instance. + public McpClientTool WithProgress(IProgress progress) + { + Throw.IfNull(progress); + + return new McpClientTool(_client, ProtocolTool, JsonSerializerOptions, _name, _description, progress); } } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs index 46c1879c..a5bf96fe 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs @@ -29,7 +29,6 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer } services.AddSingleton(McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)] (string i) => $"{i} Result", new() { Name = "ValuesSetViaAttr" })); services.AddSingleton(McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)] (string i) => $"{i} Result", new() { Name = "ValuesSetViaOptions", Destructive = true, OpenWorld = false, ReadOnly = true })); - } [Theory] @@ -209,7 +208,7 @@ public async Task CreateSamplingHandler_ShouldHandleResourceMessages() [Fact] public async Task ListToolsAsync_AllToolsReturned() { - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.Equal(12, tools.Count); @@ -235,7 +234,7 @@ public async Task ListToolsAsync_AllToolsReturned() [Fact] public async Task EnumerateToolsAsync_AllToolsReturned() { - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); await foreach (var tool in client.EnumerateToolsAsync(cancellationToken: TestContext.Current.CancellationToken)) { @@ -254,7 +253,7 @@ public async Task EnumerateToolsAsync_AllToolsReturned() public async Task EnumerateToolsAsync_FlowsJsonSerializerOptions() { JsonSerializerOptions options = new(JsonSerializerOptions.Default); - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); bool hasTools = false; await foreach (var tool in client.EnumerateToolsAsync(options, TestContext.Current.CancellationToken)) @@ -275,7 +274,7 @@ public async Task EnumerateToolsAsync_FlowsJsonSerializerOptions() public async Task EnumerateToolsAsync_HonorsJsonSerializerOptions() { JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); var tool = (await client.ListToolsAsync(emptyOptions, TestContext.Current.CancellationToken)).First(); await Assert.ThrowsAsync(async () => await tool.InvokeAsync(new() { ["i"] = 42 }, TestContext.Current.CancellationToken)); @@ -285,7 +284,7 @@ public async Task EnumerateToolsAsync_HonorsJsonSerializerOptions() public async Task SendRequestAsync_HonorsJsonSerializerOptions() { JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); await Assert.ThrowsAsync(() => client.SendRequestAsync("Method4", new() { Name = "tool" }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken)); } @@ -294,7 +293,7 @@ public async Task SendRequestAsync_HonorsJsonSerializerOptions() public async Task SendNotificationAsync_HonorsJsonSerializerOptions() { JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); await Assert.ThrowsAsync(() => client.SendNotificationAsync("Method4", new { Value = 42 }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken)); } @@ -303,7 +302,7 @@ public async Task SendNotificationAsync_HonorsJsonSerializerOptions() public async Task GetPromptsAsync_HonorsJsonSerializerOptions() { JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); await Assert.ThrowsAsync(() => client.GetPromptAsync("Prompt", new Dictionary { ["i"] = 42 }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken)); } @@ -312,7 +311,7 @@ public async Task GetPromptsAsync_HonorsJsonSerializerOptions() public async Task WithName_ChangesToolName() { JsonSerializerOptions options = new(JsonSerializerOptions.Default); - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); var tool = (await client.ListToolsAsync(options, TestContext.Current.CancellationToken)).First(); var originalName = tool.Name; @@ -327,7 +326,7 @@ public async Task WithName_ChangesToolName() public async Task WithDescription_ChangesToolDescription() { JsonSerializerOptions options = new(JsonSerializerOptions.Default); - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); var tool = (await client.ListToolsAsync(options, TestContext.Current.CancellationToken)).FirstOrDefault(); var originalDescription = tool?.Description; var redescribedTool = tool?.WithDescription("ToolWithNewDescription"); @@ -336,10 +335,55 @@ public async Task WithDescription_ChangesToolDescription() Assert.Equal(originalDescription, tool?.Description); } + [Fact] + public async Task WithProgress_ProgressReported() + { + const int TotalNotifications = 3; + int remainingProgress = TotalNotifications; + TaskCompletionSource allProgressReceived = new(TaskCreationOptions.RunContinuationsAsynchronously); + + Server.ServerOptions.Capabilities?.Tools?.ToolCollection?.Add(McpServerTool.Create(async (IProgress progress) => + { + for (int i = 0; i < TotalNotifications; i++) + { + progress.Report(new ProgressNotificationValue { Progress = i * 10, Message = "making progress" }); + await Task.Delay(1); + } + + await allProgressReceived.Task; + + return 42; + }, new() { Name = "ProgressReporter" })); + + await using IMcpClient client = await CreateMcpClientForServer(); + + var tool = (await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken)).First(t => t.Name == "ProgressReporter"); + + IProgress progress = new SynchronousProgress(value => + { + Assert.True(value.Progress >= 0 && value.Progress <= 100); + Assert.Equal("making progress", value.Message); + if (Interlocked.Decrement(ref remainingProgress) == 0) + { + allProgressReceived.SetResult(true); + } + }); + + Assert.Throws("progress", () => tool.WithProgress(null!)); + + var result = await tool.WithProgress(progress).InvokeAsync(cancellationToken: TestContext.Current.CancellationToken); + Assert.Contains("42", result?.ToString()); + } + + private sealed class SynchronousProgress(Action callback) : IProgress + { + public void Report(ProgressNotificationValue value) => callback(value); + } + [Fact] public async Task AsClientLoggerProvider_MessagesSentToClient() { - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); ILoggerProvider loggerProvider = Server.AsClientLoggerProvider(); Assert.Throws("categoryName", () => loggerProvider.CreateLogger(null!)); diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs index 289330ae..a66cbeca 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs @@ -96,7 +96,7 @@ public void Adds_Prompts_To_Server() [Fact] public async Task Can_List_And_Call_Registered_Prompts() { - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); Assert.Equal(6, prompts.Count); @@ -125,7 +125,7 @@ public async Task Can_List_And_Call_Registered_Prompts() [Fact] public async Task Can_Be_Notified_Of_Prompt_Changes() { - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); Assert.Equal(6, prompts.Count); @@ -166,7 +166,7 @@ public async Task Can_Be_Notified_Of_Prompt_Changes() [Fact] public async Task Throws_When_Prompt_Fails() { - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); await Assert.ThrowsAsync(async () => await client.GetPromptAsync( nameof(SimplePrompts.ThrowsException), @@ -176,7 +176,7 @@ await Assert.ThrowsAsync(async () => await client.GetPromptAsync( [Fact] public async Task Throws_Exception_On_Unknown_Prompt() { - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); var e = await Assert.ThrowsAsync(async () => await client.GetPromptAsync( "NotRegisteredPrompt", @@ -188,7 +188,7 @@ public async Task Throws_Exception_On_Unknown_Prompt() [Fact] public async Task Throws_Exception_Missing_Parameter() { - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); var e = await Assert.ThrowsAsync(async () => await client.GetPromptAsync( nameof(SimplePrompts.ReturnsChatMessages), diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index d833356a..40ca2446 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -124,7 +124,7 @@ public void Adds_Tools_To_Server() [Fact] public async Task Can_List_Registered_Tools() { - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.Equal(16, tools.Count); @@ -190,7 +190,7 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T [Fact] public async Task Can_Be_Notified_Of_Tool_Changes() { - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.Equal(16, tools.Count); @@ -231,7 +231,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes() [Fact] public async Task Can_Call_Registered_Tool() { - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "Echo", @@ -249,7 +249,7 @@ public async Task Can_Call_Registered_Tool() [Fact] public async Task Can_Call_Registered_Tool_With_Array_Result() { - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "EchoArray", @@ -273,7 +273,7 @@ public async Task Can_Call_Registered_Tool_With_Array_Result() [Fact] public async Task Can_Call_Registered_Tool_With_Null_Result() { - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "ReturnNull", @@ -287,7 +287,7 @@ public async Task Can_Call_Registered_Tool_With_Null_Result() [Fact] public async Task Can_Call_Registered_Tool_With_Json_Result() { - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "ReturnJson", @@ -304,7 +304,7 @@ public async Task Can_Call_Registered_Tool_With_Json_Result() [Fact] public async Task Can_Call_Registered_Tool_With_Int_Result() { - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "ReturnInteger", @@ -320,7 +320,7 @@ public async Task Can_Call_Registered_Tool_With_Int_Result() [Fact] public async Task Can_Call_Registered_Tool_And_Pass_ComplexType() { - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "EchoComplex", @@ -338,7 +338,7 @@ public async Task Can_Call_Registered_Tool_And_Pass_ComplexType() [Fact] public async Task Can_Call_Registered_Tool_With_Instance_Method() { - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); string[][] parts = new string[2][]; for (int i = 0; i < 2; i++) @@ -367,7 +367,7 @@ public async Task Can_Call_Registered_Tool_With_Instance_Method() [Fact] public async Task Returns_IsError_Content_When_Tool_Fails() { - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "ThrowException", @@ -382,7 +382,7 @@ public async Task Returns_IsError_Content_When_Tool_Fails() [Fact] public async Task Throws_Exception_On_Unknown_Tool() { - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); var e = await Assert.ThrowsAsync(async () => await client.CallToolAsync( "NotRegisteredTool", @@ -394,7 +394,7 @@ public async Task Throws_Exception_On_Unknown_Tool() [Fact] public async Task Returns_IsError_Missing_Parameter() { - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "Echo", @@ -511,7 +511,7 @@ public void WithToolsFromAssembly_Parameters_Satisfiable_From_DI(ServiceLifetime [Fact] public async Task Recognizes_Parameter_Types() { - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); @@ -584,7 +584,7 @@ public void Create_ExtractsToolAnnotations_SomeSet() [Fact] public async Task HandlesIProgressParameter() { - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(tools); @@ -627,7 +627,7 @@ public async Task HandlesIProgressParameter() [Fact] public async Task CancellationNotificationsPropagateToToolTokens() { - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(tools); diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs index 86aa16ab..5815a625 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs @@ -23,7 +23,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer [Fact] public async Task InjectScopedServiceAsArgument() { - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(McpServerScopedTestsJsonContext.Default.Options, TestContext.Current.CancellationToken); var tool = tools.First(t => t.Name == nameof(EchoTool.EchoComplex)); diff --git a/tests/ModelContextProtocol.Tests/Protocol/CancellationTests.cs b/tests/ModelContextProtocol.Tests/Protocol/CancellationTests.cs index 81a0f87d..6aa05818 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/CancellationTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/CancellationTests.cs @@ -33,7 +33,7 @@ private static async Task WaitForCancellation(CancellationToken cancellationToke [Fact] public async Task PrecancelRequest_CancelsBeforeSending() { - var client = await CreateMcpClientForServer(); + await using var client = await CreateMcpClientForServer(); bool gotCancellation = false; await using (Server.RegisterNotificationHandler(NotificationMethods.CancelledNotification, (notification, cancellationToken) => @@ -51,7 +51,7 @@ public async Task PrecancelRequest_CancelsBeforeSending() [Fact] public async Task CancellationPropagation_RequestingCancellationCancelsPendingRequest() { - var client = await CreateMcpClientForServer(); + await using var client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); var waitTool = tools.First(t => t.Name == nameof(WaitForCancellation)); diff --git a/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs b/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs index 6aa7ccf0..1d42c3d8 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs @@ -13,7 +13,7 @@ public NotificationHandlerTests(ITestOutputHelper testOutputHelper) public async Task RegistrationsAreRemovedWhenDisposed() { const string NotificationName = "somethingsomething"; - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); const int Iterations = 10; @@ -40,7 +40,7 @@ public async Task RegistrationsAreRemovedWhenDisposed() public async Task MultipleRegistrationsResultInMultipleCallbacks() { const string NotificationName = "somethingsomething"; - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); const int RegistrationCount = 10; @@ -80,7 +80,7 @@ public async Task MultipleRegistrationsResultInMultipleCallbacks() public async Task MultipleHandlersRunEvenIfOneThrows() { const string NotificationName = "somethingsomething"; - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); const int RegistrationCount = 10; @@ -122,7 +122,7 @@ public async Task MultipleHandlersRunEvenIfOneThrows() public async Task DisposeAsyncDoesNotCompleteWhileNotificationHandlerRuns(int numberOfDisposals) { const string NotificationName = "somethingsomething"; - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); var handlerRunning = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var releaseHandler = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); @@ -163,7 +163,7 @@ public async Task DisposeAsyncDoesNotCompleteWhileNotificationHandlerRuns(int nu public async Task DisposeAsyncCompletesImmediatelyWhenInvokedFromHandler(int numberOfDisposals) { const string NotificationName = "somethingsomething"; - IMcpClient client = await CreateMcpClientForServer(); + await using IMcpClient client = await CreateMcpClientForServer(); var handlerRunning = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var releaseHandler = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);