Skip to content

Commit

Permalink
feat: generic pipeline support
Browse files Browse the repository at this point in the history
  • Loading branch information
GerardSmit committed Dec 22, 2023
1 parent 4597fe3 commit a8a6bd2
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ public static IMediatorBuilder AddMediator(this IServiceCollection services)
services.TryAddSingleton(typeof(GenericNotificationCache<>));
services.TryAddTransient(typeof(GenericNotificationHandler<>));

services.TryAddSingleton(typeof(GenericPipelineBehavior<,>));

services.TryAddSingleton(typeof(GenericStreamRequestCache<,>));
services.TryAddTransient(typeof(IStreamRequestHandler<,>), typeof(GenericStreamRequestHandler<,>));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=handlers/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=handlers_005Crequest/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=namespaces/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=namespaces_005Chandlers/@EntryIndexedValue">True</s:Boolean></wpf:ResourceDictionary>
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=namespaces_005Chandlers/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=pipeline/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=pipeline_005Cgeneric/@EntryIndexedValue">True</s:Boolean></wpf:ResourceDictionary>
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;

namespace Zapto.Mediator;

internal sealed record GenericPipelineBehaviorRegistration
{
public GenericPipelineBehaviorRegistration(Type requestType, Type? responseType, Type behaviorType)
{
RequestType = requestType;
ResponseType = responseType;
BehaviorType = behaviorType;
}

public Type RequestType { get; }

public Type? ResponseType { get; }

public Type BehaviorType { get; }
}

