#pragma once #ifndef MIJIN_ASYNC_COROUTINE_HPP_INCLUDED #define MIJIN_ASYNC_COROUTINE_HPP_INCLUDED 1 #include #include #include #include #include #include #include #include "./future.hpp" #include "./message_queue.hpp" #include "../container/optional.hpp" #include "../internal/common.hpp" #include "../memory/memutil.hpp" #include "../util/flag.hpp" #include "../util/iterators.hpp" #include "../util/misc.hpp" #include "../util/scope_guard.hpp" #include "../util/traits.hpp" #if !defined(MIJIN_COROUTINE_ENABLE_DEBUG_INFO) # define MIJIN_COROUTINE_ENABLE_DEBUG_INFO 0 // Capture stack each time a coroutine is started. Warning, expensive! // TODO: maybe implement a lighter version only storing the return address? #endif #if !defined(MIJIN_COROUTINE_ENABLE_EXCEPTIONS) # define MIJIN_COROUTINE_ENABLE_EXCEPTIONS 0 #endif #if MIJIN_COROUTINE_ENABLE_DEBUG_INFO #include "../debug/stacktrace.hpp" #endif #if !defined(MIJIN_COROUTINE_ENABLE_EXCEPTION_HANDLING) # define MIJIN_COROUTINE_ENABLE_EXCEPTION_HANDLING MIJIN_ENABLE_EXCEPTIONS #elif !__cpp_exceptions # error "Coroutine exception handling enabled, but exceptions are disabled." #endif #if !defined(MIJIN_COROUTINE_ENABLE_CANCEL) # define MIJIN_COROUTINE_ENABLE_CANCEL MIJIN_ENABLE_EXCEPTIONS #elif !__cpp_exceptions # error "Cancelling tasks requires exceptions to be anbled." #endif namespace mijin { // // public defines // // // public types // enum class TaskStatus { SUSPENDED = 0, RUNNING = 1, WAITING = 2, FINISHED = 3, YIELDED = 4 }; // forward declarations template struct TaskState; template typename TAllocator = MIJIN_DEFAULT_ALLOCATOR> class TaskLoop; template typename TAllocator = MIJIN_DEFAULT_ALLOCATOR> class TaskBase; #if MIJIN_COROUTINE_ENABLE_CANCEL struct TaskCancelled : std::exception {}; #endif namespace impl { inline void throwIfCancelled(); } // namespace impl class TaskHandle { private: std::weak_ptr state_; public: TaskHandle() = default; explicit TaskHandle(std::weak_ptr state) MIJIN_NOEXCEPT : state_(std::move(state)) {} TaskHandle(const TaskHandle&) = default; TaskHandle(TaskHandle&&) = default; TaskHandle& operator=(const TaskHandle&) = default; TaskHandle& operator=(TaskHandle&&) = default; bool operator==(const TaskHandle& other) const MIJIN_NOEXCEPT { return !state_.owner_before(other.state_) && !other.state_.owner_before(state_); } bool operator!=(const TaskHandle& other) const MIJIN_NOEXCEPT { return !(*this == other); } [[nodiscard]] bool isValid() const MIJIN_NOEXCEPT { return !state_.expired(); } inline void cancel() const MIJIN_NOEXCEPT; [[nodiscard]] inline Optional getLocation() const MIJIN_NOEXCEPT; #if MIJIN_COROUTINE_ENABLE_DEBUG_INFO inline Optional getCreationStack() const MIJIN_NOEXCEPT; #endif }; struct TaskSharedState { std::atomic_bool cancelled_ = false; TaskHandle subTask; std::source_location sourceLoc; #if MIJIN_COROUTINE_ENABLE_DEBUG_INFO Stacktrace creationStack_; #endif }; template struct TaskState { Optional value; std::exception_ptr exception; TaskStatus status = TaskStatus::SUSPENDED; TaskState() = default; TaskState(const TaskState&) = default; TaskState(TaskState&&) MIJIN_NOEXCEPT = default; inline TaskState(T _value, TaskStatus _status) MIJIN_NOEXCEPT : value(std::move(_value)), status(_status) {} inline TaskState(std::exception_ptr _exception) MIJIN_NOEXCEPT : exception(std::move(_exception)), status(TaskStatus::FINISHED) {} TaskState& operator=(const TaskState&) = default; TaskState& operator=(TaskState&&) MIJIN_NOEXCEPT = default; }; template<> struct TaskState { std::exception_ptr exception; TaskStatus status = TaskStatus::SUSPENDED; TaskState() = default; TaskState(const TaskState&) = default; TaskState(TaskState&&) MIJIN_NOEXCEPT = default; inline TaskState(TaskStatus _status) MIJIN_NOEXCEPT : status(_status) {} inline TaskState(std::exception_ptr _exception) MIJIN_NOEXCEPT : exception(std::move(_exception)), status(TaskStatus::FINISHED) {} TaskState& operator=(const TaskState&) = default; TaskState& operator=(TaskState&&) MIJIN_NOEXCEPT = default; }; namespace impl { template struct TaskReturn { template constexpr void return_value(TArgs&&... args) MIJIN_NOEXCEPT { (static_cast(*this).state_) = TaskState(TReturn(std::forward(args)...), TaskStatus::FINISHED); } constexpr void return_value(TReturn value) MIJIN_NOEXCEPT { (static_cast(*this).state_) = TaskState(TReturn(std::move(value)), TaskStatus::FINISHED); } constexpr void unhandled_exception() MIJIN_NOEXCEPT { (static_cast(*this).state_) = TaskState(std::current_exception()); } }; template struct TaskReturn { constexpr void return_void() MIJIN_NOEXCEPT { static_cast(*this).state_.status = TaskStatus::FINISHED; } constexpr void unhandled_exception() MIJIN_NOEXCEPT { (static_cast(*this).state_) = TaskState(std::current_exception()); } }; } template typename TAllocator = MIJIN_DEFAULT_ALLOCATOR> using TaskFuture = Future; template typename TAllocator = MIJIN_DEFAULT_ALLOCATOR> using TaskFuturePtr = FuturePtr; template typename TAllocator> struct TaskAwaitableFuture { TaskFuturePtr future; [[nodiscard]] constexpr bool await_ready() const MIJIN_NOEXCEPT { return future->ready(); } constexpr void await_suspend(std::coroutine_handle<>) const MIJIN_NOEXCEPT {} constexpr TValue await_resume() const { impl::throwIfCancelled(); if constexpr (std::is_same_v) { return; } else { return std::move(future->get()); } } }; template struct TaskAwaitableSignal { std::shared_ptr> data; [[nodiscard]] constexpr bool await_ready() const MIJIN_NOEXCEPT { return false; } constexpr void await_suspend(std::coroutine_handle<>) const MIJIN_NOEXCEPT {} inline auto& await_resume() const { impl::throwIfCancelled(); return *data; } }; template struct TaskAwaitableSignal { std::shared_ptr data; [[nodiscard]] constexpr bool await_ready() const MIJIN_NOEXCEPT { return false; } constexpr void await_suspend(std::coroutine_handle<>) const MIJIN_NOEXCEPT {} constexpr auto& await_resume() const { impl::throwIfCancelled(); return *data; } }; template<> struct TaskAwaitableSignal<> { [[nodiscard]] constexpr bool await_ready() const MIJIN_NOEXCEPT { return false; } constexpr void await_suspend(std::coroutine_handle<>) const MIJIN_NOEXCEPT {} inline void await_resume() const { impl::throwIfCancelled(); } }; struct TaskAwaitableSuspend { [[nodiscard]] constexpr bool await_ready() const MIJIN_NOEXCEPT { return false; } constexpr void await_suspend(std::coroutine_handle<>) const MIJIN_NOEXCEPT {} inline void await_resume() const { impl::throwIfCancelled(); } }; namespace impl { template using default_is_valid = T::default_is_valid_t; } template typename TAllocator> struct TaskAllocatorTraits { static constexpr bool default_is_valid_v = detect_or_t>::value; template static TAllocator create() { auto taskLoop = TaskLoop::currentOpt(); if (taskLoop != nullptr) { return TAllocator(taskLoop->getAllocator()); } if constexpr (std::is_default_constructible_v>) { return TAllocator(); } else { MIJIN_FATAL("Could not create task allocator."); } } }; template<> struct TaskAllocatorTraits { static constexpr bool default_is_valid_v = true; template static std::allocator create() noexcept { return std::allocator(); } }; template typename TAllocator, typename T> TAllocator makeTaskAllocator() { return TaskAllocatorTraits::template create(); } template typename TAllocator = MIJIN_DEFAULT_ALLOCATOR> struct TaskPromise : impl::TaskReturn> { using handle_t = std::coroutine_handle; using task_t = typename TTraits::task_t; using result_t = typename TTraits::result_t; [[no_unique_address]] TAllocator allocator_; TaskState state_; std::shared_ptr sharedState_; TaskLoop* loop_ = nullptr; explicit TaskPromise(TAllocator allocator = makeTaskAllocator()) MIJIN_NOEXCEPT_IF(std::is_nothrow_move_constructible_v>) : allocator_(std::move(allocator)), sharedState_(std::allocate_shared(TAllocator(allocator_))) {} constexpr task_t get_return_object() MIJIN_NOEXCEPT { return task_t(handle_t::from_promise(*this)); } constexpr TaskAwaitableSuspend initial_suspend() MIJIN_NOEXCEPT { return {}; } constexpr std::suspend_always final_suspend() noexcept { return {}; } // note: this must always be noexcept, no matter what // template // constexpr std::suspend_always yield_value(TValue value) MIJIN_NOEXCEPT { // *state_ = TaskState(std::move(value), TaskStatus::YIELDED); // return {}; // } // TODO: implement yielding (can't use futures for this) // constexpr void unhandled_exception() MIJIN_NOEXCEPT {} template typename TAllocator2> auto await_transform(TaskFuturePtr future, std::source_location sourceLoc = std::source_location::current()) MIJIN_NOEXCEPT { MIJIN_ASSERT(loop_ != nullptr, "Cannot await future outside of a loop!"); sharedState_->sourceLoc = std::move(sourceLoc); TaskAwaitableFuture awaitable{future}; if (!awaitable.await_ready()) { state_.status = TaskStatus::WAITING; future->sigSet.connect([this, future]() mutable { state_.status = TaskStatus::SUSPENDED; }, Oneshot::YES); } return awaitable; } template auto await_transform(TaskBase task, std::source_location sourceLoc = std::source_location::current()) MIJIN_NOEXCEPT { MIJIN_ASSERT(loop_ != nullptr, "Cannot await another task outside of a loop!"); // NOLINT(clang-analyzer-core.UndefinedBinaryOperatorResult) auto future = delayEvaluation(loop_)->addTaskImpl(std::move(task), &sharedState_->subTask); // hackidyhack: delay evaluation of the type of loop_ as it is only forward-declared here return await_transform(future, std::move(sourceLoc)); } template auto await_transform(Signal& signal, std::source_location sourceLoc = std::source_location::current()) MIJIN_NOEXCEPT { auto data = std::make_shared>(); sharedState_->sourceLoc = std::move(sourceLoc); signal.connect([this, data](TFirstArg arg0, TSecondArg arg1, TArgs... args) mutable { *data = std::make_tuple(std::move(arg0), std::move(arg1), std::move(args)...); state_.status = TaskStatus::SUSPENDED; }, Oneshot::YES); TaskAwaitableSignal awaitable{data}; state_.status = TaskStatus::WAITING; return awaitable; } template auto await_transform(Signal& signal, std::source_location sourceLoc = std::source_location::current()) MIJIN_NOEXCEPT { auto data = std::make_shared(); sharedState_->sourceLoc = std::move(sourceLoc); signal.connect([this, data](TFirstArg arg0) mutable { *data = std::move(arg0); state_.status = TaskStatus::SUSPENDED; }, Oneshot::YES); TaskAwaitableSignal awaitable{data}; state_.status = TaskStatus::WAITING; return awaitable; } auto await_transform(Signal<>& signal, std::source_location sourceLoc = std::source_location::current()) MIJIN_NOEXCEPT { sharedState_->sourceLoc = std::move(sourceLoc); signal.connect([this]() { state_.status = TaskStatus::SUSPENDED; }, Oneshot::YES); TaskAwaitableSignal<> awaitable{}; state_.status = TaskStatus::WAITING; return awaitable; } std::suspend_always await_transform(std::suspend_always, std::source_location sourceLoc = std::source_location::current()) MIJIN_NOEXCEPT { sharedState_->sourceLoc = std::move(sourceLoc); state_.status = TaskStatus::SUSPENDED; return std::suspend_always(); } std::suspend_never await_transform(std::suspend_never, std::source_location sourceLoc = std::source_location::current()) MIJIN_NOEXCEPT { sharedState_->sourceLoc = std::move(sourceLoc); return std::suspend_never(); } TaskAwaitableSuspend await_transform(TaskAwaitableSuspend, std::source_location sourceLoc = std::source_location::current()) MIJIN_NOEXCEPT { sharedState_->sourceLoc = std::move(sourceLoc); state_.status = TaskStatus::SUSPENDED; return TaskAwaitableSuspend(); } // make sure the allocators are also used for the promise itself void* operator new(std::size_t size) { return makeTaskAllocator().allocate((size - 1) / sizeof(std::max_align_t) + 1); } void operator delete(void* ptr, std::size_t size) noexcept { TaskPromise* self = static_cast(ptr); self->allocator_.deallocate(static_cast(ptr), (size - 1) / sizeof(std::max_align_t) + 1); } }; template typename TAllocator> class [[nodiscard("Tasks should either we awaited or added to a loop.")]] TaskBase { public: using task_t = TaskBase; using result_t = TResult; struct Traits { using task_t = TaskBase; using result_t = TResult; }; public: using promise_type = TaskPromise; using handle_t = typename promise_type::handle_t; private: handle_t handle_; public: constexpr explicit TaskBase(handle_t handle) MIJIN_NOEXCEPT : handle_(handle) { #if MIJIN_COROUTINE_ENABLE_DEBUG_INFO if (Result stacktrace = captureStacktrace(2); stacktrace.isSuccess()) { handle_.promise().sharedState_->creationStack_ = *stacktrace; } #endif } TaskBase(const TaskBase&) = delete; TaskBase(TaskBase&& other) MIJIN_NOEXCEPT : handle_(std::exchange(other.handle_, nullptr)) {} ~TaskBase() MIJIN_NOEXCEPT; public: TaskBase& operator=(const TaskBase&) = delete; TaskBase& operator=(TaskBase&& other) MIJIN_NOEXCEPT { if (handle_) { handle_.destroy(); } handle_ = std::exchange(other.handle_, nullptr); return *this; } [[nodiscard]] constexpr bool operator==(const TaskBase& other) const MIJIN_NOEXCEPT { return handle_ == other.handle_; } [[nodiscard]] constexpr bool operator!=(const TaskBase& other) const MIJIN_NOEXCEPT { return handle_ != other.handle_; } public: [[nodiscard]] constexpr TaskState& state() MIJIN_NOEXCEPT { return handle_.promise().state_; } constexpr TaskState& resume() { state().status = TaskStatus::RUNNING; handle_.resume(); return state(); } constexpr std::shared_ptr& sharedState() MIJIN_NOEXCEPT { return handle_.promise().sharedState_; } private: [[nodiscard]] constexpr handle_t handle() const MIJIN_NOEXCEPT { return handle_; } [[nodiscard]] constexpr TaskLoop* getLoop() MIJIN_NOEXCEPT { return handle_.promise().loop_; } constexpr void setLoop(TaskLoop* loop) MIJIN_NOEXCEPT { // MIJIN_ASSERT(handle_.promise().loop_ == nullptr // || handle_.promise().loop_ == loop // || loop == nullptr, "Task already has a loop assigned!"); handle_.promise().loop_ = loop; } friend class TaskLoop; template typename TAllocator2> friend class WrappedTask; }; template struct is_task : std::false_type {}; template typename TAllocator> struct is_task> : std::true_type {}; template inline constexpr bool is_task_v = is_task::value; template concept task_type = is_task_v; template typename TAllocator = MIJIN_DEFAULT_ALLOCATOR> class WrappedTaskBase { public: virtual ~WrappedTaskBase() = default; public: virtual TaskStatus status() MIJIN_NOEXCEPT = 0; virtual std::exception_ptr exception() MIJIN_NOEXCEPT = 0; // virtual std::any result() MIJIN_NOEXCEPT = 0; virtual void resume() = 0; virtual void* raw() MIJIN_NOEXCEPT = 0; virtual std::coroutine_handle<> handle() MIJIN_NOEXCEPT = 0; virtual void setLoop(TaskLoop* loop) MIJIN_NOEXCEPT = 0; virtual std::shared_ptr& sharedState() MIJIN_NOEXCEPT = 0; [[nodiscard]] inline bool canResume() { const TaskStatus stat = status(); return (stat == TaskStatus::SUSPENDED || stat == TaskStatus::YIELDED); } }; template typename TAllocator = MIJIN_DEFAULT_ALLOCATOR> class WrappedTask : public WrappedTaskBase { private: TTask task_; public: constexpr explicit WrappedTask(TTask&& task) MIJIN_NOEXCEPT : task_(std::move(task)) {} WrappedTask(const WrappedTask&) = delete; WrappedTask(WrappedTask&&) MIJIN_NOEXCEPT = default; public: WrappedTask& operator=(const WrappedTask&) = delete; WrappedTask& operator=(WrappedTask&&) MIJIN_NOEXCEPT = default; public: TaskStatus status() MIJIN_NOEXCEPT override { return task_.state().status; } std::exception_ptr exception() MIJIN_NOEXCEPT override { return task_.state().exception; } // std::any result() MIJIN_NOEXCEPT // { // if constexpr (std::is_same_v) { // return {}; // } // else { // return std::any(task_.state().value); // } // } void resume() override { task_.resume(); } void* raw() MIJIN_NOEXCEPT override { return &task_; } std::coroutine_handle<> handle() MIJIN_NOEXCEPT override { return task_.handle(); } void setLoop(TaskLoop* loop) MIJIN_NOEXCEPT override { task_.setLoop(loop); } virtual std::shared_ptr& sharedState() MIJIN_NOEXCEPT override { return task_.sharedState(); } }; template typename TAllocator> auto wrapTask(TAllocator> allocator, TTask&& task) { using wrapped_task_t = WrappedTask; using deleter_t = AllocatorDeleter>; using allocator_t = TAllocator; wrapped_task_t* ptr = ::new (allocator.allocate(1)) wrapped_task_t(std::forward(task)); return std::unique_ptr(ptr, AllocatorDeleter(std::move(allocator))); } template typename TAllocator> class TaskLoop { public: MIJIN_DEFINE_FLAG(CanContinue); MIJIN_DEFINE_FLAG(IgnoreWaiting); using wrapped_task_t = WrappedTaskBase; using wrapped_allocator_t = TAllocator; using wrapped_deleter_t = AllocatorDeleter; using wrapped_task_base_ptr_t = std::unique_ptr; struct StoredTask { using set_future_t = std::function; wrapped_task_base_ptr_t task; set_future_t setFuture; std::any resultData; StoredTask(wrapped_task_base_ptr_t&& task_, set_future_t&& setFuture_, std::any&& resultData_) : task(std::move(task_)), setFuture(std::move(setFuture_)), resultData(std::move(resultData_)) {} template StoredTask(TAllocator allocator_) : task(nullptr, wrapped_deleter_t(wrapped_allocator_t(allocator_))) {} }; using exception_handler_t = std::function; using allocator_t = TAllocator; protected: using task_vector_t = std::vector>; template using wrapped_task_ptr_t = std::unique_ptr>; exception_handler_t uncaughtExceptionHandler_; [[no_unique_address]] allocator_t allocator_; public: explicit TaskLoop(allocator_t allocator = {}) MIJIN_NOEXCEPT_IF(std::is_nothrow_move_constructible_v) : allocator_(std::move(allocator)) {}; TaskLoop(const TaskLoop&) = delete; TaskLoop(TaskLoop&&) = delete; virtual ~TaskLoop() MIJIN_NOEXCEPT = default; [[nodiscard]] const allocator_t& getAllocator() const MIJIN_NOEXCEPT { return allocator_; } TaskLoop& operator=(const TaskLoop&) = delete; TaskLoop& operator=(TaskLoop&&) = delete; void setUncaughtExceptionHandler(exception_handler_t handler) MIJIN_NOEXCEPT { uncaughtExceptionHandler_ = std::move(handler); } template TaskFuturePtr addTaskImpl(TaskBase task, TaskHandle* outHandle) MIJIN_NOEXCEPT; template TaskFuturePtr addTask(TaskBase task, TaskHandle* outHandle = nullptr) MIJIN_NOEXCEPT { static_assert(TaskAllocatorTraits::default_is_valid_v, "Allocator is not valid when default constructed, use makeTask() instead."); return addTaskImpl(std::move(task), outHandle); } template auto makeTask(TCoro&& coro, TaskHandle& outHandle, TArgs&&... args) MIJIN_NOEXCEPT; template auto makeTask(TCoro&& coro, TArgs&&... args) MIJIN_NOEXCEPT { TaskHandle dummy; return makeTask(std::forward(coro), dummy, std::forward(args)...); } virtual void transferCurrentTask(TaskLoop& otherLoop) MIJIN_NOEXCEPT = 0; virtual void addStoredTask(StoredTask&& storedTask) MIJIN_NOEXCEPT = 0; [[nodiscard]] static TaskLoop& current() MIJIN_NOEXCEPT; [[nodiscard]] static TaskLoop* currentOpt() MIJIN_NOEXCEPT; protected: inline TaskStatus tickTask(StoredTask& task); protected: static inline TaskLoop*& currentLoopStorage() MIJIN_NOEXCEPT; template static inline void setFutureHelper(StoredTask& storedTask) MIJIN_NOEXCEPT_IF(!MIJIN_COROUTINE_ENABLE_EXCEPTIONS); }; template using Task = TaskBase; template typename TAllocator = MIJIN_DEFAULT_ALLOCATOR> class BaseSimpleTaskLoop : public TaskLoop { private: using base_t = TaskLoop; using typename TaskLoop::task_vector_t; using typename TaskLoop::allocator_t; using typename TaskLoop::StoredTask; using typename TaskLoop::CanContinue; using typename TaskLoop::IgnoreWaiting; using base_t::allocator_; task_vector_t tasks_; task_vector_t newTasks_; task_vector_t::iterator currentTask_; MessageQueue queuedTasks_; std::thread::id threadId_; public: explicit BaseSimpleTaskLoop(const allocator_t& allocator = {}) MIJIN_NOEXCEPT_IF(std::is_nothrow_copy_constructible_v) : base_t(std::move(allocator)), tasks_(TAllocator(allocator_)), newTasks_(TAllocator(allocator_)), queuedTasks_(constructArray::BUFFER_SIZE>(allocator_)) {} public: // TaskLoop implementation void transferCurrentTask(TaskLoop& otherLoop) MIJIN_NOEXCEPT override; void addStoredTask(StoredTask&& storedTask) MIJIN_NOEXCEPT override; public: // public interface [[nodiscard]] constexpr bool empty() const MIJIN_NOEXCEPT { return tasks_.empty() && newTasks_.empty(); } [[nodiscard]] constexpr std::size_t getNumTasks() const MIJIN_NOEXCEPT { return tasks_.size() + newTasks_.size(); } [[nodiscard]] std::size_t getActiveTasks() const MIJIN_NOEXCEPT; inline CanContinue tick(); inline void runUntilDone(IgnoreWaiting ignoreWaiting = IgnoreWaiting::NO); inline void cancelAllTasks() MIJIN_NOEXCEPT; [[nodiscard]] inline std::vector> getAllTasks() const MIJIN_NOEXCEPT; private: inline void assertCorrectThread() { MIJIN_ASSERT(threadId_ == std::thread::id() || threadId_ == std::this_thread::get_id(), "Unsafe to TaskLoop from different thread!"); } }; using SimpleTaskLoop = BaseSimpleTaskLoop<>; template typename TAllocator = MIJIN_DEFAULT_ALLOCATOR> class BaseMultiThreadedTaskLoop : public TaskLoop { private: using base_t = TaskLoop; using typename base_t::task_vector_t; using typename base_t::allocator_t; using typename base_t::StoredTask; using base_t::allocator_; task_vector_t parkedTasks_; // buffer for tasks that don't fit into readyTasks_ MessageQueue queuedTasks_; // tasks that should be appended to parked tasks MessageQueue readyTasks_; // task queue to send tasks to a worker thread MessageQueue returningTasks_; // task that have executed on a worker thread and return for further processing std::jthread managerThread_; std::vector> workerThreads_; public: explicit BaseMultiThreadedTaskLoop(allocator_t allocator = {}) MIJIN_NOEXCEPT_IF(std::is_nothrow_copy_constructible_v) : base_t(std::move(allocator)), parkedTasks_(TAllocator(allocator_)), queuedTasks_(constructArray::BUFFER_SIZE>(allocator_)), readyTasks_(constructArray::BUFFER_SIZE>(allocator_)), returningTasks_(constructArray::BUFFER_SIZE>(allocator_)), workerThreads_(TAllocator(allocator_)) {} public: // TaskLoop implementation void transferCurrentTask(TaskLoop& otherLoop) MIJIN_NOEXCEPT override; void addStoredTask(StoredTask&& storedTask) MIJIN_NOEXCEPT override; public: // public interface void start(std::size_t numWorkerThreads); void stop(); private: // private stuff void managerThread(std::stop_token stopToken); void workerThread(std::stop_token stopToken, std::size_t workerId); static StoredTask*& getCurrentTask() { static thread_local StoredTask* task = nullptr; return task; } }; using MultiThreadedTaskLoop = BaseMultiThreadedTaskLoop<>; // // public functions // namespace impl { extern thread_local std::shared_ptr gCurrentTaskState; inline void throwIfCancelled() { #if MIJIN_COROUTINE_ENABLE_CANCEL if (gCurrentTaskState->cancelled_) { throw TaskCancelled(); } #endif } } void TaskHandle::cancel() const MIJIN_NOEXCEPT { if (std::shared_ptr state = state_.lock()) { state->cancelled_ = true; state->subTask.cancel(); } } Optional TaskHandle::getLocation() const MIJIN_NOEXCEPT { if (std::shared_ptr state = state_.lock()) { return state->sourceLoc; } return NULL_OPTIONAL; } #if MIJIN_COROUTINE_ENABLE_DEBUG_INFO Optional TaskHandle::getCreationStack() const MIJIN_NOEXCEPT { if (std::shared_ptr state = state_.lock()) { return state->creationStack_; } return NULL_OPTIONAL; } #endif // MIJIN_COROUTINE_ENABLE_DEBUG_INFO template typename TAllocator> TaskBase::~TaskBase() MIJIN_NOEXCEPT { if (handle_) { handle_.destroy(); } } template typename TAllocator> template TaskFuturePtr TaskLoop::addTaskImpl(TaskBase task, TaskHandle* outHandle) MIJIN_NOEXCEPT { MIJIN_ASSERT(!task.getLoop(), "Attempting to add task that already has a loop!"); task.setLoop(this); TaskFuturePtr future = std::allocate_shared>(TAllocator>(allocator_), allocator_); auto setFuture = &setFutureHelper; if (outHandle != nullptr) { *outHandle = TaskHandle(task.sharedState()); } // add tasks to a seperate vector first as we might be running another task right now TAllocator>> allocator(allocator_); addStoredTask(StoredTask(wrapTask(std::move(allocator), std::move(task)), std::move(setFuture), future)); return future; } template typename TAllocator> template auto TaskLoop::makeTask(TCoro&& coro, TaskHandle& outHandle, TArgs&&... args) MIJIN_NOEXCEPT { TaskLoop* previousLoop = currentLoopStorage(); currentLoopStorage() = this; auto result = addTaskImpl(std::invoke(std::forward(coro), std::forward(args)...), &outHandle); currentLoopStorage() = previousLoop; return result; } template typename TAllocator> TaskStatus TaskLoop::tickTask(StoredTask& task) { TaskStatus status = {}; impl::gCurrentTaskState = task.task->sharedState(); do { task.task->resume(); status = task.task ? task.task->status() : TaskStatus::WAITING; // no inner task -> task switch context (and will be removed later) } while (status == TaskStatus::RUNNING); impl::gCurrentTaskState = nullptr; #if MIJIN_COROUTINE_ENABLE_EXCEPTION_HANDLING && !MIJIN_COROUTINE_ENABLE_EXCEPTIONS if (task.task && task.task->exception()) { try { std::rethrow_exception(task.task->exception()); } #if MIJIN_COROUTINE_ENABLE_CANCEL catch(TaskCancelled&) {} // ignore those #endif catch(...) { if (uncaughtExceptionHandler_) { uncaughtExceptionHandler_(std::current_exception()); } else { throw; } } // TODO: handle the exception somehow, others may be waiting return TaskStatus::FINISHED; } #endif // MIJIN_COROUTINE_ENABLE_EXCEPTION_HANDLING if (status == TaskStatus::YIELDED || status == TaskStatus::FINISHED) { try { task.setFuture(task); } catch(TaskCancelled&) {} catch(...) { if (uncaughtExceptionHandler_) { uncaughtExceptionHandler_(std::current_exception()); } else { throw; } } } return status; } template typename TAllocator> /* static */ inline auto TaskLoop::current() MIJIN_NOEXCEPT -> TaskLoop& { MIJIN_ASSERT(currentLoopStorage() != nullptr, "Attempting to fetch current loop while no coroutine is running!"); return *currentLoopStorage(); } template typename TAllocator> /* static */ inline auto TaskLoop::currentOpt() MIJIN_NOEXCEPT -> TaskLoop* { return currentLoopStorage(); } template typename TAllocator> /* static */ auto TaskLoop::currentLoopStorage() MIJIN_NOEXCEPT -> TaskLoop*& { static thread_local TaskLoop* storage = nullptr; return storage; } template typename TAllocator> template /* static */ inline void TaskLoop::setFutureHelper(StoredTask& storedTask) MIJIN_NOEXCEPT_IF(!MIJIN_COROUTINE_ENABLE_EXCEPTIONS) { TaskBase& task = *static_cast*>(storedTask.task->raw()); const auto& future = std::any_cast&>(storedTask.resultData); #if MIJIN_COROUTINE_ENABLE_EXCEPTIONS if (task.state().exception) { if (future.use_count() < 2) { // future has been discarded, but someone must handle the exception std::rethrow_exception(task.state().exception); } future->setException(task.state().exception); return; } #endif if constexpr (!std::is_same_v) { MIJIN_ASSERT(!task.state().value.empty(), "Task did not produce a value?"); future->set(std::move(task.state().value.get())); } else { future->set(); } } template typename TAllocator> inline std::suspend_always switchContext(TaskLoop& taskLoop) { TaskLoop& currentTaskLoop = TaskLoop::current(); if (¤tTaskLoop == &taskLoop) { return {}; } currentTaskLoop.transferCurrentTask(taskLoop); return {}; } template typename TAllocator> void BaseSimpleTaskLoop::transferCurrentTask(TaskLoop& otherLoop) MIJIN_NOEXCEPT { assertCorrectThread(); if (&otherLoop == this) { return; } MIJIN_ASSERT_FATAL(currentTask_ != tasks_.end(), "Trying to call transferCurrentTask() while not running a task!"); // now start the transfer, first disown the task StoredTask storedTask = std::move(*currentTask_); currentTask_->task = nullptr; // just to be sure // then send it over to the other loop otherLoop.addStoredTask(std::move(storedTask)); } template typename TAllocator> void BaseSimpleTaskLoop::addStoredTask(StoredTask&& storedTask) MIJIN_NOEXCEPT { storedTask.task->setLoop(this); if (threadId_ == std::thread::id() || threadId_ == std::this_thread::get_id()) { // same thread, just copy it over if (TaskLoop::currentLoopStorage() != nullptr) { // currently running, can't append to tasks_ directly newTasks_.push_back(std::move(storedTask)); } else { tasks_.push_back(std::move(storedTask)); } } else { // other thread, better be safe queuedTasks_.push(std::move(storedTask)); } } template typename TAllocator> std::size_t BaseSimpleTaskLoop::getActiveTasks() const MIJIN_NOEXCEPT { std::size_t sum = 0; for (const StoredTask& task : mijin::chain(tasks_, newTasks_)) { const TaskStatus status = task.task ? task.task->status() : TaskStatus::FINISHED; if (status == TaskStatus::SUSPENDED || status == TaskStatus::RUNNING) { ++sum; } } return sum; } template typename TAllocator> inline auto BaseSimpleTaskLoop::tick() -> CanContinue { // set current taskloop MIJIN_ASSERT(TaskLoop::currentLoopStorage() == nullptr, "Trying to tick a loop from a coroutine, this is not supported."); TaskLoop::currentLoopStorage() = this; MIJIN_SCOPE_EXIT { TaskLoop::currentLoopStorage() = nullptr; }; threadId_ = std::this_thread::get_id(); // move over all tasks from newTasks for (StoredTask& task : newTasks_) { tasks_.push_back(std::move(task)); } newTasks_.clear(); // also pick up tasks from other threads while(true) { std::optional task = queuedTasks_.tryPop(); if (!task.has_value()) { break; } tasks_.push_back(std::move(*task)); } // remove any tasks that are finished executing auto it = std::remove_if(tasks_.begin(), tasks_.end(), [](StoredTask& task) { return task.task->status() == TaskStatus::FINISHED; }); tasks_.erase(it, tasks_.end()); CanContinue canContinue = CanContinue::NO; // then execute all tasks that can be executed for (currentTask_ = tasks_.begin(); currentTask_ != tasks_.end(); ++currentTask_) { StoredTask& task = *currentTask_; TaskStatus status = task.task->status(); if (status == TaskStatus::WAITING && task.task->sharedState()->cancelled_) { // always continue a cancelled task, even if it was still waiting for a result status = TaskStatus::SUSPENDED; } if (status != TaskStatus::SUSPENDED && status != TaskStatus::YIELDED) { MIJIN_ASSERT(status == TaskStatus::WAITING, "Task with invalid status in task list!"); continue; } status = base_t::tickTask(task); if (status == TaskStatus::SUSPENDED || status == TaskStatus::YIELDED) { canContinue = CanContinue::YES; } } // remove any tasks that have been transferred to another queue it = std::remove_if(tasks_.begin(), tasks_.end(), [](const StoredTask& task) { return task.task == nullptr; }); tasks_.erase(it, tasks_.end()); return canContinue; } template typename TAllocator> void BaseSimpleTaskLoop::runUntilDone(IgnoreWaiting ignoreWaiting) { while (!tasks_.empty() || !newTasks_.empty()) { const CanContinue canContinue = tick(); if (ignoreWaiting && !canContinue) { break; } } } template typename TAllocator> void BaseSimpleTaskLoop::cancelAllTasks() MIJIN_NOEXCEPT { for (StoredTask& task : mijin::chain(tasks_, newTasks_)) { task.task->sharedState()->cancelled_ = true; } for (StoredTask& task : queuedTasks_) { // just discard it (void) task; } } template typename TAllocator> std::vector> BaseSimpleTaskLoop::getAllTasks() const MIJIN_NOEXCEPT { std::vector> result((TAllocator(TaskLoop::allocator_))); for (const StoredTask& task : mijin::chain(tasks_, newTasks_)) { result.emplace_back(task.task->sharedState()); } return result; } template typename TAllocator> void BaseMultiThreadedTaskLoop::managerThread(std::stop_token stopToken) // NOLINT(performance-unnecessary-value-param) { // setCurrentThreadName("Task Manager"); while (!stopToken.stop_requested()) { // first clear out any parked tasks that are actually finished auto itRem = std::remove_if(parkedTasks_.begin(), parkedTasks_.end(), [](StoredTask& task) { return !task.task || task.task->status() == TaskStatus::FINISHED; }); parkedTasks_.erase(itRem, parkedTasks_.end()); // then try to push any task from the buffer into the queue, if possible for (auto it = parkedTasks_.begin(); it != parkedTasks_.end();) { if (!it->task->canResume()) { ++it; continue; } if (readyTasks_.tryPushMaybeMove(*it)) { it = parkedTasks_.erase(it); } else { break; } } // then clear the incoming task queue while (true) { std::optional task = queuedTasks_.tryPop(); if (!task.has_value()) { break; } // try to directly move it into the next queue if (readyTasks_.tryPushMaybeMove(*task)) { continue; } // otherwise park it parkedTasks_.push_back(std::move(*task)); } // next collect tasks returning from the worker threads while (true) { std::optional task = returningTasks_.tryPop(); if (!task.has_value()) { break; } if (task->task == nullptr || task->task->status() == TaskStatus::FINISHED) { continue; // task has been transferred or finished } if (task->task->canResume() && readyTasks_.tryPushMaybeMove(*task)) { continue; // instantly resume, no questions asked } // otherwise park it for future processing parkedTasks_.push_back(std::move(*task)); } std::this_thread::sleep_for(std::chrono::milliseconds(1)); } } template typename TAllocator> void BaseMultiThreadedTaskLoop::workerThread(std::stop_token stopToken, std::size_t workerId) // NOLINT(performance-unnecessary-value-param) { TaskLoop::currentLoopStorage() = this; // forever (on this thread) std::array threadName; (void) std::snprintf(threadName.data(), 16, "Task Worker %lu", static_cast(workerId)); // setCurrentThreadName(threadName.data()); while (!stopToken.stop_requested()) { // try to fetch a task to run std::optional task = readyTasks_.tryPop(); if (!task.has_value()) { std::this_thread::sleep_for(std::chrono::milliseconds(1)); continue; } // run it getCurrentTask() = &*task; impl::gCurrentTaskState = task->task->sharedState(); tickTask(*task); getCurrentTask() = nullptr; impl::gCurrentTaskState = nullptr; // and give it back returningTasks_.push(std::move(*task)); } } template typename TAllocator> void BaseMultiThreadedTaskLoop::transferCurrentTask(TaskLoop& otherLoop) MIJIN_NOEXCEPT { if (&otherLoop == this) { return; } MIJIN_ASSERT_FATAL(getCurrentTask() != nullptr, "Trying to call transferCurrentTask() while not running a task!"); // now start the transfer, first disown the task StoredTask storedTask = std::move(*getCurrentTask()); getCurrentTask()->task = nullptr; // just to be sure // then send it over to the other loop otherLoop.addStoredTask(std::move(storedTask)); } template typename TAllocator> void BaseMultiThreadedTaskLoop::addStoredTask(StoredTask&& storedTask) MIJIN_NOEXCEPT { storedTask.task->setLoop(this); // just assume we are not on the manager thread, as that wouldn't make sense queuedTasks_.push(std::move(storedTask)); } template typename TAllocator> void BaseMultiThreadedTaskLoop::start(std::size_t numWorkerThreads) { managerThread_ = std::jthread([this](std::stop_token stopToken) { managerThread(std::move(stopToken)); }); workerThreads_.reserve(numWorkerThreads); for (std::size_t workerId = 0; workerId < numWorkerThreads; ++workerId) { workerThreads_.emplace_back([this, workerId](std::stop_token stopToken) { workerThread(std::move(stopToken), workerId); }); } } template typename TAllocator> void BaseMultiThreadedTaskLoop::stop() { workerThreads_.clear(); // will also set the stop token managerThread_ = {}; // this too } // utility stuff inline TaskAwaitableSuspend c_suspend() { return TaskAwaitableSuspend(); } template typename TCollection, FutureType TFuture, typename... TTemplateArgs> Task<> c_allDone(const TCollection& futures) { bool allDone = true; do { allDone = true; for (const TFuture& future : futures) { if (future && !future->ready()) { allDone = false; break; } } co_await c_suspend(); } while (!allDone); } template typename TAllocator, typename... TResult> struct AllDoneHelper { TaskLoop& currentTaskLoop; template auto makeFuture(TaskBase&& task, std::array& outHandles) { return currentTaskLoop.addTaskImpl(std::move(task), &outHandles[index]); } template auto makeFutures(TaskBase&&... tasks, std::array& outHandles, std::index_sequence) { return std::make_tuple(makeFuture(std::move(tasks), outHandles)...); } }; template typename TAllocator, typename... TResult> TaskBase, TAllocator> c_allDone(TaskBase&&... tasks) { TaskLoop& currentTaskLoop = TaskLoop::current(); std::tuple futures = std::make_tuple(currentTaskLoop.addTaskImpl(std::move(tasks), nullptr)...); while (!allReady(futures)) { co_await c_suspend(); } co_return getAll(futures); } [[nodiscard]] inline TaskHandle getCurrentTask() MIJIN_NOEXCEPT { MIJIN_ASSERT(impl::gCurrentTaskState != nullptr, "Attempt to call getCurrentTask() outside of task."); return TaskHandle(impl::gCurrentTaskState); } } #endif // MIJIN_ASYNC_COROUTINE_HPP_INCLUDED