Implemented/fixed Windows/MSVC support for sockets.

This commit is contained in:
Patrick 2024-08-19 18:35:55 +02:00
parent 35e7131780
commit df260808b9
5 changed files with 217 additions and 51 deletions

View File

@ -2,5 +2,9 @@
"libbacktrace": "libbacktrace":
{ {
"condition": "compiler_family == 'gcc' or compiler_family == 'clang'" "condition": "compiler_family == 'gcc' or compiler_family == 'clang'"
},
"winsock2":
{
"condition": "target_os == 'nt'"
} }
} }

View File

@ -5,13 +5,17 @@
#include "../detect.hpp" #include "../detect.hpp"
#include "../util/string.hpp" #include "../util/string.hpp"
#include "../util/variant.hpp"
#if MIJIN_TARGET_OS == MIJIN_OS_LINUX #if MIJIN_TARGET_OS == MIJIN_OS_LINUX
#include <fcntl.h> #include <fcntl.h>
#include <unistd.h> #include <unistd.h>
#include <arpa/inet.h> #include <arpa/inet.h>
#include <sys/socket.h> #include <sys/socket.h>
#include "../util/variant.hpp" #elif MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
#include <WinSock2.h>
#include <ws2tcpip.h>
#include "../util/winundef.hpp"
#endif #endif
namespace mijin namespace mijin
@ -23,11 +27,22 @@ StreamError translateErrno() noexcept
{ {
switch (errno) switch (errno)
{ {
case EIO:
return StreamError::IO_ERROR;
default: default:
return StreamError::UNKNOWN_ERROR; return StreamError::UNKNOWN_ERROR;
} }
} }
int readFlags(const ReadOptions& options)
{
return (options.partial ? 0 : MSG_WAITALL)
| (options.peek ? MSG_PEEK : 0);
}
#if MIJIN_TARGET_OS == MIJIN_OS_LINUX
const int SOCKOPT_ONE = 1;
bool appendSocketFlags(int handle, int flags) noexcept bool appendSocketFlags(int handle, int flags) noexcept
{ {
const int currentFlags = fcntl(handle, F_GETFL); const int currentFlags = fcntl(handle, F_GETFL);
@ -48,11 +63,102 @@ bool removeSocketFlags(int handle, int flags) noexcept
return fcntl(handle, F_SETFL, currentFlags & ~flags) >= 0; return fcntl(handle, F_SETFL, currentFlags & ~flags) >= 0;
} }
int readFlags(const ReadOptions& options) long osRecv(int socket, std::span<std::uint8_t> buffer, int flags)
{ {
return (options.partial ? 0 : MSG_WAITALL) return static_cast<long>(recv(socket, buffer.data(), buffer.size(), flags);
| (options.peek ? MSG_PEEK : 0);
} }
long osSend(int socket, std::span<const std::uint8_t> buffer, int flags)
{
return static_cast<long>(send(handle_, buffer.data(), buffer.size(), flags));
}
int osCreateSocket(int domain, int type, int protocol)
{
return socket(domain, type, protocol);
}
int osCloseSocket(int socket)
{
return ::close(socket);
}
bool osIsSocketValid(int socket)
{
return socket >= 0;
}
bool osSetSocketNonBlocking(int socket, bool blocking)
{
if (blocking)
{
return appendSocketFlags(socket, O_NONBLOCK);
}
else
{
return removeSocketFlags(socket, O_NONBLOCK);
}
}
#elif MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
using in_addr_t = ULONG;
const char SOCKOPT_ONE = 1;
thread_local int numSocketsOpen = 0;
long osRecv(SOCKET socket, std::span<std::uint8_t> buffer, int flags)
{
return recv(socket, reinterpret_cast<char*>(buffer.data()), static_cast<int>(buffer.size()), flags);
}
long osSend(SOCKET socket, std::span<const std::uint8_t> buffer, int flags)
{
return send(socket, reinterpret_cast<const char*>(buffer.data()), static_cast<int>(buffer.size()), flags);
}
SOCKET osCreateSocket(int addressFamily, int type, int protocol)
{
if (numSocketsOpen == 0)
{
WSADATA wsaData;
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0)
{
return INVALID_SOCKET_HANDLE;
}
}
SOCKET result = socket(addressFamily, type, protocol);
if (result != INVALID_SOCKET_HANDLE)
{
++numSocketsOpen;
}
return result;
}
int osCloseSocket(SOCKET socket)
{
const int result = closesocket(socket);
if (result == 0)
{
MIJIN_ASSERT(numSocketsOpen > 0, "Inbalanced calls to osOpenSocket and osCloseSocket!");
--numSocketsOpen;
if (numSocketsOpen == 0)
{
WSACleanup();
}
}
return result;
}
bool osIsSocketValid(SOCKET socket)
{
return socket != INVALID_SOCKET;
}
bool osSetSocketNonBlocking(SOCKET socket, bool blocking)
{
u_long value = blocking ? 0 : 1;
return ioctlsocket(socket, FIONBIO, &value) == NO_ERROR;
}
#endif // MIJIN_TARGET_OS == MIJIN_OS_LINUX
} }
Optional<IPv4Address> IPv4Address::fromString(std::string_view stringView) noexcept Optional<IPv4Address> IPv4Address::fromString(std::string_view stringView) noexcept
@ -101,7 +207,7 @@ Optional<IPv6Address> IPv6Address::fromString(std::string_view stringView) noexc
return NULL_OPTIONAL; return NULL_OPTIONAL;
} }
IPv6Address address; IPv6Address address = {};
unsigned hextet = 0; unsigned hextet = 0;
for (std::string_view part : partsLeft) for (std::string_view part : partsLeft)
{ {
@ -131,7 +237,7 @@ StreamError TCPStream::readRaw(std::span<std::uint8_t> buffer, const ReadOptions
MIJIN_ASSERT(isOpen(), "Socket is not open."); MIJIN_ASSERT(isOpen(), "Socket is not open.");
setAsync(false); setAsync(false);
const ::ssize_t bytesRead = recv(handle_, buffer.data(), buffer.size(), readFlags(options)); const long bytesRead = osRecv(handle_, buffer, readFlags(options));
if (bytesRead < 0) if (bytesRead < 0)
{ {
return translateErrno(); return translateErrno();
@ -146,7 +252,7 @@ StreamError TCPStream::writeRaw(std::span<const std::uint8_t> buffer)
MIJIN_ASSERT(isOpen(), "Socket is not open."); MIJIN_ASSERT(isOpen(), "Socket is not open.");
setAsync(false); setAsync(false);
if (send(handle_, buffer.data(), buffer.size(), 0) < 0) if (osSend(handle_, buffer, 0) < 0)
{ {
return translateErrno(); return translateErrno();
} }
@ -161,7 +267,7 @@ mijin::Task<StreamError> TCPStream::c_readRaw(std::span<std::uint8_t> buffer, co
while(true) while(true)
{ {
const ::ssize_t bytesRead = recv(handle_, buffer.data(), buffer.size(), readFlags(options)); const long bytesRead = osRecv(handle_, buffer, readFlags(options));
if (bytesRead >= 0) if (bytesRead >= 0)
{ {
if (outBytesRead != nullptr) { if (outBytesRead != nullptr) {
@ -184,7 +290,7 @@ mijin::Task<StreamError> TCPStream::c_writeRaw(std::span<const std::uint8_t> buf
while (true) while (true)
{ {
if (send(handle_, buffer.data(), buffer.size(), 0) >= 0) if (osSend(handle_, buffer, 0) >= 0)
{ {
co_return StreamError::SUCCESS; co_return StreamError::SUCCESS;
} }
@ -204,14 +310,7 @@ void TCPStream::setAsync(bool async)
} }
async_ = async; async_ = async;
if (async) osSetSocketNonBlocking(handle_, async);
{
appendSocketFlags(handle_, O_NONBLOCK);
}
else
{
removeSocketFlags(handle_, O_NONBLOCK);
}
} }
std::size_t TCPStream::tell() std::size_t TCPStream::tell()
@ -253,8 +352,8 @@ StreamError TCPStream::open(ip_address_t address, std::uint16_t port) noexcept
{ {
MIJIN_ASSERT(!isOpen(), "Socket is already open."); MIJIN_ASSERT(!isOpen(), "Socket is already open.");
handle_ = socket(AF_INET, SOCK_STREAM, 0); handle_ = osCreateSocket(AF_INET, SOCK_STREAM, 0);
if (handle_ < 0) if (!osIsSocketValid(handle_))
{ {
return translateErrno(); return translateErrno();
} }
@ -263,29 +362,31 @@ StreamError TCPStream::open(ip_address_t address, std::uint16_t port) noexcept
[&](const IPv4Address& address4) [&](const IPv4Address& address4)
{ {
#if __BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__ #if __BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__
#error "TODO: swap byte orderof thre address" #error "TODO: swap byte order of the address"
#endif #endif
sockaddr_in connectAddress = sockaddr_in connectAddress =
{ {
.sin_family = AF_INET, .sin_family = AF_INET,
.sin_port = htons(port), .sin_port = htons(port),
.sin_addr = {.s_addr = std::bit_cast<in_addr_t>(address4)} .sin_addr = std::bit_cast<in_addr>(address4.octets)
}; };
return connect(handle_, reinterpret_cast<sockaddr*>(&connectAddress), sizeof(sockaddr_in)) == 0; return connect(handle_, reinterpret_cast<sockaddr*>(&connectAddress), sizeof(sockaddr_in)) == 0;
}, },
[&](const IPv6Address& address6) { [&](const IPv6Address& address6)
{
sockaddr_in6 connectAddress = sockaddr_in6 connectAddress =
{ {
.sin6_family = AF_INET, .sin6_family = AF_INET,
.sin6_port = htons(port), .sin6_port = htons(port),
.sin6_addr = std::bit_cast<in6_addr>(address6) .sin6_addr = std::bit_cast<in6_addr>(address6)
}; };
return connect(handle_, reinterpret_cast<sockaddr*>(&connectAddress), sizeof(sockaddr_in6)) == 0;} return connect(handle_, reinterpret_cast<sockaddr*>(&connectAddress), sizeof(sockaddr_in6)) == 0;
}
}, address); }, address);
if (!connected) if (!connected)
{ {
::close(handle_); osCloseSocket(handle_);
handle_ = -1; handle_ = INVALID_SOCKET_HANDLE;
return translateErrno(); return translateErrno();
} }
@ -295,8 +396,8 @@ StreamError TCPStream::open(ip_address_t address, std::uint16_t port) noexcept
void TCPStream::close() noexcept void TCPStream::close() noexcept
{ {
MIJIN_ASSERT(isOpen(), "Socket is not open."); MIJIN_ASSERT(isOpen(), "Socket is not open.");
::close(handle_); osCloseSocket(handle_);
handle_ = -1; handle_ = INVALID_SOCKET_HANDLE;
} }
TCPStream& TCPSocket::getStream() noexcept TCPStream& TCPSocket::getStream() noexcept
@ -304,26 +405,49 @@ TCPStream& TCPSocket::getStream() noexcept
return stream_; return stream_;
} }
StreamError TCPServerSocket::setup(const char* address, std::uint16_t port) noexcept StreamError TCPServerSocket::setup(ip_address_t address, std::uint16_t port) noexcept
{ {
MIJIN_ASSERT(!isListening(), "Socket is already listening."); MIJIN_ASSERT(!isListening(), "Socket is already listening.");
handle_ = socket(AF_INET, SOCK_STREAM, 0); handle_ = osCreateSocket(AF_INET, SOCK_STREAM, 0);
if (handle_ < 0) if (!osIsSocketValid(handle_))
{ {
return translateErrno(); return translateErrno();
} }
if (setsockopt(handle_, SOL_SOCKET, SO_REUSEADDR, &SOCKOPT_ONE, sizeof(int)) != 0)
{
close();
return translateErrno();
}
const bool bound = std::visit(Visitor{
[&](const IPv4Address& address4)
{
#if __BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__
#error "TODO: swap byte order of the address"
#endif
sockaddr_in bindAddress = sockaddr_in bindAddress =
{ {
.sin_family = AF_INET, .sin_family = AF_INET,
.sin_port = htons(port), .sin_port = htons(port),
.sin_addr = {inet_addr(address)} .sin_addr = std::bit_cast<in_addr>(address4.octets)
}; };
static const int ONE = 1; return bind(handle_, reinterpret_cast<sockaddr*>(&bindAddress), sizeof(sockaddr_in)) == 0;
if ((setsockopt(handle_, SOL_SOCKET, SO_REUSEADDR, &ONE, sizeof(int))) },
|| (bind(handle_, reinterpret_cast<sockaddr*>(&bindAddress), sizeof(sockaddr_in)) < 0) [&](const IPv6Address& address6)
{
sockaddr_in6 bindAddress =
{
.sin6_family = AF_INET,
.sin6_port = htons(port),
.sin6_addr = std::bit_cast<in6_addr>(address6)
};
return bind(handle_, reinterpret_cast<sockaddr*>(&bindAddress), sizeof(sockaddr_in6)) == 0;
}
}, address);
if (!bound
|| (listen(handle_, LISTEN_BACKLOG) < 0) || (listen(handle_, LISTEN_BACKLOG) < 0)
|| !appendSocketFlags(handle_, O_NONBLOCK)) || !osSetSocketNonBlocking(handle_, true))
{ {
close(); close();
return translateErrno(); return translateErrno();
@ -335,8 +459,8 @@ void TCPServerSocket::close() noexcept
{ {
MIJIN_ASSERT(isListening(), "Socket is not listening."); MIJIN_ASSERT(isListening(), "Socket is not listening.");
::close(handle_); osCloseSocket(handle_);
handle_ = -1; handle_ = INVALID_SOCKET_HANDLE;
} }
Task<StreamResult<std::unique_ptr<Socket>>> TCPServerSocket::c_waitForConnection() noexcept Task<StreamResult<std::unique_ptr<Socket>>> TCPServerSocket::c_waitForConnection() noexcept
@ -345,8 +469,8 @@ Task<StreamResult<std::unique_ptr<Socket>>> TCPServerSocket::c_waitForConnection
{ {
sockaddr_in client = {}; sockaddr_in client = {};
socklen_t LENGTH = sizeof(sockaddr_in); socklen_t LENGTH = sizeof(sockaddr_in);
const int newSocket = accept(handle_, reinterpret_cast<sockaddr*>(&client), &LENGTH); const socket_handle_t newSocket = accept(handle_, reinterpret_cast<sockaddr*>(&client), &LENGTH);
if (newSocket < 0) if (!osIsSocketValid(newSocket))
{ {
if (errno != EAGAIN) if (errno != EAGAIN)
{ {

View File

@ -6,6 +6,7 @@
#include <array> #include <array>
#include <variant> #include <variant>
#include "../detect.hpp"
#include "../async/coroutine.hpp" #include "../async/coroutine.hpp"
#include "../container/optional.hpp" #include "../container/optional.hpp"
#include "../io/stream.hpp" #include "../io/stream.hpp"
@ -17,6 +18,16 @@ namespace mijin
// public types // public types
// //
#if MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
using socket_handle_t = std::uintptr_t;
inline constexpr socket_handle_t INVALID_SOCKET_HANDLE = static_cast<socket_handle_t>(-1);
#else
using socket_handle_t = int;
inline constexpr socket_handle_t INVALID_SOCKET_HANDLE = -1;
#endif
struct IPv4Address struct IPv4Address
{ {
std::array<std::uint8_t, 4> octets; std::array<std::uint8_t, 4> octets;
@ -86,7 +97,7 @@ public:
class TCPStream : public Stream class TCPStream : public Stream
{ {
private: private:
int handle_ = -1; socket_handle_t handle_ = INVALID_SOCKET_HANDLE;
bool async_ = false; bool async_ = false;
public: public:
StreamError readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options, std::size_t* outBytesRead) override; StreamError readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options, std::size_t* outBytesRead) override;
@ -101,7 +112,7 @@ public:
StreamError open(ip_address_t address, std::uint16_t port) noexcept; StreamError open(ip_address_t address, std::uint16_t port) noexcept;
void close() noexcept; void close() noexcept;
[[nodiscard]] bool isOpen() const noexcept { return handle_ >= 0; } [[nodiscard]] bool isOpen() const noexcept { return handle_ != INVALID_SOCKET_HANDLE; }
private: private:
void setAsync(bool async); void setAsync(bool async);
@ -133,9 +144,17 @@ public:
class TCPServerSocket : public ServerSocket class TCPServerSocket : public ServerSocket
{ {
private: private:
int handle_ = -1; socket_handle_t handle_ = INVALID_SOCKET_HANDLE;
public: public:
StreamError setup(const char* address, std::uint16_t port) noexcept; StreamError setup(ip_address_t address, std::uint16_t port) noexcept;
StreamError setup(std::string_view addressText, std::uint16_t port) noexcept
{
if (Optional<ip_address_t> address = ipAddressFromString(addressText); !address.empty())
{
return setup(*address, port);
}
return StreamError::UNKNOWN_ERROR;
}
void close() noexcept override; void close() noexcept override;
[[nodiscard]] bool isListening() const noexcept { return handle_ >= 0; } [[nodiscard]] bool isListening() const noexcept { return handle_ >= 0; }

View File

@ -4,6 +4,7 @@
#if !defined(MIJIN_UTIL_STRING_HPP_INCLUDED) #if !defined(MIJIN_UTIL_STRING_HPP_INCLUDED)
#define MIJIN_UTIL_STRING_HPP_INCLUDED 1 #define MIJIN_UTIL_STRING_HPP_INCLUDED 1
#include <algorithm>
#include <array> #include <array>
#include <charconv> #include <charconv>
#include <iterator> #include <iterator>
@ -288,8 +289,10 @@ template<typename TNumber>
[[nodiscard]] [[nodiscard]]
bool toNumber(std::string_view stringView, TNumber& outNumber, int base = 10) noexcept bool toNumber(std::string_view stringView, TNumber& outNumber, int base = 10) noexcept
{ {
const std::from_chars_result res = std::from_chars(&*stringView.begin(), &*stringView.end(), outNumber, base); const char* start = &*stringView.begin();
return res.ec == std::errc{} && res.ptr == &*stringView.end(); const char* end = start + stringView.size();
const std::from_chars_result res = std::from_chars(start, end, outNumber, base);
return res.ec == std::errc{} && res.ptr == end;
} }
namespace pipe namespace pipe

View File

@ -0,0 +1,16 @@
#if defined(NEAR)
#undef NEAR
#endif
#if defined(FAR)
#undef FAR
#endif
#if defined(ERROR)
#undef ERROR
#endif
#if defined(IGNORE)
#undef IGNORE
#endif