internal sealed class GenericPipelineBehavior<TRequest, TResponse>
where TRequest : notnull
{
private readonly List<Type> _handlerTypes;
private readonly IEnumerable<GenericPipelineBehaviorRegistration> _enumerable;

public GenericPipelineBehavior(IEnumerable<GenericPipelineBehaviorRegistration> enumerable)
{
_enumerable = enumerable;
_handlerTypes = CreateHandlerTypes();
}

public bool IsEmpty => _handlerTypes.Count == 0;

private List<Type> CreateHandlerTypes()
{
var handlerTypes = new List<Type>();

if (_enumerable is GenericPipelineBehaviorRegistration[] { Length: 0 })
{
return handlerTypes;
}

var requestType = typeof(TRequest);
var arguments = requestType.GetGenericArguments();

if (requestType.IsGenericType)
{
requestType = requestType.GetGenericTypeDefinition();
}

var responseType = typeof(TResponse);

if (responseType.IsGenericType)
{
responseType = responseType.GetGenericTypeDefinition();
}

foreach (var registration in _enumerable)
{
if (!registration.RequestType.IsAssignableFrom(requestType) ||
registration.ResponseType is not null && !registration.ResponseType.IsAssignableFrom(responseType))
{
continue;
}

var type = registration.BehaviorType.MakeGenericType(arguments);

handlerTypes.Add(type);
}

handlerTypes.Reverse();

return handlerTypes;
}

public ValueTask<TResponse> Handle(IServiceProvider provider, TRequest request, RequestHandlerDelegate<TResponse> next, CancellationToken cancellationToken)
{
foreach (var cachedType in _handlerTypes)
{
var behavior = (IPipelineBehavior<TRequest, TResponse>)provider.GetRequiredService(cachedType);
var nextPipeline = next;

next = () => behavior.Handle(provider, request, nextPipeline, cancellationToken);
}

return next();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Linq;
using MediatR;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;

namespace Zapto.Mediator;

Expand Down Expand Up @@ -48,10 +49,14 @@ public IMediatorBuilder AddPipelineBehavior(Type requestType, Type? responseType
{
if (requestType.IsGenericType || responseType is null || responseType.IsGenericTypeDefinition)
{
throw new NotSupportedException("Generic pipeline behaviors are not supported.");
_services.TryAdd(new ServiceDescriptor(behaviorType, behaviorType, GetLifetime(scope)));
_services.AddSingleton(new GenericPipelineBehaviorRegistration(requestType, responseType, behaviorType));
}
else
{
_services.Add(new ServiceDescriptor(typeof(IPipelineBehavior<,>).MakeGenericType(requestType, responseType), behaviorType, GetLifetime(scope)));
}

_services.Add(new ServiceDescriptor(typeof(IPipelineBehavior<,>).MakeGenericType(requestType, responseType), behaviorType, GetLifetime(scope)));
return this;
}

Expand Down
10 changes: 7 additions & 3 deletions src/Mediator.DependencyInjection/ServiceProviderMediator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ private ValueTask<TResponse> SendWithPipelineArray<TRequest, TResponse>(
CancellationToken cancellationToken
) where TRequest : IRequest<TResponse>
{
if (array.Length == 0)
var generic = _provider.GetRequiredService<GenericPipelineBehavior<TRequest, TResponse>>();

if (array.Length == 0 && generic.IsEmpty)
{
return handler.Handle(_provider, request, cancellationToken);
}
Expand All @@ -104,7 +106,7 @@ CancellationToken cancellationToken
next = () => pipelineBehavior.Handle(_provider, request, nextPipeline, cancellationToken);
}

return next();
return generic.Handle(_provider, request, next, cancellationToken);
}

private ValueTask<TResponse> SendWithPipelineEnumerable<TRequest, TResponse>(
Expand All @@ -114,6 +116,8 @@ private ValueTask<TResponse> SendWithPipelineEnumerable<TRequest, TResponse>(
CancellationToken cancellationToken
) where TRequest : IRequest<TResponse>
{
var generic = _provider.GetRequiredService<GenericPipelineBehavior<TRequest, TResponse>>();

RequestHandlerDelegate<TResponse> next = () => handler.Handle(_provider, request, cancellationToken);

foreach (var pipelineBehavior in pipeline.Reverse())
Expand All @@ -122,7 +126,7 @@ CancellationToken cancellationToken
next = () => pipelineBehavior.Handle(_provider, request, nextPipeline, cancellationToken);
}

return next();
return generic.Handle(_provider, request, next, cancellationToken);
}

public IAsyncEnumerable<TResponse> CreateStream<TResponse>(MediatorNamespace ns, IStreamRequest<TResponse> request,
Expand Down
15 changes: 10 additions & 5 deletions src/Mediator/IMediatorBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ public interface IMediatorBuilder
{
IMediatorBuilder AddNamespace(MediatorNamespace ns);

IMediatorBuilder AddRequestHandler<THandler>(RegistrationScope scope = RegistrationScope.Transient) where THandler : IRequestHandler;
IMediatorBuilder AddRequestHandler<THandler>(RegistrationScope scope = RegistrationScope.Transient)
where THandler : IRequestHandler;

IMediatorBuilder AddRequestHandler(Type handlerType, RegistrationScope scope = RegistrationScope.Transient);

Expand Down Expand Up @@ -101,7 +102,8 @@ IMediatorBuilder AddNotificationHandler<TNotification>(Func<IServiceProvider, TN

IMediatorBuilder AddStreamRequestHandler(Type type, RegistrationScope scope = RegistrationScope.Transient);

IMediatorBuilder AddStreamRequestHandler<THandler>(RegistrationScope scope = RegistrationScope.Transient) where THandler : IStreamRequestHandler;
IMediatorBuilder AddStreamRequestHandler<THandler>(RegistrationScope scope = RegistrationScope.Transient)
where THandler : IStreamRequestHandler;

IMediatorBuilder AddStreamRequestHandler(Type requestType, Type? responseType, Type handlerType, RegistrationScope scope = RegistrationScope.Transient);

Expand All @@ -123,19 +125,22 @@ public IMediatorBuilder AddStreamRequestHandler<TRequest, TResponse>(Func<IServi

IMediatorBuilder AddDefaultRequestHandler(Type handlerType, RegistrationScope scope = RegistrationScope.Transient);

IMediatorBuilder AddDefaultRequestHandler<THandler>(RegistrationScope scope = RegistrationScope.Transient) where THandler : class, IDefaultRequestHandler;
IMediatorBuilder AddDefaultRequestHandler<THandler>(RegistrationScope scope = RegistrationScope.Transient)
where THandler : class, IDefaultRequestHandler;

IMediatorBuilder AddDefaultNotificationHandler(IDefaultNotificationHandler handler);

IMediatorBuilder AddDefaultNotificationHandler(Type handlerType, RegistrationScope scope = RegistrationScope.Transient);

IMediatorBuilder AddDefaultNotificationHandler<THandler>(RegistrationScope scope = RegistrationScope.Transient) where THandler : class, IDefaultNotificationHandler;
IMediatorBuilder AddDefaultNotificationHandler<THandler>(RegistrationScope scope = RegistrationScope.Transient)
where THandler : class, IDefaultNotificationHandler;

IMediatorBuilder AddDefaultStreamRequestHandler(IDefaultStreamRequestHandler handler);

IMediatorBuilder AddDefaultStreamRequestHandler(Type handlerType, RegistrationScope scope = RegistrationScope.Transient);

IMediatorBuilder AddDefaultStreamRequestHandler<THandler>(RegistrationScope scope = RegistrationScope.Transient) where THandler : class, IDefaultStreamRequestHandler;
IMediatorBuilder AddDefaultStreamRequestHandler<THandler>(RegistrationScope scope = RegistrationScope.Transient)
where THandler : class, IDefaultStreamRequestHandler;

IMediatorBuilder AddPipelineBehavior<TRequest, TResponse>(IPipelineBehavior<TRequest, TResponse> behavior)
where TRequest : notnull;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#if NET7_0_OR_GREATER
using System;
using System.Numerics;
using System.Threading;
using System.Threading.Tasks;
using MediatR;
using Microsoft.Extensions.DependencyInjection;
using Xunit;
using Zapto.Mediator;

namespace Mediator.DependencyInjection.Tests.Generics;

public record struct ReturnNumberRequest<TSelf>(TSelf Value) : IRequest<TSelf>
where TSelf : INumber<TSelf>;

public class ReturnNumberRequestHandler<TSelf> : IRequestHandler<ReturnNumberRequest<TSelf>, TSelf>
where TSelf : INumber<TSelf>
{
public ValueTask<TSelf> Handle(IServiceProvider provider, ReturnNumberRequest<TSelf> request,
CancellationToken cancellationToken)
{
return new ValueTask<TSelf>(request.Value);
}
}

public class AddOnePipelineBehavior<TSelf> : IPipelineBehavior<ReturnNumberRequest<TSelf>, TSelf>
where TSelf : INumber<TSelf>, IAdditionOperators<TSelf, TSelf, TSelf>
{
public async ValueTask<TSelf> Handle(IServiceProvider provider, ReturnNumberRequest<TSelf> request, RequestHandlerDelegate<TSelf> next,
CancellationToken cancellationToken)
{
var result = await next();

return result + TSelf.One;
}
}

public class BehaviorNumberTest
{
[Fact]
public async Task ReturnSelf()
{
await using var provider = new ServiceCollection()
.AddMediator(b =>
{
b.AddRequestHandler(typeof(ReturnNumberRequestHandler<>));
})
.BuildServiceProvider();

var mediator = provider.GetRequiredService<IMediator>();

Assert.Equal(0, await mediator.Send(new ReturnNumberRequest<int>(0)));
Assert.Equal(10L, await mediator.Send(new ReturnNumberRequest<long>(10L)));
}

[Fact]
public async Task AddOneBehavior()
{
await using var provider = new ServiceCollection()
.AddMediator(b =>
{
b.AddRequestHandler(typeof(ReturnNumberRequestHandler<>));
b.AddPipelineBehavior(typeof(AddOnePipelineBehavior<>));
})
.BuildServiceProvider();

var mediator = provider.GetRequiredService<IMediator>();

Assert.Equal(1, await mediator.Send(new ReturnNumberRequest<int>(0)));
Assert.Equal(11L, await mediator.Send(new ReturnNumberRequest<long>(10L)));
}
}
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

<PropertyGroup>
<TargetFrameworks>net461;net6.0;net7.0;net8.0</TargetFrameworks>
<LangVersion>10</LangVersion>
<LangVersion>12</LangVersion>
<Nullable>enable</Nullable>
<IsPackable>false</IsPackable>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
Expand Down

0 comments on commit a8a6bd2

Please sign in to comment.