117 lines
3.1 KiB
C++

#include "./ip.hpp"
#include "../detect.hpp"
#include "./detail/net_common.hpp"
namespace mijin
{
namespace
{
#if MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
struct WSAQueryContext
{
// WSA stuff
OVERLAPPED overlapped;
PADDRINFOEXA results;
HANDLE cancelHandle = nullptr;
// my stuff
StreamResult<std::vector<ip_address_t>> result;
};
using os_resolve_handle_t = WSAQueryContext;
void WINAPI getAddrComplete(DWORD error, DWORD bytes, LPOVERLAPPED overlapped) noexcept
{
(void) bytes;
WSAQueryContext& queryContext = *CONTAINING_RECORD(overlapped, WSAQueryContext, overlapped);
if (error != ERROR_SUCCESS)
{
queryContext.result = detail::translateWinError(error);
}
std::vector<ip_address_t> resultAddresses;
for (PADDRINFOEXA result = queryContext.results; result != nullptr; result = result->ai_next)
{
switch (result->ai_family)
{
case AF_INET:
{
sockaddr_in& addr = *reinterpret_cast<sockaddr_in*>(result->ai_addr);
resultAddresses.emplace_back(std::bit_cast<IPv4Address>(addr.sin_addr));
}
break;
case AF_INET6:
{
sockaddr_in6& addr = *reinterpret_cast<sockaddr_in6*>(result->ai_addr);
resultAddresses.emplace_back(std::bit_cast<IPv6Address>(addr.sin6_addr));
}
break;
default: break;
}
}
if (queryContext.results != nullptr)
{
// WTF is wrong with people at MS?
// you can't access FreeAddrInfoExA otherwise...
#if defined(FreeAddrInfoEx)
#undef FreeAddrInfoEx
#endif
FreeAddrInfoExA(queryContext.results);
}
}
StreamError osBeginResolve(const std::string& hostname, os_resolve_handle_t& queryContext) noexcept
{
if (!detail::initWSA())
{
return detail::translateWSAError();
}
ADDRINFOEXA hints = {.ai_family = AF_UNSPEC};
const int error = GetAddrInfoExA(
/* pName = */ hostname.c_str(),
/* pServiceName = */ nullptr,
/* dwNameSpace = */ NS_DNS,
/* lpNspId = */ nullptr,
/* hints = */ &hints,
/* ppResult = */ &queryContext.results,
/* timeout = */ nullptr,
/* lpOverlapped = */ &queryContext.overlapped,
/* lpCompletionRoutine = */ &getAddrComplete,
/* lpNameHandle = */ nullptr
);
if (error != WSA_IO_PENDING)
{
getAddrComplete(error, 0, &queryContext.overlapped);
}
return StreamError::SUCCESS;
}
bool osResolveDone(os_resolve_handle_t& queryContext) noexcept
{
return !queryContext.result.isEmpty();
}
StreamResult<std::vector<ip_address_t>> osResolveResult(os_resolve_handle_t& queryContext) noexcept
{
return queryContext.result;
}
#endif // MIJIN_TARGET_OS
}
Task<StreamResult<std::vector<ip_address_t>>> c_resolveHostname(std::string hostname) noexcept
{
os_resolve_handle_t resolveHandle;
if (StreamError error = osBeginResolve(hostname, resolveHandle); error != StreamError::SUCCESS)
{
co_return error;
}
while (!osResolveDone(resolveHandle))
{
co_await c_suspend();
}
co_return osResolveResult(resolveHandle);
}
}