Skip to content

Instantly share code, notes, and snippets.

@aannenko
Last active November 11, 2023 19:45
Show Gist options
  • Save aannenko/4dfd87296115f97ca790764ce57788fc to your computer and use it in GitHub Desktop.
Save aannenko/4dfd87296115f97ca790764ce57788fc to your computer and use it in GitHub Desktop.
One Task, many awaiters - starts a task, issues awaiter-tasks for it, tracks the task until it's finished or all awaiter-tasks are cancelled.

One Task - Many Awaiters

Purpose

I needed a code that:

  • when it's called the first time, starts a cancellable task and returns a cancellable awaiter-task for it
  • on subsequent calls, returns cancellable awaiter-tasks for the original task
  • lets the task keep running while there is someone waiting for its completion (some awaiter-tasks are not cancelled)
  • cancels the task when there is no one left waiting for its completion (all awaiter-tasks are cancelled)

I came up with several classes where AwaiterTaskSource is the centerpiece.

Scenario

Let's say we are building a web-service with an endpoint that accepts a string and performs a long-running calculation based on this string. And we want this endpoint to behave like this:

                                        time

                 Start of the operation. ---
          First actor calls the endpoint  |
    with a 1 second timeout, passes "a".  |
             Calculation for "a" starts,  |
       first actor waits for its result.  |
                                          |
                                         --- 0.5 seconds later.
                                          |  Second actor calls the endpoint
                                          |  with no timeout, passes "a" too.
                                          |  Calculation for "a" is already running,
                                          |  second actor waits for its result.
                                          |
            1 second into the operation. ---
        First actor's request times out,  |
  cancellation issued to the calculation  |
       yet the calculation keeps running  |
because second actor still waits for it.  |
                                          |
                                         --- 1.5 seconds into the operation.
                                             The calculation is finished,
                                             second actor gets the calculation result.

OneTaskManyAwaitersService allows us to achieve this if we pass the incoming "a" and a task-creating delegate to its RunOrAwait method.

Implementation

  • AwaiterTaskSource executes a task factory and then issues awaiter-tasks for the task created by the factory
  • OneTaskManyAwaitersService creates and accesses instances of AwaiterTaskSource by key, promotes their awaiter-tasks
  • OneTaskManyAwaitersServiceTests contains a slew of tests (and example usages) for the above classes.
