diff --git a/source/mijin/async/coroutine.cpp b/source/mijin/async/coroutine.cpp index 6b4f238..6684b26 100644 --- a/source/mijin/async/coroutine.cpp +++ b/source/mijin/async/coroutine.cpp @@ -180,7 +180,7 @@ void MultiThreadedTaskLoop::transferCurrentTask(TaskLoop& otherLoop) noexcept return; } - MIJIN_ASSERT_FATAL(currentTask_ != nullptr, "Trying to call transferCurrentTask() while not running a task!"); + MIJIN_ASSERT_FATAL(impl::gCurrentTask != nullptr, "Trying to call transferCurrentTask() while not running a task!"); // now start the transfer, first disown the task StoredTask storedTask = std::move(*impl::gCurrentTask); diff --git a/source/mijin/async/coroutine.hpp b/source/mijin/async/coroutine.hpp index 933fa02..13cc18d 100644 --- a/source/mijin/async/coroutine.hpp +++ b/source/mijin/async/coroutine.hpp @@ -4,6 +4,15 @@ #ifndef MIJIN_ASYNC_COROUTINE_HPP_INCLUDED #define MIJIN_ASYNC_COROUTINE_HPP_INCLUDED 1 + +#if !defined(MIJIN_COROUTINE_ENABLE_DEBUG_INFO) +#if defined(MIJIN_DEBUG) +# define MIJIN_COROUTINE_ENABLE_DEBUG_INFO 1 +#else +# define MIJIN_COROUTINE_ENABLE_DEBUG_INFO 0 +#endif +#endif + #include #include #include @@ -15,6 +24,9 @@ #include "../container/optional.hpp" #include "../util/flag.hpp" #include "../util/traits.hpp" +#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO +#include "../debug/stacktrace.hpp" +#endif namespace mijin { @@ -23,10 +35,6 @@ namespace mijin // public defines // -#if !defined(MIJIN_COROUTINE_ENABLE_DEBUG_INFO) -#define MIJIN_COROUTINE_ENABLE_DEBUG_INFO 0 -#endif - // // public types // @@ -56,6 +64,32 @@ namespace impl inline void throwIfCancelled(); } // namespace impl +class TaskHandle +{ +private: + std::weak_ptr state_; +public: + TaskHandle() = default; + explicit TaskHandle(std::weak_ptr state) noexcept : state_(std::move(state)) {} + TaskHandle(const TaskHandle&) = default; + TaskHandle(TaskHandle&&) = default; + + TaskHandle& operator=(const TaskHandle&) = default; + TaskHandle& operator=(TaskHandle&&) = default; + + [[nodiscard]] bool isValid() const noexcept + { + return !state_.expired(); + } + + inline void cancel() const noexcept; +}; +struct TaskSharedState +{ + std::atomic_bool cancelled_ = false; + TaskHandle subTask; +}; + template struct TaskState { @@ -193,10 +227,11 @@ struct TaskPromise : impl::TaskReturn state_; + std::shared_ptr sharedState_ = std::make_shared(); TaskLoop* loop_ = nullptr; constexpr task_t get_return_object() noexcept { return task_t(handle_t::from_promise(*this)); } - constexpr std::suspend_always initial_suspend() noexcept { return {}; } + constexpr TaskAwaitableSuspend initial_suspend() noexcept { return {}; } constexpr std::suspend_always final_suspend() noexcept { return {}; } // template @@ -229,7 +264,7 @@ struct TaskPromise : impl::TaskReturn task) noexcept { MIJIN_ASSERT(loop_ != nullptr, "Cannot await another task outside of a loop!"); - auto future = delayEvaluation(loop_)->addTask(std::move(task)); // hackidyhack: delay evaluation of the type of loop_ as it is only forward-declared here + auto future = delayEvaluation(loop_)->addTask(std::move(task), &sharedState_->subTask); // hackidyhack: delay evaluation of the type of loop_ as it is only forward-declared here return await_transform(future); } @@ -305,14 +340,38 @@ public: using handle_t = typename promise_type::handle_t; private: handle_t handle_; +#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO + Stacktrace creationStack_; +#endif public: - constexpr explicit TaskBase(handle_t handle) noexcept : handle_(handle) {} + constexpr explicit TaskBase(handle_t handle) noexcept : handle_(handle) { +#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO + if (Result stacktrace = captureStacktrace(1); stacktrace.isSuccess()) + { + creationStack_ = *stacktrace; + } +#endif + } TaskBase(const TaskBase&) = delete; - TaskBase(TaskBase&& other) noexcept : handle_(std::exchange(other.handle_, nullptr)) {} + TaskBase(TaskBase&& other) noexcept : handle_(std::exchange(other.handle_, nullptr)) +#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO + , creationStack_(std::move(other.creationStack_)) +#endif + {} ~TaskBase() noexcept; public: - TaskBase& operator=(const TaskBase&) = default; - TaskBase& operator=(TaskBase&& other) noexcept = default; + TaskBase& operator=(const TaskBase&) = delete; + TaskBase& operator=(TaskBase&& other) noexcept + { + if (handle_) { + handle_.destroy(); + } + handle_ = std::exchange(other.handle_, nullptr); +#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO + creationStack_ = std::move(other.creationStack_); +#endif + return *this; + } [[nodiscard]] constexpr bool operator==(const TaskBase& other) const noexcept { return handle_ == other.handle_; } @@ -331,6 +390,10 @@ public: handle_.resume(); return state(); } + constexpr std::shared_ptr& sharedState() noexcept + { + return handle_.promise().sharedState_; + } private: [[nodiscard]] constexpr handle_t handle() const noexcept { return handle_; } @@ -365,6 +428,7 @@ public: virtual void* raw() noexcept = 0; virtual std::coroutine_handle<> handle() noexcept = 0; virtual void setLoop(TaskLoop* loop) noexcept = 0; + virtual std::shared_ptr& sharedState() noexcept = 0; [[nodiscard]] inline bool canResume() { const TaskStatus stat = status(); @@ -400,38 +464,7 @@ public: void* raw() noexcept override { return &task_; } std::coroutine_handle<> handle() noexcept override { return task_.handle(); } void setLoop(TaskLoop* loop) noexcept override { task_.setLoop(loop); } -}; - -struct TaskSharedState -{ - std::atomic_bool cancelled_ = false; -}; - -class TaskHandle -{ -private: - std::weak_ptr state_; -public: - TaskHandle() = default; - explicit TaskHandle(std::weak_ptr state) noexcept : state_(std::move(state)) {} - TaskHandle(const TaskHandle&) = default; - TaskHandle(TaskHandle&&) = default; - - TaskHandle& operator=(const TaskHandle&) = default; - TaskHandle& operator=(TaskHandle&&) = default; - - [[nodiscard]] bool isValid() const noexcept - { - return !state_.expired(); - } - - void cancel() const noexcept - { - if (std::shared_ptr state = state_.lock()) - { - state->cancelled_ = true; - } - } + virtual std::shared_ptr& sharedState() noexcept { return task_.sharedState(); } }; template @@ -450,7 +483,6 @@ public: using wrapped_task_base_ptr_t = std::unique_ptr; struct StoredTask { - std::shared_ptr sharedState; wrapped_task_base_ptr_t task; std::function setFuture; std::any resultData; @@ -546,13 +578,22 @@ extern thread_local TaskLoop::StoredTask* gCurrentTask; inline void throwIfCancelled() { - if (gCurrentTask->sharedState->cancelled_) + if (gCurrentTask->task->sharedState()->cancelled_) { throw TaskCancelled(); } } } +void TaskHandle::cancel() const noexcept +{ + if (std::shared_ptr state = state_.lock()) + { + state->cancelled_ = true; + state->subTask.cancel(); + } +} + template TaskBase::~TaskBase() noexcept { @@ -568,18 +609,16 @@ inline FuturePtr TaskLoop::addTask(TaskBase task, TaskHandle* MIJIN_ASSERT(!task.getLoop(), "Attempting to add task that already has a loop!"); task.setLoop(this); - auto sharedState = std::make_shared(); auto future = std::make_shared>(); auto setFuture = &setFutureHelper; if (outHandle != nullptr) { - *outHandle = TaskHandle(sharedState); + *outHandle = TaskHandle(task.sharedState()); } // add tasks to a seperate vector first as we might be running another task right now addStoredTask(StoredTask{ - .sharedState = std::move(sharedState), .task = wrapTask(std::move(task)), .setFuture = setFuture, .resultData = future @@ -703,6 +742,11 @@ inline auto SimpleTaskLoop::tick() -> CanContinue { StoredTask& task = *currentTask_; TaskStatus status = task.task->status(); + if (status == TaskStatus::WAITING && task.task->sharedState()->cancelled_) + { + // always continue a cancelled task, even if it was still waiting for a result + status = TaskStatus::SUSPENDED; + } if (status != TaskStatus::SUSPENDED && status != TaskStatus::YIELDED) { MIJIN_ASSERT(status == TaskStatus::WAITING, "Task with invalid status in task list!");