Moved the creation stack to the shared state so it is retrievable, added comparison operators to TaskHandle and added getCurrentTask() function to retrieve the handle of the current task.

This commit is contained in:
Patrick 2023-11-19 20:04:19 +01:00
parent 803f1463dc
commit ba8c1ebe1e

View File

@ -23,6 +23,7 @@
#include "./message_queue.hpp" #include "./message_queue.hpp"
#include "../container/optional.hpp" #include "../container/optional.hpp"
#include "../util/flag.hpp" #include "../util/flag.hpp"
#include "../util/iterators.hpp"
#include "../util/traits.hpp" #include "../util/traits.hpp"
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO #if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
#include "../debug/stacktrace.hpp" #include "../debug/stacktrace.hpp"
@ -77,17 +78,30 @@ public:
TaskHandle& operator=(const TaskHandle&) = default; TaskHandle& operator=(const TaskHandle&) = default;
TaskHandle& operator=(TaskHandle&&) = default; TaskHandle& operator=(TaskHandle&&) = default;
bool operator==(const TaskHandle& other) const noexcept {
return !state_.owner_before(other.state_) && !other.state_.owner_before(state_);
}
bool operator!=(const TaskHandle& other) const noexcept {
return !(*this == other);
}
[[nodiscard]] bool isValid() const noexcept [[nodiscard]] bool isValid() const noexcept
{ {
return !state_.expired(); return !state_.expired();
} }
inline void cancel() const noexcept; inline void cancel() const noexcept;
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
inline Optional<Stacktrace> getCreationStack() const noexcept;
#endif
}; };
struct TaskSharedState struct TaskSharedState
{ {
std::atomic_bool cancelled_ = false; std::atomic_bool cancelled_ = false;
TaskHandle subTask; TaskHandle subTask;
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
Stacktrace creationStack_;
#endif
}; };
template<typename T> template<typename T>
@ -340,23 +354,17 @@ public:
using handle_t = typename promise_type::handle_t; using handle_t = typename promise_type::handle_t;
private: private:
handle_t handle_; handle_t handle_;
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
Stacktrace creationStack_;
#endif
public: public:
constexpr explicit TaskBase(handle_t handle) noexcept : handle_(handle) { constexpr explicit TaskBase(handle_t handle) noexcept : handle_(handle) {
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO #if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
if (Result<Stacktrace> stacktrace = captureStacktrace(1); stacktrace.isSuccess()) if (Result<Stacktrace> stacktrace = captureStacktrace(2); stacktrace.isSuccess())
{ {
creationStack_ = *stacktrace; handle_.promise().sharedState_->creationStack_ = *stacktrace;
} }
#endif #endif
} }
TaskBase(const TaskBase&) = delete; TaskBase(const TaskBase&) = delete;
TaskBase(TaskBase&& other) noexcept : handle_(std::exchange(other.handle_, nullptr)) TaskBase(TaskBase&& other) noexcept : handle_(std::exchange(other.handle_, nullptr))
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
, creationStack_(std::move(other.creationStack_))
#endif
{} {}
~TaskBase() noexcept; ~TaskBase() noexcept;
public: public:
@ -367,9 +375,6 @@ public:
handle_.destroy(); handle_.destroy();
} }
handle_ = std::exchange(other.handle_, nullptr); handle_ = std::exchange(other.handle_, nullptr);
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
creationStack_ = std::move(other.creationStack_);
#endif
return *this; return *this;
} }
@ -464,7 +469,7 @@ public:
void* raw() noexcept override { return &task_; } void* raw() noexcept override { return &task_; }
std::coroutine_handle<> handle() noexcept override { return task_.handle(); } std::coroutine_handle<> handle() noexcept override { return task_.handle(); }
void setLoop(TaskLoop* loop) noexcept override { task_.setLoop(loop); } void setLoop(TaskLoop* loop) noexcept override { task_.setLoop(loop); }
virtual std::shared_ptr<TaskSharedState>& sharedState() noexcept { return task_.sharedState(); } virtual std::shared_ptr<TaskSharedState>& sharedState() noexcept override { return task_.sharedState(); }
}; };
template<typename TTask> template<typename TTask>
@ -540,8 +545,11 @@ public: // TaskLoop implementation
public: // public interface public: // public interface
[[nodiscard]] constexpr bool empty() const noexcept { return tasks_.empty() && newTasks_.empty(); } [[nodiscard]] constexpr bool empty() const noexcept { return tasks_.empty() && newTasks_.empty(); }
[[nodiscard]] constexpr std::size_t getNumTasks() const noexcept { return tasks_.size() + newTasks_.size(); }
inline CanContinue tick(); inline CanContinue tick();
inline void runUntilDone(IgnoreWaiting ignoreWaiting = IgnoreWaiting::NO); inline void runUntilDone(IgnoreWaiting ignoreWaiting = IgnoreWaiting::NO);
inline void cancelAllTasks() noexcept;
[[nodiscard]] inline std::vector<TaskHandle> getAllTasks() const noexcept;
private: private:
inline void assertCorrectThread() { MIJIN_ASSERT(threadId_ == std::thread::id() || threadId_ == std::this_thread::get_id(), "Unsafe to TaskLoop from different thread!"); } inline void assertCorrectThread() { MIJIN_ASSERT(threadId_ == std::thread::id() || threadId_ == std::this_thread::get_id(), "Unsafe to TaskLoop from different thread!"); }
}; };
@ -594,6 +602,17 @@ void TaskHandle::cancel() const noexcept
} }
} }
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
Optional<Stacktrace> TaskHandle::getCreationStack() const noexcept
{
if (std::shared_ptr<TaskSharedState> state = state_.lock())
{
return state->creationStack_;
}
return NULL_OPTIONAL;
}
#endif // MIJIN_COROUTINE_ENABLE_DEBUG_INFO
template<typename TResult> template<typename TResult>
TaskBase<TResult>::~TaskBase() noexcept TaskBase<TResult>::~TaskBase() noexcept
{ {
@ -784,6 +803,29 @@ inline void SimpleTaskLoop::runUntilDone(IgnoreWaiting ignoreWaiting)
} }
} }
inline void SimpleTaskLoop::cancelAllTasks() noexcept
{
for (StoredTask& task : mijin::chain(tasks_, newTasks_))
{
task.task->sharedState()->cancelled_ = true;
}
for (StoredTask& task : queuedTasks_)
{
// just discard it
(void) task;
}
}
inline std::vector<TaskHandle> SimpleTaskLoop::getAllTasks() const noexcept
{
std::vector<TaskHandle> result;
for (const StoredTask& task : mijin::chain(tasks_, newTasks_))
{
result.emplace_back(task.task->sharedState());
}
return result;
}
// utility stuff // utility stuff
inline TaskAwaitableSuspend c_suspend() { inline TaskAwaitableSuspend c_suspend() {
@ -808,8 +850,11 @@ Task<> c_allDone(const TCollection<FuturePtr<TType>, TTemplateArgs...>& futures)
} while (!allDone); } while (!allDone);
} }
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO [[nodiscard]] inline TaskHandle getCurrentTask() noexcept
#endif {
MIJIN_ASSERT(impl::gCurrentTask != nullptr, "Attempt to call getCurrentTask() outside of task.");
return TaskHandle(impl::gCurrentTask->task->sharedState());
}
} }
#endif // MIJIN_ASYNC_COROUTINE_HPP_INCLUDED #endif // MIJIN_ASYNC_COROUTINE_HPP_INCLUDED