diff --git a/src/SignalR/clients/csharp/Client.SourceGenerator/src/HubClientProxyGenerator.Parser.cs b/src/SignalR/clients/csharp/Client.SourceGenerator/src/HubClientProxyGenerator.Parser.cs index 1dd1655dd864..6ed0b35de46d 100644 --- a/src/SignalR/clients/csharp/Client.SourceGenerator/src/HubClientProxyGenerator.Parser.cs +++ b/src/SignalR/clients/csharp/Client.SourceGenerator/src/HubClientProxyGenerator.Parser.cs @@ -134,13 +134,7 @@ private static bool IsExtensionClassSignatureValid(ClassDeclarationSyntax syntax return true; } - internal static bool IsSyntaxTargetForGeneration(SyntaxNode node) => node is MemberAccessExpressionSyntax - { - Name: GenericNameSyntax - { - Arity: 1 - } - }; + internal static bool IsSyntaxTargetForGeneration(SyntaxNode node) => node is MemberAccessExpressionSyntax{ Name: SimpleNameSyntax }; internal static MemberAccessExpressionSyntax? GetSemanticTargetForGeneration(GeneratorSyntaxContext context) { @@ -251,10 +245,10 @@ internal SourceGenerationSpec Parse(ImmutableArray meth var argModel = _compilation.GetSemanticModel(argType.SyntaxTree); symbol = (ITypeSymbol)argModel.GetSymbolInfo(argType).Symbol; } - else if (memberAccess.Name is not GenericNameSyntax - && memberAccess.Parent.ChildNodes().FirstOrDefault(x => x is ArgumentListSyntax) is - ArgumentListSyntax - { Arguments: { Count: 1 } } als) + else if (memberAccess.Name is SimpleNameSyntax + && memberAccess.Parent.ChildNodes().FirstOrDefault(x => x is ArgumentListSyntax) is + ArgumentListSyntax + { Arguments: { Count: 1 } } als) { // Method isn't using generic syntax so inspect first expression in arguments to deduce the type var argModel = _compilation.GetSemanticModel(als.Arguments[0].Expression.SyntaxTree); diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HubClientProxyGeneratorTests.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HubClientProxyGeneratorTests.cs index a252a46a5ecf..df96bc893740 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HubClientProxyGeneratorTests.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HubClientProxyGeneratorTests.cs @@ -31,7 +31,7 @@ public interface IMyClient Task ReturnTask(); } - private class MyClient : IMyClient + internal class MyClient : IMyClient { public int CallsOfNoArg; public void NoArg() @@ -259,4 +259,52 @@ public async Task CallbacksGetTriggered() await returnTaskFunc(Array.Empty(), returnTaskState); Assert.Equal(1, myClient.CallsOfReturnTask); } + + [Fact] + public void RegistersCallbackProviderWithExplicitGeneric() + { + // Arrange + var mockConn = MockHubConnection.Get(); + var noArgReg = new Disposable(); + mockConn + .Setup(x => x.On( + "NoArg", + Array.Empty(), + It.IsAny>(), + It.IsAny())) + .Returns(noArgReg); + var conn = mockConn.Object; + var myClient = new MyClient(); + + // Act + var registration = conn.SetHubClient(myClient); + + // Assert + mockConn.VerifyAll(); + Assert.False(noArgReg.IsDisposed); + } + + [Fact] + public void RegistersCallbackProviderWithInferredGeneric() + { + // Arrange + var mockConn = MockHubConnection.Get(); + var noArgReg = new Disposable(); + mockConn + .Setup(x => x.On( + "NoArg", + Array.Empty(), + It.IsAny>(), + It.IsAny())) + .Returns(noArgReg); + var conn = mockConn.Object; + var myClient = new MyClient(); + + // Act + var registration = conn.SetHubClient(myClient); // Inferred generic type + + // Assert + mockConn.VerifyAll(); + Assert.False(noArgReg.IsDisposed); + } }