Initial commit.

This commit is contained in:
Tim Ambrogi
2022-03-04 15:30:18 -05:00
commit 0aad97fa48
49 changed files with 23959 additions and 0 deletions

View File

@@ -0,0 +1,220 @@
// WARNING: This is an internal implementation header, which must be included from a specific location/namespace
// That is the reason that this header does not contain a #pragma once, nor namespace guards
// Helper struct representing a transition event to a new FSM state
struct TransitionEvent
{
Task<> newTask;
StateId newStateId;
};
// Base class for defining links between states
class LinkBase
{
public:
virtual ~LinkBase() = default;
virtual TOptional<TransitionEvent> EvaluateLink(const tOnStateTransitionFn& in_onTransitionFn) const = 0;
};
// Type-safe link handle
class LinkHandle
{
bool IsOnCompleteLink() const
{
return m_linkType == eType::OnComplete;
}
bool HasCondition() const
{
return m_isConditionalLink;
}
protected:
// Link-type enum
enum class eType
{
Normal,
OnComplete,
};
// Friends
template<class, class> friend class StateHandle;
friend class ::TaskFSM;
// Constructors (friend-only)
LinkHandle() = delete;
LinkHandle(TSharedPtr<LinkBase> in_link, eType in_linkType, bool in_isConditional)
: m_link(MoveTemp(in_link))
, m_linkType(in_linkType)
, m_isConditionalLink(in_isConditional)
{
}
TOptional<TransitionEvent> EvaluateLink(const tOnStateTransitionFn& in_onTransitionFn) const
{
return m_link->EvaluateLink(in_onTransitionFn);
}
private:
TSharedPtr<LinkBase> m_link; // The underlying link
eType m_linkType; // Whether the link is normal or OnComplete
bool m_isConditionalLink; // Whether the link has an associated condition predicate
};
// Internal FSM state object
template<class tStateInput, class tStateConstructorFn>
struct State
{
State(tStateConstructorFn in_stateCtorFn, StateId in_stateId, FString in_debugName)
: stateCtorFn(in_stateCtorFn)
, stateId(in_stateId)
, debugName(in_debugName)
{
}
tStateConstructorFn stateCtorFn;
StateId stateId;
FString debugName;
};
// Internal FSM state object (exit state specialization)
template<>
struct State<void, void>
{
State(StateId in_stateId, FString in_debugName)
: stateId(in_stateId)
, debugName(in_debugName)
{
}
StateId stateId;
FString debugName;
};
// Internal link definition object
template<class ReturnT, class tStateConstructorFn, class tPredicateFn>
class Link : public LinkBase
{
public:
Link(TSharedPtr<State<ReturnT, tStateConstructorFn>> in_targetState, tPredicateFn in_predicate)
: m_targetState(MoveTemp(in_targetState))
, m_predicate(in_predicate)
{
}
private:
virtual TOptional<TransitionEvent> EvaluateLink(const tOnStateTransitionFn& in_onTransitionFn) const final
{
TOptional<TransitionEvent> result;
if(TOptional<ReturnT> payload = m_predicate())
{
if(in_onTransitionFn)
{
in_onTransitionFn();
}
result = TransitionEvent{ m_targetState->stateCtorFn(payload.GetValue()), m_targetState->stateId };
}
return result;
}
TSharedPtr<State<ReturnT, tStateConstructorFn>> m_targetState;
tPredicateFn m_predicate;
};
// Internal link definition object (no-payload specialization)
template<class tStateConstructorFn, class tPredicateFn>
class Link<void, tStateConstructorFn, tPredicateFn> : public LinkBase
{
public:
Link(TSharedPtr<State<void, tStateConstructorFn>> in_targetState, tPredicateFn in_predicate)
: m_targetState(MoveTemp(in_targetState))
, m_predicate(in_predicate)
{
}
private:
virtual TOptional<TransitionEvent> EvaluateLink(const tOnStateTransitionFn& in_onTransitionFn) const final
{
TOptional<TransitionEvent> result;
if(m_predicate())
{
if(in_onTransitionFn)
{
in_onTransitionFn();
}
result = TransitionEvent{ m_targetState->stateCtorFn(), m_targetState->stateId };
}
return result;
}
TSharedPtr<State<void, tStateConstructorFn>> m_targetState;
tPredicateFn m_predicate;
};
// Internal link definition object (exit-state specialization)
template<class tPredicateFn>
class Link<void, void, tPredicateFn> : public LinkBase
{
public:
Link(TSharedPtr<State<void, void>> in_targetState, tPredicateFn in_predicate)
: m_targetState(MoveTemp(in_targetState))
, m_predicate(in_predicate)
{
}
private:
virtual TOptional<TransitionEvent> EvaluateLink(const tOnStateTransitionFn& in_onTransitionFn) const final
{
TOptional<TransitionEvent> result;
if(m_predicate())
{
if(in_onTransitionFn)
{
in_onTransitionFn();
}
result = TransitionEvent{ Task<>(), m_targetState->stateId };
}
return result;
}
TSharedPtr<State<void, void>> m_targetState;
tPredicateFn m_predicate;
};
// Specialized type traits that deduce the first argument type of an arbitrary callable type
template <typename tRet, typename tArg>
static tArg get_first_arg_type(TFunction<tRet(tArg)> f); // Return type is first argument type
template <typename tRet>
static void get_first_arg_type(TFunction<tRet()> f); // Return type is void (function has no arguments)
template <typename T>
struct function_traits : public function_traits<decltype(&T::operator())> // Generic callable objects (use operator())
{
};
template <typename tRet, typename... tArgs> // Function
struct function_traits<tRet(tArgs...)>
{
using tFunction = TFunction<tRet(tArgs...)>;
using tArg = decltype(get_first_arg_type(tFunction()));
};
template <typename tRet, typename... tArgs> // Function ptr
struct function_traits<tRet(*)(tArgs...)>
{
using tFunction = TFunction<tRet(tArgs...)>;
using tArg = decltype(get_first_arg_type(tFunction()));
};
template <typename tClass, typename tRet, typename... tArgs> // Member function ptr (const)
struct function_traits<tRet(tClass::*)(tArgs...) const>
{
using tFunction = TFunction<tRet(tArgs...)>;
using tArg = decltype(get_first_arg_type(tFunction()));
};
template <typename tClass, typename tRet, typename... tArgs> // Member function ptr
struct function_traits<tRet(tClass::*)(tArgs...)>
{
using tFunction = TFunction<tRet(tArgs...)>;
using tArg = decltype(get_first_arg_type(tFunction()));
};

View File

@@ -0,0 +1,896 @@
// WARNING: This is an internal implementation header, which must be included from a specific location/namespace
// That is the reason that this header does not contain a #pragma once, nor namespace guards
enum class eTaskRef;
template <typename tRet> class TaskPromise;
class TaskInternalBase;
template <typename tRet> class TaskInternal;
//--- tTaskReadyFn ---//
using tTaskReadyFn = TFunction<bool()>;
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable>
auto CancelTaskIf(Task<tRet, RefType, Resumable>&& in_task, tTaskCancelFn in_cancelFn);
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable>
auto StopTaskIf(Task<tRet, RefType, Resumable>&& in_task, tTaskCancelFn in_cancelFn);
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable, typename tTimeFn>
auto StopTaskIf(Task<tRet, RefType, Resumable>&& in_task, tTaskCancelFn in_cancelFn, tTaskTime in_timeout, tTimeFn in_timeFn);
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable, typename T>
auto StopTaskIf(Task<tRet, RefType, Resumable>&& in_task, tTaskCancelFn in_cancelFn, tTaskTime in_timeout);
//--- Suspend-If Awaiter ---//
struct SuspendIf
{
SuspendIf(bool in_suspend)
: m_suspend(in_suspend)
{
}
bool await_ready() noexcept { return !m_suspend; }
void await_suspend(std::coroutine_handle<>) noexcept {}
void await_resume() noexcept {}
private:
bool m_suspend;
};
//--- Task Debug Stack Formatter ---//
struct TaskDebugStackFormatter
{
// Format function (formats a debug output string) [virtual]
virtual FString Format(const FString& in_str) const
{
FString result = Indent(0);
int32_t indent = 0;
int32_t start = 0;
int32_t found = 0;
while((found = in_str.FindChar('\n', start)) != INDEX_NONE)
{
int32_t end = found + 1;
if((found < in_str.Len() - 1) && (in_str[found + 1] == '`')) // indent
{
++indent;
++end;
}
else if((found >= 1) && (in_str[found - 1] == '`')) // dedent
{
--indent;
--found;
}
result += in_str.Mid(start, found - start) + '\n' + Indent(indent);
start = end;
}
result += in_str.Mid(start);
return result;
}
virtual FString Indent(int32_t in_indent) const
{
return FString::ChrN(in_indent * 2, ' ');
}
};
static FString FormatDebugString(FString in_str)
{
in_str.ReplaceCharInline('\n', ' ');
in_str.LeftChopInline(32, false);
return in_str;
}
//--- SetDebugName Awaiter ---//
#if SQUID_ENABLE_TASK_DEBUG
struct SetDebugName
{
// Sets a Task's debug name field
SetDebugName(const char* in_name)
: m_name(in_name)
{
}
SetDebugName(const char* in_name, TFunction<FString()> in_dataFn)
: m_name(in_name)
, m_dataFn(in_dataFn)
{
}
private:
template <typename tRet> friend class TaskPromiseBase;
const char* m_name = nullptr;
TFunction<FString()> m_dataFn;
};
#endif //SQUID_ENABLE_TASK_DEBUG
//--- AddStopTask Awaiter ---//
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable>
struct AddStopTaskAwaiter
{
AddStopTaskAwaiter(Task<tRet, RefType, Resumable>& in_taskToStop)
: m_taskToStop(&in_taskToStop)
{
}
private:
template <typename tRet> friend class TaskPromiseBase;
Task<tRet, RefType, Resumable>* m_taskToStop = nullptr;
};
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable>
auto AddStopTask(Task<tRet, RefType, Resumable>& in_taskToStop)
{
return AddStopTaskAwaiter<tRet, RefType, Resumable>(in_taskToStop);
};
//--- RemoveStopTask Awaiter ---//
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable>
struct RemoveStopTaskAwaiter
{
RemoveStopTaskAwaiter(Task<tRet, RefType, Resumable>& in_taskToStop)
: m_taskToStop(&in_taskToStop)
{
}
private:
template <typename tRet> friend class TaskPromiseBase;
Task<tRet, RefType, Resumable>* m_taskToStop = nullptr;
};
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable>
auto RemoveStopTask(Task<tRet, RefType, Resumable>& in_taskToStop)
{
return RemoveStopTaskAwaiter<tRet, RefType, Resumable>(in_taskToStop);
};
//--- Task Awaiter ---//
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable, typename promise_type>
struct TaskAwaiterBase
{
TaskAwaiterBase(const Task<tRet, RefType, Resumable>& in_task)
{
// This constructor exists to minimize downstream compile-error spam when co_awaiting a non-copyable Task by copy
}
TaskAwaiterBase(Task<tRet, RefType, Resumable>&& in_task)
: m_task(MoveTemp(in_task))
{
SQUID_RUNTIME_CHECK(m_task.IsValid(), "Tried to await an invalid task");
}
TaskAwaiterBase(TaskAwaiterBase&& in_taskAwaiter) noexcept
{
m_task = MoveTemp(in_taskAwaiter.m_task);
}
bool await_ready() noexcept
{
if(m_task.IsDone())
{
return true;
}
return false;
}
template <eTaskResumable UResumable = Resumable, typename std::enable_if_t<UResumable == eTaskResumable::Yes>* = nullptr>
bool await_suspend(std::coroutine_handle<promise_type> in_coroHandle) noexcept
{
// Set the sub-task on the suspending task
auto& promise = in_coroHandle.promise();
auto taskInternal = promise.GetInternalTask();
auto subTaskInternal = m_task.GetInternalTask();
if(taskInternal->IsStopRequested())
{
subTaskInternal->RequestStop(); // Propagate any stop request to new sub-tasks
}
taskInternal->SetSubTask(StaticCastSharedPtr<TaskInternalBase>(subTaskInternal));
// Resume the task
if(m_task.Resume() == eTaskStatus::Done)
{
taskInternal->SetSubTask(nullptr);
return false; // Do not suspend, because the task is done
}
return true; // Suspend, because the task is not done
}
template <eTaskResumable UResumable = Resumable, typename std::enable_if_t<UResumable == eTaskResumable::No>* = nullptr>
bool await_suspend(std::coroutine_handle<promise_type> in_coroHandle) noexcept
{
auto& promise = in_coroHandle.promise();
if(!m_task.IsDone())
{
promise.SetReadyFunction([this] { return m_task.IsDone(); });
return true; // Suspend, because the task is not done
}
return false; // Do not suspend, because the task is done
}
protected:
auto GetInternalTask() const
{
return m_task.GetInternalTask();
}
Task<tRet, RefType, Resumable> m_task;
};
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable, typename promise_type>
struct TaskAwaiter : public TaskAwaiterBase<tRet, RefType, Resumable, promise_type>
{
using TaskAwaiterBase<tRet, RefType, Resumable, promise_type>::TaskAwaiterBase;
template <typename U = tRet, typename std::enable_if_t<!std::is_void<U>::value>* = nullptr>
auto await_resume()
{
this->m_task.RethrowUnhandledException(); // Re-throw any exceptions
auto retVal = this->m_task.TakeReturnValue();
SQUID_RUNTIME_CHECK(retVal, "Awaited task return value is unset");
return MoveTemp(retVal.GetValue());
}
template <typename U = tRet, typename std::enable_if_t<std::is_void<U>::value>* = nullptr>
void await_resume()
{
this->m_task.RethrowUnhandledException(); // Re-throw any exceptions
}
};
//--- Future Awaiter ---//
template <typename tRet, typename promise_type>
struct FutureAwaiter
{
FutureAwaiter(TFuture<tRet>&& in_future)
: m_future(MoveTemp(in_future))
{
}
~FutureAwaiter()
{
}
FutureAwaiter(FutureAwaiter&& in_futureAwaiter) noexcept
{
m_future = MoveTemp(in_futureAwaiter.m_future);
}
bool await_ready() noexcept
{
bool isReady = m_future.IsReady();
return isReady;
}
bool await_suspend(std::coroutine_handle<promise_type> in_coroHandle) noexcept
{
// Set the ready function
auto& promise = in_coroHandle.promise();
// Suspend if future is not ready
bool shouldSuspend = !m_future.IsReady();
if(shouldSuspend)
{
promise.SetReadyFunction([this] { return m_future.IsReady(); });
}
return shouldSuspend;
}
template <typename U = tRet, typename std::enable_if_t<!std::is_void<U>::value>* = nullptr>
auto await_resume()
{
return m_future.Get();
}
template <typename U = tRet, typename std::enable_if_t<std::is_void<U>::value>* = nullptr>
void await_resume()
{
m_future.Get();
}
private:
TFuture<tRet> m_future;
};
//--- Shared Future Awaiter ---//
template <typename tRet, typename promise_type>
struct SharedFutureAwaiter
{
SharedFutureAwaiter(const TSharedFuture<tRet>& in_sharedFuture)
: m_sharedFuture(in_sharedFuture)
{
}
bool await_ready() noexcept
{
bool isReady = m_sharedFuture.IsReady();
return isReady;
}
bool await_suspend(std::coroutine_handle<promise_type> in_coroHandle) noexcept
{
// Set the ready function
auto& promise = in_coroHandle.promise();
// Suspend if future is not ready
bool shouldSuspend = !m_sharedFuture.IsReady();
if(shouldSuspend)
{
promise.SetReadyFunction([this] { return m_sharedFuture.IsReady(); });
}
return shouldSuspend;
}
template <typename U = tRet, typename std::enable_if_t<!std::is_void<U>::value>* = nullptr>
auto await_resume()
{
return m_sharedFuture.Get();
}
template <typename U = tRet, typename std::enable_if_t<std::is_void<U>::value>* = nullptr>
void await_resume()
{
m_sharedFuture.Get(); // Trigger any pending errors
}
private:
TSharedFuture<tRet> m_sharedFuture;
};
//--- TaskPromiseBase ---//
template <typename tRet>
class alignas(16) TaskPromiseBase
{
public:
// Type aliases
using promise_type = TaskPromise<tRet>;
using tTaskInternal = TaskInternal<tRet>;
// Destructor
~TaskPromiseBase()
{
// NOTE: Destructor is non-virtual, because it is always handled + destroyed as its concrete type
m_taskInternal->OnTaskPromiseDestroyed();
}
// Coroutine interface functions
auto initial_suspend() noexcept
{
return std::suspend_always();
}
auto final_suspend() noexcept
{
return std::suspend_always();
}
auto get_return_object()
{
return std::coroutine_handle<promise_type>::from_promise(*static_cast<promise_type*>(this));
}
static TSharedPtr<tTaskInternal> get_return_object_on_allocation_failure()
{
SQUID_THROW(std::bad_alloc(), "Failed to allocate memory for Task");
return {};
}
//----------------------------------------------------------------------------
// HACK: Coroutines in UE5 under MSVC is currently causing a memory underrun
// These allocators are a workaround for the issue (as is alignas(16))
void* operator new(size_t Size) noexcept
{
const size_t WorkaroundAlign = std::alignment_of<TaskPromiseBase>();
Size += WorkaroundAlign;
return (void*)((uint8_t*)FMemory::Malloc(Size, WorkaroundAlign) + WorkaroundAlign);
}
void operator delete(void* Ptr) noexcept
{
const size_t WorkaroundAlign = std::alignment_of<TaskPromiseBase>();
auto OffsetPtr = (uint8_t*)Ptr - WorkaroundAlign;
FMemory::Free(OffsetPtr);
}
//----------------------------------------------------------------------------
#if SQUID_NEEDS_UNHANDLED_EXCEPTION
void unhandled_exception() noexcept
{
#if SQUID_USE_EXCEPTIONS
// Propagate exceptions for handling
m_taskInternal->SetUnhandledException(std::current_exception());
#endif //SQUID_USE_EXCEPTIONS
}
#endif // SQUID_NEEDS_UNHANDLED_EXCEPTION
// Internal Task
void SetInternalTask(tTaskInternal* in_taskInternal)
{
m_taskInternal = in_taskInternal;
}
tTaskInternal* GetInternalTask()
{
return m_taskInternal;
}
const tTaskInternal* GetInternalTask() const
{
return m_taskInternal;
}
// Ready Function
void SetReadyFunction(const tTaskReadyFn& in_taskReadyFn)
{
m_taskInternal->SetReadyFunction(in_taskReadyFn);
}
// Await-Transforms
auto await_transform(Suspend in_awaiter)
{
return in_awaiter;
}
auto await_transform(std::suspend_never in_awaiter)
{
return in_awaiter;
}
#if SQUID_ENABLE_TASK_DEBUG
auto await_transform(SetDebugName in_awaiter)
{
m_taskInternal->SetDebugName(in_awaiter.m_name);
m_taskInternal->SetDebugDataFn(in_awaiter.m_dataFn);
return std::suspend_never();
}
#endif //SQUID_ENABLE_TASK_DEBUG
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable>
auto await_transform(AddStopTaskAwaiter<tRet, RefType, Resumable> in_awaiter)
{
m_taskInternal->AddStopTask(*in_awaiter.m_taskToStop);
return std::suspend_never();
}
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable>
auto await_transform(RemoveStopTaskAwaiter<tRet, RefType, Resumable> in_awaiter)
{
m_taskInternal->RemoveStopTask(*in_awaiter.m_taskToStop);
return std::suspend_never();
}
auto await_transform(GetStopContext in_awaiter)
{
struct GetStopContextAwaiter : public std::suspend_never
{
GetStopContextAwaiter(StopContext in_stopCtx)
: stopCtx(in_stopCtx)
{
}
auto await_resume() noexcept
{
return stopCtx;
}
StopContext stopCtx;
};
GetStopContextAwaiter stopCtxAwaiter{ m_taskInternal->GetStopContext() };
return stopCtxAwaiter;
}
auto await_transform(const tTaskReadyFn& in_taskReadyFn)
{
// Check if we are already ready, and suspend if we are not
bool isReady = in_taskReadyFn();
if(!isReady)
{
m_taskInternal->SetReadyFunction(in_taskReadyFn);
}
return SuspendIf(!isReady); // Suspend if the function isn't already ready
}
template <typename tFutureRet>
auto await_transform(TFuture<tFutureRet>&& in_future)
{
return FutureAwaiter<tFutureRet, promise_type>(MoveTemp(in_future));
}
template <typename tFutureRet>
auto await_transform(const TSharedFuture<tFutureRet>& in_sharedFuture)
{
return SharedFutureAwaiter<tFutureRet, promise_type>(in_sharedFuture);
}
// Task Await-Transforms
template <typename tTaskRet, eTaskRef RefType, eTaskResumable Resumable,
typename std::enable_if_t<Resumable == eTaskResumable::Yes>* = nullptr>
auto await_transform(Task<tTaskRet, RefType, Resumable>&& in_task) // Move version
{
return TaskAwaiter<tTaskRet, RefType, Resumable, promise_type>(MoveTemp(in_task));
}
template <typename tTaskRet, eTaskRef RefType, eTaskResumable Resumable,
typename std::enable_if_t<Resumable == eTaskResumable::No>* = nullptr>
auto await_transform(Task<tTaskRet, RefType, Resumable> in_task) // Copy version (Non-Resumable)
{
return TaskAwaiter<tTaskRet, RefType, Resumable, promise_type>(MoveTemp(in_task));
}
template <typename tTaskRet, eTaskRef RefType, eTaskResumable Resumable,
typename std::enable_if_t<Resumable == eTaskResumable::Yes>* = nullptr>
auto await_transform(const Task<tTaskRet, RefType, Resumable>& in_task) // Invalid copy version (Resumable)
{
static_assert(static_false<tTaskRet>::value, "Cannot await a non-copyable (resumable) Task by copy (try co_await MoveTemp(task), co_await WeakTaskHandle(task), or co_await task.WaitUntilDone()");
return TaskAwaiter<tTaskRet, RefType, Resumable, promise_type>(MoveTemp(in_task));
}
protected:
tTaskInternal* m_taskInternal = nullptr;
};
//--- TaskPromise ---//
template <typename tRet>
class TaskPromise : public TaskPromiseBase<tRet>
{
public:
// Return value access
void return_value(const tRet& in_retVal) // Copy return value
{
this->m_taskInternal->SetReturnValue(in_retVal);
}
void return_value(tRet&& in_retVal) // Move return value
{
this->m_taskInternal->SetReturnValue(MoveTemp(in_retVal));
}
};
template <>
class TaskPromise<void> : public TaskPromiseBase<void>
{
public:
void return_void()
{
}
};
//--- TaskInternalBase ---//
class TaskInternalBase
{
public:
TaskInternalBase(std::coroutine_handle<> in_coroHandle)
: m_coroHandle(in_coroHandle)
{
SQUID_RUNTIME_CHECK(m_coroHandle, "Invalid coroutine handle passed into Task");
}
~TaskInternalBase() // NOTE: Destructor is intentionally non-virtual (shared_ptr preserves concrete type via deleter)
{
Kill(); // Used for killing subtasks
}
StopContext GetStopContext() const
{
return { &m_isStopRequested };
}
bool IsStopRequested() const
{
return m_isStopRequested;
}
void RequestStop() // Propagates a request for the task to come to a 'graceful' stop
{
m_isStopRequested = true;
for(auto& stopTask : m_stopTasks)
{
if(auto locked = stopTask.Pin())
{
locked->RequestStop();
}
}
m_stopTasks.SetNum(0);
}
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable>
void AddStopTask(Task<tRet, RefType, Resumable>& in_taskToStop) // Adds a task to the list of tasks to which we propagate stop requests
{
if(m_isStopRequested)
{
in_taskToStop.RequestStop();
}
else if(in_taskToStop.IsValid())
{
m_stopTasks.Add(in_taskToStop.GetInternalTask());
}
}
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable>
void RemoveStopTask(Task<tRet, RefType, Resumable>& in_taskToStop) // Removes a task to the list of tasks to which we propagate stop requests
{
if(in_taskToStop.IsValid())
{
for(int32_t i = 0; i < m_stopTasks.Num(); ++i)
{
if(m_stopTasks[i].Pin() == in_taskToStop.GetInternalTask())
{
m_stopTasks[i] = m_stopTasks.Last();
m_stopTasks.Pop();
return;
}
}
}
}
eTaskStatus Resume() // Returns whether the task is still running
{
// Make sure this task is not already mid-resume
SQUID_RUNTIME_CHECK(m_internalState != eInternalState::Resuming, "Attempted to resume Task while already resumed");
// Task is destroyed, therefore task is done
if(m_internalState == eInternalState::Destroyed)
{
return eTaskStatus::Done;
}
// Mark task as resuming
m_internalState = eInternalState::Resuming;
// Resume any active sub-task
if(m_subTaskInternal)
{
// Propagate any stop requests to sub-task prior to resuming
if(m_isStopRequested)
{
m_subTaskInternal->m_isStopRequested = true;
}
// Resume the sub-task
if(m_subTaskInternal->Resume() != eTaskStatus::Done)
{
m_internalState = eInternalState::Idle;
return eTaskStatus::Suspended; // Sub-task not done, therefore task is not done
}
// Clear the sub-task
m_subTaskInternal = nullptr;
}
// Resume task, if necessary
if(CanResume())
{
m_taskReadyFn = nullptr; // Clear any ready function we were waiting on
m_coroHandle.resume(); // Resume the underlying std::coroutine_handle
}
// Return to idle state and return current task status
auto taskStatus = m_coroHandle.done() ? eTaskStatus::Done : eTaskStatus::Suspended;
if(taskStatus == eTaskStatus::Done)
{
m_isDone = true; // Mark task done
}
m_internalState = eInternalState::Idle;
return taskStatus;
}
// Sub-task
void SetSubTask(TSharedPtr<TaskInternalBase> in_subTaskInternal)
{
m_subTaskInternal = in_subTaskInternal;
}
#if SQUID_ENABLE_TASK_DEBUG
// Debug task name + stack
FString GetDebugName() const
{
return (!IsDone() && m_debugDataFn) ? (FString(m_debugName) + " [" + m_debugDataFn() + "]") : m_debugName;
}
FString GetDebugStack() const
{
FString result = m_subTaskInternal ? (GetDebugName() + " -> " + m_subTaskInternal->GetDebugStack()) : GetDebugName();
return result;
}
void SetDebugName(const char* in_debugName)
{
if(in_debugName)
{
m_debugName = in_debugName;
}
}
void SetDebugDataFn(TFunction<FString()> in_debugDataFn)
{
m_debugDataFn = in_debugDataFn;
}
#endif //SQUID_ENABLE_TASK_DEBUG
// Exceptions
#if SQUID_USE_EXCEPTIONS
std::exception_ptr GetUnhandledException() const
{
if(m_isExceptionSet)
{
return m_exception;
}
return nullptr;
}
#endif //SQUID_USE_EXCEPTIONS
protected:
#if SQUID_USE_EXCEPTIONS
// Internal implementation of exception-setting (called by TaskInternal<> child classes)
void InternalSetUnhandledException(std::exception_ptr in_exception)
{
// NOTE: This must never be called more than once in the lifetime of an internal task
SQUID_RUNTIME_CHECK(!m_isExceptionSet, "Exception was set for a task after it had already been set");
if(!m_isExceptionSet)
{
m_exception = in_exception;
m_isExceptionSet = true;
}
}
#endif //SQUID_USE_EXCEPTIONS
private:
template <typename tRet> friend class TaskPromiseBase;
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable, typename promise_type> friend struct TaskAwaiterBase;
template <typename tRet, eTaskRef RefType, eTaskResumable Resumable> friend class Task;
// Kill this task
void Kill() // Kill() can safely be called multiple times
{
SQUID_RUNTIME_CHECK(m_internalState != eInternalState::Resuming, "Attempted to kill Task while resumed");
if(m_internalState == eInternalState::Idle)
{
// Mark task done
m_isDone = true;
// First destroy any sub-tasks
if(m_subTaskInternal)
{
m_subTaskInternal->Kill();
}
// Destroy the underlying std::coroutine_handle
m_coroHandle.destroy(); // This should only ever be called directly from this one place
m_coroHandle = nullptr;
m_taskReadyFn = nullptr; // Clear out the ready function
m_internalState = eInternalState::Destroyed;
}
}
// Done + can-resume status
void SetReadyFunction(const tTaskReadyFn& in_taskReadyFn)
{
m_taskReadyFn = in_taskReadyFn;
}
bool CanResume() const
{
if(IsDone())
{
return false;
}
if(m_subTaskInternal)
{
bool canResume = m_subTaskInternal->CanResume();
return canResume;
}
bool isReady = !m_taskReadyFn || m_taskReadyFn();
return isReady;
}
bool IsDone() const
{
return m_isDone;
}
bool m_isDone = false;
// Internal state
enum class eInternalState
{
Idle,
Resuming,
Destroyed,
};
eInternalState m_internalState = eInternalState::Idle;
// Task ready condition (when awaiting a TFunction<bool>)
tTaskReadyFn m_taskReadyFn;
#if SQUID_USE_EXCEPTIONS
// Exceptions
std::exception_ptr m_exception = nullptr;
bool m_isExceptionSet = false;
#endif //SQUID_USE_EXCEPTIONS
// Sub-task
TSharedPtr<TaskInternalBase> m_subTaskInternal;
// Reference-counting (determines underlying std::coroutine_handle lifetime, not lifetime of this internal task)
void AddLogicalRef()
{
++m_refCount;
}
void RemoveLogicalRef()
{
--m_refCount;
if(m_refCount == 0)
{
Kill();
}
}
int32_t m_refCount = 0; // Number of (strong) non-weak tasks referencing the internal task
// C++ std::coroutine_handle
std::coroutine_handle<> m_coroHandle;
// Stop request
bool m_isStopRequested = false;
TArray<TWeakPtr<TaskInternalBase>> m_stopTasks;
#if SQUID_ENABLE_TASK_DEBUG
// Debug Data
const char* m_debugName = "[unnamed task]";
TFunction<FString()> m_debugDataFn;
#endif //SQUID_ENABLE_TASK_DEBUG
};
//--- TaskInternal ---//
template <typename tRet>
class TaskInternal : public TaskInternalBase
{
public:
using promise_type = TaskPromise<tRet>;
TaskInternal(std::coroutine_handle<promise_type> in_handle)
: TaskInternalBase(in_handle)
{
auto& promisePtr = in_handle.promise();
promisePtr.SetInternalTask(this);
}
#if SQUID_USE_EXCEPTIONS
void SetUnhandledException(std::exception_ptr in_exception)
{
m_retValState = eTaskRetValState::Orphaned; // Return value can never be set if there was an unhandled exception
InternalSetUnhandledException(in_exception);
}
#endif //SQUID_USE_EXCEPTIONS
void SetReturnValue(const tRet& in_retVal)
{
tRet retVal = in_retVal;
SetReturnValue(MoveTemp(retVal));
}
void SetReturnValue(tRet&& in_retVal)
{
if(m_retValState == eTaskRetValState::Unset)
{
m_retVal = MoveTemp(in_retVal);
m_retValState = eTaskRetValState::Set;
return;
}
// These conditions should (logically) never be met, but they are included in case future changes violate that constraint
SQUID_RUNTIME_CHECK(m_retValState != eTaskRetValState::Set, "Attempted to set a task's return value when it was already set");
SQUID_RUNTIME_CHECK(m_retValState != eTaskRetValState::Taken, "Attempted to set a task's return value after it was already taken");
SQUID_RUNTIME_CHECK(m_retValState != eTaskRetValState::Orphaned, "Attempted to set a task's return value after it was orphaned");
}
TOptional<tRet> TakeReturnValue()
{
// If the value has been set, mark it as taken and move-return the value
if(m_retValState == eTaskRetValState::Set)
{
m_retValState = eTaskRetValState::Taken;
return MoveTemp(m_retVal);
}
// If the value was not set, return an unset optional (checking that it was neither taken nor orphaned)
SQUID_RUNTIME_CHECK(m_retValState != eTaskRetValState::Taken, "Attempted to take a task's return value after it was already successfully taken");
SQUID_RUNTIME_CHECK(m_retValState != eTaskRetValState::Orphaned, "Attempted to take a task's return value that will never be set (task ended prematurely)");
return {};
}
void OnTaskPromiseDestroyed()
{
// Mark the return value as orphaned if it was never set
m_retValState = eTaskRetValState::Orphaned;
}
private:
// Internal state
enum class eTaskRetValState
{
Unset, // Has not yet been set
Set, // Has been set and can be taken
Taken, // Has been taken and can no longer be taken
Orphaned, // Will never be set because the TaskPromise has been destroyed
};
eTaskRetValState m_retValState = eTaskRetValState::Unset; // Initially unset
TOptional<tRet> m_retVal;
};
template <>
class TaskInternal<void> : public TaskInternalBase
{
public:
using promise_type = TaskPromise<void>;
TaskInternal(std::coroutine_handle<promise_type> in_handle)
: TaskInternalBase(in_handle)
{
auto& promisePtr = in_handle.promise();
promisePtr.SetInternalTask(this);
}
#if SQUID_USE_EXCEPTIONS
void SetUnhandledException(std::exception_ptr in_exception)
{
InternalSetUnhandledException(in_exception);
}
#endif //SQUID_USE_EXCEPTIONS
void TakeReturnValue() // This function is an intentional no-op, to simplify certain templated function implementations
{
}
void OnTaskPromiseDestroyed()
{
}
};

