860 lines
		
	
	
		
			26 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			860 lines
		
	
	
		
			26 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
 | 
						|
#pragma once
 | 
						|
 | 
						|
#ifndef MIJIN_ASYNC_COROUTINE_HPP_INCLUDED
 | 
						|
#define MIJIN_ASYNC_COROUTINE_HPP_INCLUDED 1
 | 
						|
 | 
						|
 | 
						|
#if !defined(MIJIN_COROUTINE_ENABLE_DEBUG_INFO)
 | 
						|
#   define MIJIN_COROUTINE_ENABLE_DEBUG_INFO 0 // Capture stack each time a coroutine is started. Warning, expensive! // TODO: maybe implement a lighter version only storing the return address?
 | 
						|
#endif
 | 
						|
 | 
						|
#include <any>
 | 
						|
#include <coroutine>
 | 
						|
#include <exception>
 | 
						|
#include <memory>
 | 
						|
#include <thread>
 | 
						|
#include <tuple>
 | 
						|
 | 
						|
#include "./future.hpp"
 | 
						|
#include "./message_queue.hpp"
 | 
						|
#include "../container/optional.hpp"
 | 
						|
#include "../internal/common.hpp"
 | 
						|
#include "../util/flag.hpp"
 | 
						|
#include "../util/iterators.hpp"
 | 
						|
#include "../util/traits.hpp"
 | 
						|
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
 | 
						|
#include "../debug/stacktrace.hpp"
 | 
						|
#endif
 | 
						|
 | 
						|
namespace mijin
 | 
						|
{
 | 
						|
 | 
						|
//
 | 
						|
// public defines
 | 
						|
//
 | 
						|
 | 
						|
//
 | 
						|
// public types
 | 
						|
//
 | 
						|
 | 
						|
enum class TaskStatus
 | 
						|
{
 | 
						|
    SUSPENDED = 0,
 | 
						|
    RUNNING = 1,
 | 
						|
    WAITING = 2,
 | 
						|
    FINISHED = 3,
 | 
						|
    YIELDED = 4
 | 
						|
};
 | 
						|
 | 
						|
// forward declarations
 | 
						|
template<typename T>
 | 
						|
struct TaskState;
 | 
						|
 | 
						|
class TaskLoop;
 | 
						|
 | 
						|
template<typename TResult = void>
 | 
						|
class TaskBase;
 | 
						|
 | 
						|
struct TaskCancelled : std::exception {};
 | 
						|
 | 
						|
namespace impl
 | 
						|
{
 | 
						|
inline void throwIfCancelled();
 | 
						|
} // namespace impl
 | 
						|
 | 
						|
class TaskHandle
 | 
						|
{
 | 
						|
private:
 | 
						|
    std::weak_ptr<struct TaskSharedState> state_;
 | 
						|
public:
 | 
						|
    TaskHandle() = default;
 | 
						|
    explicit TaskHandle(std::weak_ptr<TaskSharedState> state) MIJIN_NOEXCEPT : state_(std::move(state)) {}
 | 
						|
    TaskHandle(const TaskHandle&) = default;
 | 
						|
    TaskHandle(TaskHandle&&) = default;
 | 
						|
 | 
						|
    TaskHandle& operator=(const TaskHandle&) = default;
 | 
						|
    TaskHandle& operator=(TaskHandle&&) = default;
 | 
						|
 | 
						|
    bool operator==(const TaskHandle& other) const MIJIN_NOEXCEPT {
 | 
						|
        return !state_.owner_before(other.state_) && !other.state_.owner_before(state_);
 | 
						|
    }
 | 
						|
    bool operator!=(const TaskHandle& other) const MIJIN_NOEXCEPT {
 | 
						|
        return !(*this == other);
 | 
						|
    }
 | 
						|
 | 
						|
    [[nodiscard]] bool isValid() const MIJIN_NOEXCEPT
 | 
						|
    {
 | 
						|
        return !state_.expired();
 | 
						|
    }
 | 
						|
 | 
						|
    inline void cancel() const MIJIN_NOEXCEPT;
 | 
						|
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
 | 
						|
    inline Optional<Stacktrace> getCreationStack() const MIJIN_NOEXCEPT;
 | 
						|
#endif
 | 
						|
};
 | 
						|
struct TaskSharedState
 | 
						|
{
 | 
						|
    std::atomic_bool cancelled_ = false;
 | 
						|
    TaskHandle subTask;
 | 
						|
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
 | 
						|
    Stacktrace creationStack_;
 | 
						|
#endif
 | 
						|
};
 | 
						|
 | 
						|
template<typename T>
 | 
						|
struct TaskState
 | 
						|
{
 | 
						|
    Optional<T> value;
 | 
						|
    std::exception_ptr exception;
 | 
						|
    TaskStatus status = TaskStatus::SUSPENDED;
 | 
						|
 | 
						|
    TaskState() = default;
 | 
						|
    TaskState(const TaskState&) = default;
 | 
						|
    TaskState(TaskState&&) MIJIN_NOEXCEPT = default;
 | 
						|
    inline TaskState(T _value, TaskStatus _status) MIJIN_NOEXCEPT : value(std::move(_value)), status(_status) {}
 | 
						|
    inline TaskState(std::exception_ptr _exception) MIJIN_NOEXCEPT : exception(std::move(_exception)), status(TaskStatus::FINISHED) {}
 | 
						|
    TaskState& operator=(const TaskState&) = default;
 | 
						|
    TaskState& operator=(TaskState&&) MIJIN_NOEXCEPT = default;
 | 
						|
};
 | 
						|
 | 
						|
template<>
 | 
						|
struct TaskState<void>
 | 
						|
{
 | 
						|
    std::exception_ptr exception;
 | 
						|
    TaskStatus status = TaskStatus::SUSPENDED;
 | 
						|
 | 
						|
    TaskState() = default;
 | 
						|
    TaskState(const TaskState&) = default;
 | 
						|
    TaskState(TaskState&&) MIJIN_NOEXCEPT = default;
 | 
						|
    inline TaskState(TaskStatus _status) MIJIN_NOEXCEPT : status(_status) {}
 | 
						|
    inline TaskState(std::exception_ptr _exception) MIJIN_NOEXCEPT : exception(std::move(_exception)), status(TaskStatus::FINISHED) {}
 | 
						|
    TaskState& operator=(const TaskState&) = default;
 | 
						|
    TaskState& operator=(TaskState&&) MIJIN_NOEXCEPT = default;
 | 
						|
};
 | 
						|
 | 
						|
namespace impl
 | 
						|
{
 | 
						|
template<typename TReturn, typename TPromise>
 | 
						|
struct TaskReturn
 | 
						|
{
 | 
						|
    template<typename... TArgs>
 | 
						|
    constexpr void return_value(TArgs&&... args) MIJIN_NOEXCEPT {
 | 
						|
        (static_cast<TPromise&>(*this).state_) = TaskState<TReturn>(TReturn(std::forward<TArgs>(args)...), TaskStatus::FINISHED);
 | 
						|
    }
 | 
						|
 | 
						|
    constexpr void return_value(TReturn value) MIJIN_NOEXCEPT {
 | 
						|
        (static_cast<TPromise&>(*this).state_) = TaskState<TReturn>(TReturn(std::move(value)), TaskStatus::FINISHED);
 | 
						|
    }
 | 
						|
 | 
						|
    constexpr void unhandled_exception() MIJIN_NOEXCEPT {
 | 
						|
        (static_cast<TPromise&>(*this).state_) = TaskState<TReturn>(std::current_exception());
 | 
						|
    }
 | 
						|
};
 | 
						|
 | 
						|
template<typename TPromise>
 | 
						|
struct TaskReturn<void, TPromise>
 | 
						|
{
 | 
						|
    constexpr void return_void() MIJIN_NOEXCEPT {
 | 
						|
        static_cast<TPromise&>(*this).state_.status = TaskStatus::FINISHED;
 | 
						|
    }
 | 
						|
 | 
						|
    constexpr void unhandled_exception() MIJIN_NOEXCEPT {
 | 
						|
        (static_cast<TPromise&>(*this).state_) = TaskState<void>(std::current_exception());
 | 
						|
    }
 | 
						|
};
 | 
						|
}
 | 
						|
 | 
						|
template<typename TValue>
 | 
						|
struct TaskAwaitableFuture
 | 
						|
{
 | 
						|
    FuturePtr<TValue> future;
 | 
						|
 | 
						|
    [[nodiscard]] constexpr bool await_ready() const MIJIN_NOEXCEPT { return future->ready(); }
 | 
						|
    constexpr void await_suspend(std::coroutine_handle<>) const MIJIN_NOEXCEPT {}
 | 
						|
    constexpr TValue await_resume() const
 | 
						|
    {
 | 
						|
        impl::throwIfCancelled();
 | 
						|
        if constexpr (std::is_same_v<TValue, void>) {
 | 
						|
            return;
 | 
						|
        }
 | 
						|
        else {
 | 
						|
            return std::move(future->get());
 | 
						|
        }
 | 
						|
    }
 | 
						|
};
 | 
						|
 | 
						|
template<typename... TArgs>
 | 
						|
struct TaskAwaitableSignal
 | 
						|
{
 | 
						|
    std::shared_ptr<std::tuple<TArgs...>> data;
 | 
						|
 | 
						|
    [[nodiscard]] constexpr bool await_ready() const MIJIN_NOEXCEPT { return false; }
 | 
						|
    constexpr void await_suspend(std::coroutine_handle<>) const MIJIN_NOEXCEPT {}
 | 
						|
    inline auto& await_resume() const
 | 
						|
    {
 | 
						|
        impl::throwIfCancelled();
 | 
						|
        return *data;
 | 
						|
    }
 | 
						|
};
 | 
						|
 | 
						|
template<typename TSingleArg>
 | 
						|
struct TaskAwaitableSignal<TSingleArg>
 | 
						|
{
 | 
						|
    std::shared_ptr<TSingleArg> data;
 | 
						|
 | 
						|
    [[nodiscard]] constexpr bool await_ready() const MIJIN_NOEXCEPT { return false; }
 | 
						|
    constexpr void await_suspend(std::coroutine_handle<>) const MIJIN_NOEXCEPT {}
 | 
						|
    constexpr auto& await_resume() const
 | 
						|
    {
 | 
						|
        impl::throwIfCancelled();
 | 
						|
        return *data;
 | 
						|
    }
 | 
						|
};
 | 
						|
 | 
						|
template<>
 | 
						|
struct TaskAwaitableSignal<>
 | 
						|
{
 | 
						|
    [[nodiscard]] constexpr bool await_ready() const MIJIN_NOEXCEPT { return false; }
 | 
						|
    constexpr void await_suspend(std::coroutine_handle<>) const MIJIN_NOEXCEPT {}
 | 
						|
    inline void await_resume() const {
 | 
						|
        impl::throwIfCancelled();
 | 
						|
    }
 | 
						|
};
 | 
						|
 | 
						|
struct TaskAwaitableSuspend
 | 
						|
{
 | 
						|
    [[nodiscard]] constexpr bool await_ready() const MIJIN_NOEXCEPT { return false; }
 | 
						|
    constexpr void await_suspend(std::coroutine_handle<>) const MIJIN_NOEXCEPT {}
 | 
						|
    inline void await_resume() const {
 | 
						|
        impl::throwIfCancelled();
 | 
						|
    }
 | 
						|
};
 | 
						|
 | 
						|
template<typename TTraits>
 | 
						|
struct TaskPromise : impl::TaskReturn<typename TTraits::result_t, TaskPromise<TTraits>>
 | 
						|
{
 | 
						|
    using handle_t = std::coroutine_handle<TaskPromise>;
 | 
						|
    using task_t = typename TTraits::task_t;
 | 
						|
    using result_t = typename TTraits::result_t;
 | 
						|
 | 
						|
    TaskState<result_t> state_;
 | 
						|
    std::shared_ptr<TaskSharedState> sharedState_ = std::make_shared<TaskSharedState>();
 | 
						|
    TaskLoop* loop_ = nullptr;
 | 
						|
 | 
						|
    constexpr task_t get_return_object() MIJIN_NOEXCEPT { return task_t(handle_t::from_promise(*this)); }
 | 
						|
    constexpr TaskAwaitableSuspend initial_suspend() MIJIN_NOEXCEPT { return {}; }
 | 
						|
    constexpr std::suspend_always final_suspend() noexcept { return {}; } // note: this must always be noexcept, no matter what
 | 
						|
    
 | 
						|
    // template<typename TValue>
 | 
						|
    // constexpr std::suspend_always yield_value(TValue value) MIJIN_NOEXCEPT {
 | 
						|
    //     *state_ = TaskState<result_t>(std::move(value), TaskStatus::YIELDED);
 | 
						|
    //     return {};
 | 
						|
    // }
 | 
						|
 | 
						|
    // TODO: implement yielding (can't use futures for this)
 | 
						|
 | 
						|
    // constexpr void unhandled_exception() MIJIN_NOEXCEPT {}
 | 
						|
 | 
						|
    template<typename TValue>
 | 
						|
    auto await_transform(FuturePtr<TValue> future) MIJIN_NOEXCEPT
 | 
						|
    {
 | 
						|
        MIJIN_ASSERT(loop_ != nullptr, "Cannot await future outside of a loop!");
 | 
						|
        TaskAwaitableFuture<TValue> awaitable{future};
 | 
						|
        if (!awaitable.await_ready())
 | 
						|
        {
 | 
						|
            state_.status = TaskStatus::WAITING;
 | 
						|
            future->sigSet.connect([this, future]() mutable
 | 
						|
            {
 | 
						|
                state_.status = TaskStatus::SUSPENDED;
 | 
						|
            }, Oneshot::YES);
 | 
						|
        }
 | 
						|
        return awaitable;
 | 
						|
    }
 | 
						|
 | 
						|
    template<typename TResultOther>
 | 
						|
    auto await_transform(TaskBase<TResultOther> task) MIJIN_NOEXCEPT
 | 
						|
    {
 | 
						|
        MIJIN_ASSERT(loop_ != nullptr, "Cannot await another task outside of a loop!"); // NOLINT(clang-analyzer-core.UndefinedBinaryOperatorResult)
 | 
						|
        auto future = delayEvaluation<TResultOther>(loop_)->addTask(std::move(task), &sharedState_->subTask); // hackidyhack: delay evaluation of the type of loop_ as it is only forward-declared here
 | 
						|
        return await_transform(future);
 | 
						|
    }
 | 
						|
 | 
						|
    template<typename TFirstArg, typename TSecondArg, typename... TArgs>
 | 
						|
    auto await_transform(Signal<TFirstArg, TSecondArg, TArgs...>& signal) MIJIN_NOEXCEPT
 | 
						|
    {
 | 
						|
        auto data = std::make_shared<std::tuple<TFirstArg, TSecondArg, TArgs...>>();
 | 
						|
        signal.connect([this, data](TFirstArg arg0, TSecondArg arg1, TArgs... args) mutable
 | 
						|
        {
 | 
						|
            *data = std::make_tuple(std::move(arg0), std::move(arg1), std::move(args)...);
 | 
						|
            state_.status = TaskStatus::SUSPENDED;
 | 
						|
        }, Oneshot::YES);
 | 
						|
        TaskAwaitableSignal<TFirstArg, TSecondArg, TArgs...> awaitable{data};
 | 
						|
        state_.status = TaskStatus::WAITING;
 | 
						|
        return awaitable;
 | 
						|
    }
 | 
						|
 | 
						|
    template<typename TFirstArg>
 | 
						|
    auto await_transform(Signal<TFirstArg>& signal) MIJIN_NOEXCEPT
 | 
						|
    {
 | 
						|
        auto data = std::make_shared<TFirstArg>();
 | 
						|
        signal.connect([this, data](TFirstArg arg0) mutable
 | 
						|
        {
 | 
						|
            *data = std::move(arg0);
 | 
						|
            state_.status = TaskStatus::SUSPENDED;
 | 
						|
        }, Oneshot::YES);
 | 
						|
        TaskAwaitableSignal<TFirstArg> awaitable{data};
 | 
						|
        state_.status = TaskStatus::WAITING;
 | 
						|
        return awaitable;
 | 
						|
    }
 | 
						|
 | 
						|
    auto await_transform(Signal<>& signal) MIJIN_NOEXCEPT
 | 
						|
    {
 | 
						|
        signal.connect([this]()
 | 
						|
        {
 | 
						|
            state_.status = TaskStatus::SUSPENDED;
 | 
						|
        }, Oneshot::YES);
 | 
						|
        TaskAwaitableSignal<> awaitable{};
 | 
						|
        state_.status = TaskStatus::WAITING;
 | 
						|
        return awaitable;
 | 
						|
    }
 | 
						|
 | 
						|
    std::suspend_always await_transform(std::suspend_always) MIJIN_NOEXCEPT
 | 
						|
    {
 | 
						|
        state_.status = TaskStatus::SUSPENDED;
 | 
						|
        return std::suspend_always();
 | 
						|
    }
 | 
						|
 | 
						|
    std::suspend_never await_transform(std::suspend_never) MIJIN_NOEXCEPT {
 | 
						|
        return std::suspend_never();
 | 
						|
    }
 | 
						|
 | 
						|
    TaskAwaitableSuspend await_transform(TaskAwaitableSuspend) MIJIN_NOEXCEPT
 | 
						|
    {
 | 
						|
        state_.status = TaskStatus::SUSPENDED;
 | 
						|
        return TaskAwaitableSuspend();
 | 
						|
    }
 | 
						|
};
 | 
						|
 | 
						|
template<typename TResult>
 | 
						|
class [[nodiscard("Tasks should either we awaited or added to a loop.")]] TaskBase
 | 
						|
{
 | 
						|
public:
 | 
						|
    using task_t = TaskBase;
 | 
						|
    using result_t = TResult;
 | 
						|
    struct Traits
 | 
						|
    {
 | 
						|
        using task_t = TaskBase;
 | 
						|
        using result_t = TResult;
 | 
						|
    };
 | 
						|
public:
 | 
						|
    using promise_type = TaskPromise<Traits>;
 | 
						|
    using handle_t = typename promise_type::handle_t;
 | 
						|
private:
 | 
						|
    handle_t handle_;
 | 
						|
public:
 | 
						|
    constexpr explicit TaskBase(handle_t handle) MIJIN_NOEXCEPT : handle_(handle) {
 | 
						|
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
 | 
						|
        if (Result<Stacktrace> stacktrace = captureStacktrace(2); stacktrace.isSuccess())
 | 
						|
        {
 | 
						|
            handle_.promise().sharedState_->creationStack_ = *stacktrace;
 | 
						|
        }
 | 
						|
#endif
 | 
						|
    }
 | 
						|
    TaskBase(const TaskBase&) = delete;
 | 
						|
    TaskBase(TaskBase&& other) MIJIN_NOEXCEPT : handle_(std::exchange(other.handle_, nullptr))
 | 
						|
    {}
 | 
						|
    ~TaskBase() MIJIN_NOEXCEPT;
 | 
						|
public:
 | 
						|
    TaskBase& operator=(const TaskBase&) = delete;
 | 
						|
    TaskBase& operator=(TaskBase&& other) MIJIN_NOEXCEPT
 | 
						|
    {
 | 
						|
        if (handle_) {
 | 
						|
            handle_.destroy();
 | 
						|
        }
 | 
						|
        handle_ = std::exchange(other.handle_, nullptr);
 | 
						|
        return *this;
 | 
						|
    }
 | 
						|
 | 
						|
    [[nodiscard]]
 | 
						|
    constexpr bool operator==(const TaskBase& other) const MIJIN_NOEXCEPT { return handle_ == other.handle_; }
 | 
						|
 | 
						|
    [[nodiscard]]
 | 
						|
    constexpr bool operator!=(const TaskBase& other) const MIJIN_NOEXCEPT { return handle_ != other.handle_; }
 | 
						|
public:
 | 
						|
    [[nodiscard]]
 | 
						|
    constexpr TaskState<TResult>& state() MIJIN_NOEXCEPT
 | 
						|
    {
 | 
						|
        return handle_.promise().state_;
 | 
						|
    }
 | 
						|
    constexpr TaskState<TResult>& resume()
 | 
						|
    {
 | 
						|
        state().status = TaskStatus::RUNNING;
 | 
						|
        handle_.resume();
 | 
						|
        return state();
 | 
						|
    }
 | 
						|
    constexpr std::shared_ptr<TaskSharedState>& sharedState() MIJIN_NOEXCEPT
 | 
						|
    {
 | 
						|
        return handle_.promise().sharedState_;
 | 
						|
    }
 | 
						|
private:
 | 
						|
    [[nodiscard]]
 | 
						|
    constexpr handle_t handle() const MIJIN_NOEXCEPT { return handle_; }
 | 
						|
    [[nodiscard]]
 | 
						|
    constexpr TaskLoop* getLoop() MIJIN_NOEXCEPT
 | 
						|
    {
 | 
						|
        return handle_.promise().loop_;
 | 
						|
    }
 | 
						|
    constexpr void setLoop(TaskLoop* loop) MIJIN_NOEXCEPT
 | 
						|
    {
 | 
						|
        // MIJIN_ASSERT(handle_.promise().loop_ == nullptr
 | 
						|
        //     || handle_.promise().loop_ == loop
 | 
						|
        //     || loop == nullptr, "Task already has a loop assigned!");
 | 
						|
        handle_.promise().loop_ = loop;
 | 
						|
    }
 | 
						|
 | 
						|
    friend class TaskLoop;
 | 
						|
 | 
						|
    template<typename TTask>
 | 
						|
    friend class WrappedTask;
 | 
						|
};
 | 
						|
 | 
						|
class WrappedTaskBase
 | 
						|
{
 | 
						|
public:
 | 
						|
    virtual ~WrappedTaskBase() = default;
 | 
						|
public:
 | 
						|
    virtual TaskStatus status() MIJIN_NOEXCEPT = 0;
 | 
						|
    virtual std::exception_ptr exception() MIJIN_NOEXCEPT = 0;
 | 
						|
    // virtual std::any result() MIJIN_NOEXCEPT = 0;
 | 
						|
    virtual void resume() = 0;
 | 
						|
    virtual void* raw() MIJIN_NOEXCEPT = 0;
 | 
						|
    virtual std::coroutine_handle<> handle() MIJIN_NOEXCEPT = 0;
 | 
						|
    virtual void setLoop(TaskLoop* loop) MIJIN_NOEXCEPT = 0;
 | 
						|
    virtual std::shared_ptr<TaskSharedState>& sharedState() MIJIN_NOEXCEPT = 0;
 | 
						|
 | 
						|
    [[nodiscard]] inline bool canResume() {
 | 
						|
        const TaskStatus stat = status();
 | 
						|
        return (stat == TaskStatus::SUSPENDED || stat == TaskStatus::YIELDED);
 | 
						|
    }
 | 
						|
};
 | 
						|
 | 
						|
template<typename TTask>
 | 
						|
class WrappedTask : public WrappedTaskBase
 | 
						|
{
 | 
						|
private:
 | 
						|
    TTask task_;
 | 
						|
public:
 | 
						|
    constexpr explicit WrappedTask(TTask&& task) MIJIN_NOEXCEPT : task_(std::move(task)) {}
 | 
						|
    WrappedTask(const WrappedTask&) = delete;
 | 
						|
    WrappedTask(WrappedTask&&) MIJIN_NOEXCEPT = default;
 | 
						|
public:
 | 
						|
    WrappedTask& operator=(const WrappedTask&) = delete;
 | 
						|
    WrappedTask& operator=(WrappedTask&&) MIJIN_NOEXCEPT = default;
 | 
						|
public:
 | 
						|
    TaskStatus status() MIJIN_NOEXCEPT override { return task_.state().status; }
 | 
						|
    std::exception_ptr exception() MIJIN_NOEXCEPT override { return task_.state().exception; }
 | 
						|
    // std::any result() MIJIN_NOEXCEPT
 | 
						|
    // {
 | 
						|
    //     if constexpr (std::is_same_v<typename TTask::result_t, void>) {
 | 
						|
    //         return {};
 | 
						|
    //     }
 | 
						|
    //     else {
 | 
						|
    //         return std::any(task_.state().value);
 | 
						|
    //     }
 | 
						|
    // }
 | 
						|
    void resume() override { task_.resume(); }
 | 
						|
    void* raw() MIJIN_NOEXCEPT override { return &task_; }
 | 
						|
    std::coroutine_handle<> handle() MIJIN_NOEXCEPT override { return task_.handle(); }
 | 
						|
    void setLoop(TaskLoop* loop) MIJIN_NOEXCEPT override { task_.setLoop(loop); }
 | 
						|
    virtual std::shared_ptr<TaskSharedState>& sharedState() MIJIN_NOEXCEPT override { return task_.sharedState(); }
 | 
						|
};
 | 
						|
 | 
						|
template<typename TTask>
 | 
						|
std::unique_ptr<WrappedTask<TTask>> wrapTask(TTask&& task) MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    return std::make_unique<WrappedTask<TTask>>(std::forward<TTask>(task));
 | 
						|
}
 | 
						|
 | 
						|
class TaskLoop
 | 
						|
{
 | 
						|
public:
 | 
						|
    MIJIN_DEFINE_FLAG(CanContinue);
 | 
						|
    MIJIN_DEFINE_FLAG(IgnoreWaiting);
 | 
						|
 | 
						|
    using wrapped_task_t = WrappedTaskBase;
 | 
						|
    using wrapped_task_base_ptr_t = std::unique_ptr<wrapped_task_t>;
 | 
						|
    struct StoredTask
 | 
						|
    {
 | 
						|
        wrapped_task_base_ptr_t task;
 | 
						|
        std::function<void(StoredTask&)> setFuture;
 | 
						|
        std::any resultData;
 | 
						|
    };
 | 
						|
 | 
						|
    using exception_handler_t = std::function<void(std::exception_ptr)>;
 | 
						|
protected:
 | 
						|
    using task_vector_t = std::vector<StoredTask>;
 | 
						|
 | 
						|
    template<typename TTask>
 | 
						|
    using wrapped_task_ptr_t = std::unique_ptr<WrappedTask<TTask>>;
 | 
						|
 | 
						|
    exception_handler_t uncaughtExceptionHandler_;
 | 
						|
public:
 | 
						|
    TaskLoop() MIJIN_NOEXCEPT = default;
 | 
						|
    TaskLoop(const TaskLoop&) = delete;
 | 
						|
    TaskLoop(TaskLoop&&) = delete;
 | 
						|
    virtual ~TaskLoop() MIJIN_NOEXCEPT = default;
 | 
						|
 | 
						|
    TaskLoop& operator=(const TaskLoop&) = delete;
 | 
						|
    TaskLoop& operator=(TaskLoop&&) = delete;
 | 
						|
 | 
						|
    void setUncaughtExceptionHandler(exception_handler_t handler) MIJIN_NOEXCEPT { uncaughtExceptionHandler_ = std::move(handler); }
 | 
						|
 | 
						|
    template<typename TResult>
 | 
						|
    inline FuturePtr<TResult> addTask(TaskBase<TResult> task, TaskHandle* outHandle = nullptr) MIJIN_NOEXCEPT;
 | 
						|
 | 
						|
    virtual void transferCurrentTask(TaskLoop& otherLoop) MIJIN_NOEXCEPT = 0;
 | 
						|
    virtual void addStoredTask(StoredTask&& storedTask) MIJIN_NOEXCEPT = 0;
 | 
						|
 | 
						|
    [[nodiscard]] static TaskLoop& current() MIJIN_NOEXCEPT;
 | 
						|
protected:
 | 
						|
    inline TaskStatus tickTask(StoredTask& task);
 | 
						|
protected:
 | 
						|
    static inline TaskLoop*& currentLoopStorage() MIJIN_NOEXCEPT;
 | 
						|
    template<typename TResult>
 | 
						|
    static inline void setFutureHelper(StoredTask& storedTask) MIJIN_NOEXCEPT;
 | 
						|
};
 | 
						|
 | 
						|
template<typename TResult = void>
 | 
						|
using Task = TaskBase<TResult>;
 | 
						|
 | 
						|
class SimpleTaskLoop : public TaskLoop
 | 
						|
{
 | 
						|
private:
 | 
						|
    task_vector_t tasks_;
 | 
						|
    task_vector_t newTasks_;
 | 
						|
    task_vector_t::iterator currentTask_;
 | 
						|
    MessageQueue<StoredTask> queuedTasks_;
 | 
						|
    std::thread::id threadId_;
 | 
						|
 | 
						|
public: // TaskLoop implementation
 | 
						|
    void transferCurrentTask(TaskLoop& otherLoop) MIJIN_NOEXCEPT override;
 | 
						|
    void addStoredTask(StoredTask&& storedTask) MIJIN_NOEXCEPT override;
 | 
						|
 | 
						|
public: // public interface
 | 
						|
    [[nodiscard]] constexpr bool empty() const MIJIN_NOEXCEPT { return tasks_.empty() && newTasks_.empty(); }
 | 
						|
    [[nodiscard]] constexpr std::size_t getNumTasks() const MIJIN_NOEXCEPT { return tasks_.size() + newTasks_.size(); }
 | 
						|
    [[nodiscard]] std::size_t getActiveTasks() const MIJIN_NOEXCEPT;
 | 
						|
    inline CanContinue tick();
 | 
						|
    inline void runUntilDone(IgnoreWaiting ignoreWaiting = IgnoreWaiting::NO);
 | 
						|
    inline void cancelAllTasks() MIJIN_NOEXCEPT;
 | 
						|
    [[nodiscard]] inline std::vector<TaskHandle> getAllTasks() const MIJIN_NOEXCEPT;
 | 
						|
private:
 | 
						|
    inline void assertCorrectThread() { MIJIN_ASSERT(threadId_ == std::thread::id() || threadId_ == std::this_thread::get_id(), "Unsafe to TaskLoop from different thread!"); }
 | 
						|
};
 | 
						|
 | 
						|
class MultiThreadedTaskLoop : public TaskLoop
 | 
						|
{
 | 
						|
private:
 | 
						|
    task_vector_t parkedTasks_; // buffer for tasks that don't fit into readyTasks_
 | 
						|
    MessageQueue<StoredTask> queuedTasks_; // tasks that should be appended to parked tasks
 | 
						|
    MessageQueue<StoredTask> readyTasks_; // task queue to send tasks to a worker thread
 | 
						|
    MessageQueue<StoredTask> returningTasks_; // task that have executed on a worker thread and return for further processing
 | 
						|
    std::jthread managerThread_;
 | 
						|
    std::vector<std::jthread> workerThreads_;
 | 
						|
 | 
						|
public: // TaskLoop implementation
 | 
						|
    void transferCurrentTask(TaskLoop& otherLoop) MIJIN_NOEXCEPT override;
 | 
						|
    void addStoredTask(StoredTask&& storedTask) MIJIN_NOEXCEPT override;
 | 
						|
 | 
						|
public: // public interface
 | 
						|
    void start(std::size_t numWorkerThreads);
 | 
						|
    void stop();
 | 
						|
private: // private stuff
 | 
						|
    void managerThread(std::stop_token stopToken);
 | 
						|
    void workerThread(std::stop_token stopToken, std::size_t workerId);
 | 
						|
};
 | 
						|
 | 
						|
//
 | 
						|
// public functions
 | 
						|
//
 | 
						|
 | 
						|
namespace impl
 | 
						|
{
 | 
						|
extern thread_local TaskLoop::StoredTask* gCurrentTask;
 | 
						|
 | 
						|
inline void throwIfCancelled()
 | 
						|
{
 | 
						|
    if (gCurrentTask->task->sharedState()->cancelled_)
 | 
						|
    {
 | 
						|
        throw TaskCancelled();
 | 
						|
    }
 | 
						|
}
 | 
						|
}
 | 
						|
 | 
						|
void TaskHandle::cancel() const MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    if (std::shared_ptr<TaskSharedState> state = state_.lock())
 | 
						|
    {
 | 
						|
        state->cancelled_ = true;
 | 
						|
        state->subTask.cancel();
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
 | 
						|
Optional<Stacktrace> TaskHandle::getCreationStack() const MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    if (std::shared_ptr<TaskSharedState> state = state_.lock())
 | 
						|
    {
 | 
						|
        return state->creationStack_;
 | 
						|
    }
 | 
						|
    return NULL_OPTIONAL;
 | 
						|
}
 | 
						|
#endif // MIJIN_COROUTINE_ENABLE_DEBUG_INFO
 | 
						|
 | 
						|
template<typename TResult>
 | 
						|
