292 lines
		
	
	
		
			7.9 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			292 lines
		
	
	
		
			7.9 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
 | 
						|
#include "./ip.hpp"
 | 
						|
 | 
						|
#include <format>
 | 
						|
 | 
						|
#include "../detect.hpp"
 | 
						|
#include "../util/string.hpp"
 | 
						|
#include "./detail/net_common.hpp"
 | 
						|
 | 
						|
#if MIJIN_TARGET_OS == MIJIN_OS_LINUX
 | 
						|
#if !defined(_GNU_SOURCE)
 | 
						|
#define _GNU_SOURCE
 | 
						|
#endif
 | 
						|
#include <netdb.h>
 | 
						|
#endif
 | 
						|
 | 
						|
namespace mijin
 | 
						|
{
 | 
						|
namespace
 | 
						|
{
 | 
						|
#if MIJIN_TARGET_OS == MIJIN_OS_LINUX
 | 
						|
struct AddrInfoContext
 | 
						|
{
 | 
						|
    gaicb item;
 | 
						|
    gaicb* list = &item;
 | 
						|
};
 | 
						|
using os_resolve_handle_t = AddrInfoContext;
 | 
						|
 | 
						|
StreamError translateGAIError(int error)
 | 
						|
{
 | 
						|
    (void) error; // TODO
 | 
						|
    return StreamError::UNKNOWN_ERROR;
 | 
						|
}
 | 
						|
 | 
						|
StreamError osBeginResolve(const std::string& hostname, os_resolve_handle_t& handle) noexcept
 | 
						|
{
 | 
						|
    handle.item = {.ar_name = hostname.c_str()};
 | 
						|
 | 
						|
    const int result = getaddrinfo_a(GAI_NOWAIT, &handle.list, 1, nullptr);
 | 
						|
    if (result != 0)
 | 
						|
    {
 | 
						|
        return StreamError::UNKNOWN_ERROR;
 | 
						|
    }
 | 
						|
    return StreamError::SUCCESS;
 | 
						|
}
 | 
						|
 | 
						|
bool osResolveDone(os_resolve_handle_t& handle) noexcept
 | 
						|
{
 | 
						|
    return gai_error(&handle.item) != EAI_INPROGRESS;
 | 
						|
}
 | 
						|
 | 
						|
StreamResult<std::vector<ip_address_t>> osResolveResult(os_resolve_handle_t& handle) noexcept
 | 
						|
{
 | 
						|
    if (const int error = gai_error(&handle.item); error != 0)
 | 
						|
    {
 | 
						|
        if (handle.item.ar_result != nullptr)
 | 
						|
        {
 | 
						|
            freeaddrinfo(handle.item.ar_result);
 | 
						|
        }
 | 
						|
        return translateGAIError(error);
 | 
						|
    }
 | 
						|
    if (handle.item.ar_result == nullptr)
 | 
						|
    {
 | 
						|
        return StreamError::UNKNOWN_ERROR;
 | 
						|
    }
 | 
						|
    std::vector<ip_address_t> resultAddresses;
 | 
						|
    for (addrinfo* result = handle.item.ar_result; result != nullptr; result = result->ai_next)
 | 
						|
    {
 | 
						|
        if (result->ai_protocol != IPPROTO_TCP)
 | 
						|
        {
 | 
						|
            // we actually just care about TCP, right?
 | 
						|
            continue;
 | 
						|
        }
 | 
						|
        switch (result->ai_family)
 | 
						|
        {
 | 
						|
#if __BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__
 | 
						|
#error "TODO: swap byte order of the address"
 | 
						|
#endif
 | 
						|
            case AF_INET:
 | 
						|
            {
 | 
						|
                sockaddr_in& addr = *reinterpret_cast<sockaddr_in*>(result->ai_addr);
 | 
						|
                resultAddresses.emplace_back(std::bit_cast<IPv4Address>(addr.sin_addr));
 | 
						|
                break;
 | 
						|
            }
 | 
						|
            case AF_INET6:
 | 
						|
            {
 | 
						|
                sockaddr_in6& addr = *reinterpret_cast<sockaddr_in6*>(result->ai_addr);
 | 
						|
                IPv6Address addr6 = std::bit_cast<IPv6Address>(addr.sin6_addr);
 | 
						|
                for (std::uint16_t& hextet : addr6.hextets)
 | 
						|
                {
 | 
						|
                    hextet = ntohs(hextet);
 | 
						|
                }
 | 
						|
                resultAddresses.emplace_back(addr6);
 | 
						|
                break;
 | 
						|
            }
 | 
						|
            default: break;
 | 
						|
        }
 | 
						|
    }
 | 
						|
    freeaddrinfo(handle.item.ar_result);
 | 
						|
    return resultAddresses;
 | 
						|
}
 | 
						|
#elif MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
 | 
						|
struct WSAQueryContext
 | 
						|
{
 | 
						|
    // WSA stuff
 | 
						|
    OVERLAPPED overlapped = {};
 | 
						|
    PADDRINFOEX results;
 | 
						|
    HANDLE cancelHandle = nullptr;
 | 
						|
 | 
						|
    // my stuff
 | 
						|
    StreamResult<std::vector<ip_address_t>> result;
 | 
						|
};
 | 
						|
using os_resolve_handle_t = WSAQueryContext;
 | 
						|
 | 
						|
void WINAPI getAddrComplete(DWORD error, DWORD bytes, LPOVERLAPPED overlapped) noexcept
 | 
						|
{
 | 
						|
    (void) bytes;
 | 
						|
 | 
						|
    WSAQueryContext& queryContext = *CONTAINING_RECORD(overlapped, WSAQueryContext, overlapped);
 | 
						|
    if (error != ERROR_SUCCESS)
 | 
						|
    {
 | 
						|
        queryContext.result = detail::translateWinError(error);
 | 
						|
    }
 | 
						|
    std::vector<ip_address_t> resultAddresses;
 | 
						|
    for (PADDRINFOEX result = queryContext.results; result != nullptr; result = result->ai_next)
 | 
						|
    {
 | 
						|
        switch (result->ai_family)
 | 
						|
        {
 | 
						|
        case AF_INET:
 | 
						|
            {
 | 
						|
                sockaddr_in& addr = *reinterpret_cast<sockaddr_in*>(result->ai_addr);
 | 
						|
                resultAddresses.emplace_back(std::bit_cast<IPv4Address>(addr.sin_addr));
 | 
						|
            }
 | 
						|
            break;
 | 
						|
        case AF_INET6:
 | 
						|
            {
 | 
						|
                sockaddr_in6& addr = *reinterpret_cast<sockaddr_in6*>(result->ai_addr);
 | 
						|
                IPv6Address addr6 = std::bit_cast<IPv6Address>(addr.sin6_addr);
 | 
						|
                for (std::uint16_t& hextet : addr6.hextets)
 | 
						|
                {
 | 
						|
                    hextet = ntohs(hextet);
 | 
						|
                }
 | 
						|
                resultAddresses.emplace_back(addr6);
 | 
						|
            }
 | 
						|
            break;
 | 
						|
        default: break;
 | 
						|
        }
 | 
						|
    }
 | 
						|
    if (queryContext.results != nullptr)
 | 
						|
    {
 | 
						|
        FreeAddrInfoEx(queryContext.results);
 | 
						|
    }
 | 
						|
    queryContext.result = std::move(resultAddresses);
 | 
						|
}
 | 
						|
 | 
						|
StreamError osBeginResolve(const std::string& hostname, os_resolve_handle_t& queryContext) noexcept
 | 
						|
{
 | 
						|
    if (!detail::initWSA())
 | 
						|
    {
 | 
						|
        return detail::translateWSAError();
 | 
						|
    }
 | 
						|
    ADDRINFOEX hints = {.ai_family = AF_UNSPEC};
 | 
						|
 | 
						|
    std::wstring hostnameW(hostname.begin(), hostname.end());
 | 
						|
    const int error = GetAddrInfoEx(
 | 
						|
        /* pName = */ hostnameW.c_str(),
 | 
						|
        /* pServiceName = */ nullptr,
 | 
						|
        /* dwNameSpace = */ NS_DNS,
 | 
						|
        /* lpNspId = */ nullptr,
 | 
						|
        /* hints = */ &hints,
 | 
						|
        /* ppResult = */ &queryContext.results,
 | 
						|
        /* timeout = */ nullptr,
 | 
						|
        /* lpOverlapped = */ &queryContext.overlapped,
 | 
						|
        /* lpCompletionRoutine = */ &getAddrComplete,
 | 
						|
        /* lpNameHandle = */ nullptr
 | 
						|
    );
 | 
						|
    if (error != WSA_IO_PENDING)
 | 
						|
    {
 | 
						|
        getAddrComplete(error, 0, &queryContext.overlapped);
 | 
						|
    }
 | 
						|
    return StreamError::SUCCESS;
 | 
						|
}
 | 
						|
 | 
						|
bool osResolveDone(os_resolve_handle_t& queryContext) noexcept
 | 
						|
{
 | 
						|
    return !queryContext.result.isEmpty();
 | 
						|
}
 | 
						|
 | 
						|
StreamResult<std::vector<ip_address_t>> osResolveResult(os_resolve_handle_t& queryContext) noexcept
 | 
						|
{
 | 
						|
    return queryContext.result;
 | 
						|
}
 | 
						|
#endif // MIJIN_TARGET_OS
 | 
						|
}
 | 
						|
 | 
						|
std::string IPv4Address::toString() const
 | 
						|
{
 | 
						|
    return std::format("{}.{}.{}.{}", octets[0], octets[1], octets[2], octets[3]);
 | 
						|
}
 | 
						|
 | 
						|
std::string IPv6Address::toString() const
 | 
						|
{
 | 
						|
    return std::format("{:x}:{:x}:{:x}:{:x}:{:x}:{:x}:{:x}:{:x}", hextets[0], hextets[1], hextets[2], hextets[3], hextets[4],
 | 
						|
                       hextets[5], hextets[6], hextets[7]);
 | 
						|
}
 | 
						|
 | 
						|
Optional<IPv4Address> IPv4Address::fromString(std::string_view stringView) noexcept
 | 
						|
{
 | 
						|
    std::vector<std::string_view> parts = split(stringView, ".", {.limitParts = 4});
 | 
						|
    if (parts.size() != 4) {
 | 
						|
        return NULL_OPTIONAL;
 | 
						|
    }
 | 
						|
    IPv4Address address;
 | 
						|
    for (int idx = 0; idx < 4; ++idx)
 | 
						|
    {
 | 
						|
        if (!toNumber(parts[idx], address.octets[idx]))
 | 
						|
        {
 | 
						|
            return NULL_OPTIONAL;
 | 
						|
        }
 | 
						|
    }
 | 
						|
    return address;
 | 
						|
}
 | 
						|
 | 
						|
Optional<IPv6Address> IPv6Address::fromString(std::string_view stringView) noexcept
 | 
						|
{
 | 
						|
    // very specific edge case
 | 
						|
    if (stringView.contains(":::"))
 | 
						|
    {
 | 
						|
        return NULL_OPTIONAL;
 | 
						|
    }
 | 
						|
 | 
						|
    std::vector<std::string_view> parts = split(stringView, "::", {.ignoreEmpty = false});
 | 
						|
    if (parts.size() > 2)
 | 
						|
    {
 | 
						|
        return NULL_OPTIONAL;
 | 
						|
    }
 | 
						|
    if (parts.size() == 1)
 | 
						|
    {
 | 
						|
        parts.emplace_back("");
 | 
						|
    }
 | 
						|
 | 
						|
    std::vector<std::string_view> partsLeft = split(parts[0], ":");
 | 
						|
    std::vector<std::string_view> partsRight = split(parts[1], ":");
 | 
						|
 | 
						|
    std::erase_if(partsLeft, std::mem_fn(&std::string_view::empty));
 | 
						|
    std::erase_if(partsRight, std::mem_fn(&std::string_view::empty));
 | 
						|
 | 
						|
    if (partsLeft.size() + partsRight.size() > 8)
 | 
						|
    {
 | 
						|
        return NULL_OPTIONAL;
 | 
						|
    }
 | 
						|
 | 
						|
    IPv6Address address = {};
 | 
						|
    unsigned hextet = 0;
 | 
						|
    for (std::string_view part : partsLeft)
 | 
						|
    {
 | 
						|
        if (!toNumber(part, address.hextets[hextet], /* base = */ 16))
 | 
						|
        {
 | 
						|
            return NULL_OPTIONAL;
 | 
						|
        }
 | 
						|
        ++hextet;
 | 
						|
    }
 | 
						|
    for (; hextet < (8 - partsRight.size()); ++hextet)
 | 
						|
    {
 | 
						|
        address.hextets[hextet] = 0;
 | 
						|
    }
 | 
						|
    for (std::string_view part : partsRight)
 | 
						|
    {
 | 
						|
        if (!toNumber(part, address.hextets[hextet], /* base = */ 16))
 | 
						|
        {
 | 
						|
            return NULL_OPTIONAL;
 | 
						|
        }
 | 
						|
        ++hextet;
 | 
						|
    }
 | 
						|
    return address;
 | 
						|
}
 | 
						|
 | 
						|
Task<StreamResult<std::vector<ip_address_t>>> c_resolveHostname(std::string hostname) noexcept
 | 
						|
{
 | 
						|
    os_resolve_handle_t resolveHandle;
 | 
						|
    if (StreamError error = osBeginResolve(hostname, resolveHandle); error != StreamError::SUCCESS)
 | 
						|
    {
 | 
						|
        co_return error;
 | 
						|
    }
 | 
						|
    while (!osResolveDone(resolveHandle))
 | 
						|
    {
 | 
						|
        co_await c_suspend();
 | 
						|
    }
 | 
						|
    co_return osResolveResult(resolveHandle);
 | 
						|
}
 | 
						|
}
 |