Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions src/Http/Http.Extensions/src/RequestDelegateFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ public static RequestDelegateResult Create(Delegate handler, RequestDelegateFact
/// <summary>
/// Creates a <see cref="RequestDelegate"/> implementation for <paramref name="handler"/>.
/// </summary>
/// <param name="handler">A request handler with any number of custom parameters that often produces a response with its return value.</param>
/// <param name="handler">
/// A request handler with any number of custom parameters that often produces a response with its return value.
/// If delegate points to instance method, but <see cref="Delegate.Target"/> is set to <see langword="null"/>, target will be fetched from <seealso cref="HttpContext.RequestServices"/>.
/// </param>
/// <param name="options">The <see cref="RequestDelegateFactoryOptions"/> used to configure the behavior of the handler.</param>
/// <param name="metadataResult">
/// The result returned from <see cref="InferMetadata(MethodInfo, RequestDelegateFactoryOptions?)"/> if that was used to inferring metadata before creating the final RequestDelegate.
Expand All @@ -178,23 +181,37 @@ public static RequestDelegateResult Create(Delegate handler, RequestDelegateFact
{
ArgumentNullException.ThrowIfNull(handler);

var targetExpression = handler.Target switch
UnaryExpression? targetExpression = null;
Func<HttpContext, object?>? targetFactory = null;
Expression<Func<HttpContext, object?>>? targetFactoryExpression = null;

switch (handler.Target)
{
object => Expression.Convert(TargetExpr, handler.Target.GetType()),
null => null,
};
case object:
targetExpression = Expression.Convert(TargetExpr, handler.Target.GetType());
targetFactory = (httpContext) => handler.Target;
targetFactoryExpression = (httpContext) => handler.Target;

break;

case null when !handler.Method.IsStatic:
targetExpression = Expression.Convert(TargetExpr, handler.Method.ReflectedType!);
targetFactory = (httpContext) => httpContext.RequestServices.GetRequiredService(handler.Method.ReflectedType!);
targetFactoryExpression = (httpContext) => httpContext.RequestServices.GetRequiredService(handler.Method.ReflectedType!);

break;
}

var factoryContext = CreateFactoryContext(options, metadataResult, handler);

Expression<Func<HttpContext, object?>> targetFactory = (httpContext) => handler.Target;
var targetableRequestDelegate = CreateTargetableRequestDelegate(handler.Method, targetExpression, factoryContext, targetFactory);
var targetableRequestDelegate = CreateTargetableRequestDelegate(handler.Method, targetExpression, factoryContext, targetFactoryExpression);

RequestDelegate finalRequestDelegate = targetableRequestDelegate switch
{
// handler is a RequestDelegate that has not been modified by a filter. Short-circuit and return the original RequestDelegate back.
// It's possible a filter factory has still modified the endpoint metadata though.
null => (RequestDelegate)handler,
_ => httpContext => targetableRequestDelegate(handler.Target, httpContext),
_ => httpContext => targetableRequestDelegate(targetFactory?.Invoke(httpContext), httpContext),
};

return CreateRequestDelegateResult(finalRequestDelegate, factoryContext.EndpointBuilder);
Expand Down Expand Up @@ -369,8 +386,9 @@ private static IReadOnlyList<object> AsReadOnlyList(IList<object> metadata)
}
}

// return null for plain RequestDelegates that have not been modified by filters so we can just pass back the original RequestDelegate.
if (filterPipeline is null && factoryContext.Handler is RequestDelegate)
// return null for plain RequestDelegates that have not been modified by filters so we can just pass back the original RequestDelegate
// but only when target is not injected
if (filterPipeline is null && factoryContext.Handler is RequestDelegate && (factoryContext.Handler.Method.IsStatic || factoryContext.Handler.Target is not null))
{
return null;
}
Expand Down
59 changes: 57 additions & 2 deletions src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,56 @@ public async Task RequestDelegatePopulatesParametersFromServiceWithAndWithoutAtt
Assert.Same(myOriginalService, httpContext.Items["service"]);
}

[Fact]
public async Task RequestDelegateInjectingHandlerForUnboundCustomDelegate()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton(LoggerFactory);
serviceCollection.AddScoped<HttpHandler>();

var services = serviceCollection.BuildServiceProvider();

using var requestScoped = services.CreateScope();

var httpContext = CreateHttpContext();
httpContext.RequestServices = requestScoped.ServiceProvider;

var requestMethod = typeof(HttpHandler).GetMethod(nameof(HttpHandler.Handle))!;
var requestMethodDelegate = requestMethod.CreateDelegate<Func<HttpHandler, HttpContext, Task>>();

var factoryResult = RequestDelegateFactory.Create(requestMethodDelegate, options: new() { ServiceProvider = services });
var requestDelegate = factoryResult.RequestDelegate;

await requestDelegate(httpContext);

Assert.Equal(1, httpContext.Items["calls"]);
}

[Fact]
public async Task RequestDelegateInjectingHandlerForUnboundRequestDelegate()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton(LoggerFactory);
serviceCollection.AddScoped<HttpHandler>();

var services = serviceCollection.BuildServiceProvider();

using var requestScoped = services.CreateScope();

var httpContext = CreateHttpContext();
httpContext.RequestServices = requestScoped.ServiceProvider;

var requestMethod = typeof(HttpHandler).GetMethod(nameof(HttpHandler.Handle))!;
var requestMethodDelegate = requestMethod.CreateDelegate<RequestDelegate>(null);

var factoryResult = RequestDelegateFactory.Create(requestMethodDelegate, options: new() { ServiceProvider = services });
var requestDelegate = factoryResult.RequestDelegate;

await requestDelegate(httpContext);

Assert.Equal(1, httpContext.Items["calls"]);
}

[Fact]
public async Task RequestDelegatePopulatesHttpContextParameterWithoutAttribute()
{
Expand Down Expand Up @@ -3659,14 +3709,19 @@ private class FromServiceAttribute : Attribute, IFromServiceMetadata
{
}

class HttpHandler
private class HttpHandler
{
private int _calls;

public void Handle(HttpContext httpContext)
/// <remarks>
/// Method in form of <see cref="RequestDelegate"/>.
/// </remarks>
public Task Handle(HttpContext httpContext)
{
_calls++;
httpContext.Items["calls"] = _calls;

return Task.CompletedTask;
}
}

Expand Down
Loading