TaskBase<TResult>::~TaskBase() MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    if (handle_)
 | 
						|
    {
 | 
						|
        handle_.destroy();
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
template<typename TResult>
 | 
						|
inline FuturePtr<TResult> TaskLoop::addTask(TaskBase<TResult> task, TaskHandle* outHandle) MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    MIJIN_ASSERT(!task.getLoop(), "Attempting to add task that already has a loop!");
 | 
						|
    task.setLoop(this);
 | 
						|
 | 
						|
    auto future = std::make_shared<Future<TResult>>();
 | 
						|
    auto setFuture = &setFutureHelper<TResult>;
 | 
						|
 | 
						|
    if (outHandle != nullptr)
 | 
						|
    {
 | 
						|
        *outHandle = TaskHandle(task.sharedState());
 | 
						|
    }
 | 
						|
 | 
						|
    // add tasks to a seperate vector first as we might be running another task right now
 | 
						|
    addStoredTask(StoredTask{
 | 
						|
        .task = wrapTask(std::move(task)),
 | 
						|
        .setFuture = setFuture,
 | 
						|
        .resultData = future
 | 
						|
    });
 | 
						|
 | 
						|
    return future;
 | 
						|
}
 | 
						|
 | 
						|
inline TaskStatus TaskLoop::tickTask(StoredTask& task)
 | 
						|
{
 | 
						|
    TaskStatus status = {};
 | 
						|
    impl::gCurrentTask = &task;
 | 
						|
    do
 | 
						|
    {
 | 
						|
        task.task->resume();
 | 
						|
        status = task.task ? task.task->status() : TaskStatus::WAITING; // no inner task -> task switch context (and will be removed later)
 | 
						|
    }
 | 
						|
    while (status == TaskStatus::RUNNING);
 | 
						|
    impl::gCurrentTask = nullptr;
 | 
						|
 | 
						|
    if (task.task && task.task->exception())
 | 
						|
    {
 | 
						|
        try
 | 
						|
        {
 | 
						|
            std::rethrow_exception(task.task->exception());
 | 
						|
        }
 | 
						|
        catch(TaskCancelled&) {} // ignore those
 | 
						|
        catch(...)
 | 
						|
        {
 | 
						|
            if (uncaughtExceptionHandler_)
 | 
						|
            {
 | 
						|
                uncaughtExceptionHandler_(std::current_exception());
 | 
						|
            }
 | 
						|
            else
 | 
						|
            {
 | 
						|
                throw;
 | 
						|
            }
 | 
						|
        }
 | 
						|
 | 
						|
        // TODO: handle the exception somehow, others may be waiting
 | 
						|
        return TaskStatus::FINISHED;
 | 
						|
    }
 | 
						|
    if (status == TaskStatus::YIELDED || status == TaskStatus::FINISHED)
 | 
						|
    {
 | 
						|
        task.setFuture(task);
 | 
						|
    }
 | 
						|
    return status;
 | 
						|
}
 | 
						|
 | 
						|
/* static */ inline auto TaskLoop::current() MIJIN_NOEXCEPT -> TaskLoop&
 | 
						|
{
 | 
						|
    MIJIN_ASSERT(currentLoopStorage() != nullptr, "Attempting to fetch current loop while no coroutine is running!");
 | 
						|
    return *currentLoopStorage();
 | 
						|
}
 | 
						|
 | 
						|
/* static */ auto TaskLoop::currentLoopStorage() MIJIN_NOEXCEPT -> TaskLoop*&
 | 
						|
{
 | 
						|
    static thread_local TaskLoop* storage = nullptr;
 | 
						|
    return storage;
 | 
						|
}
 | 
						|
 | 
						|
template<typename TResult>
 | 
						|
