Implemented/fixed Windows/MSVC support for sockets.

This commit is contained in:
Patrick 2024-08-19 18:35:55 +02:00
parent 35e7131780
commit df260808b9
5 changed files with 217 additions and 51 deletions

View File

@ -2,5 +2,9 @@
"libbacktrace":
{
"condition": "compiler_family == 'gcc' or compiler_family == 'clang'"
},
"winsock2":
{
"condition": "target_os == 'nt'"
}
}

View File

@ -5,13 +5,17 @@
#include "../detect.hpp"
#include "../util/string.hpp"
#include "../util/variant.hpp"
#if MIJIN_TARGET_OS == MIJIN_OS_LINUX
#include <fcntl.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include "../util/variant.hpp"
#elif MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
#include <WinSock2.h>
#include <ws2tcpip.h>
#include "../util/winundef.hpp"
#endif
namespace mijin
@ -23,11 +27,22 @@ StreamError translateErrno() noexcept
{
switch (errno)
{
default:
return StreamError::UNKNOWN_ERROR;
case EIO:
return StreamError::IO_ERROR;
default:
return StreamError::UNKNOWN_ERROR;
}
}
int readFlags(const ReadOptions& options)
{
return (options.partial ? 0 : MSG_WAITALL)
| (options.peek ? MSG_PEEK : 0);
}
#if MIJIN_TARGET_OS == MIJIN_OS_LINUX
const int SOCKOPT_ONE = 1;
bool appendSocketFlags(int handle, int flags) noexcept
{
const int currentFlags = fcntl(handle, F_GETFL);
@ -48,11 +63,102 @@ bool removeSocketFlags(int handle, int flags) noexcept
return fcntl(handle, F_SETFL, currentFlags & ~flags) >= 0;
}
int readFlags(const ReadOptions& options)
long osRecv(int socket, std::span<std::uint8_t> buffer, int flags)
{
return (options.partial ? 0 : MSG_WAITALL)
| (options.peek ? MSG_PEEK : 0);
return static_cast<long>(recv(socket, buffer.data(), buffer.size(), flags);
}
long osSend(int socket, std::span<const std::uint8_t> buffer, int flags)
{
return static_cast<long>(send(handle_, buffer.data(), buffer.size(), flags));
}
int osCreateSocket(int domain, int type, int protocol)
{
return socket(domain, type, protocol);
}
int osCloseSocket(int socket)
{
return ::close(socket);
}
bool osIsSocketValid(int socket)
{
return socket >= 0;
}
bool osSetSocketNonBlocking(int socket, bool blocking)
{
if (blocking)
{
return appendSocketFlags(socket, O_NONBLOCK);
}
else
{
return removeSocketFlags(socket, O_NONBLOCK);
}
}
#elif MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
using in_addr_t = ULONG;
const char SOCKOPT_ONE = 1;
thread_local int numSocketsOpen = 0;
long osRecv(SOCKET socket, std::span<std::uint8_t> buffer, int flags)
{
return recv(socket, reinterpret_cast<char*>(buffer.data()), static_cast<int>(buffer.size()), flags);
}
long osSend(SOCKET socket, std::span<const std::uint8_t> buffer, int flags)
{
return send(socket, reinterpret_cast<const char*>(buffer.data()), static_cast<int>(buffer.size()), flags);
}
SOCKET osCreateSocket(int addressFamily, int type, int protocol)
{
if (numSocketsOpen == 0)
{
WSADATA wsaData;
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0)
{
return INVALID_SOCKET_HANDLE;
}
}
SOCKET result = socket(addressFamily, type, protocol);
if (result != INVALID_SOCKET_HANDLE)
{
++numSocketsOpen;
}
return result;
}
int osCloseSocket(SOCKET socket)
{
const int result = closesocket(socket);
if (result == 0)
{
MIJIN_ASSERT(numSocketsOpen > 0, "Inbalanced calls to osOpenSocket and osCloseSocket!");
--numSocketsOpen;
if (numSocketsOpen == 0)
{
WSACleanup();
}
}
return result;
}
bool osIsSocketValid(SOCKET socket)
{
return socket != INVALID_SOCKET;
}
bool osSetSocketNonBlocking(SOCKET socket, bool blocking)
{
u_long value = blocking ? 0 : 1;
return ioctlsocket(socket, FIONBIO, &value) == NO_ERROR;
}
#endif // MIJIN_TARGET_OS == MIJIN_OS_LINUX
}
Optional<IPv4Address> IPv4Address::fromString(std::string_view stringView) noexcept
@ -101,7 +207,7 @@ Optional<IPv6Address> IPv6Address::fromString(std::string_view stringView) noexc
return NULL_OPTIONAL;
}
IPv6Address address;
IPv6Address address = {};
unsigned hextet = 0;
for (std::string_view part : partsLeft)
{
@ -131,7 +237,7 @@ StreamError TCPStream::readRaw(std::span<std::uint8_t> buffer, const ReadOptions
MIJIN_ASSERT(isOpen(), "Socket is not open.");
setAsync(false);
const ::ssize_t bytesRead = recv(handle_, buffer.data(), buffer.size(), readFlags(options));
const long bytesRead = osRecv(handle_, buffer, readFlags(options));
if (bytesRead < 0)
{
return translateErrno();
@ -146,7 +252,7 @@ StreamError TCPStream::writeRaw(std::span<const std::uint8_t> buffer)
MIJIN_ASSERT(isOpen(), "Socket is not open.");
setAsync(false);
if (send(handle_, buffer.data(), buffer.size(), 0) < 0)
if (osSend(handle_, buffer, 0) < 0)
{
return translateErrno();
}
@ -161,7 +267,7 @@ mijin::Task<StreamError> TCPStream::c_readRaw(std::span<std::uint8_t> buffer, co
while(true)
{
const ::ssize_t bytesRead = recv(handle_, buffer.data(), buffer.size(), readFlags(options));
const long bytesRead = osRecv(handle_, buffer, readFlags(options));
if (bytesRead >= 0)
{
if (outBytesRead != nullptr) {
@ -184,7 +290,7 @@ mijin::Task<StreamError> TCPStream::c_writeRaw(std::span<const std::uint8_t> buf
while (true)
{
if (send(handle_, buffer.data(), buffer.size(), 0) >= 0)
if (osSend(handle_, buffer, 0) >= 0)
{
co_return StreamError::SUCCESS;
}
@ -204,14 +310,7 @@ void TCPStream::setAsync(bool async)
}
async_ = async;
if (async)
{
appendSocketFlags(handle_, O_NONBLOCK);
}
else
{
removeSocketFlags(handle_, O_NONBLOCK);
}
osSetSocketNonBlocking(handle_, async);
}
std::size_t TCPStream::tell()
@ -253,8 +352,8 @@ StreamError TCPStream::open(ip_address_t address, std::uint16_t port) noexcept
{
MIJIN_ASSERT(!isOpen(), "Socket is already open.");
handle_ = socket(AF_INET, SOCK_STREAM, 0);
if (handle_ < 0)
handle_ = osCreateSocket(AF_INET, SOCK_STREAM, 0);
if (!osIsSocketValid(handle_))
{
return translateErrno();
}
@ -263,29 +362,31 @@ StreamError TCPStream::open(ip_address_t address, std::uint16_t port) noexcept
[&](const IPv4Address& address4)
{
#if __BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__
#error "TODO: swap byte orderof thre address"
#error "TODO: swap byte order of the address"
#endif
sockaddr_in connectAddress =
{
.sin_family = AF_INET,
.sin_port = htons(port),
.sin_addr = {.s_addr = std::bit_cast<in_addr_t>(address4)}
.sin_addr = std::bit_cast<in_addr>(address4.octets)
};
return connect(handle_, reinterpret_cast<sockaddr*>(&connectAddress), sizeof(sockaddr_in)) == 0;
},
[&](const IPv6Address& address6) {
[&](const IPv6Address& address6)
{
sockaddr_in6 connectAddress =
{
.sin6_family = AF_INET,
.sin6_port = htons(port),
.sin6_addr = std::bit_cast<in6_addr>(address6)
};
return connect(handle_, reinterpret_cast<sockaddr*>(&connectAddress), sizeof(sockaddr_in6)) == 0;}
return connect(handle_, reinterpret_cast<sockaddr*>(&connectAddress), sizeof(sockaddr_in6)) == 0;
}
}, address);
if (!connected)
{
::close(handle_);
handle_ = -1;
osCloseSocket(handle_);
handle_ = INVALID_SOCKET_HANDLE;
return translateErrno();
}
@ -295,8 +396,8 @@ StreamError TCPStream::open(ip_address_t address, std::uint16_t port) noexcept
void TCPStream::close() noexcept
{
MIJIN_ASSERT(isOpen(), "Socket is not open.");
::close(handle_);
handle_ = -1;
osCloseSocket(handle_);
handle_ = INVALID_SOCKET_HANDLE;
}
TCPStream& TCPSocket::getStream() noexcept
@ -304,26 +405,49 @@ TCPStream& TCPSocket::getStream() noexcept
return stream_;
}
StreamError TCPServerSocket::setup(const char* address, std::uint16_t port) noexcept
StreamError TCPServerSocket::setup(ip_address_t address, std::uint16_t port) noexcept
{
MIJIN_ASSERT(!isListening(), "Socket is already listening.");
handle_ = socket(AF_INET, SOCK_STREAM, 0);
if (handle_ < 0)
handle_ = osCreateSocket(AF_INET, SOCK_STREAM, 0);
if (!osIsSocketValid(handle_))
{
return translateErrno();
}
sockaddr_in bindAddress =
if (setsockopt(handle_, SOL_SOCKET, SO_REUSEADDR, &SOCKOPT_ONE, sizeof(int)) != 0)
{
.sin_family = AF_INET,
.sin_port = htons(port),
.sin_addr = {inet_addr(address)}
};
static const int ONE = 1;
if ((setsockopt(handle_, SOL_SOCKET, SO_REUSEADDR, &ONE, sizeof(int)))
|| (bind(handle_, reinterpret_cast<sockaddr*>(&bindAddress), sizeof(sockaddr_in)) < 0)
close();
return translateErrno();
}
const bool bound = std::visit(Visitor{
[&](const IPv4Address& address4)
{
#if __BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__
#error "TODO: swap byte order of the address"
#endif
sockaddr_in bindAddress =
{
.sin_family = AF_INET,
.sin_port = htons(port),
.sin_addr = std::bit_cast<in_addr>(address4.octets)
};
return bind(handle_, reinterpret_cast<sockaddr*>(&bindAddress), sizeof(sockaddr_in)) == 0;
},
[&](const IPv6Address& address6)
{
sockaddr_in6 bindAddress =
{
.sin6_family = AF_INET,
.sin6_port = htons(port),
.sin6_addr = std::bit_cast<in6_addr>(address6)
};
return bind(handle_, reinterpret_cast<sockaddr*>(&bindAddress), sizeof(sockaddr_in6)) == 0;
}
}, address);
if (!bound
|| (listen(handle_, LISTEN_BACKLOG) < 0)
|| !appendSocketFlags(handle_, O_NONBLOCK))
|| !osSetSocketNonBlocking(handle_, true))
{
close();
return translateErrno();
@ -335,8 +459,8 @@ void TCPServerSocket::close() noexcept
{
MIJIN_ASSERT(isListening(), "Socket is not listening.");
::close(handle_);
handle_ = -1;
osCloseSocket(handle_);
handle_ = INVALID_SOCKET_HANDLE;
}
Task<StreamResult<std::unique_ptr<Socket>>> TCPServerSocket::c_waitForConnection() noexcept
@ -345,8 +469,8 @@ Task<StreamResult<std::unique_ptr<Socket>>> TCPServerSocket::c_waitForConnection
{
sockaddr_in client = {};
socklen_t LENGTH = sizeof(sockaddr_in);
const int newSocket = accept(handle_, reinterpret_cast<sockaddr*>(&client), &LENGTH);
if (newSocket < 0)
const socket_handle_t newSocket = accept(handle_, reinterpret_cast<sockaddr*>(&client), &LENGTH);
if (!osIsSocketValid(newSocket))
{
if (errno != EAGAIN)
{

View File

@ -6,6 +6,7 @@
#include <array>
#include <variant>
#include "../detect.hpp"
#include "../async/coroutine.hpp"
#include "../container/optional.hpp"
#include "../io/stream.hpp"
@ -17,6 +18,16 @@ namespace mijin
// public types
//
#if MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
using socket_handle_t = std::uintptr_t;
inline constexpr socket_handle_t INVALID_SOCKET_HANDLE = static_cast<socket_handle_t>(-1);
#else
using socket_handle_t = int;
inline constexpr socket_handle_t INVALID_SOCKET_HANDLE = -1;
#endif
struct IPv4Address
{
std::array<std::uint8_t, 4> octets;
@ -86,7 +97,7 @@ public:
class TCPStream : public Stream
{
private:
int handle_ = -1;
socket_handle_t handle_ = INVALID_SOCKET_HANDLE;
bool async_ = false;
public:
StreamError readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options, std::size_t* outBytesRead) override;
@ -101,7 +112,7 @@ public:
StreamError open(ip_address_t address, std::uint16_t port) noexcept;
void close() noexcept;
[[nodiscard]] bool isOpen() const noexcept { return handle_ >= 0; }
[[nodiscard]] bool isOpen() const noexcept { return handle_ != INVALID_SOCKET_HANDLE; }
private:
void setAsync(bool async);
@ -133,9 +144,17 @@ public:
class TCPServerSocket : public ServerSocket
{
private:
int handle_ = -1;
socket_handle_t handle_ = INVALID_SOCKET_HANDLE;
public:
StreamError setup(const char* address, std::uint16_t port) noexcept;
StreamError setup(ip_address_t address, std::uint16_t port) noexcept;
StreamError setup(std::string_view addressText, std::uint16_t port) noexcept
{
if (Optional<ip_address_t> address = ipAddressFromString(addressText); !address.empty())
{
return setup(*address, port);
}
return StreamError::UNKNOWN_ERROR;
}
void close() noexcept override;
[[nodiscard]] bool isListening() const noexcept { return handle_ >= 0; }

View File

@ -4,6 +4,7 @@
#if !defined(MIJIN_UTIL_STRING_HPP_INCLUDED)
#define MIJIN_UTIL_STRING_HPP_INCLUDED 1
#include <algorithm>
#include <array>
#include <charconv>
#include <iterator>
@ -288,8 +289,10 @@ template<typename TNumber>
[[nodiscard]]
bool toNumber(std::string_view stringView, TNumber& outNumber, int base = 10) noexcept
{
const std::from_chars_result res = std::from_chars(&*stringView.begin(), &*stringView.end(), outNumber, base);
return res.ec == std::errc{} && res.ptr == &*stringView.end();
const char* start = &*stringView.begin();
const char* end = start + stringView.size();
const std::from_chars_result res = std::from_chars(start, end, outNumber, base);
return res.ec == std::errc{} && res.ptr == end;
}
namespace pipe

View File

@ -0,0 +1,16 @@
#if defined(NEAR)
#undef NEAR
#endif
#if defined(FAR)
#undef FAR
#endif
#if defined(ERROR)
#undef ERROR
#endif
#if defined(IGNORE)
#undef IGNORE
#endif