View File

@@ -0,0 +1,236 @@
#pragma once
//--- User configuration header ---//
#include "../TasksConfig.h"
// Namespace macros (enabled/disabled via SQUID_ENABLE_NAMESPACE)
#if SQUID_ENABLE_NAMESPACE
#define NAMESPACE_SQUID_BEGIN namespace Squid {
#define NAMESPACE_SQUID_END }
#define NAMESPACE_SQUID Squid
#else
#define NAMESPACE_SQUID_BEGIN
#define NAMESPACE_SQUID_END
#define NAMESPACE_SQUID
namespace Squid {} // Convenience to allow 'using namespace Squid' even when namespace is disabled
#endif
// Exception macros (to support environments with exceptions disabled)
#if SQUID_USE_EXCEPTIONS && (defined(__cpp_exceptions) || defined(__EXCEPTIONS))
#include <stdexcept>
#define SQUID_THROW(exception, errStr) throw exception;
#define SQUID_RUNTIME_ERROR(errStr) throw std::runtime_error(errStr);
#define SQUID_RUNTIME_CHECK(condition, errStr) if(!(condition)) throw std::runtime_error(errStr);
#else
#include <cassert>
#define SQUID_THROW(exception, errStr) assert(false && errStr);
#define SQUID_RUNTIME_ERROR(errStr) assert(false && errStr);
#define SQUID_RUNTIME_CHECK(condition, errStr) assert((condition) && errStr);
#endif //__cpp_exceptions
// Time Interface
NAMESPACE_SQUID_BEGIN
#if SQUID_ENABLE_DOUBLE_PRECISION_TIME
using tTaskTime = double;
#else
using tTaskTime = float; // Defines time units for use with the Task system
#endif //SQUID_ENABLE_DOUBLE_PRECISION_TIME
NAMESPACE_SQUID_END
// Coroutine de-optimization macros [DEPRECATED]
#ifdef _MSC_VER
#if _MSC_VER >= 1920
// Newer versions of Visual Studio (>= VS2019) compile coroutines correctly
#define COROUTINE_OPTIMIZE_OFF
#define COROUTINE_OPTIMIZE_ON
#else
// Older versions of Visual Studio had code generation bugs when optimizing coroutines (they would compile, but have incorrect runtime results)
#define COROUTINE_OPTIMIZE_OFF __pragma(optimize("", off))
#define COROUTINE_OPTIMIZE_ON __pragma(optimize("", on))
#endif // _MSC_VER >= 1920
#else
// The Clang compiler has sometimes crashed when optimizing/inlining certain coroutines, so this macro can be used to disable inlining
#define COROUTINE_OPTIMIZE_OFF _Pragma("clang optimize off")
#define COROUTINE_OPTIMIZE_ON _Pragma("clang optimize on")
#endif
// False type for use in static_assert() [static_assert(false, ...) -> static_assert(static_false<T>, ...)]
template<typename T>
struct static_false : std::false_type
{
};
// Determine C++ Language Version
#if defined(_MSVC_LANG)
#define CPP_LANGUAGE_VERSION _MSVC_LANG
#elif defined(__cplusplus)
#define CPP_LANGUAGE_VERSION __cplusplus
#else
#define CPP_LANGUAGE_VERSION 0L
#endif
#if CPP_LANGUAGE_VERSION > 201703L // C++20 or higher
#define HAS_CXX17 1
#define HAS_CXX20 1
#elif CPP_LANGUAGE_VERSION > 201402L // C++17 or higher
#define HAS_CXX17 1
#define HAS_CXX20 0
#elif CPP_LANGUAGE_VERSION > 201103L // C++14 or higher
#define HAS_CXX17 0
#define HAS_CXX20 0
#else // C++11 or lower
#error Squid::Tasks requires C++14 or higher
#define HAS_CXX17 0
#define HAS_CXX20 0
#endif
#undef CPP_LANGUAGE_VERSION
// C++20 Compatibility (std::coroutine)
#if HAS_CXX20 || (defined(_MSVC_LANG) && defined(__cpp_lib_coroutine)) // Standard coroutines
#include <coroutine>
#define SQUID_EXPERIMENTAL_COROUTINES 0
#else // Experimental coroutines
#if defined(__clang__) && defined(_STL_COMPILER_PREPROCESSOR)
// HACK: Some distributions of clang don't have a <experimental/coroutine> header. We only need a few symbols, so just define them ourselves
namespace std {
namespace experimental {
inline namespace coroutines_v1 {
template <typename R, typename...> struct coroutine_traits {
using promise_type = typename R::promise_type;
};
template <typename Promise = void> struct coroutine_handle;
template <> struct coroutine_handle<void> {
static coroutine_handle from_address(void* addr) noexcept {
coroutine_handle me;
me.ptr = addr;
return me;
}
void operator()() { resume(); }
void* address() const { return ptr; }
void resume() const { __builtin_coro_resume(ptr); }
void destroy() const { __builtin_coro_destroy(ptr); }
bool done() const { return __builtin_coro_done(ptr); }
coroutine_handle& operator=(decltype(nullptr)) {
ptr = nullptr;
return *this;
}
coroutine_handle(decltype(nullptr)) : ptr(nullptr) {}
coroutine_handle() : ptr(nullptr) {}
// void reset() { ptr = nullptr; } // add to P0057?
explicit operator bool() const { return ptr; }
protected:
void* ptr;
};
template <typename Promise> struct coroutine_handle : coroutine_handle<> {
using coroutine_handle<>::operator=;
static coroutine_handle from_address(void* addr) noexcept {
coroutine_handle me;
me.ptr = addr;
return me;
}
Promise& promise() const {
return *reinterpret_cast<Promise*>(
__builtin_coro_promise(ptr, alignof(Promise), false));
}
static coroutine_handle from_promise(Promise& promise) {
coroutine_handle p;
p.ptr = __builtin_coro_promise(&promise, alignof(Promise), true);
return p;
}
};
template <typename _PromiseT>
bool operator==(coroutine_handle<_PromiseT> const& _Left,
coroutine_handle<_PromiseT> const& _Right) noexcept
{
return _Left.address() == _Right.address();
}
template <typename _PromiseT>
bool operator!=(coroutine_handle<_PromiseT> const& _Left,
coroutine_handle<_PromiseT> const& _Right) noexcept
{
return !(_Left == _Right);
}
struct suspend_always {
bool await_ready() noexcept { return false; }
void await_suspend(coroutine_handle<>) noexcept {}
void await_resume() noexcept {}
};
struct suspend_never {
bool await_ready() noexcept { return true; }
void await_suspend(coroutine_handle<>) noexcept {}
void await_resume() noexcept {}
};
}
}
}
#else
#include <experimental/coroutine>
#endif
namespace std // Alias experimental coroutine symbols into std namespace
{
template <class _Promise = void>
using coroutine_handle = experimental::coroutine_handle<_Promise>;
using suspend_never = experimental::suspend_never;
using suspend_always = experimental::suspend_always;
};
#define SQUID_EXPERIMENTAL_COROUTINES 1
#endif
// Determine whether our tasks need the member function "unhandled_exception()" defined or not
#if defined(_MSC_VER)
// MSVC's rules for exceptions differ between standard + experimental coroutines
#if SQUID_EXPERIMENTAL_COROUTINES
// If exceptions are enabled, we must define unhandled_exception()
#if defined(__cpp_exceptions) && __cpp_exceptions == 199711
#define SQUID_NEEDS_UNHANDLED_EXCEPTION 1
#else
#define SQUID_NEEDS_UNHANDLED_EXCEPTION 0
#endif
#else
// If we're using VS16.11 or newer -- or older than 16.10, we have one set of rules for standard coroutines
#if _MSC_FULL_VER >= 192930133L || _MSC_VER < 1429L
#define SQUID_NEEDS_UNHANDLED_EXCEPTION 1
#else
#if defined(__cpp_exceptions) && __cpp_exceptions == 199711
#define SQUID_NEEDS_UNHANDLED_EXCEPTION 1
#else
// 16.10 has a bug with their standard coroutine implementation that creates a set of contradicting requirements
// https://developercommunity.visualstudio.com/t/coroutine-uses-promise_type::unhandled_e/1374530
#error Visual Studio 16.10 has a compiler bug that prevents all coroutines from compiling when exceptions are disabled and using standard C++20 coroutines or /await:strict. Please either upgrade your version of Visual Studio, or use the experimental /await flag, or enable exceptions.
#endif
#endif
#endif
#else
// Clang always requires unhandled_exception() to be defined
#define SQUID_NEEDS_UNHANDLED_EXCEPTION 1
#endif
// C++17 Compatibility ([[nodiscard]])
#if !defined(SQUID_NODISCARD) && defined(__has_cpp_attribute)
#if __has_cpp_attribute(nodiscard)
#define SQUID_NODISCARD [[nodiscard]]
#endif
#endif
#ifndef SQUID_NODISCARD
#define SQUID_NODISCARD
#endif
#undef HAS_CXX17
#undef HAS_CXX20
// Include UE core headers
#include "CoreMinimal.h"
#include "Engine/World.h"
#include "Engine/Engine.h"
#include "Async/Future.h"