Also cancel sub-tasks (those that are awaited by this one) when cancelling a task.

This commit is contained in:
Patrick 2023-11-18 22:20:47 +01:00
parent d98e14285b
commit 065181fc69
2 changed files with 92 additions and 48 deletions

View File

@ -180,7 +180,7 @@ void MultiThreadedTaskLoop::transferCurrentTask(TaskLoop& otherLoop) noexcept
return;
}
MIJIN_ASSERT_FATAL(currentTask_ != nullptr, "Trying to call transferCurrentTask() while not running a task!");
MIJIN_ASSERT_FATAL(impl::gCurrentTask != nullptr, "Trying to call transferCurrentTask() while not running a task!");
// now start the transfer, first disown the task
StoredTask storedTask = std::move(*impl::gCurrentTask);

View File

@ -4,6 +4,15 @@
#ifndef MIJIN_ASYNC_COROUTINE_HPP_INCLUDED
#define MIJIN_ASYNC_COROUTINE_HPP_INCLUDED 1
#if !defined(MIJIN_COROUTINE_ENABLE_DEBUG_INFO)
#if defined(MIJIN_DEBUG)
# define MIJIN_COROUTINE_ENABLE_DEBUG_INFO 1
#else
# define MIJIN_COROUTINE_ENABLE_DEBUG_INFO 0
#endif
#endif
#include <any>
#include <coroutine>
#include <memory>
@ -15,6 +24,9 @@
#include "../container/optional.hpp"
#include "../util/flag.hpp"
#include "../util/traits.hpp"
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
#include "../debug/stacktrace.hpp"
#endif
namespace mijin
{
@ -23,10 +35,6 @@ namespace mijin
// public defines
//
#if !defined(MIJIN_COROUTINE_ENABLE_DEBUG_INFO)
#define MIJIN_COROUTINE_ENABLE_DEBUG_INFO 0
#endif
//
// public types
//
@ -56,6 +64,32 @@ namespace impl
inline void throwIfCancelled();
} // namespace impl
class TaskHandle
{
private:
std::weak_ptr<struct TaskSharedState> state_;
public:
TaskHandle() = default;
explicit TaskHandle(std::weak_ptr<TaskSharedState> state) noexcept : state_(std::move(state)) {}
TaskHandle(const TaskHandle&) = default;
TaskHandle(TaskHandle&&) = default;
TaskHandle& operator=(const TaskHandle&) = default;
TaskHandle& operator=(TaskHandle&&) = default;
[[nodiscard]] bool isValid() const noexcept
{
return !state_.expired();
}
inline void cancel() const noexcept;
};
struct TaskSharedState
{
std::atomic_bool cancelled_ = false;
TaskHandle subTask;
};
template<typename T>
struct TaskState
{
@ -193,10 +227,11 @@ struct TaskPromise : impl::TaskReturn<typename TTraits::result_t, TaskPromise<TT
using result_t = typename TTraits::result_t;
TaskState<result_t> state_;
std::shared_ptr<TaskSharedState> sharedState_ = std::make_shared<TaskSharedState>();
TaskLoop* loop_ = nullptr;
constexpr task_t get_return_object() noexcept { return task_t(handle_t::from_promise(*this)); }
constexpr std::suspend_always initial_suspend() noexcept { return {}; }
constexpr TaskAwaitableSuspend initial_suspend() noexcept { return {}; }
constexpr std::suspend_always final_suspend() noexcept { return {}; }
// template<typename TValue>
@ -229,7 +264,7 @@ struct TaskPromise : impl::TaskReturn<typename TTraits::result_t, TaskPromise<TT
auto await_transform(TaskBase<TResultOther> task) noexcept
{
MIJIN_ASSERT(loop_ != nullptr, "Cannot await another task outside of a loop!");
auto future = delayEvaluation<TResultOther>(loop_)->addTask(std::move(task)); // hackidyhack: delay evaluation of the type of loop_ as it is only forward-declared here
auto future = delayEvaluation<TResultOther>(loop_)->addTask(std::move(task), &sharedState_->subTask); // hackidyhack: delay evaluation of the type of loop_ as it is only forward-declared here
return await_transform(future);
}
@ -305,14 +340,38 @@ public:
using handle_t = typename promise_type::handle_t;
private:
handle_t handle_;
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
Stacktrace creationStack_;
#endif
public:
constexpr explicit TaskBase(handle_t handle) noexcept : handle_(handle) {}
constexpr explicit TaskBase(handle_t handle) noexcept : handle_(handle) {
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
if (Result<Stacktrace> stacktrace = captureStacktrace(1); stacktrace.isSuccess())
{
creationStack_ = *stacktrace;
}
#endif
}
TaskBase(const TaskBase&) = delete;
TaskBase(TaskBase&& other) noexcept : handle_(std::exchange(other.handle_, nullptr)) {}
TaskBase(TaskBase&& other) noexcept : handle_(std::exchange(other.handle_, nullptr))
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
, creationStack_(std::move(other.creationStack_))
#endif
{}
~TaskBase() noexcept;
public:
TaskBase& operator=(const TaskBase&) = default;
TaskBase& operator=(TaskBase&& other) noexcept = default;
TaskBase& operator=(const TaskBase&) = delete;
TaskBase& operator=(TaskBase&& other) noexcept
{
if (handle_) {
handle_.destroy();
}
handle_ = std::exchange(other.handle_, nullptr);
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
creationStack_ = std::move(other.creationStack_);
#endif
return *this;
}
[[nodiscard]]
constexpr bool operator==(const TaskBase& other) const noexcept { return handle_ == other.handle_; }
@ -331,6 +390,10 @@ public:
handle_.resume();
return state();
}
constexpr std::shared_ptr<TaskSharedState>& sharedState() noexcept
{
return handle_.promise().sharedState_;
}
private:
[[nodiscard]]
constexpr handle_t handle() const noexcept { return handle_; }
@ -365,6 +428,7 @@ public:
virtual void* raw() noexcept = 0;
virtual std::coroutine_handle<> handle() noexcept = 0;
virtual void setLoop(TaskLoop* loop) noexcept = 0;
virtual std::shared_ptr<TaskSharedState>& sharedState() noexcept = 0;
[[nodiscard]] inline bool canResume() {
const TaskStatus stat = status();
@ -400,38 +464,7 @@ public:
void* raw() noexcept override { return &task_; }
std::coroutine_handle<> handle() noexcept override { return task_.handle(); }
void setLoop(TaskLoop* loop) noexcept override { task_.setLoop(loop); }
};
struct TaskSharedState
{
std::atomic_bool cancelled_ = false;
};
class TaskHandle
{
private:
std::weak_ptr<TaskSharedState> state_;
public:
TaskHandle() = default;
explicit TaskHandle(std::weak_ptr<TaskSharedState> state) noexcept : state_(std::move(state)) {}
TaskHandle(const TaskHandle&) = default;
TaskHandle(TaskHandle&&) = default;
TaskHandle& operator=(const TaskHandle&) = default;
TaskHandle& operator=(TaskHandle&&) = default;
[[nodiscard]] bool isValid() const noexcept
{
return !state_.expired();
}
void cancel() const noexcept
{
if (std::shared_ptr<TaskSharedState> state = state_.lock())
{
state->cancelled_ = true;
}
}
virtual std::shared_ptr<TaskSharedState>& sharedState() noexcept { return task_.sharedState(); }
};
template<typename TTask>
@ -450,7 +483,6 @@ public:
using wrapped_task_base_ptr_t = std::unique_ptr<wrapped_task_t>;
struct StoredTask
{
std::shared_ptr<TaskSharedState> sharedState;
wrapped_task_base_ptr_t task;
std::function<void(StoredTask&)> setFuture;
std::any resultData;
@ -546,13 +578,22 @@ extern thread_local TaskLoop::StoredTask* gCurrentTask;
inline void throwIfCancelled()
{
if (gCurrentTask->sharedState->cancelled_)
if (gCurrentTask->task->sharedState()->cancelled_)
{
throw TaskCancelled();
}
}
}
void TaskHandle::cancel() const noexcept
{
if (std::shared_ptr<TaskSharedState> state = state_.lock())
{
state->cancelled_ = true;
state->subTask.cancel();
}
}
template<typename TResult>
TaskBase<TResult>::~TaskBase() noexcept
{
@ -568,18 +609,16 @@ inline FuturePtr<TResult> TaskLoop::addTask(TaskBase<TResult> task, TaskHandle*
MIJIN_ASSERT(!task.getLoop(), "Attempting to add task that already has a loop!");
task.setLoop(this);
auto sharedState = std::make_shared<TaskSharedState>();
auto future = std::make_shared<Future<TResult>>();
auto setFuture = &setFutureHelper<TResult>;
if (outHandle != nullptr)
{
*outHandle = TaskHandle(sharedState);
*outHandle = TaskHandle(task.sharedState());
}
// add tasks to a seperate vector first as we might be running another task right now
addStoredTask(StoredTask{
.sharedState = std::move(sharedState),
.task = wrapTask(std::move(task)),
.setFuture = setFuture,
.resultData = future
@ -703,6 +742,11 @@ inline auto SimpleTaskLoop::tick() -> CanContinue
{
StoredTask& task = *currentTask_;
TaskStatus status = task.task->status();
if (status == TaskStatus::WAITING && task.task->sharedState()->cancelled_)
{
// always continue a cancelled task, even if it was still waiting for a result
status = TaskStatus::SUSPENDED;
}
if (status != TaskStatus::SUSPENDED && status != TaskStatus::YIELDED)
{
MIJIN_ASSERT(status == TaskStatus::WAITING, "Task with invalid status in task list!");