From bb534b6cacb52ba0cf29c256e54057a8fc118123 Mon Sep 17 00:00:00 2001 From: Koen Date: Mon, 26 Aug 2024 01:33:34 +0100 Subject: [PATCH 01/16] Generate code that knows how to track cancellation tokens --- .../AnalyzerReleases.Unshipped.md | 7 +- ...leCancellationTokenParametersDiagnostic.cs | 16 + .../InvokableGenerator.cs | 299 +++++++++++++++--- src/Orleans.CodeGenerator/LibraryTypes.cs | 18 +- .../Model/InvokableMethodDescription.cs | 6 + src/Orleans.CodeGenerator/ProxyGenerator.cs | 45 +++ .../SerializerGenerator.cs | 56 ++++ .../ICancellableInvokableGrainExtension.cs | 19 ++ .../CancellableInvokableGrainExtension.cs | 12 + .../Cancellation/CancellationRuntime.cs | 48 +++ .../Invocation/ICancellableInvokable.cs | 16 + .../Invocation/ICancellationRuntime.cs | 28 ++ 12 files changed, 522 insertions(+), 48 deletions(-) create mode 100644 src/Orleans.CodeGenerator/Diagnostics/MultipleCancellationTokenParametersDiagnostic.cs create mode 100644 src/Orleans.Core.Abstractions/Cancellation/ICancellableInvokableGrainExtension.cs create mode 100644 src/Orleans.Runtime/Cancellation/CancellableInvokableGrainExtension.cs create mode 100644 src/Orleans.Runtime/Cancellation/CancellationRuntime.cs create mode 100644 src/Orleans.Serialization/Invocation/ICancellableInvokable.cs create mode 100644 src/Orleans.Serialization/Invocation/ICancellationRuntime.cs diff --git a/src/Orleans.CodeGenerator/AnalyzerReleases.Unshipped.md b/src/Orleans.CodeGenerator/AnalyzerReleases.Unshipped.md index b1b99aaf26..d5a7efc0b7 100644 --- a/src/Orleans.CodeGenerator/AnalyzerReleases.Unshipped.md +++ b/src/Orleans.CodeGenerator/AnalyzerReleases.Unshipped.md @@ -1,3 +1,8 @@ -; Unshipped analyzer release +; Unshipped analyzer release ; https://github.com/dotnet/roslyn-analyzers/blob/main/src/Microsoft.CodeAnalysis.Analyzers/ReleaseTrackingAnalyzers.Help.md +### New Rules + +Rule ID | Category | Severity | Notes +--------|----------|----------|-------------------- +ORLEANS0109 | Usage | Error | Method has multiple CancellationToken parameters diff --git a/src/Orleans.CodeGenerator/Diagnostics/MultipleCancellationTokenParametersDiagnostic.cs b/src/Orleans.CodeGenerator/Diagnostics/MultipleCancellationTokenParametersDiagnostic.cs new file mode 100644 index 0000000000..526c74fc4e --- /dev/null +++ b/src/Orleans.CodeGenerator/Diagnostics/MultipleCancellationTokenParametersDiagnostic.cs @@ -0,0 +1,16 @@ +using System.Linq; +using Microsoft.CodeAnalysis; + +namespace Orleans.CodeGenerator.Diagnostics; + +public static class MultipleCancellationTokenParametersDiagnostic +{ + public const string DiagnosticId = "ORLEANS0109"; + public const string Title = "Grain method has multiple parameters of type CancellationToken"; + public const string MessageFormat = "The type {0} contains method {1} which has multiple CancellationToken parameters. Only a single CancellationToken parameter is supported."; + public const string Category = "Usage"; + + private static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor(DiagnosticId, Title, MessageFormat, Category, DiagnosticSeverity.Error, isEnabledByDefault: true); + + internal static Diagnostic CreateDiagnostic(IMethodSymbol symbol) => Diagnostic.Create(Rule, symbol.Locations.First(), symbol.ContainingType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), symbol.Name); +} \ No newline at end of file diff --git a/src/Orleans.CodeGenerator/InvokableGenerator.cs b/src/Orleans.CodeGenerator/InvokableGenerator.cs index a09f75a5be..b60473e1e9 100644 --- a/src/Orleans.CodeGenerator/InvokableGenerator.cs +++ b/src/Orleans.CodeGenerator/InvokableGenerator.cs @@ -30,6 +30,7 @@ public GeneratedInvokableDescription Generate(InvokableMethodDescription invokab var generatedClassName = GetSimpleClassName(invokableMethodInfo); var baseClassType = GetBaseClassType(invokableMethodInfo); + var additionalInterfaceTypes = GetAdditionalInterfaceTypes(invokableMethodInfo); var fieldDescriptions = GetFieldDescriptions(invokableMethodInfo); var fields = GetFieldDeclarations(invokableMethodInfo, fieldDescriptions); var (ctor, ctorArgs) = GenerateConstructor(generatedClassName, invokableMethodInfo, baseClassType); @@ -46,7 +47,7 @@ public GeneratedInvokableDescription Generate(InvokableMethodDescription invokab } } - var targetField = fieldDescriptions.OfType().Single(); + var holderField = fieldDescriptions.OfType().Single(); var accessibilityKind = accessibility switch { @@ -58,11 +59,12 @@ public GeneratedInvokableDescription Generate(InvokableMethodDescription invokab invokableMethodInfo, generatedClassName, baseClassType, + additionalInterfaceTypes, fieldDescriptions, fields, ctor, compoundTypeAliases, - targetField, + holderField, accessibilityKind); string returnValueInitializerMethod = null; @@ -112,11 +114,12 @@ private ClassDeclarationSyntax GetClassDeclarationSyntax( InvokableMethodDescription method, string generatedClassName, INamedTypeSymbol baseClassType, + INamedTypeSymbol[] additionalInterfaceTypes, List fieldDescriptions, MemberDeclarationSyntax[] fields, ConstructorDeclarationSyntax ctor, List compoundTypeAliases, - TargetFieldDescription targetField, + HolderFieldDescription holderField, SyntaxKind accessibilityKind) { var classDeclaration = ClassDeclaration(generatedClassName) @@ -125,6 +128,14 @@ private ClassDeclarationSyntax GetClassDeclarationSyntax( .AddAttributeLists(CodeGenerator.GetGeneratedCodeAttributes()) .AddMembers(fields); + if (additionalInterfaceTypes.Length > 0) + { + foreach (var interfaceType in additionalInterfaceTypes) + { + classDeclaration = classDeclaration.AddBaseListTypes(SimpleBaseType(interfaceType.ToTypeSyntax())); + } + } + foreach (var alias in compoundTypeAliases) { classDeclaration = classDeclaration.AddAttributeLists( @@ -148,12 +159,13 @@ private ClassDeclarationSyntax GetClassDeclarationSyntax( GenerateGetActivityName(method), GenerateGetInterfaceType(method), GenerateGetMethod(), - GenerateSetTargetMethod(method, targetField), - GenerateGetTargetMethod(targetField), + GenerateSetTargetMethod(holderField), + GenerateGetTargetMethod(method, holderField), GenerateDisposeMethod(fieldDescriptions, baseClassType), GenerateGetArgumentMethod(method, fieldDescriptions), GenerateSetArgumentMethod(method, fieldDescriptions), - GenerateInvokeInnerMethod(method, fieldDescriptions, targetField)); + GenerateInvokeInnerMethod(method, fieldDescriptions, holderField), + GenerateGetCancellableTokenIdMember(method)); if (method.AllTypeParameters.Count > 0) { @@ -182,7 +194,7 @@ private MemberDeclarationSyntax[] GenerateResponseTimeoutPropertyMembers(long va .WithExpressionBody(ArrowExpressionClause(IdentifierName("_responseTimeoutValue"))) .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword)); -; + ; return new MemberDeclarationSyntax[] { timespanField, responseTimeoutProperty }; } @@ -274,46 +286,78 @@ private INamedTypeSymbol GetBaseClassType(InvokableMethodDescription method) throw new OrleansGeneratorDiagnosticAnalysisException(InvalidRpcMethodReturnTypeDiagnostic.CreateDiagnostic(method)); } - private MemberDeclarationSyntax GenerateSetTargetMethod( - InvokableMethodDescription methodDescription, - TargetFieldDescription targetField) + private INamedTypeSymbol[] GetAdditionalInterfaceTypes(InvokableMethodDescription method) + { + if (method.IsCancellable) + { + var cancellationTokensCount = method.Method.Parameters.Count(parameterSymbol => SymbolEqualityComparer.Default.Equals(method.CodeGenerator.LibraryTypes.CancellationToken, parameterSymbol.Type)); + if (cancellationTokensCount is > 1) + { + throw new OrleansGeneratorDiagnosticAnalysisException(MultipleCancellationTokenParametersDiagnostic.CreateDiagnostic(method.Method)); + } + + return [LibraryTypes.ICancellableInvokable]; + } + + return []; + } + + private MemberDeclarationSyntax GenerateSetTargetMethod(HolderFieldDescription holderField) { var holder = IdentifierName("holder"); var holderParameter = holder.Identifier; - var containingInterface = methodDescription.ContainingInterface; - var isExtension = methodDescription.Key.ProxyBase.IsExtension; - var getTarget = InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - holder, - GenericName(isExtension ? "GetComponent" : "GetTarget") - .WithTypeArgumentList( - TypeArgumentList( - SingletonSeparatedList(containingInterface.ToTypeSyntax()))))) - .WithArgumentList(ArgumentList()); - - var body = - AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - IdentifierName(targetField.FieldName), - getTarget); return MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), "SetTarget") .WithParameterList(ParameterList(SingletonSeparatedList(Parameter(holderParameter).WithType(LibraryTypes.ITargetHolder.ToTypeSyntax())))) - .WithExpressionBody(ArrowExpressionClause(body)) + .WithExpressionBody(ArrowExpressionClause( + AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + ThisExpression(), + IdentifierName(holderField.FieldName) + ), holder))) .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))); } - private MemberDeclarationSyntax GenerateGetTargetMethod(TargetFieldDescription targetField) + private MemberDeclarationSyntax GenerateGetTargetMethod( + InvokableMethodDescription methodDescription, + HolderFieldDescription holderField) { + var isExtension = methodDescription.Key.ProxyBase.IsExtension; + var body = ConditionalAccessExpression( + holderField.FieldName.ToIdentifierName(), + InvocationExpression( + MemberBindingExpression( + GenericName(isExtension ? "GetComponent" : "GetTarget") + .WithTypeArgumentList( + TypeArgumentList( + SingletonSeparatedList(methodDescription.Method.ContainingType.ToTypeSyntax()))))) + .WithArgumentList(ArgumentList())); + return MethodDeclaration(PredefinedType(Token(SyntaxKind.ObjectKeyword)), "GetTarget") .WithParameterList(ParameterList()) - .WithExpressionBody(ArrowExpressionClause(IdentifierName(targetField.FieldName))) + .WithExpressionBody(ArrowExpressionClause(body)) .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))); } + private MemberDeclarationSyntax GenerateGetCancellableTokenIdMember(InvokableMethodDescription method) + { + if (!method.IsCancellable) + { + return null; + } + + // Method to get the CancellableTokenId + var cancellableRequestIdMethod = MethodDeclaration(LibraryTypes.Guid.ToTypeSyntax(), "GetCancellableTokenId") + .WithBody(Block( + ReturnStatement(IdentifierName("cancellableTokenId")) + )) + .AddModifiers(Token(SyntaxKind.PublicKeyword)); + + return cancellableRequestIdMethod; + } + private MemberDeclarationSyntax GenerateGetArgumentMethod( InvokableMethodDescription methodDescription, List fields) @@ -456,7 +500,7 @@ private MemberDeclarationSyntax GenerateSetArgumentMethod( private MemberDeclarationSyntax GenerateInvokeInnerMethod( InvokableMethodDescription method, List fields, - TargetFieldDescription target) + HolderFieldDescription holder) { var resultTask = IdentifierName("resultTask"); @@ -464,13 +508,28 @@ private MemberDeclarationSyntax GenerateInvokeInnerMethod( var args = SeparatedList( fields.OfType() .OrderBy(p => p.ParameterOrdinal) - .Select(p => Argument(IdentifierName(p.FieldName)))); + .Select(p => SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, p.Parameter.Type) + ? Argument(IdentifierName("cancellationToken")) + : Argument(IdentifierName(p.FieldName)))); + + var isExtension = method.Key.ProxyBase.IsExtension; + var getTarget = InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + holder.FieldName.ToIdentifierName(), + GenericName(isExtension ? "GetComponent" : "GetTarget") + .WithTypeArgumentList( + TypeArgumentList( + SingletonSeparatedList(method.Method.ContainingType.ToTypeSyntax()))))) + .WithArgumentList(ArgumentList()); + + ExpressionSyntax methodCall; if (method.MethodTypeParameters.Count > 0) { methodCall = MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(target.FieldName), + getTarget, GenericName( Identifier(method.Method.Name), TypeArgumentList( @@ -479,14 +538,125 @@ private MemberDeclarationSyntax GenerateInvokeInnerMethod( } else { - methodCall = IdentifierName(target.FieldName).Member(method.Method.Name); + methodCall = getTarget.Member(method.Method.Name); + } + + BlockSyntax body; + + if (method.Method.ReturnsVoid) + { + body = Block( + ExpressionStatement( + InvocationExpression(methodCall, ArgumentList(args)) + ) + ); + } + else if (method.IsCancellable) + { + body = Block( + LocalDeclarationStatement( + VariableDeclaration( + LibraryTypes.ICancellationRuntime.ToTypeSyntax(), + SingletonSeparatedList( + VariableDeclarator(Identifier("cancellationRuntime")).WithInitializer( + EqualsValueClause( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("holder"), + GenericName("GetComponent") + .WithTypeArgumentList( + TypeArgumentList( + SingletonSeparatedList( + LibraryTypes.ICancellationRuntime.ToTypeSyntax() + ) + ) + ) + ) + ) + ) + ) + ) + ) + ), + LocalDeclarationStatement( + VariableDeclaration( + LibraryTypes.CancellationToken.ToTypeSyntax(), + SingletonSeparatedList( + VariableDeclarator(Identifier("cancellationToken")).WithInitializer( + EqualsValueClause( + BinaryExpression( + SyntaxKind.CoalesceExpression, + ConditionalAccessExpression( + IdentifierName("cancellationRuntime"), + InvocationExpression( + MemberBindingExpression( + IdentifierName("RegisterCancellableToken"))) + .AddArgumentListArguments( + Argument( + IdentifierName("cancellableTokenId")))), + DefaultExpression(LibraryTypes.CancellationToken.ToTypeSyntax()) + ) + ) + ) + ) + ) + ), + TryStatement().WithBlock( + Block( + ((INamedTypeSymbol)method.Method.ReturnType).ConstructedFrom is { IsGenericType: true } + ? ReturnStatement( + AwaitExpression( + InvocationExpression(methodCall, ArgumentList(args)) + ) + ) + : ExpressionStatement( + AwaitExpression( + InvocationExpression(methodCall, ArgumentList(args)) + ) + ) + ) + ) + .WithFinally( + FinallyClause( + Block( + ExpressionStatement( + ConditionalAccessExpression( + IdentifierName("cancellationRuntime"), + InvocationExpression( + MemberBindingExpression( + IdentifierName("Cancel"))) + .AddArgumentListArguments( + Argument(IdentifierName("cancellableTokenId")), + Argument(LiteralExpression(SyntaxKind.TrueLiteralExpression)) + ) + ) + ) + ) + ) + ) + ); + } + else + { + body = Block( + ReturnStatement( + InvocationExpression(methodCall, ArgumentList(args)) + ) + ); } - return MethodDeclaration(method.Method.ReturnType.ToTypeSyntax(method.TypeParameterSubstitutions), "InvokeInner") + var methodDeclaration = MethodDeclaration(method.Method.ReturnType.ToTypeSyntax(method.TypeParameterSubstitutions), "InvokeInner") .WithParameterList(ParameterList()) - .WithExpressionBody(ArrowExpressionClause(InvocationExpression(methodCall, ArgumentList(args)))) - .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) + .WithBody(body) .WithModifiers(TokenList(Token(SyntaxKind.ProtectedKeyword), Token(SyntaxKind.OverrideKeyword))); + + if (!method.Method.ReturnsVoid && method.IsCancellable) + { + methodDeclaration = methodDeclaration.AddModifiers(Token(SyntaxKind.AsyncKeyword)); + } + + return methodDeclaration; } private MemberDeclarationSyntax GenerateDisposeMethod( @@ -635,6 +805,9 @@ MemberDeclarationSyntax GetFieldDeclaration(InvokerFieldDescription description) case MethodParameterFieldDescription _: field = field.AddModifiers(Token(SyntaxKind.PublicKeyword)); break; + case CancellableTokenFieldDescription _: + field = field.AddModifiers(Token(SyntaxKind.PublicKeyword)); + break; } return field; @@ -690,6 +863,16 @@ private ExpressionSyntax GetTypesArray(InvokableMethodDescription method, IEnume body.Add(ExpressionStatement(InvocationExpression(IdentifierName(methodName), ArgumentList(SeparatedList(new[] { Argument(argumentExpression) }))))); } + if (method.IsCancellable) + { + body.Add( + ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + IdentifierName("cancellableTokenId"), + InvocationExpression(LibraryTypes.Guid.ToTypeSyntax().Member("NewGuid"))))); + } + if (body.Count == 0 && parameters.Count == 0) return default; @@ -708,17 +891,22 @@ private ExpressionSyntax GetTypesArray(InvokableMethodDescription method, IEnume private List GetFieldDescriptions(InvokableMethodDescription method) { var fields = new List(); - uint fieldId = 0; + foreach (var parameter in method.Method.Parameters) { fields.Add(new MethodParameterFieldDescription(method.CodeGenerator, parameter, $"arg{fieldId}", fieldId, method.TypeParameterSubstitutions)); fieldId++; } - fields.Add(new TargetFieldDescription(method.Method.ContainingType)); + fields.Add(new HolderFieldDescription(LibraryTypes.ITargetHolder)); fields.Add(new MethodInfoFieldDescription(LibraryTypes.MethodInfo, "MethodBackingField")); + if (method.IsCancellable) + { + fields.Add(new CancellableTokenFieldDescription(LibraryTypes.Guid, "cancellableTokenId", fieldId, method.ContainingInterface)); + } + return fields; } @@ -736,9 +924,9 @@ protected InvokerFieldDescription(ITypeSymbol fieldType, string fieldName) public abstract bool IsInstanceField { get; } } - internal sealed class TargetFieldDescription : InvokerFieldDescription + internal sealed class HolderFieldDescription : InvokerFieldDescription { - public TargetFieldDescription(ITypeSymbol fieldType) : base(fieldType, "target") { } + public HolderFieldDescription(ITypeSymbol fieldType) : base(fieldType, "holder") { } public override bool IsSerializable => false; public override bool IsInstanceField => true; @@ -813,5 +1001,34 @@ public MethodInfoFieldDescription(ITypeSymbol fieldType, string fieldName) : bas public override bool IsSerializable => false; public override bool IsInstanceField => false; } + + internal sealed class CancellableTokenFieldDescription : InvokerFieldDescription, IMemberDescription + { + public CancellableTokenFieldDescription( + ITypeSymbol fieldType, + string fieldName, + uint fieldId, + INamedTypeSymbol containingType) : base(fieldType, fieldName) + { + FieldId = fieldId; + ContainingType = containingType; + } + + public IFieldSymbol Field => null; + public uint FieldId { get; } + public ISymbol Symbol => FieldType; + public ITypeSymbol Type => FieldType; + public INamedTypeSymbol ContainingType { get; } + public string AssemblyName => Type.ContainingAssembly.ToDisplayName(); + public TypeSyntax TypeSyntax => Type.ToTypeSyntax(); + public string TypeName => Type.ToDisplayName(); + public string TypeNameIdentifier => Type.GetValidIdentifier(); + public bool IsPrimaryConstructorParameter => false; + + public TypeSyntax GetTypeSyntax(ITypeSymbol typeSymbol) => TypeSyntax; + + public override bool IsSerializable => true; + public override bool IsInstanceField => true; + } } } diff --git a/src/Orleans.CodeGenerator/LibraryTypes.cs b/src/Orleans.CodeGenerator/LibraryTypes.cs index a432ab96a1..091a5182d7 100644 --- a/src/Orleans.CodeGenerator/LibraryTypes.cs +++ b/src/Orleans.CodeGenerator/LibraryTypes.cs @@ -40,6 +40,7 @@ private LibraryTypes(Compilation compilation, CodeGeneratorOptions options) ConstructorAttributeTypes = options.ConstructorAttributes.Select(Type).ToArray(); AliasAttribute = Type("Orleans.AliasAttribute"); IInvokable = Type("Orleans.Serialization.Invocation.IInvokable"); + ICancellableInvokable = Type("Orleans.Serialization.Invocation.ICancellableInvokable"); InvokeMethodNameAttribute = Type("Orleans.InvokeMethodNameAttribute"); RuntimeHelpers = Type("System.Runtime.CompilerServices.RuntimeHelpers"); InvokableCustomInitializerAttribute = Type("Orleans.InvokableCustomInitializerAttribute"); @@ -58,6 +59,8 @@ private LibraryTypes(Compilation compilation, CodeGeneratorOptions options) SuppressReferenceTrackingAttribute = Type("Orleans.SuppressReferenceTrackingAttribute"); OmitDefaultMemberValuesAttribute = Type("Orleans.OmitDefaultMemberValuesAttribute"); ITargetHolder = Type("Orleans.Serialization.Invocation.ITargetHolder"); + ICancellationRuntime = Type("Orleans.Serialization.Invocation.ICancellationRuntime"); + ICancellableInvokableGrainExtension = TypeOrDefault("Orleans.Runtime.ICancellableInvokableGrainExtension"); TypeManifestProviderAttribute = Type("Orleans.Serialization.Configuration.TypeManifestProviderAttribute"); NonSerializedAttribute = Type("System.NonSerializedAttribute"); ObsoleteAttribute = Type("System.ObsoleteAttribute"); @@ -77,11 +80,11 @@ private LibraryTypes(Compilation compilation, CodeGeneratorOptions options) _dateOnly = TypeOrDefault("System.DateOnly"); _dateTimeOffset = Type("System.DateTimeOffset"); _bitVector32 = Type("System.Collections.Specialized.BitVector32"); - _guid = Type("System.Guid"); _compareInfo = Type("System.Globalization.CompareInfo"); _cultureInfo = Type("System.Globalization.CultureInfo"); _version = Type("System.Version"); _timeOnly = TypeOrDefault("System.TimeOnly"); + Guid = Type("System.Guid"); ICodecProvider = Type("Orleans.Serialization.Serializers.ICodecProvider"); ValueSerializer = Type("Orleans.Serialization.Serializers.IValueSerializer`1"); ValueTask = Type("System.Threading.Tasks.ValueTask"); @@ -153,7 +156,7 @@ private LibraryTypes(Compilation compilation, CodeGeneratorOptions options) TimeSpan = Type("System.TimeSpan"); _ipAddress = Type("System.Net.IPAddress"); _ipEndPoint = Type("System.Net.IPEndPoint"); - _cancellationToken = Type("System.Threading.CancellationToken"); + CancellationToken = Type("System.Threading.CancellationToken"); _immutableContainerTypes = new[] { compilation.GetSpecialType(SpecialType.System_Nullable_T), @@ -218,7 +221,10 @@ INamedTypeSymbol Type(string metadataName) public INamedTypeSymbol IActivator_1 { get; private set; } public INamedTypeSymbol IBufferWriter { get; private set; } public INamedTypeSymbol IInvokable { get; private set; } + public INamedTypeSymbol ICancellableInvokable { get; private set; } public INamedTypeSymbol ITargetHolder { get; private set; } + public INamedTypeSymbol ICancellationRuntime { get; private set; } + public INamedTypeSymbol? ICancellableInvokableGrainExtension { get; private set; } public INamedTypeSymbol TypeManifestProviderAttribute { get; private set; } public INamedTypeSymbol NonSerializedAttribute { get; private set; } public INamedTypeSymbol ObsoleteAttribute { get; private set; } @@ -259,13 +265,13 @@ INamedTypeSymbol Type(string metadataName) public INamedTypeSymbol SuppressReferenceTrackingAttribute { get; private set; } public INamedTypeSymbol OmitDefaultMemberValuesAttribute { get; private set; } public INamedTypeSymbol CopyContext { get; private set; } + public INamedTypeSymbol CancellationToken { get; private set; } + public INamedTypeSymbol Guid { get; private set; } public Compilation Compilation { get; private set; } public INamedTypeSymbol TimeSpan { get; private set; } private INamedTypeSymbol _ipAddress; private INamedTypeSymbol _ipEndPoint; - private INamedTypeSymbol _cancellationToken; private INamedTypeSymbol[] _immutableContainerTypes; - private INamedTypeSymbol _guid; private INamedTypeSymbol _bitVector32; private INamedTypeSymbol _compareInfo; private INamedTypeSymbol _cultureInfo; @@ -280,14 +286,14 @@ INamedTypeSymbol Type(string metadataName) _dateOnly, _timeOnly, _dateTimeOffset, - _guid, + Guid, _bitVector32, _compareInfo, _cultureInfo, _version, _ipAddress, _ipEndPoint, - _cancellationToken, + CancellationToken, Type, _uri, _uInt128, diff --git a/src/Orleans.CodeGenerator/Model/InvokableMethodDescription.cs b/src/Orleans.CodeGenerator/Model/InvokableMethodDescription.cs index ec6b30c831..6b72db6ce1 100644 --- a/src/Orleans.CodeGenerator/Model/InvokableMethodDescription.cs +++ b/src/Orleans.CodeGenerator/Model/InvokableMethodDescription.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.Globalization; +using System.Linq; namespace Orleans.CodeGenerator { @@ -206,6 +207,11 @@ static bool TryGetNamedArgument(ImmutableArray public INamedTypeSymbol ContainingInterface { get; } + /// + /// Gets a value indicating whether this method is cancellable. + /// + public bool IsCancellable => Method.Parameters.Any(parameterSymbol => SymbolEqualityComparer.Default.Equals(CodeGenerator.LibraryTypes.CancellationToken, parameterSymbol.Type)); + public bool Equals(InvokableMethodDescription other) => Key.Equals(other.Key); public override bool Equals(object obj) => obj is InvokableMethodDescription imd && Equals(imd); public override int GetHashCode() => Key.GetHashCode(); diff --git a/src/Orleans.CodeGenerator/ProxyGenerator.cs b/src/Orleans.CodeGenerator/ProxyGenerator.cs index cdb66ecc4c..2b0ebbb02b 100644 --- a/src/Orleans.CodeGenerator/ProxyGenerator.cs +++ b/src/Orleans.CodeGenerator/ProxyGenerator.cs @@ -151,6 +151,51 @@ MethodDeclarationSyntax CreateProxyMethod(ProxyMethodDescription methodDescripti .Concat(_codeGenerator.LibraryTypes.StaticCopiers) .ToList(); + // Ensure to hook up the cancellation token if the method has one + var cancellationTokenParameter = methodSymbol.Parameters.SingleOrDefault(parameter => SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, parameter.Type)); + if (cancellationTokenParameter is not null) + { + // Throw aggressively if cancellation is already requested + statements.Add( + ExpressionStatement( + InvocationExpression( + IdentifierName($"arg{cancellationTokenParameter.Ordinal}").Member("ThrowIfCancellationRequested"), + ArgumentList() + ) + ) + ); + + // Register for cancellation + statements.Add( + ExpressionStatement( + InvocationExpression( + IdentifierName($"arg{cancellationTokenParameter.Ordinal}").Member("Register")) + .WithArgumentList( + ArgumentList(SeparatedList(new[] + { + Argument( + SimpleLambdaExpression( + Parameter(Identifier("arg")), + InvocationExpression( + InvocationExpression(ThisExpression().Member("AsReference", LibraryTypes.ICancellableInvokableGrainExtension.ToTypeSyntax())).Member("CancelRemoteToken"), + ArgumentList(SeparatedList(new[] + { + Argument( + CastExpression( + ParseTypeName(_codeGenerator.LibraryTypes.Guid.ToDisplayName()), + IdentifierName("arg") + ) + ), + })) + ) + ) + ), + Argument( + InvocationExpression( + IdentifierName("request").Member(IdentifierName("GetCancellableTokenId")))) + }))))); + } + // Set request object fields from method parameters. var parameterIndex = 0; var parameters = invokable.Members.OfType().Select(member => new SerializableMethodMember(member)); diff --git a/src/Orleans.CodeGenerator/SerializerGenerator.cs b/src/Orleans.CodeGenerator/SerializerGenerator.cs index 3f131687b4..62d7f9f819 100644 --- a/src/Orleans.CodeGenerator/SerializerGenerator.cs +++ b/src/Orleans.CodeGenerator/SerializerGenerator.cs @@ -47,6 +47,10 @@ public ClassDeclarationSyntax Generate(ISerializableTypeDescription type) { members.Add(new SerializableMethodMember(methodParameter)); } + else if (member is CancellableTokenFieldDescription cancellableTokenField) + { + members.Add(new SerializableCancellableTokenMember(_codeGenerator, cancellableTokenField)); + } } var fieldDescriptions = GetFieldDescriptions(type, members); @@ -1137,6 +1141,8 @@ public SerializableMethodMember(MethodParameterFieldDescription member) public bool IsShallowCopyable => LibraryTypes.IsShallowCopyable(_member.Parameter.Type) || _member.Parameter.HasAnyAttribute(LibraryTypes.ImmutableAttributes); + public bool IsCancellationToken => SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, _member.Parameter.Type); + /// /// Gets syntax representing the type of this field. /// @@ -1168,6 +1174,56 @@ public ExpressionSyntax GetSetter(ExpressionSyntax instance, ExpressionSyntax va public FieldAccessorDescription GetSetterFieldDescription() => null; } + internal class SerializableCancellableTokenMember : ISerializableMember + { + private readonly CodeGenerator _codeGenerator; + private readonly CancellableTokenFieldDescription _member; + + public SerializableCancellableTokenMember(CodeGenerator codeGenerator, CancellableTokenFieldDescription member) + { + _codeGenerator = codeGenerator; + _member = member; + } + + public IMemberDescription Member => _member; + + private LibraryTypes LibraryTypes => _codeGenerator.LibraryTypes; + + public bool IsShallowCopyable => LibraryTypes.IsShallowCopyable(_member.Type); + + public bool IsCancellationToken => false; + + /// + /// Gets syntax representing the type of this field. + /// + public TypeSyntax TypeSyntax => _member.TypeSyntax; + + public bool IsValueType => _member.Type.IsValueType; + + public bool IsPrimaryConstructorParameter => false; + + /// + /// Returns syntax for retrieving the value of this field, deep copying it if necessary. + /// + /// The instance of the containing type. + /// Syntax for retrieving the value of this field. + public ExpressionSyntax GetGetter(ExpressionSyntax instance) => instance.Member(_member.FieldName); + + /// + /// Returns syntax for setting the value of this field. + /// + /// The instance of the containing type. + /// Syntax for the new value. + /// Syntax for setting the value of this field. + public ExpressionSyntax GetSetter(ExpressionSyntax instance, ExpressionSyntax value) => AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + instance.Member(_member.FieldName), + value); + + public FieldAccessorDescription GetGetterFieldDescription() => null; + public FieldAccessorDescription GetSetterFieldDescription() => null; + } + /// /// Represents a serializable member (field/property) of a type. /// diff --git a/src/Orleans.Core.Abstractions/Cancellation/ICancellableInvokableGrainExtension.cs b/src/Orleans.Core.Abstractions/Cancellation/ICancellableInvokableGrainExtension.cs new file mode 100644 index 0000000000..75337bba81 --- /dev/null +++ b/src/Orleans.Core.Abstractions/Cancellation/ICancellableInvokableGrainExtension.cs @@ -0,0 +1,19 @@ +using System; +using System.Threading.Tasks; +using Orleans.Concurrency; + +namespace Orleans.Runtime; + +public interface ICancellableInvokableGrainExtension : IGrainExtension +{ + /// + /// Indicates that a cancellation token has been canceled. + /// + /// + /// The token id + /// + /// A representing the operation. + /// + [AlwaysInterleave] + Task CancelRemoteToken(Guid tokenId); +} \ No newline at end of file diff --git a/src/Orleans.Runtime/Cancellation/CancellableInvokableGrainExtension.cs b/src/Orleans.Runtime/Cancellation/CancellableInvokableGrainExtension.cs new file mode 100644 index 0000000000..916fd6e64f --- /dev/null +++ b/src/Orleans.Runtime/Cancellation/CancellableInvokableGrainExtension.cs @@ -0,0 +1,12 @@ +using System; +using System.Threading.Tasks; + +namespace Orleans.Runtime.Cancellation; + +internal class CancellableInvokableGrainExtension : ICancellableInvokableGrainExtension +{ + public Task CancelRemoteToken(Guid tokenId) + { + throw new NotImplementedException(); + } +} diff --git a/src/Orleans.Runtime/Cancellation/CancellationRuntime.cs b/src/Orleans.Runtime/Cancellation/CancellationRuntime.cs new file mode 100644 index 0000000000..be6da144cf --- /dev/null +++ b/src/Orleans.Runtime/Cancellation/CancellationRuntime.cs @@ -0,0 +1,48 @@ +using System; +using System.Collections.Concurrent; +using System.Threading; +using Orleans.Serialization.Invocation; + +namespace Orleans.Runtime.Cancellation; + +internal class CancellationRuntime : ICancellationRuntime +{ + readonly ConcurrentDictionary _cancellationTokens = new ConcurrentDictionary(); + + TokenEntry GetOrCreateEntry(Guid tokenId) + { + return _cancellationTokens.GetOrAdd(tokenId, _ => new TokenEntry(new CancellationTokenSource())); + } + + public void Cancel(Guid tokenId, bool lastCall) + { + if (lastCall) + { + // On a last call, we can remove the token entry and dispose of it. If no entry exists then we can ignore the call. + if (_cancellationTokens.TryRemove(tokenId, out var entry)) + { + entry.CancellationTokenSource.Cancel(); + entry.Dispose(); + } + } + else + { + // If our invokable has yet to complete, we can cancel the token and leave the entry in place. + var entry = GetOrCreateEntry(tokenId); + entry.CancellationTokenSource.Cancel(); + } + } + + public CancellationToken RegisterCancellableToken(Guid tokenId) + { + var entry = GetOrCreateEntry(tokenId); + return entry.CancellationTokenSource.Token; + } + + readonly record struct TokenEntry(CancellationTokenSource CancellationTokenSource) : IDisposable + { + // TODO: Expire the entry after a certain amount of time + + public void Dispose() => CancellationTokenSource.Dispose(); + } +} \ No newline at end of file diff --git a/src/Orleans.Serialization/Invocation/ICancellableInvokable.cs b/src/Orleans.Serialization/Invocation/ICancellableInvokable.cs new file mode 100644 index 0000000000..55fce390ce --- /dev/null +++ b/src/Orleans.Serialization/Invocation/ICancellableInvokable.cs @@ -0,0 +1,16 @@ +#nullable enable +using System; + +namespace Orleans.Serialization.Invocation +{ + /// + /// Represents an invokable that can be canceled + /// + public interface ICancellableInvokable : IInvokable + { + /// + /// Returns an id that uniquely identifies this invokable + /// + Guid GetCancellableTokenId(); + } +} \ No newline at end of file diff --git a/src/Orleans.Serialization/Invocation/ICancellationRuntime.cs b/src/Orleans.Serialization/Invocation/ICancellationRuntime.cs new file mode 100644 index 0000000000..a92ffe07bc --- /dev/null +++ b/src/Orleans.Serialization/Invocation/ICancellationRuntime.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Orleans.Serialization.Invocation; + +/// +/// An optional runtime to fascilitate in cancelling invokables +/// +public interface ICancellationRuntime +{ + /// + /// Registers the token and returns a cancellation token linked to the token id + /// + /// The token id to register + /// A cancellationToken that will be cancelled once Cancel for the token has been called + CancellationToken RegisterCancellableToken(Guid tokenId); + + /// + /// Cancels the invokable with the specified token id + /// + /// The token id to cancel + /// Whether this is the last call associated with the token + void Cancel(Guid tokenId, bool lastCall); +} \ No newline at end of file From d2334d74e2bc4aaf49450e45f2c67e7d07d1c0e9 Mon Sep 17 00:00:00 2001 From: Koen Date: Mon, 26 Aug 2024 01:54:59 +0100 Subject: [PATCH 02/16] Remove cancellationToken as an invokable argument --- .../InvokableGenerator.cs | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/src/Orleans.CodeGenerator/InvokableGenerator.cs b/src/Orleans.CodeGenerator/InvokableGenerator.cs index b60473e1e9..2bc88d859e 100644 --- a/src/Orleans.CodeGenerator/InvokableGenerator.cs +++ b/src/Orleans.CodeGenerator/InvokableGenerator.cs @@ -504,13 +504,13 @@ private MemberDeclarationSyntax GenerateInvokeInnerMethod( { var resultTask = IdentifierName("resultTask"); + // C# var resultTask = this.target.{Method}({params}); var args = SeparatedList( - fields.OfType() - .OrderBy(p => p.ParameterOrdinal) - .Select(p => SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, p.Parameter.Type) + method.Method.Parameters + .Select(p => SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, p.Type) ? Argument(IdentifierName("cancellationToken")) - : Argument(IdentifierName(p.FieldName)))); + : Argument(IdentifierName($"arg{p.Ordinal}")))); var isExtension = method.Key.ProxyBase.IsExtension; var getTarget = InvocationExpression( @@ -691,11 +691,19 @@ private MemberDeclarationSyntax GenerateDisposeMethod( } private MemberDeclarationSyntax GenerateGetArgumentCount(InvokableMethodDescription methodDescription) - => methodDescription.Method.Parameters.Length is var count and not 0 ? + { + var count = methodDescription.Method.Parameters.Length; + if (methodDescription.IsCancellable) + { + count -= 1; + } + + return count is not 0 ? MethodDeclaration(PredefinedType(Token(SyntaxKind.IntKeyword)), "GetArgumentCount") .WithExpressionBody(ArrowExpressionClause(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(count)))) .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))) .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) : null; + } private MemberDeclarationSyntax GenerateGetActivityName(InvokableMethodDescription methodDescription) { @@ -895,6 +903,11 @@ private List GetFieldDescriptions(InvokableMethodDescri foreach (var parameter in method.Method.Parameters) { + if (SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, parameter.Type)) + { + continue; + } + fields.Add(new MethodParameterFieldDescription(method.CodeGenerator, parameter, $"arg{fieldId}", fieldId, method.TypeParameterSubstitutions)); fieldId++; } From 669bd5905b7cfad836da432bdb6548cffb4783a7 Mon Sep 17 00:00:00 2001 From: Koen Date: Mon, 26 Aug 2024 02:59:17 +0100 Subject: [PATCH 03/16] Implemented runtime --- .../CancellableInvokableGrainExtension.cs | 25 ++++- .../Cancellation/CancellationRuntime.cs | 92 +++++++++++++++---- .../Hosting/DefaultSiloServices.cs | 5 + .../Invocation/ICancellationRuntime.cs | 5 + 4 files changed, 107 insertions(+), 20 deletions(-) diff --git a/src/Orleans.Runtime/Cancellation/CancellableInvokableGrainExtension.cs b/src/Orleans.Runtime/Cancellation/CancellableInvokableGrainExtension.cs index 916fd6e64f..b418c810c8 100644 --- a/src/Orleans.Runtime/Cancellation/CancellableInvokableGrainExtension.cs +++ b/src/Orleans.Runtime/Cancellation/CancellableInvokableGrainExtension.cs @@ -1,12 +1,33 @@ using System; +using System.Threading; using System.Threading.Tasks; +using Orleans.Serialization.Invocation; namespace Orleans.Runtime.Cancellation; -internal class CancellableInvokableGrainExtension : ICancellableInvokableGrainExtension +internal class CancellableInvokableGrainExtension : ICancellableInvokableGrainExtension, IDisposable { + readonly ICancellationRuntime _runtime; + readonly Timer _cleanupTimer; + + public CancellableInvokableGrainExtension(IGrainContext grainContext) + { + _runtime = grainContext.GetComponent(); + _cleanupTimer = new Timer(obj => ((CancellableInvokableGrainExtension)obj)._runtime.ExpireTokens(), this, TimeSpan.FromSeconds(30), TimeSpan.FromSeconds(30)); + } + public Task CancelRemoteToken(Guid tokenId) { - throw new NotImplementedException(); + if (_runtime is not null) + { + _runtime.Cancel(tokenId, lastCall: false); + } + + return Task.CompletedTask; + } + + public void Dispose() + { + _cleanupTimer.Dispose(); } } diff --git a/src/Orleans.Runtime/Cancellation/CancellationRuntime.cs b/src/Orleans.Runtime/Cancellation/CancellationRuntime.cs index be6da144cf..3e5652230a 100644 --- a/src/Orleans.Runtime/Cancellation/CancellationRuntime.cs +++ b/src/Orleans.Runtime/Cancellation/CancellationRuntime.cs @@ -1,48 +1,104 @@ using System; -using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics; +using System.Runtime.InteropServices; using System.Threading; +using Microsoft.Extensions.Logging; using Orleans.Serialization.Invocation; namespace Orleans.Runtime.Cancellation; internal class CancellationRuntime : ICancellationRuntime { - readonly ConcurrentDictionary _cancellationTokens = new ConcurrentDictionary(); + private static readonly TimeSpan _cleanupFrequency = TimeSpan.FromMinutes(7); - TokenEntry GetOrCreateEntry(Guid tokenId) + readonly Dictionary _cancellationTokens = new Dictionary(); + + CancellationTokenSource _reusableCancellationTokenSource; + + ref TokenEntry GetOrCreateEntry(Guid tokenId) { - return _cancellationTokens.GetOrAdd(tokenId, _ => new TokenEntry(new CancellationTokenSource())); + lock (_cancellationTokens) + { + ref var entry = ref CollectionsMarshal.GetValueRefOrAddDefault(_cancellationTokens, tokenId, out var exists); + + if (!exists) + { + var cancellationTokenSource = _reusableCancellationTokenSource; + if (cancellationTokenSource is not null) + { + _reusableCancellationTokenSource = null; + } + else + { + cancellationTokenSource = new CancellationTokenSource(); + } + entry.SetSource(cancellationTokenSource); + } + + entry.Touch(); + return ref entry; + } } public void Cancel(Guid tokenId, bool lastCall) { + var entry = GetOrCreateEntry(tokenId); + entry.Source.Cancel(); + if (lastCall) { - // On a last call, we can remove the token entry and dispose of it. If no entry exists then we can ignore the call. - if (_cancellationTokens.TryRemove(tokenId, out var entry)) + // Cancel the source on the last call + entry.Source.Cancel(); + + // Try and reuse the source + if (_reusableCancellationTokenSource is not null || entry.Source.TryReset() is false || Interlocked.CompareExchange(ref _reusableCancellationTokenSource, entry.Source, null) != entry.Source) { - entry.CancellationTokenSource.Cancel(); - entry.Dispose(); + // Dispose if we failed to reuse + entry.Source.Dispose(); } } - else - { - // If our invokable has yet to complete, we can cancel the token and leave the entry in place. - var entry = GetOrCreateEntry(tokenId); - entry.CancellationTokenSource.Cancel(); - } } public CancellationToken RegisterCancellableToken(Guid tokenId) { var entry = GetOrCreateEntry(tokenId); - return entry.CancellationTokenSource.Token; + return entry.Source.Token; + } + + public void ExpireTokens() + { + var now = Stopwatch.GetTimestamp(); + lock (_cancellationTokens) + { + foreach (var token in _cancellationTokens) + { + if (token.Value.IsExpired(_cleanupFrequency, now)) + { + _cancellationTokens.Remove(token.Key); + } + } + } } - readonly record struct TokenEntry(CancellationTokenSource CancellationTokenSource) : IDisposable + struct TokenEntry { - // TODO: Expire the entry after a certain amount of time + private long _createdTime; + + public void Touch() => _createdTime = Stopwatch.GetTimestamp(); - public void Dispose() => CancellationTokenSource.Dispose(); + public void SetSource(CancellationTokenSource source) + { + Source = source; + } + + public CancellationTokenSource Source { get; private set; } + + public bool IsExpired(TimeSpan expiry, long nowTimestamp) + { + var untouchedTime = TimeSpan.FromSeconds((nowTimestamp - _createdTime) / (double)Stopwatch.Frequency); + + return untouchedTime >= expiry; + } } } \ No newline at end of file diff --git a/src/Orleans.Runtime/Hosting/DefaultSiloServices.cs b/src/Orleans.Runtime/Hosting/DefaultSiloServices.cs index 614eb90430..43b5a0482e 100644 --- a/src/Orleans.Runtime/Hosting/DefaultSiloServices.cs +++ b/src/Orleans.Runtime/Hosting/DefaultSiloServices.cs @@ -43,6 +43,8 @@ using Orleans.Serialization.Internal; using Orleans.Core; using Orleans.Placement.Repartitioning; +using Orleans.Runtime.Cancellation; +using Orleans.Serialization.Invocation; namespace Orleans.Hosting { @@ -99,6 +101,9 @@ internal static void AddDefaultServices(ISiloBuilder builder) services.TryAddSingleton(); services.AddTransient(); services.AddKeyedTransient(typeof(ICancellationSourcesExtension), (sp, _) => sp.GetRequiredService()); + services.TryAddTransient(); + services.AddTransient(); + services.AddKeyedTransient(typeof(ICancellableInvokableGrainExtension), (sp, _) => sp.GetRequiredService()); services.TryAddSingleton(sp => sp.GetRequiredService().ConcreteGrainFactory); services.TryAddSingleton(); services.TryAddSingleton(); diff --git a/src/Orleans.Serialization/Invocation/ICancellationRuntime.cs b/src/Orleans.Serialization/Invocation/ICancellationRuntime.cs index a92ffe07bc..64a80f8468 100644 --- a/src/Orleans.Serialization/Invocation/ICancellationRuntime.cs +++ b/src/Orleans.Serialization/Invocation/ICancellationRuntime.cs @@ -25,4 +25,9 @@ public interface ICancellationRuntime /// The token id to cancel /// Whether this is the last call associated with the token void Cancel(Guid tokenId, bool lastCall); + + /// + /// Expires any tokens that have not yet been cancelled and have been inactive for a period of time + /// + void ExpireTokens(); } \ No newline at end of file From 790e92621cc54597d3d977fb14e9797daf8cbc9e Mon Sep 17 00:00:00 2001 From: Koen Date: Mon, 26 Aug 2024 03:00:55 +0100 Subject: [PATCH 04/16] Revert "Remove cancellationToken as an invokable argument" This reverts commit d2334d74e2bc4aaf49450e45f2c67e7d07d1c0e9. --- .../InvokableGenerator.cs | 23 ++++--------------- 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/src/Orleans.CodeGenerator/InvokableGenerator.cs b/src/Orleans.CodeGenerator/InvokableGenerator.cs index 2bc88d859e..b60473e1e9 100644 --- a/src/Orleans.CodeGenerator/InvokableGenerator.cs +++ b/src/Orleans.CodeGenerator/InvokableGenerator.cs @@ -504,13 +504,13 @@ private MemberDeclarationSyntax GenerateInvokeInnerMethod( { var resultTask = IdentifierName("resultTask"); - // C# var resultTask = this.target.{Method}({params}); var args = SeparatedList( - method.Method.Parameters - .Select(p => SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, p.Type) + fields.OfType() + .OrderBy(p => p.ParameterOrdinal) + .Select(p => SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, p.Parameter.Type) ? Argument(IdentifierName("cancellationToken")) - : Argument(IdentifierName($"arg{p.Ordinal}")))); + : Argument(IdentifierName(p.FieldName)))); var isExtension = method.Key.ProxyBase.IsExtension; var getTarget = InvocationExpression( @@ -691,19 +691,11 @@ private MemberDeclarationSyntax GenerateDisposeMethod( } private MemberDeclarationSyntax GenerateGetArgumentCount(InvokableMethodDescription methodDescription) - { - var count = methodDescription.Method.Parameters.Length; - if (methodDescription.IsCancellable) - { - count -= 1; - } - - return count is not 0 ? + => methodDescription.Method.Parameters.Length is var count and not 0 ? MethodDeclaration(PredefinedType(Token(SyntaxKind.IntKeyword)), "GetArgumentCount") .WithExpressionBody(ArrowExpressionClause(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(count)))) .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))) .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) : null; - } private MemberDeclarationSyntax GenerateGetActivityName(InvokableMethodDescription methodDescription) { @@ -903,11 +895,6 @@ private List GetFieldDescriptions(InvokableMethodDescri foreach (var parameter in method.Method.Parameters) { - if (SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, parameter.Type)) - { - continue; - } - fields.Add(new MethodParameterFieldDescription(method.CodeGenerator, parameter, $"arg{fieldId}", fieldId, method.TypeParameterSubstitutions)); fieldId++; } From aec371017316fbedafad965e12dd7ad04bf872a9 Mon Sep 17 00:00:00 2001 From: Koen Date: Mon, 26 Aug 2024 03:22:47 +0100 Subject: [PATCH 05/16] Enable cancellationTokens to travel over the wire --- src/Orleans.CodeGenerator/LibraryTypes.cs | 1 + .../Codecs/CancellationTokenCodec.cs | 55 +++++++++++++++++++ 2 files changed, 56 insertions(+) create mode 100644 src/Orleans.Serialization/Codecs/CancellationTokenCodec.cs diff --git a/src/Orleans.CodeGenerator/LibraryTypes.cs b/src/Orleans.CodeGenerator/LibraryTypes.cs index 091a5182d7..653ad83076 100644 --- a/src/Orleans.CodeGenerator/LibraryTypes.cs +++ b/src/Orleans.CodeGenerator/LibraryTypes.cs @@ -127,6 +127,7 @@ private LibraryTypes(Compilation compilation, CodeGeneratorOptions options) new(TypeOrDefault("System.Int128"), TypeOrDefault("Orleans.Serialization.Codecs.Int128Codec")), new(TypeOrDefault("System.Half"), TypeOrDefault("Orleans.Serialization.Codecs.HalfCodec")), new(Type("System.Uri"), Type("Orleans.Serialization.Codecs.UriCodec")), + new(Type("System.Threading.CancellationToken"), Type("Orleans.Serialization.Codecs.CancellationTokenCodec")), }.Where(desc => desc.UnderlyingType is { } && desc.CodecType is { }).ToArray(); WellKnownCodecs = new WellKnownCodecDescription[] { diff --git a/src/Orleans.Serialization/Codecs/CancellationTokenCodec.cs b/src/Orleans.Serialization/Codecs/CancellationTokenCodec.cs new file mode 100644 index 0000000000..5f452428b8 --- /dev/null +++ b/src/Orleans.Serialization/Codecs/CancellationTokenCodec.cs @@ -0,0 +1,55 @@ +using System; +using System.Buffers; +using System.Runtime.CompilerServices; +using System.Threading; +using Orleans.Serialization.Buffers; +using Orleans.Serialization.WireProtocol; + +namespace Orleans.Serialization.Codecs +{ + /// + /// Serializer for . + /// + [RegisterSerializer] + public sealed class CancellationTokenCodec : IFieldCodec + { + void IFieldCodec.WriteField(ref Writer writer, uint fieldIdDelta, Type expectedType, CancellationToken value) + { + ReferenceCodec.MarkValueField(writer.Session); + writer.WriteFieldHeader(fieldIdDelta, expectedType, typeof(CancellationToken), WireType.Fixed32); + writer.WriteInt32(value.IsCancellationRequested ? 1 : 0); + } + + /// + /// Writes a field without type info (expected type is statically known). + /// + /// The buffer writer type. + /// The writer. + /// The field identifier delta. + /// The value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void WriteField(ref Writer writer, uint fieldIdDelta, CancellationToken value) where TBufferWriter : IBufferWriter + { + ReferenceCodec.MarkValueField(writer.Session); + writer.WriteFieldHeaderExpected(fieldIdDelta, WireType.Fixed32); + writer.WriteInt32(value.IsCancellationRequested ? 1 : 0); + } + + /// + CancellationToken IFieldCodec.ReadValue(ref Reader reader, Field field) => ReadValue(ref reader, field); + + /// + /// Reads a value. + /// + /// The reader input type. + /// The reader. + /// The field. + /// The value. + public static CancellationToken ReadValue(ref Reader reader, Field field) + { + ReferenceCodec.MarkValueField(reader.Session); + field.EnsureWireType(WireType.Fixed32); + return new CancellationToken(reader.ReadInt32() == 1 ? true : false); + } + } +} \ No newline at end of file From 5d2f005bf71f8abd337a6209ca0c0d026cb074be Mon Sep 17 00:00:00 2001 From: Koen Date: Mon, 26 Aug 2024 03:24:05 +0100 Subject: [PATCH 06/16] Added tests --- .../TestGrainInterfaces/IGenericInterfaces.cs | 4 + test/Grains/TestGrains/GenericGrains.cs | 39 +++ .../CancellationTokenTests.cs | 247 ++++++++++++++++++ 3 files changed, 290 insertions(+) create mode 100644 test/Tester/CancellationTests/CancellationTokenTests.cs diff --git a/test/Grains/TestGrainInterfaces/IGenericInterfaces.cs b/test/Grains/TestGrainInterfaces/IGenericInterfaces.cs index 9a84c549ce..1a500fa27e 100644 --- a/test/Grains/TestGrainInterfaces/IGenericInterfaces.cs +++ b/test/Grains/TestGrainInterfaces/IGenericInterfaces.cs @@ -193,15 +193,19 @@ public interface ILongRunningTaskGrain : IGrainWithGuidKey Task GetRuntimeInstanceIdWithDelay(TimeSpan delay); Task LongWait(GrainCancellationToken tc, TimeSpan delay); + Task LongWait(CancellationToken tc, TimeSpan delay); Task LongRunningTask(T t, TimeSpan delay); Task CallOtherLongRunningTask(ILongRunningTaskGrain target, T t, TimeSpan delay); Task FanOutOtherLongRunningTask(ILongRunningTaskGrain target, T t, TimeSpan delay, int degreeOfParallelism); Task CallOtherLongRunningTask(ILongRunningTaskGrain target, GrainCancellationToken tc, TimeSpan delay); + Task CallOtherLongRunningTask(ILongRunningTaskGrain target, CancellationToken tc, TimeSpan delay); Task CallOtherLongRunningTaskWithLocalToken(ILongRunningTaskGrain target, TimeSpan delay, TimeSpan delayBeforeCancel); Task CancellationTokenCallbackResolve(GrainCancellationToken tc); + Task CancellationTokenCallbackResolve(CancellationToken tc); Task CallOtherCancellationTokenCallbackResolve(ILongRunningTaskGrain target); Task CancellationTokenCallbackThrow(GrainCancellationToken tc); + Task CancellationTokenCallbackThrow(CancellationToken tc); Task GetLastValue(); } diff --git a/test/Grains/TestGrains/GenericGrains.cs b/test/Grains/TestGrains/GenericGrains.cs index ac7196f59f..d69575e26a 100644 --- a/test/Grains/TestGrains/GenericGrains.cs +++ b/test/Grains/TestGrains/GenericGrains.cs @@ -627,6 +627,16 @@ public Task CancellationTokenCallbackThrow(GrainCancellationToken tc) return Task.CompletedTask; } + public Task CancellationTokenCallbackThrow(CancellationToken tc) + { + tc.Register(() => + { + throw new InvalidOperationException("From cancellation token callback"); + }); + + return Task.CompletedTask; + } + public Task GetLastValue() { return Task.FromResult(lastValue); @@ -660,6 +670,25 @@ public Task CancellationTokenCallbackResolve(GrainCancellationToken tc) return tcs.Task; } + public Task CancellationTokenCallbackResolve(CancellationToken tc) + { + var tcs = new TaskCompletionSource(); + var orleansTs = TaskScheduler.Current; + tc.Register(() => + { + if (TaskScheduler.Current != orleansTs) + { + tcs.SetException(new Exception("Callback executed on wrong thread")); + } + else + { + tcs.SetResult(true); + } + }); + + return tcs.Task; + } + public async Task CallOtherLongRunningTask(ILongRunningTaskGrain target, T t, TimeSpan delay) { return await target.LongRunningTask(t, delay); @@ -681,6 +710,11 @@ public async Task CallOtherLongRunningTask(ILongRunningTaskGrain target, Grai await target.LongWait(tc, delay); } + public async Task CallOtherLongRunningTask(ILongRunningTaskGrain target, CancellationToken tc, TimeSpan delay) + { + await target.LongWait(tc, delay); + } + public async Task CallOtherLongRunningTaskWithLocalToken(ILongRunningTaskGrain target, TimeSpan delay, TimeSpan delayBeforeCancel) { var tcs = new GrainCancellationTokenSource(); @@ -695,6 +729,11 @@ public async Task LongWait(GrainCancellationToken tc, TimeSpan delay) await Task.Delay(delay, tc.CancellationToken); } + public async Task LongWait(CancellationToken tc, TimeSpan delay) + { + await Task.Delay(delay, tc); + } + public async Task LongRunningTask(T t, TimeSpan delay) { await Task.Delay(delay); diff --git a/test/Tester/CancellationTests/CancellationTokenTests.cs b/test/Tester/CancellationTests/CancellationTokenTests.cs new file mode 100644 index 0000000000..a1c742444b --- /dev/null +++ b/test/Tester/CancellationTests/CancellationTokenTests.cs @@ -0,0 +1,247 @@ +using Microsoft.Extensions.Logging; +using Orleans.TestingHost; +using TestExtensions; +using UnitTests.GrainInterfaces; +using Xunit; + +namespace UnitTests.CancellationTests +{ + public class CancellationTokenTests : OrleansTestingBase, IClassFixture + { + private readonly Fixture fixture; + + public class Fixture : BaseTestClusterFixture + { + protected override void ConfigureTestCluster(TestClusterBuilder builder) + { + base.ConfigureTestCluster(builder); + builder.AddSiloBuilderConfigurator(); + } + + private class SiloConfig : ISiloConfigurator + { + public void Configure(ISiloBuilder siloBuilder) + { + siloBuilder.ConfigureLogging(logging => logging.AddDebug()); + } + } + } + + public CancellationTokenTests(Fixture fixture) + { + this.fixture = fixture; + } + + [Theory, TestCategory("BVT"), TestCategory("Cancellation")] + [InlineData(0)] + [InlineData(10)] + [InlineData(300)] + public async Task GrainTaskCancellation(int delay) + { + var grain = this.fixture.GrainFactory.GetGrain>(Guid.NewGuid()); + var tcs = new CancellationTokenSource(); + var grainTask = grain.LongWait(tcs.Token, TimeSpan.FromSeconds(10)); + await Task.Delay(TimeSpan.FromMilliseconds(delay)); + await tcs.CancelAsync(); + await Assert.ThrowsAsync(() => grainTask); + } + + [Theory, TestCategory("BVT"), TestCategory("Cancellation")] + [InlineData(0)] + [InlineData(10)] + [InlineData(300)] + public async Task MultipleGrainsTaskCancellation(int delay) + { + var tcs = new CancellationTokenSource(); + var grainTasks = Enumerable.Range(0, 5) + .Select(i => this.fixture.GrainFactory.GetGrain>(Guid.NewGuid()) + .LongWait(tcs.Token, TimeSpan.FromSeconds(10))) + .Select(task => Assert.ThrowsAsync(() => task)).ToList(); + await Task.Delay(TimeSpan.FromMilliseconds(delay)); + await tcs.CancelAsync(); + await Task.WhenAll(grainTasks); + } + + [Theory, TestCategory("BVT"), TestCategory("Cancellation")] + [InlineData(0)] + [InlineData(10)] + [InlineData(300)] + public async Task GrainTaskMultipleCancellations(int delay) + { + var grain = this.fixture.GrainFactory.GetGrain>(Guid.NewGuid()); + var grainTasks = Enumerable.Range(0, 5) + .Select(async i => + { + var tcs = new CancellationTokenSource(); + var task = grain.LongWait(tcs.Token, TimeSpan.FromSeconds(10)); + await Task.WhenAny(task, Task.Delay(TimeSpan.FromMilliseconds(delay))); + await tcs.CancelAsync(); + try + { + await task; + Assert.Fail("Expected TaskCancelledException, but message completed"); + } + catch (TaskCanceledException) { } + }) + .ToList(); + await Task.WhenAll(grainTasks); + } + + [Fact, TestCategory("BVT"), TestCategory("Cancellation")] + public async Task MultipleGrainTasksSingleCancellation() + { + var grain = this.fixture.GrainFactory.GetGrain>(Guid.NewGuid()); + + var primaryTsc = new CancellationTokenSource(); + var primaryTask = grain.LongWait(primaryTsc.Token, TimeSpan.FromSeconds(10)); + var otherTsc = new CancellationTokenSource(); + var otherGrainTask = grain.LongWait(otherTsc.Token, TimeSpan.FromSeconds(10)); + + primaryTsc.Cancel(); + await Assert.ThrowsAnyAsync(() => primaryTask); + + await Task.Delay(TimeSpan.FromMilliseconds(100)); + Assert.Equal(TaskStatus.Running, otherGrainTask.Status); + } + + [Fact, TestCategory("BVT"), TestCategory("Cancellation")] + public async Task TokenPassingWithoutCancellation_NoExceptionShouldBeThrown() + { + var grain = this.fixture.GrainFactory.GetGrain>(Guid.NewGuid()); + var tcs = new CancellationTokenSource(); + try + { + await grain.LongWait(tcs.Token, TimeSpan.FromMilliseconds(1)); + } + catch (Exception ex) + { + Assert.Fail("Expected no exception, but got: " + ex.Message); + } + } + + [Fact, TestCategory("BVT"), TestCategory("Cancellation")] + public async Task PreCancelledTokenPassing() + { + var grain = this.fixture.GrainFactory.GetGrain>(Guid.NewGuid()); + var tcs = new CancellationTokenSource(); + await tcs.CancelAsync(); + + // Except a OperationCanceledException to be thrown as the token is already cancelled + Assert.Throws(() => grain.LongWait(tcs.Token, TimeSpan.FromSeconds(10)).Ignore()); + } + + [Fact, TestCategory("BVT"), TestCategory("Cancellation")] + public async Task CancellationTokenCallbacksExecutionContext() + { + var grain = this.fixture.GrainFactory.GetGrain>(Guid.NewGuid()); + var tcs = new CancellationTokenSource(); + var grainTask = grain.CancellationTokenCallbackResolve(tcs.Token); + await Task.Delay(TimeSpan.FromMilliseconds(100)); + await tcs.CancelAsync(); + var result = await grainTask; + Assert.True(result); + } + + [Fact, TestCategory("BVT"), TestCategory("Cancellation")] + public async Task CancellationTokenCallbacksTaskSchedulerContext() + { + var grains = await GetGrains(false); + + var tcs = new CancellationTokenSource(); + var grainTask = grains.Item1.CallOtherCancellationTokenCallbackResolve(grains.Item2); + await tcs.CancelAsync(); + var result = await grainTask; + Assert.True(result); + } + + [Fact, TestCategory("Cancellation")] + public async Task CancellationTokenCallbacksThrow_ExceptionDoesNotPropagate() + { + var grain = this.fixture.GrainFactory.GetGrain>(Guid.NewGuid()); + var tcs = new CancellationTokenSource(); + _ = grain.CancellationTokenCallbackThrow(tcs.Token); + await Task.Delay(TimeSpan.FromMilliseconds(100)); + // Cancellation is a cooperative mechanism, so we don't expect the exception to propagate + await tcs.CancelAsync(); + } + + [Theory, TestCategory("BVT"), TestCategory("Cancellation")] + [InlineData(0)] + [InlineData(10)] + [InlineData(300)] + public async Task InSiloGrainCancellation(int delay) + { + await GrainGrainCancellation(false, delay); + } + + [Theory, TestCategory("BVT"), TestCategory("Cancellation")] + [InlineData(0)] + [InlineData(10)] + [InlineData(300)] + public async Task InterSiloGrainCancellation(int delay) + { + await GrainGrainCancellation(true, delay); + } + + [SkippableTheory(Skip = "https://github.com/dotnet/orleans/issues/5654"), TestCategory("BVT"), TestCategory("Cancellation")] + [InlineData(0)] + [InlineData(10)] + [InlineData(300)] + public async Task InterSiloClientCancellationTokenPassing(int delay) + { + await ClientGrainGrainTokenPassing(delay, true); + } + + [Theory, TestCategory("BVT"), TestCategory("Cancellation")] + [InlineData(0)] + [InlineData(10)] + [InlineData(300)] + public async Task InSiloClientCancellationTokenPassing(int delay) + { + await ClientGrainGrainTokenPassing(delay, false); + } + + private async Task ClientGrainGrainTokenPassing(int delay, bool interSilo) + { + var grains = await GetGrains(interSilo); + var grain = grains.Item1; + var target = grains.Item2; + var tcs = new CancellationTokenSource(); + var grainTask = grain.CallOtherLongRunningTask(target, tcs.Token, TimeSpan.FromSeconds(10)); + await Task.Delay(TimeSpan.FromMilliseconds(delay)); + await tcs.CancelAsync(); + await Assert.ThrowsAnyAsync(() => grainTask); + } + + private async Task GrainGrainCancellation(bool interSilo, int delay) + { + var grains = await GetGrains(interSilo); + var grain = grains.Item1; + var target = grains.Item2; + var grainTask = grain.CallOtherLongRunningTaskWithLocalToken(target, TimeSpan.FromSeconds(10), + TimeSpan.FromMilliseconds(delay)); + await Assert.ThrowsAnyAsync(() => grainTask); + } + + private async Task, ILongRunningTaskGrain>> GetGrains(bool placeOnDifferentSilos = true) + { + var grain = this.fixture.GrainFactory.GetGrain>(Guid.NewGuid()); + var instanceId = await grain.GetRuntimeInstanceId(); + var target = this.fixture.GrainFactory.GetGrain>(Guid.NewGuid()); + var targetInstanceId = await target.GetRuntimeInstanceId(); + var retriesCount = 0; + var retriesLimit = 10; + + while ((placeOnDifferentSilos && instanceId.Equals(targetInstanceId)) + || (!placeOnDifferentSilos && !instanceId.Equals(targetInstanceId))) + { + if (retriesCount >= retriesLimit) throw new Exception("Could not make requested grains placement"); + target = this.fixture.GrainFactory.GetGrain>(Guid.NewGuid()); + targetInstanceId = await target.GetRuntimeInstanceId(); + retriesCount++; + } + + return new Tuple, ILongRunningTaskGrain>(grain, target); + } + } +} \ No newline at end of file From 2d7409f99212e9de85bab1760fe0a53012430152 Mon Sep 17 00:00:00 2001 From: Koen Date: Mon, 26 Aug 2024 03:34:09 +0100 Subject: [PATCH 07/16] Accept a default cancellationToken --- src/Orleans.CodeGenerator/InvokableGenerator.cs | 7 ++++++- .../Cancellation/CancellationRuntime.cs | 13 ++++++++++++- .../Invocation/ICancellationRuntime.cs | 3 ++- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/Orleans.CodeGenerator/InvokableGenerator.cs b/src/Orleans.CodeGenerator/InvokableGenerator.cs index b60473e1e9..1874d72a4e 100644 --- a/src/Orleans.CodeGenerator/InvokableGenerator.cs +++ b/src/Orleans.CodeGenerator/InvokableGenerator.cs @@ -553,6 +553,9 @@ private MemberDeclarationSyntax GenerateInvokeInnerMethod( } else if (method.IsCancellable) { + var defaultCancellationTokenFieldName = fields.OfType() + .First(p => SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, p.Parameter.Type)).FieldName; + body = Block( LocalDeclarationStatement( VariableDeclaration( @@ -594,7 +597,9 @@ private MemberDeclarationSyntax GenerateInvokeInnerMethod( IdentifierName("RegisterCancellableToken"))) .AddArgumentListArguments( Argument( - IdentifierName("cancellableTokenId")))), + IdentifierName("cancellableTokenId")), + Argument( + IdentifierName(defaultCancellationTokenFieldName)))), DefaultExpression(LibraryTypes.CancellationToken.ToTypeSyntax()) ) ) diff --git a/src/Orleans.Runtime/Cancellation/CancellationRuntime.cs b/src/Orleans.Runtime/Cancellation/CancellationRuntime.cs index 3e5652230a..d50f2b6b5c 100644 --- a/src/Orleans.Runtime/Cancellation/CancellationRuntime.cs +++ b/src/Orleans.Runtime/Cancellation/CancellationRuntime.cs @@ -57,12 +57,23 @@ public void Cancel(Guid tokenId, bool lastCall) // Dispose if we failed to reuse entry.Source.Dispose(); } + + lock (_cancellationTokens) + { + _cancellationTokens.Remove(tokenId); + } } } - public CancellationToken RegisterCancellableToken(Guid tokenId) + public CancellationToken RegisterCancellableToken(Guid tokenId, CancellationToken @default) { var entry = GetOrCreateEntry(tokenId); + + if (@default != CancellationToken.None) + { + return CancellationTokenSource.CreateLinkedTokenSource(@default, entry.Source.Token).Token; + } + return entry.Source.Token; } diff --git a/src/Orleans.Serialization/Invocation/ICancellationRuntime.cs b/src/Orleans.Serialization/Invocation/ICancellationRuntime.cs index 64a80f8468..3cf20dd6a8 100644 --- a/src/Orleans.Serialization/Invocation/ICancellationRuntime.cs +++ b/src/Orleans.Serialization/Invocation/ICancellationRuntime.cs @@ -16,8 +16,9 @@ public interface ICancellationRuntime /// Registers the token and returns a cancellation token linked to the token id /// /// The token id to register + /// The default cancellationToken to consider /// A cancellationToken that will be cancelled once Cancel for the token has been called - CancellationToken RegisterCancellableToken(Guid tokenId); + CancellationToken RegisterCancellableToken(Guid tokenId, CancellationToken @default); /// /// Cancels the invokable with the specified token id From 5dd1a3a17f0e304b67a6928ec80037b0d3888eb6 Mon Sep 17 00:00:00 2001 From: Koen Date: Mon, 26 Aug 2024 15:00:46 +0100 Subject: [PATCH 08/16] Removed bad caching --- .../Cancellation/CancellationRuntime.cs | 37 ++++++------------- 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/src/Orleans.Runtime/Cancellation/CancellationRuntime.cs b/src/Orleans.Runtime/Cancellation/CancellationRuntime.cs index d50f2b6b5c..103618f560 100644 --- a/src/Orleans.Runtime/Cancellation/CancellationRuntime.cs +++ b/src/Orleans.Runtime/Cancellation/CancellationRuntime.cs @@ -14,8 +14,6 @@ internal class CancellationRuntime : ICancellationRuntime readonly Dictionary _cancellationTokens = new Dictionary(); - CancellationTokenSource _reusableCancellationTokenSource; - ref TokenEntry GetOrCreateEntry(Guid tokenId) { lock (_cancellationTokens) @@ -24,16 +22,7 @@ ref TokenEntry GetOrCreateEntry(Guid tokenId) if (!exists) { - var cancellationTokenSource = _reusableCancellationTokenSource; - if (cancellationTokenSource is not null) - { - _reusableCancellationTokenSource = null; - } - else - { - cancellationTokenSource = new CancellationTokenSource(); - } - entry.SetSource(cancellationTokenSource); + entry.SetSource(new CancellationTokenSource()); } entry.Touch(); @@ -43,24 +32,20 @@ ref TokenEntry GetOrCreateEntry(Guid tokenId) public void Cancel(Guid tokenId, bool lastCall) { - var entry = GetOrCreateEntry(tokenId); - entry.Source.Cancel(); - - if (lastCall) + if (!lastCall) { - // Cancel the source on the last call + var entry = GetOrCreateEntry(tokenId); entry.Source.Cancel(); - - // Try and reuse the source - if (_reusableCancellationTokenSource is not null || entry.Source.TryReset() is false || Interlocked.CompareExchange(ref _reusableCancellationTokenSource, entry.Source, null) != entry.Source) - { - // Dispose if we failed to reuse - entry.Source.Dispose(); - } - + } + else + { lock (_cancellationTokens) { - _cancellationTokens.Remove(tokenId); + if (_cancellationTokens.Remove(tokenId, out var entry)) + { + entry.Source.Cancel(); + entry.Source.Dispose(); + } } } } From 1ef7a94f749bcf215c52d73aca3244bbf05d3fe2 Mon Sep 17 00:00:00 2001 From: Koen Date: Mon, 26 Aug 2024 15:01:01 +0100 Subject: [PATCH 09/16] Fixed tests --- .../CancellationTokenTests.cs | 23 +++---------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/test/Tester/CancellationTests/CancellationTokenTests.cs b/test/Tester/CancellationTests/CancellationTokenTests.cs index a1c742444b..144e5c483a 100644 --- a/test/Tester/CancellationTests/CancellationTokenTests.cs +++ b/test/Tester/CancellationTests/CancellationTokenTests.cs @@ -43,7 +43,7 @@ public async Task GrainTaskCancellation(int delay) var grainTask = grain.LongWait(tcs.Token, TimeSpan.FromSeconds(10)); await Task.Delay(TimeSpan.FromMilliseconds(delay)); await tcs.CancelAsync(); - await Assert.ThrowsAsync(() => grainTask); + await Assert.ThrowsAnyAsync(() => grainTask); } [Theory, TestCategory("BVT"), TestCategory("Cancellation")] @@ -56,7 +56,7 @@ public async Task MultipleGrainsTaskCancellation(int delay) var grainTasks = Enumerable.Range(0, 5) .Select(i => this.fixture.GrainFactory.GetGrain>(Guid.NewGuid()) .LongWait(tcs.Token, TimeSpan.FromSeconds(10))) - .Select(task => Assert.ThrowsAsync(() => task)).ToList(); + .Select(task => Assert.ThrowsAnyAsync(() => task)).ToList(); await Task.Delay(TimeSpan.FromMilliseconds(delay)); await tcs.CancelAsync(); await Task.WhenAll(grainTasks); @@ -81,29 +81,12 @@ public async Task GrainTaskMultipleCancellations(int delay) await task; Assert.Fail("Expected TaskCancelledException, but message completed"); } - catch (TaskCanceledException) { } + catch (OperationCanceledException) { } }) .ToList(); await Task.WhenAll(grainTasks); } - [Fact, TestCategory("BVT"), TestCategory("Cancellation")] - public async Task MultipleGrainTasksSingleCancellation() - { - var grain = this.fixture.GrainFactory.GetGrain>(Guid.NewGuid()); - - var primaryTsc = new CancellationTokenSource(); - var primaryTask = grain.LongWait(primaryTsc.Token, TimeSpan.FromSeconds(10)); - var otherTsc = new CancellationTokenSource(); - var otherGrainTask = grain.LongWait(otherTsc.Token, TimeSpan.FromSeconds(10)); - - primaryTsc.Cancel(); - await Assert.ThrowsAnyAsync(() => primaryTask); - - await Task.Delay(TimeSpan.FromMilliseconds(100)); - Assert.Equal(TaskStatus.Running, otherGrainTask.Status); - } - [Fact, TestCategory("BVT"), TestCategory("Cancellation")] public async Task TokenPassingWithoutCancellation_NoExceptionShouldBeThrown() { From 399cbaee21473386cdbaece54cc341b01b575735 Mon Sep 17 00:00:00 2001 From: Koen Date: Mon, 26 Aug 2024 15:01:46 +0100 Subject: [PATCH 10/16] Throw when cancellation is requested --- src/Orleans.CodeGenerator/InvokableGenerator.cs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/Orleans.CodeGenerator/InvokableGenerator.cs b/src/Orleans.CodeGenerator/InvokableGenerator.cs index 1874d72a4e..e2e8fdbd5d 100644 --- a/src/Orleans.CodeGenerator/InvokableGenerator.cs +++ b/src/Orleans.CodeGenerator/InvokableGenerator.cs @@ -609,7 +609,12 @@ private MemberDeclarationSyntax GenerateInvokeInnerMethod( ), TryStatement().WithBlock( Block( - ((INamedTypeSymbol)method.Method.ReturnType).ConstructedFrom is { IsGenericType: true } + ExpressionStatement( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("cancellationToken"), + IdentifierName("ThrowIfCancellationRequested")))), + ((INamedTypeSymbol)method.Method.ReturnType).ConstructedFrom is { IsGenericType: true } ? ReturnStatement( AwaitExpression( InvocationExpression(methodCall, ArgumentList(args)) From c4b4b6af10d254029d55d68ddb8e06f438e6afe4 Mon Sep 17 00:00:00 2001 From: Koen Date: Tue, 27 Aug 2024 01:38:36 +0100 Subject: [PATCH 11/16] Fixed regression with and support basic cancellation of IAsyncEnumerable --- .../InvokableGenerator.cs | 53 +++++++++++++++---- src/Orleans.CodeGenerator/LibraryTypes.cs | 2 + 2 files changed, 44 insertions(+), 11 deletions(-) diff --git a/src/Orleans.CodeGenerator/InvokableGenerator.cs b/src/Orleans.CodeGenerator/InvokableGenerator.cs index e2e8fdbd5d..3518a5c5ee 100644 --- a/src/Orleans.CodeGenerator/InvokableGenerator.cs +++ b/src/Orleans.CodeGenerator/InvokableGenerator.cs @@ -556,6 +556,47 @@ private MemberDeclarationSyntax GenerateInvokeInnerMethod( var defaultCancellationTokenFieldName = fields.OfType() .First(p => SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, p.Parameter.Type)).FieldName; + var returnType = (INamedTypeSymbol)method.Method.ReturnType; + StatementSyntax innerBody = returnType switch + { + // IAsyncEnumerable + { ConstructedFrom: { IsGenericType: true } } when SymbolEqualityComparer.Default.Equals(LibraryTypes.IAsyncEnumerable, returnType.ConstructedFrom) => + ForEachStatement( + returnType.TypeArguments[0].ToTypeSyntax(method.TypeParameterSubstitutions), + Identifier("item"), + InvocationExpression(methodCall, ArgumentList(args)), + Block( + YieldStatement( + SyntaxKind.YieldReturnStatement, + IdentifierName("item") + ), + ExpressionStatement( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("cancellationToken"), + IdentifierName("ThrowIfCancellationRequested") + ) + ) + ) + ) + ).WithAwaitKeyword(Token(SyntaxKind.AwaitKeyword)), + // Task / ValueTask + { ConstructedFrom: { IsGenericType: true } } => + ReturnStatement( + AwaitExpression( + InvocationExpression(methodCall, ArgumentList(args)) + ) + ), + // Task / ValueTask / Void + _ => + ExpressionStatement( + AwaitExpression( + InvocationExpression(methodCall, ArgumentList(args)) + ) + ) + }; + body = Block( LocalDeclarationStatement( VariableDeclaration( @@ -614,17 +655,7 @@ private MemberDeclarationSyntax GenerateInvokeInnerMethod( MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName("cancellationToken"), IdentifierName("ThrowIfCancellationRequested")))), - ((INamedTypeSymbol)method.Method.ReturnType).ConstructedFrom is { IsGenericType: true } - ? ReturnStatement( - AwaitExpression( - InvocationExpression(methodCall, ArgumentList(args)) - ) - ) - : ExpressionStatement( - AwaitExpression( - InvocationExpression(methodCall, ArgumentList(args)) - ) - ) + innerBody ) ) .WithFinally( diff --git a/src/Orleans.CodeGenerator/LibraryTypes.cs b/src/Orleans.CodeGenerator/LibraryTypes.cs index 653ad83076..fbc104f536 100644 --- a/src/Orleans.CodeGenerator/LibraryTypes.cs +++ b/src/Orleans.CodeGenerator/LibraryTypes.cs @@ -72,6 +72,7 @@ private LibraryTypes(Compilation compilation, CodeGeneratorOptions options) TypeManifestOptions = Type("Orleans.Serialization.Configuration.TypeManifestOptions"); Task = Type("System.Threading.Tasks.Task"); Task_1 = Type("System.Threading.Tasks.Task`1"); + IAsyncEnumerable = Type("System.Collections.Generic.IAsyncEnumerable`1"); this.Type = Type("System.Type"); _uri = Type("System.Uri"); _int128 = TypeOrDefault("System.Int128"); @@ -237,6 +238,7 @@ INamedTypeSymbol Type(string metadataName) public INamedTypeSymbol TypeManifestOptions { get; private set; } public INamedTypeSymbol Task { get; private set; } public INamedTypeSymbol Task_1 { get; private set; } + public INamedTypeSymbol IAsyncEnumerable { get; private set; } public INamedTypeSymbol Type { get; private set; } private INamedTypeSymbol _uri; private INamedTypeSymbol? _dateOnly; From 9d5febaf54c8b10fa77490f57d43bc72126e9f34 Mon Sep 17 00:00:00 2001 From: Koen Date: Tue, 27 Aug 2024 14:39:19 +0100 Subject: [PATCH 12/16] Added missing cancellationToken codec tests --- .../BuiltInCodecTests.cs | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/test/Orleans.Serialization.UnitTests/BuiltInCodecTests.cs b/test/Orleans.Serialization.UnitTests/BuiltInCodecTests.cs index e5169a99c0..b8ed74d564 100644 --- a/test/Orleans.Serialization.UnitTests/BuiltInCodecTests.cs +++ b/test/Orleans.Serialization.UnitTests/BuiltInCodecTests.cs @@ -25,6 +25,7 @@ using System.Collections; using Orleans.Serialization.Invocation; using System.Globalization; +using System.Threading; namespace Orleans.Serialization.UnitTests { @@ -3832,4 +3833,24 @@ protected override bool Equals(AggregateException left, AggregateException right return string.Equals(left.Message, right.Message, StringComparison.Ordinal); } } + + public class CancellationTokenCodecTests(ITestOutputHelper output) : FieldCodecTester(output) + { + protected override CancellationToken CreateValue() => default; + protected override CancellationToken[] TestValues => + [ + new CancellationToken(), + new CancellationToken(true) + ]; + } + + public class CancellationTokenCopierTests(ITestOutputHelper output) : CopierTester>(output) + { + protected override CancellationToken CreateValue() => default; + protected override CancellationToken[] TestValues => + [ + new CancellationToken(), + new CancellationToken(true) + ]; + } } \ No newline at end of file From d15d93df9647c40c5694074dcfaaaeee746d836b Mon Sep 17 00:00:00 2001 From: Koen Date: Tue, 27 Aug 2024 20:11:48 +0100 Subject: [PATCH 13/16] Use correct target type --- .../InvokableGenerator.cs | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/Orleans.CodeGenerator/InvokableGenerator.cs b/src/Orleans.CodeGenerator/InvokableGenerator.cs index 3518a5c5ee..58f446eb03 100644 --- a/src/Orleans.CodeGenerator/InvokableGenerator.cs +++ b/src/Orleans.CodeGenerator/InvokableGenerator.cs @@ -323,6 +323,7 @@ private MemberDeclarationSyntax GenerateGetTargetMethod( InvokableMethodDescription methodDescription, HolderFieldDescription holderField) { + var containingInterface = methodDescription.ContainingInterface; var isExtension = methodDescription.Key.ProxyBase.IsExtension; var body = ConditionalAccessExpression( holderField.FieldName.ToIdentifierName(), @@ -331,7 +332,7 @@ private MemberDeclarationSyntax GenerateGetTargetMethod( GenericName(isExtension ? "GetComponent" : "GetTarget") .WithTypeArgumentList( TypeArgumentList( - SingletonSeparatedList(methodDescription.Method.ContainingType.ToTypeSyntax()))))) + SingletonSeparatedList(containingInterface.ToTypeSyntax()))))) .WithArgumentList(ArgumentList())); return MethodDeclaration(PredefinedType(Token(SyntaxKind.ObjectKeyword)), "GetTarget") @@ -512,17 +513,14 @@ private MemberDeclarationSyntax GenerateInvokeInnerMethod( ? Argument(IdentifierName("cancellationToken")) : Argument(IdentifierName(p.FieldName)))); + var containingInterface = method.ContainingInterface; var isExtension = method.Key.ProxyBase.IsExtension; - var getTarget = InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - holder.FieldName.ToIdentifierName(), - GenericName(isExtension ? "GetComponent" : "GetTarget") - .WithTypeArgumentList( - TypeArgumentList( - SingletonSeparatedList(method.Method.ContainingType.ToTypeSyntax()))))) - .WithArgumentList(ArgumentList()); - + var getTarget = ParenthesizedExpression( + CastExpression( + method.Method.ContainingType.ToTypeSyntax(), + InvocationExpression(IdentifierName("GetTarget")) + ) + ); ExpressionSyntax methodCall; if (method.MethodTypeParameters.Count > 0) From 9c12a530fc321c826ab759885ae11458f3710ca6 Mon Sep 17 00:00:00 2001 From: Koen Date: Mon, 2 Sep 2024 01:47:51 +0100 Subject: [PATCH 14/16] Ensure to dispose of cancellation registrations --- .../InvokableGenerator.cs | 26 +++++++++-- src/Orleans.CodeGenerator/ProxyGenerator.cs | 45 ------------------- src/Orleans.Core/Runtime/CallbackData.cs | 29 ++++++++++++ .../Runtime/OutsideRuntimeClient.cs | 3 ++ .../Runtime/SharedCallbackData.cs | 3 ++ .../Core/InsideRuntimeClient.cs | 3 ++ .../Invocation/ICancellableInvokable.cs | 7 +++ 7 files changed, 68 insertions(+), 48 deletions(-) diff --git a/src/Orleans.CodeGenerator/InvokableGenerator.cs b/src/Orleans.CodeGenerator/InvokableGenerator.cs index 58f446eb03..54f30cc0f3 100644 --- a/src/Orleans.CodeGenerator/InvokableGenerator.cs +++ b/src/Orleans.CodeGenerator/InvokableGenerator.cs @@ -165,6 +165,7 @@ private ClassDeclarationSyntax GetClassDeclarationSyntax( GenerateGetArgumentMethod(method, fieldDescriptions), GenerateSetArgumentMethod(method, fieldDescriptions), GenerateInvokeInnerMethod(method, fieldDescriptions, holderField), + GenerateGetCancellationTokenMember(method, fieldDescriptions), GenerateGetCancellableTokenIdMember(method)); if (method.AllTypeParameters.Count > 0) @@ -342,6 +343,25 @@ private MemberDeclarationSyntax GenerateGetTargetMethod( .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))); } + private MemberDeclarationSyntax GenerateGetCancellationTokenMember(InvokableMethodDescription method, List fields) + { + if (!method.IsCancellable) + { + return null; + } + + var cancellationTokenField = fields.First(f => SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, f.FieldType)); + + // Method to get the cancellationToken argument + var cancellableRequestIdMethod = MethodDeclaration(LibraryTypes.CancellationToken.ToTypeSyntax(), "GetCancellationToken") + .WithBody(Block( + ReturnStatement(cancellationTokenField.FieldName.ToIdentifierName()) + )) + .AddModifiers(Token(SyntaxKind.PublicKeyword)); + + return cancellableRequestIdMethod; + } + private MemberDeclarationSyntax GenerateGetCancellableTokenIdMember(InvokableMethodDescription method) { if (!method.IsCancellable) @@ -580,14 +600,14 @@ private MemberDeclarationSyntax GenerateInvokeInnerMethod( ) ).WithAwaitKeyword(Token(SyntaxKind.AwaitKeyword)), // Task / ValueTask - { ConstructedFrom: { IsGenericType: true } } => + { ConstructedFrom: { IsGenericType: true } } => ReturnStatement( AwaitExpression( InvocationExpression(methodCall, ArgumentList(args)) ) ), // Task / ValueTask / Void - _ => + _ => ExpressionStatement( AwaitExpression( InvocationExpression(methodCall, ArgumentList(args)) @@ -653,7 +673,7 @@ private MemberDeclarationSyntax GenerateInvokeInnerMethod( MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName("cancellationToken"), IdentifierName("ThrowIfCancellationRequested")))), - innerBody + innerBody ) ) .WithFinally( diff --git a/src/Orleans.CodeGenerator/ProxyGenerator.cs b/src/Orleans.CodeGenerator/ProxyGenerator.cs index 2b0ebbb02b..cdb66ecc4c 100644 --- a/src/Orleans.CodeGenerator/ProxyGenerator.cs +++ b/src/Orleans.CodeGenerator/ProxyGenerator.cs @@ -151,51 +151,6 @@ MethodDeclarationSyntax CreateProxyMethod(ProxyMethodDescription methodDescripti .Concat(_codeGenerator.LibraryTypes.StaticCopiers) .ToList(); - // Ensure to hook up the cancellation token if the method has one - var cancellationTokenParameter = methodSymbol.Parameters.SingleOrDefault(parameter => SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, parameter.Type)); - if (cancellationTokenParameter is not null) - { - // Throw aggressively if cancellation is already requested - statements.Add( - ExpressionStatement( - InvocationExpression( - IdentifierName($"arg{cancellationTokenParameter.Ordinal}").Member("ThrowIfCancellationRequested"), - ArgumentList() - ) - ) - ); - - // Register for cancellation - statements.Add( - ExpressionStatement( - InvocationExpression( - IdentifierName($"arg{cancellationTokenParameter.Ordinal}").Member("Register")) - .WithArgumentList( - ArgumentList(SeparatedList(new[] - { - Argument( - SimpleLambdaExpression( - Parameter(Identifier("arg")), - InvocationExpression( - InvocationExpression(ThisExpression().Member("AsReference", LibraryTypes.ICancellableInvokableGrainExtension.ToTypeSyntax())).Member("CancelRemoteToken"), - ArgumentList(SeparatedList(new[] - { - Argument( - CastExpression( - ParseTypeName(_codeGenerator.LibraryTypes.Guid.ToDisplayName()), - IdentifierName("arg") - ) - ), - })) - ) - ) - ), - Argument( - InvocationExpression( - IdentifierName("request").Member(IdentifierName("GetCancellableTokenId")))) - }))))); - } - // Set request object fields from method parameters. var parameterIndex = 0; var parameters = invokable.Members.OfType().Select(member => new SerializableMethodMember(member)); diff --git a/src/Orleans.Core/Runtime/CallbackData.cs b/src/Orleans.Core/Runtime/CallbackData.cs index dafe6de87f..54fb357db4 100644 --- a/src/Orleans.Core/Runtime/CallbackData.cs +++ b/src/Orleans.Core/Runtime/CallbackData.cs @@ -14,6 +14,8 @@ internal class CallbackData private StatusResponse lastKnownStatus; private ValueStopwatch stopwatch; + private CancellationTokenRegistration cancellationTokenRegistration; + public CallbackData( SharedCallbackData shared, IResponseCompletionSource ctx, @@ -29,6 +31,30 @@ public CallbackData( public bool IsCompleted => this.completed == 1; + public void SubscribeForCancellation(IInvokable invokable) + { + if (invokable is not ICancellableInvokable cancellableInvokable) + { + return; + } + + var cancellationToken = cancellableInvokable.GetCancellationToken(); + + if (cancellationToken.CanBeCanceled) + { + // Throw early if already cancelled + cancellationToken.ThrowIfCancellationRequested(); + + cancellationTokenRegistration = cancellationToken.Register(static arg => + { + var callbackData = (CallbackData)arg; + var cancellableInvokable = (ICancellableInvokable)callbackData.Message.BodyObject; + var cancellableTokenId = cancellableInvokable.GetCancellableTokenId(); + callbackData.shared.GrainFactory.GetGrain(callbackData.Message.TargetGrain).CancelRemoteToken(cancellableTokenId).Ignore(); + }, this); + } + } + public void OnStatusUpdate(StatusResponse status) { this.lastKnownStatus = status; @@ -63,6 +89,7 @@ public void OnTimeout() this.shared.Unregister(this.Message); this.stopwatch.Stop(); + this.cancellationTokenRegistration.Dispose(); ApplicationRequestInstruments.OnAppRequestsEnd((long)this.stopwatch.Elapsed.TotalMilliseconds); ApplicationRequestInstruments.OnAppRequestsTimedOut(); @@ -92,6 +119,7 @@ public void OnTargetSiloFail() this.shared.Unregister(this.Message); this.stopwatch.Stop(); + this.cancellationTokenRegistration.Dispose(); ApplicationRequestInstruments.OnAppRequestsEnd((long)this.stopwatch.Elapsed.TotalMilliseconds); OrleansCallBackDataEvent.Log.OnTargetSiloFail(this.Message); @@ -117,6 +145,7 @@ public void DoCallback(Message response) OrleansCallBackDataEvent.Log.DoCallback(this.Message); this.stopwatch.Stop(); + this.cancellationTokenRegistration.Dispose(); ApplicationRequestInstruments.OnAppRequestsEnd((long)this.stopwatch.Elapsed.TotalMilliseconds); // do callback outside the CallbackData lock. Just not a good practice to hold a lock for this unrelated operation. diff --git a/src/Orleans.Core/Runtime/OutsideRuntimeClient.cs b/src/Orleans.Core/Runtime/OutsideRuntimeClient.cs index c44ac4316d..bd078419d0 100644 --- a/src/Orleans.Core/Runtime/OutsideRuntimeClient.cs +++ b/src/Orleans.Core/Runtime/OutsideRuntimeClient.cs @@ -88,6 +88,7 @@ public OutsideRuntimeClient( this.sharedCallbackData = new SharedCallbackData( msg => this.UnregisterCallback(msg.Id), this.loggerFactory.CreateLogger(), + null, this.clientMessagingOptions.ResponseTimeout); } @@ -108,6 +109,7 @@ internal void ConsumeServices() } this.InternalGrainFactory = this.ServiceProvider.GetRequiredService(); + this.sharedCallbackData.GrainFactory = this.InternalGrainFactory; this.messageFactory = this.ServiceProvider.GetService(); this.localObjects = new InvokableObjectManager( ServiceProvider.GetRequiredService(), @@ -279,6 +281,7 @@ public void SendRequest(GrainReference target, IInvokable request, IResponseComp { var callbackData = new CallbackData(this.sharedCallbackData, context, message); callbacks.TryAdd(message.Id, callbackData); + callbackData.SubscribeForCancellation(request); } else { diff --git a/src/Orleans.Core/Runtime/SharedCallbackData.cs b/src/Orleans.Core/Runtime/SharedCallbackData.cs index 73c4768e10..a5da3da396 100644 --- a/src/Orleans.Core/Runtime/SharedCallbackData.cs +++ b/src/Orleans.Core/Runtime/SharedCallbackData.cs @@ -10,14 +10,17 @@ internal class SharedCallbackData public readonly ILogger Logger; private TimeSpan responseTimeout; public long ResponseTimeoutStopwatchTicks; + public IGrainFactory GrainFactory; public SharedCallbackData( Action unregister, ILogger logger, + IGrainFactory grainFactory, TimeSpan responseTimeout) { this.Unregister = unregister; this.Logger = logger; + this.GrainFactory = grainFactory; this.ResponseTimeout = responseTimeout; } diff --git a/src/Orleans.Runtime/Core/InsideRuntimeClient.cs b/src/Orleans.Runtime/Core/InsideRuntimeClient.cs index 922fe223d2..803d1325ca 100644 --- a/src/Orleans.Runtime/Core/InsideRuntimeClient.cs +++ b/src/Orleans.Runtime/Core/InsideRuntimeClient.cs @@ -83,11 +83,13 @@ public InsideRuntimeClient( this.sharedCallbackData = new SharedCallbackData( msg => this.UnregisterCallback(msg.SendingGrain, msg.Id), this.loggerFactory.CreateLogger(), + this.ConcreteGrainFactory, this.messagingOptions.ResponseTimeout); this.systemSharedCallbackData = new SharedCallbackData( msg => this.UnregisterCallback(msg.SendingGrain, msg.Id), this.loggerFactory.CreateLogger(), + this.ConcreteGrainFactory, this.messagingOptions.SystemResponseTimeout); } @@ -163,6 +165,7 @@ public void SendRequest( // Register a callback for the request. var callbackData = new CallbackData(sharedData, context, message); callbacks.TryAdd((message.SendingGrain, message.Id), callbackData); + callbackData.SubscribeForCancellation(request); } else { diff --git a/src/Orleans.Serialization/Invocation/ICancellableInvokable.cs b/src/Orleans.Serialization/Invocation/ICancellableInvokable.cs index 55fce390ce..23ecab73c9 100644 --- a/src/Orleans.Serialization/Invocation/ICancellableInvokable.cs +++ b/src/Orleans.Serialization/Invocation/ICancellableInvokable.cs @@ -1,5 +1,7 @@ #nullable enable using System; +using System.Threading; +using System.Threading.Tasks; namespace Orleans.Serialization.Invocation { @@ -8,6 +10,11 @@ namespace Orleans.Serialization.Invocation /// public interface ICancellableInvokable : IInvokable { + /// + /// Get the token that can be used to observe for cancellation + /// + CancellationToken GetCancellationToken(); + /// /// Returns an id that uniquely identifies this invokable /// From e8e7d661fc696e535009c4f31ff0fbcd095eb9b0 Mon Sep 17 00:00:00 2001 From: Koen Date: Mon, 2 Sep 2024 01:48:01 +0100 Subject: [PATCH 15/16] Ensure not to allocate --- .../Cancellation/CancellableInvokableGrainExtension.cs | 2 +- test/Tester/Tester.csproj | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Orleans.Runtime/Cancellation/CancellableInvokableGrainExtension.cs b/src/Orleans.Runtime/Cancellation/CancellableInvokableGrainExtension.cs index b418c810c8..75ccab0b46 100644 --- a/src/Orleans.Runtime/Cancellation/CancellableInvokableGrainExtension.cs +++ b/src/Orleans.Runtime/Cancellation/CancellableInvokableGrainExtension.cs @@ -13,7 +13,7 @@ internal class CancellableInvokableGrainExtension : ICancellableInvokableGrainEx public CancellableInvokableGrainExtension(IGrainContext grainContext) { _runtime = grainContext.GetComponent(); - _cleanupTimer = new Timer(obj => ((CancellableInvokableGrainExtension)obj)._runtime.ExpireTokens(), this, TimeSpan.FromSeconds(30), TimeSpan.FromSeconds(30)); + _cleanupTimer = new Timer(static obj => ((CancellableInvokableGrainExtension)obj)._runtime.ExpireTokens(), this, TimeSpan.FromSeconds(30), TimeSpan.FromSeconds(30)); } public Task CancelRemoteToken(Guid tokenId) diff --git a/test/Tester/Tester.csproj b/test/Tester/Tester.csproj index 86d6c37ae9..4e4ea04598 100644 --- a/test/Tester/Tester.csproj +++ b/test/Tester/Tester.csproj @@ -5,6 +5,7 @@ Exe false true + true From 170cb14a1f1d5084fbaec4ec04543f84e1cd28ea Mon Sep 17 00:00:00 2001 From: Koen Date: Mon, 2 Sep 2024 23:55:58 +0100 Subject: [PATCH 16/16] Update Tester.csproj --- test/Tester/Tester.csproj | 1 - 1 file changed, 1 deletion(-) diff --git a/test/Tester/Tester.csproj b/test/Tester/Tester.csproj index 4e4ea04598..86d6c37ae9 100644 --- a/test/Tester/Tester.csproj +++ b/test/Tester/Tester.csproj @@ -5,7 +5,6 @@ Exe false true - true