From df260808b9724ba6a585e87dac1550432eb7ff04 Mon Sep 17 00:00:00 2001 From: Patrick Wuttke Date: Mon, 19 Aug 2024 18:35:55 +0200 Subject: [PATCH] Implemented/fixed Windows/MSVC support for sockets. --- dependencies.json | 4 + source/mijin/net/socket.cpp | 214 ++++++++++++++++++++++++++------- source/mijin/net/socket.hpp | 27 ++++- source/mijin/util/string.hpp | 7 +- source/mijin/util/winundef.hpp | 16 +++ 5 files changed, 217 insertions(+), 51 deletions(-) create mode 100644 source/mijin/util/winundef.hpp diff --git a/dependencies.json b/dependencies.json index 482b106..e2931e7 100644 --- a/dependencies.json +++ b/dependencies.json @@ -2,5 +2,9 @@ "libbacktrace": { "condition": "compiler_family == 'gcc' or compiler_family == 'clang'" + }, + "winsock2": + { + "condition": "target_os == 'nt'" } } diff --git a/source/mijin/net/socket.cpp b/source/mijin/net/socket.cpp index 4236556..e6f9949 100644 --- a/source/mijin/net/socket.cpp +++ b/source/mijin/net/socket.cpp @@ -5,13 +5,17 @@ #include "../detect.hpp" #include "../util/string.hpp" +#include "../util/variant.hpp" #if MIJIN_TARGET_OS == MIJIN_OS_LINUX #include #include #include #include -#include "../util/variant.hpp" +#elif MIJIN_TARGET_OS == MIJIN_OS_WINDOWS +#include +#include +#include "../util/winundef.hpp" #endif namespace mijin @@ -23,11 +27,22 @@ StreamError translateErrno() noexcept { switch (errno) { - default: - return StreamError::UNKNOWN_ERROR; + case EIO: + return StreamError::IO_ERROR; + default: + return StreamError::UNKNOWN_ERROR; } } +int readFlags(const ReadOptions& options) +{ + return (options.partial ? 0 : MSG_WAITALL) + | (options.peek ? MSG_PEEK : 0); +} + +#if MIJIN_TARGET_OS == MIJIN_OS_LINUX +const int SOCKOPT_ONE = 1; + bool appendSocketFlags(int handle, int flags) noexcept { const int currentFlags = fcntl(handle, F_GETFL); @@ -48,11 +63,102 @@ bool removeSocketFlags(int handle, int flags) noexcept return fcntl(handle, F_SETFL, currentFlags & ~flags) >= 0; } -int readFlags(const ReadOptions& options) +long osRecv(int socket, std::span buffer, int flags) { - return (options.partial ? 0 : MSG_WAITALL) - | (options.peek ? MSG_PEEK : 0); + return static_cast(recv(socket, buffer.data(), buffer.size(), flags); } + +long osSend(int socket, std::span buffer, int flags) +{ + return static_cast(send(handle_, buffer.data(), buffer.size(), flags)); +} + +int osCreateSocket(int domain, int type, int protocol) +{ + return socket(domain, type, protocol); +} + +int osCloseSocket(int socket) +{ + return ::close(socket); +} + +bool osIsSocketValid(int socket) +{ + return socket >= 0; +} + +bool osSetSocketNonBlocking(int socket, bool blocking) +{ + if (blocking) + { + return appendSocketFlags(socket, O_NONBLOCK); + } + else + { + return removeSocketFlags(socket, O_NONBLOCK); + } +} +#elif MIJIN_TARGET_OS == MIJIN_OS_WINDOWS +using in_addr_t = ULONG; + +const char SOCKOPT_ONE = 1; +thread_local int numSocketsOpen = 0; + +long osRecv(SOCKET socket, std::span buffer, int flags) +{ + return recv(socket, reinterpret_cast(buffer.data()), static_cast(buffer.size()), flags); +} + +long osSend(SOCKET socket, std::span buffer, int flags) +{ + return send(socket, reinterpret_cast(buffer.data()), static_cast(buffer.size()), flags); +} + +SOCKET osCreateSocket(int addressFamily, int type, int protocol) +{ + if (numSocketsOpen == 0) + { + WSADATA wsaData; + if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) + { + return INVALID_SOCKET_HANDLE; + } + } + SOCKET result = socket(addressFamily, type, protocol); + if (result != INVALID_SOCKET_HANDLE) + { + ++numSocketsOpen; + } + return result; +} + +int osCloseSocket(SOCKET socket) +{ + const int result = closesocket(socket); + if (result == 0) + { + MIJIN_ASSERT(numSocketsOpen > 0, "Inbalanced calls to osOpenSocket and osCloseSocket!"); + --numSocketsOpen; + if (numSocketsOpen == 0) + { + WSACleanup(); + } + } + return result; +} + +bool osIsSocketValid(SOCKET socket) +{ + return socket != INVALID_SOCKET; +} + +bool osSetSocketNonBlocking(SOCKET socket, bool blocking) +{ + u_long value = blocking ? 0 : 1; + return ioctlsocket(socket, FIONBIO, &value) == NO_ERROR; +} +#endif // MIJIN_TARGET_OS == MIJIN_OS_LINUX } Optional IPv4Address::fromString(std::string_view stringView) noexcept @@ -101,7 +207,7 @@ Optional IPv6Address::fromString(std::string_view stringView) noexc return NULL_OPTIONAL; } - IPv6Address address; + IPv6Address address = {}; unsigned hextet = 0; for (std::string_view part : partsLeft) { @@ -131,7 +237,7 @@ StreamError TCPStream::readRaw(std::span buffer, const ReadOptions MIJIN_ASSERT(isOpen(), "Socket is not open."); setAsync(false); - const ::ssize_t bytesRead = recv(handle_, buffer.data(), buffer.size(), readFlags(options)); + const long bytesRead = osRecv(handle_, buffer, readFlags(options)); if (bytesRead < 0) { return translateErrno(); @@ -146,7 +252,7 @@ StreamError TCPStream::writeRaw(std::span buffer) MIJIN_ASSERT(isOpen(), "Socket is not open."); setAsync(false); - if (send(handle_, buffer.data(), buffer.size(), 0) < 0) + if (osSend(handle_, buffer, 0) < 0) { return translateErrno(); } @@ -161,7 +267,7 @@ mijin::Task TCPStream::c_readRaw(std::span buffer, co while(true) { - const ::ssize_t bytesRead = recv(handle_, buffer.data(), buffer.size(), readFlags(options)); + const long bytesRead = osRecv(handle_, buffer, readFlags(options)); if (bytesRead >= 0) { if (outBytesRead != nullptr) { @@ -184,7 +290,7 @@ mijin::Task TCPStream::c_writeRaw(std::span buf while (true) { - if (send(handle_, buffer.data(), buffer.size(), 0) >= 0) + if (osSend(handle_, buffer, 0) >= 0) { co_return StreamError::SUCCESS; } @@ -204,14 +310,7 @@ void TCPStream::setAsync(bool async) } async_ = async; - if (async) - { - appendSocketFlags(handle_, O_NONBLOCK); - } - else - { - removeSocketFlags(handle_, O_NONBLOCK); - } + osSetSocketNonBlocking(handle_, async); } std::size_t TCPStream::tell() @@ -253,8 +352,8 @@ StreamError TCPStream::open(ip_address_t address, std::uint16_t port) noexcept { MIJIN_ASSERT(!isOpen(), "Socket is already open."); - handle_ = socket(AF_INET, SOCK_STREAM, 0); - if (handle_ < 0) + handle_ = osCreateSocket(AF_INET, SOCK_STREAM, 0); + if (!osIsSocketValid(handle_)) { return translateErrno(); } @@ -263,29 +362,31 @@ StreamError TCPStream::open(ip_address_t address, std::uint16_t port) noexcept [&](const IPv4Address& address4) { #if __BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__ -#error "TODO: swap byte orderof thre address" +#error "TODO: swap byte order of the address" #endif sockaddr_in connectAddress = { .sin_family = AF_INET, .sin_port = htons(port), - .sin_addr = {.s_addr = std::bit_cast(address4)} + .sin_addr = std::bit_cast(address4.octets) }; return connect(handle_, reinterpret_cast(&connectAddress), sizeof(sockaddr_in)) == 0; }, - [&](const IPv6Address& address6) { + [&](const IPv6Address& address6) + { sockaddr_in6 connectAddress = { .sin6_family = AF_INET, .sin6_port = htons(port), .sin6_addr = std::bit_cast(address6) }; - return connect(handle_, reinterpret_cast(&connectAddress), sizeof(sockaddr_in6)) == 0;} + return connect(handle_, reinterpret_cast(&connectAddress), sizeof(sockaddr_in6)) == 0; + } }, address); if (!connected) { - ::close(handle_); - handle_ = -1; + osCloseSocket(handle_); + handle_ = INVALID_SOCKET_HANDLE; return translateErrno(); } @@ -295,8 +396,8 @@ StreamError TCPStream::open(ip_address_t address, std::uint16_t port) noexcept void TCPStream::close() noexcept { MIJIN_ASSERT(isOpen(), "Socket is not open."); - ::close(handle_); - handle_ = -1; + osCloseSocket(handle_); + handle_ = INVALID_SOCKET_HANDLE; } TCPStream& TCPSocket::getStream() noexcept @@ -304,26 +405,49 @@ TCPStream& TCPSocket::getStream() noexcept return stream_; } -StreamError TCPServerSocket::setup(const char* address, std::uint16_t port) noexcept +StreamError TCPServerSocket::setup(ip_address_t address, std::uint16_t port) noexcept { MIJIN_ASSERT(!isListening(), "Socket is already listening."); - handle_ = socket(AF_INET, SOCK_STREAM, 0); - if (handle_ < 0) + handle_ = osCreateSocket(AF_INET, SOCK_STREAM, 0); + if (!osIsSocketValid(handle_)) { return translateErrno(); } - sockaddr_in bindAddress = + if (setsockopt(handle_, SOL_SOCKET, SO_REUSEADDR, &SOCKOPT_ONE, sizeof(int)) != 0) { - .sin_family = AF_INET, - .sin_port = htons(port), - .sin_addr = {inet_addr(address)} - }; - static const int ONE = 1; - if ((setsockopt(handle_, SOL_SOCKET, SO_REUSEADDR, &ONE, sizeof(int))) - || (bind(handle_, reinterpret_cast(&bindAddress), sizeof(sockaddr_in)) < 0) + close(); + return translateErrno(); + } + + const bool bound = std::visit(Visitor{ + [&](const IPv4Address& address4) + { +#if __BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__ +#error "TODO: swap byte order of the address" +#endif + sockaddr_in bindAddress = + { + .sin_family = AF_INET, + .sin_port = htons(port), + .sin_addr = std::bit_cast(address4.octets) + }; + return bind(handle_, reinterpret_cast(&bindAddress), sizeof(sockaddr_in)) == 0; + }, + [&](const IPv6Address& address6) + { + sockaddr_in6 bindAddress = + { + .sin6_family = AF_INET, + .sin6_port = htons(port), + .sin6_addr = std::bit_cast(address6) + }; + return bind(handle_, reinterpret_cast(&bindAddress), sizeof(sockaddr_in6)) == 0; + } + }, address); + if (!bound || (listen(handle_, LISTEN_BACKLOG) < 0) - || !appendSocketFlags(handle_, O_NONBLOCK)) + || !osSetSocketNonBlocking(handle_, true)) { close(); return translateErrno(); @@ -335,8 +459,8 @@ void TCPServerSocket::close() noexcept { MIJIN_ASSERT(isListening(), "Socket is not listening."); - ::close(handle_); - handle_ = -1; + osCloseSocket(handle_); + handle_ = INVALID_SOCKET_HANDLE; } Task>> TCPServerSocket::c_waitForConnection() noexcept @@ -345,8 +469,8 @@ Task>> TCPServerSocket::c_waitForConnection { sockaddr_in client = {}; socklen_t LENGTH = sizeof(sockaddr_in); - const int newSocket = accept(handle_, reinterpret_cast(&client), &LENGTH); - if (newSocket < 0) + const socket_handle_t newSocket = accept(handle_, reinterpret_cast(&client), &LENGTH); + if (!osIsSocketValid(newSocket)) { if (errno != EAGAIN) { diff --git a/source/mijin/net/socket.hpp b/source/mijin/net/socket.hpp index eb73990..eef95cf 100644 --- a/source/mijin/net/socket.hpp +++ b/source/mijin/net/socket.hpp @@ -6,6 +6,7 @@ #include #include +#include "../detect.hpp" #include "../async/coroutine.hpp" #include "../container/optional.hpp" #include "../io/stream.hpp" @@ -17,6 +18,16 @@ namespace mijin // public types // +#if MIJIN_TARGET_OS == MIJIN_OS_WINDOWS +using socket_handle_t = std::uintptr_t; + +inline constexpr socket_handle_t INVALID_SOCKET_HANDLE = static_cast(-1); +#else +using socket_handle_t = int; + +inline constexpr socket_handle_t INVALID_SOCKET_HANDLE = -1; +#endif + struct IPv4Address { std::array octets; @@ -86,7 +97,7 @@ public: class TCPStream : public Stream { private: - int handle_ = -1; + socket_handle_t handle_ = INVALID_SOCKET_HANDLE; bool async_ = false; public: StreamError readRaw(std::span buffer, const ReadOptions& options, std::size_t* outBytesRead) override; @@ -101,7 +112,7 @@ public: StreamError open(ip_address_t address, std::uint16_t port) noexcept; void close() noexcept; - [[nodiscard]] bool isOpen() const noexcept { return handle_ >= 0; } + [[nodiscard]] bool isOpen() const noexcept { return handle_ != INVALID_SOCKET_HANDLE; } private: void setAsync(bool async); @@ -133,9 +144,17 @@ public: class TCPServerSocket : public ServerSocket { private: - int handle_ = -1; + socket_handle_t handle_ = INVALID_SOCKET_HANDLE; public: - StreamError setup(const char* address, std::uint16_t port) noexcept; + StreamError setup(ip_address_t address, std::uint16_t port) noexcept; + StreamError setup(std::string_view addressText, std::uint16_t port) noexcept + { + if (Optional address = ipAddressFromString(addressText); !address.empty()) + { + return setup(*address, port); + } + return StreamError::UNKNOWN_ERROR; + } void close() noexcept override; [[nodiscard]] bool isListening() const noexcept { return handle_ >= 0; } diff --git a/source/mijin/util/string.hpp b/source/mijin/util/string.hpp index 8ba8237..f41a719 100644 --- a/source/mijin/util/string.hpp +++ b/source/mijin/util/string.hpp @@ -4,6 +4,7 @@ #if !defined(MIJIN_UTIL_STRING_HPP_INCLUDED) #define MIJIN_UTIL_STRING_HPP_INCLUDED 1 +#include #include #include #include @@ -288,8 +289,10 @@ template [[nodiscard]] bool toNumber(std::string_view stringView, TNumber& outNumber, int base = 10) noexcept { - const std::from_chars_result res = std::from_chars(&*stringView.begin(), &*stringView.end(), outNumber, base); - return res.ec == std::errc{} && res.ptr == &*stringView.end(); + const char* start = &*stringView.begin(); + const char* end = start + stringView.size(); + const std::from_chars_result res = std::from_chars(start, end, outNumber, base); + return res.ec == std::errc{} && res.ptr == end; } namespace pipe diff --git a/source/mijin/util/winundef.hpp b/source/mijin/util/winundef.hpp new file mode 100644 index 0000000..f72ccf4 --- /dev/null +++ b/source/mijin/util/winundef.hpp @@ -0,0 +1,16 @@ + +#if defined(NEAR) +#undef NEAR +#endif + +#if defined(FAR) +#undef FAR +#endif + +#if defined(ERROR) +#undef ERROR +#endif + +#if defined(IGNORE) +#undef IGNORE +#endif