#nullable enable
using System;
using System.Threading;
using System.Threading.Tasks;
public sealed class AwaiterTaskSource<TResult>
{
private readonly CancellationTokenSource _cancellationTokenSource;
private int _awaitersCount = 0;
private AwaiterTaskSource(Task<TResult> task, CancellationTokenSource cancellationTokenSource)
{
Task = task;
_cancellationTokenSource = cancellationTokenSource;
}
public Task<TResult> Task { get; }
public async Task<TResult> GetAwaiterTask(CancellationToken cancellationToken = default)
{
if (Task.IsCompleted)
return Task.GetAwaiter().GetResult();
Interlocked.Increment(ref _awaitersCount);
try
{
return await Task.WaitAsync(cancellationToken).ConfigureAwait(false);
}
finally
{
if (Interlocked.Decrement(ref _awaitersCount) is 0 && !Task.IsCompleted)
_cancellationTokenSource.Cancel();
}
}
public static AwaiterTaskSource<TResult> Run(Func<CancellationToken, Task<TResult>> taskFactory)
{
var cancellationTokenSource = new CancellationTokenSource();
return new(taskFactory(cancellationTokenSource.Token), cancellationTokenSource);
}
public static AwaiterTaskSource<TResult> Run<TArg>(
Func<TArg, CancellationToken, Task<TResult>> taskFactory,
TArg factoryArgument)
{
var cancellationTokenSource = new CancellationTokenSource();
return new(taskFactory(factoryArgument, cancellationTokenSource.Token), cancellationTokenSource);
}
}
#nullable enable
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
public class OneTaskManyAwaitersService<TKey, TResult> where TKey : notnull
{
private readonly ConcurrentDictionary<TKey, Lazy<AwaiterTaskSource<TResult>>> _trackers;
public OneTaskManyAwaitersService(IEqualityComparer<TKey>? keyComparer = null) =>
_trackers = new(keyComparer);
public async Task<TResult> RunOrAwait(TKey key,
Func<CancellationToken, Task<TResult>> taskFactory,
CancellationToken cancellationToken = default)
{
return await _trackers.GetOrAdd(
key,
static (key, trio) => new(() => AwaiterTaskSource<TResult>.Run(
static async (trio, cancellationToken) =>
{
var (key, factory, trackers) = trio;
try
{
return await factory(cancellationToken);
}
finally
{
trackers.TryRemove(key, out _);
}
},
trio)),
(key, taskFactory, _trackers))
.Value
.GetAwaiterTask(cancellationToken)
.ConfigureAwait(false);
}
}
#nullable enable
using NUnit.Framework;
using System.Threading;
using System.Threading.Tasks;
[TestFixture, Parallelizable]
public class OneTaskManyAwaitersServiceTests
{
private const int ShortDelay = 10; // ms
private const int MediumDelay = 40;
private const int LongDelay = 70;
private const int TaskKey = 1;
private const int TaskResult = 2;
private static async Task<int> TaskFactory(CancellationToken cancellationToken)
{
await Task.Delay(MediumDelay, cancellationToken);
return TaskResult;
}
[Test]
public void OneTaskWithNoTokenSucceeds()
{
var service = new OneTaskManyAwaitersService<int, int>();
var result = 0;
Assert.DoesNotThrowAsync(async () => result = await service.RunOrAwait(TaskKey, TaskFactory));
Assert.AreEqual(TaskResult, result);
}
[Test]
public void OneTaskWithLateTokenSucceeds()
{
var service = new OneTaskManyAwaitersService<int, int>();
var result = 0;
Assert.DoesNotThrowAsync(async () =>
result = await service.RunOrAwait(TaskKey, TaskFactory, new CancellationTokenSource(LongDelay).Token));
Assert.AreEqual(TaskResult, result);
}
[Test]
public void OneTaskWithEarlyTokenThrows()
{
var service = new OneTaskManyAwaitersService<int, int>();
Assert.ThrowsAsync<TaskCanceledException>(async () =>
await service.RunOrAwait(TaskKey, TaskFactory, new CancellationTokenSource(ShortDelay).Token));
}
[Test]
public void FirstTaskSucceeds_SecondTaskSucceeds_TaskFactoryCalledOnce()
{
var service = new OneTaskManyAwaitersService<int, int>();
var taskFactoryCalls = 0;
async Task<int> LocalTaskFactory(CancellationToken cancellationToken)
{
Interlocked.Increment(ref taskFactoryCalls);
await Task.Delay(ShortDelay, cancellationToken);
return TaskResult;
};
int[]? results = null;
CancellationTokenSource? lateCts = null;
CancellationTokenSource? earlyCts = null;
var firstTask = service.RunOrAwait(TaskKey, LocalTaskFactory,
(lateCts = new CancellationTokenSource(LongDelay)).Token);
var secondTask = service.RunOrAwait(TaskKey, LocalTaskFactory,
(earlyCts = new CancellationTokenSource(MediumDelay)).Token);
Assert.DoesNotThrowAsync(async () => results = await Task.WhenAll(firstTask, secondTask));
Assert.IsFalse(earlyCts.IsCancellationRequested);
Assert.IsFalse(lateCts.IsCancellationRequested);
CollectionAssert.AreEqual(results, new[] { TaskResult, TaskResult });
Assert.AreEqual(1, taskFactoryCalls);
}
[Test]
public void FirstTaskSucceeds_SecondTaskThrows_TaskFactoryCalledOnce()
{
var service = new OneTaskManyAwaitersService<int, int>();
var taskFactoryCalls = 0;
async Task<int> LocalTaskFactory(CancellationToken cancellationToken)
{
Interlocked.Increment(ref taskFactoryCalls);
await Task.Delay(MediumDelay, cancellationToken);
return TaskResult;
};
var result = 0;
Task<int>? firstToComplete = null;
CancellationTokenSource? lateCts = null;
var firstTask = service.RunOrAwait(TaskKey, LocalTaskFactory,
(lateCts = new CancellationTokenSource(LongDelay)).Token);
var secondTask = service.RunOrAwait(TaskKey, LocalTaskFactory,
new CancellationTokenSource(ShortDelay).Token);
Assert.DoesNotThrowAsync(async () => firstToComplete = await Task.WhenAny(firstTask, secondTask));
Assert.AreEqual(secondTask, firstToComplete);
Assert.IsFalse(lateCts.IsCancellationRequested);
Assert.IsFalse(firstTask.IsCompleted);
Assert.ThrowsAsync<TaskCanceledException>(async () => await secondTask);
Assert.DoesNotThrowAsync(async () => result = await firstTask);
Assert.AreEqual(TaskResult, result);
Assert.AreEqual(1, taskFactoryCalls);
}
[Test]
public void FirstTaskThrows_SecondTaskSucceeds_TaskFactoryCalledOnce()
{
var service = new OneTaskManyAwaitersService<int, int>();
var taskFactoryCalls = 0;
async Task<int> LocalTaskFactory(CancellationToken cancellationToken)
{
Interlocked.Increment(ref taskFactoryCalls);
await Task.Delay(MediumDelay, cancellationToken);
return TaskResult;
};
var result = 0;
Task<int>? firstToComplete = null;
CancellationTokenSource? lateCts = null;
var firstTask = service.RunOrAwait(TaskKey, LocalTaskFactory,
new CancellationTokenSource(ShortDelay).Token);
var secondTask = service.RunOrAwait(TaskKey, LocalTaskFactory,
(lateCts = new CancellationTokenSource(LongDelay)).Token);
Assert.DoesNotThrowAsync(async () => firstToComplete = await Task.WhenAny(firstTask, secondTask));
Assert.AreEqual(firstTask, firstToComplete);
Assert.IsFalse(lateCts.IsCancellationRequested);
Assert.IsFalse(secondTask.IsCompleted);
Assert.ThrowsAsync<TaskCanceledException>(async () => await firstTask);
Assert.DoesNotThrowAsync(async () => result = await secondTask);
Assert.AreEqual(TaskResult, result);
Assert.AreEqual(1, taskFactoryCalls);
}
[Test]
public void FirstTaskThrows_SecondTaskThrows_OriginalTaskCancellationRequested_TaskFactoryCalledOnce()
{
var service = new OneTaskManyAwaitersService<int, int>();
var taskFactoryCalls = 0;
CancellationToken innerCancellationToken = default;
async Task<int> LocalTaskFactory(CancellationToken cancellationToken)
{
Interlocked.Increment(ref taskFactoryCalls);
innerCancellationToken = cancellationToken;
await Task.Delay(LongDelay, cancellationToken);
return TaskResult;
};
Task<int>? firstToComplete = null;
CancellationTokenSource? lateCts = null;
var firstTask = service.RunOrAwait(TaskKey, LocalTaskFactory,
new CancellationTokenSource(ShortDelay).Token);
var secondTask = service.RunOrAwait(TaskKey, LocalTaskFactory,
(lateCts = new CancellationTokenSource(MediumDelay)).Token);
Assert.DoesNotThrowAsync(async () => firstToComplete = await Task.WhenAny(firstTask, secondTask));
Assert.AreEqual(firstTask, firstToComplete);
Assert.IsFalse(lateCts.IsCancellationRequested);
Assert.IsFalse(secondTask.IsCompleted);
Assert.ThrowsAsync<TaskCanceledException>(async () => await firstTask);
Assert.ThrowsAsync<TaskCanceledException>(async () => await secondTask);
Assert.IsTrue(innerCancellationToken.IsCancellationRequested);
Assert.AreEqual(1, taskFactoryCalls);
}
[Test]
public async Task ServiceDoesNotStopTrackingTaskIfOneAwaitingTaskCancelled_TaskFactoryCalledOnce()
{
var service = new OneTaskManyAwaitersService<int, int>();
var taskFactoryCalls = 0;
async Task<int> LocalTaskFactory(CancellationToken cancellationToken)
{
Interlocked.Increment(ref taskFactoryCalls);
await Task.Delay(MediumDelay, cancellationToken);
return TaskResult;
};
var firstTask = service.RunOrAwait(TaskKey, LocalTaskFactory, CancellationToken.None);
await Task.Delay(ShortDelay);
var secondTask = service.RunOrAwait(TaskKey, LocalTaskFactory, new CancellationToken(canceled: true));
await Task.Delay(ShortDelay);
var thirdTask = service.RunOrAwait(TaskKey, LocalTaskFactory, CancellationToken.None);
var firstToComplete = await Task.WhenAny(firstTask, secondTask, thirdTask);
Assert.AreEqual(secondTask, firstToComplete);
Assert.ThrowsAsync<TaskCanceledException>(async () => await secondTask);
Assert.DoesNotThrowAsync(async () => await firstTask);
Assert.DoesNotThrowAsync(async () => await thirdTask);
Assert.AreEqual(1, taskFactoryCalls);
}
[Test]
public void ServiceStopsTrackingTaskAfterCompletion()
{
var service = new OneTaskManyAwaitersService<int, int>();
var taskFactoryCalls = 0;
Task<int> LocalTaskFactory(CancellationToken cancellationToken)
{
Interlocked.Increment(ref taskFactoryCalls);
return Task.FromResult(TaskResult);
};
var result1 = 0;
var result2 = 0;
Assert.DoesNotThrowAsync(async () =>
{
result1 = await service.RunOrAwait(TaskKey, LocalTaskFactory, CancellationToken.None);
result2 = await service.RunOrAwait(TaskKey, LocalTaskFactory, CancellationToken.None);
});
Assert.AreEqual(2, taskFactoryCalls);
Assert.AreEqual(TaskResult, result1);
Assert.AreEqual(TaskResult, result2);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment