The /include directory now contains the correct STL-based implementation.

This commit is contained in:
Tim Ambrogi
2022-03-04 16:06:29 -05:00
parent cb814bd7dd
commit 675cf43acc
10 changed files with 2474 additions and 490 deletions

View File

@@ -31,13 +31,13 @@ struct StateId
struct TransitionDebugData
{
FSM::StateId oldStateId;
FString oldStateName;
std::string oldStateName;
FSM::StateId newStateId;
FString newStateName;
std::string newStateName;
};
// State transition callback function
using tOnStateTransitionFn = TFunction<void()>;
using tOnStateTransitionFn = std::function<void()>;
#include "Private/TaskFSMPrivate.h" // Internal use only! Do not include elsewhere!
@@ -45,8 +45,8 @@ using tOnStateTransitionFn = TFunction<void()>;
template<class tStateInput, class tStateConstructorFn>
class StateHandle
{
using tPredicateRet = typename std::conditional<!std::is_void<tStateInput>::value, TOptional<tStateInput>, bool>::type;
using tPredicateFn = TFunction<tPredicateRet()>;
using tPredicateRet = typename std::conditional<!std::is_void<tStateInput>::value, std::optional<tStateInput>, bool>::type;
using tPredicateFn = std::function<tPredicateRet()>;
public:
StateHandle(StateHandle&& in_other) = default;
StateHandle& operator=(StateHandle&& in_other) = default;
@@ -72,7 +72,7 @@ public:
}
NONVOID_ONLY LinkHandle Link(tPayload in_payload) //< Empty predicate link w/ payload (always follow link, using provided payload)
{
return _InternalLink([payload = MoveTemp(in_payload)]() -> tPredicateRet { return payload; }, LinkHandle::eType::Normal);
return _InternalLink([payload = std::move(in_payload)]() -> tPredicateRet { return payload; }, LinkHandle::eType::Normal);
}
PREDICATE_ONLY LinkHandle Link(tPredicateFn in_predicate) //< Predicate link w/ implicit payload (follow link when predicate returns a value; use return value as payload)
{
@@ -80,7 +80,7 @@ public:
}
NONVOID_ONLY_WITH_PREDICATE LinkHandle Link(tPredicateFn in_predicate, tPayload in_payload) //< Predicate link w/ explicit payload (follow link when predicate returns true; use provided payload)
{
return _InternalLink(in_predicate, MoveTemp(in_payload), LinkHandle::eType::Normal);
return _InternalLink(in_predicate, std::move(in_payload), LinkHandle::eType::Normal);
}
// OnCompleteLink methods
@@ -90,7 +90,7 @@ public:
}
NONVOID_ONLY LinkHandle OnCompleteLink(tPayload in_payload) //< Empty predicate link w/ payload (always follow link, using provided payload)
{
return _InternalLink([payload = MoveTemp(in_payload)]() -> tPredicateRet { return payload; }, LinkHandle::eType::OnComplete);
return _InternalLink([payload = std::move(in_payload)]() -> tPredicateRet { return payload; }, LinkHandle::eType::OnComplete);
}
PREDICATE_ONLY LinkHandle OnCompleteLink(tPredicateFn in_predicate) //< Predicate link w/ implicit payload (follow link when predicate returns a value; use return value as payload)
{
@@ -98,14 +98,14 @@ public:
}
NONVOID_ONLY_WITH_PREDICATE LinkHandle OnCompleteLink(tPredicateFn in_predicate, tPayload in_payload) //< Predicate link w/ explicit payload (follow link when predicate returns true; use provided payload)
{
return _InternalLink(in_predicate, MoveTemp(in_payload), LinkHandle::eType::OnComplete, true);
return _InternalLink(in_predicate, std::move(in_payload), LinkHandle::eType::OnComplete, true);
}
private:
friend class ::TaskFSM;
StateHandle() = delete;
StateHandle(TSharedPtr<State<tStateInput, tStateConstructorFn>> InStatePtr)
StateHandle(std::shared_ptr<State<tStateInput, tStateConstructorFn>> InStatePtr)
: m_state(InStatePtr)
{
}
@@ -116,21 +116,21 @@ private:
VOID_ONLY_WITH_PREDICATE LinkHandle _InternalLink(tPredicateFn in_predicate, LinkHandle::eType in_linkType, bool in_isConditional = false) // bool-returning predicate
{
static_assert(std::is_same<bool, decltype(in_predicate())>::value, "This link requires a predicate function returning bool");
TSharedPtr<LinkBase> link = MakeShared<FSM::Link<tStateInput, tStateConstructorFn, tPredicateFn>>(m_state, in_predicate);
std::shared_ptr<LinkBase> link = std::make_shared<FSM::Link<tStateInput, tStateConstructorFn, tPredicateFn>>(m_state, in_predicate);
return LinkHandle(link, in_linkType, in_isConditional);
}
NONVOID_ONLY_WITH_PREDICATE LinkHandle _InternalLink(tPredicateFn in_predicate, LinkHandle::eType in_linkType, bool in_isConditional = false) // optional-returning predicate
{
static_assert(std::is_same<TOptional<tStateInput>, decltype(in_predicate())>::value, "This link requires a predicate function returning TOptional<tStateInput>");
TSharedPtr<LinkBase> link = MakeShared<FSM::Link<tStateInput, tStateConstructorFn, tPredicateFn>>(m_state, in_predicate);
static_assert(std::is_same<std::optional<tStateInput>, decltype(in_predicate())>::value, "This link requires a predicate function returning std::optional<tStateInput>");
std::shared_ptr<LinkBase> link = std::make_shared<FSM::Link<tStateInput, tStateConstructorFn, tPredicateFn>>(m_state, in_predicate);
return LinkHandle(link, in_linkType, in_isConditional);
}
NONVOID_ONLY_WITH_PREDICATE LinkHandle _InternalLink(tPredicateFn in_predicate, tPayload in_payload, LinkHandle::eType in_linkType, bool in_isConditional = false) // bool-returning predicate w/ fixed payload
{
static_assert(std::is_same<bool, decltype(in_predicate())>::value, "This link requires a predicate function returning bool");
auto predicate = [in_predicate, in_payload]() -> TOptional<tStateInput>
auto predicate = [in_predicate, in_payload]() -> std::optional<tStateInput>
{
return in_predicate() ? TOptional<tStateInput>(in_payload) : TOptional<tStateInput>{};
return in_predicate() ? std::optional<tStateInput>(in_payload) : std::optional<tStateInput>{};
};
return _InternalLink(predicate, in_linkType, in_isConditional);
}
@@ -142,7 +142,7 @@ private:
#undef VOID_ONLY
#undef PREDICATE_ONLY
TSharedPtr<State<tStateInput, tStateConstructorFn>> m_state; // Internal state object
std::shared_ptr<State<tStateInput, tStateConstructorFn>> m_state; // Internal state object
};
} // namespace FSM
@@ -157,68 +157,67 @@ using tOnStateTransitionFn = FSM::tOnStateTransitionFn;
class TaskFSM
{
public:
using tOnStateTransitionFn = TFunction<void()>;
using tDebugStateTransitionFn = TFunction<void(TransitionDebugData)>;
using tDebugStateTransitionFn = std::function<void(TransitionDebugData)>;
// Create a new FSM state [fancy param-deducing version (hopefully) coming soon!]
template<typename tStateConstructorFn>
auto State(FString in_name, tStateConstructorFn in_stateCtorFn)
auto State(std::string in_name, tStateConstructorFn in_stateCtorFn)
{
typedef FSM::function_traits<tStateConstructorFn> tFnTraits;
using tStateInput = typename tFnTraits::tArg;
const FSM::StateId newStateId = m_states.Num();
m_states.Add(InternalStateData(in_name));
auto state = MakeShared<FSM::State<tStateInput, tStateConstructorFn>>(MoveTemp(in_stateCtorFn), newStateId, in_name);
const FSM::StateId newStateId = m_states.size();
m_states.push_back(InternalStateData(in_name));
auto state = std::make_shared<FSM::State<tStateInput, tStateConstructorFn>>(std::move(in_stateCtorFn), newStateId, in_name);
return FSM::StateHandle<tStateInput, tStateConstructorFn>{ state };
}
// Create a new FSM exit state (immediately terminates the FSM when executed)
FSM::StateHandle<void, void> State(FString in_name)
FSM::StateHandle<void, void> State(std::string in_name)
{
const FSM::StateId newStateId = m_states.Num();
m_states.Add(InternalStateData(in_name));
m_exitStates.Add(newStateId);
auto state = MakeShared<FSM::State<void, void>>(newStateId, in_name);
const FSM::StateId newStateId = m_states.size();
m_states.push_back(InternalStateData(in_name));
m_exitStates.push_back(newStateId);
auto state = std::make_shared<FSM::State<void, void>>(newStateId, in_name);
return FSM::StateHandle<void, void>{ state };
}
// Define the initial entry links into the state machine
void EntryLinks(TArray<FSM::LinkHandle> in_entryLinks);
void EntryLinks(std::vector<FSM::LinkHandle> in_entryLinks);
// Define all outgoing links from a given state (may only be called once per state)
template<class tStateInput, class tStateConstructorFn>
void StateLinks(const FSM::StateHandle<tStateInput, tStateConstructorFn>& in_originState, TArray<FSM::LinkHandle> in_outgoingLinks);
void StateLinks(const FSM::StateHandle<tStateInput, tStateConstructorFn>& in_originState, std::vector<FSM::LinkHandle> in_outgoingLinks);
// Begins execution of the state machine (returns id of final exit state)
Task<FSM::StateId> Run(tOnStateTransitionFn in_onTransitionFn = {}, tDebugStateTransitionFn in_debugStateTransitionFn = {}) const;
private:
// Evaluates all possible outgoing links from the current state, returning the first valid transition (if any transitions are valid)
TOptional<FSM::TransitionEvent> EvaluateLinks(FSM::StateId in_curStateId, bool in_isCurrentStateComplete, const tOnStateTransitionFn& in_onTransitionFn) const;
std::optional<FSM::TransitionEvent> EvaluateLinks(FSM::StateId in_curStateId, bool in_isCurrentStateComplete, const tOnStateTransitionFn& in_onTransitionFn) const;
// Internal state
struct InternalStateData
{
InternalStateData(FString in_debugName)
InternalStateData(std::string in_debugName)
: debugName(in_debugName)
{
}
TArray<FSM::LinkHandle> outgoingLinks;
FString debugName;
std::vector<FSM::LinkHandle> outgoingLinks;
std::string debugName;
};
TArray<InternalStateData> m_states;
TArray<FSM::LinkHandle> m_entryLinks;
TArray<FSM::StateId> m_exitStates;
std::vector<InternalStateData> m_states;
std::vector<FSM::LinkHandle> m_entryLinks;
std::vector<FSM::StateId> m_exitStates;
};
/// @} end of group TaskFSM
//--- TaskFSM Methods ---//
template<class tStateInput, class tStateConstructorFn>
void TaskFSM::StateLinks(const FSM::StateHandle<tStateInput, tStateConstructorFn>& in_originState, TArray<FSM::LinkHandle> in_outgoingLinks)
void TaskFSM::StateLinks(const FSM::StateHandle<tStateInput, tStateConstructorFn>& in_originState, std::vector<FSM::LinkHandle> in_outgoingLinks)
{
const int32_t stateIdx = in_originState.m_state->stateId.idx;
SQUID_RUNTIME_CHECK(m_states[stateIdx].outgoingLinks.Num() == 0, "Cannot set outgoing links more than once for each state");
SQUID_RUNTIME_CHECK(m_states[stateIdx].outgoingLinks.size() == 0, "Cannot set outgoing links more than once for each state");
// Validate that there are exactly 0 or 1 unconditional OnComplete links (there may be any number of other OnComplete links, but only one with no condition)
int32_t numOnCompleteLinks = 0;
@@ -238,9 +237,9 @@ void TaskFSM::StateLinks(const FSM::StateHandle<tStateInput, tStateConstructorFn
SQUID_RUNTIME_CHECK(numOnCompleteLinks == 0 || numOnCompleteLinks_Unconditional > 0, "More than one unconditional OnCompleteLink() was set");
// Set the outgoing links for the origin state
m_states[stateIdx].outgoingLinks = MoveTemp(in_outgoingLinks);
m_states[stateIdx].outgoingLinks = std::move(in_outgoingLinks);
}
inline void TaskFSM::EntryLinks(TArray<FSM::LinkHandle> in_entryLinks)
inline void TaskFSM::EntryLinks(std::vector<FSM::LinkHandle> in_entryLinks)
{
// Validate to ensure there are no OnComplete links set as entry links
int32_t numOnCompleteLinks = 0;
@@ -254,12 +253,12 @@ inline void TaskFSM::EntryLinks(TArray<FSM::LinkHandle> in_entryLinks)
SQUID_RUNTIME_CHECK(numOnCompleteLinks == 0, "EntryLinks() list may not contain any OnCompleteLink() links");
// Set the entry links list for this FSM
m_entryLinks = MoveTemp(in_entryLinks);
m_entryLinks = std::move(in_entryLinks);
}
inline TOptional<FSM::TransitionEvent> TaskFSM::EvaluateLinks(FSM::StateId in_curStateId, bool in_isCurrentStateComplete, const tOnStateTransitionFn& in_onTransitionFn) const
inline std::optional<FSM::TransitionEvent> TaskFSM::EvaluateLinks(FSM::StateId in_curStateId, bool in_isCurrentStateComplete, const tOnStateTransitionFn& in_onTransitionFn) const
{
// Determine whether to use entry links or state-specific outgoing links
const TArray<FSM::LinkHandle>& links = (in_curStateId.idx < m_states.Num()) ? m_states[in_curStateId.idx].outgoingLinks : m_entryLinks;
const std::vector<FSM::LinkHandle>& links = (in_curStateId.idx < m_states.size()) ? m_states[in_curStateId.idx].outgoingLinks : m_entryLinks;
// Find the first valid transition from the current state
for(const FSM::LinkHandle& link : links)
@@ -283,17 +282,17 @@ inline Task<FSM::StateId> TaskFSM::Run(tOnStateTransitionFn in_onTransitionFn, t
// Custom debug task name logic
TASK_NAME(__FUNCTION__, [this, &curStateId, &task]
{
const auto stateName = m_states.IsValidIndex(curStateId.idx) ? m_states[curStateId.idx].debugName : "";
return FString::Printf(TEXT("%s -- %s"), *stateName, *task.GetDebugStack());
const std::string stateName = (curStateId.idx < m_states.size()) ? m_states[curStateId.idx].debugName : "";
return stateName + " -- " + task.GetDebugStack();
});
// Debug state transition lambda
auto DebugStateTransition = [this, in_debugStateTransitionFn](FSM::StateId in_oldStateId, FSM::StateId in_newStateId) {
if(in_debugStateTransitionFn)
{
FString oldStateName = in_oldStateId.IsValid() ? m_states[in_oldStateId.idx].debugName : FString("<ENTRY>");
FString newStateName = m_states[in_newStateId.idx].debugName;
in_debugStateTransitionFn({ in_oldStateId, MoveTemp(oldStateName), in_newStateId, MoveTemp(newStateName) });
std::string oldStateName = in_oldStateId.IsValid() ? m_states[in_oldStateId.idx].debugName : std::string("<ENTRY>");
std::string newStateName = m_states[in_newStateId.idx].debugName;
in_debugStateTransitionFn({ in_oldStateId, std::move(oldStateName), in_newStateId, std::move(newStateName) });
}
};
@@ -301,22 +300,23 @@ inline Task<FSM::StateId> TaskFSM::Run(tOnStateTransitionFn in_onTransitionFn, t
while(true)
{
// Evaluate links, checking for a valid transition
if(TOptional<FSM::TransitionEvent> transition = EvaluateLinks(curStateId, task.IsDone(), in_onTransitionFn))
if(std::optional<FSM::TransitionEvent> transition = EvaluateLinks(curStateId, task.IsDone(), in_onTransitionFn))
{
auto newStateId = transition->newStateId;
DebugStateTransition(curStateId, newStateId); // Call state-transition debug function
// If the transition is to an exit state, return that state ID (terminating the FSM)
if(m_exitStates.Contains(newStateId.idx))
auto Found = std::find(m_exitStates.begin(), m_exitStates.end(), newStateId.idx);
if(Found != m_exitStates.end())
{
co_return newStateId;
}
SQUID_RUNTIME_CHECK(newStateId.idx < m_states.Num(), "It should be logically impossible to get an invalid state to this point");
SQUID_RUNTIME_CHECK(newStateId.idx < m_states.size(), "It should be logically impossible to get an invalid state to this point");
// Begin running new state (implicitly killing old state)
curStateId = newStateId;
co_await RemoveStopTask(task);
task = MoveTemp(transition->newTask); // NOTE: Initial call to Resume() happens below
task = std::move(transition->newTask); // NOTE: Initial call to Resume() happens below
co_await AddStopTask(task);
}