From 89bb11011669b6afd1528feacd513f8755a7fff3 Mon Sep 17 00:00:00 2001 From: Patrick Wuttke Date: Fri, 3 Nov 2023 00:48:13 +0100 Subject: [PATCH] Added cancelling (and some exception handling) for coroutines. --- source/mijin/async/coroutine.cpp | 13 ++- source/mijin/async/coroutine.hpp | 175 +++++++++++++++++++++++-------- 2 files changed, 141 insertions(+), 47 deletions(-) diff --git a/source/mijin/async/coroutine.cpp b/source/mijin/async/coroutine.cpp index 71b25c4..6b4f238 100644 --- a/source/mijin/async/coroutine.cpp +++ b/source/mijin/async/coroutine.cpp @@ -24,7 +24,10 @@ namespace mijin // internal variables // -thread_local TaskLoop::StoredTask* MultiThreadedTaskLoop::currentTask_ = nullptr; +namespace impl +{ +thread_local TaskLoop::StoredTask* gCurrentTask = nullptr; +} // // internal functions @@ -119,9 +122,9 @@ void MultiThreadedTaskLoop::workerThread(std::stop_token stopToken, std::size_t } // run it - currentTask_ = &*task; + impl::gCurrentTask = &*task; tickTask(*task); - currentTask_ = nullptr; + impl::gCurrentTask = nullptr; // and give it back returningTasks_.push(std::move(*task)); @@ -180,8 +183,8 @@ void MultiThreadedTaskLoop::transferCurrentTask(TaskLoop& otherLoop) noexcept MIJIN_ASSERT_FATAL(currentTask_ != nullptr, "Trying to call transferCurrentTask() while not running a task!"); // now start the transfer, first disown the task - StoredTask storedTask = std::move(*currentTask_); - currentTask_->task = nullptr; // just to be sure + StoredTask storedTask = std::move(*impl::gCurrentTask); + impl::gCurrentTask->task = nullptr; // just to be sure // then send it over to the other loop otherLoop.addStoredTask(std::move(storedTask)); diff --git a/source/mijin/async/coroutine.hpp b/source/mijin/async/coroutine.hpp index 20f039b..f77386f 100644 --- a/source/mijin/async/coroutine.hpp +++ b/source/mijin/async/coroutine.hpp @@ -56,11 +56,15 @@ struct TaskReturn { template constexpr void return_value(TArgs&&... args) noexcept { - *(static_cast(*this).state_) = TaskState(TReturn(std::forward(args)...), TaskStatus::FINISHED); + (static_cast(*this).state_) = TaskState(TReturn(std::forward(args)...), TaskStatus::FINISHED); } constexpr void return_value(TReturn value) noexcept { - *(static_cast(*this).state_) = TaskState(TReturn(std::move(value)), TaskStatus::FINISHED); + (static_cast(*this).state_) = TaskState(TReturn(std::move(value)), TaskStatus::FINISHED); + } + + constexpr void unhandled_exception() noexcept { + (static_cast(*this).state_) = TaskState(std::current_exception()); } }; @@ -68,33 +72,45 @@ template struct TaskReturn { constexpr void return_void() noexcept { - static_cast(*this).state_->status = TaskStatus::FINISHED; + static_cast(*this).state_.status = TaskStatus::FINISHED; + } + + constexpr void unhandled_exception() noexcept { + (static_cast(*this).state_) = TaskState(std::current_exception()); } }; + +struct TaskCancelled : std::exception {}; + +inline void throwIfCancelled(); } // namespace impl template struct TaskState { Optional value; + std::exception_ptr exception; TaskStatus status = TaskStatus::SUSPENDED; TaskState() = default; TaskState(const TaskState&) = default; TaskState(TaskState&&) noexcept = default; - constexpr TaskState(T _value, TaskStatus _status) noexcept : value(std::move(_value)), status(_status) {} + inline TaskState(T _value, TaskStatus _status) noexcept : value(std::move(_value)), status(_status) {} + inline TaskState(std::exception_ptr _exception) noexcept : exception(std::move(_exception)), status(TaskStatus::FINISHED) {} TaskState& operator=(const TaskState&) = default; TaskState& operator=(TaskState&&) noexcept = default; }; template<> struct TaskState { + std::exception_ptr exception; TaskStatus status = TaskStatus::SUSPENDED; TaskState() = default; TaskState(const TaskState&) = default; TaskState(TaskState&&) noexcept = default; - constexpr TaskState(TaskStatus _status) noexcept : status(_status) {} + inline TaskState(TaskStatus _status) noexcept : status(_status) {} + inline TaskState(std::exception_ptr _exception) noexcept : exception(std::move(_exception)), status(TaskStatus::FINISHED) {} TaskState& operator=(const TaskState&) = default; TaskState& operator=(TaskState&&) noexcept = default; }; @@ -106,7 +122,9 @@ struct TaskAwaitableFuture [[nodiscard]] constexpr bool await_ready() const noexcept { return future->ready(); } constexpr void await_suspend(std::coroutine_handle<>) const noexcept {} - constexpr TValue await_resume() const noexcept { + constexpr TValue await_resume() const noexcept + { + impl::throwIfCancelled(); if constexpr (std::is_same_v) { return; } @@ -123,7 +141,9 @@ struct TaskAwaitableSignal [[nodiscard]] constexpr bool await_ready() const noexcept { return false; } constexpr void await_suspend(std::coroutine_handle<>) const noexcept {} - constexpr auto& await_resume() const noexcept { + inline auto& await_resume() const + { + impl::throwIfCancelled(); return *data; } }; @@ -135,7 +155,9 @@ struct TaskAwaitableSignal [[nodiscard]] constexpr bool await_ready() const noexcept { return false; } constexpr void await_suspend(std::coroutine_handle<>) const noexcept {} - constexpr auto& await_resume() const noexcept { + constexpr auto& await_resume() const + { + impl::throwIfCancelled(); return *data; } }; @@ -145,7 +167,18 @@ struct TaskAwaitableSignal<> { [[nodiscard]] constexpr bool await_ready() const noexcept { return false; } constexpr void await_suspend(std::coroutine_handle<>) const noexcept {} - constexpr void await_resume() const noexcept {} + inline void await_resume() const { + impl::throwIfCancelled(); + } +}; + +struct TaskAwaitableSuspend +{ + [[nodiscard]] constexpr bool await_ready() const noexcept { return false; } + constexpr void await_suspend(std::coroutine_handle<>) const noexcept {} + inline void await_resume() const { + impl::throwIfCancelled(); + } }; template @@ -155,7 +188,7 @@ struct TaskPromise : impl::TaskReturn> state_ = std::make_shared>(); + TaskState state_; TaskLoop* loop_ = nullptr; constexpr task_t get_return_object() noexcept { return task_t(handle_t::from_promise(*this)); } @@ -170,7 +203,7 @@ struct TaskPromise : impl::TaskReturn auto await_transform(FuturePtr future) noexcept @@ -179,10 +212,10 @@ struct TaskPromise : impl::TaskReturn awaitable{future}; if (!awaitable.await_ready()) { - state_->status = TaskStatus::WAITING; + state_.status = TaskStatus::WAITING; future->sigSet.connect([this, future]() mutable { - state_->status = TaskStatus::SUSPENDED; + state_.status = TaskStatus::SUSPENDED; }, Oneshot::YES); } return awaitable; @@ -203,10 +236,10 @@ struct TaskPromise : impl::TaskReturnstatus = TaskStatus::SUSPENDED; + state_.status = TaskStatus::SUSPENDED; }, Oneshot::YES); TaskAwaitableSignal awaitable{data}; - state_->status = TaskStatus::WAITING; + state_.status = TaskStatus::WAITING; return awaitable; } @@ -217,10 +250,10 @@ struct TaskPromise : impl::TaskReturnstatus = TaskStatus::SUSPENDED; + state_.status = TaskStatus::SUSPENDED; }, Oneshot::YES); TaskAwaitableSignal awaitable{data}; - state_->status = TaskStatus::WAITING; + state_.status = TaskStatus::WAITING; return awaitable; } @@ -228,22 +261,28 @@ struct TaskPromise : impl::TaskReturnstatus = TaskStatus::SUSPENDED; + state_.status = TaskStatus::SUSPENDED; }, Oneshot::YES); TaskAwaitableSignal<> awaitable{}; - state_->status = TaskStatus::WAITING; + state_.status = TaskStatus::WAITING; return awaitable; } std::suspend_always await_transform(std::suspend_always) noexcept { - state_->status = TaskStatus::SUSPENDED; + state_.status = TaskStatus::SUSPENDED; return std::suspend_always(); } std::suspend_never await_transform(std::suspend_never) noexcept { return std::suspend_never(); } + + TaskAwaitableSuspend await_transform(TaskAwaitableSuspend) noexcept + { + state_.status = TaskStatus::SUSPENDED; + return TaskAwaitableSuspend(); + } }; template @@ -262,10 +301,9 @@ public: using handle_t = typename promise_type::handle_t; private: handle_t handle_; - std::shared_ptr> state_; public: - constexpr explicit TaskBase(handle_t handle) noexcept : handle_(handle), state_(handle.promise().state_) {} - TaskBase(const TaskBase&) = default; + constexpr explicit TaskBase(handle_t handle) noexcept : handle_(handle) {} + TaskBase(const TaskBase&) = delete; TaskBase(TaskBase&& other) noexcept = default; ~TaskBase() noexcept; public: @@ -278,17 +316,16 @@ public: [[nodiscard]] constexpr bool operator!=(const TaskBase& other) const noexcept { return handle_ != other.handle_; } public: - constexpr TaskState& resume() noexcept - { - state_->status = TaskStatus::RUNNING; - handle_.resume(); - return *state_; - } - [[nodiscard]] constexpr TaskState& state() noexcept { - return *state_; + return handle_.promise().state_; + } + constexpr TaskState& resume() + { + state().status = TaskStatus::RUNNING; + handle_.resume(); + return state(); } private: [[nodiscard]] @@ -318,8 +355,9 @@ public: virtual ~WrappedTaskBase() = default; public: virtual TaskStatus status() noexcept = 0; + virtual std::exception_ptr exception() noexcept = 0; // virtual std::any result() noexcept = 0; - virtual void resume() noexcept = 0; + virtual void resume() = 0; virtual void* raw() noexcept = 0; virtual std::coroutine_handle<> handle() noexcept = 0; virtual void setLoop(TaskLoop* loop) noexcept = 0; @@ -344,6 +382,7 @@ public: WrappedTask& operator=(WrappedTask&&) noexcept = default; public: TaskStatus status() noexcept override { return task_.state().status; } + std::exception_ptr exception() noexcept override { return task_.state().exception; } // std::any result() noexcept override // { // if constexpr (std::is_same_v) { @@ -353,12 +392,39 @@ public: // return std::any(task_.state().value); // } // } - void resume() noexcept override { task_.resume(); } + void resume() override { task_.resume(); } void* raw() noexcept override { return &task_; } std::coroutine_handle<> handle() noexcept override { return task_.handle(); } void setLoop(TaskLoop* loop) noexcept override { task_.setLoop(loop); } }; +struct TaskSharedState +{ + std::atomic_bool cancelled_ = false; +}; + +class TaskHandle +{ +private: + std::weak_ptr state_; +public: + TaskHandle() = default; + explicit TaskHandle(std::weak_ptr state) noexcept : state_(std::move(state)) {} + TaskHandle(const TaskHandle&) = default; + TaskHandle(TaskHandle&&) = default; + + TaskHandle& operator=(const TaskHandle&) = default; + TaskHandle& operator=(TaskHandle&&) = default; + + void cancel() const noexcept + { + if (std::shared_ptr state = state_.lock()) + { + state->cancelled_ = true; + } + } +}; + template std::unique_ptr> wrapTask(TTask&& task) noexcept { @@ -370,16 +436,17 @@ class TaskLoop public: MIJIN_DEFINE_FLAG(CanContinue); MIJIN_DEFINE_FLAG(IgnoreWaiting); -protected: + using wrapped_task_t = WrappedTaskBase; using wrapped_task_base_ptr_t = std::unique_ptr; - struct StoredTask { + std::shared_ptr sharedState; wrapped_task_base_ptr_t task; std::function setFuture; std::any resultData; }; +protected: using task_vector_t = std::vector; template @@ -394,7 +461,7 @@ public: TaskLoop& operator=(TaskLoop&&) = delete; template - inline FuturePtr addTask(TaskBase task) noexcept; + inline FuturePtr addTask(TaskBase task, TaskHandle* outHandle = nullptr) noexcept; virtual void transferCurrentTask(TaskLoop& otherLoop) noexcept = 0; virtual void addStoredTask(StoredTask&& storedTask) noexcept = 0; @@ -452,14 +519,25 @@ public: // public interface private: // private stuff void managerThread(std::stop_token stopToken); void workerThread(std::stop_token stopToken, std::size_t workerId); - - static thread_local StoredTask* currentTask_; }; // // public functions // +namespace impl +{ +extern thread_local TaskLoop::StoredTask* gCurrentTask; + +inline void throwIfCancelled() +{ + if (gCurrentTask->sharedState->cancelled_) + { + throw TaskCancelled(); + } +} +} + template TaskBase::~TaskBase() noexcept { @@ -470,17 +548,23 @@ TaskBase::~TaskBase() noexcept } template -inline FuturePtr TaskLoop::addTask(TaskBase task) noexcept +inline FuturePtr TaskLoop::addTask(TaskBase task, TaskHandle* outHandle) noexcept { MIJIN_ASSERT(!task.getLoop(), "Attempting to add task that already has a loop!"); task.setLoop(this); + auto sharedState = std::make_shared(); auto future = std::make_shared>(); - auto setFuture = &setFutureHelper; + if (outHandle != nullptr) + { + *outHandle = TaskHandle(sharedState); + } + // add tasks to a seperate vector first as we might be running another task right now addStoredTask(StoredTask{ + .sharedState = std::move(sharedState), .task = wrapTask(std::move(task)), .setFuture = setFuture, .resultData = future @@ -492,13 +576,20 @@ inline FuturePtr TaskLoop::addTask(TaskBase task) noexcept inline TaskStatus TaskLoop::tickTask(StoredTask& task) noexcept { TaskStatus status = {}; + impl::gCurrentTask = &task; do { task.task->resume(); status = task.task ? task.task->status() : TaskStatus::WAITING; // no inner task -> task switch context (and will be removed later) } while (status == TaskStatus::RUNNING); + impl::gCurrentTask = nullptr; + if (task.task && task.task->exception()) + { + // TODO: handle the exception somehow, others may be waiting + return TaskStatus::FINISHED; + } if (status == TaskStatus::YIELDED || status == TaskStatus::FINISHED) { task.setFuture(task); @@ -619,8 +710,8 @@ inline void SimpleTaskLoop::runUntilDone(IgnoreWaiting ignoreWaiting) noexcept // utility stuff -inline std::suspend_always c_suspend() { - return std::suspend_always(); +inline TaskAwaitableSuspend c_suspend() { + return TaskAwaitableSuspend(); } template typename TCollection, typename TType, typename... TTemplateArgs>