Initial commit.
This commit is contained in:
128
include/FunctionGuard.h
Normal file
128
include/FunctionGuard.h
Normal file
@@ -0,0 +1,128 @@
|
||||
#pragma once
|
||||
|
||||
/// @defgroup FunctionGuard Function Guard
|
||||
/// @brief Scope guard that calls a function as it leaves scope.
|
||||
/// @{
|
||||
///
|
||||
/// A FunctionGuard is an scope guard object that stores a functor that will be called from its destructor. By
|
||||
/// convention, scope guards are move-only objects that are intended for allocation on the stack, to ensure that certain
|
||||
/// operations are performed exactly once (when their scope collapses).
|
||||
///
|
||||
/// Because tasks can be canceled while suspended (and thus do not reach the end of the function), any cleanup code at
|
||||
/// the end of a task isn't guaranteed to execute. Because FunctionGuard is an RAII object, it gives programmers an
|
||||
/// opportunity to schedule guaranteed cleanup code, no matter how a task terminates.
|
||||
///
|
||||
/// Consider the following example of a task that manages a character's "charge attack" in a combat-oriented game:
|
||||
///
|
||||
/// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~{.cpp}
|
||||
///
|
||||
/// class Character : public Actor
|
||||
/// {
|
||||
/// public:
|
||||
/// Task<> ChargeAttackState()
|
||||
/// {
|
||||
/// bool bIsFullyCharged = false;
|
||||
/// if(Input->IsAttackButtonPressed())
|
||||
/// {
|
||||
/// StartCharging(); // Start playing charge effects
|
||||
/// auto stopChargingGuard = MakeFnGuard([&]{
|
||||
/// StopCharging(); // Stop playing charge effects
|
||||
/// });
|
||||
///
|
||||
/// // Wait for N seconds (canceling if button is no longer held)
|
||||
/// bIsFullyCharged = co_await WaitSeconds(chargeTime).CancelIf([&] {
|
||||
/// return !Input->IsAttackButtonPressed();
|
||||
/// });
|
||||
/// } // <-- This is when StopCharging() will be called
|
||||
/// FireShot(bIsFullyCharged);
|
||||
/// }
|
||||
/// };
|
||||
///
|
||||
/// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
///
|
||||
/// In the above example, we can guarantee that StopCharging will logically be called exactly once for every call to
|
||||
/// StartCharging(), even the ChargeAttackState() task is killed or canceled. Furthermore, we know that StopCharging()
|
||||
/// will always be called prior to the call to FireShot().
|
||||
///
|
||||
/// In practice, it is often desirable to create more domain-specific scope guards for specific use cases, but
|
||||
/// FunctionGuard provides a simple general-purpose tool for writing robust, water-tight coroutine logic without the
|
||||
/// overhead of creating bespoke support classes.
|
||||
|
||||
//--- User configuration header ---//
|
||||
#include "TasksConfig.h"
|
||||
|
||||
NAMESPACE_SQUID_BEGIN
|
||||
|
||||
template <typename tFn = TFunction<void()>>
|
||||
class FunctionGuard
|
||||
{
|
||||
public:
|
||||
FunctionGuard() = default; /// Default constructor
|
||||
FunctionGuard(nullptr_t) /// Null-pointer constructor
|
||||
{
|
||||
}
|
||||
FunctionGuard(tFn in_fn) /// Functor constructor
|
||||
: m_fn(MoveTemp(in_fn))
|
||||
{
|
||||
}
|
||||
~FunctionGuard() /// Destructor
|
||||
{
|
||||
Execute();
|
||||
}
|
||||
FunctionGuard(FunctionGuard&& in_other) noexcept /// Move constructor
|
||||
: m_fn(MoveTemp(in_other.m_fn))
|
||||
{
|
||||
in_other.Forget();
|
||||
}
|
||||
FunctionGuard& operator=(FunctionGuard<tFn>&& in_other) noexcept /// Move assignment operator
|
||||
{
|
||||
m_fn = MoveTemp(in_other.m_fn);
|
||||
in_other.Forget();
|
||||
return *this;
|
||||
}
|
||||
FunctionGuard& operator=(nullptr_t) noexcept /// Null-pointer assignment operator (calls Forget() to clear the functor)
|
||||
{
|
||||
Forget();
|
||||
return *this;
|
||||
}
|
||||
operator bool() const /// Convenience conversion operator that calls IsBound()
|
||||
{
|
||||
return IsBound();
|
||||
}
|
||||
bool IsBound() noexcept /// Returns whether functor has been bound to this FunctionGuard
|
||||
{
|
||||
return m_fn;
|
||||
}
|
||||
void Execute() /// Executes and clears the functor (if bound)
|
||||
{
|
||||
if(m_fn)
|
||||
{
|
||||
m_fn.GetValue()();
|
||||
Forget();
|
||||
}
|
||||
}
|
||||
void Forget() noexcept /// Clear the functor (without calling it)
|
||||
{
|
||||
m_fn.Reset();
|
||||
}
|
||||
|
||||
private:
|
||||
TOptional<tFn> m_fn; // The function to call when this scope guard is destroyed
|
||||
};
|
||||
|
||||
/// Create a function guard (directly stores the concretely-typed functor in the FunctionGuard)
|
||||
template <typename tFn>
|
||||
FunctionGuard<tFn> MakeFnGuard(tFn in_fn)
|
||||
{
|
||||
return FunctionGuard<tFn>(MoveTemp(in_fn));
|
||||
}
|
||||
|
||||
/// Create a generic function guard (preferable when re-assigning new functor values to the same variable)
|
||||
inline FunctionGuard<> MakeGenericFnGuard(TFunction<void()> in_fn)
|
||||
{
|
||||
return FunctionGuard<>(MoveTemp(in_fn));
|
||||
}
|
||||
|
||||
NAMESPACE_SQUID_END
|
||||
|
||||
///@} end of FunctionGuard group
|
||||
220
include/Private/TaskFSMPrivate.h
Normal file
220
include/Private/TaskFSMPrivate.h
Normal file
@@ -0,0 +1,220 @@
|
||||
// WARNING: This is an internal implementation header, which must be included from a specific location/namespace
|
||||
// That is the reason that this header does not contain a #pragma once, nor namespace guards
|
||||
|
||||
// Helper struct representing a transition event to a new FSM state
|
||||
struct TransitionEvent
|
||||
{
|
||||
Task<> newTask;
|
||||
StateId newStateId;
|
||||
};
|
||||
|
||||
// Base class for defining links between states
|
||||
class LinkBase
|
||||
{
|
||||
public:
|
||||
virtual ~LinkBase() = default;
|
||||
virtual TOptional<TransitionEvent> EvaluateLink(const tOnStateTransitionFn& in_onTransitionFn) const = 0;
|
||||
};
|
||||
|
||||
// Type-safe link handle
|
||||
class LinkHandle
|
||||
{
|
||||
bool IsOnCompleteLink() const
|
||||
{
|
||||
return m_linkType == eType::OnComplete;
|
||||
}
|
||||
bool HasCondition() const
|
||||
{
|
||||
return m_isConditionalLink;
|
||||
}
|
||||
|
||||
protected:
|
||||
// Link-type enum
|
||||
enum class eType
|
||||
{
|
||||
Normal,
|
||||
OnComplete,
|
||||
};
|
||||
|
||||
// Friends
|
||||
template<class, class> friend class StateHandle;
|
||||
friend class ::TaskFSM;
|
||||
|
||||
// Constructors (friend-only)
|
||||
LinkHandle() = delete;
|
||||
LinkHandle(TSharedPtr<LinkBase> in_link, eType in_linkType, bool in_isConditional)
|
||||
: m_link(MoveTemp(in_link))
|
||||
, m_linkType(in_linkType)
|
||||
, m_isConditionalLink(in_isConditional)
|
||||
{
|
||||
}
|
||||
TOptional<TransitionEvent> EvaluateLink(const tOnStateTransitionFn& in_onTransitionFn) const
|
||||
{
|
||||
return m_link->EvaluateLink(in_onTransitionFn);
|
||||
}
|
||||
|
||||
private:
|
||||
TSharedPtr<LinkBase> m_link; // The underlying link
|
||||
eType m_linkType; // Whether the link is normal or OnComplete
|
||||
bool m_isConditionalLink; // Whether the link has an associated condition predicate
|
||||
};
|
||||
|
||||
// Internal FSM state object
|
||||
template<class tStateInput, class tStateConstructorFn>
|
||||
struct State
|
||||
{
|
||||
State(tStateConstructorFn in_stateCtorFn, StateId in_stateId, FString in_debugName)
|
||||
: stateCtorFn(in_stateCtorFn)
|
||||
, stateId(in_stateId)
|
||||
, debugName(in_debugName)
|
||||
{
|
||||
}
|
||||
|
||||
tStateConstructorFn stateCtorFn;
|
||||
StateId stateId;
|
||||
FString debugName;
|
||||
};
|
||||
|
||||
// Internal FSM state object (exit state specialization)
|
||||
template<>
|
||||
struct State<void, void>
|
||||
{
|
||||
State(StateId in_stateId, FString in_debugName)
|
||||
: stateId(in_stateId)
|
||||
, debugName(in_debugName)
|
||||
{
|
||||
}
|
||||
|
||||
StateId stateId;
|
||||
FString debugName;
|
||||
};
|
||||
|
||||
// Internal link definition object
|
||||
template<class ReturnT, class tStateConstructorFn, class tPredicateFn>
|
||||
class Link : public LinkBase
|
||||
{
|
||||
public:
|
||||
Link(TSharedPtr<State<ReturnT, tStateConstructorFn>> in_targetState, tPredicateFn in_predicate)
|
||||
: m_targetState(MoveTemp(in_targetState))
|
||||
, m_predicate(in_predicate)
|
||||
{
|
||||
}
|
||||
|
||||
private:
|
||||
virtual TOptional<TransitionEvent> EvaluateLink(const tOnStateTransitionFn& in_onTransitionFn) const final
|
||||
{
|
||||
TOptional<TransitionEvent> result;
|
||||
if(TOptional<ReturnT> payload = m_predicate())
|
||||
{
|
||||
if(in_onTransitionFn)
|
||||
{
|
||||
in_onTransitionFn();
|
||||
}
|
||||
result = TransitionEvent{ m_targetState->stateCtorFn(payload.GetValue()), m_targetState->stateId };
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
TSharedPtr<State<ReturnT, tStateConstructorFn>> m_targetState;
|
||||
tPredicateFn m_predicate;
|
||||
};
|
||||
|
||||
// Internal link definition object (no-payload specialization)
|
||||
template<class tStateConstructorFn, class tPredicateFn>
|
||||
class Link<void, tStateConstructorFn, tPredicateFn> : public LinkBase
|
||||
{
|
||||
public:
|
||||
Link(TSharedPtr<State<void, tStateConstructorFn>> in_targetState, tPredicateFn in_predicate)
|
||||
: m_targetState(MoveTemp(in_targetState))
|
||||
, m_predicate(in_predicate)
|
||||
{
|
||||
}
|
||||
|
||||
private:
|
||||
virtual TOptional<TransitionEvent> EvaluateLink(const tOnStateTransitionFn& in_onTransitionFn) const final
|
||||
{
|
||||
TOptional<TransitionEvent> result;
|
||||
if(m_predicate())
|
||||
{
|
||||
if(in_onTransitionFn)
|
||||
{
|
||||
in_onTransitionFn();
|
||||
}
|
||||
result = TransitionEvent{ m_targetState->stateCtorFn(), m_targetState->stateId };
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
TSharedPtr<State<void, tStateConstructorFn>> m_targetState;
|
||||
tPredicateFn m_predicate;
|
||||
};
|
||||
|
||||
// Internal link definition object (exit-state specialization)
|
||||
template<class tPredicateFn>
|
||||
class Link<void, void, tPredicateFn> : public LinkBase
|
||||
{
|
||||
public:
|
||||
Link(TSharedPtr<State<void, void>> in_targetState, tPredicateFn in_predicate)
|
||||
: m_targetState(MoveTemp(in_targetState))
|
||||
, m_predicate(in_predicate)
|
||||
{
|
||||
}
|
||||
|
||||
private:
|
||||
virtual TOptional<TransitionEvent> EvaluateLink(const tOnStateTransitionFn& in_onTransitionFn) const final
|
||||
{
|
||||
TOptional<TransitionEvent> result;
|
||||
if(m_predicate())
|
||||
{
|
||||
if(in_onTransitionFn)
|
||||
{
|
||||
in_onTransitionFn();
|
||||
}
|
||||
result = TransitionEvent{ Task<>(), m_targetState->stateId };
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
TSharedPtr<State<void, void>> m_targetState;
|
||||
tPredicateFn m_predicate;
|
||||
};
|
||||
|
||||
// Specialized type traits that deduce the first argument type of an arbitrary callable type
|
||||
template <typename tRet, typename tArg>
|
||||
static tArg get_first_arg_type(TFunction<tRet(tArg)> f); // Return type is first argument type
|
||||
|
||||
template <typename tRet>
|
||||
static void get_first_arg_type(TFunction<tRet()> f); // Return type is void (function has no arguments)
|
||||
|
||||
template <typename T>
|
||||
struct function_traits : public function_traits<decltype(&T::operator())> // Generic callable objects (use operator())
|
||||
{
|
||||
};
|
||||
|
||||
template <typename tRet, typename... tArgs> // Function
|
||||
struct function_traits<tRet(tArgs...)>
|
||||
{
|
||||
using tFunction = TFunction<tRet(tArgs...)>;
|
||||
using tArg = decltype(get_first_arg_type(tFunction()));
|
||||
};
|
||||
|
||||
template <typename tRet, typename... tArgs> // Function ptr
|
||||
struct function_traits<tRet(*)(tArgs...)>
|
||||
{
|
||||
using tFunction = TFunction<tRet(tArgs...)>;
|
||||
using tArg = decltype(get_first_arg_type(tFunction()));
|
||||
};
|
||||
|
||||
template <typename tClass, typename tRet, typename... tArgs> // Member function ptr (const)
|
||||
struct function_traits<tRet(tClass::*)(tArgs...) const>
|
||||
{
|
||||
using tFunction = TFunction<tRet(tArgs...)>;
|
||||
using tArg = decltype(get_first_arg_type(tFunction()));
|
||||
};
|
||||
|
||||
template <typename tClass, typename tRet, typename... tArgs> // Member function ptr
|
||||
struct function_traits<tRet(tClass::*)(tArgs...)>
|
||||
{
|
||||
using tFunction = TFunction<tRet(tArgs...)>;
|
||||
using tArg = decltype(get_first_arg_type(tFunction()));
|
||||
};
|
||||
896
include/Private/TaskPrivate.h
Normal file
896
include/Private/TaskPrivate.h
Normal file
@@ -0,0 +1,896 @@
|
||||
// WARNING: This is an internal implementation header, which must be included from a specific location/namespace
|
||||
// That is the reason that this header does not contain a #pragma once, nor namespace guards
|
||||
|
||||
enum class eTaskRef;
|
||||
template <typename tRet> class TaskPromise;
|
||||
class TaskInternalBase;
|
||||
template <typename tRet> class TaskInternal;
|
||||
|
||||
//--- tTaskReadyFn ---//
|
||||
using tTaskReadyFn = TFunction<bool()>;
|
||||
|
||||
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable>
|
||||
auto CancelTaskIf(Task<tRet, RefType, Resumable>&& in_task, tTaskCancelFn in_cancelFn);
|
||||
|
||||
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable>
|
||||
auto StopTaskIf(Task<tRet, RefType, Resumable>&& in_task, tTaskCancelFn in_cancelFn);
|
||||
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable, typename tTimeFn>
|
||||
auto StopTaskIf(Task<tRet, RefType, Resumable>&& in_task, tTaskCancelFn in_cancelFn, tTaskTime in_timeout, tTimeFn in_timeFn);
|
||||
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable, typename T>
|
||||
auto StopTaskIf(Task<tRet, RefType, Resumable>&& in_task, tTaskCancelFn in_cancelFn, tTaskTime in_timeout);
|
||||
|
||||
//--- Suspend-If Awaiter ---//
|
||||
struct SuspendIf
|
||||
{
|
||||
SuspendIf(bool in_suspend)
|
||||
: m_suspend(in_suspend)
|
||||
{
|
||||
}
|
||||
bool await_ready() noexcept { return !m_suspend; }
|
||||
void await_suspend(std::coroutine_handle<>) noexcept {}
|
||||
void await_resume() noexcept {}
|
||||
|
||||
private:
|
||||
bool m_suspend;
|
||||
};
|
||||
|
||||
//--- Task Debug Stack Formatter ---//
|
||||
struct TaskDebugStackFormatter
|
||||
{
|
||||
// Format function (formats a debug output string) [virtual]
|
||||
virtual FString Format(const FString& in_str) const
|
||||
{
|
||||
FString result = Indent(0);
|
||||
int32_t indent = 0;
|
||||
int32_t start = 0;
|
||||
int32_t found = 0;
|
||||
while((found = in_str.FindChar('\n', start)) != INDEX_NONE)
|
||||
{
|
||||
int32_t end = found + 1;
|
||||
if((found < in_str.Len() - 1) && (in_str[found + 1] == '`')) // indent
|
||||
{
|
||||
++indent;
|
||||
++end;
|
||||
}
|
||||
else if((found >= 1) && (in_str[found - 1] == '`')) // dedent
|
||||
{
|
||||
--indent;
|
||||
--found;
|
||||
}
|
||||
result += in_str.Mid(start, found - start) + '\n' + Indent(indent);
|
||||
start = end;
|
||||
}
|
||||
result += in_str.Mid(start);
|
||||
return result;
|
||||
}
|
||||
virtual FString Indent(int32_t in_indent) const
|
||||
{
|
||||
return FString::ChrN(in_indent * 2, ' ');
|
||||
}
|
||||
};
|
||||
static FString FormatDebugString(FString in_str)
|
||||
{
|
||||
in_str.ReplaceCharInline('\n', ' ');
|
||||
in_str.LeftChopInline(32, false);
|
||||
return in_str;
|
||||
}
|
||||
|
||||
//--- SetDebugName Awaiter ---//
|
||||
#if SQUID_ENABLE_TASK_DEBUG
|
||||
struct SetDebugName
|
||||
{
|
||||
// Sets a Task's debug name field
|
||||
SetDebugName(const char* in_name)
|
||||
: m_name(in_name)
|
||||
{
|
||||
}
|
||||
SetDebugName(const char* in_name, TFunction<FString()> in_dataFn)
|
||||
: m_name(in_name)
|
||||
, m_dataFn(in_dataFn)
|
||||
{
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename tRet> friend class TaskPromiseBase;
|
||||
const char* m_name = nullptr;
|
||||
TFunction<FString()> m_dataFn;
|
||||
};
|
||||
#endif //SQUID_ENABLE_TASK_DEBUG
|
||||
|
||||
//--- AddStopTask Awaiter ---//
|
||||
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable>
|
||||
struct AddStopTaskAwaiter
|
||||
{
|
||||
AddStopTaskAwaiter(Task<tRet, RefType, Resumable>& in_taskToStop)
|
||||
: m_taskToStop(&in_taskToStop)
|
||||
{
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename tRet> friend class TaskPromiseBase;
|
||||
Task<tRet, RefType, Resumable>* m_taskToStop = nullptr;
|
||||
};
|
||||
|
||||
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable>
|
||||
auto AddStopTask(Task<tRet, RefType, Resumable>& in_taskToStop)
|
||||
{
|
||||
return AddStopTaskAwaiter<tRet, RefType, Resumable>(in_taskToStop);
|
||||
};
|
||||
|
||||
//--- RemoveStopTask Awaiter ---//
|
||||
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable>
|
||||
struct RemoveStopTaskAwaiter
|
||||
{
|
||||
RemoveStopTaskAwaiter(Task<tRet, RefType, Resumable>& in_taskToStop)
|
||||
: m_taskToStop(&in_taskToStop)
|
||||
{
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename tRet> friend class TaskPromiseBase;
|
||||
Task<tRet, RefType, Resumable>* m_taskToStop = nullptr;
|
||||
};
|
||||
|
||||
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable>
|
||||
auto RemoveStopTask(Task<tRet, RefType, Resumable>& in_taskToStop)
|
||||
{
|
||||
return RemoveStopTaskAwaiter<tRet, RefType, Resumable>(in_taskToStop);
|
||||
};
|
||||
|
||||
//--- Task Awaiter ---//
|
||||
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable, typename promise_type>
|
||||
struct TaskAwaiterBase
|
||||
{
|
||||
TaskAwaiterBase(const Task<tRet, RefType, Resumable>& in_task)
|
||||
{
|
||||
// This constructor exists to minimize downstream compile-error spam when co_awaiting a non-copyable Task by copy
|
||||
}
|
||||
TaskAwaiterBase(Task<tRet, RefType, Resumable>&& in_task)
|
||||
: m_task(MoveTemp(in_task))
|
||||
{
|
||||
SQUID_RUNTIME_CHECK(m_task.IsValid(), "Tried to await an invalid task");
|
||||
}
|
||||
TaskAwaiterBase(TaskAwaiterBase&& in_taskAwaiter) noexcept
|
||||
{
|
||||
m_task = MoveTemp(in_taskAwaiter.m_task);
|
||||
}
|
||||
bool await_ready() noexcept
|
||||
{
|
||||
if(m_task.IsDone())
|
||||
{
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
template <eTaskResumable UResumable = Resumable, typename std::enable_if_t<UResumable == eTaskResumable::Yes>* = nullptr>
|
||||
bool await_suspend(std::coroutine_handle<promise_type> in_coroHandle) noexcept
|
||||
{
|
||||
// Set the sub-task on the suspending task
|
||||
auto& promise = in_coroHandle.promise();
|
||||
auto taskInternal = promise.GetInternalTask();
|
||||
auto subTaskInternal = m_task.GetInternalTask();
|
||||
if(taskInternal->IsStopRequested())
|
||||
{
|
||||
subTaskInternal->RequestStop(); // Propagate any stop request to new sub-tasks
|
||||
}
|
||||
taskInternal->SetSubTask(StaticCastSharedPtr<TaskInternalBase>(subTaskInternal));
|
||||
|
||||
// Resume the task
|
||||
if(m_task.Resume() == eTaskStatus::Done)
|
||||
{
|
||||
taskInternal->SetSubTask(nullptr);
|
||||
return false; // Do not suspend, because the task is done
|
||||
}
|
||||
return true; // Suspend, because the task is not done
|
||||
}
|
||||
template <eTaskResumable UResumable = Resumable, typename std::enable_if_t<UResumable == eTaskResumable::No>* = nullptr>
|
||||
bool await_suspend(std::coroutine_handle<promise_type> in_coroHandle) noexcept
|
||||
{
|
||||
auto& promise = in_coroHandle.promise();
|
||||
if(!m_task.IsDone())
|
||||
{
|
||||
promise.SetReadyFunction([this] { return m_task.IsDone(); });
|
||||
return true; // Suspend, because the task is not done
|
||||
}
|
||||
return false; // Do not suspend, because the task is done
|
||||
}
|
||||
|
||||
protected:
|
||||
auto GetInternalTask() const
|
||||
{
|
||||
return m_task.GetInternalTask();
|
||||
}
|
||||
Task<tRet, RefType, Resumable> m_task;
|
||||
};
|
||||
|
||||
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable, typename promise_type>
|
||||
struct TaskAwaiter : public TaskAwaiterBase<tRet, RefType, Resumable, promise_type>
|
||||
{
|
||||
using TaskAwaiterBase<tRet, RefType, Resumable, promise_type>::TaskAwaiterBase;
|
||||
|
||||
template <typename U = tRet, typename std::enable_if_t<!std::is_void<U>::value>* = nullptr>
|
||||
auto await_resume()
|
||||
{
|
||||
this->m_task.RethrowUnhandledException(); // Re-throw any exceptions
|
||||
auto retVal = this->m_task.TakeReturnValue();
|
||||
SQUID_RUNTIME_CHECK(retVal, "Awaited task return value is unset");
|
||||
return MoveTemp(retVal.GetValue());
|
||||
}
|
||||
|
||||
template <typename U = tRet, typename std::enable_if_t<std::is_void<U>::value>* = nullptr>
|
||||
void await_resume()
|
||||
{
|
||||
this->m_task.RethrowUnhandledException(); // Re-throw any exceptions
|
||||
}
|
||||
};
|
||||
|
||||
//--- Future Awaiter ---//
|
||||
template <typename tRet, typename promise_type>
|
||||
struct FutureAwaiter
|
||||
{
|
||||
FutureAwaiter(TFuture<tRet>&& in_future)
|
||||
: m_future(MoveTemp(in_future))
|
||||
{
|
||||
}
|
||||
~FutureAwaiter()
|
||||
{
|
||||
}
|
||||
FutureAwaiter(FutureAwaiter&& in_futureAwaiter) noexcept
|
||||
{
|
||||
m_future = MoveTemp(in_futureAwaiter.m_future);
|
||||
}
|
||||
bool await_ready() noexcept
|
||||
{
|
||||
bool isReady = m_future.IsReady();
|
||||
return isReady;
|
||||
}
|
||||
bool await_suspend(std::coroutine_handle<promise_type> in_coroHandle) noexcept
|
||||
{
|
||||
// Set the ready function
|
||||
auto& promise = in_coroHandle.promise();
|
||||
|
||||
// Suspend if future is not ready
|
||||
bool shouldSuspend = !m_future.IsReady();
|
||||
if(shouldSuspend)
|
||||
{
|
||||
promise.SetReadyFunction([this] { return m_future.IsReady(); });
|
||||
}
|
||||
return shouldSuspend;
|
||||
}
|
||||
|
||||
template <typename U = tRet, typename std::enable_if_t<!std::is_void<U>::value>* = nullptr>
|
||||
auto await_resume()
|
||||
{
|
||||
return m_future.Get();
|
||||
}
|
||||
|
||||
template <typename U = tRet, typename std::enable_if_t<std::is_void<U>::value>* = nullptr>
|
||||
void await_resume()
|
||||
{
|
||||
m_future.Get();
|
||||
}
|
||||
|
||||
private:
|
||||
TFuture<tRet> m_future;
|
||||
};
|
||||
|
||||
//--- Shared Future Awaiter ---//
|
||||
template <typename tRet, typename promise_type>
|
||||
struct SharedFutureAwaiter
|
||||
{
|
||||
SharedFutureAwaiter(const TSharedFuture<tRet>& in_sharedFuture)
|
||||
: m_sharedFuture(in_sharedFuture)
|
||||
{
|
||||
}
|
||||
bool await_ready() noexcept
|
||||
{
|
||||
bool isReady = m_sharedFuture.IsReady();
|
||||
return isReady;
|
||||
}
|
||||
bool await_suspend(std::coroutine_handle<promise_type> in_coroHandle) noexcept
|
||||
{
|
||||
// Set the ready function
|
||||
auto& promise = in_coroHandle.promise();
|
||||
|
||||
// Suspend if future is not ready
|
||||
bool shouldSuspend = !m_sharedFuture.IsReady();
|
||||
if(shouldSuspend)
|
||||
{
|
||||
promise.SetReadyFunction([this] { return m_sharedFuture.IsReady(); });
|
||||
}
|
||||
return shouldSuspend;
|
||||
}
|
||||
|
||||
template <typename U = tRet, typename std::enable_if_t<!std::is_void<U>::value>* = nullptr>
|
||||
auto await_resume()
|
||||
{
|
||||
return m_sharedFuture.Get();
|
||||
}
|
||||
|
||||
template <typename U = tRet, typename std::enable_if_t<std::is_void<U>::value>* = nullptr>
|
||||
void await_resume()
|
||||
{
|
||||
m_sharedFuture.Get(); // Trigger any pending errors
|
||||
}
|
||||
|
||||
private:
|
||||
TSharedFuture<tRet> m_sharedFuture;
|
||||
};
|
||||
|
||||
//--- TaskPromiseBase ---//
|
||||
template <typename tRet>
|
||||
class alignas(16) TaskPromiseBase
|
||||
{
|
||||
public:
|
||||
// Type aliases
|
||||
using promise_type = TaskPromise<tRet>;
|
||||
using tTaskInternal = TaskInternal<tRet>;
|
||||
|
||||
// Destructor
|
||||
~TaskPromiseBase()
|
||||
{
|
||||
// NOTE: Destructor is non-virtual, because it is always handled + destroyed as its concrete type
|
||||
m_taskInternal->OnTaskPromiseDestroyed();
|
||||
}
|
||||
|
||||
// Coroutine interface functions
|
||||
auto initial_suspend() noexcept
|
||||
{
|
||||
return std::suspend_always();
|
||||
}
|
||||
auto final_suspend() noexcept
|
||||
{
|
||||
return std::suspend_always();
|
||||
}
|
||||
auto get_return_object()
|
||||
{
|
||||
return std::coroutine_handle<promise_type>::from_promise(*static_cast<promise_type*>(this));
|
||||
}
|
||||
static TSharedPtr<tTaskInternal> get_return_object_on_allocation_failure()
|
||||
{
|
||||
SQUID_THROW(std::bad_alloc(), "Failed to allocate memory for Task");
|
||||
return {};
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
// HACK: Coroutines in UE5 under MSVC is currently causing a memory underrun
|
||||
// These allocators are a workaround for the issue (as is alignas(16))
|
||||
void* operator new(size_t Size) noexcept
|
||||
{
|
||||
const size_t WorkaroundAlign = std::alignment_of<TaskPromiseBase>();
|
||||
Size += WorkaroundAlign;
|
||||
return (void*)((uint8_t*)FMemory::Malloc(Size, WorkaroundAlign) + WorkaroundAlign);
|
||||
}
|
||||
void operator delete(void* Ptr) noexcept
|
||||
{
|
||||
const size_t WorkaroundAlign = std::alignment_of<TaskPromiseBase>();
|
||||
auto OffsetPtr = (uint8_t*)Ptr - WorkaroundAlign;
|
||||
FMemory::Free(OffsetPtr);
|
||||
}
|
||||
//----------------------------------------------------------------------------
|
||||
|
||||
#if SQUID_NEEDS_UNHANDLED_EXCEPTION
|
||||
void unhandled_exception() noexcept
|
||||
{
|
||||
#if SQUID_USE_EXCEPTIONS
|
||||
// Propagate exceptions for handling
|
||||
m_taskInternal->SetUnhandledException(std::current_exception());
|
||||
#endif //SQUID_USE_EXCEPTIONS
|
||||
}
|
||||
#endif // SQUID_NEEDS_UNHANDLED_EXCEPTION
|
||||
|
||||
// Internal Task
|
||||
void SetInternalTask(tTaskInternal* in_taskInternal)
|
||||
{
|
||||
m_taskInternal = in_taskInternal;
|
||||
}
|
||||
tTaskInternal* GetInternalTask()
|
||||
{
|
||||
return m_taskInternal;
|
||||
}
|
||||
const tTaskInternal* GetInternalTask() const
|
||||
{
|
||||
return m_taskInternal;
|
||||
}
|
||||
|
||||
// Ready Function
|
||||
void SetReadyFunction(const tTaskReadyFn& in_taskReadyFn)
|
||||
{
|
||||
m_taskInternal->SetReadyFunction(in_taskReadyFn);
|
||||
}
|
||||
|
||||
// Await-Transforms
|
||||
auto await_transform(Suspend in_awaiter)
|
||||
{
|
||||
return in_awaiter;
|
||||
}
|
||||
auto await_transform(std::suspend_never in_awaiter)
|
||||
{
|
||||
return in_awaiter;
|
||||
}
|
||||
|
||||
#if SQUID_ENABLE_TASK_DEBUG
|
||||
auto await_transform(SetDebugName in_awaiter)
|
||||
{
|
||||
m_taskInternal->SetDebugName(in_awaiter.m_name);
|
||||
m_taskInternal->SetDebugDataFn(in_awaiter.m_dataFn);
|
||||
return std::suspend_never();
|
||||
}
|
||||
#endif //SQUID_ENABLE_TASK_DEBUG
|
||||
|
||||
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable>
|
||||
auto await_transform(AddStopTaskAwaiter<tRet, RefType, Resumable> in_awaiter)
|
||||
{
|
||||
m_taskInternal->AddStopTask(*in_awaiter.m_taskToStop);
|
||||
return std::suspend_never();
|
||||
}
|
||||
|
||||
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable>
|
||||
auto await_transform(RemoveStopTaskAwaiter<tRet, RefType, Resumable> in_awaiter)
|
||||
{
|
||||
m_taskInternal->RemoveStopTask(*in_awaiter.m_taskToStop);
|
||||
return std::suspend_never();
|
||||
}
|
||||
|
||||
auto await_transform(GetStopContext in_awaiter)
|
||||
{
|
||||
struct GetStopContextAwaiter : public std::suspend_never
|
||||
{
|
||||
GetStopContextAwaiter(StopContext in_stopCtx)
|
||||
: stopCtx(in_stopCtx)
|
||||
{
|
||||
}
|
||||
auto await_resume() noexcept
|
||||
{
|
||||
return stopCtx;
|
||||
}
|
||||
StopContext stopCtx;
|
||||
};
|
||||
GetStopContextAwaiter stopCtxAwaiter{ m_taskInternal->GetStopContext() };
|
||||
return stopCtxAwaiter;
|
||||
}
|
||||
auto await_transform(const tTaskReadyFn& in_taskReadyFn)
|
||||
{
|
||||
// Check if we are already ready, and suspend if we are not
|
||||
bool isReady = in_taskReadyFn();
|
||||
if(!isReady)
|
||||
{
|
||||
m_taskInternal->SetReadyFunction(in_taskReadyFn);
|
||||
}
|
||||
return SuspendIf(!isReady); // Suspend if the function isn't already ready
|
||||
}
|
||||
|
||||
template <typename tFutureRet>
|
||||
auto await_transform(TFuture<tFutureRet>&& in_future)
|
||||
{
|
||||
return FutureAwaiter<tFutureRet, promise_type>(MoveTemp(in_future));
|
||||
}
|
||||
|
||||
template <typename tFutureRet>
|
||||
auto await_transform(const TSharedFuture<tFutureRet>& in_sharedFuture)
|
||||
{
|
||||
return SharedFutureAwaiter<tFutureRet, promise_type>(in_sharedFuture);
|
||||
}
|
||||
|
||||
// Task Await-Transforms
|
||||
template <typename tTaskRet, eTaskRef RefType, eTaskResumable Resumable,
|
||||
typename std::enable_if_t<Resumable == eTaskResumable::Yes>* = nullptr>
|
||||
auto await_transform(Task<tTaskRet, RefType, Resumable>&& in_task) // Move version
|
||||
{
|
||||
return TaskAwaiter<tTaskRet, RefType, Resumable, promise_type>(MoveTemp(in_task));
|
||||
}
|
||||
|
||||
template <typename tTaskRet, eTaskRef RefType, eTaskResumable Resumable,
|
||||
typename std::enable_if_t<Resumable == eTaskResumable::No>* = nullptr>
|
||||
auto await_transform(Task<tTaskRet, RefType, Resumable> in_task) // Copy version (Non-Resumable)
|
||||
{
|
||||
return TaskAwaiter<tTaskRet, RefType, Resumable, promise_type>(MoveTemp(in_task));
|
||||
}
|
||||
|
||||
template <typename tTaskRet, eTaskRef RefType, eTaskResumable Resumable,
|
||||
typename std::enable_if_t<Resumable == eTaskResumable::Yes>* = nullptr>
|
||||
auto await_transform(const Task<tTaskRet, RefType, Resumable>& in_task) // Invalid copy version (Resumable)
|
||||
{
|
||||
static_assert(static_false<tTaskRet>::value, "Cannot await a non-copyable (resumable) Task by copy (try co_await MoveTemp(task), co_await WeakTaskHandle(task), or co_await task.WaitUntilDone()");
|
||||
return TaskAwaiter<tTaskRet, RefType, Resumable, promise_type>(MoveTemp(in_task));
|
||||
}
|
||||
|
||||
protected:
|
||||
tTaskInternal* m_taskInternal = nullptr;
|
||||
};
|
||||
|
||||
//--- TaskPromise ---//
|
||||
template <typename tRet>
|
||||
class TaskPromise : public TaskPromiseBase<tRet>
|
||||
{
|
||||
public:
|
||||
// Return value access
|
||||
void return_value(const tRet& in_retVal) // Copy return value
|
||||
{
|
||||
this->m_taskInternal->SetReturnValue(in_retVal);
|
||||
}
|
||||
void return_value(tRet&& in_retVal) // Move return value
|
||||
{
|
||||
this->m_taskInternal->SetReturnValue(MoveTemp(in_retVal));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class TaskPromise<void> : public TaskPromiseBase<void>
|
||||
{
|
||||
public:
|
||||
void return_void()
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
//--- TaskInternalBase ---//
|
||||
class TaskInternalBase
|
||||
{
|
||||
public:
|
||||
TaskInternalBase(std::coroutine_handle<> in_coroHandle)
|
||||
: m_coroHandle(in_coroHandle)
|
||||
{
|
||||
SQUID_RUNTIME_CHECK(m_coroHandle, "Invalid coroutine handle passed into Task");
|
||||
}
|
||||
~TaskInternalBase() // NOTE: Destructor is intentionally non-virtual (shared_ptr preserves concrete type via deleter)
|
||||
{
|
||||
Kill(); // Used for killing subtasks
|
||||
}
|
||||
StopContext GetStopContext() const
|
||||
{
|
||||
return { &m_isStopRequested };
|
||||
}
|
||||
bool IsStopRequested() const
|
||||
{
|
||||
return m_isStopRequested;
|
||||
}
|
||||
void RequestStop() // Propagates a request for the task to come to a 'graceful' stop
|
||||
{
|
||||
m_isStopRequested = true;
|
||||
for(auto& stopTask : m_stopTasks)
|
||||
{
|
||||
if(auto locked = stopTask.Pin())
|
||||
{
|
||||
locked->RequestStop();
|
||||
}
|
||||
}
|
||||
m_stopTasks.SetNum(0);
|
||||
}
|
||||
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable>
|
||||
void AddStopTask(Task<tRet, RefType, Resumable>& in_taskToStop) // Adds a task to the list of tasks to which we propagate stop requests
|
||||
{
|
||||
if(m_isStopRequested)
|
||||
{
|
||||
in_taskToStop.RequestStop();
|
||||
}
|
||||
else if(in_taskToStop.IsValid())
|
||||
{
|
||||
m_stopTasks.Add(in_taskToStop.GetInternalTask());
|
||||
}
|
||||
}
|
||||
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable>
|
||||
void RemoveStopTask(Task<tRet, RefType, Resumable>& in_taskToStop) // Removes a task to the list of tasks to which we propagate stop requests
|
||||
{
|
||||
if(in_taskToStop.IsValid())
|
||||
{
|
||||
for(int32_t i = 0; i < m_stopTasks.Num(); ++i)
|
||||
{
|
||||
if(m_stopTasks[i].Pin() == in_taskToStop.GetInternalTask())
|
||||
{
|
||||
m_stopTasks[i] = m_stopTasks.Last();
|
||||
m_stopTasks.Pop();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
eTaskStatus Resume() // Returns whether the task is still running
|
||||
{
|
||||
// Make sure this task is not already mid-resume
|
||||
SQUID_RUNTIME_CHECK(m_internalState != eInternalState::Resuming, "Attempted to resume Task while already resumed");
|
||||
|
||||
// Task is destroyed, therefore task is done
|
||||
if(m_internalState == eInternalState::Destroyed)
|
||||
{
|
||||
return eTaskStatus::Done;
|
||||
}
|
||||
|
||||
// Mark task as resuming
|
||||
m_internalState = eInternalState::Resuming;
|
||||
|
||||
// Resume any active sub-task
|
||||
if(m_subTaskInternal)
|
||||
{
|
||||
// Propagate any stop requests to sub-task prior to resuming
|
||||
if(m_isStopRequested)
|
||||
{
|
||||
m_subTaskInternal->m_isStopRequested = true;
|
||||
}
|
||||
|
||||
// Resume the sub-task
|
||||
if(m_subTaskInternal->Resume() != eTaskStatus::Done)
|
||||
{
|
||||
m_internalState = eInternalState::Idle;
|
||||
return eTaskStatus::Suspended; // Sub-task not done, therefore task is not done
|
||||
}
|
||||
|
||||
// Clear the sub-task
|
||||
m_subTaskInternal = nullptr;
|
||||
}
|
||||
|
||||
// Resume task, if necessary
|
||||
if(CanResume())
|
||||
{
|
||||
m_taskReadyFn = nullptr; // Clear any ready function we were waiting on
|
||||
m_coroHandle.resume(); // Resume the underlying std::coroutine_handle
|
||||
}
|
||||
|
||||
// Return to idle state and return current task status
|
||||
auto taskStatus = m_coroHandle.done() ? eTaskStatus::Done : eTaskStatus::Suspended;
|
||||
if(taskStatus == eTaskStatus::Done)
|
||||
{
|
||||
m_isDone = true; // Mark task done
|
||||
}
|
||||
m_internalState = eInternalState::Idle;
|
||||
return taskStatus;
|
||||
}
|
||||
|
||||
// Sub-task
|
||||
void SetSubTask(TSharedPtr<TaskInternalBase> in_subTaskInternal)
|
||||
{
|
||||
m_subTaskInternal = in_subTaskInternal;
|
||||
}
|
||||
|
||||
#if SQUID_ENABLE_TASK_DEBUG
|
||||
// Debug task name + stack
|
||||
FString GetDebugName() const
|
||||
{
|
||||
return (!IsDone() && m_debugDataFn) ? (FString(m_debugName) + " [" + m_debugDataFn() + "]") : m_debugName;
|
||||
}
|
||||
FString GetDebugStack() const
|
||||
{
|
||||
FString result = m_subTaskInternal ? (GetDebugName() + " -> " + m_subTaskInternal->GetDebugStack()) : GetDebugName();
|
||||
return result;
|
||||
}
|
||||
void SetDebugName(const char* in_debugName)
|
||||
{
|
||||
if(in_debugName)
|
||||
{
|
||||
m_debugName = in_debugName;
|
||||
}
|
||||
}
|
||||
void SetDebugDataFn(TFunction<FString()> in_debugDataFn)
|
||||
{
|
||||
m_debugDataFn = in_debugDataFn;
|
||||
}
|
||||
#endif //SQUID_ENABLE_TASK_DEBUG
|
||||
|
||||
// Exceptions
|
||||
#if SQUID_USE_EXCEPTIONS
|
||||
std::exception_ptr GetUnhandledException() const
|
||||
{
|
||||
if(m_isExceptionSet)
|
||||
{
|
||||
return m_exception;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
#endif //SQUID_USE_EXCEPTIONS
|
||||
|
||||
protected:
|
||||
#if SQUID_USE_EXCEPTIONS
|
||||
// Internal implementation of exception-setting (called by TaskInternal<> child classes)
|
||||
void InternalSetUnhandledException(std::exception_ptr in_exception)
|
||||
{
|
||||
// NOTE: This must never be called more than once in the lifetime of an internal task
|
||||
SQUID_RUNTIME_CHECK(!m_isExceptionSet, "Exception was set for a task after it had already been set");
|
||||
if(!m_isExceptionSet)
|
||||
{
|
||||
m_exception = in_exception;
|
||||
m_isExceptionSet = true;
|
||||
}
|
||||
}
|
||||
#endif //SQUID_USE_EXCEPTIONS
|
||||
|
||||
private:
|
||||
template <typename tRet> friend class TaskPromiseBase;
|
||||
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable, typename promise_type> friend struct TaskAwaiterBase;
|
||||
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable> friend class Task;
|
||||
|
||||
// Kill this task
|
||||
void Kill() // Kill() can safely be called multiple times
|
||||
{
|
||||
SQUID_RUNTIME_CHECK(m_internalState != eInternalState::Resuming, "Attempted to kill Task while resumed");
|
||||
if(m_internalState == eInternalState::Idle)
|
||||
{
|
||||
// Mark task done
|
||||
m_isDone = true;
|
||||
|
||||
// First destroy any sub-tasks
|
||||
if(m_subTaskInternal)
|
||||
{
|
||||
m_subTaskInternal->Kill();
|
||||
}
|
||||
|
||||
// Destroy the underlying std::coroutine_handle
|
||||
m_coroHandle.destroy(); // This should only ever be called directly from this one place
|
||||
m_coroHandle = nullptr;
|
||||
m_taskReadyFn = nullptr; // Clear out the ready function
|
||||
m_internalState = eInternalState::Destroyed;
|
||||
}
|
||||
}
|
||||
|
||||
// Done + can-resume status
|
||||
void SetReadyFunction(const tTaskReadyFn& in_taskReadyFn)
|
||||
{
|
||||
m_taskReadyFn = in_taskReadyFn;
|
||||
}
|
||||
bool CanResume() const
|
||||
{
|
||||
if(IsDone())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(m_subTaskInternal)
|
||||
{
|
||||
bool canResume = m_subTaskInternal->CanResume();
|
||||
return canResume;
|
||||
}
|
||||
bool isReady = !m_taskReadyFn || m_taskReadyFn();
|
||||
return isReady;
|
||||
}
|
||||
bool IsDone() const
|
||||
{
|
||||
return m_isDone;
|
||||
}
|
||||
bool m_isDone = false;
|
||||
|
||||
// Internal state
|
||||
enum class eInternalState
|
||||
{
|
||||
Idle,
|
||||
Resuming,
|
||||
Destroyed,
|
||||
};
|
||||
eInternalState m_internalState = eInternalState::Idle;
|
||||
|
||||
// Task ready condition (when awaiting a TFunction<bool>)
|
||||
tTaskReadyFn m_taskReadyFn;
|
||||
|
||||
#if SQUID_USE_EXCEPTIONS
|
||||
// Exceptions
|
||||
std::exception_ptr m_exception = nullptr;
|
||||
bool m_isExceptionSet = false;
|
||||
#endif //SQUID_USE_EXCEPTIONS
|
||||
|
||||
// Sub-task
|
||||
TSharedPtr<TaskInternalBase> m_subTaskInternal;
|
||||
|
||||
// Reference-counting (determines underlying std::coroutine_handle lifetime, not lifetime of this internal task)
|
||||
void AddLogicalRef()
|
||||
{
|
||||
++m_refCount;
|
||||
}
|
||||
void RemoveLogicalRef()
|
||||
{
|
||||
--m_refCount;
|
||||
if(m_refCount == 0)
|
||||
{
|
||||
Kill();
|
||||
}
|
||||
}
|
||||
int32_t m_refCount = 0; // Number of (strong) non-weak tasks referencing the internal task
|
||||
|
||||
// C++ std::coroutine_handle
|
||||
std::coroutine_handle<> m_coroHandle;
|
||||
|
||||
// Stop request
|
||||
bool m_isStopRequested = false;
|
||||
TArray<TWeakPtr<TaskInternalBase>> m_stopTasks;
|
||||
|
||||
#if SQUID_ENABLE_TASK_DEBUG
|
||||
// Debug Data
|
||||
const char* m_debugName = "[unnamed task]";
|
||||
TFunction<FString()> m_debugDataFn;
|
||||
#endif //SQUID_ENABLE_TASK_DEBUG
|
||||
};
|
||||
|
||||
//--- TaskInternal ---//
|
||||
template <typename tRet>
|
||||
class TaskInternal : public TaskInternalBase
|
||||
{
|
||||
public:
|
||||
using promise_type = TaskPromise<tRet>;
|
||||
|
||||
TaskInternal(std::coroutine_handle<promise_type> in_handle)
|
||||
: TaskInternalBase(in_handle)
|
||||
{
|
||||
auto& promisePtr = in_handle.promise();
|
||||
promisePtr.SetInternalTask(this);
|
||||
}
|
||||
#if SQUID_USE_EXCEPTIONS
|
||||
void SetUnhandledException(std::exception_ptr in_exception)
|
||||
{
|
||||
m_retValState = eTaskRetValState::Orphaned; // Return value can never be set if there was an unhandled exception
|
||||
InternalSetUnhandledException(in_exception);
|
||||
}
|
||||
#endif //SQUID_USE_EXCEPTIONS
|
||||
void SetReturnValue(const tRet& in_retVal)
|
||||
{
|
||||
tRet retVal = in_retVal;
|
||||
SetReturnValue(MoveTemp(retVal));
|
||||
}
|
||||
void SetReturnValue(tRet&& in_retVal)
|
||||
{
|
||||
if(m_retValState == eTaskRetValState::Unset)
|
||||
{
|
||||
m_retVal = MoveTemp(in_retVal);
|
||||
m_retValState = eTaskRetValState::Set;
|
||||
return;
|
||||
}
|
||||
|
||||
// These conditions should (logically) never be met, but they are included in case future changes violate that constraint
|
||||
SQUID_RUNTIME_CHECK(m_retValState != eTaskRetValState::Set, "Attempted to set a task's return value when it was already set");
|
||||
SQUID_RUNTIME_CHECK(m_retValState != eTaskRetValState::Taken, "Attempted to set a task's return value after it was already taken");
|
||||
SQUID_RUNTIME_CHECK(m_retValState != eTaskRetValState::Orphaned, "Attempted to set a task's return value after it was orphaned");
|
||||
}
|
||||
TOptional<tRet> TakeReturnValue()
|
||||
{
|
||||
// If the value has been set, mark it as taken and move-return the value
|
||||
if(m_retValState == eTaskRetValState::Set)
|
||||
{
|
||||
m_retValState = eTaskRetValState::Taken;
|
||||
return MoveTemp(m_retVal);
|
||||
}
|
||||
|
||||
// If the value was not set, return an unset optional (checking that it was neither taken nor orphaned)
|
||||
SQUID_RUNTIME_CHECK(m_retValState != eTaskRetValState::Taken, "Attempted to take a task's return value after it was already successfully taken");
|
||||
SQUID_RUNTIME_CHECK(m_retValState != eTaskRetValState::Orphaned, "Attempted to take a task's return value that will never be set (task ended prematurely)");
|
||||
return {};
|
||||
}
|
||||
void OnTaskPromiseDestroyed()
|
||||
{
|
||||
// Mark the return value as orphaned if it was never set
|
||||
m_retValState = eTaskRetValState::Orphaned;
|
||||
}
|
||||
|
||||
private:
|
||||
// Internal state
|
||||
enum class eTaskRetValState
|
||||
{
|
||||
Unset, // Has not yet been set
|
||||
Set, // Has been set and can be taken
|
||||
Taken, // Has been taken and can no longer be taken
|
||||
Orphaned, // Will never be set because the TaskPromise has been destroyed
|
||||
};
|
||||
|
||||
eTaskRetValState m_retValState = eTaskRetValState::Unset; // Initially unset
|
||||
TOptional<tRet> m_retVal;
|
||||
};
|
||||
|
||||
template <>
|
||||
class TaskInternal<void> : public TaskInternalBase
|
||||
{
|
||||
public:
|
||||
using promise_type = TaskPromise<void>;
|
||||
|
||||
TaskInternal(std::coroutine_handle<promise_type> in_handle)
|
||||
: TaskInternalBase(in_handle)
|
||||
{
|
||||
auto& promisePtr = in_handle.promise();
|
||||
promisePtr.SetInternalTask(this);
|
||||
}
|
||||
#if SQUID_USE_EXCEPTIONS
|
||||
void SetUnhandledException(std::exception_ptr in_exception)
|
||||
{
|
||||
InternalSetUnhandledException(in_exception);
|
||||
}
|
||||
#endif //SQUID_USE_EXCEPTIONS
|
||||
void TakeReturnValue() // This function is an intentional no-op, to simplify certain templated function implementations
|
||||
{
|
||||
}
|
||||
void OnTaskPromiseDestroyed()
|
||||
{
|
||||
}
|
||||
};
|
||||
236
include/Private/TasksCommonPrivate.h
Normal file
236
include/Private/TasksCommonPrivate.h
Normal file
@@ -0,0 +1,236 @@
|
||||
#pragma once
|
||||
|
||||
//--- User configuration header ---//
|
||||
#include "../TasksConfig.h"
|
||||
|
||||
// Namespace macros (enabled/disabled via SQUID_ENABLE_NAMESPACE)
|
||||
#if SQUID_ENABLE_NAMESPACE
|
||||
#define NAMESPACE_SQUID_BEGIN namespace Squid {
|
||||
#define NAMESPACE_SQUID_END }
|
||||
#define NAMESPACE_SQUID Squid
|
||||
#else
|
||||
#define NAMESPACE_SQUID_BEGIN
|
||||
#define NAMESPACE_SQUID_END
|
||||
#define NAMESPACE_SQUID
|
||||
namespace Squid {} // Convenience to allow 'using namespace Squid' even when namespace is disabled
|
||||
#endif
|
||||
|
||||
// Exception macros (to support environments with exceptions disabled)
|
||||
#if SQUID_USE_EXCEPTIONS && (defined(__cpp_exceptions) || defined(__EXCEPTIONS))
|
||||
#include <stdexcept>
|
||||
#define SQUID_THROW(exception, errStr) throw exception;
|
||||
#define SQUID_RUNTIME_ERROR(errStr) throw std::runtime_error(errStr);
|
||||
#define SQUID_RUNTIME_CHECK(condition, errStr) if(!(condition)) throw std::runtime_error(errStr);
|
||||
#else
|
||||
#include <cassert>
|
||||
#define SQUID_THROW(exception, errStr) assert(false && errStr);
|
||||
#define SQUID_RUNTIME_ERROR(errStr) assert(false && errStr);
|
||||
#define SQUID_RUNTIME_CHECK(condition, errStr) assert((condition) && errStr);
|
||||
#endif //__cpp_exceptions
|
||||
|
||||
// Time Interface
|
||||
NAMESPACE_SQUID_BEGIN
|
||||
#if SQUID_ENABLE_DOUBLE_PRECISION_TIME
|
||||
using tTaskTime = double;
|
||||
#else
|
||||
using tTaskTime = float; // Defines time units for use with the Task system
|
||||
#endif //SQUID_ENABLE_DOUBLE_PRECISION_TIME
|
||||
NAMESPACE_SQUID_END
|
||||
|
||||
// Coroutine de-optimization macros [DEPRECATED]
|
||||
#ifdef _MSC_VER
|
||||
#if _MSC_VER >= 1920
|
||||
// Newer versions of Visual Studio (>= VS2019) compile coroutines correctly
|
||||
#define COROUTINE_OPTIMIZE_OFF
|
||||
#define COROUTINE_OPTIMIZE_ON
|
||||
#else
|
||||
// Older versions of Visual Studio had code generation bugs when optimizing coroutines (they would compile, but have incorrect runtime results)
|
||||
#define COROUTINE_OPTIMIZE_OFF __pragma(optimize("", off))
|
||||
#define COROUTINE_OPTIMIZE_ON __pragma(optimize("", on))
|
||||
#endif // _MSC_VER >= 1920
|
||||
#else
|
||||
// The Clang compiler has sometimes crashed when optimizing/inlining certain coroutines, so this macro can be used to disable inlining
|
||||
#define COROUTINE_OPTIMIZE_OFF _Pragma("clang optimize off")
|
||||
#define COROUTINE_OPTIMIZE_ON _Pragma("clang optimize on")
|
||||
#endif
|
||||
|
||||
// False type for use in static_assert() [static_assert(false, ...) -> static_assert(static_false<T>, ...)]
|
||||
template<typename T>
|
||||
struct static_false : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
// Determine C++ Language Version
|
||||
#if defined(_MSVC_LANG)
|
||||
#define CPP_LANGUAGE_VERSION _MSVC_LANG
|
||||
#elif defined(__cplusplus)
|
||||
#define CPP_LANGUAGE_VERSION __cplusplus
|
||||
#else
|
||||
#define CPP_LANGUAGE_VERSION 0L
|
||||
#endif
|
||||
|
||||
#if CPP_LANGUAGE_VERSION > 201703L // C++20 or higher
|
||||
#define HAS_CXX17 1
|
||||
#define HAS_CXX20 1
|
||||
#elif CPP_LANGUAGE_VERSION > 201402L // C++17 or higher
|
||||
#define HAS_CXX17 1
|
||||
#define HAS_CXX20 0
|
||||
#elif CPP_LANGUAGE_VERSION > 201103L // C++14 or higher
|
||||
#define HAS_CXX17 0
|
||||
#define HAS_CXX20 0
|
||||
#else // C++11 or lower
|
||||
#error Squid::Tasks requires C++14 or higher
|
||||
#define HAS_CXX17 0
|
||||
#define HAS_CXX20 0
|
||||
#endif
|
||||
#undef CPP_LANGUAGE_VERSION
|
||||
|
||||
// C++20 Compatibility (std::coroutine)
|
||||
#if HAS_CXX20 || (defined(_MSVC_LANG) && defined(__cpp_lib_coroutine)) // Standard coroutines
|
||||
#include <coroutine>
|
||||
#define SQUID_EXPERIMENTAL_COROUTINES 0
|
||||
#else // Experimental coroutines
|
||||
#if defined(__clang__) && defined(_STL_COMPILER_PREPROCESSOR)
|
||||
// HACK: Some distributions of clang don't have a <experimental/coroutine> header. We only need a few symbols, so just define them ourselves
|
||||
namespace std {
|
||||
namespace experimental {
|
||||
inline namespace coroutines_v1 {
|
||||
|
||||
template <typename R, typename...> struct coroutine_traits {
|
||||
using promise_type = typename R::promise_type;
|
||||
};
|
||||
|
||||
template <typename Promise = void> struct coroutine_handle;
|
||||
|
||||
template <> struct coroutine_handle<void> {
|
||||
static coroutine_handle from_address(void* addr) noexcept {
|
||||
coroutine_handle me;
|
||||
me.ptr = addr;
|
||||
return me;
|
||||
}
|
||||
void operator()() { resume(); }
|
||||
void* address() const { return ptr; }
|
||||
void resume() const { __builtin_coro_resume(ptr); }
|
||||
void destroy() const { __builtin_coro_destroy(ptr); }
|
||||
bool done() const { return __builtin_coro_done(ptr); }
|
||||
coroutine_handle& operator=(decltype(nullptr)) {
|
||||
ptr = nullptr;
|
||||
return *this;
|
||||
}
|
||||
coroutine_handle(decltype(nullptr)) : ptr(nullptr) {}
|
||||
coroutine_handle() : ptr(nullptr) {}
|
||||
// void reset() { ptr = nullptr; } // add to P0057?
|
||||
explicit operator bool() const { return ptr; }
|
||||
|
||||
protected:
|
||||
void* ptr;
|
||||
};
|
||||
|
||||
template <typename Promise> struct coroutine_handle : coroutine_handle<> {
|
||||
using coroutine_handle<>::operator=;
|
||||
|
||||
static coroutine_handle from_address(void* addr) noexcept {
|
||||
coroutine_handle me;
|
||||
me.ptr = addr;
|
||||
return me;
|
||||
}
|
||||
|
||||
Promise& promise() const {
|
||||
return *reinterpret_cast<Promise*>(
|
||||
__builtin_coro_promise(ptr, alignof(Promise), false));
|
||||
}
|
||||
static coroutine_handle from_promise(Promise& promise) {
|
||||
coroutine_handle p;
|
||||
p.ptr = __builtin_coro_promise(&promise, alignof(Promise), true);
|
||||
return p;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename _PromiseT>
|
||||
bool operator==(coroutine_handle<_PromiseT> const& _Left,
|
||||
coroutine_handle<_PromiseT> const& _Right) noexcept
|
||||
{
|
||||
return _Left.address() == _Right.address();
|
||||
}
|
||||
|
||||
template <typename _PromiseT>
|
||||
bool operator!=(coroutine_handle<_PromiseT> const& _Left,
|
||||
coroutine_handle<_PromiseT> const& _Right) noexcept
|
||||
{
|
||||
return !(_Left == _Right);
|
||||
}
|
||||
|
||||
struct suspend_always {
|
||||
bool await_ready() noexcept { return false; }
|
||||
void await_suspend(coroutine_handle<>) noexcept {}
|
||||
void await_resume() noexcept {}
|
||||
};
|
||||
struct suspend_never {
|
||||
bool await_ready() noexcept { return true; }
|
||||
void await_suspend(coroutine_handle<>) noexcept {}
|
||||
void await_resume() noexcept {}
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
#include <experimental/coroutine>
|
||||
#endif
|
||||
namespace std // Alias experimental coroutine symbols into std namespace
|
||||
{
|
||||
template <class _Promise = void>
|
||||
using coroutine_handle = experimental::coroutine_handle<_Promise>;
|
||||
using suspend_never = experimental::suspend_never;
|
||||
using suspend_always = experimental::suspend_always;
|
||||
};
|
||||
#define SQUID_EXPERIMENTAL_COROUTINES 1
|
||||
#endif
|
||||
|
||||
// Determine whether our tasks need the member function "unhandled_exception()" defined or not
|
||||
#if defined(_MSC_VER)
|
||||
// MSVC's rules for exceptions differ between standard + experimental coroutines
|
||||
#if SQUID_EXPERIMENTAL_COROUTINES
|
||||
// If exceptions are enabled, we must define unhandled_exception()
|
||||
#if defined(__cpp_exceptions) && __cpp_exceptions == 199711
|
||||
#define SQUID_NEEDS_UNHANDLED_EXCEPTION 1
|
||||
#else
|
||||
#define SQUID_NEEDS_UNHANDLED_EXCEPTION 0
|
||||
#endif
|
||||
#else
|
||||
// If we're using VS16.11 or newer -- or older than 16.10, we have one set of rules for standard coroutines
|
||||
#if _MSC_FULL_VER >= 192930133L || _MSC_VER < 1429L
|
||||
#define SQUID_NEEDS_UNHANDLED_EXCEPTION 1
|
||||
#else
|
||||
#if defined(__cpp_exceptions) && __cpp_exceptions == 199711
|
||||
#define SQUID_NEEDS_UNHANDLED_EXCEPTION 1
|
||||
#else
|
||||
// 16.10 has a bug with their standard coroutine implementation that creates a set of contradicting requirements
|
||||
// https://developercommunity.visualstudio.com/t/coroutine-uses-promise_type::unhandled_e/1374530
|
||||
#error Visual Studio 16.10 has a compiler bug that prevents all coroutines from compiling when exceptions are disabled and using standard C++20 coroutines or /await:strict. Please either upgrade your version of Visual Studio, or use the experimental /await flag, or enable exceptions.
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
#else
|
||||
// Clang always requires unhandled_exception() to be defined
|
||||
#define SQUID_NEEDS_UNHANDLED_EXCEPTION 1
|
||||
#endif
|
||||
|
||||
// C++17 Compatibility ([[nodiscard]])
|
||||
#if !defined(SQUID_NODISCARD) && defined(__has_cpp_attribute)
|
||||
#if __has_cpp_attribute(nodiscard)
|
||||
#define SQUID_NODISCARD [[nodiscard]]
|
||||
#endif
|
||||
#endif
|
||||
#ifndef SQUID_NODISCARD
|
||||
#define SQUID_NODISCARD
|
||||
#endif
|
||||
|
||||
#undef HAS_CXX17
|
||||
#undef HAS_CXX20
|
||||
|
||||
// Include UE core headers
|
||||
#include "CoreMinimal.h"
|
||||
#include "Engine/World.h"
|
||||
#include "Engine/Engine.h"
|
||||
#include "Async/Future.h"
|
||||
1153
include/Task.h
Normal file
1153
include/Task.h
Normal file
File diff suppressed because it is too large
Load Diff
331
include/TaskFSM.h
Normal file
331
include/TaskFSM.h
Normal file
@@ -0,0 +1,331 @@
|
||||
#pragma once
|
||||
|
||||
/// @defgroup TaskFSM Task FSM
|
||||
/// @brief Finite state machine that implements states using task factories
|
||||
/// @{
|
||||
///
|
||||
/// Full documentation of the TaskFSM system coming soon!
|
||||
|
||||
#include "Task.h"
|
||||
|
||||
NAMESPACE_SQUID_BEGIN
|
||||
|
||||
class TaskFSM;
|
||||
|
||||
namespace FSM
|
||||
{
|
||||
// State ID wrapper
|
||||
struct StateId
|
||||
{
|
||||
StateId() = default;
|
||||
StateId(int32_t in_idx) : idx(in_idx) {}
|
||||
StateId(size_t in_idx) : idx((int32_t)in_idx) {}
|
||||
bool operator==(const StateId& other) const { return (other.idx == idx); }
|
||||
bool operator!=(const StateId& other) const { return (other.idx != idx); }
|
||||
bool IsValid() const { return idx != INT32_MAX; }
|
||||
|
||||
int32_t idx = INT32_MAX; // Default to invalid idx
|
||||
};
|
||||
|
||||
// State transition debug data
|
||||
struct TransitionDebugData
|
||||
{
|
||||
FSM::StateId oldStateId;
|
||||
FString oldStateName;
|
||||
FSM::StateId newStateId;
|
||||
FString newStateName;
|
||||
};
|
||||
|
||||
// State transition callback function
|
||||
using tOnStateTransitionFn = TFunction<void()>;
|
||||
|
||||
#include "Private/TaskFSMPrivate.h" // Internal use only! Do not include elsewhere!
|
||||
|
||||
//--- State Handle ---//
|
||||
template<class tStateInput, class tStateConstructorFn>
|
||||
class StateHandle
|
||||
{
|
||||
using tPredicateRet = typename std::conditional<!std::is_void<tStateInput>::value, TOptional<tStateInput>, bool>::type;
|
||||
using tPredicateFn = TFunction<tPredicateRet()>;
|
||||
public:
|
||||
StateHandle(StateHandle&& in_other) = default;
|
||||
StateHandle& operator=(StateHandle&& in_other) = default;
|
||||
|
||||
StateId GetId() const //< Get the ID of this state
|
||||
{
|
||||
return m_state ? m_state->idx : StateId{};
|
||||
}
|
||||
|
||||
// SFINAE Template Declaration Macros (#defines)
|
||||
/// @cond
|
||||
#define NONVOID_ONLY_WITH_PREDICATE template <class tPredicateFn, typename tPayload = tStateInput, typename std::enable_if_t<!std::is_void<tPayload>::value>* = nullptr>
|
||||
#define VOID_ONLY_WITH_PREDICATE template <class tPredicateFn, typename tPayload = tStateInput, typename std::enable_if_t<std::is_void<tPayload>::value>* = nullptr>
|
||||
#define NONVOID_ONLY template <typename tPayload = tStateInput, typename std::enable_if_t<!std::is_void<tPayload>::value && !std::is_convertible<tPayload, tPredicateFn>::value>* = nullptr>
|
||||
#define VOID_ONLY template <typename tPayload = tStateInput, typename std::enable_if_t<std::is_void<tPayload>::value>* = nullptr>
|
||||
#define PREDICATE_ONLY template <typename tPredicateFn, typename std::enable_if_t<!std::is_convertible<tStateInput, tPredicateFn>::value>* = nullptr>
|
||||
/// @endcond
|
||||
|
||||
// Link methods
|
||||
VOID_ONLY LinkHandle Link() //< Empty predicate link (always follow link)
|
||||
{
|
||||
return _InternalLink([] { return true; }, LinkHandle::eType::Normal);
|
||||
}
|
||||
NONVOID_ONLY LinkHandle Link(tPayload in_payload) //< Empty predicate link w/ payload (always follow link, using provided payload)
|
||||
{
|
||||
return _InternalLink([payload = MoveTemp(in_payload)]() -> tPredicateRet { return payload; }, LinkHandle::eType::Normal);
|
||||
}
|
||||
PREDICATE_ONLY LinkHandle Link(tPredicateFn in_predicate) //< Predicate link w/ implicit payload (follow link when predicate returns a value; use return value as payload)
|
||||
{
|
||||
return _InternalLink(in_predicate, LinkHandle::eType::Normal);
|
||||
}
|
||||
NONVOID_ONLY_WITH_PREDICATE LinkHandle Link(tPredicateFn in_predicate, tPayload in_payload) //< Predicate link w/ explicit payload (follow link when predicate returns true; use provided payload)
|
||||
{
|
||||
return _InternalLink(in_predicate, MoveTemp(in_payload), LinkHandle::eType::Normal);
|
||||
}
|
||||
|
||||
// OnCompleteLink methods
|
||||
VOID_ONLY LinkHandle OnCompleteLink() //< Empty predicate link (always follow link)
|
||||
{
|
||||
return _InternalLink([] { return true; }, LinkHandle::eType::OnComplete);
|
||||
}
|
||||
NONVOID_ONLY LinkHandle OnCompleteLink(tPayload in_payload) //< Empty predicate link w/ payload (always follow link, using provided payload)
|
||||
{
|
||||
return _InternalLink([payload = MoveTemp(in_payload)]() -> tPredicateRet { return payload; }, LinkHandle::eType::OnComplete);
|
||||
}
|
||||
PREDICATE_ONLY LinkHandle OnCompleteLink(tPredicateFn in_predicate) //< Predicate link w/ implicit payload (follow link when predicate returns a value; use return value as payload)
|
||||
{
|
||||
return _InternalLink(in_predicate, LinkHandle::eType::OnComplete, true);
|
||||
}
|
||||
NONVOID_ONLY_WITH_PREDICATE LinkHandle OnCompleteLink(tPredicateFn in_predicate, tPayload in_payload) //< Predicate link w/ explicit payload (follow link when predicate returns true; use provided payload)
|
||||
{
|
||||
return _InternalLink(in_predicate, MoveTemp(in_payload), LinkHandle::eType::OnComplete, true);
|
||||
}
|
||||
|
||||
private:
|
||||
friend class ::TaskFSM;
|
||||
|
||||
StateHandle() = delete;
|
||||
StateHandle(TSharedPtr<State<tStateInput, tStateConstructorFn>> InStatePtr)
|
||||
: m_state(InStatePtr)
|
||||
{
|
||||
}
|
||||
StateHandle(const StateHandle& Other) = delete;
|
||||
StateHandle& operator=(const StateHandle& Other) = delete;
|
||||
|
||||
// Internal link function implementations
|
||||
VOID_ONLY_WITH_PREDICATE LinkHandle _InternalLink(tPredicateFn in_predicate, LinkHandle::eType in_linkType, bool in_isConditional = false) // bool-returning predicate
|
||||
{
|
||||
static_assert(std::is_same<bool, decltype(in_predicate())>::value, "This link requires a predicate function returning bool");
|
||||
TSharedPtr<LinkBase> link = MakeShared<FSM::Link<tStateInput, tStateConstructorFn, tPredicateFn>>(m_state, in_predicate);
|
||||
return LinkHandle(link, in_linkType, in_isConditional);
|
||||
}
|
||||
NONVOID_ONLY_WITH_PREDICATE LinkHandle _InternalLink(tPredicateFn in_predicate, LinkHandle::eType in_linkType, bool in_isConditional = false) // optional-returning predicate
|
||||
{
|
||||
static_assert(std::is_same<TOptional<tStateInput>, decltype(in_predicate())>::value, "This link requires a predicate function returning TOptional<tStateInput>");
|
||||
TSharedPtr<LinkBase> link = MakeShared<FSM::Link<tStateInput, tStateConstructorFn, tPredicateFn>>(m_state, in_predicate);
|
||||
return LinkHandle(link, in_linkType, in_isConditional);
|
||||
}
|
||||
NONVOID_ONLY_WITH_PREDICATE LinkHandle _InternalLink(tPredicateFn in_predicate, tPayload in_payload, LinkHandle::eType in_linkType, bool in_isConditional = false) // bool-returning predicate w/ fixed payload
|
||||
{
|
||||
static_assert(std::is_same<bool, decltype(in_predicate())>::value, "This link requires a predicate function returning bool");
|
||||
auto predicate = [in_predicate, in_payload]() -> TOptional<tStateInput>
|
||||
{
|
||||
return in_predicate() ? TOptional<tStateInput>(in_payload) : TOptional<tStateInput>{};
|
||||
};
|
||||
return _InternalLink(predicate, in_linkType, in_isConditional);
|
||||
}
|
||||
|
||||
// SFINAE Template Declaration Macros (#undefs)
|
||||
#undef NONVOID_ONLY_WITH_PREDICATE
|
||||
#undef VOID_ONLY_WITH_PREDICATE
|
||||
#undef NONVOID_ONLY
|
||||
#undef VOID_ONLY
|
||||
#undef PREDICATE_ONLY
|
||||
|
||||
TSharedPtr<State<tStateInput, tStateConstructorFn>> m_state; // Internal state object
|
||||
};
|
||||
|
||||
} // namespace FSM
|
||||
|
||||
using StateId = FSM::StateId;
|
||||
template<class tStateInput, class tStateConstructorFn>
|
||||
using StateHandle = FSM::StateHandle<tStateInput, tStateConstructorFn>;
|
||||
using TransitionDebugData = FSM::TransitionDebugData;
|
||||
using tOnStateTransitionFn = FSM::tOnStateTransitionFn;
|
||||
|
||||
//--- TaskFSM ---//
|
||||
class TaskFSM
|
||||
{
|
||||
public:
|
||||
using tOnStateTransitionFn = TFunction<void()>;
|
||||
using tDebugStateTransitionFn = TFunction<void(TransitionDebugData)>;
|
||||
|
||||
// Create a new FSM state [fancy param-deducing version (hopefully) coming soon!]
|
||||
template<typename tStateConstructorFn>
|
||||
auto State(FString in_name, tStateConstructorFn in_stateCtorFn)
|
||||
{
|
||||
typedef FSM::function_traits<tStateConstructorFn> tFnTraits;
|
||||
using tStateInput = typename tFnTraits::tArg;
|
||||
const FSM::StateId newStateId = m_states.Num();
|
||||
m_states.Add(InternalStateData(in_name));
|
||||
auto state = MakeShared<FSM::State<tStateInput, tStateConstructorFn>>(MoveTemp(in_stateCtorFn), newStateId, in_name);
|
||||
return FSM::StateHandle<tStateInput, tStateConstructorFn>{ state };
|
||||
}
|
||||
|
||||
// Create a new FSM exit state (immediately terminates the FSM when executed)
|
||||
FSM::StateHandle<void, void> State(FString in_name)
|
||||
{
|
||||
const FSM::StateId newStateId = m_states.Num();
|
||||
m_states.Add(InternalStateData(in_name));
|
||||
m_exitStates.Add(newStateId);
|
||||
auto state = MakeShared<FSM::State<void, void>>(newStateId, in_name);
|
||||
return FSM::StateHandle<void, void>{ state };
|
||||
}
|
||||
|
||||
// Define the initial entry links into the state machine
|
||||
void EntryLinks(TArray<FSM::LinkHandle> in_entryLinks);
|
||||
|
||||
// Define all outgoing links from a given state (may only be called once per state)
|
||||
template<class tStateInput, class tStateConstructorFn>
|
||||
void StateLinks(const FSM::StateHandle<tStateInput, tStateConstructorFn>& in_originState, TArray<FSM::LinkHandle> in_outgoingLinks);
|
||||
|
||||
// Begins execution of the state machine (returns id of final exit state)
|
||||
Task<FSM::StateId> Run(tOnStateTransitionFn in_onTransitionFn = {}, tDebugStateTransitionFn in_debugStateTransitionFn = {}) const;
|
||||
|
||||
private:
|
||||
// Evaluates all possible outgoing links from the current state, returning the first valid transition (if any transitions are valid)
|
||||
TOptional<FSM::TransitionEvent> EvaluateLinks(FSM::StateId in_curStateId, bool in_isCurrentStateComplete, const tOnStateTransitionFn& in_onTransitionFn) const;
|
||||
|
||||
// Internal state
|
||||
struct InternalStateData
|
||||
{
|
||||
InternalStateData(FString in_debugName)
|
||||
: debugName(in_debugName)
|
||||
{
|
||||
}
|
||||
TArray<FSM::LinkHandle> outgoingLinks;
|
||||
FString debugName;
|
||||
};
|
||||
TArray<InternalStateData> m_states;
|
||||
TArray<FSM::LinkHandle> m_entryLinks;
|
||||
TArray<FSM::StateId> m_exitStates;
|
||||
};
|
||||
|
||||
/// @} end of group TaskFSM
|
||||
|
||||
//--- TaskFSM Methods ---//
|
||||
template<class tStateInput, class tStateConstructorFn>
|
||||
void TaskFSM::StateLinks(const FSM::StateHandle<tStateInput, tStateConstructorFn>& in_originState, TArray<FSM::LinkHandle> in_outgoingLinks)
|
||||
{
|
||||
const int32_t stateIdx = in_originState.m_state->stateId.idx;
|
||||
SQUID_RUNTIME_CHECK(m_states[stateIdx].outgoingLinks.Num() == 0, "Cannot set outgoing links more than once for each state");
|
||||
|
||||
// Validate that there are exactly 0 or 1 unconditional OnComplete links (there may be any number of other OnComplete links, but only one with no condition)
|
||||
int32_t numOnCompleteLinks = 0;
|
||||
int32_t numOnCompleteLinks_Unconditional = 0;
|
||||
for(const FSM::LinkHandle& link : in_outgoingLinks)
|
||||
{
|
||||
if(link.IsOnCompleteLink())
|
||||
{
|
||||
SQUID_RUNTIME_CHECK(numOnCompleteLinks_Unconditional == 0, "Cannot call OnCompleteLink() after calling OnCompleteLink() with no conditions (unreachable link)");
|
||||
++numOnCompleteLinks;
|
||||
if(!link.HasCondition())
|
||||
{
|
||||
numOnCompleteLinks_Unconditional++;
|
||||
}
|
||||
}
|
||||
}
|
||||
SQUID_RUNTIME_CHECK(numOnCompleteLinks == 0 || numOnCompleteLinks_Unconditional > 0, "More than one unconditional OnCompleteLink() was set");
|
||||
|
||||
// Set the outgoing links for the origin state
|
||||
m_states[stateIdx].outgoingLinks = MoveTemp(in_outgoingLinks);
|
||||
}
|
||||
inline void TaskFSM::EntryLinks(TArray<FSM::LinkHandle> in_entryLinks)
|
||||
{
|
||||
// Validate to ensure there are no OnComplete links set as entry links
|
||||
int32_t numOnCompleteLinks = 0;
|
||||
for(const FSM::LinkHandle& link : in_entryLinks)
|
||||
{
|
||||
if(link.IsOnCompleteLink())
|
||||
{
|
||||
++numOnCompleteLinks;
|
||||
}
|
||||
}
|
||||
SQUID_RUNTIME_CHECK(numOnCompleteLinks == 0, "EntryLinks() list may not contain any OnCompleteLink() links");
|
||||
|
||||
// Set the entry links list for this FSM
|
||||
m_entryLinks = MoveTemp(in_entryLinks);
|
||||
}
|
||||
inline TOptional<FSM::TransitionEvent> TaskFSM::EvaluateLinks(FSM::StateId in_curStateId, bool in_isCurrentStateComplete, const tOnStateTransitionFn& in_onTransitionFn) const
|
||||
{
|
||||
// Determine whether to use entry links or state-specific outgoing links
|
||||
const TArray<FSM::LinkHandle>& links = (in_curStateId.idx < m_states.Num()) ? m_states[in_curStateId.idx].outgoingLinks : m_entryLinks;
|
||||
|
||||
// Find the first valid transition from the current state
|
||||
for(const FSM::LinkHandle& link : links)
|
||||
{
|
||||
if(!link.IsOnCompleteLink() || in_isCurrentStateComplete) // Skip link evaluation check for OnComplete links unless current state is complete
|
||||
{
|
||||
if(auto result = link.EvaluateLink(in_onTransitionFn)) // Check if the transition to this state is valid
|
||||
{
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
return {}; // No valid transition was found
|
||||
}
|
||||
inline Task<FSM::StateId> TaskFSM::Run(tOnStateTransitionFn in_onTransitionFn, tDebugStateTransitionFn in_debugStateTransitionFn) const
|
||||
{
|
||||
// Task-local variables
|
||||
FSM::StateId curStateId; // The current state's ID
|
||||
Task<> task; // The current state's task
|
||||
|
||||
// Custom debug task name logic
|
||||
TASK_NAME(__FUNCTION__, [this, &curStateId, &task]
|
||||
{
|
||||
const auto stateName = m_states.IsValidIndex(curStateId.idx) ? m_states[curStateId.idx].debugName : "";
|
||||
return FString::Printf(TEXT("%s -- %s"), *stateName, *task.GetDebugStack());
|
||||
});
|
||||
|
||||
// Debug state transition lambda
|
||||
auto DebugStateTransition = [this, in_debugStateTransitionFn](FSM::StateId in_oldStateId, FSM::StateId in_newStateId) {
|
||||
if(in_debugStateTransitionFn)
|
||||
{
|
||||
FString oldStateName = in_oldStateId.IsValid() ? m_states[in_oldStateId.idx].debugName : FString("<ENTRY>");
|
||||
FString newStateName = m_states[in_newStateId.idx].debugName;
|
||||
in_debugStateTransitionFn({ in_oldStateId, MoveTemp(oldStateName), in_newStateId, MoveTemp(newStateName) });
|
||||
}
|
||||
};
|
||||
|
||||
// Main FSM loop
|
||||
while(true)
|
||||
{
|
||||
// Evaluate links, checking for a valid transition
|
||||
if(TOptional<FSM::TransitionEvent> transition = EvaluateLinks(curStateId, task.IsDone(), in_onTransitionFn))
|
||||
{
|
||||
auto newStateId = transition->newStateId;
|
||||
DebugStateTransition(curStateId, newStateId); // Call state-transition debug function
|
||||
|
||||
// If the transition is to an exit state, return that state ID (terminating the FSM)
|
||||
if(m_exitStates.Contains(newStateId.idx))
|
||||
{
|
||||
co_return newStateId;
|
||||
}
|
||||
SQUID_RUNTIME_CHECK(newStateId.idx < m_states.Num(), "It should be logically impossible to get an invalid state to this point");
|
||||
|
||||
// Begin running new state (implicitly killing old state)
|
||||
curStateId = newStateId;
|
||||
co_await RemoveStopTask(task);
|
||||
task = MoveTemp(transition->newTask); // NOTE: Initial call to Resume() happens below
|
||||
co_await AddStopTask(task);
|
||||
}
|
||||
|
||||
// Resume current state
|
||||
task.Resume();
|
||||
|
||||
// Suspend until next frame
|
||||
co_await Suspend();
|
||||
}
|
||||
}
|
||||
|
||||
NAMESPACE_SQUID_END
|
||||
215
include/TaskManager.h
Normal file
215
include/TaskManager.h
Normal file
@@ -0,0 +1,215 @@
|
||||
#pragma once
|
||||
|
||||
/// @defgroup TaskManager Task Manager
|
||||
/// @brief Manager that runs and resumes a collection of tasks.
|
||||
/// @{
|
||||
///
|
||||
/// A TaskManager is a simple manager class that holds an ordered list of tasks and resumes them whenever it is updated.
|
||||
///
|
||||
/// Running Tasks
|
||||
/// -------------
|
||||
/// There are two primary ways to run tasks on a task manager.
|
||||
///
|
||||
/// The first method (running an "unmanaged task") is to pass a task into @ref TaskManager::Run(). This will move the task
|
||||
/// into the task manager and return a @ref TaskHandle that can be used to observe and manage the lifetime of the task (as well
|
||||
/// as potentially take a return value after the task finishes). With unmanaged tasks, the task manager only holds a weak
|
||||
/// reference to the task, meaning that the @ref TaskHandle returned by @ref TaskManager::Run() is the only remaining strong
|
||||
/// reference to the task. Because of this, the caller is entirely responsible for managing the lifetime of the task.
|
||||
///
|
||||
/// The second method (running a "managed task") is to pass a task into @ref TaskManager::RunManaged(). Like
|
||||
/// @ref TaskManager::Run(), this will move the task into the task manager and return a @ref WeakTaskHandle that can be used to
|
||||
/// observe the lifetime of the task (as well as manually kill it, if desired). Unlike unmanaged tasks, the task manager
|
||||
/// stores a strong reference to the task. Because of this, that caller is not responsible for managing the lifetime of
|
||||
/// the task. This difference in task ownership means that (unlike an unmanaged task) a managed task can be thought of as
|
||||
/// a "fire-and-forget" task that will run until either it finishes or until something else explicitly kills it.
|
||||
///
|
||||
/// Order of Execution
|
||||
/// ------------------
|
||||
/// The ordering of task updates within a call to @ref TaskManager::Update() is stable, meaning that the first task that
|
||||
/// is run on a task manager will remain the first to resume, no matter how many other tasks are run on the task manager
|
||||
/// (or terminate) in the meantime.
|
||||
///
|
||||
/// Integration into Actor Classes
|
||||
/// ------------------------------
|
||||
/// Consider the following example of a TaskManager that has been integrated into a TaskActor base class:
|
||||
///
|
||||
/// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~{.cpp}
|
||||
///
|
||||
/// class TaskActor : public Actor
|
||||
/// {
|
||||
/// public:
|
||||
/// virtual void OnInitialize() override // Automatically called when this enemy enters the scene
|
||||
/// {
|
||||
/// Actor::OnInitialize(); // Call the base Actor function
|
||||
/// m_taskMgr.RunManaged(ManageActor()); // Run main actor task as a fire-and-forget "managed task"
|
||||
/// }
|
||||
///
|
||||
/// virtual void Tick(float in_dt) override // Automatically called every frame
|
||||
/// {
|
||||
/// Actor::Tick(in_dt); // Call the base Actor function
|
||||
/// m_taskMgr.Update(); // Resume all active tasks once per tick
|
||||
/// }
|
||||
///
|
||||
/// virtual void OnDestroy() override // Automatically called when this enemy leaves the scene
|
||||
/// {
|
||||
/// m_taskMgr.KillAllTasks(); // Kill all active tasks when we leave the scene
|
||||
/// Actor::OnDestroy(); // Call the base Actor function
|
||||
/// }
|
||||
///
|
||||
/// protected:
|
||||
/// virtual Task<> ManageActor() // Overridden (in its entirety) by child classes
|
||||
/// {
|
||||
/// co_await WaitForever(); // Waits forever (doing nothing)
|
||||
/// }
|
||||
///
|
||||
/// TaskManager m_taskMgr;
|
||||
/// };
|
||||
///
|
||||
/// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
///
|
||||
/// In the above example, TaskManager is instantiated once per high-level actor. It is updated once per frame within
|
||||
/// the Tick() method, and all its tasks are killed when it leaves the scene in OnDestroy(). Lastly, a single entry-point
|
||||
/// coroutine is run as a managed task when the actor enters the scene. (The above is the conventional method of integration
|
||||
/// into this style of game engine.)
|
||||
///
|
||||
/// Note that it is sometimes necessary to have multiple TaskManagers within a single actor. For example, if there are
|
||||
/// multiple tick functions (such as one for pre-physics updates and one for post-physics updates), then instantiating
|
||||
/// a second "post-physics" task manager may be desirable.
|
||||
|
||||
#include "Task.h"
|
||||
|
||||
NAMESPACE_SQUID_BEGIN
|
||||
|
||||
//--- TaskManager ---//
|
||||
/// Manager that runs and resumes a collection of tasks.
|
||||
class TaskManager
|
||||
{
|
||||
public:
|
||||
~TaskManager() {} /// Destructor (disables copy/move construction + assignment)
|
||||
|
||||
/// @brief Run an unmanaged task
|
||||
/// @details Run() return a @ref TaskHandle<> that holds a strong reference to the task. If there are ever no
|
||||
/// strong references remaining to an unmanaged task, it will immediately be killed and removed from the manager.
|
||||
template <typename tRet>
|
||||
SQUID_NODISCARD TaskHandle<tRet> Run(Task<tRet>&& in_task)
|
||||
{
|
||||
// Run unmanaged task
|
||||
TaskHandle<tRet> taskHandle = in_task;
|
||||
WeakTask weakTask = MoveTemp(in_task);
|
||||
RunWeakTask(MoveTemp(weakTask));
|
||||
return taskHandle;
|
||||
}
|
||||
template <typename tRet>
|
||||
SQUID_NODISCARD TaskHandle<tRet> Run(const Task<tRet>& in_task) /// @private Illegal copy implementation
|
||||
{
|
||||
static_assert(static_false<tRet>::value, "Cannot run an unmanaged task by copy (try Run(MoveTemp(task)))");
|
||||
return {};
|
||||
}
|
||||
|
||||
/// @brief Run a managed task
|
||||
/// @details RunManaged() return a @ref WeakTaskHandle, meaning it can be used to run a "fire-and-forget" background
|
||||
/// task in situations where it is not necessary to observe or control task lifetime.
|
||||
template <typename tRet>
|
||||
WeakTaskHandle RunManaged(Task<tRet>&& in_task)
|
||||
{
|
||||
// Run managed task
|
||||
WeakTaskHandle weakTaskHandle = in_task;
|
||||
m_strongRefs.Add(Run(MoveTemp(in_task)));
|
||||
return weakTaskHandle;
|
||||
}
|
||||
template <typename tRet>
|
||||
WeakTaskHandle RunManaged(const Task<tRet>& in_task) /// @private Illegal copy implementation
|
||||
{
|
||||
static_assert(static_false<tRet>::value, "Cannot run a managed task by copy (try RunManaged(MoveTemp(task)))");
|
||||
return {};
|
||||
}
|
||||
|
||||
/// @brief Run a weak task
|
||||
/// @details RunWeakTask() runs a WeakTask. The caller is assumed to have already created a strong TaskHandle<> that
|
||||
/// references the WeakTask, thus keeping it from being killed. When the last strong reference to the WeakTask is
|
||||
/// destroyed, the task will immediately be killed and removed from the manager.
|
||||
void RunWeakTask(WeakTask&& in_task)
|
||||
{
|
||||
// Run unmanaged task
|
||||
m_tasks.Add(MoveTemp(in_task));
|
||||
}
|
||||
|
||||
/// Call Task::Kill() on all tasks (managed + unmanaged)
|
||||
void KillAllTasks()
|
||||
{
|
||||
m_tasks.Reset(); // Destroying all the weak tasks implicitly destroys all internal tasks
|
||||
|
||||
// No need to call Kill() on each TaskHandle in m_strongRefs
|
||||
m_strongRefs.Reset(); // Handles in the strong refs array only ever point to tasks in the now-cleared m_tasks array
|
||||
}
|
||||
|
||||
/// @brief Issue a stop request using @ref Task::RequestStop() on all active tasks (managed and unmanaged)
|
||||
/// @details Returns a new awaiter task that will wait until all those tasks have all terminated.
|
||||
Task<> StopAllTasks()
|
||||
{
|
||||
// Request stop on all tasks
|
||||
TArray<WeakTaskHandle> weakHandles;
|
||||
for(auto& task : m_tasks)
|
||||
{
|
||||
task.RequestStop();
|
||||
weakHandles.Add(task);
|
||||
}
|
||||
|
||||
// Return a fence task that waits until all stopped tasks are complete
|
||||
return [](TArray<WeakTaskHandle> in_weakHandles) -> Task<> {
|
||||
TASK_NAME("StopAllTasks() Fence Task");
|
||||
for(const auto& weakHandle : in_weakHandles)
|
||||
{
|
||||
co_await weakHandle; // Wait until task is complete
|
||||
}
|
||||
}(MoveTemp(weakHandles));
|
||||
}
|
||||
|
||||
/// Call @ref Task::Resume() on all active tasks exactly once (managed + unmanaged)
|
||||
void Update()
|
||||
{
|
||||
// Resume all tasks
|
||||
int32 writeIdx = 0;
|
||||
for(int32 readIdx = 0; readIdx < m_tasks.Num(); ++readIdx)
|
||||
{
|
||||
if(m_tasks[readIdx].Resume() != eTaskStatus::Done)
|
||||
{
|
||||
if(writeIdx != readIdx)
|
||||
{
|
||||
m_tasks[writeIdx] = MoveTemp(m_tasks[readIdx]);
|
||||
}
|
||||
++writeIdx;
|
||||
}
|
||||
}
|
||||
m_tasks.SetNum(writeIdx);
|
||||
|
||||
// Prune strong tasks that are done
|
||||
m_strongRefs.RemoveAllSwap([](const auto& in_taskHandle) { return in_taskHandle.IsDone(); });
|
||||
}
|
||||
|
||||
/// Get a debug string containing a list of all active tasks
|
||||
FString GetDebugString(TOptional<TaskDebugStackFormatter> in_formatter = {}) const
|
||||
{
|
||||
FString debugStr;
|
||||
for(const auto& task : m_tasks)
|
||||
{
|
||||
if(!task.IsDone())
|
||||
{
|
||||
if(debugStr.Len())
|
||||
{
|
||||
debugStr += '\n';
|
||||
}
|
||||
debugStr += task.GetDebugStack(in_formatter);
|
||||
}
|
||||
}
|
||||
return debugStr;
|
||||
}
|
||||
|
||||
private:
|
||||
TArray<WeakTask> m_tasks;
|
||||
TArray<TaskHandle<>> m_strongRefs;
|
||||
};
|
||||
|
||||
NAMESPACE_SQUID_END
|
||||
|
||||
///@} end of TaskManager group
|
||||
48
include/TasksConfig.h
Normal file
48
include/TasksConfig.h
Normal file
@@ -0,0 +1,48 @@
|
||||
#pragma once
|
||||
|
||||
// Squid::Tasks version (major.minor.patch)
|
||||
#define SQUID_TASKS_VERSION_MAJOR 0
|
||||
#define SQUID_TASKS_VERSION_MINOR 2
|
||||
#define SQUID_TASKS_VERSION_PATCH 0
|
||||
|
||||
/// @defgroup Config Configuration
|
||||
/// @brief Configuration settings for the Squid::Tasks library
|
||||
/// @{
|
||||
|
||||
/// Enables Task debug names and callstack tracking via Task::GetDebugStack() and Task::GetDebugName()
|
||||
#ifndef SQUID_ENABLE_TASK_DEBUG
|
||||
#define SQUID_ENABLE_TASK_DEBUG 1
|
||||
#endif
|
||||
|
||||
/// Switches time type (tTaskTime) from 32-bit single-precision floats to 64-bit double-precision floats
|
||||
#ifndef SQUID_ENABLE_DOUBLE_PRECISION_TIME
|
||||
#define SQUID_ENABLE_DOUBLE_PRECISION_TIME 0
|
||||
#endif
|
||||
|
||||
/// Wraps a Squid:: namespace around all classes in the Squid::Tasks library
|
||||
#ifndef SQUID_ENABLE_NAMESPACE
|
||||
#define SQUID_ENABLE_NAMESPACE 0
|
||||
#endif
|
||||
|
||||
/// Enables experimental (largely-untested) exception handling, and replaces all asserts with runtime_error exceptions
|
||||
#ifndef SQUID_USE_EXCEPTIONS
|
||||
#define SQUID_USE_EXCEPTIONS 0
|
||||
#endif
|
||||
|
||||
/// Enables global time support(alleviating the need to specify a time stream for time - sensitive awaiters) [see @ref GetGlobalTime()]
|
||||
#ifndef SQUID_ENABLE_GLOBAL_TIME
|
||||
// ***************
|
||||
// *** WARNING ***
|
||||
// ***************
|
||||
// It is generally inadvisable for game projects to define a global task time, as it assumes there is only a single time-stream.
|
||||
// Within game projects, there is usually a "game time" and "real time", as well as others (such as "audio time", "unpaused time").
|
||||
// Furthermore, in engines such as Unreal, a non-static world context object must be provided.
|
||||
|
||||
// To enable global task time, user must *also* define a GetGlobalTime() implementation (otherwise there will be a linker error)
|
||||
#define SQUID_ENABLE_GLOBAL_TIME 0
|
||||
#endif
|
||||
|
||||
/// @} end of addtogroup Config
|
||||
|
||||
//--- C++17/C++20 Compatibility ---//
|
||||
#include "Private/TasksCommonPrivate.h"
|
||||
320
include/TokenList.h
Normal file
320
include/TokenList.h
Normal file
@@ -0,0 +1,320 @@
|
||||
#pragma once
|
||||
|
||||
/// @defgroup Tokens Token List
|
||||
/// @brief Data structure for tracking decentralized state across multiple tasks.
|
||||
/// @{
|
||||
///
|
||||
/// Token objects can be created using @ref TokenList::MakeToken(), returning a shared pointer to a new Token. This
|
||||
/// new Token can then be added to the TokenList using @ref TokenList::AddToken(). @ref TokenList::TakeToken()
|
||||
/// can be used to make + add a new token with a single function call.
|
||||
///
|
||||
/// Because TokenList uses weak pointers to track its elements, Token objects are logically removed from the list once
|
||||
/// they are destroyed. As such, it is usually unnecessary to explicitly call @ref TokenList::RemoveToken() to remove a
|
||||
/// Token from the list. Instead, it is idiomatic to consider the Token to be a sort of "scope guard" that will remove
|
||||
/// itself from all TokenList objects when it leaves scope.
|
||||
///
|
||||
/// The TokenList class is included as part of Squid::Tasks to provide a simple mechanism for robustly sharing aribtrary
|
||||
/// state between multiple tasks. Consider this example of a poison damage-over-time system:
|
||||
/// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~{.cpp}
|
||||
///
|
||||
/// class Character : public Actor
|
||||
/// {
|
||||
/// public:
|
||||
/// bool IsPoisoned() const
|
||||
/// {
|
||||
/// return m_poisonTokens; // Whether there are any live poison tokens
|
||||
/// }
|
||||
///
|
||||
/// void OnPoisoned(float in_dps, float in_duration)
|
||||
/// {
|
||||
/// m_taskMgr.RunManaged(ManagePoisonInstance(in_dps, in_duration));
|
||||
/// }
|
||||
///
|
||||
/// private:
|
||||
/// TokenList<float> m_poisonTokens; // Token list indicating live poison damage
|
||||
///
|
||||
/// Task<> ManagePoisonInstance(float in_dps, float in_duration)
|
||||
/// {
|
||||
/// // Take a poison token and hold it for N seconds
|
||||
/// auto poisonToken = m_poisonTokens.TakeToken(__FUNCTION__, in_dps);
|
||||
/// co_await WaitSeconds(in_duration);
|
||||
/// }
|
||||
///
|
||||
/// Task<> ManageCharacter() // Called once per frame
|
||||
/// {
|
||||
/// while(true)
|
||||
/// {
|
||||
/// float poisonDps = m_poisonTokens.GetMax(); // Get highest DPS poison instance
|
||||
/// DealDamage(poisonDps * GetDT()); // Deal the actual poison damage
|
||||
/// co_await Suspend();
|
||||
/// }
|
||||
/// }
|
||||
/// };
|
||||
///
|
||||
/// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
///
|
||||
/// As the above example shows, this mechanism is well-suited for coroutines, as they can hold a Token across
|
||||
/// multiple frames. Also note that Token objects can optionally hold data. The TokenList class has query functions
|
||||
/// (e.g. GetMin()/GetMax()) that can be used to aggregate the data from the set of live tokens. This is used above
|
||||
/// to quickly find the highest DPS poison instance.
|
||||
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
//--- User configuration header ---//
|
||||
#include "TasksConfig.h"
|
||||
|
||||
NAMESPACE_SQUID_BEGIN
|
||||
|
||||
template <typename T = void>
|
||||
class TokenList;
|
||||
|
||||
/// @brief Handle to a TokenList element that stores a debug name
|
||||
/// @details In most circumstances, name should be set to \ref __FUNCTION__ at the point of creation.
|
||||
struct Token
|
||||
{
|
||||
Token(FString in_name)
|
||||
: name(MoveTemp(in_name))
|
||||
{
|
||||
}
|
||||
FString name; // Used for debug only
|
||||
};
|
||||
|
||||
/// @brief Handle to a TokenList element that stores both a debug name and associated data
|
||||
/// @details In most circumstances, name should be set to \c __FUNCTION__ at the point of creation.
|
||||
template <typename tData>
|
||||
struct DataToken
|
||||
{
|
||||
DataToken(FString in_name, tData in_data)
|
||||
: name(MoveTemp(in_name))
|
||||
, data(MoveTemp(in_data))
|
||||
{
|
||||
}
|
||||
FString name; // Used for debug only
|
||||
tData data;
|
||||
};
|
||||
|
||||
/// Create a token with the specified debug name
|
||||
inline TSharedPtr<Token> MakeToken(FString in_name)
|
||||
{
|
||||
return MakeShared<Token>(MoveTemp(in_name));
|
||||
}
|
||||
|
||||
/// Create a token with the specified debug name and associated data
|
||||
template <typename tData>
|
||||
TSharedPtr<DataToken<tData>> MakeToken(FString in_name, tData in_data)
|
||||
{
|
||||
return MakeShared<DataToken<tData>>(MoveTemp(in_name), MoveTemp(in_data));
|
||||
}
|
||||
|
||||
/// @brief Container for tracking decentralized state across multiple tasks. (See \ref Tokens for more info...)
|
||||
/// @tparam T Type of data to associate with each Token in this container
|
||||
template <typename T>
|
||||
class TokenList
|
||||
{
|
||||
public:
|
||||
/// Type of Token tracked by this container
|
||||
using Token = typename std::conditional_t<std::is_void<T>::value, Token, DataToken<T>>;
|
||||
|
||||
/// Create a token with the specified debug name
|
||||
template <typename U = T, typename std::enable_if_t<std::is_void<U>::value>* = nullptr>
|
||||
static TSharedPtr<Token> MakeToken(FString in_name)
|
||||
{
|
||||
return MakeShared<Token>(MoveTemp(in_name));
|
||||
}
|
||||
|
||||
/// Create a token with the specified debug name and associated data
|
||||
template <typename U = T, typename std::enable_if_t<!std::is_void<U>::value>* = nullptr>
|
||||
static TSharedPtr<Token> MakeToken(FString in_name, U in_data)
|
||||
{
|
||||
return MakeShared<Token>(MoveTemp(in_name), MoveTemp(in_data));
|
||||
}
|
||||
|
||||
/// Create and add a token with the specified debug name
|
||||
template <typename U = T, typename std::enable_if_t<std::is_void<U>::value>* = nullptr>
|
||||
SQUID_NODISCARD TSharedPtr<Token> TakeToken(FString in_name)
|
||||
{
|
||||
return AddToken(MakeToken(MoveTemp(in_name)));
|
||||
}
|
||||
|
||||
/// Create and add a token with the specified debug name and associated data
|
||||
template <typename U = T, typename std::enable_if_t<!std::is_void<U>::value>* = nullptr>
|
||||
SQUID_NODISCARD TSharedPtr<Token> TakeToken(FString in_name, U in_data)
|
||||
{
|
||||
return AddToken(MakeToken(MoveTemp(in_name), MoveTemp(in_data)));
|
||||
}
|
||||
|
||||
/// Add an existing token to this container
|
||||
TSharedPtr<Token> AddToken(TSharedPtr<Token> in_token)
|
||||
{
|
||||
SQUID_RUNTIME_CHECK(in_token, "Cannot add null token");
|
||||
Sanitize();
|
||||
m_tokens.AddUnique(in_token);
|
||||
return in_token;
|
||||
}
|
||||
|
||||
/// Explicitly remove a token from this container
|
||||
void RemoveToken(TSharedPtr<Token> in_token)
|
||||
{
|
||||
// Find and remove the token
|
||||
m_tokens.Remove(in_token);
|
||||
}
|
||||
|
||||
/// Convenience conversion operator that calls HasTokens()
|
||||
operator bool() const
|
||||
{
|
||||
return HasTokens();
|
||||
}
|
||||
|
||||
/// Returns whether this container holds any live tokens
|
||||
bool HasTokens() const
|
||||
{
|
||||
// Return true when holding any unexpired tokens
|
||||
for(auto i = (int32_t)(m_tokens.Num() - 1); i >= 0; --i)
|
||||
{
|
||||
const auto& token = m_tokens[i];
|
||||
if(token.IsValid())
|
||||
{
|
||||
return true;
|
||||
}
|
||||
m_tokens.Pop(); // Because the token is expired, we can safely remove it from the back
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Returns an array of all live token data
|
||||
TArray<T> GetTokenData() const
|
||||
{
|
||||
TArray<T> tokenData;
|
||||
for(const auto& tokenWeak : m_tokens)
|
||||
{
|
||||
if(auto token = tokenWeak.Pin())
|
||||
{
|
||||
tokenData.Add(token->data);
|
||||
}
|
||||
}
|
||||
return tokenData;
|
||||
}
|
||||
|
||||
/// @name Data Queries
|
||||
/// Methods for querying and aggregating the data from the set of live tokens.
|
||||
/// @{
|
||||
|
||||
/// Returns associated data from the least-recently-added live token
|
||||
TOptional<T> GetLeastRecent() const
|
||||
{
|
||||
Sanitize();
|
||||
return m_tokens.Num() ? m_tokens[0].Pin()->data : TOptional<T>{};
|
||||
}
|
||||
|
||||
/// Returns associated data from the most-recently-added live token
|
||||
TOptional<T> GetMostRecent() const
|
||||
{
|
||||
Sanitize();
|
||||
return m_tokens.Num() ? m_tokens.Last().Pin()->data : TOptional<T>{};
|
||||
}
|
||||
|
||||
/// Returns smallest associated data from the set of live tokens
|
||||
TOptional<T> GetMin() const
|
||||
{
|
||||
TOptional<T> ret;
|
||||
SanitizeAndProcessData([&ret](const T& in_data) {
|
||||
if(!ret || in_data < ret.GetValue())
|
||||
{
|
||||
ret = in_data;
|
||||
}
|
||||
});
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Returns largest associated data from the set of live tokens
|
||||
TOptional<T> GetMax() const
|
||||
{
|
||||
TOptional<T> ret;
|
||||
SanitizeAndProcessData([&ret](const T& in_data) {
|
||||
if(!ret || in_data > ret.GetValue())
|
||||
{
|
||||
ret = in_data;
|
||||
}
|
||||
});
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Returns arithmetic mean of all associated data from the set of live tokens
|
||||
TOptional<double> GetMean() const
|
||||
{
|
||||
TOptional<double> ret;
|
||||
TOptional<double> total;
|
||||
SanitizeAndProcessData([&total](const T& in_data) {
|
||||
total = total.Get(0.0) + (double)in_data;
|
||||
});
|
||||
if(total)
|
||||
{
|
||||
ret = total.GetValue() / m_tokens.Num();
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Returns whether the set of live tokens contains at least one token associated with the specified data
|
||||
template <typename U = T, typename std::enable_if_t<!std::is_void<U>::value>* = nullptr>
|
||||
bool Contains(const U& in_searchData) const
|
||||
{
|
||||
bool containsData = false;
|
||||
SanitizeAndProcessData([&in_searchData, &containsData](const T& in_data) {
|
||||
if(in_searchData == in_data)
|
||||
{
|
||||
containsData = true;
|
||||
}
|
||||
});
|
||||
return containsData;
|
||||
}
|
||||
///@} end of Data Queries
|
||||
|
||||
/// Returns a debug string containing a list of the debug names of all live tokens
|
||||
FString GetDebugString() const
|
||||
{
|
||||
TArray<FString> tokenStrings;
|
||||
for(auto token : m_tokens)
|
||||
{
|
||||
if(token.IsValid())
|
||||
{
|
||||
tokenStrings.Add(token.Pin()->name);
|
||||
}
|
||||
}
|
||||
if(tokenStrings.Num())
|
||||
{
|
||||
return FString::Join(tokenStrings, TEXT("\n"));
|
||||
}
|
||||
return TEXT("[no tokens]");
|
||||
}
|
||||
|
||||
private:
|
||||
// Sanitation
|
||||
void Sanitize() const
|
||||
{
|
||||
// Remove all invalid tokens
|
||||
m_tokens.RemoveAll([](const Wp<Token>& in_token) { return !in_token.IsValid(); });
|
||||
}
|
||||
template <typename tFn>
|
||||
void SanitizeAndProcessData(tFn in_dataFn) const
|
||||
{
|
||||
// Remove all invalid tokens while applying a processing function on each valid token
|
||||
m_tokens.RemoveAll([&in_dataFn](const TWeakPtr<Token>& in_token) {
|
||||
if(auto pinnedToken = in_token.Pin())
|
||||
{
|
||||
in_dataFn(pinnedToken->data);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
// Token data
|
||||
mutable TArray<TWeakPtr<Token>> m_tokens; // Mutable so we can remove expired tokens while converting bool
|
||||
};
|
||||
|
||||
NAMESPACE_SQUID_END
|
||||
|
||||
///@} end of Tokens group
|
||||
Reference in New Issue
Block a user