Also cancel sub-tasks (those that are awaited by this one) when cancelling a task.
This commit is contained in:
parent
d98e14285b
commit
065181fc69
@ -180,7 +180,7 @@ void MultiThreadedTaskLoop::transferCurrentTask(TaskLoop& otherLoop) noexcept
|
||||
return;
|
||||
}
|
||||
|
||||
MIJIN_ASSERT_FATAL(currentTask_ != nullptr, "Trying to call transferCurrentTask() while not running a task!");
|
||||
MIJIN_ASSERT_FATAL(impl::gCurrentTask != nullptr, "Trying to call transferCurrentTask() while not running a task!");
|
||||
|
||||
// now start the transfer, first disown the task
|
||||
StoredTask storedTask = std::move(*impl::gCurrentTask);
|
||||
|
@ -4,6 +4,15 @@
|
||||
#ifndef MIJIN_ASYNC_COROUTINE_HPP_INCLUDED
|
||||
#define MIJIN_ASYNC_COROUTINE_HPP_INCLUDED 1
|
||||
|
||||
|
||||
#if !defined(MIJIN_COROUTINE_ENABLE_DEBUG_INFO)
|
||||
#if defined(MIJIN_DEBUG)
|
||||
# define MIJIN_COROUTINE_ENABLE_DEBUG_INFO 1
|
||||
#else
|
||||
# define MIJIN_COROUTINE_ENABLE_DEBUG_INFO 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#include <any>
|
||||
#include <coroutine>
|
||||
#include <memory>
|
||||
@ -15,6 +24,9 @@
|
||||
#include "../container/optional.hpp"
|
||||
#include "../util/flag.hpp"
|
||||
#include "../util/traits.hpp"
|
||||
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
|
||||
#include "../debug/stacktrace.hpp"
|
||||
#endif
|
||||
|
||||
namespace mijin
|
||||
{
|
||||
@ -23,10 +35,6 @@ namespace mijin
|
||||
// public defines
|
||||
//
|
||||
|
||||
#if !defined(MIJIN_COROUTINE_ENABLE_DEBUG_INFO)
|
||||
#define MIJIN_COROUTINE_ENABLE_DEBUG_INFO 0
|
||||
#endif
|
||||
|
||||
//
|
||||
// public types
|
||||
//
|
||||
@ -56,6 +64,32 @@ namespace impl
|
||||
inline void throwIfCancelled();
|
||||
} // namespace impl
|
||||
|
||||
class TaskHandle
|
||||
{
|
||||
private:
|
||||
std::weak_ptr<struct TaskSharedState> state_;
|
||||
public:
|
||||
TaskHandle() = default;
|
||||
explicit TaskHandle(std::weak_ptr<TaskSharedState> state) noexcept : state_(std::move(state)) {}
|
||||
TaskHandle(const TaskHandle&) = default;
|
||||
TaskHandle(TaskHandle&&) = default;
|
||||
|
||||
TaskHandle& operator=(const TaskHandle&) = default;
|
||||
TaskHandle& operator=(TaskHandle&&) = default;
|
||||
|
||||
[[nodiscard]] bool isValid() const noexcept
|
||||
{
|
||||
return !state_.expired();
|
||||
}
|
||||
|
||||
inline void cancel() const noexcept;
|
||||
};
|
||||
struct TaskSharedState
|
||||
{
|
||||
std::atomic_bool cancelled_ = false;
|
||||
TaskHandle subTask;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct TaskState
|
||||
{
|
||||
@ -193,10 +227,11 @@ struct TaskPromise : impl::TaskReturn<typename TTraits::result_t, TaskPromise<TT
|
||||
using result_t = typename TTraits::result_t;
|
||||
|
||||
TaskState<result_t> state_;
|
||||
std::shared_ptr<TaskSharedState> sharedState_ = std::make_shared<TaskSharedState>();
|
||||
TaskLoop* loop_ = nullptr;
|
||||
|
||||
constexpr task_t get_return_object() noexcept { return task_t(handle_t::from_promise(*this)); }
|
||||
constexpr std::suspend_always initial_suspend() noexcept { return {}; }
|
||||
constexpr TaskAwaitableSuspend initial_suspend() noexcept { return {}; }
|
||||
constexpr std::suspend_always final_suspend() noexcept { return {}; }
|
||||
|
||||
// template<typename TValue>
|
||||
@ -229,7 +264,7 @@ struct TaskPromise : impl::TaskReturn<typename TTraits::result_t, TaskPromise<TT
|
||||
auto await_transform(TaskBase<TResultOther> task) noexcept
|
||||
{
|
||||
MIJIN_ASSERT(loop_ != nullptr, "Cannot await another task outside of a loop!");
|
||||
auto future = delayEvaluation<TResultOther>(loop_)->addTask(std::move(task)); // hackidyhack: delay evaluation of the type of loop_ as it is only forward-declared here
|
||||
auto future = delayEvaluation<TResultOther>(loop_)->addTask(std::move(task), &sharedState_->subTask); // hackidyhack: delay evaluation of the type of loop_ as it is only forward-declared here
|
||||
return await_transform(future);
|
||||
}
|
||||
|
||||
@ -305,14 +340,38 @@ public:
|
||||
using handle_t = typename promise_type::handle_t;
|
||||
private:
|
||||
handle_t handle_;
|
||||
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
|
||||
Stacktrace creationStack_;
|
||||
#endif
|
||||
public:
|
||||
constexpr explicit TaskBase(handle_t handle) noexcept : handle_(handle) {}
|
||||
constexpr explicit TaskBase(handle_t handle) noexcept : handle_(handle) {
|
||||
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
|
||||
if (Result<Stacktrace> stacktrace = captureStacktrace(1); stacktrace.isSuccess())
|
||||
{
|
||||
creationStack_ = *stacktrace;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
TaskBase(const TaskBase&) = delete;
|
||||
TaskBase(TaskBase&& other) noexcept : handle_(std::exchange(other.handle_, nullptr)) {}
|
||||
TaskBase(TaskBase&& other) noexcept : handle_(std::exchange(other.handle_, nullptr))
|
||||
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
|
||||
, creationStack_(std::move(other.creationStack_))
|
||||
#endif
|
||||
{}
|
||||
~TaskBase() noexcept;
|
||||
public:
|
||||
TaskBase& operator=(const TaskBase&) = default;
|
||||
TaskBase& operator=(TaskBase&& other) noexcept = default;
|
||||
TaskBase& operator=(const TaskBase&) = delete;
|
||||
TaskBase& operator=(TaskBase&& other) noexcept
|
||||
{
|
||||
if (handle_) {
|
||||
handle_.destroy();
|
||||
}
|
||||
handle_ = std::exchange(other.handle_, nullptr);
|
||||
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
|
||||
creationStack_ = std::move(other.creationStack_);
|
||||
#endif
|
||||
return *this;
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
constexpr bool operator==(const TaskBase& other) const noexcept { return handle_ == other.handle_; }
|
||||
@ -331,6 +390,10 @@ public:
|
||||
handle_.resume();
|
||||
return state();
|
||||
}
|
||||
constexpr std::shared_ptr<TaskSharedState>& sharedState() noexcept
|
||||
{
|
||||
return handle_.promise().sharedState_;
|
||||
}
|
||||
private:
|
||||
[[nodiscard]]
|
||||
constexpr handle_t handle() const noexcept { return handle_; }
|
||||
@ -365,6 +428,7 @@ public:
|
||||
virtual void* raw() noexcept = 0;
|
||||
virtual std::coroutine_handle<> handle() noexcept = 0;
|
||||
virtual void setLoop(TaskLoop* loop) noexcept = 0;
|
||||
virtual std::shared_ptr<TaskSharedState>& sharedState() noexcept = 0;
|
||||
|
||||
[[nodiscard]] inline bool canResume() {
|
||||
const TaskStatus stat = status();
|
||||
@ -400,38 +464,7 @@ public:
|
||||
void* raw() noexcept override { return &task_; }
|
||||
std::coroutine_handle<> handle() noexcept override { return task_.handle(); }
|
||||
void setLoop(TaskLoop* loop) noexcept override { task_.setLoop(loop); }
|
||||
};
|
||||
|
||||
struct TaskSharedState
|
||||
{
|
||||
std::atomic_bool cancelled_ = false;
|
||||
};
|
||||
|
||||
class TaskHandle
|
||||
{
|
||||
private:
|
||||
std::weak_ptr<TaskSharedState> state_;
|
||||
public:
|
||||
TaskHandle() = default;
|
||||
explicit TaskHandle(std::weak_ptr<TaskSharedState> state) noexcept : state_(std::move(state)) {}
|
||||
TaskHandle(const TaskHandle&) = default;
|
||||
TaskHandle(TaskHandle&&) = default;
|
||||
|
||||
TaskHandle& operator=(const TaskHandle&) = default;
|
||||
TaskHandle& operator=(TaskHandle&&) = default;
|
||||
|
||||
[[nodiscard]] bool isValid() const noexcept
|
||||
{
|
||||
return !state_.expired();
|
||||
}
|
||||
|
||||
void cancel() const noexcept
|
||||
{
|
||||
if (std::shared_ptr<TaskSharedState> state = state_.lock())
|
||||
{
|
||||
state->cancelled_ = true;
|
||||
}
|
||||
}
|
||||
virtual std::shared_ptr<TaskSharedState>& sharedState() noexcept { return task_.sharedState(); }
|
||||
};
|
||||
|
||||
template<typename TTask>
|
||||
@ -450,7 +483,6 @@ public:
|
||||
using wrapped_task_base_ptr_t = std::unique_ptr<wrapped_task_t>;
|
||||
struct StoredTask
|
||||
{
|
||||
std::shared_ptr<TaskSharedState> sharedState;
|
||||
wrapped_task_base_ptr_t task;
|
||||
std::function<void(StoredTask&)> setFuture;
|
||||
std::any resultData;
|
||||
@ -546,13 +578,22 @@ extern thread_local TaskLoop::StoredTask* gCurrentTask;
|
||||
|
||||
inline void throwIfCancelled()
|
||||
{
|
||||
if (gCurrentTask->sharedState->cancelled_)
|
||||
if (gCurrentTask->task->sharedState()->cancelled_)
|
||||
{
|
||||
throw TaskCancelled();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TaskHandle::cancel() const noexcept
|
||||
{
|
||||
if (std::shared_ptr<TaskSharedState> state = state_.lock())
|
||||
{
|
||||
state->cancelled_ = true;
|
||||
state->subTask.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TResult>
|
||||
TaskBase<TResult>::~TaskBase() noexcept
|
||||
{
|
||||
@ -568,18 +609,16 @@ inline FuturePtr<TResult> TaskLoop::addTask(TaskBase<TResult> task, TaskHandle*
|
||||
MIJIN_ASSERT(!task.getLoop(), "Attempting to add task that already has a loop!");
|
||||
task.setLoop(this);
|
||||
|
||||
auto sharedState = std::make_shared<TaskSharedState>();
|
||||
auto future = std::make_shared<Future<TResult>>();
|
||||
auto setFuture = &setFutureHelper<TResult>;
|
||||
|
||||
if (outHandle != nullptr)
|
||||
{
|
||||
*outHandle = TaskHandle(sharedState);
|
||||
*outHandle = TaskHandle(task.sharedState());
|
||||
}
|
||||
|
||||
// add tasks to a seperate vector first as we might be running another task right now
|
||||
addStoredTask(StoredTask{
|
||||
.sharedState = std::move(sharedState),
|
||||
.task = wrapTask(std::move(task)),
|
||||
.setFuture = setFuture,
|
||||
.resultData = future
|
||||
@ -703,6 +742,11 @@ inline auto SimpleTaskLoop::tick() -> CanContinue
|
||||
{
|
||||
StoredTask& task = *currentTask_;
|
||||
TaskStatus status = task.task->status();
|
||||
if (status == TaskStatus::WAITING && task.task->sharedState()->cancelled_)
|
||||
{
|
||||
// always continue a cancelled task, even if it was still waiting for a result
|
||||
status = TaskStatus::SUSPENDED;
|
||||
}
|
||||
if (status != TaskStatus::SUSPENDED && status != TaskStatus::YIELDED)
|
||||
{
|
||||
MIJIN_ASSERT(status == TaskStatus::WAITING, "Task with invalid status in task list!");
|
||||
|
Loading…
x
Reference in New Issue
Block a user