diff --git a/source/mijin/net/detail/net_common.hpp b/source/mijin/net/detail/net_common.hpp new file mode 100644 index 0000000..47fe2f9 --- /dev/null +++ b/source/mijin/net/detail/net_common.hpp @@ -0,0 +1,26 @@ + +#pragma once + +#include "../../detect.hpp" + +#if MIJIN_TARGET_OS == MIJIN_OS_LINUX +#include +#include +#include +#include +#elif MIJIN_TARGET_OS == MIJIN_OS_WINDOWS +#define _WINSOCK_DEPRECATED_NO_WARNINGS +#include +#include +#include "../util/winundef.hpp" +#endif // MIJIN_TARGET_OS + +namespace mijin::detail +{ +#if MIJIN_TARGET_OS == MIJIN_OS_WINDOWS +bool initWSA() noexcept; +StreamError translateWSAError() noexcept; +StreamError translateWinError(DWORD error) noexcept; +StreamError translateWinError() noexcept; +#endif // MIJIN_TARGET_OS == MIJIN_OS_WINDOWS +}// namespace mijin::detail diff --git a/source/mijin/net/ip.cpp b/source/mijin/net/ip.cpp new file mode 100644 index 0000000..5e54d72 --- /dev/null +++ b/source/mijin/net/ip.cpp @@ -0,0 +1,116 @@ + +#include "./ip.hpp" + +#include "../detect.hpp" +#include "./detail/net_common.hpp" + +namespace mijin +{ +namespace +{ +#if MIJIN_TARGET_OS == MIJIN_OS_WINDOWS +struct WSAQueryContext +{ + // WSA stuff + OVERLAPPED overlapped; + PADDRINFOEXA results; + HANDLE cancelHandle = nullptr; + + // my stuff + StreamResult> result; +}; +using os_resolve_handle_t = WSAQueryContext; + +void WINAPI getAddrComplete(DWORD error, DWORD bytes, LPOVERLAPPED overlapped) noexcept +{ + (void) bytes; + + WSAQueryContext& queryContext = *CONTAINING_RECORD(overlapped, WSAQueryContext, overlapped); + if (error != ERROR_SUCCESS) + { + queryContext.result = detail::translateWinError(error); + } + std::vector resultAddresses; + for (PADDRINFOEXA result = queryContext.results; result != nullptr; result = result->ai_next) + { + switch (result->ai_family) + { + case AF_INET: + { + sockaddr_in& addr = *reinterpret_cast(result->ai_addr); + resultAddresses.emplace_back(std::bit_cast(addr.sin_addr)); + } + break; + case AF_INET6: + { + sockaddr_in6& addr = *reinterpret_cast(result->ai_addr); + resultAddresses.emplace_back(std::bit_cast(addr.sin6_addr)); + } + break; + default: break; + } + } + if (queryContext.results != nullptr) + { + // WTF is wrong with people at MS? + // you can't access FreeAddrInfoExA otherwise... +#if defined(FreeAddrInfoEx) +#undef FreeAddrInfoEx +#endif + FreeAddrInfoExA(queryContext.results); + } +} + +StreamError osBeginResolve(const std::string& hostname, os_resolve_handle_t& queryContext) noexcept +{ + if (!detail::initWSA()) + { + return detail::translateWSAError(); + } + ADDRINFOEXA hints = {.ai_family = AF_UNSPEC}; + + const int error = GetAddrInfoExA( + /* pName = */ hostname.c_str(), + /* pServiceName = */ nullptr, + /* dwNameSpace = */ NS_DNS, + /* lpNspId = */ nullptr, + /* hints = */ &hints, + /* ppResult = */ &queryContext.results, + /* timeout = */ nullptr, + /* lpOverlapped = */ &queryContext.overlapped, + /* lpCompletionRoutine = */ &getAddrComplete, + /* lpNameHandle = */ nullptr + ); + if (error != WSA_IO_PENDING) + { + getAddrComplete(error, 0, &queryContext.overlapped); + } + return StreamError::SUCCESS; +} + +bool osResolveDone(os_resolve_handle_t& queryContext) noexcept +{ + return !queryContext.result.isEmpty(); +} + +StreamResult> osResolveResult(os_resolve_handle_t& queryContext) noexcept +{ + return queryContext.result; +} +#endif // MIJIN_TARGET_OS +} + +Task>> c_resolveHostname(std::string hostname) noexcept +{ + os_resolve_handle_t resolveHandle; + if (StreamError error = osBeginResolve(hostname, resolveHandle); error != StreamError::SUCCESS) + { + co_return error; + } + while (!osResolveDone(resolveHandle)) + { + co_await c_suspend(); + } + co_return osResolveResult(resolveHandle); +} +} diff --git a/source/mijin/net/ip.hpp b/source/mijin/net/ip.hpp new file mode 100644 index 0000000..8166328 --- /dev/null +++ b/source/mijin/net/ip.hpp @@ -0,0 +1,53 @@ +#pragma once + +#if !defined(MIJIN_NET_IP_HPP_INCLUDED) +#define MIJIN_NET_IP_HPP_INCLUDED 1 + +#include +#include +#include "../async/coroutine.hpp" +#include "../container/optional.hpp" +#include "../io/stream.hpp" // TODO: rename Stream{Error,Result} to IO{*} + +namespace mijin +{ +struct IPv4Address +{ + std::array octets; + + auto operator<=>(const IPv4Address&) const noexcept = default; + + [[nodiscard]] + static Optional fromString(std::string_view stringView) noexcept; +}; + +struct IPv6Address +{ + std::array hextets; + + auto operator<=>(const IPv6Address&) const noexcept = default; + + [[nodiscard]] + static Optional fromString(std::string_view stringView) noexcept; +}; +using ip_address_t = std::variant; + +[[nodiscard]] +inline Optional ipAddressFromString(std::string_view stringView) noexcept +{ + if (Optional ipv4Address = IPv4Address::fromString(stringView); !ipv4Address.empty()) + { + return ip_address_t(*ipv4Address); + } + if (Optional ipv6Address = IPv6Address::fromString(stringView); !ipv6Address.empty()) + { + return ip_address_t(*ipv6Address); + } + return NULL_OPTIONAL; +} + +[[nodiscard]] +Task>> c_resolveHostname(std::string hostname) noexcept; +} + +#endif // !defined(MIJIN_NET_IP_HPP_INCLUDED) diff --git a/source/mijin/net/socket.cpp b/source/mijin/net/socket.cpp index e71a2a1..10e1d23 100644 --- a/source/mijin/net/socket.cpp +++ b/source/mijin/net/socket.cpp @@ -1,23 +1,11 @@ #include "./socket.hpp" -#include - +#include "./detail/net_common.hpp" #include "../detect.hpp" #include "../util/string.hpp" #include "../util/variant.hpp" -#if MIJIN_TARGET_OS == MIJIN_OS_LINUX -#include -#include -#include -#include -#elif MIJIN_TARGET_OS == MIJIN_OS_WINDOWS -#include -#include -#include "../util/winundef.hpp" -#endif - namespace mijin { namespace @@ -103,7 +91,19 @@ bool osSetSocketNonBlocking(int socket, bool blocking) using in_addr_t = ULONG; const char SOCKOPT_ONE = 1; -thread_local int numSocketsOpen = 0; +thread_local bool gWsaInited = false; + +class WSAGuard +{ +public: + ~WSAGuard() noexcept + { + if (gWsaInited) + { + WSACleanup(); + } + } +} thread_local [[maybe_unused]] gWsaGuard; long osRecv(SOCKET socket, std::span buffer, int flags) { @@ -117,35 +117,16 @@ long osSend(SOCKET socket, std::span buffer, int flags) SOCKET osCreateSocket(int addressFamily, int type, int protocol) { - if (numSocketsOpen == 0) + if (!detail::initWSA()) { - WSADATA wsaData; - if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) - { - return INVALID_SOCKET_HANDLE; - } + return INVALID_SOCKET_HANDLE; } - SOCKET result = socket(addressFamily, type, protocol); - if (result != INVALID_SOCKET_HANDLE) - { - ++numSocketsOpen; - } - return result; + return socket(addressFamily, type, protocol); } int osCloseSocket(SOCKET socket) { - const int result = closesocket(socket); - if (result == 0) - { - MIJIN_ASSERT(numSocketsOpen > 0, "Inbalanced calls to osOpenSocket and osCloseSocket!"); - --numSocketsOpen; - if (numSocketsOpen == 0) - { - WSACleanup(); - } - } - return result; + return closesocket(socket); } bool osIsSocketValid(SOCKET socket) @@ -161,6 +142,45 @@ bool osSetSocketNonBlocking(SOCKET socket, bool blocking) #endif // MIJIN_TARGET_OS == MIJIN_OS_LINUX } +namespace detail +{ +#if MIJIN_TARGET_OS == MIJIN_OS_WINDOWS +bool initWSA() noexcept +{ + if (gWsaInited) + { + return true; + } + + WSADATA wsaData; + if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) + { + return false; + } + gWsaInited = true; + return true; +} + +StreamError translateWSAError() noexcept +{ + // TODO + return StreamError::UNKNOWN_ERROR; +} + +StreamError translateWinError(DWORD error) noexcept +{ + // TODO + (void) error; + return StreamError::UNKNOWN_ERROR; +} + +StreamError translateWinError() noexcept +{ + return translateWinError(GetLastError()); +} +#endif // MIJIN_TARGET_OS == MIJIN_OS_WINDOWS +}// namespace impl + Optional IPv4Address::fromString(std::string_view stringView) noexcept { std::vector parts = split(stringView, ".", {.limitParts = 4}); @@ -280,11 +300,11 @@ mijin::Task TCPStream::c_readRaw(std::span buffer, co } co_return StreamError::SUCCESS; } - else if (bytesRead == 0) + if (bytesRead == 0) { co_return StreamError::CONNECTION_CLOSED; } - else if (errno != EAGAIN) + if (errno != EAGAIN) { co_return translateErrno(); } @@ -292,7 +312,7 @@ mijin::Task TCPStream::c_readRaw(std::span buffer, co } } -mijin::Task TCPStream::c_writeRaw(std::span buffer) +Task TCPStream::c_writeRaw(std::span buffer) { MIJIN_ASSERT(isOpen(), "Socket is not open."); @@ -318,7 +338,7 @@ mijin::Task TCPStream::c_writeRaw(std::span buf { co_return translateErrno(); } - co_await mijin::c_suspend(); + co_await c_suspend(); } } @@ -338,7 +358,7 @@ std::size_t TCPStream::tell() return 0; } -StreamError TCPStream::seek(std::intptr_t /* pos */, mijin::SeekMode /* seekMode */) +StreamError TCPStream::seek(std::intptr_t /* pos */, SeekMode /* seekMode */) { return StreamError::NOT_SUPPORTED; } @@ -396,7 +416,7 @@ StreamError TCPStream::open(ip_address_t address, std::uint16_t port) noexcept { sockaddr_in6 connectAddress = { - .sin6_family = AF_INET, + .sin6_family = AF_INET6, .sin6_port = htons(port), .sin6_addr = std::bit_cast(address6) }; @@ -458,7 +478,7 @@ StreamError TCPServerSocket::setup(ip_address_t address, std::uint16_t port) noe { sockaddr_in6 bindAddress = { - .sin6_family = AF_INET, + .sin6_family = AF_INET6, .sin6_port = htons(port), .sin6_addr = std::bit_cast(address6) }; diff --git a/source/mijin/net/socket.hpp b/source/mijin/net/socket.hpp index eef95cf..39672e9 100644 --- a/source/mijin/net/socket.hpp +++ b/source/mijin/net/socket.hpp @@ -4,8 +4,7 @@ #if !defined(MIJIN_NET_SOCKET_HPP_INCLUDED) #define MIJIN_NET_SOCKET_HPP_INCLUDED 1 -#include -#include +#include "./ip.hpp" #include "../detect.hpp" #include "../async/coroutine.hpp" #include "../container/optional.hpp" @@ -28,41 +27,6 @@ using socket_handle_t = int; inline constexpr socket_handle_t INVALID_SOCKET_HANDLE = -1; #endif -struct IPv4Address -{ - std::array octets; - - auto operator<=>(const IPv4Address&) const noexcept = default; - - [[nodiscard]] - static Optional fromString(std::string_view stringView) noexcept; -}; - -struct IPv6Address -{ - std::array hextets; - - auto operator<=>(const IPv6Address&) const noexcept = default; - - [[nodiscard]] - static Optional fromString(std::string_view stringView) noexcept; -}; -using ip_address_t = std::variant; - -[[nodiscard]] -inline Optional ipAddressFromString(std::string_view stringView) noexcept -{ - if (Optional ipv4Address = IPv4Address::fromString(stringView); !ipv4Address.empty()) - { - return ip_address_t(*ipv4Address); - } - if (Optional ipv6Address = IPv6Address::fromString(stringView); !ipv6Address.empty()) - { - return ip_address_t(*ipv6Address); - } - return NULL_OPTIONAL; -} - class Socket { protected: diff --git a/source/mijin/util/string.hpp b/source/mijin/util/string.hpp index 5189f9a..91221e6 100644 --- a/source/mijin/util/string.hpp +++ b/source/mijin/util/string.hpp @@ -299,7 +299,10 @@ template [[nodiscard]] constexpr bool toNumber(std::basic_string_view stringView, TNumber& outNumber, int base = 10) noexcept requires (!std::is_same_v) { - std::string asString(stringView.begin(), stringView.end()); + std::string asString; + asString.resize(stringView.size()); + // should only contain number symbols, so just cast down to char + std::transform(stringView.begin(), stringView.end(), asString.begin(), [](TChar chr) { return static_cast(chr); }); return toNumber(asString, outNumber, base); }