diff --git a/src/Components/Server/src/Circuits/RemoteJSDataStream.cs b/src/Components/Server/src/Circuits/RemoteJSDataStream.cs index 75c432a9ec74..d1cdf0724832 100644 --- a/src/Components/Server/src/Circuits/RemoteJSDataStream.cs +++ b/src/Components/Server/src/Circuits/RemoteJSDataStream.cs @@ -181,28 +181,14 @@ public override void Write(byte[] buffer, int offset, int count) public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - var linkedCancellationToken = GetLinkedCancellationToken(_streamCancellationToken, cancellationToken); - return await _pipeReaderStream.ReadAsync(buffer.AsMemory(offset, count), linkedCancellationToken); + using var linkedCts = ValueLinkedCancellationTokenSource.Create(_streamCancellationToken, cancellationToken); + return await _pipeReaderStream.ReadAsync(buffer.AsMemory(offset, count), linkedCts.Token); } public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) { - var linkedCancellationToken = GetLinkedCancellationToken(_streamCancellationToken, cancellationToken); - return await _pipeReaderStream.ReadAsync(buffer, linkedCancellationToken); - } - - private static CancellationToken GetLinkedCancellationToken(CancellationToken a, CancellationToken b) - { - if (a.CanBeCanceled && b.CanBeCanceled) - { - return CancellationTokenSource.CreateLinkedTokenSource(a, b).Token; - } - else if (a.CanBeCanceled) - { - return a; - } - - return b; + using var linkedCts = ValueLinkedCancellationTokenSource.Create(_streamCancellationToken, cancellationToken); + return await _pipeReaderStream.ReadAsync(buffer, linkedCts.Token); } private async Task ThrowOnTimeout() @@ -243,4 +229,45 @@ protected override void Dispose(bool disposing) _disposed = true; } + + // A helper for creating and disposing linked CancellationTokenSources + // without allocating, when possible. + // Internal for testing. + internal readonly struct ValueLinkedCancellationTokenSource : IDisposable + { + private readonly CancellationTokenSource? _linkedCts; + + public readonly CancellationToken Token; + + // For testing. + internal bool HasLinkedCancellationTokenSource => _linkedCts is not null; + + public static ValueLinkedCancellationTokenSource Create( + CancellationToken token1, CancellationToken token2) + { + if (!token1.CanBeCanceled) + { + return new(linkedCts: null, token2); + } + + if (!token2.CanBeCanceled) + { + return new(linkedCts: null, token1); + } + + var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(token1, token2); + return new(linkedCts, linkedCts.Token); + } + + private ValueLinkedCancellationTokenSource(CancellationTokenSource? linkedCts, CancellationToken token) + { + _linkedCts = linkedCts; + Token = token; + } + + public void Dispose() + { + _linkedCts?.Dispose(); + } + } } diff --git a/src/Components/Server/test/Circuits/RemoteJSDataStreamTest.cs b/src/Components/Server/test/Circuits/RemoteJSDataStreamTest.cs index dac28c3acea4..e737cedbe7c3 100644 --- a/src/Components/Server/test/Circuits/RemoteJSDataStreamTest.cs +++ b/src/Components/Server/test/Circuits/RemoteJSDataStreamTest.cs @@ -287,6 +287,61 @@ public async Task ReceiveData_ReceivesDataThenTimesout_StreamDisposed() Assert.False(success); } + [Theory] + [InlineData(false)] + [InlineData(true)] + public void ValueLinkedCts_Works_WhenOneTokenCannotBeCanceled(bool isToken1Cancelable) + { + var cts = new CancellationTokenSource(); + var token1 = isToken1Cancelable ? cts.Token : CancellationToken.None; + var token2 = isToken1Cancelable ? CancellationToken.None : cts.Token; + + using var linkedCts = RemoteJSDataStream.ValueLinkedCancellationTokenSource.Create(token1, token2); + + Assert.False(linkedCts.HasLinkedCancellationTokenSource); + Assert.False(linkedCts.Token.IsCancellationRequested); + + cts.Cancel(); + + Assert.True(linkedCts.Token.IsCancellationRequested); + } + + [Fact] + public void ValueLinkedCts_Works_WhenBothTokensCannotBeCanceled() + { + using var linkedCts = RemoteJSDataStream.ValueLinkedCancellationTokenSource.Create( + CancellationToken.None, + CancellationToken.None); + + Assert.False(linkedCts.HasLinkedCancellationTokenSource); + Assert.False(linkedCts.Token.IsCancellationRequested); + } + + [Theory] + [InlineData(false, true)] + [InlineData(true, false)] + [InlineData(true, true)] + public void ValueLinkedCts_Works_WhenBothTokensCanBeCanceled(bool shouldCancelToken1, bool shouldCancelToken2) + { + var cts1 = new CancellationTokenSource(); + var cts2 = new CancellationTokenSource(); + using var linkedCts = RemoteJSDataStream.ValueLinkedCancellationTokenSource.Create(cts1.Token, cts2.Token); + + Assert.True(linkedCts.HasLinkedCancellationTokenSource); + Assert.False(linkedCts.Token.IsCancellationRequested); + + if (shouldCancelToken1) + { + cts1.Cancel(); + } + if (shouldCancelToken2) + { + cts2.Cancel(); + } + + Assert.True(linkedCts.Token.IsCancellationRequested); + } + private static async Task CreateRemoteJSDataStreamAsync(TestRemoteJSRuntime jsRuntime = null) { var jsStreamReference = Mock.Of();