diff --git a/source/mijin/async/coroutine.hpp b/source/mijin/async/coroutine.hpp index 13cc18d..07c7782 100644 --- a/source/mijin/async/coroutine.hpp +++ b/source/mijin/async/coroutine.hpp @@ -23,6 +23,7 @@ #include "./message_queue.hpp" #include "../container/optional.hpp" #include "../util/flag.hpp" +#include "../util/iterators.hpp" #include "../util/traits.hpp" #if MIJIN_COROUTINE_ENABLE_DEBUG_INFO #include "../debug/stacktrace.hpp" @@ -77,17 +78,30 @@ public: TaskHandle& operator=(const TaskHandle&) = default; TaskHandle& operator=(TaskHandle&&) = default; + bool operator==(const TaskHandle& other) const noexcept { + return !state_.owner_before(other.state_) && !other.state_.owner_before(state_); + } + bool operator!=(const TaskHandle& other) const noexcept { + return !(*this == other); + } + [[nodiscard]] bool isValid() const noexcept { return !state_.expired(); } inline void cancel() const noexcept; +#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO + inline Optional getCreationStack() const noexcept; +#endif }; struct TaskSharedState { std::atomic_bool cancelled_ = false; TaskHandle subTask; +#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO + Stacktrace creationStack_; +#endif }; template @@ -340,23 +354,17 @@ public: using handle_t = typename promise_type::handle_t; private: handle_t handle_; -#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO - Stacktrace creationStack_; -#endif public: constexpr explicit TaskBase(handle_t handle) noexcept : handle_(handle) { #if MIJIN_COROUTINE_ENABLE_DEBUG_INFO - if (Result stacktrace = captureStacktrace(1); stacktrace.isSuccess()) + if (Result stacktrace = captureStacktrace(2); stacktrace.isSuccess()) { - creationStack_ = *stacktrace; + handle_.promise().sharedState_->creationStack_ = *stacktrace; } #endif } TaskBase(const TaskBase&) = delete; TaskBase(TaskBase&& other) noexcept : handle_(std::exchange(other.handle_, nullptr)) -#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO - , creationStack_(std::move(other.creationStack_)) -#endif {} ~TaskBase() noexcept; public: @@ -367,9 +375,6 @@ public: handle_.destroy(); } handle_ = std::exchange(other.handle_, nullptr); -#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO - creationStack_ = std::move(other.creationStack_); -#endif return *this; } @@ -464,7 +469,7 @@ public: void* raw() noexcept override { return &task_; } std::coroutine_handle<> handle() noexcept override { return task_.handle(); } void setLoop(TaskLoop* loop) noexcept override { task_.setLoop(loop); } - virtual std::shared_ptr& sharedState() noexcept { return task_.sharedState(); } + virtual std::shared_ptr& sharedState() noexcept override { return task_.sharedState(); } }; template @@ -540,8 +545,11 @@ public: // TaskLoop implementation public: // public interface [[nodiscard]] constexpr bool empty() const noexcept { return tasks_.empty() && newTasks_.empty(); } + [[nodiscard]] constexpr std::size_t getNumTasks() const noexcept { return tasks_.size() + newTasks_.size(); } inline CanContinue tick(); inline void runUntilDone(IgnoreWaiting ignoreWaiting = IgnoreWaiting::NO); + inline void cancelAllTasks() noexcept; + [[nodiscard]] inline std::vector getAllTasks() const noexcept; private: inline void assertCorrectThread() { MIJIN_ASSERT(threadId_ == std::thread::id() || threadId_ == std::this_thread::get_id(), "Unsafe to TaskLoop from different thread!"); } }; @@ -594,6 +602,17 @@ void TaskHandle::cancel() const noexcept } } +#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO +Optional TaskHandle::getCreationStack() const noexcept +{ + if (std::shared_ptr state = state_.lock()) + { + return state->creationStack_; + } + return NULL_OPTIONAL; +} +#endif // MIJIN_COROUTINE_ENABLE_DEBUG_INFO + template TaskBase::~TaskBase() noexcept { @@ -784,6 +803,29 @@ inline void SimpleTaskLoop::runUntilDone(IgnoreWaiting ignoreWaiting) } } +inline void SimpleTaskLoop::cancelAllTasks() noexcept +{ + for (StoredTask& task : mijin::chain(tasks_, newTasks_)) + { + task.task->sharedState()->cancelled_ = true; + } + for (StoredTask& task : queuedTasks_) + { + // just discard it + (void) task; + } +} + +inline std::vector SimpleTaskLoop::getAllTasks() const noexcept +{ + std::vector result; + for (const StoredTask& task : mijin::chain(tasks_, newTasks_)) + { + result.emplace_back(task.task->sharedState()); + } + return result; +} + // utility stuff inline TaskAwaitableSuspend c_suspend() { @@ -808,8 +850,11 @@ Task<> c_allDone(const TCollection, TTemplateArgs...>& futures) } while (!allDone); } -#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO -#endif +[[nodiscard]] inline TaskHandle getCurrentTask() noexcept +{ + MIJIN_ASSERT(impl::gCurrentTask != nullptr, "Attempt to call getCurrentTask() outside of task."); + return TaskHandle(impl::gCurrentTask->task->sharedState()); +} } #endif // MIJIN_ASYNC_COROUTINE_HPP_INCLUDED