Moved the creation stack to the shared state so it is retrievable, added comparison operators to TaskHandle and added getCurrentTask() function to retrieve the handle of the current task.

This commit is contained in:
Patrick 2023-11-19 20:04:19 +01:00
parent 803f1463dc
commit ba8c1ebe1e

View File

@ -23,6 +23,7 @@
#include "./message_queue.hpp"
#include "../container/optional.hpp"
#include "../util/flag.hpp"
#include "../util/iterators.hpp"
#include "../util/traits.hpp"
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
#include "../debug/stacktrace.hpp"
@ -77,17 +78,30 @@ public:
TaskHandle& operator=(const TaskHandle&) = default;
TaskHandle& operator=(TaskHandle&&) = default;
bool operator==(const TaskHandle& other) const noexcept {
return !state_.owner_before(other.state_) && !other.state_.owner_before(state_);
}
bool operator!=(const TaskHandle& other) const noexcept {
return !(*this == other);
}
[[nodiscard]] bool isValid() const noexcept
{
return !state_.expired();
}
inline void cancel() const noexcept;
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
inline Optional<Stacktrace> getCreationStack() const noexcept;
#endif
};
struct TaskSharedState
{
std::atomic_bool cancelled_ = false;
TaskHandle subTask;
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
Stacktrace creationStack_;
#endif
};
template<typename T>
@ -340,23 +354,17 @@ public:
using handle_t = typename promise_type::handle_t;
private:
handle_t handle_;
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
Stacktrace creationStack_;
#endif
public:
constexpr explicit TaskBase(handle_t handle) noexcept : handle_(handle) {
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
if (Result<Stacktrace> stacktrace = captureStacktrace(1); stacktrace.isSuccess())
if (Result<Stacktrace> stacktrace = captureStacktrace(2); stacktrace.isSuccess())
{
creationStack_ = *stacktrace;
handle_.promise().sharedState_->creationStack_ = *stacktrace;
}
#endif
}
TaskBase(const TaskBase&) = delete;
TaskBase(TaskBase&& other) noexcept : handle_(std::exchange(other.handle_, nullptr))
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
, creationStack_(std::move(other.creationStack_))
#endif
{}
~TaskBase() noexcept;
public:
@ -367,9 +375,6 @@ public:
handle_.destroy();
}
handle_ = std::exchange(other.handle_, nullptr);
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
creationStack_ = std::move(other.creationStack_);
#endif
return *this;
}
@ -464,7 +469,7 @@ public:
void* raw() noexcept override { return &task_; }
std::coroutine_handle<> handle() noexcept override { return task_.handle(); }
void setLoop(TaskLoop* loop) noexcept override { task_.setLoop(loop); }
virtual std::shared_ptr<TaskSharedState>& sharedState() noexcept { return task_.sharedState(); }
virtual std::shared_ptr<TaskSharedState>& sharedState() noexcept override { return task_.sharedState(); }
};
template<typename TTask>
@ -540,8 +545,11 @@ public: // TaskLoop implementation
public: // public interface
[[nodiscard]] constexpr bool empty() const noexcept { return tasks_.empty() && newTasks_.empty(); }
[[nodiscard]] constexpr std::size_t getNumTasks() const noexcept { return tasks_.size() + newTasks_.size(); }
inline CanContinue tick();
inline void runUntilDone(IgnoreWaiting ignoreWaiting = IgnoreWaiting::NO);
inline void cancelAllTasks() noexcept;
[[nodiscard]] inline std::vector<TaskHandle> getAllTasks() const noexcept;
private:
inline void assertCorrectThread() { MIJIN_ASSERT(threadId_ == std::thread::id() || threadId_ == std::this_thread::get_id(), "Unsafe to TaskLoop from different thread!"); }
};
@ -594,6 +602,17 @@ void TaskHandle::cancel() const noexcept
}
}
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
Optional<Stacktrace> TaskHandle::getCreationStack() const noexcept
{
if (std::shared_ptr<TaskSharedState> state = state_.lock())
{
return state->creationStack_;
}
return NULL_OPTIONAL;
}
#endif // MIJIN_COROUTINE_ENABLE_DEBUG_INFO
template<typename TResult>
TaskBase<TResult>::~TaskBase() noexcept
{
@ -784,6 +803,29 @@ inline void SimpleTaskLoop::runUntilDone(IgnoreWaiting ignoreWaiting)
}
}
inline void SimpleTaskLoop::cancelAllTasks() noexcept
{
for (StoredTask& task : mijin::chain(tasks_, newTasks_))
{
task.task->sharedState()->cancelled_ = true;
}
for (StoredTask& task : queuedTasks_)
{
// just discard it
(void) task;
}
}
inline std::vector<TaskHandle> SimpleTaskLoop::getAllTasks() const noexcept
{
std::vector<TaskHandle> result;
for (const StoredTask& task : mijin::chain(tasks_, newTasks_))
{
result.emplace_back(task.task->sharedState());
}
return result;
}
// utility stuff
inline TaskAwaitableSuspend c_suspend() {
@ -808,8 +850,11 @@ Task<> c_allDone(const TCollection<FuturePtr<TType>, TTemplateArgs...>& futures)
} while (!allDone);
}
#if MIJIN_COROUTINE_ENABLE_DEBUG_INFO
#endif
[[nodiscard]] inline TaskHandle getCurrentTask() noexcept
{
MIJIN_ASSERT(impl::gCurrentTask != nullptr, "Attempt to call getCurrentTask() outside of task.");
return TaskHandle(impl::gCurrentTask->task->sharedState());
}
}
#endif // MIJIN_ASYNC_COROUTINE_HPP_INCLUDED