Also cancel sub-tasks (those that are awaited by this one) when cancelling a task.
This commit is contained in:
parent
d98e14285b
commit
065181fc69
@ -180,7 +180,7 @@ void MultiThreadedTaskLoop::transferCurrentTask(TaskLoop& otherLoop) noexcept
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
MIJIN_ASSERT_FATAL(currentTask_ != nullptr, "Trying to call transferCurrentTask() while not running a task!");
|
MIJIN_ASSERT_FATAL(impl::gCurrentTask != nullptr, "Trying to call transferCurrentTask() while not running a task!");
|
||||||
|
|
||||||
// now start the transfer, first disown the task
|
// now start the transfer, first disown the task
|
||||||
StoredTask storedTask = std::move(*impl::gCurrentTask);
|
StoredTask storedTask = std::move(*impl::gCurrentTask);
|
||||||
|
@ -4,6 +4,15 @@
|
|||||||
#ifndef MIJIN_ASYNC_COROUTINE_HPP_INCLUDED
|
#ifndef MIJIN_ASYNC_COROUTINE_HPP_INCLUDED
|
||||||
#define MIJIN_ASYNC_COROUTINE_HPP_INCLUDED 1
|
#define MIJIN_ASYNC_COROUTINE_HPP_INCLUDED 1
|
||||||
|
|
||||||
|
|
||||||
|
#if !defined(MIJIN_COROUTINE_ENABLE_DEBUG_INFO)
|
||||||
|
#if defined(MIJIN_DEBUG)
|
||||||
|
# define MIJIN_COROUTINE_ENABLE_DEBUG_INFO 1
|
||||||
|
#else
|
||||||
|
# define MIJIN_COROUTINE_ENABLE_DEBUG_INFO 0
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <any>
|
#include <any>
|
||||||
#include <coroutine>
|
#include <coroutine>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
@ -15,6 +24,9 @@
|
|||||||
#include "../container/optional.hpp"
|
#include "../container/optional.hpp"
|
||||||
#include "../util/flag.hpp"
|
#include "../util/flag.hpp"
|
||||||
#include "../util/traits.hpp"
|
#include "../util/traits.hpp"
|
||||||
|
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
|
||||||
|
#include "../debug/stacktrace.hpp"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace mijin
|
namespace mijin
|
||||||
{
|
{
|
||||||
@ -23,10 +35,6 @@ namespace mijin
|
|||||||
// public defines
|
// public defines
|
||||||
//
|
//
|
||||||
|
|
||||||
#if !defined(MIJIN_COROUTINE_ENABLE_DEBUG_INFO)
|
|
||||||
#define MIJIN_COROUTINE_ENABLE_DEBUG_INFO 0
|
|
||||||
#endif
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// public types
|
// public types
|
||||||
//
|
//
|
||||||
@ -56,6 +64,32 @@ namespace impl
|
|||||||
inline void throwIfCancelled();
|
inline void throwIfCancelled();
|
||||||
} // namespace impl
|
} // namespace impl
|
||||||
|
|
||||||
|
class TaskHandle
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
std::weak_ptr<struct TaskSharedState> state_;
|
||||||
|
public:
|
||||||
|
TaskHandle() = default;
|
||||||
|
explicit TaskHandle(std::weak_ptr<TaskSharedState> state) noexcept : state_(std::move(state)) {}
|
||||||
|
TaskHandle(const TaskHandle&) = default;
|
||||||
|
TaskHandle(TaskHandle&&) = default;
|
||||||
|
|
||||||
|
TaskHandle& operator=(const TaskHandle&) = default;
|
||||||
|
TaskHandle& operator=(TaskHandle&&) = default;
|
||||||
|
|
||||||
|
[[nodiscard]] bool isValid() const noexcept
|
||||||
|
{
|
||||||
|
return !state_.expired();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void cancel() const noexcept;
|
||||||
|
};
|
||||||
|
struct TaskSharedState
|
||||||
|
{
|
||||||
|
std::atomic_bool cancelled_ = false;
|
||||||
|
TaskHandle subTask;
|
||||||
|
};
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
struct TaskState
|
struct TaskState
|
||||||
{
|
{
|
||||||
@ -193,10 +227,11 @@ struct TaskPromise : impl::TaskReturn<typename TTraits::result_t, TaskPromise<TT
|
|||||||
using result_t = typename TTraits::result_t;
|
using result_t = typename TTraits::result_t;
|
||||||
|
|
||||||
TaskState<result_t> state_;
|
TaskState<result_t> state_;
|
||||||
|
std::shared_ptr<TaskSharedState> sharedState_ = std::make_shared<TaskSharedState>();
|
||||||
TaskLoop* loop_ = nullptr;
|
TaskLoop* loop_ = nullptr;
|
||||||
|
|
||||||
constexpr task_t get_return_object() noexcept { return task_t(handle_t::from_promise(*this)); }
|
constexpr task_t get_return_object() noexcept { return task_t(handle_t::from_promise(*this)); }
|
||||||
constexpr std::suspend_always initial_suspend() noexcept { return {}; }
|
constexpr TaskAwaitableSuspend initial_suspend() noexcept { return {}; }
|
||||||
constexpr std::suspend_always final_suspend() noexcept { return {}; }
|
constexpr std::suspend_always final_suspend() noexcept { return {}; }
|
||||||
|
|
||||||
// template<typename TValue>
|
// template<typename TValue>
|
||||||
@ -229,7 +264,7 @@ struct TaskPromise : impl::TaskReturn<typename TTraits::result_t, TaskPromise<TT
|
|||||||
auto await_transform(TaskBase<TResultOther> task) noexcept
|
auto await_transform(TaskBase<TResultOther> task) noexcept
|
||||||
{
|
{
|
||||||
MIJIN_ASSERT(loop_ != nullptr, "Cannot await another task outside of a loop!");
|
MIJIN_ASSERT(loop_ != nullptr, "Cannot await another task outside of a loop!");
|
||||||
auto future = delayEvaluation<TResultOther>(loop_)->addTask(std::move(task)); // hackidyhack: delay evaluation of the type of loop_ as it is only forward-declared here
|
auto future = delayEvaluation<TResultOther>(loop_)->addTask(std::move(task), &sharedState_->subTask); // hackidyhack: delay evaluation of the type of loop_ as it is only forward-declared here
|
||||||
return await_transform(future);
|
return await_transform(future);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -305,14 +340,38 @@ public:
|
|||||||
using handle_t = typename promise_type::handle_t;
|
using handle_t = typename promise_type::handle_t;
|
||||||
private:
|
private:
|
||||||
handle_t handle_;
|
handle_t handle_;
|
||||||
|
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
|
||||||
|
Stacktrace creationStack_;
|
||||||
|
#endif
|
||||||
public:
|
public:
|
||||||
constexpr explicit TaskBase(handle_t handle) noexcept : handle_(handle) {}
|
constexpr explicit TaskBase(handle_t handle) noexcept : handle_(handle) {
|
||||||
|
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
|
||||||
|
if (Result<Stacktrace> stacktrace = captureStacktrace(1); stacktrace.isSuccess())
|
||||||
|
{
|
||||||
|
creationStack_ = *stacktrace;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
TaskBase(const TaskBase&) = delete;
|
TaskBase(const TaskBase&) = delete;
|
||||||
TaskBase(TaskBase&& other) noexcept : handle_(std::exchange(other.handle_, nullptr)) {}
|
TaskBase(TaskBase&& other) noexcept : handle_(std::exchange(other.handle_, nullptr))
|
||||||
|
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
|
||||||
|
, creationStack_(std::move(other.creationStack_))
|
||||||
|
#endif
|
||||||
|
{}
|
||||||
~TaskBase() noexcept;
|
~TaskBase() noexcept;
|
||||||
public:
|
public:
|
||||||
TaskBase& operator=(const TaskBase&) = default;
|
TaskBase& operator=(const TaskBase&) = delete;
|
||||||
TaskBase& operator=(TaskBase&& other) noexcept = default;
|
TaskBase& operator=(TaskBase&& other) noexcept
|
||||||
|
{
|
||||||
|
if (handle_) {
|
||||||
|
handle_.destroy();
|
||||||
|
}
|
||||||
|
handle_ = std::exchange(other.handle_, nullptr);
|
||||||
|
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
|
||||||
|
creationStack_ = std::move(other.creationStack_);
|
||||||
|
#endif
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
[[nodiscard]]
|
[[nodiscard]]
|
||||||
constexpr bool operator==(const TaskBase& other) const noexcept { return handle_ == other.handle_; }
|
constexpr bool operator==(const TaskBase& other) const noexcept { return handle_ == other.handle_; }
|
||||||
@ -331,6 +390,10 @@ public:
|
|||||||
handle_.resume();
|
handle_.resume();
|
||||||
return state();
|
return state();
|
||||||
}
|
}
|
||||||
|
constexpr std::shared_ptr<TaskSharedState>& sharedState() noexcept
|
||||||
|
{
|
||||||
|
return handle_.promise().sharedState_;
|
||||||
|
}
|
||||||
private:
|
private:
|
||||||
[[nodiscard]]
|
[[nodiscard]]
|
||||||
constexpr handle_t handle() const noexcept { return handle_; }
|
constexpr handle_t handle() const noexcept { return handle_; }
|
||||||
@ -365,6 +428,7 @@ public:
|
|||||||
virtual void* raw() noexcept = 0;
|
virtual void* raw() noexcept = 0;
|
||||||
virtual std::coroutine_handle<> handle() noexcept = 0;
|
virtual std::coroutine_handle<> handle() noexcept = 0;
|
||||||
virtual void setLoop(TaskLoop* loop) noexcept = 0;
|
virtual void setLoop(TaskLoop* loop) noexcept = 0;
|
||||||
|
virtual std::shared_ptr<TaskSharedState>& sharedState() noexcept = 0;
|
||||||
|
|
||||||
[[nodiscard]] inline bool canResume() {
|
[[nodiscard]] inline bool canResume() {
|
||||||
const TaskStatus stat = status();
|
const TaskStatus stat = status();
|
||||||
@ -400,38 +464,7 @@ public:
|
|||||||
void* raw() noexcept override { return &task_; }
|
void* raw() noexcept override { return &task_; }
|
||||||
std::coroutine_handle<> handle() noexcept override { return task_.handle(); }
|
std::coroutine_handle<> handle() noexcept override { return task_.handle(); }
|
||||||
void setLoop(TaskLoop* loop) noexcept override { task_.setLoop(loop); }
|
void setLoop(TaskLoop* loop) noexcept override { task_.setLoop(loop); }
|
||||||
};
|
virtual std::shared_ptr<TaskSharedState>& sharedState() noexcept { return task_.sharedState(); }
|
||||||
|
|
||||||
struct TaskSharedState
|
|
||||||
{
|
|
||||||
std::atomic_bool cancelled_ = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
class TaskHandle
|
|
||||||
{
|
|
||||||
private:
|
|
||||||
std::weak_ptr<TaskSharedState> state_;
|
|
||||||
public:
|
|
||||||
TaskHandle() = default;
|
|
||||||
explicit TaskHandle(std::weak_ptr<TaskSharedState> state) noexcept : state_(std::move(state)) {}
|
|
||||||
TaskHandle(const TaskHandle&) = default;
|
|
||||||
TaskHandle(TaskHandle&&) = default;
|
|
||||||
|
|
||||||
TaskHandle& operator=(const TaskHandle&) = default;
|
|
||||||
TaskHandle& operator=(TaskHandle&&) = default;
|
|
||||||
|
|
||||||
[[nodiscard]] bool isValid() const noexcept
|
|
||||||
{
|
|
||||||
return !state_.expired();
|
|
||||||
}
|
|
||||||
|
|
||||||
void cancel() const noexcept
|
|
||||||
{
|
|
||||||
if (std::shared_ptr<TaskSharedState> state = state_.lock())
|
|
||||||
{
|
|
||||||
state->cancelled_ = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename TTask>
|
template<typename TTask>
|
||||||
@ -450,7 +483,6 @@ public:
|
|||||||
using wrapped_task_base_ptr_t = std::unique_ptr<wrapped_task_t>;
|
using wrapped_task_base_ptr_t = std::unique_ptr<wrapped_task_t>;
|
||||||
struct StoredTask
|
struct StoredTask
|
||||||
{
|
{
|
||||||
std::shared_ptr<TaskSharedState> sharedState;
|
|
||||||
wrapped_task_base_ptr_t task;
|
wrapped_task_base_ptr_t task;
|
||||||
std::function<void(StoredTask&)> setFuture;
|
std::function<void(StoredTask&)> setFuture;
|
||||||
std::any resultData;
|
std::any resultData;
|
||||||
@ -546,13 +578,22 @@ extern thread_local TaskLoop::StoredTask* gCurrentTask;
|
|||||||
|
|
||||||
inline void throwIfCancelled()
|
inline void throwIfCancelled()
|
||||||
{
|
{
|
||||||
if (gCurrentTask->sharedState->cancelled_)
|
if (gCurrentTask->task->sharedState()->cancelled_)
|
||||||
{
|
{
|
||||||
throw TaskCancelled();
|
throw TaskCancelled();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TaskHandle::cancel() const noexcept
|
||||||
|
{
|
||||||
|
if (std::shared_ptr<TaskSharedState> state = state_.lock())
|
||||||
|
{
|
||||||
|
state->cancelled_ = true;
|
||||||
|
state->subTask.cancel();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template<typename TResult>
|
template<typename TResult>
|
||||||
TaskBase<TResult>::~TaskBase() noexcept
|
TaskBase<TResult>::~TaskBase() noexcept
|
||||||
{
|
{
|
||||||
@ -568,18 +609,16 @@ inline FuturePtr<TResult> TaskLoop::addTask(TaskBase<TResult> task, TaskHandle*
|
|||||||
MIJIN_ASSERT(!task.getLoop(), "Attempting to add task that already has a loop!");
|
MIJIN_ASSERT(!task.getLoop(), "Attempting to add task that already has a loop!");
|
||||||
task.setLoop(this);
|
task.setLoop(this);
|
||||||
|
|
||||||
auto sharedState = std::make_shared<TaskSharedState>();
|
|
||||||
auto future = std::make_shared<Future<TResult>>();
|
auto future = std::make_shared<Future<TResult>>();
|
||||||
auto setFuture = &setFutureHelper<TResult>;
|
auto setFuture = &setFutureHelper<TResult>;
|
||||||
|
|
||||||
if (outHandle != nullptr)
|
if (outHandle != nullptr)
|
||||||
{
|
{
|
||||||
*outHandle = TaskHandle(sharedState);
|
*outHandle = TaskHandle(task.sharedState());
|
||||||
}
|
}
|
||||||
|
|
||||||
// add tasks to a seperate vector first as we might be running another task right now
|
// add tasks to a seperate vector first as we might be running another task right now
|
||||||
addStoredTask(StoredTask{
|
addStoredTask(StoredTask{
|
||||||
.sharedState = std::move(sharedState),
|
|
||||||
.task = wrapTask(std::move(task)),
|
.task = wrapTask(std::move(task)),
|
||||||
.setFuture = setFuture,
|
.setFuture = setFuture,
|
||||||
.resultData = future
|
.resultData = future
|
||||||
@ -703,6 +742,11 @@ inline auto SimpleTaskLoop::tick() -> CanContinue
|
|||||||
{
|
{
|
||||||
StoredTask& task = *currentTask_;
|
StoredTask& task = *currentTask_;
|
||||||
TaskStatus status = task.task->status();
|
TaskStatus status = task.task->status();
|
||||||
|
if (status == TaskStatus::WAITING && task.task->sharedState()->cancelled_)
|
||||||
|
{
|
||||||
|
// always continue a cancelled task, even if it was still waiting for a result
|
||||||
|
status = TaskStatus::SUSPENDED;
|
||||||
|
}
|
||||||
if (status != TaskStatus::SUSPENDED && status != TaskStatus::YIELDED)
|
if (status != TaskStatus::SUSPENDED && status != TaskStatus::YIELDED)
|
||||||
{
|
{
|
||||||
MIJIN_ASSERT(status == TaskStatus::WAITING, "Task with invalid status in task list!");
|
MIJIN_ASSERT(status == TaskStatus::WAITING, "Task with invalid status in task list!");
|
||||||
|
Loading…
x
Reference in New Issue
Block a user