Split IP stuff into separate source and WIP implementation of name resolution.

This commit is contained in:
Patrick 2024-08-20 12:07:25 +02:00
parent 05f0e1474a
commit 8a611bf4f3
6 changed files with 264 additions and 82 deletions

View File

@ -0,0 +1,26 @@
#pragma once
#include "../../detect.hpp"
#if MIJIN_TARGET_OS == MIJIN_OS_LINUX
#include <fcntl.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#elif MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
#define _WINSOCK_DEPRECATED_NO_WARNINGS
#include <WinSock2.h>
#include <ws2tcpip.h>
#include "../util/winundef.hpp"
#endif // MIJIN_TARGET_OS
namespace mijin::detail
{
#if MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
bool initWSA() noexcept;
StreamError translateWSAError() noexcept;
StreamError translateWinError(DWORD error) noexcept;
StreamError translateWinError() noexcept;
#endif // MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
}// namespace mijin::detail

116
source/mijin/net/ip.cpp Normal file
View File

@ -0,0 +1,116 @@
#include "./ip.hpp"
#include "../detect.hpp"
#include "./detail/net_common.hpp"
namespace mijin
{
namespace
{
#if MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
struct WSAQueryContext
{
// WSA stuff
OVERLAPPED overlapped;
PADDRINFOEXA results;
HANDLE cancelHandle = nullptr;
// my stuff
StreamResult<std::vector<ip_address_t>> result;
};
using os_resolve_handle_t = WSAQueryContext;
void WINAPI getAddrComplete(DWORD error, DWORD bytes, LPOVERLAPPED overlapped) noexcept
{
(void) bytes;
WSAQueryContext& queryContext = *CONTAINING_RECORD(overlapped, WSAQueryContext, overlapped);
if (error != ERROR_SUCCESS)
{
queryContext.result = detail::translateWinError(error);
}
std::vector<ip_address_t> resultAddresses;
for (PADDRINFOEXA result = queryContext.results; result != nullptr; result = result->ai_next)
{
switch (result->ai_family)
{
case AF_INET:
{
sockaddr_in& addr = *reinterpret_cast<sockaddr_in*>(result->ai_addr);
resultAddresses.emplace_back(std::bit_cast<IPv4Address>(addr.sin_addr));
}
break;
case AF_INET6:
{
sockaddr_in6& addr = *reinterpret_cast<sockaddr_in6*>(result->ai_addr);
resultAddresses.emplace_back(std::bit_cast<IPv6Address>(addr.sin6_addr));
}
break;
default: break;
}
}
if (queryContext.results != nullptr)
{
// WTF is wrong with people at MS?
// you can't access FreeAddrInfoExA otherwise...
#if defined(FreeAddrInfoEx)
#undef FreeAddrInfoEx
#endif
FreeAddrInfoExA(queryContext.results);
}
}
StreamError osBeginResolve(const std::string& hostname, os_resolve_handle_t& queryContext) noexcept
{
if (!detail::initWSA())
{
return detail::translateWSAError();
}
ADDRINFOEXA hints = {.ai_family = AF_UNSPEC};
const int error = GetAddrInfoExA(
/* pName = */ hostname.c_str(),
/* pServiceName = */ nullptr,
/* dwNameSpace = */ NS_DNS,
/* lpNspId = */ nullptr,
/* hints = */ &hints,
/* ppResult = */ &queryContext.results,
/* timeout = */ nullptr,
/* lpOverlapped = */ &queryContext.overlapped,
/* lpCompletionRoutine = */ &getAddrComplete,
/* lpNameHandle = */ nullptr
);
if (error != WSA_IO_PENDING)
{
getAddrComplete(error, 0, &queryContext.overlapped);
}
return StreamError::SUCCESS;
}
bool osResolveDone(os_resolve_handle_t& queryContext) noexcept
{
return !queryContext.result.isEmpty();
}
StreamResult<std::vector<ip_address_t>> osResolveResult(os_resolve_handle_t& queryContext) noexcept
{
return queryContext.result;
}
#endif // MIJIN_TARGET_OS
}
Task<StreamResult<std::vector<ip_address_t>>> c_resolveHostname(std::string hostname) noexcept
{
os_resolve_handle_t resolveHandle;
if (StreamError error = osBeginResolve(hostname, resolveHandle); error != StreamError::SUCCESS)
{
co_return error;
}
while (!osResolveDone(resolveHandle))
{
co_await c_suspend();
}
co_return osResolveResult(resolveHandle);
}
}

53
source/mijin/net/ip.hpp Normal file
View File

@ -0,0 +1,53 @@
#pragma once
#if !defined(MIJIN_NET_IP_HPP_INCLUDED)
#define MIJIN_NET_IP_HPP_INCLUDED 1
#include <array>
#include <variant>
#include "../async/coroutine.hpp"
#include "../container/optional.hpp"
#include "../io/stream.hpp" // TODO: rename Stream{Error,Result} to IO{*}
namespace mijin
{
struct IPv4Address
{
std::array<std::uint8_t, 4> octets;
auto operator<=>(const IPv4Address&) const noexcept = default;
[[nodiscard]]
static Optional<IPv4Address> fromString(std::string_view stringView) noexcept;
};
struct IPv6Address
{
std::array<std::uint16_t, 8> hextets;
auto operator<=>(const IPv6Address&) const noexcept = default;
[[nodiscard]]
static Optional<IPv6Address> fromString(std::string_view stringView) noexcept;
};
using ip_address_t = std::variant<IPv4Address, IPv6Address>;
[[nodiscard]]
inline Optional<ip_address_t> ipAddressFromString(std::string_view stringView) noexcept
{
if (Optional<IPv4Address> ipv4Address = IPv4Address::fromString(stringView); !ipv4Address.empty())
{
return ip_address_t(*ipv4Address);
}
if (Optional<IPv6Address> ipv6Address = IPv6Address::fromString(stringView); !ipv6Address.empty())
{
return ip_address_t(*ipv6Address);
}
return NULL_OPTIONAL;
}
[[nodiscard]]
Task<StreamResult<std::vector<ip_address_t>>> c_resolveHostname(std::string hostname) noexcept;
}
#endif // !defined(MIJIN_NET_IP_HPP_INCLUDED)

View File

@ -1,23 +1,11 @@
#include "./socket.hpp"
#include <iostream>
#include "./detail/net_common.hpp"
#include "../detect.hpp"
#include "../util/string.hpp"
#include "../util/variant.hpp"
#if MIJIN_TARGET_OS == MIJIN_OS_LINUX
#include <fcntl.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#elif MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
#include <WinSock2.h>
#include <ws2tcpip.h>
#include "../util/winundef.hpp"
#endif
namespace mijin
{
namespace
@ -103,7 +91,19 @@ bool osSetSocketNonBlocking(int socket, bool blocking)
using in_addr_t = ULONG;
const char SOCKOPT_ONE = 1;
thread_local int numSocketsOpen = 0;
thread_local bool gWsaInited = false;
class WSAGuard
{
public:
~WSAGuard() noexcept
{
if (gWsaInited)
{
WSACleanup();
}
}
} thread_local [[maybe_unused]] gWsaGuard;
long osRecv(SOCKET socket, std::span<std::uint8_t> buffer, int flags)
{
@ -117,35 +117,16 @@ long osSend(SOCKET socket, std::span<const std::uint8_t> buffer, int flags)
SOCKET osCreateSocket(int addressFamily, int type, int protocol)
{
if (numSocketsOpen == 0)
if (!detail::initWSA())
{
WSADATA wsaData;
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0)
{
return INVALID_SOCKET_HANDLE;
}
return INVALID_SOCKET_HANDLE;
}
SOCKET result = socket(addressFamily, type, protocol);
if (result != INVALID_SOCKET_HANDLE)
{
++numSocketsOpen;
}
return result;
return socket(addressFamily, type, protocol);
}
int osCloseSocket(SOCKET socket)
{
const int result = closesocket(socket);
if (result == 0)
{
MIJIN_ASSERT(numSocketsOpen > 0, "Inbalanced calls to osOpenSocket and osCloseSocket!");
--numSocketsOpen;
if (numSocketsOpen == 0)
{
WSACleanup();
}
}
return result;
return closesocket(socket);
}
bool osIsSocketValid(SOCKET socket)
@ -161,6 +142,45 @@ bool osSetSocketNonBlocking(SOCKET socket, bool blocking)
#endif // MIJIN_TARGET_OS == MIJIN_OS_LINUX
}
namespace detail
{
#if MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
bool initWSA() noexcept
{
if (gWsaInited)
{
return true;
}
WSADATA wsaData;
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0)
{
return false;
}
gWsaInited = true;
return true;
}
StreamError translateWSAError() noexcept
{
// TODO
return StreamError::UNKNOWN_ERROR;
}
StreamError translateWinError(DWORD error) noexcept
{
// TODO
(void) error;
return StreamError::UNKNOWN_ERROR;
}
StreamError translateWinError() noexcept
{
return translateWinError(GetLastError());
}
#endif // MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
}// namespace impl
Optional<IPv4Address> IPv4Address::fromString(std::string_view stringView) noexcept
{
std::vector<std::string_view> parts = split(stringView, ".", {.limitParts = 4});
@ -280,11 +300,11 @@ mijin::Task<StreamError> TCPStream::c_readRaw(std::span<std::uint8_t> buffer, co
}
co_return StreamError::SUCCESS;
}
else if (bytesRead == 0)
if (bytesRead == 0)
{
co_return StreamError::CONNECTION_CLOSED;
}
else if (errno != EAGAIN)
if (errno != EAGAIN)
{
co_return translateErrno();
}
@ -292,7 +312,7 @@ mijin::Task<StreamError> TCPStream::c_readRaw(std::span<std::uint8_t> buffer, co
}
}
mijin::Task<StreamError> TCPStream::c_writeRaw(std::span<const std::uint8_t> buffer)
Task<StreamError> TCPStream::c_writeRaw(std::span<const std::uint8_t> buffer)
{
MIJIN_ASSERT(isOpen(), "Socket is not open.");
@ -318,7 +338,7 @@ mijin::Task<StreamError> TCPStream::c_writeRaw(std::span<const std::uint8_t> buf
{
co_return translateErrno();
}
co_await mijin::c_suspend();
co_await c_suspend();
}
}
@ -338,7 +358,7 @@ std::size_t TCPStream::tell()
return 0;
}
StreamError TCPStream::seek(std::intptr_t /* pos */, mijin::SeekMode /* seekMode */)
StreamError TCPStream::seek(std::intptr_t /* pos */, SeekMode /* seekMode */)
{
return StreamError::NOT_SUPPORTED;
}
@ -396,7 +416,7 @@ StreamError TCPStream::open(ip_address_t address, std::uint16_t port) noexcept
{
sockaddr_in6 connectAddress =
{
.sin6_family = AF_INET,
.sin6_family = AF_INET6,
.sin6_port = htons(port),
.sin6_addr = std::bit_cast<in6_addr>(address6)
};
@ -458,7 +478,7 @@ StreamError TCPServerSocket::setup(ip_address_t address, std::uint16_t port) noe
{
sockaddr_in6 bindAddress =
{
.sin6_family = AF_INET,
.sin6_family = AF_INET6,
.sin6_port = htons(port),
.sin6_addr = std::bit_cast<in6_addr>(address6)
};

View File

@ -4,8 +4,7 @@
#if !defined(MIJIN_NET_SOCKET_HPP_INCLUDED)
#define MIJIN_NET_SOCKET_HPP_INCLUDED 1
#include <array>
#include <variant>
#include "./ip.hpp"
#include "../detect.hpp"
#include "../async/coroutine.hpp"
#include "../container/optional.hpp"
@ -28,41 +27,6 @@ using socket_handle_t = int;
inline constexpr socket_handle_t INVALID_SOCKET_HANDLE = -1;
#endif
struct IPv4Address
{
std::array<std::uint8_t, 4> octets;
auto operator<=>(const IPv4Address&) const noexcept = default;
[[nodiscard]]
static Optional<IPv4Address> fromString(std::string_view stringView) noexcept;
};
struct IPv6Address
{
std::array<std::uint16_t, 8> hextets;
auto operator<=>(const IPv6Address&) const noexcept = default;
[[nodiscard]]
static Optional<IPv6Address> fromString(std::string_view stringView) noexcept;
};
using ip_address_t = std::variant<IPv4Address, IPv6Address>;
[[nodiscard]]
inline Optional<ip_address_t> ipAddressFromString(std::string_view stringView) noexcept
{
if (Optional<IPv4Address> ipv4Address = IPv4Address::fromString(stringView); !ipv4Address.empty())
{
return ip_address_t(*ipv4Address);
}
if (Optional<IPv6Address> ipv6Address = IPv6Address::fromString(stringView); !ipv6Address.empty())
{
return ip_address_t(*ipv6Address);
}
return NULL_OPTIONAL;
}
class Socket
{
protected:

View File

@ -299,7 +299,10 @@ template<typename TChar, typename TTraits, typename TNumber>
[[nodiscard]]
constexpr bool toNumber(std::basic_string_view<TChar, TTraits> stringView, TNumber& outNumber, int base = 10) noexcept requires (!std::is_same_v<TChar, char>)
{
std::string asString(stringView.begin(), stringView.end());
std::string asString;
asString.resize(stringView.size());
// should only contain number symbols, so just cast down to char
std::transform(stringView.begin(), stringView.end(), asString.begin(), [](TChar chr) { return static_cast<char>(chr); });
return toNumber(asString, outNumber, base);
}