Skip to content

Instantly share code, notes, and snippets.

@radleta
Created November 18, 2019 14:00
Show Gist options
  • Save radleta/7aa90ed03960af29c003d81eb847b588 to your computer and use it in GitHub Desktop.
Save radleta/7aa90ed03960af29c003d81eb847b588 to your computer and use it in GitHub Desktop.
Useful extensions for C#.
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace RichardAdleta
{
public static class IEnumerableExtensions
{
public static IEnumerable<TItem> EmptyIfNull<TItem>(this IEnumerable<TItem> enumerable)
{
if (enumerable == null)
return Enumerable.Empty<TItem>();
return enumerable;
}
public static IEnumerable<TItem> ToEnumerable<TItem>(this TItem item)
{
yield return item;
}
public static bool IsNullOrEmpty<TItem>(this IEnumerable<TItem> enumerable)
{
return enumerable == null
|| enumerable.Count() == 0;
}
public static IEnumerable<TItem> Tree<TItem>(this IEnumerable<TItem> enumerable, Func<TItem, IEnumerable<TItem>> getChildren)
{
if (enumerable == null)
{
throw new System.ArgumentNullException(nameof(enumerable));
}
foreach (var item in enumerable)
{
yield return item;
var children = getChildren(item);
if (getChildren != null)
{
foreach (var child in Tree(children, getChildren))
{
yield return child;
}
}
}
}
public static IEnumerable<IEnumerable<T>> Partition<T>(this IEnumerable<T> enumerable, int partitionSize)
{
if (enumerable is null)
{
throw new ArgumentNullException(nameof(enumerable));
}
if (partitionSize < 1)
{
throw new ArgumentOutOfRangeException(nameof(partitionSize), "Cannot be less than one.");
}
// get an enumerator for the enumerable
using (var enumerator = enumerable.GetEnumerator())
{
// start by trying to read
while (enumerator.MoveNext())
{
// success now read the entire partition
yield return PartitionReadWhile<T>(enumerator, partitionSize);
}
}
}
private static IEnumerable<T> PartitionReadWhile<T>(IEnumerator<T> enumerator, int partitionSize)
{
for (var i = 0; i < partitionSize; partitionSize++)
{
// always read first b/c we've already moved to the next position
yield return enumerator.Current;
// try to go to the next position
if (!enumerator.MoveNext())
{
// bail when we've completed the enumerator
yield break;
}
}
}
public static StringBuilder Join<TItem>(this IEnumerable<TItem> strings, string separator = null)
{
var sb = new StringBuilder();
foreach (var s in strings)
{
if (separator != null
&& sb.Length > 0)
{
sb.Append(separator);
}
sb.Append(s);
}
return sb;
}
public static bool Same<T>(this IEnumerable<T> self, IEnumerable<T> other) where T : IEquatable<T>
{
if (self == null) throw new ArgumentNullException(nameof(self));
if (other == null) throw new ArgumentNullException(nameof(other));
using (var selfEnumerator = self.GetEnumerator())
using (var otherEnumerator = other.GetEnumerator())
{
var enumerate = false;
while (enumerate)
{
var selfNext = selfEnumerator.MoveNext();
var otherNext = otherEnumerator.MoveNext();
if (selfNext != otherNext
|| (selfNext && !selfEnumerator.Current.Equals(otherEnumerator.Current)))
{
return false;
}
if (!selfNext)
{
break;
}
}
return true;
}
}
public static IEnumerable<TItem> Distinct<TItem, TKey>(this IEnumerable<TItem> enumerable, Func<TItem, TKey> keySelector, int? allowDuplicateAfter = null)
{
var existing = new HashSet<TKey>();
var duplicates = 0;
foreach(var item in enumerable)
{
var key = keySelector(item);
if (!existing.Contains(key)
|| allowDuplicateAfter < duplicates)
{
existing.Add(key);
yield return item;
duplicates = 0;
}
else
{
duplicates++;
}
}
}
public static IEnumerable<TItem> Map<TItem>(this IEnumerable<TItem> enumerable, Action<TItem> map)
{
return enumerable.Select(i =>
{
map(i);
return i;
});
}
public static IEnumerable<string> SelectToString<TItem>(this IEnumerable<TItem> enumerable)
{
return enumerable.Select(i =>
{
return i?.ToString();
});
}
/// <summary>
/// Creates a <see cref="List{T}"/> with the optimial capacity based on the <c>enumerable</c>.
/// </summary>
/// <typeparam name="T">The type of the items in the enerable.</typeparam>
/// <param name="enumerable">The instance to check.</param>
/// <returns>A new instance of <see cref="List{T}"/> with optimial capacity based on the <c>enumerable</c>.</returns>
public static List<T> CreateListWithOptimalCapacityOrDefault<T>(this IEnumerable enumerable)
{
// return null when the enumerable is null
// no need creating something when we have nothing
if (enumerable == null)
{
return null;
}
// when the enumerable is something that can tell us a quick count
// then use it to init the collection
if (enumerable is ICollection collection)
{
return new List<T>(collection.Count);
}
// else fall back on default behavior
return new List<T>();
}
}
}
using Nito.AsyncEx;
using System;
using System.Collections;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
namespace RichardAdleta
{
public static class TaskExtensions
{
/// <summary>
/// The default concurrency of WhenAllThrottled.
/// </summary>
public static int DefaultWhenAllThrottledMaxDegreeOfParallelism { get; set; } = 3;
/// <summary>
/// Awaits each task in the <c>tasks</c> with the <see cref="DefaultWhenAllThrottledMaxDegreeOfParallelism"/> concurrency.
/// </summary>
/// <param name="tasks">The tasks to await.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>Awaitable.</returns>
public static Task WhenAllThrottledAsync(this IEnumerable<Task> tasks, System.Threading.CancellationToken cancellationToken) => WhenAllThrottledAsync(tasks, DefaultWhenAllThrottledMaxDegreeOfParallelism, cancellationToken);
/// <summary>
/// Awaits each task in the <c>tasks</c> with the <c>maxDegreeOfParallelism</c> concurrency.
/// </summary>
/// <param name="tasks">The tasks to await.</param>
/// <param name="maxDegreeOfParallelism">The maximum number of tasks to await at any one time.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>Awaitable.</returns>
public static async Task WhenAllThrottledAsync(this IEnumerable<Task> tasks, int maxDegreeOfParallelism, System.Threading.CancellationToken cancellationToken)
{
if (tasks is null)
{
throw new ArgumentNullException(nameof(tasks));
}
// when we know how big the collection
// and its less than or equal to our parallelism
if (tasks is ICollection collection
&& collection.Count <= maxDegreeOfParallelism)
{
// just do simple await on all without all our special logic
await Task.WhenAll(tasks).WaitAsync(cancellationToken).ConfigureAwait(false);
cancellationToken.ThrowIfCancellationRequested();
return;
}
// init our batch
var batch = new HashSet<Task>();
foreach (var task in tasks)
{
cancellationToken.ThrowIfCancellationRequested();
// special optimization for completed tasks
if (task.IsAsyncCompleted())
{
// await the task to ensure any exceptions are thrown here
await task.WaitAsync(cancellationToken).ConfigureAwait(false);
continue;
}
// when we don't have enough in our batch
bool taskAdded = false;
if (batch.Count < maxDegreeOfParallelism)
{
// add it to the batch
batch.Add(task);
taskAdded = true;
// when we still don't have enough
if (batch.Count < maxDegreeOfParallelism)
{
// go to the next task
continue;
}
}
// wait for any of the tasks in the batch to complete
var completedTask = await Task.WhenAny(batch).WaitAsync(cancellationToken).ConfigureAwait(false);
cancellationToken.ThrowIfCancellationRequested();
// await the task to ensure any exceptions are thrown here
await completedTask.WaitAsync(cancellationToken).ConfigureAwait(false);
cancellationToken.ThrowIfCancellationRequested();
// remove the completed task from the batch
batch.Remove(completedTask);
// add the current task to the batch
if (!taskAdded) batch.Add(task);
}
cancellationToken.ThrowIfCancellationRequested();
// wait on the remaining batch
await Task.WhenAll(batch).WaitAsync(cancellationToken).ConfigureAwait(false);
cancellationToken.ThrowIfCancellationRequested();
}
/// <summary>
/// Awaits each task in the <c>tasks</c> with the <see cref="DefaultWhenAllThrottledMaxDegreeOfParallelism"/> concurrency.
/// </summary>
/// <param name="tasks">The tasks to await.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>Awaitable.</returns>
public static Task<List<TResult>> WhenAllThrottledAsync<TResult>(this IEnumerable<Task<TResult>> tasks, System.Threading.CancellationToken cancellationToken) => WhenAllThrottledAsync(tasks, DefaultWhenAllThrottledMaxDegreeOfParallelism, cancellationToken);
/// <summary>
/// Awaits each task in the <c>tasks</c> with the <c>maxDegreeOfParallelism</c> concurrency.
/// </summary>
/// <param name="tasks">The tasks to await.</param>
/// <param name="maxDegreeOfParallelism">The maximum number of tasks to await at any one time.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>Awaitable.</returns>
public static async Task<List<TResult>> WhenAllThrottledAsync<TResult>(this IEnumerable<Task<TResult>> tasks,
int maxDegreeOfParallelism,
CancellationToken cancellationToken)
{
async Task RunThenInsertResultAtAsync(Task<TResult> task, List<TResult> output, int outputIndex, System.Threading.CancellationToken innerCancellationToken)
{
// wait for the result
var result = await task;
// bail when cancelled
innerCancellationToken.ThrowIfCancellationRequested();
// figure out where to insert the results
var minCountNeeded = outputIndex + 1;
// is the output big enough to insert into
if (output.Count < minCountNeeded)
{
// when its not then lock the output
// b/c multiple threads could be doing this at the same time
lock (output)
{
// re-check just in case another thread already increased the size
if (output.Count < minCountNeeded)
{
// init the output out to what we need
output.AddRange(Enumerable.Repeat(default(TResult), minCountNeeded - output.Count));
}
}
}
// insert the results into the correct final position
output[outputIndex] = result;
}
if (tasks is null)
{
throw new ArgumentNullException(nameof(tasks));
}
// when we know how big the collection
// and its less than or equal to our parallelism
if (tasks is ICollection collection
&& collection.Count <= maxDegreeOfParallelism)
{
// just do simple await on all without all our special logic
var quickResults = await Task.WhenAll(tasks).WaitAsync(cancellationToken).ConfigureAwait(false);
cancellationToken.ThrowIfCancellationRequested();
return quickResults.ToList();
}
var batch = new HashSet<Task>();
var resultIndex = -1;
var results = tasks.CreateListWithOptimalCapacityOrDefault<TResult>();
foreach (var task in tasks)
{
cancellationToken.ThrowIfCancellationRequested();
// wrap our task with async method to insert the result into the correct original position
var taskWrapped = RunThenInsertResultAtAsync(task, results, ++resultIndex, cancellationToken);
// decide whether we need to batch it or not
if (taskWrapped.Status == TaskStatus.RanToCompletion)
{
// its already completed so move to the next task
continue;
}
// when we don't have enough in our batch
bool taskAdded = false;
if (batch.Count < maxDegreeOfParallelism)
{
// add it to the batch
batch.Add(taskWrapped);
taskAdded = true;
// when we still don't have enough
if (batch.Count < maxDegreeOfParallelism)
{
// go to the next task
continue;
}
}
// lets wait for the first task to complete
var completedTask = await Task.WhenAny(batch)
.WaitAsync(cancellationToken)
.ConfigureAwait(false);
// bail when cancellation triggered
cancellationToken.ThrowIfCancellationRequested();
// remove it from the batch
batch.Remove(completedTask);
// add the current task to the batch
if (!taskAdded) batch.Add(taskWrapped);
}
// finally await the remaining batch
if (batch.Count > 0)
{
// wait on all the remaining tasks to complete
await Task.WhenAll(batch)
.WaitAsync(cancellationToken)
.ConfigureAwait(false);
}
cancellationToken.ThrowIfCancellationRequested();
return results;
}
/// <summary>
/// Filters the <c>enumerable</c> by the <c>where</c>.
/// </summary>
/// <typeparam name="T">The type contained by the enumerable.</typeparam>
/// <param name="enumerable">The enumerable.</param>
/// <param name="where">The function to call per item to evaluate whether it should be in the final result.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A new instance of <see cref="List{T}"/> with the items from <c>enumerable</c> that evaluated to <c>true</c> when <c>where</c> was called.</returns>
public static async Task<List<T>> WhereAsync<T>(this IEnumerable<T> enumerable, Func<T, CancellationToken, ValueTask<bool>> where, CancellationToken cancellationToken)
{
if (enumerable is null)
{
throw new ArgumentNullException(nameof(enumerable));
}
if (where is null)
{
throw new ArgumentNullException(nameof(where));
}
var results = new List<T>();
foreach (var item in enumerable)
{
if (await where(item, cancellationToken).ConfigureAwait(false))
{
results.Add(item);
}
}
return results;
}
/// <summary>
/// Determines whether the async task is completed or not.
/// </summary>
/// <param name="task">The task to check.</param>
/// <returns><c>true</c> the async task has not completed yet; otherwise, <c>false</c>.</returns>
public static bool IsAsyncCompleted(this Task task)
{
if (task is null)
{
throw new ArgumentNullException(nameof(task));
}
return task.Status != TaskStatus.RanToCompletion
|| task.Status != TaskStatus.Faulted
|| task.Status != TaskStatus.Canceled;
}
/// <summary>
/// Gets or adds a new <see cref="AsyncLockCookie{TResult}"/> with a completion removing the item from the <c>dictionary</c>.
/// </summary>
/// <typeparam name="TKey">The type of the key.</typeparam>
/// <typeparam name="TValue">The type of the value.</typeparam>
/// <param name="dictionary">The dictionary.</param>
/// <param name="key">The key for the item in the dictionary.</param>
/// <param name="create">The create method to call to get the <see cref="TValue"/>.</param>
/// <returns>The cookie.</returns>
public static AsyncLockCookie<TValue> GetOrAddThenRemoveOnCompleted<TKey, TValue>(this ConcurrentDictionary<TKey, AsyncLockCookie<TValue>> dictionary, TKey key, Func<CancellationToken, Task<TValue>> create, System.Threading.CancellationToken cancellationToken)
{
// remember whether or not we created it
var created = false;
// get or add it
var cookie = dictionary.GetOrAdd(key, (k) =>
{
// tell the current thread it created it
created = true;
// return the cookie
return new AsyncLockCookie<TValue>(create);
});
// add completion task to remove the item from the dictionary when done
if (created)
{
_ = cookie.Completed
.Task
.ContinueWith((t) =>
{
dictionary.TryRemove(key, out AsyncLockCookie<TValue> removed);
}, TaskScheduler.Default);
}
// return the exclusive access
return cookie;
}
public static async Task<TResult> UntilCompletedAsync<TResult>(this Task<TResult> task, TimeSpan delay, Action action, System.Threading.CancellationToken cancellationToken)
{
var runningTask = Task.Run(() => task, cancellationToken);
while (!runningTask.IsCompleted
&& !cancellationToken.IsCancellationRequested)
{
// do action
action();
// wait on running task
if (await Task.WhenAny(
runningTask,
Task.Delay(delay)
) == runningTask)
{
return await runningTask;
}
}
cancellationToken.ThrowIfCancellationRequested();
return await runningTask;
}
public static async Task UntilCompletedAsync(this Task task, TimeSpan delay, Action action, System.Threading.CancellationToken cancellationToken)
{
var runningTask = Task.Run(() => task, cancellationToken);
while (!runningTask.IsCompleted
&& !cancellationToken.IsCancellationRequested)
{
// do action
action();
// wait on running task
if (await Task.WhenAny(
runningTask,
Task.Delay(delay)
) == runningTask)
{
await runningTask;
return;
}
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment