Added cancelling (and some exception handling) for coroutines.

This commit is contained in:
Patrick 2023-11-03 00:48:13 +01:00
parent 54c63cfe69
commit 89bb110116
2 changed files with 141 additions and 47 deletions

View File

@ -24,7 +24,10 @@ namespace mijin
// internal variables // internal variables
// //
thread_local TaskLoop::StoredTask* MultiThreadedTaskLoop::currentTask_ = nullptr; namespace impl
{
thread_local TaskLoop::StoredTask* gCurrentTask = nullptr;
}
// //
// internal functions // internal functions
@ -119,9 +122,9 @@ void MultiThreadedTaskLoop::workerThread(std::stop_token stopToken, std::size_t
} }
// run it // run it
currentTask_ = &*task; impl::gCurrentTask = &*task;
tickTask(*task); tickTask(*task);
currentTask_ = nullptr; impl::gCurrentTask = nullptr;
// and give it back // and give it back
returningTasks_.push(std::move(*task)); returningTasks_.push(std::move(*task));
@ -180,8 +183,8 @@ void MultiThreadedTaskLoop::transferCurrentTask(TaskLoop& otherLoop) noexcept
MIJIN_ASSERT_FATAL(currentTask_ != nullptr, "Trying to call transferCurrentTask() while not running a task!"); MIJIN_ASSERT_FATAL(currentTask_ != nullptr, "Trying to call transferCurrentTask() while not running a task!");
// now start the transfer, first disown the task // now start the transfer, first disown the task
StoredTask storedTask = std::move(*currentTask_); StoredTask storedTask = std::move(*impl::gCurrentTask);
currentTask_->task = nullptr; // just to be sure impl::gCurrentTask->task = nullptr; // just to be sure
// then send it over to the other loop // then send it over to the other loop
otherLoop.addStoredTask(std::move(storedTask)); otherLoop.addStoredTask(std::move(storedTask));

View File

@ -56,11 +56,15 @@ struct TaskReturn
{ {
template<typename... TArgs> template<typename... TArgs>
constexpr void return_value(TArgs&&... args) noexcept { constexpr void return_value(TArgs&&... args) noexcept {
*(static_cast<TPromise&>(*this).state_) = TaskState<TReturn>(TReturn(std::forward<TArgs>(args)...), TaskStatus::FINISHED); (static_cast<TPromise&>(*this).state_) = TaskState<TReturn>(TReturn(std::forward<TArgs>(args)...), TaskStatus::FINISHED);
} }
constexpr void return_value(TReturn value) noexcept { constexpr void return_value(TReturn value) noexcept {
*(static_cast<TPromise&>(*this).state_) = TaskState<TReturn>(TReturn(std::move(value)), TaskStatus::FINISHED); (static_cast<TPromise&>(*this).state_) = TaskState<TReturn>(TReturn(std::move(value)), TaskStatus::FINISHED);
}
constexpr void unhandled_exception() noexcept {
(static_cast<TPromise&>(*this).state_) = TaskState<TReturn>(std::current_exception());
} }
}; };
@ -68,33 +72,45 @@ template<typename TPromise>
struct TaskReturn<void, TPromise> struct TaskReturn<void, TPromise>
{ {
constexpr void return_void() noexcept { constexpr void return_void() noexcept {
static_cast<TPromise&>(*this).state_->status = TaskStatus::FINISHED; static_cast<TPromise&>(*this).state_.status = TaskStatus::FINISHED;
}
constexpr void unhandled_exception() noexcept {
(static_cast<TPromise&>(*this).state_) = TaskState<void>(std::current_exception());
} }
}; };
struct TaskCancelled : std::exception {};
inline void throwIfCancelled();
} // namespace impl } // namespace impl
template<typename T> template<typename T>
struct TaskState struct TaskState
{ {
Optional<T> value; Optional<T> value;
std::exception_ptr exception;
TaskStatus status = TaskStatus::SUSPENDED; TaskStatus status = TaskStatus::SUSPENDED;
TaskState() = default; TaskState() = default;
TaskState(const TaskState&) = default; TaskState(const TaskState&) = default;
TaskState(TaskState&&) noexcept = default; TaskState(TaskState&&) noexcept = default;
constexpr TaskState(T _value, TaskStatus _status) noexcept : value(std::move(_value)), status(_status) {} inline TaskState(T _value, TaskStatus _status) noexcept : value(std::move(_value)), status(_status) {}
inline TaskState(std::exception_ptr _exception) noexcept : exception(std::move(_exception)), status(TaskStatus::FINISHED) {}
TaskState& operator=(const TaskState&) = default; TaskState& operator=(const TaskState&) = default;
TaskState& operator=(TaskState&&) noexcept = default; TaskState& operator=(TaskState&&) noexcept = default;
}; };
template<> template<>
struct TaskState<void> struct TaskState<void>
{ {
std::exception_ptr exception;
TaskStatus status = TaskStatus::SUSPENDED; TaskStatus status = TaskStatus::SUSPENDED;
TaskState() = default; TaskState() = default;
TaskState(const TaskState&) = default; TaskState(const TaskState&) = default;
TaskState(TaskState&&) noexcept = default; TaskState(TaskState&&) noexcept = default;
constexpr TaskState(TaskStatus _status) noexcept : status(_status) {} inline TaskState(TaskStatus _status) noexcept : status(_status) {}
inline TaskState(std::exception_ptr _exception) noexcept : exception(std::move(_exception)), status(TaskStatus::FINISHED) {}
TaskState& operator=(const TaskState&) = default; TaskState& operator=(const TaskState&) = default;
TaskState& operator=(TaskState&&) noexcept = default; TaskState& operator=(TaskState&&) noexcept = default;
}; };
@ -106,7 +122,9 @@ struct TaskAwaitableFuture
[[nodiscard]] constexpr bool await_ready() const noexcept { return future->ready(); } [[nodiscard]] constexpr bool await_ready() const noexcept { return future->ready(); }
constexpr void await_suspend(std::coroutine_handle<>) const noexcept {} constexpr void await_suspend(std::coroutine_handle<>) const noexcept {}
constexpr TValue await_resume() const noexcept { constexpr TValue await_resume() const noexcept
{
impl::throwIfCancelled();
if constexpr (std::is_same_v<TValue, void>) { if constexpr (std::is_same_v<TValue, void>) {
return; return;
} }
@ -123,7 +141,9 @@ struct TaskAwaitableSignal
[[nodiscard]] constexpr bool await_ready() const noexcept { return false; } [[nodiscard]] constexpr bool await_ready() const noexcept { return false; }
constexpr void await_suspend(std::coroutine_handle<>) const noexcept {} constexpr void await_suspend(std::coroutine_handle<>) const noexcept {}
constexpr auto& await_resume() const noexcept { inline auto& await_resume() const
{
impl::throwIfCancelled();
return *data; return *data;
} }
}; };
@ -135,7 +155,9 @@ struct TaskAwaitableSignal<TSingleArg>
[[nodiscard]] constexpr bool await_ready() const noexcept { return false; } [[nodiscard]] constexpr bool await_ready() const noexcept { return false; }
constexpr void await_suspend(std::coroutine_handle<>) const noexcept {} constexpr void await_suspend(std::coroutine_handle<>) const noexcept {}
constexpr auto& await_resume() const noexcept { constexpr auto& await_resume() const
{
impl::throwIfCancelled();
return *data; return *data;
} }
}; };
@ -145,7 +167,18 @@ struct TaskAwaitableSignal<>
{ {
[[nodiscard]] constexpr bool await_ready() const noexcept { return false; } [[nodiscard]] constexpr bool await_ready() const noexcept { return false; }
constexpr void await_suspend(std::coroutine_handle<>) const noexcept {} constexpr void await_suspend(std::coroutine_handle<>) const noexcept {}
constexpr void await_resume() const noexcept {} inline void await_resume() const {
impl::throwIfCancelled();
}
};
struct TaskAwaitableSuspend
{
[[nodiscard]] constexpr bool await_ready() const noexcept { return false; }
constexpr void await_suspend(std::coroutine_handle<>) const noexcept {}
inline void await_resume() const {
impl::throwIfCancelled();
}
}; };
template<typename TTraits> template<typename TTraits>
@ -155,7 +188,7 @@ struct TaskPromise : impl::TaskReturn<typename TTraits::result_t, TaskPromise<TT
using task_t = typename TTraits::task_t; using task_t = typename TTraits::task_t;
using result_t = typename TTraits::result_t; using result_t = typename TTraits::result_t;
std::shared_ptr<TaskState<result_t>> state_ = std::make_shared<TaskState<result_t>>(); TaskState<result_t> state_;
TaskLoop* loop_ = nullptr; TaskLoop* loop_ = nullptr;
constexpr task_t get_return_object() noexcept { return task_t(handle_t::from_promise(*this)); } constexpr task_t get_return_object() noexcept { return task_t(handle_t::from_promise(*this)); }
@ -170,7 +203,7 @@ struct TaskPromise : impl::TaskReturn<typename TTraits::result_t, TaskPromise<TT
// TODO: implement yielding (can't use futures for this) // TODO: implement yielding (can't use futures for this)
constexpr void unhandled_exception() noexcept {} // constexpr void unhandled_exception() noexcept {}
template<typename TValue> template<typename TValue>
auto await_transform(FuturePtr<TValue> future) noexcept auto await_transform(FuturePtr<TValue> future) noexcept
@ -179,10 +212,10 @@ struct TaskPromise : impl::TaskReturn<typename TTraits::result_t, TaskPromise<TT
TaskAwaitableFuture<TValue> awaitable{future}; TaskAwaitableFuture<TValue> awaitable{future};
if (!awaitable.await_ready()) if (!awaitable.await_ready())
{ {
state_->status = TaskStatus::WAITING; state_.status = TaskStatus::WAITING;
future->sigSet.connect([this, future]() mutable future->sigSet.connect([this, future]() mutable
{ {
state_->status = TaskStatus::SUSPENDED; state_.status = TaskStatus::SUSPENDED;
}, Oneshot::YES); }, Oneshot::YES);
} }
return awaitable; return awaitable;
@ -203,10 +236,10 @@ struct TaskPromise : impl::TaskReturn<typename TTraits::result_t, TaskPromise<TT
signal.connect([this, data](TFirstArg arg0, TSecondArg arg1, TArgs... args) mutable signal.connect([this, data](TFirstArg arg0, TSecondArg arg1, TArgs... args) mutable
{ {
*data = std::make_tuple(std::move(arg0), std::move(arg1), std::move(args)...); *data = std::make_tuple(std::move(arg0), std::move(arg1), std::move(args)...);
state_->status = TaskStatus::SUSPENDED; state_.status = TaskStatus::SUSPENDED;
}, Oneshot::YES); }, Oneshot::YES);
TaskAwaitableSignal<TFirstArg, TSecondArg, TArgs...> awaitable{data}; TaskAwaitableSignal<TFirstArg, TSecondArg, TArgs...> awaitable{data};
state_->status = TaskStatus::WAITING; state_.status = TaskStatus::WAITING;
return awaitable; return awaitable;
} }
@ -217,10 +250,10 @@ struct TaskPromise : impl::TaskReturn<typename TTraits::result_t, TaskPromise<TT
signal.connect([this, data](TFirstArg arg0) mutable signal.connect([this, data](TFirstArg arg0) mutable
{ {
*data = std::move(arg0); *data = std::move(arg0);
state_->status = TaskStatus::SUSPENDED; state_.status = TaskStatus::SUSPENDED;
}, Oneshot::YES); }, Oneshot::YES);
TaskAwaitableSignal<TFirstArg> awaitable{data}; TaskAwaitableSignal<TFirstArg> awaitable{data};
state_->status = TaskStatus::WAITING; state_.status = TaskStatus::WAITING;
return awaitable; return awaitable;
} }
@ -228,22 +261,28 @@ struct TaskPromise : impl::TaskReturn<typename TTraits::result_t, TaskPromise<TT
{ {
signal.connect([this]() signal.connect([this]()
{ {
state_->status = TaskStatus::SUSPENDED; state_.status = TaskStatus::SUSPENDED;
}, Oneshot::YES); }, Oneshot::YES);
TaskAwaitableSignal<> awaitable{}; TaskAwaitableSignal<> awaitable{};
state_->status = TaskStatus::WAITING; state_.status = TaskStatus::WAITING;
return awaitable; return awaitable;
} }
std::suspend_always await_transform(std::suspend_always) noexcept std::suspend_always await_transform(std::suspend_always) noexcept
{ {
state_->status = TaskStatus::SUSPENDED; state_.status = TaskStatus::SUSPENDED;
return std::suspend_always(); return std::suspend_always();
} }
std::suspend_never await_transform(std::suspend_never) noexcept { std::suspend_never await_transform(std::suspend_never) noexcept {
return std::suspend_never(); return std::suspend_never();
} }
TaskAwaitableSuspend await_transform(TaskAwaitableSuspend) noexcept
{
state_.status = TaskStatus::SUSPENDED;
return TaskAwaitableSuspend();
}
}; };
template<typename TResult> template<typename TResult>
@ -262,10 +301,9 @@ public:
using handle_t = typename promise_type::handle_t; using handle_t = typename promise_type::handle_t;
private: private:
handle_t handle_; handle_t handle_;
std::shared_ptr<TaskState<result_t>> state_;
public: public:
constexpr explicit TaskBase(handle_t handle) noexcept : handle_(handle), state_(handle.promise().state_) {} constexpr explicit TaskBase(handle_t handle) noexcept : handle_(handle) {}
TaskBase(const TaskBase&) = default; TaskBase(const TaskBase&) = delete;
TaskBase(TaskBase&& other) noexcept = default; TaskBase(TaskBase&& other) noexcept = default;
~TaskBase() noexcept; ~TaskBase() noexcept;
public: public:
@ -278,17 +316,16 @@ public:
[[nodiscard]] [[nodiscard]]
constexpr bool operator!=(const TaskBase& other) const noexcept { return handle_ != other.handle_; } constexpr bool operator!=(const TaskBase& other) const noexcept { return handle_ != other.handle_; }
public: public:
constexpr TaskState<TResult>& resume() noexcept
{
state_->status = TaskStatus::RUNNING;
handle_.resume();
return *state_;
}
[[nodiscard]] [[nodiscard]]
constexpr TaskState<TResult>& state() noexcept constexpr TaskState<TResult>& state() noexcept
{ {
return *state_; return handle_.promise().state_;
}
constexpr TaskState<TResult>& resume()
{
state().status = TaskStatus::RUNNING;
handle_.resume();
return state();
} }
private: private:
[[nodiscard]] [[nodiscard]]
@ -318,8 +355,9 @@ public:
virtual ~WrappedTaskBase() = default; virtual ~WrappedTaskBase() = default;
public: public:
virtual TaskStatus status() noexcept = 0; virtual TaskStatus status() noexcept = 0;
virtual std::exception_ptr exception() noexcept = 0;
// virtual std::any result() noexcept = 0; // virtual std::any result() noexcept = 0;
virtual void resume() noexcept = 0; virtual void resume() = 0;
virtual void* raw() noexcept = 0; virtual void* raw() noexcept = 0;
virtual std::coroutine_handle<> handle() noexcept = 0; virtual std::coroutine_handle<> handle() noexcept = 0;
virtual void setLoop(TaskLoop* loop) noexcept = 0; virtual void setLoop(TaskLoop* loop) noexcept = 0;
@ -344,6 +382,7 @@ public:
WrappedTask& operator=(WrappedTask&&) noexcept = default; WrappedTask& operator=(WrappedTask&&) noexcept = default;
public: public:
TaskStatus status() noexcept override { return task_.state().status; } TaskStatus status() noexcept override { return task_.state().status; }
std::exception_ptr exception() noexcept override { return task_.state().exception; }
// std::any result() noexcept override // std::any result() noexcept override
// { // {
// if constexpr (std::is_same_v<typename TTask::result_t, void>) { // if constexpr (std::is_same_v<typename TTask::result_t, void>) {
@ -353,12 +392,39 @@ public:
// return std::any(task_.state().value); // return std::any(task_.state().value);
// } // }
// } // }
void resume() noexcept override { task_.resume(); } void resume() override { task_.resume(); }
void* raw() noexcept override { return &task_; } void* raw() noexcept override { return &task_; }
std::coroutine_handle<> handle() noexcept override { return task_.handle(); } std::coroutine_handle<> handle() noexcept override { return task_.handle(); }
void setLoop(TaskLoop* loop) noexcept override { task_.setLoop(loop); } void setLoop(TaskLoop* loop) noexcept override { task_.setLoop(loop); }
}; };
struct TaskSharedState
{
std::atomic_bool cancelled_ = false;
};
class TaskHandle
{
private:
std::weak_ptr<TaskSharedState> state_;
public:
TaskHandle() = default;
explicit TaskHandle(std::weak_ptr<TaskSharedState> state) noexcept : state_(std::move(state)) {}
TaskHandle(const TaskHandle&) = default;
TaskHandle(TaskHandle&&) = default;
TaskHandle& operator=(const TaskHandle&) = default;
TaskHandle& operator=(TaskHandle&&) = default;
void cancel() const noexcept
{
if (std::shared_ptr<TaskSharedState> state = state_.lock())
{
state->cancelled_ = true;
}
}
};
template<typename TTask> template<typename TTask>
std::unique_ptr<WrappedTask<TTask>> wrapTask(TTask&& task) noexcept std::unique_ptr<WrappedTask<TTask>> wrapTask(TTask&& task) noexcept
{ {
@ -370,16 +436,17 @@ class TaskLoop
public: public:
MIJIN_DEFINE_FLAG(CanContinue); MIJIN_DEFINE_FLAG(CanContinue);
MIJIN_DEFINE_FLAG(IgnoreWaiting); MIJIN_DEFINE_FLAG(IgnoreWaiting);
protected:
using wrapped_task_t = WrappedTaskBase; using wrapped_task_t = WrappedTaskBase;
using wrapped_task_base_ptr_t = std::unique_ptr<wrapped_task_t>; using wrapped_task_base_ptr_t = std::unique_ptr<wrapped_task_t>;
struct StoredTask struct StoredTask
{ {
std::shared_ptr<TaskSharedState> sharedState;
wrapped_task_base_ptr_t task; wrapped_task_base_ptr_t task;
std::function<void(StoredTask&)> setFuture; std::function<void(StoredTask&)> setFuture;
std::any resultData; std::any resultData;
}; };
protected:
using task_vector_t = std::vector<StoredTask>; using task_vector_t = std::vector<StoredTask>;
template<typename TTask> template<typename TTask>
@ -394,7 +461,7 @@ public:
TaskLoop& operator=(TaskLoop&&) = delete; TaskLoop& operator=(TaskLoop&&) = delete;
template<typename TResult> template<typename TResult>
inline FuturePtr<TResult> addTask(TaskBase<TResult> task) noexcept; inline FuturePtr<TResult> addTask(TaskBase<TResult> task, TaskHandle* outHandle = nullptr) noexcept;
virtual void transferCurrentTask(TaskLoop& otherLoop) noexcept = 0; virtual void transferCurrentTask(TaskLoop& otherLoop) noexcept = 0;
virtual void addStoredTask(StoredTask&& storedTask) noexcept = 0; virtual void addStoredTask(StoredTask&& storedTask) noexcept = 0;
@ -452,14 +519,25 @@ public: // public interface
private: // private stuff private: // private stuff
void managerThread(std::stop_token stopToken); void managerThread(std::stop_token stopToken);
void workerThread(std::stop_token stopToken, std::size_t workerId); void workerThread(std::stop_token stopToken, std::size_t workerId);
static thread_local StoredTask* currentTask_;
}; };
// //
// public functions // public functions
// //
namespace impl
{
extern thread_local TaskLoop::StoredTask* gCurrentTask;
inline void throwIfCancelled()
{
if (gCurrentTask->sharedState->cancelled_)
{
throw TaskCancelled();
}
}
}
template<typename TResult> template<typename TResult>
TaskBase<TResult>::~TaskBase() noexcept TaskBase<TResult>::~TaskBase() noexcept
{ {
@ -470,17 +548,23 @@ TaskBase<TResult>::~TaskBase() noexcept
} }
template<typename TResult> template<typename TResult>
inline FuturePtr<TResult> TaskLoop::addTask(TaskBase<TResult> task) noexcept inline FuturePtr<TResult> TaskLoop::addTask(TaskBase<TResult> task, TaskHandle* outHandle) noexcept
{ {
MIJIN_ASSERT(!task.getLoop(), "Attempting to add task that already has a loop!"); MIJIN_ASSERT(!task.getLoop(), "Attempting to add task that already has a loop!");
task.setLoop(this); task.setLoop(this);
auto sharedState = std::make_shared<TaskSharedState>();
auto future = std::make_shared<Future<TResult>>(); auto future = std::make_shared<Future<TResult>>();
auto setFuture = &setFutureHelper<TResult>; auto setFuture = &setFutureHelper<TResult>;
if (outHandle != nullptr)
{
*outHandle = TaskHandle(sharedState);
}
// add tasks to a seperate vector first as we might be running another task right now // add tasks to a seperate vector first as we might be running another task right now
addStoredTask(StoredTask{ addStoredTask(StoredTask{
.sharedState = std::move(sharedState),
.task = wrapTask(std::move(task)), .task = wrapTask(std::move(task)),
.setFuture = setFuture, .setFuture = setFuture,
.resultData = future .resultData = future
@ -492,13 +576,20 @@ inline FuturePtr<TResult> TaskLoop::addTask(TaskBase<TResult> task) noexcept
inline TaskStatus TaskLoop::tickTask(StoredTask& task) noexcept inline TaskStatus TaskLoop::tickTask(StoredTask& task) noexcept
{ {
TaskStatus status = {}; TaskStatus status = {};
impl::gCurrentTask = &task;
do do
{ {
task.task->resume(); task.task->resume();
status = task.task ? task.task->status() : TaskStatus::WAITING; // no inner task -> task switch context (and will be removed later) status = task.task ? task.task->status() : TaskStatus::WAITING; // no inner task -> task switch context (and will be removed later)
} }
while (status == TaskStatus::RUNNING); while (status == TaskStatus::RUNNING);
impl::gCurrentTask = nullptr;
if (task.task && task.task->exception())
{
// TODO: handle the exception somehow, others may be waiting
return TaskStatus::FINISHED;
}
if (status == TaskStatus::YIELDED || status == TaskStatus::FINISHED) if (status == TaskStatus::YIELDED || status == TaskStatus::FINISHED)
{ {
task.setFuture(task); task.setFuture(task);
@ -619,8 +710,8 @@ inline void SimpleTaskLoop::runUntilDone(IgnoreWaiting ignoreWaiting) noexcept
// utility stuff // utility stuff
inline std::suspend_always c_suspend() { inline TaskAwaitableSuspend c_suspend() {
return std::suspend_always(); return TaskAwaitableSuspend();
} }
template<template<typename...> typename TCollection, typename TType, typename... TTemplateArgs> template<template<typename...> typename TCollection, typename TType, typename... TTemplateArgs>