#pragma once #ifndef MIJIN_ASYNC_COROUTINE_HPP_INCLUDED #define MIJIN_ASYNC_COROUTINE_HPP_INCLUDED 1 #include #include #include #include #include #include "./future.hpp" #include "./message_queue.hpp" #include "../container/optional.hpp" #include "../util/flag.hpp" #include "../util/traits.hpp" namespace mijin { // // public defines // #if !defined(MIJIN_COROUTINE_ENABLE_DEBUG_INFO) #define MIJIN_COROUTINE_ENABLE_DEBUG_INFO 0 #endif // // public types // enum class TaskStatus { SUSPENDED = 0, RUNNING = 1, WAITING = 2, FINISHED = 3, YIELDED = 4 }; // forward declarations template struct TaskState; class TaskLoop; template class TaskBase; namespace impl { struct TaskCancelled : std::exception {}; inline void throwIfCancelled(); } // namespace impl template struct TaskState { Optional value; std::exception_ptr exception; TaskStatus status = TaskStatus::SUSPENDED; TaskState() = default; TaskState(const TaskState&) = default; TaskState(TaskState&&) noexcept = default; inline TaskState(T _value, TaskStatus _status) noexcept : value(std::move(_value)), status(_status) {} inline TaskState(std::exception_ptr _exception) noexcept : exception(std::move(_exception)), status(TaskStatus::FINISHED) {} TaskState& operator=(const TaskState&) = default; TaskState& operator=(TaskState&&) noexcept = default; }; template<> struct TaskState { std::exception_ptr exception; TaskStatus status = TaskStatus::SUSPENDED; TaskState() = default; TaskState(const TaskState&) = default; TaskState(TaskState&&) noexcept = default; inline TaskState(TaskStatus _status) noexcept : status(_status) {} inline TaskState(std::exception_ptr _exception) noexcept : exception(std::move(_exception)), status(TaskStatus::FINISHED) {} TaskState& operator=(const TaskState&) = default; TaskState& operator=(TaskState&&) noexcept = default; }; namespace impl { template struct TaskReturn { template constexpr void return_value(TArgs&&... args) noexcept { (static_cast(*this).state_) = TaskState(TReturn(std::forward(args)...), TaskStatus::FINISHED); } constexpr void return_value(TReturn value) noexcept { (static_cast(*this).state_) = TaskState(TReturn(std::move(value)), TaskStatus::FINISHED); } constexpr void unhandled_exception() noexcept { (static_cast(*this).state_) = TaskState(std::current_exception()); } }; template struct TaskReturn { constexpr void return_void() noexcept { static_cast(*this).state_.status = TaskStatus::FINISHED; } constexpr void unhandled_exception() noexcept { (static_cast(*this).state_) = TaskState(std::current_exception()); } }; } template struct TaskAwaitableFuture { FuturePtr future; [[nodiscard]] constexpr bool await_ready() const noexcept { return future->ready(); } constexpr void await_suspend(std::coroutine_handle<>) const noexcept {} constexpr TValue await_resume() const noexcept { impl::throwIfCancelled(); if constexpr (std::is_same_v) { return; } else { return std::move(future->get()); } } }; template struct TaskAwaitableSignal { std::shared_ptr> data; [[nodiscard]] constexpr bool await_ready() const noexcept { return false; } constexpr void await_suspend(std::coroutine_handle<>) const noexcept {} inline auto& await_resume() const { impl::throwIfCancelled(); return *data; } }; template struct TaskAwaitableSignal { std::shared_ptr data; [[nodiscard]] constexpr bool await_ready() const noexcept { return false; } constexpr void await_suspend(std::coroutine_handle<>) const noexcept {} constexpr auto& await_resume() const { impl::throwIfCancelled(); return *data; } }; template<> struct TaskAwaitableSignal<> { [[nodiscard]] constexpr bool await_ready() const noexcept { return false; } constexpr void await_suspend(std::coroutine_handle<>) const noexcept {} inline void await_resume() const { impl::throwIfCancelled(); } }; struct TaskAwaitableSuspend { [[nodiscard]] constexpr bool await_ready() const noexcept { return false; } constexpr void await_suspend(std::coroutine_handle<>) const noexcept {} inline void await_resume() const { impl::throwIfCancelled(); } }; template struct TaskPromise : impl::TaskReturn> { using handle_t = std::coroutine_handle; using task_t = typename TTraits::task_t; using result_t = typename TTraits::result_t; TaskState state_; TaskLoop* loop_ = nullptr; constexpr task_t get_return_object() noexcept { return task_t(handle_t::from_promise(*this)); } constexpr std::suspend_always initial_suspend() noexcept { return {}; } constexpr std::suspend_always final_suspend() noexcept { return {}; } // template // constexpr std::suspend_always yield_value(TValue value) noexcept { // *state_ = TaskState(std::move(value), TaskStatus::YIELDED); // return {}; // } // TODO: implement yielding (can't use futures for this) // constexpr void unhandled_exception() noexcept {} template auto await_transform(FuturePtr future) noexcept { MIJIN_ASSERT(loop_ != nullptr, "Cannot await future outside of a loop!"); TaskAwaitableFuture awaitable{future}; if (!awaitable.await_ready()) { state_.status = TaskStatus::WAITING; future->sigSet.connect([this, future]() mutable { state_.status = TaskStatus::SUSPENDED; }, Oneshot::YES); } return awaitable; } template auto await_transform(TaskBase task) noexcept { MIJIN_ASSERT(loop_ != nullptr, "Cannot await another task outside of a loop!"); auto future = delayEvaluation(loop_)->addTask(std::move(task)); // hackidyhack: delay evaluation of the type of loop_ as it is only forward-declared here return await_transform(future); } template auto await_transform(Signal& signal) noexcept { auto data = std::make_shared>(); signal.connect([this, data](TFirstArg arg0, TSecondArg arg1, TArgs... args) mutable { *data = std::make_tuple(std::move(arg0), std::move(arg1), std::move(args)...); state_.status = TaskStatus::SUSPENDED; }, Oneshot::YES); TaskAwaitableSignal awaitable{data}; state_.status = TaskStatus::WAITING; return awaitable; } template auto await_transform(Signal& signal) noexcept { auto data = std::make_shared(); signal.connect([this, data](TFirstArg arg0) mutable { *data = std::move(arg0); state_.status = TaskStatus::SUSPENDED; }, Oneshot::YES); TaskAwaitableSignal awaitable{data}; state_.status = TaskStatus::WAITING; return awaitable; } auto await_transform(Signal<>& signal) noexcept { signal.connect([this]() { state_.status = TaskStatus::SUSPENDED; }, Oneshot::YES); TaskAwaitableSignal<> awaitable{}; state_.status = TaskStatus::WAITING; return awaitable; } std::suspend_always await_transform(std::suspend_always) noexcept { state_.status = TaskStatus::SUSPENDED; return std::suspend_always(); } std::suspend_never await_transform(std::suspend_never) noexcept { return std::suspend_never(); } TaskAwaitableSuspend await_transform(TaskAwaitableSuspend) noexcept { state_.status = TaskStatus::SUSPENDED; return TaskAwaitableSuspend(); } }; template class TaskBase { public: using task_t = TaskBase; using result_t = TResult; struct Traits { using task_t = TaskBase; using result_t = TResult; }; public: using promise_type = TaskPromise; using handle_t = typename promise_type::handle_t; private: handle_t handle_; public: constexpr explicit TaskBase(handle_t handle) noexcept : handle_(handle) {} TaskBase(const TaskBase&) = delete; TaskBase(TaskBase&& other) noexcept : handle_(std::exchange(other.handle_, nullptr)) {} ~TaskBase() noexcept; public: TaskBase& operator=(const TaskBase&) = default; TaskBase& operator=(TaskBase&& other) noexcept = default; [[nodiscard]] constexpr bool operator==(const TaskBase& other) const noexcept { return handle_ == other.handle_; } [[nodiscard]] constexpr bool operator!=(const TaskBase& other) const noexcept { return handle_ != other.handle_; } public: [[nodiscard]] constexpr TaskState& state() noexcept { return handle_.promise().state_; } constexpr TaskState& resume() { state().status = TaskStatus::RUNNING; handle_.resume(); return state(); } private: [[nodiscard]] constexpr handle_t handle() const noexcept { return handle_; } [[nodiscard]] constexpr TaskLoop* getLoop() noexcept { return handle_.promise().loop_; } constexpr void setLoop(TaskLoop* loop) noexcept { // MIJIN_ASSERT(handle_.promise().loop_ == nullptr // || handle_.promise().loop_ == loop // || loop == nullptr, "Task already has a loop assigned!"); handle_.promise().loop_ = loop; } friend class TaskLoop; template friend class WrappedTask; }; class WrappedTaskBase { public: virtual ~WrappedTaskBase() = default; public: virtual TaskStatus status() noexcept = 0; virtual std::exception_ptr exception() noexcept = 0; // virtual std::any result() noexcept = 0; virtual void resume() = 0; virtual void* raw() noexcept = 0; virtual std::coroutine_handle<> handle() noexcept = 0; virtual void setLoop(TaskLoop* loop) noexcept = 0; [[nodiscard]] inline bool canResume() { const TaskStatus stat = status(); return (stat == TaskStatus::SUSPENDED || stat == TaskStatus::YIELDED); } }; template class WrappedTask : public WrappedTaskBase { private: TTask task_; public: constexpr explicit WrappedTask(TTask&& task) noexcept : task_(std::move(task)) {} WrappedTask(const WrappedTask&) = delete; WrappedTask(WrappedTask&&) noexcept = default; public: WrappedTask& operator=(const WrappedTask&) = delete; WrappedTask& operator=(WrappedTask&&) noexcept = default; public: TaskStatus status() noexcept override { return task_.state().status; } std::exception_ptr exception() noexcept override { return task_.state().exception; } // std::any result() noexcept override // { // if constexpr (std::is_same_v) { // return {}; // } // else { // return std::any(task_.state().value); // } // } void resume() override { task_.resume(); } void* raw() noexcept override { return &task_; } std::coroutine_handle<> handle() noexcept override { return task_.handle(); } void setLoop(TaskLoop* loop) noexcept override { task_.setLoop(loop); } }; struct TaskSharedState { std::atomic_bool cancelled_ = false; }; class TaskHandle { private: std::weak_ptr state_; public: TaskHandle() = default; explicit TaskHandle(std::weak_ptr state) noexcept : state_(std::move(state)) {} TaskHandle(const TaskHandle&) = default; TaskHandle(TaskHandle&&) = default; TaskHandle& operator=(const TaskHandle&) = default; TaskHandle& operator=(TaskHandle&&) = default; [[nodiscard]] bool isValid() const noexcept { return !state_.expired(); } void cancel() const noexcept { if (std::shared_ptr state = state_.lock()) { state->cancelled_ = true; } } }; template std::unique_ptr> wrapTask(TTask&& task) noexcept { return std::make_unique>(std::forward(task)); } class TaskLoop { public: MIJIN_DEFINE_FLAG(CanContinue); MIJIN_DEFINE_FLAG(IgnoreWaiting); using wrapped_task_t = WrappedTaskBase; using wrapped_task_base_ptr_t = std::unique_ptr; struct StoredTask { std::shared_ptr sharedState; wrapped_task_base_ptr_t task; std::function setFuture; std::any resultData; }; protected: using task_vector_t = std::vector; template using wrapped_task_ptr_t = std::unique_ptr>; public: TaskLoop() = default; TaskLoop(const TaskLoop&) = delete; TaskLoop(TaskLoop&&) = delete; virtual ~TaskLoop() = default; TaskLoop& operator=(const TaskLoop&) = delete; TaskLoop& operator=(TaskLoop&&) = delete; template inline FuturePtr addTask(TaskBase task, TaskHandle* outHandle = nullptr) noexcept; virtual void transferCurrentTask(TaskLoop& otherLoop) noexcept = 0; virtual void addStoredTask(StoredTask&& storedTask) noexcept = 0; [[nodiscard]] static TaskLoop& current() noexcept; protected: inline TaskStatus tickTask(StoredTask& task) noexcept; protected: static inline TaskLoop*& currentLoopStorage() noexcept; template static inline void setFutureHelper(StoredTask& storedTask) noexcept; }; template using Task = TaskBase; class SimpleTaskLoop : public TaskLoop { private: task_vector_t tasks_; task_vector_t newTasks_; task_vector_t::iterator currentTask_; MessageQueue queuedTasks_; std::thread::id threadId_; public: // TaskLoop implementation void transferCurrentTask(TaskLoop& otherLoop) noexcept override; void addStoredTask(StoredTask&& storedTask) noexcept override; public: // public interface [[nodiscard]] constexpr bool empty() const noexcept { return tasks_.empty() && newTasks_.empty(); } inline CanContinue tick() noexcept; inline void runUntilDone(IgnoreWaiting ignoreWaiting = IgnoreWaiting::NO) noexcept; private: inline void assertCorrectThread() { MIJIN_ASSERT(threadId_ == std::thread::id() || threadId_ == std::this_thread::get_id(), "Unsafe to TaskLoop from different thread!"); } }; class MultiThreadedTaskLoop : public TaskLoop { private: task_vector_t parkedTasks_; // buffer for tasks that don't fit into readyTasks_ MessageQueue queuedTasks_; // tasks that should be appended to parked tasks MessageQueue readyTasks_; // task queue to send tasks to a worker thread MessageQueue returningTasks_; // task that have executed on a worker thread and return for further processing std::jthread managerThread_; std::vector workerThreads_; public: // TaskLoop implementation void transferCurrentTask(TaskLoop& otherLoop) noexcept override; void addStoredTask(StoredTask&& storedTask) noexcept override; public: // public interface void start(std::size_t numWorkerThreads); void stop(); private: // private stuff void managerThread(std::stop_token stopToken); void workerThread(std::stop_token stopToken, std::size_t workerId); }; // // public functions // namespace impl { extern thread_local TaskLoop::StoredTask* gCurrentTask; inline void throwIfCancelled() { if (gCurrentTask->sharedState->cancelled_) { throw TaskCancelled(); } } } template TaskBase::~TaskBase() noexcept { if (handle_) { handle_.destroy(); } } template inline FuturePtr TaskLoop::addTask(TaskBase task, TaskHandle* outHandle) noexcept { MIJIN_ASSERT(!task.getLoop(), "Attempting to add task that already has a loop!"); task.setLoop(this); auto sharedState = std::make_shared(); auto future = std::make_shared>(); auto setFuture = &setFutureHelper; if (outHandle != nullptr) { *outHandle = TaskHandle(sharedState); } // add tasks to a seperate vector first as we might be running another task right now addStoredTask(StoredTask{ .sharedState = std::move(sharedState), .task = wrapTask(std::move(task)), .setFuture = setFuture, .resultData = future }); return future; } inline TaskStatus TaskLoop::tickTask(StoredTask& task) noexcept { TaskStatus status = {}; impl::gCurrentTask = &task; do { task.task->resume(); status = task.task ? task.task->status() : TaskStatus::WAITING; // no inner task -> task switch context (and will be removed later) } while (status == TaskStatus::RUNNING); impl::gCurrentTask = nullptr; if (task.task && task.task->exception()) { // TODO: handle the exception somehow, others may be waiting return TaskStatus::FINISHED; } if (status == TaskStatus::YIELDED || status == TaskStatus::FINISHED) { task.setFuture(task); } return status; } /* static */ inline auto TaskLoop::current() noexcept -> TaskLoop& { MIJIN_ASSERT(currentLoopStorage() != nullptr, "Attempting to fetch current loop while no coroutine is running!"); return *currentLoopStorage(); } /* static */ auto TaskLoop::currentLoopStorage() noexcept -> TaskLoop*& { static thread_local TaskLoop* storage = nullptr; return storage; } template /* static */ inline void TaskLoop::setFutureHelper(StoredTask& storedTask) noexcept { TaskBase& task = *static_cast*>(storedTask.task->raw()); auto future = std::any_cast>(storedTask.resultData); if constexpr (!std::is_same_v) { MIJIN_ASSERT(!task.state().value.empty(), "Task did not produce a value?"); future->set(std::move(task.state().value.get())); } else { future->set(); } } inline std::suspend_always switchContext(TaskLoop& taskLoop) { TaskLoop& currentTaskLoop = TaskLoop::current(); if (¤tTaskLoop == &taskLoop) { return {}; } currentTaskLoop.transferCurrentTask(taskLoop); return {}; } inline auto SimpleTaskLoop::tick() noexcept -> CanContinue { // set current taskloop MIJIN_ASSERT(currentLoopStorage() == nullptr, "Trying to tick a loop from a coroutine, this is not supported."); currentLoopStorage() = this; threadId_ = std::this_thread::get_id(); // move over all tasks from newTasks for (StoredTask& task : newTasks_) { tasks_.push_back(std::move(task)); } newTasks_.clear(); // also pick up tasks from other threads while(true) { std::optional task = queuedTasks_.tryPop(); if (!task.has_value()) { break; } tasks_.push_back(std::move(*task)); } // remove any tasks that are finished executing auto it = std::remove_if(tasks_.begin(), tasks_.end(), [](StoredTask& task) { return task.task->status() == TaskStatus::FINISHED; }); tasks_.erase(it, tasks_.end()); CanContinue canContinue = CanContinue::NO; // then execute all tasks that can be executed for (currentTask_ = tasks_.begin(); currentTask_ != tasks_.end(); ++currentTask_) { StoredTask& task = *currentTask_; TaskStatus status = task.task->status(); if (status != TaskStatus::SUSPENDED && status != TaskStatus::YIELDED) { MIJIN_ASSERT(status == TaskStatus::WAITING, "Task with invalid status in task list!"); continue; } status = tickTask(task); if (status == TaskStatus::SUSPENDED || status == TaskStatus::YIELDED) { canContinue = CanContinue::YES; } } // reset current loop currentLoopStorage() = nullptr; // remove any tasks that have been transferred to another queue it = std::remove_if(tasks_.begin(), tasks_.end(), [](const StoredTask& task) { return task.task == nullptr; }); tasks_.erase(it, tasks_.end()); return canContinue; } inline void SimpleTaskLoop::runUntilDone(IgnoreWaiting ignoreWaiting) noexcept { while (!tasks_.empty() || !newTasks_.empty()) { const CanContinue canContinue = tick(); if (ignoreWaiting && !canContinue) { break; } } } // utility stuff inline TaskAwaitableSuspend c_suspend() { return TaskAwaitableSuspend(); } template typename TCollection, typename TType, typename... TTemplateArgs> Task<> c_allDone(const TCollection, TTemplateArgs...>& futures) { bool allDone = true; do { allDone = true; for (const FuturePtr& future : futures) { if (future && !future->ready()) { allDone = false; break; } } co_await c_suspend(); } while (!allDone); } #if MIJIN_COROUTINE_ENABLE_DEBUG_INFO #endif } #endif // MIJIN_ASYNC_COROUTINE_HPP_INCLUDED