Last active
June 6, 2024 13:17
-
-
Save YexuanXiao/2b857a504bc304e9319c96308c56c675 to your computer and use it in GitHub Desktop.
thread pool (near completion)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <algorithm> | |
#include <vector> | |
#include <thread> | |
#include <coroutine> | |
#include <compare> | |
#include <semaphore> | |
#include <atomic> | |
#include <mutex> | |
#include <exception> | |
#include <array> | |
#include <cassert> | |
#include <optional> | |
#include <type_traits> | |
// clang-format off | |
#if !defined(__cpp_size_t_suffix) || __cpp_size_t_suffix < 202011L | |
inline constexpr auto operator""uz(unsigned long long const value) | |
{ | |
return std::size_t{ value }; | |
} | |
inline constexpr auto operator""z(unsigned long long const value) | |
{ | |
return std::ptrdiff_t(value); | |
} | |
#endif | |
// clang-format on | |
#define BIZWEN_THREAD_POOL_DEAD_LOCK_MITIGATE | |
#define BIZWEN_THREAD_POOL_ENABLE_SAFE_CHECK | |
namespace bizwen | |
{ | |
class thread_pool | |
{ | |
class thread_ | |
{ | |
using vector = std::vector<std::coroutine_handle<>>; | |
std::thread t_; | |
std::vector<std::coroutine_handle<>> q_; | |
public: | |
thread_() | |
{ | |
q_.reserve(10uz); | |
} | |
void thread(std::thread t) noexcept | |
{ | |
t_ = std::move(t); | |
} | |
thread_(thread_&&) noexcept = default; | |
thread_& operator=(thread_&&) noexcept = default; | |
operator std::thread::id() const noexcept | |
{ | |
return t_.get_id(); | |
} | |
auto join() noexcept | |
{ | |
return t_.join(); | |
} | |
auto begin() noexcept | |
{ | |
return q_.begin(); | |
} | |
auto end() noexcept | |
{ | |
return q_.end(); | |
} | |
auto size() const noexcept | |
{ | |
return q_.size(); | |
} | |
auto clear() noexcept | |
{ | |
return q_.clear(); | |
} | |
auto resize(std::size_t s) noexcept | |
{ | |
q_.resize(s); | |
} | |
auto push_back(std::coroutine_handle<> h) | |
{ | |
q_.push_back(h); | |
} | |
}; | |
template <typename T> // breakline | |
class priority_queue_: private std::vector<T> | |
{ | |
using vector = std::vector<T>; | |
public: | |
using vector::begin; | |
using vector::empty; | |
using vector::end; | |
using vector::size; | |
using vector::vector; | |
template <class... Args> // breakline | |
void emplace_back(Args&&... args) | |
{ | |
vector::emplace_back(std::forward<Args>(args)...); | |
std::push_heap(vector::begin(), vector::end()); | |
} | |
void pop() noexcept | |
{ | |
std::pop_heap(vector::begin(), vector::end()); | |
vector::pop_back(); | |
} | |
T top() noexcept | |
{ | |
return vector::front(); | |
} | |
}; | |
class task_base_ | |
{ | |
protected: | |
std::coroutine_handle<> handle_; | |
unsigned long long uid_; | |
std::thread::id tid_; | |
public: | |
task_base_(std::coroutine_handle<> handle, unsigned long long id, std::thread::id tid) noexcept : handle_(handle), uid_(id), tid_(tid) | |
{ | |
} | |
void operator()() noexcept | |
{ | |
handle_(); | |
} | |
operator decltype(tid_)() const noexcept | |
{ | |
return tid_; | |
} | |
operator decltype(handle_)() const noexcept | |
{ | |
return handle_; | |
} | |
friend bool operator==(task_base_& t, unsigned long long id) noexcept | |
{ | |
return t.uid_ == id; | |
} | |
}; | |
class normal_task_: public task_base_ | |
{ | |
std::size_t priority_; | |
public: | |
normal_task_(std::coroutine_handle<> handle, unsigned long long id, std::size_t priority, std::thread::id tid) noexcept : task_base_(handle, id, tid), priority_(priority) | |
{ | |
} | |
friend std::weak_ordering operator<=>(normal_task_ const& lhs, normal_task_ const& rhs) noexcept | |
{ | |
if (lhs.priority_ > rhs.priority_) | |
return std::weak_ordering::greater; | |
else if (lhs.priority_ < rhs.priority_) | |
return std::weak_ordering::less; | |
if (lhs.uid_ < rhs.uid_) | |
return std::weak_ordering::greater; | |
else if (lhs.uid_ > rhs.uid_) | |
return std::weak_ordering::less; | |
return std::weak_ordering::equivalent; | |
} | |
}; | |
class lazy_task_: public task_base_ | |
{ | |
std::chrono::steady_clock::time_point time_; | |
public: | |
lazy_task_(std::coroutine_handle<> handle, unsigned long long id, std::chrono::milliseconds duration, std::thread::id tid) noexcept : task_base_(handle, id, tid), time_(std::chrono::steady_clock::now() + duration) | |
{ | |
} | |
operator decltype(time_)() const noexcept | |
{ | |
return time_; | |
} | |
friend std::weak_ordering operator<=>(lazy_task_ const& lhs, lazy_task_ const& rhs) noexcept | |
{ | |
if (lhs.time_ < rhs.time_) | |
return std::weak_ordering::greater; | |
else if (lhs.time_ > rhs.time_) | |
return std::weak_ordering::less; | |
if (lhs.uid_ < rhs.uid_) | |
return std::weak_ordering::greater; | |
else if (lhs.uid_ > rhs.uid_) | |
return std::weak_ordering::less; | |
return std::weak_ordering::equivalent; | |
} | |
}; | |
class normal_callback_ | |
{ | |
thread_pool& pool_; | |
std::size_t pos_; | |
public: | |
normal_callback_(thread_pool& pool, std::size_t pos) noexcept : pool_(pool), pos_(pos) | |
{ | |
} | |
void operator()() noexcept | |
{ | |
pool_.normal_loop_(pos_); | |
} | |
}; | |
class lazy_callback_ | |
{ | |
thread_pool& pool_; | |
public: | |
lazy_callback_(thread_pool& pool) noexcept : pool_(pool) | |
{ | |
} | |
void operator()() noexcept | |
{ | |
pool_.lazy_loop_(); | |
} | |
}; | |
// to ensure exception safety, system_error are not accepted | |
class mutex_ | |
{ | |
std::atomic<bool> s_{}; | |
public: | |
void lock() noexcept | |
{ | |
while (s_.exchange(true, std::memory_order::acquire)) | |
s_.wait(true, std::memory_order::relaxed); | |
} | |
bool try_lock() noexcept | |
{ | |
return !s_.exchange(true, std::memory_order::acquire); | |
} | |
void unlock() noexcept | |
{ | |
s_.store(false, std::memory_order::release); | |
s_.notify_one(); | |
} | |
}; | |
// to ensure exception safety, system_error are not accepted | |
class waiter_ | |
{ | |
std::counting_semaphore<> s_{ 0z }; | |
public: | |
waiter_() noexcept = default; | |
void notify_n(std::ptrdiff_t n) noexcept | |
{ | |
s_.release(n); | |
} | |
void notify_one() noexcept | |
{ | |
s_.release(); | |
} | |
void wait() noexcept | |
{ | |
s_.acquire(); | |
} | |
bool try_wait() noexcept | |
{ | |
return s_.try_acquire(); | |
} | |
template <class Clock, class Duration> // breakline | |
bool try_wait_until(std::chrono::time_point<Clock, Duration> abs_time) noexcept | |
{ | |
return s_.try_acquire_until(abs_time); | |
} | |
template <class Rep, class Period> // breakline | |
bool try_wait_for(std::chrono::duration<Rep, Period> rel_time) noexcept | |
{ | |
return s_.try_acquire_for(rel_time); | |
} | |
}; | |
std::atomic<bool> exit_flag_{}; | |
mutex_ mitigate_mutex_{}; | |
#ifdef BIZWEN_THREAD_POOL_DEAD_LOCK_MITIGATE | |
int thread_count_{}; | |
std::chrono::steady_clock::time_point time_{}; | |
#endif | |
std::atomic<unsigned long long> unique_id_{}; | |
priority_queue_<normal_task_> normals_queue_{}; | |
std::vector<thread_> normals_threads_{}; | |
waiter_ normals_waiter_{}; | |
mutex_ normals_mutex_{}; | |
priority_queue_<lazy_task_> lazys_queue_{}; | |
std::thread lazys_thread_{}; | |
waiter_ lazys_waiter_{}; | |
mutex_ lazys_mutex_{}; | |
#ifdef BIZWEN_THREAD_POOL_DEAD_LOCK_MITIGATE | |
class count_lock_ | |
{ | |
thread_pool& pool_; | |
public: | |
count_lock_(thread_pool& pool) noexcept : pool_(pool) | |
{ | |
++pool_.thread_count_; | |
} | |
~count_lock_() | |
{ | |
--pool_.thread_count_; | |
pool_.time_ = std::chrono::steady_clock::now(); | |
} | |
}; | |
#endif | |
void increase_thread_() | |
{ | |
#ifdef BIZWEN_THREAD_POOL_DEAD_LOCK_MITIGATE | |
if (!mitigate_mutex_.try_lock()) [[unlikely]] | |
return; | |
#if defined(_MSC_VER) && !defined(__clang__) | |
#pragma warning(push) | |
#pragma warning(disable : 26110) | |
#endif | |
std::unique_lock lock{ mitigate_mutex_, std::adopt_lock }; | |
#if defined(_MSC_VER) && !defined(__clang__) | |
#pragma warning(pop) | |
#endif | |
if (auto size{ normals_threads_.size() }; int((size == thread_count_) & (time_ + std::chrono::milliseconds{ 16 } <= std::chrono::steady_clock::now()))) [[unlikely]] | |
{ | |
normals_threads_.resize(size + 1uz); | |
normals_threads_[size].thread(std::thread(normal_callback_{ *this, size })); | |
} | |
#endif | |
} | |
void set_exit_() noexcept | |
{ | |
exit_flag_.store(true, std::memory_order::relaxed); | |
} | |
bool is_exit_() const noexcept | |
{ | |
return exit_flag_.load(std::memory_order::relaxed); | |
} | |
void run_in_(std::coroutine_handle<> handle, std::thread::id id) noexcept | |
{ | |
assert(id != decltype(id){}); | |
std::lock_guard lock{ mitigate_mutex_ }; | |
for (auto& i : normals_threads_) [[likely]] | |
{ | |
if (i == id) [[unlikely]] | |
{ | |
i.push_back(handle); | |
return; | |
} | |
} | |
std::abort(); | |
} | |
std::size_t consume_pre_thread_queue(std::size_t pos) | |
{ | |
while (true) | |
{ | |
// 锁定所有线程和队列防止修改 | |
std::unique_lock lock{ mitigate_mutex_ }; | |
// 上锁之后设置线程数量,为了之后通知用 | |
// 获得当前线程的队列,注意可能由于线程数组被扩容而导致失效,因此需要通过pos访问 | |
auto& src{ normals_threads_[pos] }; | |
// 取最多取10个任务存入tasks,优化size==0的情况 | |
auto size{ src.size() }; | |
if (!size) [[likely]] | |
return normals_threads_.size(); | |
// 每次最多取10个 | |
std::array<std::coroutine_handle<>, 10uz> tasks{}; | |
auto begin{ src.begin() }; | |
if (size < 10uz) [[likely]] | |
{ | |
std::copy(begin, begin + size, tasks.begin()); | |
src.clear(); | |
} | |
else | |
{ | |
std::copy(begin, begin + 10uz, tasks.begin()); | |
std::copy(begin + 10uz, begin + size, begin); | |
src.resize(size - 10uz); | |
size = 10uz; | |
} | |
lock.unlock(); | |
for (auto begin{ tasks.begin() }, end{ begin + size }; begin != end; ++begin) [[unlikely]] | |
(*begin)(); | |
} | |
} | |
void normal_loop_(std::size_t pos) | |
{ | |
auto tid{ std::this_thread::get_id() }; | |
while (true) | |
{ | |
// 首先执行该线程的队列中的任务,并获得当前线程数 | |
std::size_t threads{ consume_pre_thread_queue(pos) }; | |
// 为了使得线程池有能够被析构的可能,不能使用无限阻塞的等待 | |
// 每秒至少取消等待一次,以提供检查退出标志的机会 | |
// 初始状态为阻塞,因为没有元素 | |
// | |
// 通知器是自通知的,也就是说在执行任务之前所有线程通过通知器串行执行 | |
// 当首次出现任务时唤醒第一个通知器 | |
// 当存在下一个任务时唤醒第二个通知器 | |
// 尝试等待有任务可用或者1s | |
if (!normals_waiter_.try_wait_for(std::chrono::seconds{ 1 })) | |
{ | |
// 如果等待1s仍无任务 | |
// 检查退出标志 | |
if (is_exit_()) [[unlikely]] | |
break; | |
// 如果不退出,则进入下一轮 | |
else | |
continue; | |
} | |
// 保护插入 | |
std::unique_lock lock{ normals_mutex_ }; | |
auto size{ normals_queue_.size() }; | |
// 由于可能会被指定任务时多唤醒,因此不保证容器内有元素 | |
if (!size) [[unlikely]] | |
{ | |
if (is_exit_()) [[unlikely]] | |
break; | |
continue; | |
} | |
auto task{ normals_queue_.top() }; | |
// 测试线程是否符合,如果不符合,释放锁,通知并则进入下一轮 | |
if (std::thread::id task_tid{ task }; int(int(task_tid != tid) & (task_tid != std::thread::id{}))) [[unlikely]] | |
{ | |
run_in_(task, task_tid); | |
normals_queue_.pop(); | |
normals_waiter_.notify_n(threads); | |
lock.unlock(); | |
continue; | |
} | |
// 此时所有条件都满足,执行任务 | |
// 弹出并解锁 | |
normals_queue_.pop(); | |
// 如果还有剩余任务则通知 | |
if (size >> 1) | |
normals_waiter_.notify_one(); | |
lock.unlock(); | |
#ifdef BIZWEN_THREAD_POOL_DEAD_LOCK_MITIGATE | |
// 记录执行中的任务数 | |
count_lock_ c_lock{ *this }; | |
#endif | |
// 执行任务 | |
task(); | |
} | |
} | |
// 如果size=0则flag为0,否则为1 | |
void lazy_loop_() | |
{ | |
while (true) | |
{ | |
// lazys是单一消费者,不存在多个线程同时消费的情况 | |
// 尝试等待有任务可用或者1s | |
if (!lazys_waiter_.try_wait_for(std::chrono::seconds{ 1 })) | |
{ | |
// 如果等待1s仍无任务 | |
// 检查退出标志 | |
if (is_exit_()) [[unlikely]] | |
break; | |
// 如果不退出,则进入下一轮 | |
else | |
continue; | |
} | |
// 保护插入 | |
std::unique_lock lock{ lazys_mutex_ }; | |
auto size{ lazys_queue_.size() }; | |
if (!size) [[unlikely]] | |
{ | |
if (is_exit_()) [[unlikely]] | |
break; | |
continue; | |
} | |
auto task{ lazys_queue_.top() }; | |
std::chrono::steady_clock::time_point time{ task }; | |
// 如果到时间,则弹出推入normals并进入下一轮 | |
if (time < decltype(time)::clock::now()) [[unlikely]] | |
{ | |
lazys_queue_.pop(); | |
// 通知并进入下一轮 | |
if (size >> 1) | |
lazys_waiter_.notify_one(); | |
lock.unlock(); | |
run_once(task, std::size_t(-1)); | |
continue; | |
} | |
// 如果未到时间,则等待到时间或被通知 | |
// 如果成功获得,则进入下一轮,这种情况只在新插入元素的情况下才发生 | |
if (lazys_waiter_.try_wait_until(time)) [[unlikely]] | |
{ | |
// 注意,此时已经消费掉了这次通知,因此需要重新补充 | |
lazys_waiter_.notify_one(); | |
continue; | |
} | |
lazys_queue_.pop(); | |
if (size >> 1) | |
lazys_waiter_.notify_one(); | |
lock.unlock(); | |
// 如果成功等待,则推入normals并进入下一轮 | |
run_once(task, std::size_t(-1)); | |
} | |
} | |
auto gen_id_() noexcept | |
{ | |
return unique_id_.fetch_add(1, std::memory_order::relaxed); | |
} | |
class context_base_ | |
{ | |
protected: | |
std::thread::id tid_; | |
context_base_() noexcept = default; | |
context_base_(std::thread::id tid) noexcept : tid_(tid) | |
{ | |
} | |
}; | |
public: | |
enum class id : unsigned long long | |
{ | |
}; | |
class context: private context_base_ | |
{ | |
friend thread_pool; | |
std::thread::id id() const noexcept | |
{ | |
return tid_; | |
} | |
context(std::thread::id tid) noexcept : context_base_(tid) | |
{ | |
} | |
public: | |
context() noexcept = default; | |
context(context const& ctx) noexcept : context_base_(ctx.tid_) | |
{ | |
} | |
bool operator==(const context& c) const noexcept | |
{ | |
return c.tid_ == tid_; | |
} | |
context& operator=(context const& ctx) noexcept | |
{ | |
tid_ = ctx.tid_; | |
return *this; | |
} | |
}; | |
// must be called from a thread owned by the thread pool | |
static context capture_context() noexcept | |
{ | |
return context{ std::this_thread::get_id() }; | |
} | |
private: | |
void check_context_legal_(context ctx) noexcept | |
{ | |
#ifdef BIZWEN_THREAD_POOL_ENABLE_SAFE_CHECK | |
auto id{ ctx.id() }; | |
if (id == std::thread::id{}) [[likely]] | |
return; | |
std::lock_guard lock{ mitigate_mutex_ }; | |
for (auto const& i : normals_threads_) [[likely]] | |
if (id == i) [[unlikely]] | |
return; | |
std::abort(); | |
#endif | |
} | |
static std::size_t cacl_thread_num(std::size_t num) noexcept | |
{ | |
auto n{ std::thread::hardware_concurrency() }; | |
n = n ? n : 2; | |
num = num > 1 ? num : n; | |
return num; | |
} | |
public: | |
id run_once(std::coroutine_handle<> callback, std::size_t priority = 0uz, context ctx = {}) | |
{ | |
check_context_legal_(ctx); | |
std::unique_lock lock{ normals_mutex_ }; | |
if (!normals_threads_.empty()) | |
normals_waiter_.notify_one(); | |
auto n{ gen_id_() }; | |
normals_queue_.emplace_back(callback, n, priority, ctx.id()); | |
lock.unlock(); | |
increase_thread_(); | |
return id{ n }; | |
} | |
id run_after(std::coroutine_handle<> callback, std::chrono::milliseconds duration, context ctx = {}) | |
{ | |
check_context_legal_(ctx); | |
// 无条件通知,这样才能打断等待中的线程 | |
lazys_waiter_.notify_one(); | |
std::lock_guard lock{ lazys_mutex_ }; | |
lazys_waiter_.notify_one(); | |
auto n{ gen_id_() }; | |
lazys_queue_.emplace_back(std::move(callback), n, duration, ctx.id()); | |
return id{ n }; | |
} | |
thread_pool(std::size_t num = 0uz) : normals_threads_(cacl_thread_num(num)), lazys_thread_(lazy_callback_{ *this }) | |
{ | |
for (std::size_t i{}, n{ normals_threads_.size() }; i != n; ++i) [[likely]] | |
normals_threads_[i].thread(std::thread(normal_callback_{ *this, i })); | |
} | |
~thread_pool() | |
{ | |
set_exit_(); | |
for (auto& i : normals_threads_) [[likely]] | |
i.join(); | |
lazys_thread_.join(); | |
} | |
}; | |
} | |
bizwen::thread_pool pool{}; | |
namespace bizwen | |
{ | |
class canceled_coroutine | |
{ | |
}; | |
class waiter_ | |
{ | |
std::binary_semaphore s_{ 0z }; | |
public: | |
waiter_() noexcept = default; | |
void notify_one() noexcept | |
{ | |
s_.release(); | |
} | |
void wait() noexcept | |
{ | |
s_.acquire(); | |
} | |
bool try_wait() noexcept | |
{ | |
return s_.try_acquire(); | |
} | |
template <class Clock, class Duration> // breakline | |
bool try_wait_until(std::chrono::time_point<Clock, Duration> abs_time) noexcept | |
{ | |
return s_.try_acquire_until(abs_time); | |
} | |
template <class Rep, class Period> // breakline | |
bool try_wait_for(std::chrono::duration<Rep, Period> rel_time) noexcept | |
{ | |
return s_.try_acquire_for(rel_time); | |
} | |
}; | |
class cancelable_promise_base: public waiter_ | |
{ | |
std::atomic<bool> canceled_{}; | |
protected: | |
std::exception_ptr e_{}; | |
public: | |
void cancel() noexcept | |
{ | |
canceled_.store(true, std::memory_order::relaxed); | |
} | |
bool canceled() noexcept | |
{ | |
return canceled_.load(std::memory_order::relaxed); | |
} | |
}; | |
template <typename T> // breakline | |
static std::coroutine_handle<cancelable_promise_base> to_cancelable(std::coroutine_handle<T> t) | |
requires std::is_convertible_v<T&, cancelable_promise_base&> | |
{ | |
return std::coroutine_handle<cancelable_promise_base>::from_address(t.address()); | |
} | |
static std::coroutine_handle<cancelable_promise_base> to_cancelable(std::coroutine_handle<> t) | |
{ | |
return std::coroutine_handle<cancelable_promise_base>::from_address(t.address()); | |
} | |
class cancelable_awaiter_base: public std::suspend_always | |
{ | |
std::coroutine_handle<cancelable_promise_base> h_; | |
protected: | |
void set_handle(std::coroutine_handle<> h) noexcept | |
{ | |
h_ = to_cancelable(h); | |
} | |
void throw_if_canceled() | |
{ | |
if (h_.promise().canceled()) | |
throw canceled_coroutine{}; | |
} | |
}; | |
class progress_promise_base: public cancelable_promise_base | |
{ | |
protected: | |
std::size_t progress{}; | |
}; | |
auto resume_background() | |
{ | |
struct background_awaiter: cancelable_awaiter_base | |
{ | |
void await_suspend(std::coroutine_handle<> handle) | |
{ | |
set_handle(handle); | |
pool.run_once(handle); | |
} | |
void await_resume() | |
{ | |
throw_if_canceled(); | |
} | |
}; | |
return background_awaiter{}; | |
} | |
auto operator co_await(thread_pool::context c) | |
{ | |
struct apartment_awaiter: cancelable_awaiter_base | |
{ | |
thread_pool::context c; | |
bool await_ready() const noexcept | |
{ | |
return thread_pool::capture_context() == c; | |
} | |
void await_suspend(std::coroutine_handle<> handle) | |
{ | |
set_handle(handle); | |
pool.run_once(handle, 0uz, c); | |
} | |
void await_resume() | |
{ | |
throw_if_canceled(); | |
} | |
}; | |
return apartment_awaiter{ .c = c }; | |
} | |
template <class Rep, class Period> // breakline | |
auto operator co_await(std::chrono::duration<Rep, Period> d) | |
{ | |
struct timer_awaiter: cancelable_awaiter_base | |
{ | |
std::chrono::milliseconds d_; | |
bool await_ready() const noexcept | |
{ | |
return d_ <= decltype(d_)::zero(); | |
} | |
void await_suspend(std::coroutine_handle<> handle) | |
{ | |
set_handle(handle); | |
pool.run_after(handle, d_); | |
} | |
void await_resume() | |
{ | |
throw_if_canceled(); | |
} | |
}; | |
return timer_awaiter{ .d_ = std::chrono::duration_cast<std::chrono::milliseconds>(d) }; | |
} | |
// like C++/WinRT promise_base::final_suspend_awaiter | |
class final_suspend_awaiter: public std::suspend_always | |
{ | |
cancelable_promise_base& p_; | |
public: | |
final_suspend_awaiter(cancelable_promise_base& p) noexcept : p_(p) | |
{ | |
} | |
void await_suspend(std::coroutine_handle<> handle) noexcept | |
{ | |
p_.notify_one(); | |
} | |
}; | |
template <typename T = void> // breakline | |
class task | |
{ | |
public: | |
class promise_type; | |
private: | |
promise_type& p_; | |
friend promise_type; | |
task(promise_type& p) noexcept : p_(p) | |
{ | |
} | |
public: | |
class promise_type: public cancelable_promise_base | |
{ | |
friend task; | |
T t_{}; | |
public: | |
promise_type() | |
{ | |
} | |
task get_return_object() | |
{ | |
return { *this }; | |
} | |
std::suspend_never initial_suspend() | |
{ | |
return {}; | |
} | |
final_suspend_awaiter final_suspend() noexcept | |
{ | |
return { *this }; | |
} | |
void return_value(T&& t) | |
{ | |
t_ = std::move(t); | |
} | |
void unhandled_exception() | |
{ | |
e_ = std::current_exception(); | |
} | |
}; | |
void cancel() noexcept | |
{ | |
p_.cancel(); | |
} | |
T get() | |
{ | |
p_.wait(); | |
std::exception_ptr e{ p_.e_ }; | |
T t{ std::move(p_.t_) }; | |
p_.notify_one(); | |
if (e) | |
std::rethrow_exception(e); | |
return std::move(t); | |
} | |
// like P2300 this_thread::sync_wait | |
std::optional<T> try_get() | |
{ | |
if (p_.try_wait()) | |
{ | |
std::exception_ptr e{ p_.e_ }; | |
T t{ std::move(p_.t_) }; | |
p_.notify_one(); | |
if (e) | |
std::rethrow_exception(e); | |
else | |
return std::move(t); | |
} | |
return std::nullopt; | |
} | |
template <class Clock, class Duration> // breakline | |
std::optional<T> try_get_until(std::chrono::time_point<Clock, Duration> abs_time) noexcept | |
{ | |
if (p_.try_wait_until(abs_time)) | |
{ | |
std::exception_ptr e{ p_.e_ }; | |
T t{ std::move(p_.t_) }; | |
p_.notify_one(); | |
if (e) | |
std::rethrow_exception(e); | |
else | |
return std::move(t); | |
} | |
return std::nullopt; | |
} | |
template <class Rep, class Period> // breakline | |
std::optional<T> try_get_for(std::chrono::duration<Rep, Period> rel_time) noexcept | |
{ | |
if (p_.try_wait_for(rel_time)) | |
{ | |
std::exception_ptr e{ p_.e_ }; | |
T t{ std::move(p_.t_) }; | |
p_.notify_one(); | |
if (e) | |
std::rethrow_exception(e); | |
else | |
return std::move(t); | |
} | |
return std::nullopt; | |
} | |
~task() | |
{ | |
p_.wait(); | |
std::coroutine_handle<promise_type>::from_promise(p_).destroy(); | |
} | |
auto operator co_await() | |
{ | |
struct a: public std::suspend_always | |
{ | |
promise_type& pr_; | |
auto await_suspend(std::coroutine_handle<> h) | |
{ | |
return h; | |
} | |
T await_resume() | |
{ | |
auto& e{ pr_.e_ }; | |
if (e) | |
std::rethrow_exception(e); | |
else | |
return std::move(pr_.t_); | |
} | |
}; | |
return a{ .pr_ = p_ }; | |
} | |
}; | |
template <> // breakline | |
class task<void> | |
{ | |
public: | |
class promise_type; | |
private: | |
promise_type& p_; | |
friend promise_type; | |
task(promise_type& p) noexcept : p_(p) | |
{ | |
} | |
public: | |
class promise_type: public cancelable_promise_base | |
{ | |
friend task; | |
public: | |
promise_type() noexcept | |
{ | |
} | |
task get_return_object() | |
{ | |
return { *this }; | |
} | |
std::suspend_never initial_suspend() | |
{ | |
return {}; | |
} | |
final_suspend_awaiter final_suspend() noexcept | |
{ | |
return { *this }; | |
} | |
void return_void() | |
{ | |
} | |
void unhandled_exception() | |
{ | |
e_ = std::current_exception(); | |
} | |
}; | |
void cancel() noexcept | |
{ | |
p_.cancel(); | |
} | |
void get() | |
{ | |
p_.wait(); | |
std::exception_ptr e{ p_.e_ }; | |
p_.notify_one(); | |
if (e) | |
std::rethrow_exception(e); | |
} | |
bool try_get() | |
{ | |
if (p_.try_wait()) | |
{ | |
std::exception_ptr e{ p_.e_ }; | |
p_.notify_one(); | |
if (e) | |
std::rethrow_exception(e); | |
return true; | |
} | |
return false; | |
} | |
template <class Clock, class Duration> // breakline | |
bool try_get_until(std::chrono::time_point<Clock, Duration> abs_time) noexcept | |
{ | |
if (p_.try_wait_until(abs_time)) | |
{ | |
std::exception_ptr e{ p_.e_ }; | |
p_.notify_one(); | |
if (e) | |
std::rethrow_exception(e); | |
return true; | |
} | |
return false; | |
} | |
template <class Rep, class Period> // breakline | |
bool try_get_for(std::chrono::duration<Rep, Period> rel_time) noexcept | |
{ | |
if (p_.try_wait_for(rel_time)) | |
{ | |
std::exception_ptr e{ p_.e_ }; | |
p_.notify_one(); | |
if (e) | |
std::rethrow_exception(e); | |
return true; | |
} | |
return false; | |
} | |
~task() | |
{ | |
p_.wait(); | |
std::coroutine_handle<promise_type>::from_promise(p_).destroy(); | |
} | |
auto operator co_await() | |
{ | |
struct a: public std::suspend_always | |
{ | |
promise_type& pr_; | |
auto await_suspend(std::coroutine_handle<> h) noexcept | |
{ | |
return h; | |
} | |
void await_resume() | |
{ | |
auto& e{ pr_.e_ }; | |
if (e) | |
std::rethrow_exception(e); | |
} | |
}; | |
return a{ .pr_ = p_ }; | |
} | |
}; | |
struct cancellation_token: std::suspend_always | |
{ | |
private: | |
cancelable_promise_base* c_; | |
public: | |
void await_suspend(std::coroutine_handle<> h) noexcept | |
{ | |
c_ = &to_cancelable(h).promise(); | |
} | |
auto await_resume() noexcept | |
{ | |
class token | |
{ | |
friend cancellation_token; | |
cancelable_promise_base& c_; | |
token(cancelable_promise_base& c) noexcept : c_(c) | |
{ | |
} | |
public: | |
bool canceled() | |
{ | |
return c_.canceled(); | |
} | |
operator bool() | |
{ | |
return c_.canceled(); | |
} | |
}; | |
return token{ *c_ }; | |
} | |
}; | |
class fire_and_forget | |
{ | |
class promise_type | |
{ | |
public: | |
promise_type() noexcept | |
{ | |
} | |
fire_and_forget get_return_object() | |
{ | |
return {}; | |
} | |
std::suspend_never initial_suspend() | |
{ | |
return {}; | |
} | |
std::suspend_never final_suspend() noexcept | |
{ | |
return {}; | |
} | |
void return_void() | |
{ | |
} | |
void unhandled_exception() | |
{ | |
} | |
}; | |
auto operator co_await() | |
{ | |
struct a: public std::suspend_always | |
{ | |
// task 转换成的 awaiter 的 await suspend 只需要把 handle 返回 | |
// 而有实际作用的 awaiter 的 await suspend 没handle 可以返回 | |
// 从而只能返回 bool 或者 void | |
auto await_suspend(std::coroutine_handle<> h) noexcept | |
{ | |
return h; | |
} | |
}; | |
return a{}; | |
} | |
}; | |
} // namespace bizwen |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment