From a956560183a61d2608a6e8e35499f56a5f323d25 Mon Sep 17 00:00:00 2001 From: Patrick Wuttke Date: Mon, 23 Jun 2025 00:23:50 +0200 Subject: [PATCH] Made coroutines allocator-aware (let's hope this really works). --- source/mijin/async/coroutine.cpp | 192 +----------- source/mijin/async/coroutine.hpp | 494 ++++++++++++++++++++++++++----- 2 files changed, 429 insertions(+), 257 deletions(-) diff --git a/source/mijin/async/coroutine.cpp b/source/mijin/async/coroutine.cpp index 3ad5681..02f9da3 100644 --- a/source/mijin/async/coroutine.cpp +++ b/source/mijin/async/coroutine.cpp @@ -26,205 +26,15 @@ namespace mijin namespace impl { -thread_local TaskLoop::StoredTask* gCurrentTask = nullptr; +thread_local std::shared_ptr gCurrentTaskState; } // // internal functions // -void MultiThreadedTaskLoop::managerThread(std::stop_token stopToken) // NOLINT(performance-unnecessary-value-param) -{ - setCurrentThreadName("Task Manager"); - - while (!stopToken.stop_requested()) - { - // first clear out any parked tasks that are actually finished - auto itRem = std::remove_if(parkedTasks_.begin(), parkedTasks_.end(), [](StoredTask& task) { - return !task.task || task.task->status() == TaskStatus::FINISHED; - }); - parkedTasks_.erase(itRem, parkedTasks_.end()); - - // then try to push any task from the buffer into the queue, if possible - for (auto it = parkedTasks_.begin(); it != parkedTasks_.end();) - { - if (!it->task->canResume()) - { - ++it; - continue; - } - - if (readyTasks_.tryPushMaybeMove(*it)) { - it = parkedTasks_.erase(it); - } - else { - break; - } - } - - // then clear the incoming task queue - while (true) - { - std::optional task = queuedTasks_.tryPop(); - if (!task.has_value()) { - break; - } - - // try to directly move it into the next queue - if (readyTasks_.tryPushMaybeMove(*task)) { - continue; - } - - // otherwise park it - parkedTasks_.push_back(std::move(*task)); - } - - // next collect tasks returning from the worker threads - while (true) - { - std::optional task = returningTasks_.tryPop(); - if (!task.has_value()) { - break; - } - - if (task->task == nullptr || task->task->status() == TaskStatus::FINISHED) { - continue; // task has been transferred or finished - } - - if (task->task->canResume() && readyTasks_.tryPushMaybeMove(*task)) { - continue; // instantly resume, no questions asked - } - - // otherwise park it for future processing - parkedTasks_.push_back(std::move(*task)); - } - - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } -} - -void MultiThreadedTaskLoop::workerThread(std::stop_token stopToken, std::size_t workerId) // NOLINT(performance-unnecessary-value-param) -{ - currentLoopStorage() = this; // forever (on this thread) - - std::array threadName; - (void) std::snprintf(threadName.data(), 16, "Task Worker %lu", static_cast(workerId)); - setCurrentThreadName(threadName.data()); - - while (!stopToken.stop_requested()) - { - // try to fetch a task to run - std::optional task = readyTasks_.tryPop(); - if (!task.has_value()) - { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - continue; - } - - // run it - impl::gCurrentTask = &*task; - tickTask(*task); - impl::gCurrentTask = nullptr; - - // and give it back - returningTasks_.push(std::move(*task)); - } -} - // // public functions // -void SimpleTaskLoop::transferCurrentTask(TaskLoop& otherLoop) MIJIN_NOEXCEPT -{ - assertCorrectThread(); - - if (&otherLoop == this) { - return; - } - - MIJIN_ASSERT_FATAL(currentTask_ != tasks_.end(), "Trying to call transferCurrentTask() while not running a task!"); - - // now start the transfer, first disown the task - StoredTask storedTask = std::move(*currentTask_); - currentTask_->task = nullptr; // just to be sure - - // then send it over to the other loop - otherLoop.addStoredTask(std::move(storedTask)); -} - -void SimpleTaskLoop::addStoredTask(StoredTask&& storedTask) MIJIN_NOEXCEPT -{ - storedTask.task->setLoop(this); - if (threadId_ == std::thread::id() || threadId_ == std::this_thread::get_id()) - { - // same thread, just copy it over - if (currentLoopStorage() != nullptr) { - // currently running, can't append to tasks_ directly - newTasks_.push_back(std::move(storedTask)); - } - else { - tasks_.push_back(std::move(storedTask)); - } - } - else - { - // other thread, better be safe - queuedTasks_.push(std::move(storedTask)); - } -} - -std::size_t SimpleTaskLoop::getActiveTasks() const MIJIN_NOEXCEPT -{ - std::size_t sum = 0; - for (const StoredTask& task : mijin::chain(tasks_, newTasks_)) - { - const TaskStatus status = task.task ? task.task->status() : TaskStatus::FINISHED; - if (status == TaskStatus::SUSPENDED || status == TaskStatus::RUNNING) - { - ++sum; - } - } - return sum; -} - -void MultiThreadedTaskLoop::transferCurrentTask(TaskLoop& otherLoop) MIJIN_NOEXCEPT -{ - if (&otherLoop == this) { - return; - } - - MIJIN_ASSERT_FATAL(impl::gCurrentTask != nullptr, "Trying to call transferCurrentTask() while not running a task!"); - - // now start the transfer, first disown the task - StoredTask storedTask = std::move(*impl::gCurrentTask); - impl::gCurrentTask->task = nullptr; // just to be sure - - // then send it over to the other loop - otherLoop.addStoredTask(std::move(storedTask)); -} - -void MultiThreadedTaskLoop::addStoredTask(StoredTask&& storedTask) MIJIN_NOEXCEPT -{ - storedTask.task->setLoop(this); - - // just assume we are not on the manager thread, as that wouldn't make sense - queuedTasks_.push(std::move(storedTask)); -} - -void MultiThreadedTaskLoop::start(std::size_t numWorkerThreads) -{ - managerThread_ = std::jthread([this](std::stop_token stopToken) { managerThread(std::move(stopToken)); }); - workerThreads_.reserve(numWorkerThreads); - for (std::size_t workerId = 0; workerId < numWorkerThreads; ++workerId) { - workerThreads_.emplace_back([this, workerId](std::stop_token stopToken) { workerThread(std::move(stopToken), workerId); }); - } -} - -void MultiThreadedTaskLoop::stop() -{ - workerThreads_.clear(); // will also set the stop token - managerThread_ = {}; // this too -} - } // namespace mijin diff --git a/source/mijin/async/coroutine.hpp b/source/mijin/async/coroutine.hpp index 2433af9..06cf809 100644 --- a/source/mijin/async/coroutine.hpp +++ b/source/mijin/async/coroutine.hpp @@ -10,6 +10,7 @@ #endif #include +#include #include #include #include @@ -20,6 +21,7 @@ #include "./message_queue.hpp" #include "../container/optional.hpp" #include "../internal/common.hpp" +#include "../memory/memutil.hpp" #include "../util/flag.hpp" #include "../util/iterators.hpp" #include "../util/traits.hpp" @@ -63,9 +65,10 @@ enum class TaskStatus template struct TaskState; +template typename TAllocator = MIJIN_DEFAULT_ALLOCATOR> class TaskLoop; -template +template typename TAllocator = MIJIN_DEFAULT_ALLOCATOR> class TaskBase; #if MIJIN_COROUTINE_ENABLE_CANCEL @@ -103,6 +106,7 @@ public: } inline void cancel() const MIJIN_NOEXCEPT; + [[nodiscard]] inline Optional getLocation() const MIJIN_NOEXCEPT; #if MIJIN_COROUTINE_ENABLE_DEBUG_INFO inline Optional getCreationStack() const MIJIN_NOEXCEPT; #endif @@ -111,6 +115,7 @@ struct TaskSharedState { std::atomic_bool cancelled_ = false; TaskHandle subTask; + std::source_location sourceLoc; #if MIJIN_COROUTINE_ENABLE_DEBUG_INFO Stacktrace creationStack_; #endif @@ -245,16 +250,61 @@ struct TaskAwaitableSuspend } }; -template +namespace impl +{ +template +using default_is_valid = T::default_is_valid_t; +} + +template typename TAllocator> +struct TaskAllocatorTraits +{ + static constexpr bool default_is_valid_v = detect_or_t>::value; + + template + static TAllocator create() + { + auto taskLoop = TaskLoop::currentOpt(); + if (taskLoop != nullptr) + { + return TAllocator(taskLoop->getAllocator()); + } + return TAllocator(); + } +}; + +template<> +struct TaskAllocatorTraits +{ + static constexpr bool default_is_valid_v = true; + + template + static std::allocator create() noexcept + { + return std::allocator(); + } +}; + +template typename TAllocator, typename T> +TAllocator makeTaskAllocator() +{ + return TaskAllocatorTraits::template create(); +} + +template typename TAllocator = MIJIN_DEFAULT_ALLOCATOR> struct TaskPromise : impl::TaskReturn> { using handle_t = std::coroutine_handle; using task_t = typename TTraits::task_t; using result_t = typename TTraits::result_t; + [[no_unique_address]] TAllocator allocator_; TaskState state_; - std::shared_ptr sharedState_ = std::make_shared(); - TaskLoop* loop_ = nullptr; + std::shared_ptr sharedState_; + TaskLoop* loop_ = nullptr; + + explicit TaskPromise(TAllocator allocator = makeTaskAllocator()) MIJIN_NOEXCEPT_IF(std::is_nothrow_move_constructible_v>) + : allocator_(std::move(allocator)), sharedState_(std::allocate_shared(TAllocator(allocator_))) {} constexpr task_t get_return_object() MIJIN_NOEXCEPT { return task_t(handle_t::from_promise(*this)); } constexpr TaskAwaitableSuspend initial_suspend() MIJIN_NOEXCEPT { return {}; } @@ -271,9 +321,10 @@ struct TaskPromise : impl::TaskReturn - auto await_transform(FuturePtr future) MIJIN_NOEXCEPT + auto await_transform(FuturePtr future, std::source_location sourceLoc = std::source_location::current()) MIJIN_NOEXCEPT { MIJIN_ASSERT(loop_ != nullptr, "Cannot await future outside of a loop!"); + sharedState_->sourceLoc = std::move(sourceLoc); TaskAwaitableFuture awaitable{future}; if (!awaitable.await_ready()) { @@ -287,17 +338,18 @@ struct TaskPromise : impl::TaskReturn - auto await_transform(TaskBase task) MIJIN_NOEXCEPT + auto await_transform(TaskBase task, std::source_location sourceLoc = std::source_location::current()) MIJIN_NOEXCEPT { MIJIN_ASSERT(loop_ != nullptr, "Cannot await another task outside of a loop!"); // NOLINT(clang-analyzer-core.UndefinedBinaryOperatorResult) - auto future = delayEvaluation(loop_)->addTask(std::move(task), &sharedState_->subTask); // hackidyhack: delay evaluation of the type of loop_ as it is only forward-declared here - return await_transform(future); + auto future = delayEvaluation(loop_)->addTaskImpl(std::move(task), &sharedState_->subTask); // hackidyhack: delay evaluation of the type of loop_ as it is only forward-declared here + return await_transform(future, std::move(sourceLoc)); } template - auto await_transform(Signal& signal) MIJIN_NOEXCEPT + auto await_transform(Signal& signal, std::source_location sourceLoc = std::source_location::current()) MIJIN_NOEXCEPT { auto data = std::make_shared>(); + sharedState_->sourceLoc = std::move(sourceLoc); signal.connect([this, data](TFirstArg arg0, TSecondArg arg1, TArgs... args) mutable { *data = std::make_tuple(std::move(arg0), std::move(arg1), std::move(args)...); @@ -309,9 +361,10 @@ struct TaskPromise : impl::TaskReturn - auto await_transform(Signal& signal) MIJIN_NOEXCEPT + auto await_transform(Signal& signal, std::source_location sourceLoc = std::source_location::current()) MIJIN_NOEXCEPT { auto data = std::make_shared(); + sharedState_->sourceLoc = std::move(sourceLoc); signal.connect([this, data](TFirstArg arg0) mutable { *data = std::move(arg0); @@ -322,8 +375,9 @@ struct TaskPromise : impl::TaskReturn& signal) MIJIN_NOEXCEPT + auto await_transform(Signal<>& signal, std::source_location sourceLoc = std::source_location::current()) MIJIN_NOEXCEPT { + sharedState_->sourceLoc = std::move(sourceLoc); signal.connect([this]() { state_.status = TaskStatus::SUSPENDED; @@ -333,24 +387,39 @@ struct TaskPromise : impl::TaskReturnsourceLoc = std::move(sourceLoc); state_.status = TaskStatus::SUSPENDED; return std::suspend_always(); } - std::suspend_never await_transform(std::suspend_never) MIJIN_NOEXCEPT { + std::suspend_never await_transform(std::suspend_never, std::source_location sourceLoc = std::source_location::current()) MIJIN_NOEXCEPT { + sharedState_->sourceLoc = std::move(sourceLoc); return std::suspend_never(); } - TaskAwaitableSuspend await_transform(TaskAwaitableSuspend) MIJIN_NOEXCEPT + TaskAwaitableSuspend await_transform(TaskAwaitableSuspend, std::source_location sourceLoc = std::source_location::current()) MIJIN_NOEXCEPT { + sharedState_->sourceLoc = std::move(sourceLoc); state_.status = TaskStatus::SUSPENDED; return TaskAwaitableSuspend(); } + + // make sure the allocators are also used for the promise itself + void* operator new(std::size_t size) + { + return makeTaskAllocator().allocate((size - 1) / sizeof(std::max_align_t) + 1); + } + + void operator delete(void* ptr, std::size_t size) noexcept + { + TaskPromise* self = static_cast(ptr); + self->allocator_.deallocate(static_cast(ptr), (size - 1) / sizeof(std::max_align_t) + 1); + } }; -template +template typename TAllocator> class [[nodiscard("Tasks should either we awaited or added to a loop.")]] TaskBase { public: @@ -362,7 +431,7 @@ public: using result_t = TResult; }; public: - using promise_type = TaskPromise; + using promise_type = TaskPromise; using handle_t = typename promise_type::handle_t; private: handle_t handle_; @@ -415,11 +484,11 @@ private: [[nodiscard]] constexpr handle_t handle() const MIJIN_NOEXCEPT { return handle_; } [[nodiscard]] - constexpr TaskLoop* getLoop() MIJIN_NOEXCEPT + constexpr TaskLoop* getLoop() MIJIN_NOEXCEPT { return handle_.promise().loop_; } - constexpr void setLoop(TaskLoop* loop) MIJIN_NOEXCEPT + constexpr void setLoop(TaskLoop* loop) MIJIN_NOEXCEPT { // MIJIN_ASSERT(handle_.promise().loop_ == nullptr // || handle_.promise().loop_ == loop @@ -427,12 +496,13 @@ private: handle_.promise().loop_ = loop; } - friend class TaskLoop; + friend class TaskLoop; - template + template typename TAllocator2> friend class WrappedTask; }; +template typename TAllocator = MIJIN_DEFAULT_ALLOCATOR> class WrappedTaskBase { public: @@ -444,7 +514,7 @@ public: virtual void resume() = 0; virtual void* raw() MIJIN_NOEXCEPT = 0; virtual std::coroutine_handle<> handle() MIJIN_NOEXCEPT = 0; - virtual void setLoop(TaskLoop* loop) MIJIN_NOEXCEPT = 0; + virtual void setLoop(TaskLoop* loop) MIJIN_NOEXCEPT = 0; virtual std::shared_ptr& sharedState() MIJIN_NOEXCEPT = 0; [[nodiscard]] inline bool canResume() { @@ -453,8 +523,8 @@ public: } }; -template -class WrappedTask : public WrappedTaskBase +template typename TAllocator = MIJIN_DEFAULT_ALLOCATOR> +class WrappedTask : public WrappedTaskBase { private: TTask task_; @@ -480,24 +550,30 @@ public: void resume() override { task_.resume(); } void* raw() MIJIN_NOEXCEPT override { return &task_; } std::coroutine_handle<> handle() MIJIN_NOEXCEPT override { return task_.handle(); } - void setLoop(TaskLoop* loop) MIJIN_NOEXCEPT override { task_.setLoop(loop); } + void setLoop(TaskLoop* loop) MIJIN_NOEXCEPT override { task_.setLoop(loop); } virtual std::shared_ptr& sharedState() MIJIN_NOEXCEPT override { return task_.sharedState(); } }; -template -std::unique_ptr> wrapTask(TTask&& task) MIJIN_NOEXCEPT +template typename TAllocator> +auto wrapTask(TAllocator> allocator, TTask&& task) { - return std::make_unique>(std::forward(task)); + using wrapped_task_t = WrappedTask; + using deleter_t = AllocatorDeleter>; + using allocator_t = TAllocator; + + wrapped_task_t* ptr = ::new (allocator.allocate(1)) wrapped_task_t(std::forward(task)); + return std::unique_ptr(ptr, AllocatorDeleter(std::move(allocator))); } +template typename TAllocator> class TaskLoop { public: MIJIN_DEFINE_FLAG(CanContinue); MIJIN_DEFINE_FLAG(IgnoreWaiting); - using wrapped_task_t = WrappedTaskBase; - using wrapped_task_base_ptr_t = std::unique_ptr; + using wrapped_task_t = WrappedTaskBase; + using wrapped_task_base_ptr_t = std::unique_ptr>>; struct StoredTask { wrapped_task_base_ptr_t task; @@ -506,31 +582,55 @@ public: }; using exception_handler_t = std::function; + using allocator_t = TAllocator; protected: - using task_vector_t = std::vector; + using task_vector_t = std::vector>; template using wrapped_task_ptr_t = std::unique_ptr>; exception_handler_t uncaughtExceptionHandler_; + [[no_unique_address]] allocator_t allocator_; public: - TaskLoop() MIJIN_NOEXCEPT = default; + explicit TaskLoop(allocator_t allocator = {}) MIJIN_NOEXCEPT_IF(std::is_nothrow_move_constructible_v) + : allocator_(std::move(allocator)) {}; TaskLoop(const TaskLoop&) = delete; TaskLoop(TaskLoop&&) = delete; virtual ~TaskLoop() MIJIN_NOEXCEPT = default; + [[nodiscard]] + const allocator_t& getAllocator() const MIJIN_NOEXCEPT { return allocator_; } + TaskLoop& operator=(const TaskLoop&) = delete; TaskLoop& operator=(TaskLoop&&) = delete; void setUncaughtExceptionHandler(exception_handler_t handler) MIJIN_NOEXCEPT { uncaughtExceptionHandler_ = std::move(handler); } template - inline FuturePtr addTask(TaskBase task, TaskHandle* outHandle = nullptr) MIJIN_NOEXCEPT; + FuturePtr addTaskImpl(TaskBase task, TaskHandle* outHandle) MIJIN_NOEXCEPT; + + template + FuturePtr addTask(TaskBase task, TaskHandle* outHandle = nullptr) MIJIN_NOEXCEPT + { + static_assert(TaskAllocatorTraits::default_is_valid_v, "Allocator is not valid when default constructed, use makeTask() instead."); + return addTaskImpl(std::move(task), outHandle); + } + + template + auto makeTask(TCoro&& coro, TaskHandle& outHandle, TArgs&&... args) MIJIN_NOEXCEPT; + + template + auto makeTask(TCoro&& coro, TArgs&&... args) MIJIN_NOEXCEPT + { + TaskHandle dummy; + return makeTask(std::forward(coro), dummy, std::forward(args)...); + } virtual void transferCurrentTask(TaskLoop& otherLoop) MIJIN_NOEXCEPT = 0; virtual void addStoredTask(StoredTask&& storedTask) MIJIN_NOEXCEPT = 0; [[nodiscard]] static TaskLoop& current() MIJIN_NOEXCEPT; + [[nodiscard]] static TaskLoop* currentOpt() MIJIN_NOEXCEPT; protected: inline TaskStatus tickTask(StoredTask& task); protected: @@ -542,17 +642,28 @@ protected: template using Task = TaskBase; -class SimpleTaskLoop : public TaskLoop +template typename TAllocator = MIJIN_DEFAULT_ALLOCATOR> +class BaseSimpleTaskLoop : public TaskLoop { private: + using base_t = TaskLoop; + using typename TaskLoop::task_vector_t; + using typename TaskLoop::allocator_t; + using typename TaskLoop::StoredTask; + using typename TaskLoop::CanContinue; + using typename TaskLoop::IgnoreWaiting; + + using base_t::allocator_; task_vector_t tasks_; task_vector_t newTasks_; task_vector_t::iterator currentTask_; MessageQueue queuedTasks_; std::thread::id threadId_; - +public: + explicit BaseSimpleTaskLoop(const allocator_t& allocator = {}) MIJIN_NOEXCEPT_IF(std::is_nothrow_copy_constructible_v) + : base_t(std::move(allocator)), tasks_(TAllocator(allocator_)), newTasks_(TAllocator(allocator_)) {} public: // TaskLoop implementation - void transferCurrentTask(TaskLoop& otherLoop) MIJIN_NOEXCEPT override; + void transferCurrentTask(TaskLoop& otherLoop) MIJIN_NOEXCEPT override; void addStoredTask(StoredTask&& storedTask) MIJIN_NOEXCEPT override; public: // public interface @@ -562,23 +673,34 @@ public: // public interface inline CanContinue tick(); inline void runUntilDone(IgnoreWaiting ignoreWaiting = IgnoreWaiting::NO); inline void cancelAllTasks() MIJIN_NOEXCEPT; - [[nodiscard]] inline std::vector getAllTasks() const MIJIN_NOEXCEPT; + [[nodiscard]] inline std::vector> getAllTasks() const MIJIN_NOEXCEPT; private: inline void assertCorrectThread() { MIJIN_ASSERT(threadId_ == std::thread::id() || threadId_ == std::this_thread::get_id(), "Unsafe to TaskLoop from different thread!"); } }; +using SimpleTaskLoop = BaseSimpleTaskLoop<>; -class MultiThreadedTaskLoop : public TaskLoop +template typename TAllocator = MIJIN_DEFAULT_ALLOCATOR> +class BaseMultiThreadedTaskLoop : public TaskLoop { private: + using base_t = TaskLoop; + using typename base_t::task_vector_t; + using typename base_t::allocator_t; + using typename base_t::StoredTask; + + using base_t::allocator_; task_vector_t parkedTasks_; // buffer for tasks that don't fit into readyTasks_ MessageQueue queuedTasks_; // tasks that should be appended to parked tasks MessageQueue readyTasks_; // task queue to send tasks to a worker thread MessageQueue returningTasks_; // task that have executed on a worker thread and return for further processing std::jthread managerThread_; - std::vector workerThreads_; + std::vector> workerThreads_; +public: + explicit BaseMultiThreadedTaskLoop(allocator_t allocator = {}) MIJIN_NOEXCEPT_IF(std::is_nothrow_copy_constructible_v) + : base_t(std::move(allocator)), parkedTasks_(TAllocator(allocator_)), workerThreads_(TAllocator(allocator_)) {} public: // TaskLoop implementation - void transferCurrentTask(TaskLoop& otherLoop) MIJIN_NOEXCEPT override; + void transferCurrentTask(TaskLoop& otherLoop) MIJIN_NOEXCEPT override; void addStoredTask(StoredTask&& storedTask) MIJIN_NOEXCEPT override; public: // public interface @@ -587,7 +709,10 @@ public: // public interface private: // private stuff void managerThread(std::stop_token stopToken); void workerThread(std::stop_token stopToken, std::size_t workerId); + + static thread_local StoredTask* currentTask_ = nullptr; }; +using MultiThreadedTaskLoop = BaseMultiThreadedTaskLoop<>; // // public functions @@ -595,12 +720,12 @@ private: // private stuff namespace impl { -extern thread_local TaskLoop::StoredTask* gCurrentTask; +extern thread_local std::shared_ptr gCurrentTaskState; inline void throwIfCancelled() { #if MIJIN_COROUTINE_ENABLE_CANCEL - if (gCurrentTask->task->sharedState()->cancelled_) + if (gCurrentTaskState->cancelled_) { throw TaskCancelled(); } @@ -617,6 +742,15 @@ void TaskHandle::cancel() const MIJIN_NOEXCEPT } } +Optional TaskHandle::getLocation() const noexcept +{ + if (std::shared_ptr state = state_.lock()) + { + return state->sourceLoc; + } + return NULL_OPTIONAL; +} + #if MIJIN_COROUTINE_ENABLE_DEBUG_INFO Optional TaskHandle::getCreationStack() const MIJIN_NOEXCEPT { @@ -628,8 +762,8 @@ Optional TaskHandle::getCreationStack() const MIJIN_NOEXCEPT } #endif // MIJIN_COROUTINE_ENABLE_DEBUG_INFO -template -TaskBase::~TaskBase() MIJIN_NOEXCEPT +template typename TAllocator> +TaskBase::~TaskBase() MIJIN_NOEXCEPT { if (handle_) { @@ -637,13 +771,14 @@ TaskBase::~TaskBase() MIJIN_NOEXCEPT } } +template typename TAllocator> template -inline FuturePtr TaskLoop::addTask(TaskBase task, TaskHandle* outHandle) MIJIN_NOEXCEPT +FuturePtr TaskLoop::addTaskImpl(TaskBase task, TaskHandle* outHandle) MIJIN_NOEXCEPT { MIJIN_ASSERT(!task.getLoop(), "Attempting to add task that already has a loop!"); task.setLoop(this); - auto future = std::make_shared>(); + auto future = std::allocate_shared>(TAllocator>(allocator_)); auto setFuture = &setFutureHelper; if (outHandle != nullptr) @@ -653,7 +788,7 @@ inline FuturePtr TaskLoop::addTask(TaskBase task, TaskHandle* // add tasks to a seperate vector first as we might be running another task right now addStoredTask(StoredTask{ - .task = wrapTask(std::move(task)), + .task = wrapTask(TAllocator>>(allocator_), std::move(task)), .setFuture = setFuture, .resultData = future }); @@ -661,17 +796,29 @@ inline FuturePtr TaskLoop::addTask(TaskBase task, TaskHandle* return future; } -inline TaskStatus TaskLoop::tickTask(StoredTask& task) +template typename TAllocator> +template +auto TaskLoop::makeTask(TCoro&& coro, TaskHandle& outHandle, TArgs&&... args) MIJIN_NOEXCEPT +{ + TaskLoop* previousLoop = currentLoopStorage(); + currentLoopStorage() = this; + auto result = addTaskImpl(std::invoke(std::forward(coro), std::forward(args)...), &outHandle); + currentLoopStorage() = previousLoop; + return result; +} + +template typename TAllocator> +TaskStatus TaskLoop::tickTask(StoredTask& task) { TaskStatus status = {}; - impl::gCurrentTask = &task; + impl::gCurrentTaskState = task.task->sharedState(); do { task.task->resume(); status = task.task ? task.task->status() : TaskStatus::WAITING; // no inner task -> task switch context (and will be removed later) } while (status == TaskStatus::RUNNING); - impl::gCurrentTask = nullptr; + impl::gCurrentTaskState = nullptr; #if MIJIN_COROUTINE_ENABLE_EXCEPTION_HANDLING if (task.task && task.task->exception()) @@ -706,22 +853,31 @@ inline TaskStatus TaskLoop::tickTask(StoredTask& task) return status; } -/* static */ inline auto TaskLoop::current() MIJIN_NOEXCEPT -> TaskLoop& +template typename TAllocator> +/* static */ inline auto TaskLoop::current() MIJIN_NOEXCEPT -> TaskLoop& { MIJIN_ASSERT(currentLoopStorage() != nullptr, "Attempting to fetch current loop while no coroutine is running!"); return *currentLoopStorage(); } -/* static */ auto TaskLoop::currentLoopStorage() MIJIN_NOEXCEPT -> TaskLoop*& +template typename TAllocator> +/* static */ inline auto TaskLoop::currentOpt() MIJIN_NOEXCEPT -> TaskLoop* +{ + return currentLoopStorage(); +} + +template typename TAllocator> +/* static */ auto TaskLoop::currentLoopStorage() MIJIN_NOEXCEPT -> TaskLoop*& { static thread_local TaskLoop* storage = nullptr; return storage; } +template typename TAllocator> template -/* static */ inline void TaskLoop::setFutureHelper(StoredTask& storedTask) MIJIN_NOEXCEPT +/* static */ inline void TaskLoop::setFutureHelper(StoredTask& storedTask) MIJIN_NOEXCEPT { - TaskBase& task = *static_cast*>(storedTask.task->raw()); + TaskBase& task = *static_cast*>(storedTask.task->raw()); auto future = std::any_cast>(storedTask.resultData); if constexpr (!std::is_same_v) @@ -734,9 +890,10 @@ template } } -inline std::suspend_always switchContext(TaskLoop& taskLoop) +template typename TAllocator> +inline std::suspend_always switchContext(TaskLoop& taskLoop) { - TaskLoop& currentTaskLoop = TaskLoop::current(); + TaskLoop& currentTaskLoop = TaskLoop::current(); if (¤tTaskLoop == &taskLoop) { return {}; } @@ -744,11 +901,68 @@ inline std::suspend_always switchContext(TaskLoop& taskLoop) return {}; } -inline auto SimpleTaskLoop::tick() -> CanContinue +template typename TAllocator> +void BaseSimpleTaskLoop::transferCurrentTask(TaskLoop& otherLoop) MIJIN_NOEXCEPT +{ + assertCorrectThread(); + + if (&otherLoop == this) { + return; + } + + MIJIN_ASSERT_FATAL(currentTask_ != tasks_.end(), "Trying to call transferCurrentTask() while not running a task!"); + + // now start the transfer, first disown the task + StoredTask storedTask = std::move(*currentTask_); + currentTask_->task = nullptr; // just to be sure + + // then send it over to the other loop + otherLoop.addStoredTask(std::move(storedTask)); +} + +template typename TAllocator> +void BaseSimpleTaskLoop::addStoredTask(StoredTask&& storedTask) MIJIN_NOEXCEPT +{ + storedTask.task->setLoop(this); + if (threadId_ == std::thread::id() || threadId_ == std::this_thread::get_id()) + { + // same thread, just copy it over + if (TaskLoop::currentLoopStorage() != nullptr) { + // currently running, can't append to tasks_ directly + newTasks_.push_back(std::move(storedTask)); + } + else { + tasks_.push_back(std::move(storedTask)); + } + } + else + { + // other thread, better be safe + queuedTasks_.push(std::move(storedTask)); + } +} + +template typename TAllocator> +std::size_t BaseSimpleTaskLoop::getActiveTasks() const MIJIN_NOEXCEPT +{ + std::size_t sum = 0; + for (const StoredTask& task : mijin::chain(tasks_, newTasks_)) + { + const TaskStatus status = task.task ? task.task->status() : TaskStatus::FINISHED; + if (status == TaskStatus::SUSPENDED || status == TaskStatus::RUNNING) + { + ++sum; + } + } + return sum; +} + +template typename TAllocator> +inline auto BaseSimpleTaskLoop::tick() -> CanContinue { // set current taskloop - MIJIN_ASSERT(currentLoopStorage() == nullptr, "Trying to tick a loop from a coroutine, this is not supported."); - currentLoopStorage() = this; + MIJIN_ASSERT(TaskLoop::currentLoopStorage() == nullptr, "Trying to tick a loop from a coroutine, this is not supported."); + TaskLoop::currentLoopStorage() = this; threadId_ = std::this_thread::get_id(); // move over all tasks from newTasks @@ -791,7 +1005,7 @@ inline auto SimpleTaskLoop::tick() -> CanContinue continue; } - status = tickTask(task); + status = base_t::tickTask(task); if (status == TaskStatus::SUSPENDED || status == TaskStatus::YIELDED) { @@ -799,7 +1013,7 @@ inline auto SimpleTaskLoop::tick() -> CanContinue } } // reset current loop - currentLoopStorage() = nullptr; + TaskLoop::currentLoopStorage() = nullptr; // remove any tasks that have been transferred to another queue it = std::remove_if(tasks_.begin(), tasks_.end(), [](const StoredTask& task) { @@ -810,7 +1024,8 @@ inline auto SimpleTaskLoop::tick() -> CanContinue return canContinue; } -inline void SimpleTaskLoop::runUntilDone(IgnoreWaiting ignoreWaiting) +template typename TAllocator> +void BaseSimpleTaskLoop::runUntilDone(IgnoreWaiting ignoreWaiting) { while (!tasks_.empty() || !newTasks_.empty()) { @@ -822,7 +1037,8 @@ inline void SimpleTaskLoop::runUntilDone(IgnoreWaiting ignoreWaiting) } } -inline void SimpleTaskLoop::cancelAllTasks() MIJIN_NOEXCEPT +template typename TAllocator> +void BaseSimpleTaskLoop::cancelAllTasks() MIJIN_NOEXCEPT { for (StoredTask& task : mijin::chain(tasks_, newTasks_)) { @@ -835,9 +1051,10 @@ inline void SimpleTaskLoop::cancelAllTasks() MIJIN_NOEXCEPT } } -inline std::vector SimpleTaskLoop::getAllTasks() const MIJIN_NOEXCEPT +template typename TAllocator> +std::vector> BaseSimpleTaskLoop::getAllTasks() const MIJIN_NOEXCEPT { - std::vector result; + std::vector> result((TAllocator(TaskLoop::allocator_))); for (const StoredTask& task : mijin::chain(tasks_, newTasks_)) { result.emplace_back(task.task->sharedState()); @@ -845,6 +1062,151 @@ inline std::vector SimpleTaskLoop::getAllTasks() const MIJIN_NOEXCEP return result; } +template typename TAllocator> +void BaseMultiThreadedTaskLoop::managerThread(std::stop_token stopToken) // NOLINT(performance-unnecessary-value-param) +{ + // setCurrentThreadName("Task Manager"); + + while (!stopToken.stop_requested()) + { + // first clear out any parked tasks that are actually finished + auto itRem = std::remove_if(parkedTasks_.begin(), parkedTasks_.end(), [](StoredTask& task) { + return !task.task || task.task->status() == TaskStatus::FINISHED; + }); + parkedTasks_.erase(itRem, parkedTasks_.end()); + + // then try to push any task from the buffer into the queue, if possible + for (auto it = parkedTasks_.begin(); it != parkedTasks_.end();) + { + if (!it->task->canResume()) + { + ++it; + continue; + } + + if (readyTasks_.tryPushMaybeMove(*it)) { + it = parkedTasks_.erase(it); + } + else { + break; + } + } + + // then clear the incoming task queue + while (true) + { + std::optional task = queuedTasks_.tryPop(); + if (!task.has_value()) { + break; + } + + // try to directly move it into the next queue + if (readyTasks_.tryPushMaybeMove(*task)) { + continue; + } + + // otherwise park it + parkedTasks_.push_back(std::move(*task)); + } + + // next collect tasks returning from the worker threads + while (true) + { + std::optional task = returningTasks_.tryPop(); + if (!task.has_value()) { + break; + } + + if (task->task == nullptr || task->task->status() == TaskStatus::FINISHED) { + continue; // task has been transferred or finished + } + + if (task->task->canResume() && readyTasks_.tryPushMaybeMove(*task)) { + continue; // instantly resume, no questions asked + } + + // otherwise park it for future processing + parkedTasks_.push_back(std::move(*task)); + } + + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } +} + +template typename TAllocator> +void BaseMultiThreadedTaskLoop::workerThread(std::stop_token stopToken, std::size_t workerId) // NOLINT(performance-unnecessary-value-param) +{ + TaskLoop::currentLoopStorage() = this; // forever (on this thread) + + std::array threadName; + (void) std::snprintf(threadName.data(), 16, "Task Worker %lu", static_cast(workerId)); + // setCurrentThreadName(threadName.data()); + + while (!stopToken.stop_requested()) + { + // try to fetch a task to run + std::optional task = readyTasks_.tryPop(); + if (!task.has_value()) + { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; + } + + // run it + currentTask_ = &*task; + impl::gCurrentTaskState = task->task->sharedState(); + tickTask(*task); + currentTask_ = nullptr; + impl::gCurrentTaskState = nullptr; + + // and give it back + returningTasks_.push(std::move(*task)); + } +} + +template typename TAllocator> +void BaseMultiThreadedTaskLoop::transferCurrentTask(TaskLoop& otherLoop) MIJIN_NOEXCEPT +{ + if (&otherLoop == this) { + return; + } + + MIJIN_ASSERT_FATAL(currentTask_ != nullptr, "Trying to call transferCurrentTask() while not running a task!"); + + // now start the transfer, first disown the task + StoredTask storedTask = std::move(*currentTask_); + currentTask_->task = nullptr; // just to be sure + + // then send it over to the other loop + otherLoop.addStoredTask(std::move(storedTask)); +} + +template typename TAllocator> +void BaseMultiThreadedTaskLoop::addStoredTask(StoredTask&& storedTask) MIJIN_NOEXCEPT +{ + storedTask.task->setLoop(this); + + // just assume we are not on the manager thread, as that wouldn't make sense + queuedTasks_.push(std::move(storedTask)); +} + +template typename TAllocator> +void BaseMultiThreadedTaskLoop::start(std::size_t numWorkerThreads) +{ + managerThread_ = std::jthread([this](std::stop_token stopToken) { managerThread(std::move(stopToken)); }); + workerThreads_.reserve(numWorkerThreads); + for (std::size_t workerId = 0; workerId < numWorkerThreads; ++workerId) { + workerThreads_.emplace_back([this, workerId](std::stop_token stopToken) { workerThread(std::move(stopToken), workerId); }); + } +} + +template typename TAllocator> +void BaseMultiThreadedTaskLoop::stop() +{ + workerThreads_.clear(); // will also set the stop token + managerThread_ = {}; // this too +} + // utility stuff inline TaskAwaitableSuspend c_suspend() { @@ -871,8 +1233,8 @@ Task<> c_allDone(const TCollection, TTemplateArgs...>& futures) [[nodiscard]] inline TaskHandle getCurrentTask() MIJIN_NOEXCEPT { - MIJIN_ASSERT(impl::gCurrentTask != nullptr, "Attempt to call getCurrentTask() outside of task."); - return TaskHandle(impl::gCurrentTask->task->sharedState()); + MIJIN_ASSERT(impl::gCurrentTaskState != nullptr, "Attempt to call getCurrentTask() outside of task."); + return TaskHandle(impl::gCurrentTaskState); } }