/* static */ inline void TaskLoop::setFutureHelper(StoredTask& storedTask) MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    TaskBase<TResult>& task = *static_cast<TaskBase<TResult>*>(storedTask.task->raw());
 | 
						|
    auto future = std::any_cast<FuturePtr<TResult>>(storedTask.resultData);
 | 
						|
 | 
						|
    if constexpr (!std::is_same_v<TResult, void>)
 | 
						|
    {
 | 
						|
        MIJIN_ASSERT(!task.state().value.empty(), "Task did not produce a value?");
 | 
						|
        future->set(std::move(task.state().value.get()));
 | 
						|
    }
 | 
						|
    else {
 | 
						|
        future->set();
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
inline std::suspend_always switchContext(TaskLoop& taskLoop)
 | 
						|
{
 | 
						|
    TaskLoop& currentTaskLoop = TaskLoop::current();
 | 
						|
    if (¤tTaskLoop == &taskLoop) {
 | 
						|
        return {};
 | 
						|
    }
 | 
						|
    currentTaskLoop.transferCurrentTask(taskLoop);
 | 
						|
    return {};
 | 
						|
}
 | 
						|
 | 
						|
inline auto SimpleTaskLoop::tick() -> CanContinue
 | 
						|
{
 | 
						|
    // set current taskloop
 | 
						|
    MIJIN_ASSERT(currentLoopStorage() == nullptr, "Trying to tick a loop from a coroutine, this is not supported.");
 | 
						|
    currentLoopStorage() = this;
 | 
						|
    threadId_ = std::this_thread::get_id();
 | 
						|
 | 
						|
    // move over all tasks from newTasks
 | 
						|
    for (StoredTask& task : newTasks_)
 | 
						|
    {
 | 
						|
        tasks_.push_back(std::move(task));
 | 
						|
    }
 | 
						|
    newTasks_.clear();
 | 
						|
 | 
						|
    // also pick up tasks from other threads
 | 
						|
    while(true)
 | 
						|
    {
 | 
						|
        std::optional<StoredTask> task = queuedTasks_.tryPop();
 | 
						|
        if (!task.has_value()) {
 | 
						|
            break;
 | 
						|
        }
 | 
						|
        tasks_.push_back(std::move(*task));
 | 
						|
    }
 | 
						|
 | 
						|
    // remove any tasks that are finished executing
 | 
						|
    auto it = std::remove_if(tasks_.begin(), tasks_.end(), [](StoredTask& task) {
 | 
						|
        return task.task->status() == TaskStatus::FINISHED;
 | 
						|
    });
 | 
						|
    tasks_.erase(it, tasks_.end());
 | 
						|
 | 
						|
    CanContinue canContinue = CanContinue::NO;
 | 
						|
    // then execute all tasks that can be executed
 | 
						|
    for (currentTask_ = tasks_.begin(); currentTask_ != tasks_.end(); ++currentTask_)
 | 
						|
    {
 | 
						|
        StoredTask& task = *currentTask_;
 | 
						|
        TaskStatus status = task.task->status();
 | 
						|
        if (status == TaskStatus::WAITING && task.task->sharedState()->cancelled_)
 | 
						|
        {
 | 
						|
            // always continue a cancelled task, even if it was still waiting for a result
 | 
						|
            status = TaskStatus::SUSPENDED;
 | 
						|
        }
 | 
						|
        if (status != TaskStatus::SUSPENDED && status != TaskStatus::YIELDED)
 | 
						|
        {
 | 
						|
            MIJIN_ASSERT(status == TaskStatus::WAITING, "Task with invalid status in task list!");
 | 
						|
            continue;
 | 
						|
        }
 | 
						|
 | 
						|
        status = tickTask(task);
 | 
						|
 | 
						|
        if (status == TaskStatus::SUSPENDED || status == TaskStatus::YIELDED)
 | 
						|
        {
 | 
						|
            canContinue = CanContinue::YES;
 | 
						|
        }
 | 
						|
    }
 | 
						|
    // reset current loop
 | 
						|
    currentLoopStorage() = nullptr;
 | 
						|
 | 
						|
    // remove any tasks that have been transferred to another queue
 | 
						|
    it = std::remove_if(tasks_.begin(), tasks_.end(), [](const StoredTask& task) {
 | 
						|
        return task.task == nullptr;
 | 
						|
    });
 | 
						|
    tasks_.erase(it, tasks_.end());
 | 
						|
 | 
						|
    return canContinue;
 | 
						|
}
 | 
						|
 | 
						|
inline void SimpleTaskLoop::runUntilDone(IgnoreWaiting ignoreWaiting)
 | 
						|
{
 | 
						|
    while (!tasks_.empty() || !newTasks_.empty())
 | 
						|
    {
 | 
						|
        const CanContinue canContinue = tick();
 | 
						|
        if (ignoreWaiting && !canContinue)
 | 
						|
        {
 | 
						|
            break;
 | 
						|
        }
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
inline void SimpleTaskLoop::cancelAllTasks() MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    for (StoredTask& task : mijin::chain(tasks_, newTasks_))
 | 
						|
    {
 | 
						|
        task.task->sharedState()->cancelled_ = true;
 | 
						|
    }
 | 
						|
    for (StoredTask& task : queuedTasks_)
 | 
						|
    {
 | 
						|
        // just discard it
 | 
						|
        (void) task;
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
inline std::vector<TaskHandle> SimpleTaskLoop::getAllTasks() const MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    std::vector<TaskHandle> result;
 | 
						|
    for (const StoredTask& task : mijin::chain(tasks_, newTasks_))
 | 
						|
    {
 | 
						|
        result.emplace_back(task.task->sharedState());
 | 
						|
    }
 | 
						|
    return result;
 | 
						|
}
 | 
						|
 | 
						|
// utility stuff
 | 
						|
 | 
						|
inline TaskAwaitableSuspend c_suspend() {
 | 
						|
    return TaskAwaitableSuspend();
 | 
						|
}
 | 
						|
 | 
						|
template<template<typename...> typename TCollection, typename TType, typename... TTemplateArgs>
 | 
						|
Task<> c_allDone(const TCollection<FuturePtr<TType>, TTemplateArgs...>& futures)
 | 
						|
{
 | 
						|
    bool allDone = true;
 | 
						|
    do
 | 
						|
    {
 | 
						|
        allDone = true;
 | 
						|
        for (const FuturePtr<TType>& future : futures)
 | 
						|
        {
 | 
						|
            if (future && !future->ready()) {
 | 
						|
                allDone = false;
 | 
						|
                break;
 | 
						|
            }
 | 
						|
        }
 | 
						|
        co_await c_suspend();
 | 
						|
    } while (!allDone);
 | 
						|
}
 | 
						|
 | 
						|
[[nodiscard]] inline TaskHandle getCurrentTask() MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    MIJIN_ASSERT(impl::gCurrentTask != nullptr, "Attempt to call getCurrentTask() outside of task.");
 | 
						|
    return TaskHandle(impl::gCurrentTask->task->sharedState());
 | 
						|
}
 | 
						|
}
 | 
						|
 | 
						|
#endif // MIJIN_ASYNC_COROUTINE_HPP_INCLUDED
 |