Also cancel sub-tasks (those that are awaited by this one) when cancelling a task.

This commit is contained in:
Patrick 2023-11-18 22:20:47 +01:00
parent d98e14285b
commit 065181fc69
2 changed files with 92 additions and 48 deletions

View File

@ -180,7 +180,7 @@ void MultiThreadedTaskLoop::transferCurrentTask(TaskLoop& otherLoop) noexcept
return; return;
} }
MIJIN_ASSERT_FATAL(currentTask_ != nullptr, "Trying to call transferCurrentTask() while not running a task!"); MIJIN_ASSERT_FATAL(impl::gCurrentTask != nullptr, "Trying to call transferCurrentTask() while not running a task!");
// now start the transfer, first disown the task // now start the transfer, first disown the task
StoredTask storedTask = std::move(*impl::gCurrentTask); StoredTask storedTask = std::move(*impl::gCurrentTask);

View File

@ -4,6 +4,15 @@
#ifndef MIJIN_ASYNC_COROUTINE_HPP_INCLUDED #ifndef MIJIN_ASYNC_COROUTINE_HPP_INCLUDED
#define MIJIN_ASYNC_COROUTINE_HPP_INCLUDED 1 #define MIJIN_ASYNC_COROUTINE_HPP_INCLUDED 1
#if !defined(MIJIN_COROUTINE_ENABLE_DEBUG_INFO)
#if defined(MIJIN_DEBUG)
# define MIJIN_COROUTINE_ENABLE_DEBUG_INFO 1
#else
# define MIJIN_COROUTINE_ENABLE_DEBUG_INFO 0
#endif
#endif
#include <any> #include <any>
#include <coroutine> #include <coroutine>
#include <memory> #include <memory>
@ -15,6 +24,9 @@
#include "../container/optional.hpp" #include "../container/optional.hpp"
#include "../util/flag.hpp" #include "../util/flag.hpp"
#include "../util/traits.hpp" #include "../util/traits.hpp"
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
#include "../debug/stacktrace.hpp"
#endif
namespace mijin namespace mijin
{ {
@ -23,10 +35,6 @@ namespace mijin
// public defines // public defines
// //
#if !defined(MIJIN_COROUTINE_ENABLE_DEBUG_INFO)
#define MIJIN_COROUTINE_ENABLE_DEBUG_INFO 0
#endif
// //
// public types // public types
// //
@ -56,6 +64,32 @@ namespace impl
inline void throwIfCancelled(); inline void throwIfCancelled();
} // namespace impl } // namespace impl
class TaskHandle
{
private:
std::weak_ptr<struct TaskSharedState> state_;
public:
TaskHandle() = default;
explicit TaskHandle(std::weak_ptr<TaskSharedState> state) noexcept : state_(std::move(state)) {}
TaskHandle(const TaskHandle&) = default;
TaskHandle(TaskHandle&&) = default;
TaskHandle& operator=(const TaskHandle&) = default;
TaskHandle& operator=(TaskHandle&&) = default;
[[nodiscard]] bool isValid() const noexcept
{
return !state_.expired();
}
inline void cancel() const noexcept;
};
struct TaskSharedState
{
std::atomic_bool cancelled_ = false;
TaskHandle subTask;
};
template<typename T> template<typename T>
struct TaskState struct TaskState
{ {
@ -193,10 +227,11 @@ struct TaskPromise : impl::TaskReturn<typename TTraits::result_t, TaskPromise<TT
using result_t = typename TTraits::result_t; using result_t = typename TTraits::result_t;
TaskState<result_t> state_; TaskState<result_t> state_;
std::shared_ptr<TaskSharedState> sharedState_ = std::make_shared<TaskSharedState>();
TaskLoop* loop_ = nullptr; TaskLoop* loop_ = nullptr;
constexpr task_t get_return_object() noexcept { return task_t(handle_t::from_promise(*this)); } constexpr task_t get_return_object() noexcept { return task_t(handle_t::from_promise(*this)); }
constexpr std::suspend_always initial_suspend() noexcept { return {}; } constexpr TaskAwaitableSuspend initial_suspend() noexcept { return {}; }
constexpr std::suspend_always final_suspend() noexcept { return {}; } constexpr std::suspend_always final_suspend() noexcept { return {}; }
// template<typename TValue> // template<typename TValue>
@ -229,7 +264,7 @@ struct TaskPromise : impl::TaskReturn<typename TTraits::result_t, TaskPromise<TT
auto await_transform(TaskBase<TResultOther> task) noexcept auto await_transform(TaskBase<TResultOther> task) noexcept
{ {
MIJIN_ASSERT(loop_ != nullptr, "Cannot await another task outside of a loop!"); MIJIN_ASSERT(loop_ != nullptr, "Cannot await another task outside of a loop!");
auto future = delayEvaluation<TResultOther>(loop_)->addTask(std::move(task)); // hackidyhack: delay evaluation of the type of loop_ as it is only forward-declared here auto future = delayEvaluation<TResultOther>(loop_)->addTask(std::move(task), &sharedState_->subTask); // hackidyhack: delay evaluation of the type of loop_ as it is only forward-declared here
return await_transform(future); return await_transform(future);
} }
@ -305,14 +340,38 @@ public:
using handle_t = typename promise_type::handle_t; using handle_t = typename promise_type::handle_t;
private: private:
handle_t handle_; handle_t handle_;
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
Stacktrace creationStack_;
#endif
public: public:
constexpr explicit TaskBase(handle_t handle) noexcept : handle_(handle) {} constexpr explicit TaskBase(handle_t handle) noexcept : handle_(handle) {
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
if (Result<Stacktrace> stacktrace = captureStacktrace(1); stacktrace.isSuccess())
{
creationStack_ = *stacktrace;
}
#endif
}
TaskBase(const TaskBase&) = delete; TaskBase(const TaskBase&) = delete;
TaskBase(TaskBase&& other) noexcept : handle_(std::exchange(other.handle_, nullptr)) {} TaskBase(TaskBase&& other) noexcept : handle_(std::exchange(other.handle_, nullptr))
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
, creationStack_(std::move(other.creationStack_))
#endif
{}
~TaskBase() noexcept; ~TaskBase() noexcept;
public: public:
TaskBase& operator=(const TaskBase&) = default; TaskBase& operator=(const TaskBase&) = delete;
TaskBase& operator=(TaskBase&& other) noexcept = default; TaskBase& operator=(TaskBase&& other) noexcept
{
if (handle_) {
handle_.destroy();
}
handle_ = std::exchange(other.handle_, nullptr);
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
creationStack_ = std::move(other.creationStack_);
#endif
return *this;
}
[[nodiscard]] [[nodiscard]]
constexpr bool operator==(const TaskBase& other) const noexcept { return handle_ == other.handle_; } constexpr bool operator==(const TaskBase& other) const noexcept { return handle_ == other.handle_; }
@ -331,6 +390,10 @@ public:
handle_.resume(); handle_.resume();
return state(); return state();
} }
constexpr std::shared_ptr<TaskSharedState>& sharedState() noexcept
{
return handle_.promise().sharedState_;
}
private: private:
[[nodiscard]] [[nodiscard]]
constexpr handle_t handle() const noexcept { return handle_; } constexpr handle_t handle() const noexcept { return handle_; }
@ -365,6 +428,7 @@ public:
virtual void* raw() noexcept = 0; virtual void* raw() noexcept = 0;
virtual std::coroutine_handle<> handle() noexcept = 0; virtual std::coroutine_handle<> handle() noexcept = 0;
virtual void setLoop(TaskLoop* loop) noexcept = 0; virtual void setLoop(TaskLoop* loop) noexcept = 0;
virtual std::shared_ptr<TaskSharedState>& sharedState() noexcept = 0;
[[nodiscard]] inline bool canResume() { [[nodiscard]] inline bool canResume() {
const TaskStatus stat = status(); const TaskStatus stat = status();
@ -400,38 +464,7 @@ public:
void* raw() noexcept override { return &task_; } void* raw() noexcept override { return &task_; }
std::coroutine_handle<> handle() noexcept override { return task_.handle(); } std::coroutine_handle<> handle() noexcept override { return task_.handle(); }
void setLoop(TaskLoop* loop) noexcept override { task_.setLoop(loop); } void setLoop(TaskLoop* loop) noexcept override { task_.setLoop(loop); }
}; virtual std::shared_ptr<TaskSharedState>& sharedState() noexcept { return task_.sharedState(); }
struct TaskSharedState
{
std::atomic_bool cancelled_ = false;
};
class TaskHandle
{
private:
std::weak_ptr<TaskSharedState> state_;
public:
TaskHandle() = default;
explicit TaskHandle(std::weak_ptr<TaskSharedState> state) noexcept : state_(std::move(state)) {}
TaskHandle(const TaskHandle&) = default;
TaskHandle(TaskHandle&&) = default;
TaskHandle& operator=(const TaskHandle&) = default;
TaskHandle& operator=(TaskHandle&&) = default;
[[nodiscard]] bool isValid() const noexcept
{
return !state_.expired();
}
void cancel() const noexcept
{
if (std::shared_ptr<TaskSharedState> state = state_.lock())
{
state->cancelled_ = true;
}
}
}; };
template<typename TTask> template<typename TTask>
@ -450,7 +483,6 @@ public:
using wrapped_task_base_ptr_t = std::unique_ptr<wrapped_task_t>; using wrapped_task_base_ptr_t = std::unique_ptr<wrapped_task_t>;
struct StoredTask struct StoredTask
{ {
std::shared_ptr<TaskSharedState> sharedState;
wrapped_task_base_ptr_t task; wrapped_task_base_ptr_t task;
std::function<void(StoredTask&)> setFuture; std::function<void(StoredTask&)> setFuture;
std::any resultData; std::any resultData;
@ -546,13 +578,22 @@ extern thread_local TaskLoop::StoredTask* gCurrentTask;
inline void throwIfCancelled() inline void throwIfCancelled()
{ {
if (gCurrentTask->sharedState->cancelled_) if (gCurrentTask->task->sharedState()->cancelled_)
{ {
throw TaskCancelled(); throw TaskCancelled();
} }
} }
} }
void TaskHandle::cancel() const noexcept
{
if (std::shared_ptr<TaskSharedState> state = state_.lock())
{
state->cancelled_ = true;
state->subTask.cancel();
}
}
template<typename TResult> template<typename TResult>
TaskBase<TResult>::~TaskBase() noexcept TaskBase<TResult>::~TaskBase() noexcept
{ {
@ -568,18 +609,16 @@ inline FuturePtr<TResult> TaskLoop::addTask(TaskBase<TResult> task, TaskHandle*
MIJIN_ASSERT(!task.getLoop(), "Attempting to add task that already has a loop!"); MIJIN_ASSERT(!task.getLoop(), "Attempting to add task that already has a loop!");
task.setLoop(this); task.setLoop(this);
auto sharedState = std::make_shared<TaskSharedState>();
auto future = std::make_shared<Future<TResult>>(); auto future = std::make_shared<Future<TResult>>();
auto setFuture = &setFutureHelper<TResult>; auto setFuture = &setFutureHelper<TResult>;
if (outHandle != nullptr) if (outHandle != nullptr)
{ {
*outHandle = TaskHandle(sharedState); *outHandle = TaskHandle(task.sharedState());
} }
// add tasks to a seperate vector first as we might be running another task right now // add tasks to a seperate vector first as we might be running another task right now
addStoredTask(StoredTask{ addStoredTask(StoredTask{
.sharedState = std::move(sharedState),
.task = wrapTask(std::move(task)), .task = wrapTask(std::move(task)),
.setFuture = setFuture, .setFuture = setFuture,
.resultData = future .resultData = future
@ -703,6 +742,11 @@ inline auto SimpleTaskLoop::tick() -> CanContinue
{ {
StoredTask& task = *currentTask_; StoredTask& task = *currentTask_;
TaskStatus status = task.task->status(); TaskStatus status = task.task->status();
if (status == TaskStatus::WAITING && task.task->sharedState()->cancelled_)
{
// always continue a cancelled task, even if it was still waiting for a result
status = TaskStatus::SUSPENDED;
}
if (status != TaskStatus::SUSPENDED && status != TaskStatus::YIELDED) if (status != TaskStatus::SUSPENDED && status != TaskStatus::YIELDED)
{ {
MIJIN_ASSERT(status == TaskStatus::WAITING, "Task with invalid status in task list!"); MIJIN_ASSERT(status == TaskStatus::WAITING, "Task with invalid status in task list!");