diff --git a/src/Orleans.CodeGenerator/AnalyzerReleases.Unshipped.md b/src/Orleans.CodeGenerator/AnalyzerReleases.Unshipped.md index b1b99aaf26..d5a7efc0b7 100644 --- a/src/Orleans.CodeGenerator/AnalyzerReleases.Unshipped.md +++ b/src/Orleans.CodeGenerator/AnalyzerReleases.Unshipped.md @@ -1,3 +1,8 @@ -; Unshipped analyzer release +; Unshipped analyzer release ; https://github.com/dotnet/roslyn-analyzers/blob/main/src/Microsoft.CodeAnalysis.Analyzers/ReleaseTrackingAnalyzers.Help.md +### New Rules + +Rule ID | Category | Severity | Notes +--------|----------|----------|-------------------- +ORLEANS0109 | Usage | Error | Method has multiple CancellationToken parameters diff --git a/src/Orleans.CodeGenerator/Diagnostics/MultipleCancellationTokenParametersDiagnostic.cs b/src/Orleans.CodeGenerator/Diagnostics/MultipleCancellationTokenParametersDiagnostic.cs new file mode 100644 index 0000000000..526c74fc4e --- /dev/null +++ b/src/Orleans.CodeGenerator/Diagnostics/MultipleCancellationTokenParametersDiagnostic.cs @@ -0,0 +1,16 @@ +using System.Linq; +using Microsoft.CodeAnalysis; + +namespace Orleans.CodeGenerator.Diagnostics; + +public static class MultipleCancellationTokenParametersDiagnostic +{ + public const string DiagnosticId = "ORLEANS0109"; + public const string Title = "Grain method has multiple parameters of type CancellationToken"; + public const string MessageFormat = "The type {0} contains method {1} which has multiple CancellationToken parameters. Only a single CancellationToken parameter is supported."; + public const string Category = "Usage"; + + private static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor(DiagnosticId, Title, MessageFormat, Category, DiagnosticSeverity.Error, isEnabledByDefault: true); + + internal static Diagnostic CreateDiagnostic(IMethodSymbol symbol) => Diagnostic.Create(Rule, symbol.Locations.First(), symbol.ContainingType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), symbol.Name); +} \ No newline at end of file diff --git a/src/Orleans.CodeGenerator/InvokableGenerator.cs b/src/Orleans.CodeGenerator/InvokableGenerator.cs index a09f75a5be..54f30cc0f3 100644 --- a/src/Orleans.CodeGenerator/InvokableGenerator.cs +++ b/src/Orleans.CodeGenerator/InvokableGenerator.cs @@ -30,6 +30,7 @@ public GeneratedInvokableDescription Generate(InvokableMethodDescription invokab var generatedClassName = GetSimpleClassName(invokableMethodInfo); var baseClassType = GetBaseClassType(invokableMethodInfo); + var additionalInterfaceTypes = GetAdditionalInterfaceTypes(invokableMethodInfo); var fieldDescriptions = GetFieldDescriptions(invokableMethodInfo); var fields = GetFieldDeclarations(invokableMethodInfo, fieldDescriptions); var (ctor, ctorArgs) = GenerateConstructor(generatedClassName, invokableMethodInfo, baseClassType); @@ -46,7 +47,7 @@ public GeneratedInvokableDescription Generate(InvokableMethodDescription invokab } } - var targetField = fieldDescriptions.OfType().Single(); + var holderField = fieldDescriptions.OfType().Single(); var accessibilityKind = accessibility switch { @@ -58,11 +59,12 @@ public GeneratedInvokableDescription Generate(InvokableMethodDescription invokab invokableMethodInfo, generatedClassName, baseClassType, + additionalInterfaceTypes, fieldDescriptions, fields, ctor, compoundTypeAliases, - targetField, + holderField, accessibilityKind); string returnValueInitializerMethod = null; @@ -112,11 +114,12 @@ private ClassDeclarationSyntax GetClassDeclarationSyntax( InvokableMethodDescription method, string generatedClassName, INamedTypeSymbol baseClassType, + INamedTypeSymbol[] additionalInterfaceTypes, List fieldDescriptions, MemberDeclarationSyntax[] fields, ConstructorDeclarationSyntax ctor, List compoundTypeAliases, - TargetFieldDescription targetField, + HolderFieldDescription holderField, SyntaxKind accessibilityKind) { var classDeclaration = ClassDeclaration(generatedClassName) @@ -125,6 +128,14 @@ private ClassDeclarationSyntax GetClassDeclarationSyntax( .AddAttributeLists(CodeGenerator.GetGeneratedCodeAttributes()) .AddMembers(fields); + if (additionalInterfaceTypes.Length > 0) + { + foreach (var interfaceType in additionalInterfaceTypes) + { + classDeclaration = classDeclaration.AddBaseListTypes(SimpleBaseType(interfaceType.ToTypeSyntax())); + } + } + foreach (var alias in compoundTypeAliases) { classDeclaration = classDeclaration.AddAttributeLists( @@ -148,12 +159,14 @@ private ClassDeclarationSyntax GetClassDeclarationSyntax( GenerateGetActivityName(method), GenerateGetInterfaceType(method), GenerateGetMethod(), - GenerateSetTargetMethod(method, targetField), - GenerateGetTargetMethod(targetField), + GenerateSetTargetMethod(holderField), + GenerateGetTargetMethod(method, holderField), GenerateDisposeMethod(fieldDescriptions, baseClassType), GenerateGetArgumentMethod(method, fieldDescriptions), GenerateSetArgumentMethod(method, fieldDescriptions), - GenerateInvokeInnerMethod(method, fieldDescriptions, targetField)); + GenerateInvokeInnerMethod(method, fieldDescriptions, holderField), + GenerateGetCancellationTokenMember(method, fieldDescriptions), + GenerateGetCancellableTokenIdMember(method)); if (method.AllTypeParameters.Count > 0) { @@ -182,7 +195,7 @@ private MemberDeclarationSyntax[] GenerateResponseTimeoutPropertyMembers(long va .WithExpressionBody(ArrowExpressionClause(IdentifierName("_responseTimeoutValue"))) .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword)); -; + ; return new MemberDeclarationSyntax[] { timespanField, responseTimeoutProperty }; } @@ -274,44 +287,96 @@ private INamedTypeSymbol GetBaseClassType(InvokableMethodDescription method) throw new OrleansGeneratorDiagnosticAnalysisException(InvalidRpcMethodReturnTypeDiagnostic.CreateDiagnostic(method)); } - private MemberDeclarationSyntax GenerateSetTargetMethod( - InvokableMethodDescription methodDescription, - TargetFieldDescription targetField) + private INamedTypeSymbol[] GetAdditionalInterfaceTypes(InvokableMethodDescription method) + { + if (method.IsCancellable) + { + var cancellationTokensCount = method.Method.Parameters.Count(parameterSymbol => SymbolEqualityComparer.Default.Equals(method.CodeGenerator.LibraryTypes.CancellationToken, parameterSymbol.Type)); + if (cancellationTokensCount is > 1) + { + throw new OrleansGeneratorDiagnosticAnalysisException(MultipleCancellationTokenParametersDiagnostic.CreateDiagnostic(method.Method)); + } + + return [LibraryTypes.ICancellableInvokable]; + } + + return []; + } + + private MemberDeclarationSyntax GenerateSetTargetMethod(HolderFieldDescription holderField) { var holder = IdentifierName("holder"); var holderParameter = holder.Identifier; + return MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), "SetTarget") + .WithParameterList(ParameterList(SingletonSeparatedList(Parameter(holderParameter).WithType(LibraryTypes.ITargetHolder.ToTypeSyntax())))) + .WithExpressionBody(ArrowExpressionClause( + AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + ThisExpression(), + IdentifierName(holderField.FieldName) + ), holder))) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) + .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))); + } + + private MemberDeclarationSyntax GenerateGetTargetMethod( + InvokableMethodDescription methodDescription, + HolderFieldDescription holderField) + { var containingInterface = methodDescription.ContainingInterface; var isExtension = methodDescription.Key.ProxyBase.IsExtension; - var getTarget = InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - holder, + var body = ConditionalAccessExpression( + holderField.FieldName.ToIdentifierName(), + InvocationExpression( + MemberBindingExpression( GenericName(isExtension ? "GetComponent" : "GetTarget") .WithTypeArgumentList( TypeArgumentList( SingletonSeparatedList(containingInterface.ToTypeSyntax()))))) - .WithArgumentList(ArgumentList()); + .WithArgumentList(ArgumentList())); - var body = - AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - IdentifierName(targetField.FieldName), - getTarget); - return MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), "SetTarget") - .WithParameterList(ParameterList(SingletonSeparatedList(Parameter(holderParameter).WithType(LibraryTypes.ITargetHolder.ToTypeSyntax())))) + return MethodDeclaration(PredefinedType(Token(SyntaxKind.ObjectKeyword)), "GetTarget") + .WithParameterList(ParameterList()) .WithExpressionBody(ArrowExpressionClause(body)) .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))); } - private MemberDeclarationSyntax GenerateGetTargetMethod(TargetFieldDescription targetField) + private MemberDeclarationSyntax GenerateGetCancellationTokenMember(InvokableMethodDescription method, List fields) { - return MethodDeclaration(PredefinedType(Token(SyntaxKind.ObjectKeyword)), "GetTarget") - .WithParameterList(ParameterList()) - .WithExpressionBody(ArrowExpressionClause(IdentifierName(targetField.FieldName))) - .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) - .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))); + if (!method.IsCancellable) + { + return null; + } + + var cancellationTokenField = fields.First(f => SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, f.FieldType)); + + // Method to get the cancellationToken argument + var cancellableRequestIdMethod = MethodDeclaration(LibraryTypes.CancellationToken.ToTypeSyntax(), "GetCancellationToken") + .WithBody(Block( + ReturnStatement(cancellationTokenField.FieldName.ToIdentifierName()) + )) + .AddModifiers(Token(SyntaxKind.PublicKeyword)); + + return cancellableRequestIdMethod; + } + + private MemberDeclarationSyntax GenerateGetCancellableTokenIdMember(InvokableMethodDescription method) + { + if (!method.IsCancellable) + { + return null; + } + + // Method to get the CancellableTokenId + var cancellableRequestIdMethod = MethodDeclaration(LibraryTypes.Guid.ToTypeSyntax(), "GetCancellableTokenId") + .WithBody(Block( + ReturnStatement(IdentifierName("cancellableTokenId")) + )) + .AddModifiers(Token(SyntaxKind.PublicKeyword)); + + return cancellableRequestIdMethod; } private MemberDeclarationSyntax GenerateGetArgumentMethod( @@ -456,7 +521,7 @@ private MemberDeclarationSyntax GenerateSetArgumentMethod( private MemberDeclarationSyntax GenerateInvokeInnerMethod( InvokableMethodDescription method, List fields, - TargetFieldDescription target) + HolderFieldDescription holder) { var resultTask = IdentifierName("resultTask"); @@ -464,13 +529,25 @@ private MemberDeclarationSyntax GenerateInvokeInnerMethod( var args = SeparatedList( fields.OfType() .OrderBy(p => p.ParameterOrdinal) - .Select(p => Argument(IdentifierName(p.FieldName)))); + .Select(p => SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, p.Parameter.Type) + ? Argument(IdentifierName("cancellationToken")) + : Argument(IdentifierName(p.FieldName)))); + + var containingInterface = method.ContainingInterface; + var isExtension = method.Key.ProxyBase.IsExtension; + var getTarget = ParenthesizedExpression( + CastExpression( + method.Method.ContainingType.ToTypeSyntax(), + InvocationExpression(IdentifierName("GetTarget")) + ) + ); + ExpressionSyntax methodCall; if (method.MethodTypeParameters.Count > 0) { methodCall = MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(target.FieldName), + getTarget, GenericName( Identifier(method.Method.Name), TypeArgumentList( @@ -479,14 +556,166 @@ private MemberDeclarationSyntax GenerateInvokeInnerMethod( } else { - methodCall = IdentifierName(target.FieldName).Member(method.Method.Name); + methodCall = getTarget.Member(method.Method.Name); + } + + BlockSyntax body; + + if (method.Method.ReturnsVoid) + { + body = Block( + ExpressionStatement( + InvocationExpression(methodCall, ArgumentList(args)) + ) + ); + } + else if (method.IsCancellable) + { + var defaultCancellationTokenFieldName = fields.OfType() + .First(p => SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, p.Parameter.Type)).FieldName; + + var returnType = (INamedTypeSymbol)method.Method.ReturnType; + StatementSyntax innerBody = returnType switch + { + // IAsyncEnumerable + { ConstructedFrom: { IsGenericType: true } } when SymbolEqualityComparer.Default.Equals(LibraryTypes.IAsyncEnumerable, returnType.ConstructedFrom) => + ForEachStatement( + returnType.TypeArguments[0].ToTypeSyntax(method.TypeParameterSubstitutions), + Identifier("item"), + InvocationExpression(methodCall, ArgumentList(args)), + Block( + YieldStatement( + SyntaxKind.YieldReturnStatement, + IdentifierName("item") + ), + ExpressionStatement( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("cancellationToken"), + IdentifierName("ThrowIfCancellationRequested") + ) + ) + ) + ) + ).WithAwaitKeyword(Token(SyntaxKind.AwaitKeyword)), + // Task / ValueTask + { ConstructedFrom: { IsGenericType: true } } => + ReturnStatement( + AwaitExpression( + InvocationExpression(methodCall, ArgumentList(args)) + ) + ), + // Task / ValueTask / Void + _ => + ExpressionStatement( + AwaitExpression( + InvocationExpression(methodCall, ArgumentList(args)) + ) + ) + }; + + body = Block( + LocalDeclarationStatement( + VariableDeclaration( + LibraryTypes.ICancellationRuntime.ToTypeSyntax(), + SingletonSeparatedList( + VariableDeclarator(Identifier("cancellationRuntime")).WithInitializer( + EqualsValueClause( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("holder"), + GenericName("GetComponent") + .WithTypeArgumentList( + TypeArgumentList( + SingletonSeparatedList( + LibraryTypes.ICancellationRuntime.ToTypeSyntax() + ) + ) + ) + ) + ) + ) + ) + ) + ) + ), + LocalDeclarationStatement( + VariableDeclaration( + LibraryTypes.CancellationToken.ToTypeSyntax(), + SingletonSeparatedList( + VariableDeclarator(Identifier("cancellationToken")).WithInitializer( + EqualsValueClause( + BinaryExpression( + SyntaxKind.CoalesceExpression, + ConditionalAccessExpression( + IdentifierName("cancellationRuntime"), + InvocationExpression( + MemberBindingExpression( + IdentifierName("RegisterCancellableToken"))) + .AddArgumentListArguments( + Argument( + IdentifierName("cancellableTokenId")), + Argument( + IdentifierName(defaultCancellationTokenFieldName)))), + DefaultExpression(LibraryTypes.CancellationToken.ToTypeSyntax()) + ) + ) + ) + ) + ) + ), + TryStatement().WithBlock( + Block( + ExpressionStatement( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("cancellationToken"), + IdentifierName("ThrowIfCancellationRequested")))), + innerBody + ) + ) + .WithFinally( + FinallyClause( + Block( + ExpressionStatement( + ConditionalAccessExpression( + IdentifierName("cancellationRuntime"), + InvocationExpression( + MemberBindingExpression( + IdentifierName("Cancel"))) + .AddArgumentListArguments( + Argument(IdentifierName("cancellableTokenId")), + Argument(LiteralExpression(SyntaxKind.TrueLiteralExpression)) + ) + ) + ) + ) + ) + ) + ); + } + else + { + body = Block( + ReturnStatement( + InvocationExpression(methodCall, ArgumentList(args)) + ) + ); } - return MethodDeclaration(method.Method.ReturnType.ToTypeSyntax(method.TypeParameterSubstitutions), "InvokeInner") + var methodDeclaration = MethodDeclaration(method.Method.ReturnType.ToTypeSyntax(method.TypeParameterSubstitutions), "InvokeInner") .WithParameterList(ParameterList()) - .WithExpressionBody(ArrowExpressionClause(InvocationExpression(methodCall, ArgumentList(args)))) - .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) + .WithBody(body) .WithModifiers(TokenList(Token(SyntaxKind.ProtectedKeyword), Token(SyntaxKind.OverrideKeyword))); + + if (!method.Method.ReturnsVoid && method.IsCancellable) + { + methodDeclaration = methodDeclaration.AddModifiers(Token(SyntaxKind.AsyncKeyword)); + } + + return methodDeclaration; } private MemberDeclarationSyntax GenerateDisposeMethod( @@ -635,6 +864,9 @@ MemberDeclarationSyntax GetFieldDeclaration(InvokerFieldDescription description) case MethodParameterFieldDescription _: field = field.AddModifiers(Token(SyntaxKind.PublicKeyword)); break; + case CancellableTokenFieldDescription _: + field = field.AddModifiers(Token(SyntaxKind.PublicKeyword)); + break; } return field; @@ -690,6 +922,16 @@ private ExpressionSyntax GetTypesArray(InvokableMethodDescription method, IEnume body.Add(ExpressionStatement(InvocationExpression(IdentifierName(methodName), ArgumentList(SeparatedList(new[] { Argument(argumentExpression) }))))); } + if (method.IsCancellable) + { + body.Add( + ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + IdentifierName("cancellableTokenId"), + InvocationExpression(LibraryTypes.Guid.ToTypeSyntax().Member("NewGuid"))))); + } + if (body.Count == 0 && parameters.Count == 0) return default; @@ -708,17 +950,22 @@ private ExpressionSyntax GetTypesArray(InvokableMethodDescription method, IEnume private List GetFieldDescriptions(InvokableMethodDescription method) { var fields = new List(); - uint fieldId = 0; + foreach (var parameter in method.Method.Parameters) { fields.Add(new MethodParameterFieldDescription(method.CodeGenerator, parameter, $"arg{fieldId}", fieldId, method.TypeParameterSubstitutions)); fieldId++; } - fields.Add(new TargetFieldDescription(method.Method.ContainingType)); + fields.Add(new HolderFieldDescription(LibraryTypes.ITargetHolder)); fields.Add(new MethodInfoFieldDescription(LibraryTypes.MethodInfo, "MethodBackingField")); + if (method.IsCancellable) + { + fields.Add(new CancellableTokenFieldDescription(LibraryTypes.Guid, "cancellableTokenId", fieldId, method.ContainingInterface)); + } + return fields; } @@ -736,9 +983,9 @@ protected InvokerFieldDescription(ITypeSymbol fieldType, string fieldName) public abstract bool IsInstanceField { get; } } - internal sealed class TargetFieldDescription : InvokerFieldDescription + internal sealed class HolderFieldDescription : InvokerFieldDescription { - public TargetFieldDescription(ITypeSymbol fieldType) : base(fieldType, "target") { } + public HolderFieldDescription(ITypeSymbol fieldType) : base(fieldType, "holder") { } public override bool IsSerializable => false; public override bool IsInstanceField => true; @@ -813,5 +1060,34 @@ public MethodInfoFieldDescription(ITypeSymbol fieldType, string fieldName) : bas public override bool IsSerializable => false; public override bool IsInstanceField => false; } + + internal sealed class CancellableTokenFieldDescription : InvokerFieldDescription, IMemberDescription + { + public CancellableTokenFieldDescription( + ITypeSymbol fieldType, + string fieldName, + uint fieldId, + INamedTypeSymbol containingType) : base(fieldType, fieldName) + { + FieldId = fieldId; + ContainingType = containingType; + } + + public IFieldSymbol Field => null; + public uint FieldId { get; } + public ISymbol Symbol => FieldType; + public ITypeSymbol Type => FieldType; + public INamedTypeSymbol ContainingType { get; } + public string AssemblyName => Type.ContainingAssembly.ToDisplayName(); + public TypeSyntax TypeSyntax => Type.ToTypeSyntax(); + public string TypeName => Type.ToDisplayName(); + public string TypeNameIdentifier => Type.GetValidIdentifier(); + public bool IsPrimaryConstructorParameter => false; + + public TypeSyntax GetTypeSyntax(ITypeSymbol typeSymbol) => TypeSyntax; + + public override bool IsSerializable => true; + public override bool IsInstanceField => true; + } } } diff --git a/src/Orleans.CodeGenerator/LibraryTypes.cs b/src/Orleans.CodeGenerator/LibraryTypes.cs index a432ab96a1..fbc104f536 100644 --- a/src/Orleans.CodeGenerator/LibraryTypes.cs +++ b/src/Orleans.CodeGenerator/LibraryTypes.cs @@ -40,6 +40,7 @@ private LibraryTypes(Compilation compilation, CodeGeneratorOptions options) ConstructorAttributeTypes = options.ConstructorAttributes.Select(Type).ToArray(); AliasAttribute = Type("Orleans.AliasAttribute"); IInvokable = Type("Orleans.Serialization.Invocation.IInvokable"); + ICancellableInvokable = Type("Orleans.Serialization.Invocation.ICancellableInvokable"); InvokeMethodNameAttribute = Type("Orleans.InvokeMethodNameAttribute"); RuntimeHelpers = Type("System.Runtime.CompilerServices.RuntimeHelpers"); InvokableCustomInitializerAttribute = Type("Orleans.InvokableCustomInitializerAttribute"); @@ -58,6 +59,8 @@ private LibraryTypes(Compilation compilation, CodeGeneratorOptions options) SuppressReferenceTrackingAttribute = Type("Orleans.SuppressReferenceTrackingAttribute"); OmitDefaultMemberValuesAttribute = Type("Orleans.OmitDefaultMemberValuesAttribute"); ITargetHolder = Type("Orleans.Serialization.Invocation.ITargetHolder"); + ICancellationRuntime = Type("Orleans.Serialization.Invocation.ICancellationRuntime"); + ICancellableInvokableGrainExtension = TypeOrDefault("Orleans.Runtime.ICancellableInvokableGrainExtension"); TypeManifestProviderAttribute = Type("Orleans.Serialization.Configuration.TypeManifestProviderAttribute"); NonSerializedAttribute = Type("System.NonSerializedAttribute"); ObsoleteAttribute = Type("System.ObsoleteAttribute"); @@ -69,6 +72,7 @@ private LibraryTypes(Compilation compilation, CodeGeneratorOptions options) TypeManifestOptions = Type("Orleans.Serialization.Configuration.TypeManifestOptions"); Task = Type("System.Threading.Tasks.Task"); Task_1 = Type("System.Threading.Tasks.Task`1"); + IAsyncEnumerable = Type("System.Collections.Generic.IAsyncEnumerable`1"); this.Type = Type("System.Type"); _uri = Type("System.Uri"); _int128 = TypeOrDefault("System.Int128"); @@ -77,11 +81,11 @@ private LibraryTypes(Compilation compilation, CodeGeneratorOptions options) _dateOnly = TypeOrDefault("System.DateOnly"); _dateTimeOffset = Type("System.DateTimeOffset"); _bitVector32 = Type("System.Collections.Specialized.BitVector32"); - _guid = Type("System.Guid"); _compareInfo = Type("System.Globalization.CompareInfo"); _cultureInfo = Type("System.Globalization.CultureInfo"); _version = Type("System.Version"); _timeOnly = TypeOrDefault("System.TimeOnly"); + Guid = Type("System.Guid"); ICodecProvider = Type("Orleans.Serialization.Serializers.ICodecProvider"); ValueSerializer = Type("Orleans.Serialization.Serializers.IValueSerializer`1"); ValueTask = Type("System.Threading.Tasks.ValueTask"); @@ -124,6 +128,7 @@ private LibraryTypes(Compilation compilation, CodeGeneratorOptions options) new(TypeOrDefault("System.Int128"), TypeOrDefault("Orleans.Serialization.Codecs.Int128Codec")), new(TypeOrDefault("System.Half"), TypeOrDefault("Orleans.Serialization.Codecs.HalfCodec")), new(Type("System.Uri"), Type("Orleans.Serialization.Codecs.UriCodec")), + new(Type("System.Threading.CancellationToken"), Type("Orleans.Serialization.Codecs.CancellationTokenCodec")), }.Where(desc => desc.UnderlyingType is { } && desc.CodecType is { }).ToArray(); WellKnownCodecs = new WellKnownCodecDescription[] { @@ -153,7 +158,7 @@ private LibraryTypes(Compilation compilation, CodeGeneratorOptions options) TimeSpan = Type("System.TimeSpan"); _ipAddress = Type("System.Net.IPAddress"); _ipEndPoint = Type("System.Net.IPEndPoint"); - _cancellationToken = Type("System.Threading.CancellationToken"); + CancellationToken = Type("System.Threading.CancellationToken"); _immutableContainerTypes = new[] { compilation.GetSpecialType(SpecialType.System_Nullable_T), @@ -218,7 +223,10 @@ INamedTypeSymbol Type(string metadataName) public INamedTypeSymbol IActivator_1 { get; private set; } public INamedTypeSymbol IBufferWriter { get; private set; } public INamedTypeSymbol IInvokable { get; private set; } + public INamedTypeSymbol ICancellableInvokable { get; private set; } public INamedTypeSymbol ITargetHolder { get; private set; } + public INamedTypeSymbol ICancellationRuntime { get; private set; } + public INamedTypeSymbol? ICancellableInvokableGrainExtension { get; private set; } public INamedTypeSymbol TypeManifestProviderAttribute { get; private set; } public INamedTypeSymbol NonSerializedAttribute { get; private set; } public INamedTypeSymbol ObsoleteAttribute { get; private set; } @@ -230,6 +238,7 @@ INamedTypeSymbol Type(string metadataName) public INamedTypeSymbol TypeManifestOptions { get; private set; } public INamedTypeSymbol Task { get; private set; } public INamedTypeSymbol Task_1 { get; private set; } + public INamedTypeSymbol IAsyncEnumerable { get; private set; } public INamedTypeSymbol Type { get; private set; } private INamedTypeSymbol _uri; private INamedTypeSymbol? _dateOnly; @@ -259,13 +268,13 @@ INamedTypeSymbol Type(string metadataName) public INamedTypeSymbol SuppressReferenceTrackingAttribute { get; private set; } public INamedTypeSymbol OmitDefaultMemberValuesAttribute { get; private set; } public INamedTypeSymbol CopyContext { get; private set; } + public INamedTypeSymbol CancellationToken { get; private set; } + public INamedTypeSymbol Guid { get; private set; } public Compilation Compilation { get; private set; } public INamedTypeSymbol TimeSpan { get; private set; } private INamedTypeSymbol _ipAddress; private INamedTypeSymbol _ipEndPoint; - private INamedTypeSymbol _cancellationToken; private INamedTypeSymbol[] _immutableContainerTypes; - private INamedTypeSymbol _guid; private INamedTypeSymbol _bitVector32; private INamedTypeSymbol _compareInfo; private INamedTypeSymbol _cultureInfo; @@ -280,14 +289,14 @@ INamedTypeSymbol Type(string metadataName) _dateOnly, _timeOnly, _dateTimeOffset, - _guid, + Guid, _bitVector32, _compareInfo, _cultureInfo, _version, _ipAddress, _ipEndPoint, - _cancellationToken, + CancellationToken, Type, _uri, _uInt128, diff --git a/src/Orleans.CodeGenerator/Model/InvokableMethodDescription.cs b/src/Orleans.CodeGenerator/Model/InvokableMethodDescription.cs index ec6b30c831..6b72db6ce1 100644 --- a/src/Orleans.CodeGenerator/Model/InvokableMethodDescription.cs +++ b/src/Orleans.CodeGenerator/Model/InvokableMethodDescription.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.Globalization; +using System.Linq; namespace Orleans.CodeGenerator { @@ -206,6 +207,11 @@ static bool TryGetNamedArgument(ImmutableArray public INamedTypeSymbol ContainingInterface { get; } + /// + /// Gets a value indicating whether this method is cancellable. + /// + public bool IsCancellable => Method.Parameters.Any(parameterSymbol => SymbolEqualityComparer.Default.Equals(CodeGenerator.LibraryTypes.CancellationToken, parameterSymbol.Type)); + public bool Equals(InvokableMethodDescription other) => Key.Equals(other.Key); public override bool Equals(object obj) => obj is InvokableMethodDescription imd && Equals(imd); public override int GetHashCode() => Key.GetHashCode(); diff --git a/src/Orleans.CodeGenerator/SerializerGenerator.cs b/src/Orleans.CodeGenerator/SerializerGenerator.cs index 16b21b6d4e..9899cf3c16 100644 --- a/src/Orleans.CodeGenerator/SerializerGenerator.cs +++ b/src/Orleans.CodeGenerator/SerializerGenerator.cs @@ -47,6 +47,10 @@ public ClassDeclarationSyntax Generate(ISerializableTypeDescription type) { members.Add(new SerializableMethodMember(methodParameter)); } + else if (member is CancellableTokenFieldDescription cancellableTokenField) + { + members.Add(new SerializableCancellableTokenMember(_codeGenerator, cancellableTokenField)); + } } var fieldDescriptions = GetFieldDescriptions(type, members); @@ -1170,6 +1174,8 @@ public SerializableMethodMember(MethodParameterFieldDescription member) public bool IsShallowCopyable => LibraryTypes.IsShallowCopyable(_member.Parameter.Type) || _member.Parameter.HasAnyAttribute(LibraryTypes.ImmutableAttributes); + public bool IsCancellationToken => SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, _member.Parameter.Type); + /// /// Gets syntax representing the type of this field. /// @@ -1201,6 +1207,56 @@ public ExpressionSyntax GetSetter(ExpressionSyntax instance, ExpressionSyntax va public FieldAccessorDescription GetSetterFieldDescription() => null; } + internal class SerializableCancellableTokenMember : ISerializableMember + { + private readonly CodeGenerator _codeGenerator; + private readonly CancellableTokenFieldDescription _member; + + public SerializableCancellableTokenMember(CodeGenerator codeGenerator, CancellableTokenFieldDescription member) + { + _codeGenerator = codeGenerator; + _member = member; + } + + public IMemberDescription Member => _member; + + private LibraryTypes LibraryTypes => _codeGenerator.LibraryTypes; + + public bool IsShallowCopyable => LibraryTypes.IsShallowCopyable(_member.Type); + + public bool IsCancellationToken => false; + + /// + /// Gets syntax representing the type of this field. + /// + public TypeSyntax TypeSyntax => _member.TypeSyntax; + + public bool IsValueType => _member.Type.IsValueType; + + public bool IsPrimaryConstructorParameter => false; + + /// + /// Returns syntax for retrieving the value of this field, deep copying it if necessary. + /// + /// The instance of the containing type. + /// Syntax for retrieving the value of this field. + public ExpressionSyntax GetGetter(ExpressionSyntax instance) => instance.Member(_member.FieldName); + + /// + /// Returns syntax for setting the value of this field. + /// + /// The instance of the containing type. + /// Syntax for the new value. + /// Syntax for setting the value of this field. + public ExpressionSyntax GetSetter(ExpressionSyntax instance, ExpressionSyntax value) => AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + instance.Member(_member.FieldName), + value); + + public FieldAccessorDescription GetGetterFieldDescription() => null; + public FieldAccessorDescription GetSetterFieldDescription() => null; + } + /// /// Represents a serializable member (field/property) of a type. /// diff --git a/src/Orleans.Core.Abstractions/Cancellation/ICancellableInvokableGrainExtension.cs b/src/Orleans.Core.Abstractions/Cancellation/ICancellableInvokableGrainExtension.cs new file mode 100644 index 0000000000..75337bba81 --- /dev/null +++ b/src/Orleans.Core.Abstractions/Cancellation/ICancellableInvokableGrainExtension.cs @@ -0,0 +1,19 @@ +using System; +using System.Threading.Tasks; +using Orleans.Concurrency; + +namespace Orleans.Runtime; + +public interface ICancellableInvokableGrainExtension : IGrainExtension +{ + /// + /// Indicates that a cancellation token has been canceled. + /// + /// + /// The token id + /// + /// A representing the operation. + /// + [AlwaysInterleave] + Task CancelRemoteToken(Guid tokenId); +} \ No newline at end of file diff --git a/src/Orleans.Core/Runtime/CallbackData.cs b/src/Orleans.Core/Runtime/CallbackData.cs index dafe6de87f..54fb357db4 100644 --- a/src/Orleans.Core/Runtime/CallbackData.cs +++ b/src/Orleans.Core/Runtime/CallbackData.cs @@ -14,6 +14,8 @@ internal class CallbackData private StatusResponse lastKnownStatus; private ValueStopwatch stopwatch; + private CancellationTokenRegistration cancellationTokenRegistration; + public CallbackData( SharedCallbackData shared, IResponseCompletionSource ctx, @@ -29,6 +31,30 @@ public CallbackData( public bool IsCompleted => this.completed == 1; + public void SubscribeForCancellation(IInvokable invokable) + { + if (invokable is not ICancellableInvokable cancellableInvokable) + { + return; + } + + var cancellationToken = cancellableInvokable.GetCancellationToken(); + + if (cancellationToken.CanBeCanceled) + { + // Throw early if already cancelled + cancellationToken.ThrowIfCancellationRequested(); + + cancellationTokenRegistration = cancellationToken.Register(static arg => + { + var callbackData = (CallbackData)arg; + var cancellableInvokable = (ICancellableInvokable)callbackData.Message.BodyObject; + var cancellableTokenId = cancellableInvokable.GetCancellableTokenId(); + callbackData.shared.GrainFactory.GetGrain(callbackData.Message.TargetGrain).CancelRemoteToken(cancellableTokenId).Ignore(); + }, this); + } + } + public void OnStatusUpdate(StatusResponse status) { this.lastKnownStatus = status; @@ -63,6 +89,7 @@ public void OnTimeout() this.shared.Unregister(this.Message); this.stopwatch.Stop(); + this.cancellationTokenRegistration.Dispose(); ApplicationRequestInstruments.OnAppRequestsEnd((long)this.stopwatch.Elapsed.TotalMilliseconds); ApplicationRequestInstruments.OnAppRequestsTimedOut(); @@ -92,6 +119,7 @@ public void OnTargetSiloFail() this.shared.Unregister(this.Message); this.stopwatch.Stop(); + this.cancellationTokenRegistration.Dispose(); ApplicationRequestInstruments.OnAppRequestsEnd((long)this.stopwatch.Elapsed.TotalMilliseconds); OrleansCallBackDataEvent.Log.OnTargetSiloFail(this.Message); @@ -117,6 +145,7 @@ public void DoCallback(Message response) OrleansCallBackDataEvent.Log.DoCallback(this.Message); this.stopwatch.Stop(); + this.cancellationTokenRegistration.Dispose(); ApplicationRequestInstruments.OnAppRequestsEnd((long)this.stopwatch.Elapsed.TotalMilliseconds); // do callback outside the CallbackData lock. Just not a good practice to hold a lock for this unrelated operation. diff --git a/src/Orleans.Core/Runtime/OutsideRuntimeClient.cs b/src/Orleans.Core/Runtime/OutsideRuntimeClient.cs index 29d536db26..a59e780b60 100644 --- a/src/Orleans.Core/Runtime/OutsideRuntimeClient.cs +++ b/src/Orleans.Core/Runtime/OutsideRuntimeClient.cs @@ -90,6 +90,7 @@ public OutsideRuntimeClient( this.sharedCallbackData = new SharedCallbackData( msg => this.UnregisterCallback(msg.Id), this.loggerFactory.CreateLogger(), + null, this.clientMessagingOptions.ResponseTimeout); } @@ -101,6 +102,7 @@ internal void ConsumeServices() _manifestProvider = ServiceProvider.GetRequiredService(); this.InternalGrainFactory = this.ServiceProvider.GetRequiredService(); + this.sharedCallbackData.GrainFactory = this.InternalGrainFactory; this.messageFactory = this.ServiceProvider.GetService(); this.localObjects = new InvokableObjectManager( ServiceProvider.GetRequiredService(), @@ -278,6 +280,7 @@ public void SendRequest(GrainReference target, IInvokable request, IResponseComp { var callbackData = new CallbackData(this.sharedCallbackData, context, message); callbacks.TryAdd(message.Id, callbackData); + callbackData.SubscribeForCancellation(request); } else { diff --git a/src/Orleans.Core/Runtime/SharedCallbackData.cs b/src/Orleans.Core/Runtime/SharedCallbackData.cs index 73c4768e10..a5da3da396 100644 --- a/src/Orleans.Core/Runtime/SharedCallbackData.cs +++ b/src/Orleans.Core/Runtime/SharedCallbackData.cs @@ -10,14 +10,17 @@ internal class SharedCallbackData public readonly ILogger Logger; private TimeSpan responseTimeout; public long ResponseTimeoutStopwatchTicks; + public IGrainFactory GrainFactory; public SharedCallbackData( Action unregister, ILogger logger, + IGrainFactory grainFactory, TimeSpan responseTimeout) { this.Unregister = unregister; this.Logger = logger; + this.GrainFactory = grainFactory; this.ResponseTimeout = responseTimeout; } diff --git a/src/Orleans.Runtime/Cancellation/CancellableInvokableGrainExtension.cs b/src/Orleans.Runtime/Cancellation/CancellableInvokableGrainExtension.cs new file mode 100644 index 0000000000..75ccab0b46 --- /dev/null +++ b/src/Orleans.Runtime/Cancellation/CancellableInvokableGrainExtension.cs @@ -0,0 +1,33 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Orleans.Serialization.Invocation; + +namespace Orleans.Runtime.Cancellation; + +internal class CancellableInvokableGrainExtension : ICancellableInvokableGrainExtension, IDisposable +{ + readonly ICancellationRuntime _runtime; + readonly Timer _cleanupTimer; + + public CancellableInvokableGrainExtension(IGrainContext grainContext) + { + _runtime = grainContext.GetComponent(); + _cleanupTimer = new Timer(static obj => ((CancellableInvokableGrainExtension)obj)._runtime.ExpireTokens(), this, TimeSpan.FromSeconds(30), TimeSpan.FromSeconds(30)); + } + + public Task CancelRemoteToken(Guid tokenId) + { + if (_runtime is not null) + { + _runtime.Cancel(tokenId, lastCall: false); + } + + return Task.CompletedTask; + } + + public void Dispose() + { + _cleanupTimer.Dispose(); + } +} diff --git a/src/Orleans.Runtime/Cancellation/CancellationRuntime.cs b/src/Orleans.Runtime/Cancellation/CancellationRuntime.cs new file mode 100644 index 0000000000..103618f560 --- /dev/null +++ b/src/Orleans.Runtime/Cancellation/CancellationRuntime.cs @@ -0,0 +1,100 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Runtime.InteropServices; +using System.Threading; +using Microsoft.Extensions.Logging; +using Orleans.Serialization.Invocation; + +namespace Orleans.Runtime.Cancellation; + +internal class CancellationRuntime : ICancellationRuntime +{ + private static readonly TimeSpan _cleanupFrequency = TimeSpan.FromMinutes(7); + + readonly Dictionary _cancellationTokens = new Dictionary(); + + ref TokenEntry GetOrCreateEntry(Guid tokenId) + { + lock (_cancellationTokens) + { + ref var entry = ref CollectionsMarshal.GetValueRefOrAddDefault(_cancellationTokens, tokenId, out var exists); + + if (!exists) + { + entry.SetSource(new CancellationTokenSource()); + } + + entry.Touch(); + return ref entry; + } + } + + public void Cancel(Guid tokenId, bool lastCall) + { + if (!lastCall) + { + var entry = GetOrCreateEntry(tokenId); + entry.Source.Cancel(); + } + else + { + lock (_cancellationTokens) + { + if (_cancellationTokens.Remove(tokenId, out var entry)) + { + entry.Source.Cancel(); + entry.Source.Dispose(); + } + } + } + } + + public CancellationToken RegisterCancellableToken(Guid tokenId, CancellationToken @default) + { + var entry = GetOrCreateEntry(tokenId); + + if (@default != CancellationToken.None) + { + return CancellationTokenSource.CreateLinkedTokenSource(@default, entry.Source.Token).Token; + } + + return entry.Source.Token; + } + + public void ExpireTokens() + { + var now = Stopwatch.GetTimestamp(); + lock (_cancellationTokens) + { + foreach (var token in _cancellationTokens) + { + if (token.Value.IsExpired(_cleanupFrequency, now)) + { + _cancellationTokens.Remove(token.Key); + } + } + } + } + + struct TokenEntry + { + private long _createdTime; + + public void Touch() => _createdTime = Stopwatch.GetTimestamp(); + + public void SetSource(CancellationTokenSource source) + { + Source = source; + } + + public CancellationTokenSource Source { get; private set; } + + public bool IsExpired(TimeSpan expiry, long nowTimestamp) + { + var untouchedTime = TimeSpan.FromSeconds((nowTimestamp - _createdTime) / (double)Stopwatch.Frequency); + + return untouchedTime >= expiry; + } + } +} \ No newline at end of file diff --git a/src/Orleans.Runtime/Core/InsideRuntimeClient.cs b/src/Orleans.Runtime/Core/InsideRuntimeClient.cs index d3069636fc..55466ab1c2 100644 --- a/src/Orleans.Runtime/Core/InsideRuntimeClient.cs +++ b/src/Orleans.Runtime/Core/InsideRuntimeClient.cs @@ -83,11 +83,13 @@ public InsideRuntimeClient( this.sharedCallbackData = new SharedCallbackData( msg => this.UnregisterCallback(msg.SendingGrain, msg.Id), this.loggerFactory.CreateLogger(), + this.ConcreteGrainFactory, this.messagingOptions.ResponseTimeout); this.systemSharedCallbackData = new SharedCallbackData( msg => this.UnregisterCallback(msg.SendingGrain, msg.Id), this.loggerFactory.CreateLogger(), + this.ConcreteGrainFactory, this.messagingOptions.SystemResponseTimeout); } @@ -163,6 +165,7 @@ public void SendRequest( // Register a callback for the request. var callbackData = new CallbackData(sharedData, context, message); callbacks.TryAdd((message.SendingGrain, message.Id), callbackData); + callbackData.SubscribeForCancellation(request); } else { diff --git a/src/Orleans.Runtime/Hosting/DefaultSiloServices.cs b/src/Orleans.Runtime/Hosting/DefaultSiloServices.cs index 67f98907cd..248256946e 100644 --- a/src/Orleans.Runtime/Hosting/DefaultSiloServices.cs +++ b/src/Orleans.Runtime/Hosting/DefaultSiloServices.cs @@ -43,8 +43,8 @@ using Orleans.Serialization.Internal; using Orleans.Core; using Orleans.Placement.Repartitioning; -using Orleans.GrainDirectory; -using Orleans.Runtime.Hosting; +using Orleans.Runtime.Cancellation; +using Orleans.Serialization.Invocation; namespace Orleans.Hosting { @@ -101,6 +101,9 @@ internal static void AddDefaultServices(ISiloBuilder builder) services.TryAddSingleton(); services.AddTransient(); services.AddKeyedTransient(typeof(ICancellationSourcesExtension), (sp, _) => sp.GetRequiredService()); + services.TryAddTransient(); + services.AddTransient(); + services.AddKeyedTransient(typeof(ICancellableInvokableGrainExtension), (sp, _) => sp.GetRequiredService()); services.TryAddSingleton(sp => sp.GetRequiredService().ConcreteGrainFactory); services.TryAddSingleton(); services.TryAddSingleton(); diff --git a/src/Orleans.Serialization/Codecs/CancellationTokenCodec.cs b/src/Orleans.Serialization/Codecs/CancellationTokenCodec.cs new file mode 100644 index 0000000000..5f452428b8 --- /dev/null +++ b/src/Orleans.Serialization/Codecs/CancellationTokenCodec.cs @@ -0,0 +1,55 @@ +using System; +using System.Buffers; +using System.Runtime.CompilerServices; +using System.Threading; +using Orleans.Serialization.Buffers; +using Orleans.Serialization.WireProtocol; + +namespace Orleans.Serialization.Codecs +{ + /// + /// Serializer for . + /// + [RegisterSerializer] + public sealed class CancellationTokenCodec : IFieldCodec + { + void IFieldCodec.WriteField(ref Writer writer, uint fieldIdDelta, Type expectedType, CancellationToken value) + { + ReferenceCodec.MarkValueField(writer.Session); + writer.WriteFieldHeader(fieldIdDelta, expectedType, typeof(CancellationToken), WireType.Fixed32); + writer.WriteInt32(value.IsCancellationRequested ? 1 : 0); + } + + /// + /// Writes a field without type info (expected type is statically known). + /// + /// The buffer writer type. + /// The writer. + /// The field identifier delta. + /// The value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void WriteField(ref Writer writer, uint fieldIdDelta, CancellationToken value) where TBufferWriter : IBufferWriter + { + ReferenceCodec.MarkValueField(writer.Session); + writer.WriteFieldHeaderExpected(fieldIdDelta, WireType.Fixed32); + writer.WriteInt32(value.IsCancellationRequested ? 1 : 0); + } + + /// + CancellationToken IFieldCodec.ReadValue(ref Reader reader, Field field) => ReadValue(ref reader, field); + + /// + /// Reads a value. + /// + /// The reader input type. + /// The reader. + /// The field. + /// The value. + public static CancellationToken ReadValue(ref Reader reader, Field field) + { + ReferenceCodec.MarkValueField(reader.Session); + field.EnsureWireType(WireType.Fixed32); + return new CancellationToken(reader.ReadInt32() == 1 ? true : false); + } + } +} \ No newline at end of file diff --git a/src/Orleans.Serialization/Invocation/ICancellableInvokable.cs b/src/Orleans.Serialization/Invocation/ICancellableInvokable.cs new file mode 100644 index 0000000000..23ecab73c9 --- /dev/null +++ b/src/Orleans.Serialization/Invocation/ICancellableInvokable.cs @@ -0,0 +1,23 @@ +#nullable enable +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Orleans.Serialization.Invocation +{ + /// + /// Represents an invokable that can be canceled + /// + public interface ICancellableInvokable : IInvokable + { + /// + /// Get the token that can be used to observe for cancellation + /// + CancellationToken GetCancellationToken(); + + /// + /// Returns an id that uniquely identifies this invokable + /// + Guid GetCancellableTokenId(); + } +} \ No newline at end of file diff --git a/src/Orleans.Serialization/Invocation/ICancellationRuntime.cs b/src/Orleans.Serialization/Invocation/ICancellationRuntime.cs new file mode 100644 index 0000000000..3cf20dd6a8 --- /dev/null +++ b/src/Orleans.Serialization/Invocation/ICancellationRuntime.cs @@ -0,0 +1,34 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Orleans.Serialization.Invocation; + +/// +/// An optional runtime to fascilitate in cancelling invokables +/// +public interface ICancellationRuntime +{ + /// + /// Registers the token and returns a cancellation token linked to the token id + /// + /// The token id to register + /// The default cancellationToken to consider + /// A cancellationToken that will be cancelled once Cancel for the token has been called + CancellationToken RegisterCancellableToken(Guid tokenId, CancellationToken @default); + + /// + /// Cancels the invokable with the specified token id + /// + /// The token id to cancel + /// Whether this is the last call associated with the token + void Cancel(Guid tokenId, bool lastCall); + + /// + /// Expires any tokens that have not yet been cancelled and have been inactive for a period of time + /// + void ExpireTokens(); +} \ No newline at end of file diff --git a/test/Grains/TestGrainInterfaces/IGenericInterfaces.cs b/test/Grains/TestGrainInterfaces/IGenericInterfaces.cs index 9a84c549ce..1a500fa27e 100644 --- a/test/Grains/TestGrainInterfaces/IGenericInterfaces.cs +++ b/test/Grains/TestGrainInterfaces/IGenericInterfaces.cs @@ -193,15 +193,19 @@ public interface ILongRunningTaskGrain : IGrainWithGuidKey Task GetRuntimeInstanceIdWithDelay(TimeSpan delay); Task LongWait(GrainCancellationToken tc, TimeSpan delay); + Task LongWait(CancellationToken tc, TimeSpan delay); Task LongRunningTask(T t, TimeSpan delay); Task CallOtherLongRunningTask(ILongRunningTaskGrain target, T t, TimeSpan delay); Task FanOutOtherLongRunningTask(ILongRunningTaskGrain target, T t, TimeSpan delay, int degreeOfParallelism); Task CallOtherLongRunningTask(ILongRunningTaskGrain target, GrainCancellationToken tc, TimeSpan delay); + Task CallOtherLongRunningTask(ILongRunningTaskGrain target, CancellationToken tc, TimeSpan delay); Task CallOtherLongRunningTaskWithLocalToken(ILongRunningTaskGrain target, TimeSpan delay, TimeSpan delayBeforeCancel); Task CancellationTokenCallbackResolve(GrainCancellationToken tc); + Task CancellationTokenCallbackResolve(CancellationToken tc); Task CallOtherCancellationTokenCallbackResolve(ILongRunningTaskGrain target); Task CancellationTokenCallbackThrow(GrainCancellationToken tc); + Task CancellationTokenCallbackThrow(CancellationToken tc); Task GetLastValue(); } diff --git a/test/Grains/TestGrains/GenericGrains.cs b/test/Grains/TestGrains/GenericGrains.cs index ac7196f59f..d69575e26a 100644 --- a/test/Grains/TestGrains/GenericGrains.cs +++ b/test/Grains/TestGrains/GenericGrains.cs @@ -627,6 +627,16 @@ public Task CancellationTokenCallbackThrow(GrainCancellationToken tc) return Task.CompletedTask; } + public Task CancellationTokenCallbackThrow(CancellationToken tc) + { + tc.Register(() => + { + throw new InvalidOperationException("From cancellation token callback"); + }); + + return Task.CompletedTask; + } + public Task GetLastValue() { return Task.FromResult(lastValue); @@ -660,6 +670,25 @@ public Task CancellationTokenCallbackResolve(GrainCancellationToken tc) return tcs.Task; } + public Task CancellationTokenCallbackResolve(CancellationToken tc) + { + var tcs = new TaskCompletionSource(); + var orleansTs = TaskScheduler.Current; + tc.Register(() => + { + if (TaskScheduler.Current != orleansTs) + { + tcs.SetException(new Exception("Callback executed on wrong thread")); + } + else + { + tcs.SetResult(true); + } + }); + + return tcs.Task; + } + public async Task CallOtherLongRunningTask(ILongRunningTaskGrain target, T t, TimeSpan delay) { return await target.LongRunningTask(t, delay); @@ -681,6 +710,11 @@ public async Task CallOtherLongRunningTask(ILongRunningTaskGrain target, Grai await target.LongWait(tc, delay); } + public async Task CallOtherLongRunningTask(ILongRunningTaskGrain target, CancellationToken tc, TimeSpan delay) + { + await target.LongWait(tc, delay); + } + public async Task CallOtherLongRunningTaskWithLocalToken(ILongRunningTaskGrain target, TimeSpan delay, TimeSpan delayBeforeCancel) { var tcs = new GrainCancellationTokenSource(); @@ -695,6 +729,11 @@ public async Task LongWait(GrainCancellationToken tc, TimeSpan delay) await Task.Delay(delay, tc.CancellationToken); } + public async Task LongWait(CancellationToken tc, TimeSpan delay) + { + await Task.Delay(delay, tc); + } + public async Task LongRunningTask(T t, TimeSpan delay) { await Task.Delay(delay); diff --git a/test/Orleans.Serialization.UnitTests/BuiltInCodecTests.cs b/test/Orleans.Serialization.UnitTests/BuiltInCodecTests.cs index a025011c8d..d791fa0cc9 100644 --- a/test/Orleans.Serialization.UnitTests/BuiltInCodecTests.cs +++ b/test/Orleans.Serialization.UnitTests/BuiltInCodecTests.cs @@ -25,6 +25,7 @@ using System.Collections; using Orleans.Serialization.Invocation; using System.Globalization; +using System.Threading; namespace Orleans.Serialization.UnitTests { @@ -3882,4 +3883,24 @@ protected override bool Equals(AggregateException left, AggregateException right return string.Equals(left.Message, right.Message, StringComparison.Ordinal); } } + + public class CancellationTokenCodecTests(ITestOutputHelper output) : FieldCodecTester(output) + { + protected override CancellationToken CreateValue() => default; + protected override CancellationToken[] TestValues => + [ + new CancellationToken(), + new CancellationToken(true) + ]; + } + + public class CancellationTokenCopierTests(ITestOutputHelper output) : CopierTester>(output) + { + protected override CancellationToken CreateValue() => default; + protected override CancellationToken[] TestValues => + [ + new CancellationToken(), + new CancellationToken(true) + ]; + } } \ No newline at end of file diff --git a/test/Tester/CancellationTests/CancellationTokenTests.cs b/test/Tester/CancellationTests/CancellationTokenTests.cs new file mode 100644 index 0000000000..144e5c483a --- /dev/null +++ b/test/Tester/CancellationTests/CancellationTokenTests.cs @@ -0,0 +1,230 @@ +using Microsoft.Extensions.Logging; +using Orleans.TestingHost; +using TestExtensions; +using UnitTests.GrainInterfaces; +using Xunit; + +namespace UnitTests.CancellationTests +{ + public class CancellationTokenTests : OrleansTestingBase, IClassFixture + { + private readonly Fixture fixture; + + public class Fixture : BaseTestClusterFixture + { + protected override void ConfigureTestCluster(TestClusterBuilder builder) + { + base.ConfigureTestCluster(builder); + builder.AddSiloBuilderConfigurator(); + } + + private class SiloConfig : ISiloConfigurator + { + public void Configure(ISiloBuilder siloBuilder) + { + siloBuilder.ConfigureLogging(logging => logging.AddDebug()); + } + } + } + + public CancellationTokenTests(Fixture fixture) + { + this.fixture = fixture; + } + + [Theory, TestCategory("BVT"), TestCategory("Cancellation")] + [InlineData(0)] + [InlineData(10)] + [InlineData(300)] + public async Task GrainTaskCancellation(int delay) + { + var grain = this.fixture.GrainFactory.GetGrain>(Guid.NewGuid()); + var tcs = new CancellationTokenSource(); + var grainTask = grain.LongWait(tcs.Token, TimeSpan.FromSeconds(10)); + await Task.Delay(TimeSpan.FromMilliseconds(delay)); + await tcs.CancelAsync(); + await Assert.ThrowsAnyAsync(() => grainTask); + } + + [Theory, TestCategory("BVT"), TestCategory("Cancellation")] + [InlineData(0)] + [InlineData(10)] + [InlineData(300)] + public async Task MultipleGrainsTaskCancellation(int delay) + { + var tcs = new CancellationTokenSource(); + var grainTasks = Enumerable.Range(0, 5) + .Select(i => this.fixture.GrainFactory.GetGrain>(Guid.NewGuid()) + .LongWait(tcs.Token, TimeSpan.FromSeconds(10))) + .Select(task => Assert.ThrowsAnyAsync(() => task)).ToList(); + await Task.Delay(TimeSpan.FromMilliseconds(delay)); + await tcs.CancelAsync(); + await Task.WhenAll(grainTasks); + } + + [Theory, TestCategory("BVT"), TestCategory("Cancellation")] + [InlineData(0)] + [InlineData(10)] + [InlineData(300)] + public async Task GrainTaskMultipleCancellations(int delay) + { + var grain = this.fixture.GrainFactory.GetGrain>(Guid.NewGuid()); + var grainTasks = Enumerable.Range(0, 5) + .Select(async i => + { + var tcs = new CancellationTokenSource(); + var task = grain.LongWait(tcs.Token, TimeSpan.FromSeconds(10)); + await Task.WhenAny(task, Task.Delay(TimeSpan.FromMilliseconds(delay))); + await tcs.CancelAsync(); + try + { + await task; + Assert.Fail("Expected TaskCancelledException, but message completed"); + } + catch (OperationCanceledException) { } + }) + .ToList(); + await Task.WhenAll(grainTasks); + } + + [Fact, TestCategory("BVT"), TestCategory("Cancellation")] + public async Task TokenPassingWithoutCancellation_NoExceptionShouldBeThrown() + { + var grain = this.fixture.GrainFactory.GetGrain>(Guid.NewGuid()); + var tcs = new CancellationTokenSource(); + try + { + await grain.LongWait(tcs.Token, TimeSpan.FromMilliseconds(1)); + } + catch (Exception ex) + { + Assert.Fail("Expected no exception, but got: " + ex.Message); + } + } + + [Fact, TestCategory("BVT"), TestCategory("Cancellation")] + public async Task PreCancelledTokenPassing() + { + var grain = this.fixture.GrainFactory.GetGrain>(Guid.NewGuid()); + var tcs = new CancellationTokenSource(); + await tcs.CancelAsync(); + + // Except a OperationCanceledException to be thrown as the token is already cancelled + Assert.Throws(() => grain.LongWait(tcs.Token, TimeSpan.FromSeconds(10)).Ignore()); + } + + [Fact, TestCategory("BVT"), TestCategory("Cancellation")] + public async Task CancellationTokenCallbacksExecutionContext() + { + var grain = this.fixture.GrainFactory.GetGrain>(Guid.NewGuid()); + var tcs = new CancellationTokenSource(); + var grainTask = grain.CancellationTokenCallbackResolve(tcs.Token); + await Task.Delay(TimeSpan.FromMilliseconds(100)); + await tcs.CancelAsync(); + var result = await grainTask; + Assert.True(result); + } + + [Fact, TestCategory("BVT"), TestCategory("Cancellation")] + public async Task CancellationTokenCallbacksTaskSchedulerContext() + { + var grains = await GetGrains(false); + + var tcs = new CancellationTokenSource(); + var grainTask = grains.Item1.CallOtherCancellationTokenCallbackResolve(grains.Item2); + await tcs.CancelAsync(); + var result = await grainTask; + Assert.True(result); + } + + [Fact, TestCategory("Cancellation")] + public async Task CancellationTokenCallbacksThrow_ExceptionDoesNotPropagate() + { + var grain = this.fixture.GrainFactory.GetGrain>(Guid.NewGuid()); + var tcs = new CancellationTokenSource(); + _ = grain.CancellationTokenCallbackThrow(tcs.Token); + await Task.Delay(TimeSpan.FromMilliseconds(100)); + // Cancellation is a cooperative mechanism, so we don't expect the exception to propagate + await tcs.CancelAsync(); + } + + [Theory, TestCategory("BVT"), TestCategory("Cancellation")] + [InlineData(0)] + [InlineData(10)] + [InlineData(300)] + public async Task InSiloGrainCancellation(int delay) + { + await GrainGrainCancellation(false, delay); + } + + [Theory, TestCategory("BVT"), TestCategory("Cancellation")] + [InlineData(0)] + [InlineData(10)] + [InlineData(300)] + public async Task InterSiloGrainCancellation(int delay) + { + await GrainGrainCancellation(true, delay); + } + + [SkippableTheory(Skip = "https://github.com/dotnet/orleans/issues/5654"), TestCategory("BVT"), TestCategory("Cancellation")] + [InlineData(0)] + [InlineData(10)] + [InlineData(300)] + public async Task InterSiloClientCancellationTokenPassing(int delay) + { + await ClientGrainGrainTokenPassing(delay, true); + } + + [Theory, TestCategory("BVT"), TestCategory("Cancellation")] + [InlineData(0)] + [InlineData(10)] + [InlineData(300)] + public async Task InSiloClientCancellationTokenPassing(int delay) + { + await ClientGrainGrainTokenPassing(delay, false); + } + + private async Task ClientGrainGrainTokenPassing(int delay, bool interSilo) + { + var grains = await GetGrains(interSilo); + var grain = grains.Item1; + var target = grains.Item2; + var tcs = new CancellationTokenSource(); + var grainTask = grain.CallOtherLongRunningTask(target, tcs.Token, TimeSpan.FromSeconds(10)); + await Task.Delay(TimeSpan.FromMilliseconds(delay)); + await tcs.CancelAsync(); + await Assert.ThrowsAnyAsync(() => grainTask); + } + + private async Task GrainGrainCancellation(bool interSilo, int delay) + { + var grains = await GetGrains(interSilo); + var grain = grains.Item1; + var target = grains.Item2; + var grainTask = grain.CallOtherLongRunningTaskWithLocalToken(target, TimeSpan.FromSeconds(10), + TimeSpan.FromMilliseconds(delay)); + await Assert.ThrowsAnyAsync(() => grainTask); + } + + private async Task, ILongRunningTaskGrain>> GetGrains(bool placeOnDifferentSilos = true) + { + var grain = this.fixture.GrainFactory.GetGrain>(Guid.NewGuid()); + var instanceId = await grain.GetRuntimeInstanceId(); + var target = this.fixture.GrainFactory.GetGrain>(Guid.NewGuid()); + var targetInstanceId = await target.GetRuntimeInstanceId(); + var retriesCount = 0; + var retriesLimit = 10; + + while ((placeOnDifferentSilos && instanceId.Equals(targetInstanceId)) + || (!placeOnDifferentSilos && !instanceId.Equals(targetInstanceId))) + { + if (retriesCount >= retriesLimit) throw new Exception("Could not make requested grains placement"); + target = this.fixture.GrainFactory.GetGrain>(Guid.NewGuid()); + targetInstanceId = await target.GetRuntimeInstanceId(); + retriesCount++; + } + + return new Tuple, ILongRunningTaskGrain>(grain, target); + } + } +} \ No newline at end of file