Implemented/fixed Windows/MSVC support for sockets.
This commit is contained in:
parent
35e7131780
commit
df260808b9
@ -2,5 +2,9 @@
|
||||
"libbacktrace":
|
||||
{
|
||||
"condition": "compiler_family == 'gcc' or compiler_family == 'clang'"
|
||||
},
|
||||
"winsock2":
|
||||
{
|
||||
"condition": "target_os == 'nt'"
|
||||
}
|
||||
}
|
||||
|
@ -5,13 +5,17 @@
|
||||
|
||||
#include "../detect.hpp"
|
||||
#include "../util/string.hpp"
|
||||
#include "../util/variant.hpp"
|
||||
|
||||
#if MIJIN_TARGET_OS == MIJIN_OS_LINUX
|
||||
#include <fcntl.h>
|
||||
#include <unistd.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <sys/socket.h>
|
||||
#include "../util/variant.hpp"
|
||||
#elif MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
|
||||
#include <WinSock2.h>
|
||||
#include <ws2tcpip.h>
|
||||
#include "../util/winundef.hpp"
|
||||
#endif
|
||||
|
||||
namespace mijin
|
||||
@ -23,11 +27,22 @@ StreamError translateErrno() noexcept
|
||||
{
|
||||
switch (errno)
|
||||
{
|
||||
default:
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
case EIO:
|
||||
return StreamError::IO_ERROR;
|
||||
default:
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
int readFlags(const ReadOptions& options)
|
||||
{
|
||||
return (options.partial ? 0 : MSG_WAITALL)
|
||||
| (options.peek ? MSG_PEEK : 0);
|
||||
}
|
||||
|
||||
#if MIJIN_TARGET_OS == MIJIN_OS_LINUX
|
||||
const int SOCKOPT_ONE = 1;
|
||||
|
||||
bool appendSocketFlags(int handle, int flags) noexcept
|
||||
{
|
||||
const int currentFlags = fcntl(handle, F_GETFL);
|
||||
@ -48,11 +63,102 @@ bool removeSocketFlags(int handle, int flags) noexcept
|
||||
return fcntl(handle, F_SETFL, currentFlags & ~flags) >= 0;
|
||||
}
|
||||
|
||||
int readFlags(const ReadOptions& options)
|
||||
long osRecv(int socket, std::span<std::uint8_t> buffer, int flags)
|
||||
{
|
||||
return (options.partial ? 0 : MSG_WAITALL)
|
||||
| (options.peek ? MSG_PEEK : 0);
|
||||
return static_cast<long>(recv(socket, buffer.data(), buffer.size(), flags);
|
||||
}
|
||||
|
||||
long osSend(int socket, std::span<const std::uint8_t> buffer, int flags)
|
||||
{
|
||||
return static_cast<long>(send(handle_, buffer.data(), buffer.size(), flags));
|
||||
}
|
||||
|
||||
int osCreateSocket(int domain, int type, int protocol)
|
||||
{
|
||||
return socket(domain, type, protocol);
|
||||
}
|
||||
|
||||
int osCloseSocket(int socket)
|
||||
{
|
||||
return ::close(socket);
|
||||
}
|
||||
|
||||
bool osIsSocketValid(int socket)
|
||||
{
|
||||
return socket >= 0;
|
||||
}
|
||||
|
||||
bool osSetSocketNonBlocking(int socket, bool blocking)
|
||||
{
|
||||
if (blocking)
|
||||
{
|
||||
return appendSocketFlags(socket, O_NONBLOCK);
|
||||
}
|
||||
else
|
||||
{
|
||||
return removeSocketFlags(socket, O_NONBLOCK);
|
||||
}
|
||||
}
|
||||
#elif MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
|
||||
using in_addr_t = ULONG;
|
||||
|
||||
const char SOCKOPT_ONE = 1;
|
||||
thread_local int numSocketsOpen = 0;
|
||||
|
||||
long osRecv(SOCKET socket, std::span<std::uint8_t> buffer, int flags)
|
||||
{
|
||||
return recv(socket, reinterpret_cast<char*>(buffer.data()), static_cast<int>(buffer.size()), flags);
|
||||
}
|
||||
|
||||
long osSend(SOCKET socket, std::span<const std::uint8_t> buffer, int flags)
|
||||
{
|
||||
return send(socket, reinterpret_cast<const char*>(buffer.data()), static_cast<int>(buffer.size()), flags);
|
||||
}
|
||||
|
||||
SOCKET osCreateSocket(int addressFamily, int type, int protocol)
|
||||
{
|
||||
if (numSocketsOpen == 0)
|
||||
{
|
||||
WSADATA wsaData;
|
||||
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0)
|
||||
{
|
||||
return INVALID_SOCKET_HANDLE;
|
||||
}
|
||||
}
|
||||
SOCKET result = socket(addressFamily, type, protocol);
|
||||
if (result != INVALID_SOCKET_HANDLE)
|
||||
{
|
||||
++numSocketsOpen;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
int osCloseSocket(SOCKET socket)
|
||||
{
|
||||
const int result = closesocket(socket);
|
||||
if (result == 0)
|
||||
{
|
||||
MIJIN_ASSERT(numSocketsOpen > 0, "Inbalanced calls to osOpenSocket and osCloseSocket!");
|
||||
--numSocketsOpen;
|
||||
if (numSocketsOpen == 0)
|
||||
{
|
||||
WSACleanup();
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool osIsSocketValid(SOCKET socket)
|
||||
{
|
||||
return socket != INVALID_SOCKET;
|
||||
}
|
||||
|
||||
bool osSetSocketNonBlocking(SOCKET socket, bool blocking)
|
||||
{
|
||||
u_long value = blocking ? 0 : 1;
|
||||
return ioctlsocket(socket, FIONBIO, &value) == NO_ERROR;
|
||||
}
|
||||
#endif // MIJIN_TARGET_OS == MIJIN_OS_LINUX
|
||||
}
|
||||
|
||||
Optional<IPv4Address> IPv4Address::fromString(std::string_view stringView) noexcept
|
||||
@ -101,7 +207,7 @@ Optional<IPv6Address> IPv6Address::fromString(std::string_view stringView) noexc
|
||||
return NULL_OPTIONAL;
|
||||
}
|
||||
|
||||
IPv6Address address;
|
||||
IPv6Address address = {};
|
||||
unsigned hextet = 0;
|
||||
for (std::string_view part : partsLeft)
|
||||
{
|
||||
@ -131,7 +237,7 @@ StreamError TCPStream::readRaw(std::span<std::uint8_t> buffer, const ReadOptions
|
||||
MIJIN_ASSERT(isOpen(), "Socket is not open.");
|
||||
setAsync(false);
|
||||
|
||||
const ::ssize_t bytesRead = recv(handle_, buffer.data(), buffer.size(), readFlags(options));
|
||||
const long bytesRead = osRecv(handle_, buffer, readFlags(options));
|
||||
if (bytesRead < 0)
|
||||
{
|
||||
return translateErrno();
|
||||
@ -146,7 +252,7 @@ StreamError TCPStream::writeRaw(std::span<const std::uint8_t> buffer)
|
||||
MIJIN_ASSERT(isOpen(), "Socket is not open.");
|
||||
setAsync(false);
|
||||
|
||||
if (send(handle_, buffer.data(), buffer.size(), 0) < 0)
|
||||
if (osSend(handle_, buffer, 0) < 0)
|
||||
{
|
||||
return translateErrno();
|
||||
}
|
||||
@ -161,7 +267,7 @@ mijin::Task<StreamError> TCPStream::c_readRaw(std::span<std::uint8_t> buffer, co
|
||||
|
||||
while(true)
|
||||
{
|
||||
const ::ssize_t bytesRead = recv(handle_, buffer.data(), buffer.size(), readFlags(options));
|
||||
const long bytesRead = osRecv(handle_, buffer, readFlags(options));
|
||||
if (bytesRead >= 0)
|
||||
{
|
||||
if (outBytesRead != nullptr) {
|
||||
@ -184,7 +290,7 @@ mijin::Task<StreamError> TCPStream::c_writeRaw(std::span<const std::uint8_t> buf
|
||||
|
||||
while (true)
|
||||
{
|
||||
if (send(handle_, buffer.data(), buffer.size(), 0) >= 0)
|
||||
if (osSend(handle_, buffer, 0) >= 0)
|
||||
{
|
||||
co_return StreamError::SUCCESS;
|
||||
}
|
||||
@ -204,14 +310,7 @@ void TCPStream::setAsync(bool async)
|
||||
}
|
||||
async_ = async;
|
||||
|
||||
if (async)
|
||||
{
|
||||
appendSocketFlags(handle_, O_NONBLOCK);
|
||||
}
|
||||
else
|
||||
{
|
||||
removeSocketFlags(handle_, O_NONBLOCK);
|
||||
}
|
||||
osSetSocketNonBlocking(handle_, async);
|
||||
}
|
||||
|
||||
std::size_t TCPStream::tell()
|
||||
@ -253,8 +352,8 @@ StreamError TCPStream::open(ip_address_t address, std::uint16_t port) noexcept
|
||||
{
|
||||
MIJIN_ASSERT(!isOpen(), "Socket is already open.");
|
||||
|
||||
handle_ = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (handle_ < 0)
|
||||
handle_ = osCreateSocket(AF_INET, SOCK_STREAM, 0);
|
||||
if (!osIsSocketValid(handle_))
|
||||
{
|
||||
return translateErrno();
|
||||
}
|
||||
@ -263,29 +362,31 @@ StreamError TCPStream::open(ip_address_t address, std::uint16_t port) noexcept
|
||||
[&](const IPv4Address& address4)
|
||||
{
|
||||
#if __BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__
|
||||
#error "TODO: swap byte orderof thre address"
|
||||
#error "TODO: swap byte order of the address"
|
||||
#endif
|
||||
sockaddr_in connectAddress =
|
||||
{
|
||||
.sin_family = AF_INET,
|
||||
.sin_port = htons(port),
|
||||
.sin_addr = {.s_addr = std::bit_cast<in_addr_t>(address4)}
|
||||
.sin_addr = std::bit_cast<in_addr>(address4.octets)
|
||||
};
|
||||
return connect(handle_, reinterpret_cast<sockaddr*>(&connectAddress), sizeof(sockaddr_in)) == 0;
|
||||
},
|
||||
[&](const IPv6Address& address6) {
|
||||
[&](const IPv6Address& address6)
|
||||
{
|
||||
sockaddr_in6 connectAddress =
|
||||
{
|
||||
.sin6_family = AF_INET,
|
||||
.sin6_port = htons(port),
|
||||
.sin6_addr = std::bit_cast<in6_addr>(address6)
|
||||
};
|
||||
return connect(handle_, reinterpret_cast<sockaddr*>(&connectAddress), sizeof(sockaddr_in6)) == 0;}
|
||||
return connect(handle_, reinterpret_cast<sockaddr*>(&connectAddress), sizeof(sockaddr_in6)) == 0;
|
||||
}
|
||||
}, address);
|
||||
if (!connected)
|
||||
{
|
||||
::close(handle_);
|
||||
handle_ = -1;
|
||||
osCloseSocket(handle_);
|
||||
handle_ = INVALID_SOCKET_HANDLE;
|
||||
return translateErrno();
|
||||
}
|
||||
|
||||
@ -295,8 +396,8 @@ StreamError TCPStream::open(ip_address_t address, std::uint16_t port) noexcept
|
||||
void TCPStream::close() noexcept
|
||||
{
|
||||
MIJIN_ASSERT(isOpen(), "Socket is not open.");
|
||||
::close(handle_);
|
||||
handle_ = -1;
|
||||
osCloseSocket(handle_);
|
||||
handle_ = INVALID_SOCKET_HANDLE;
|
||||
}
|
||||
|
||||
TCPStream& TCPSocket::getStream() noexcept
|
||||
@ -304,26 +405,49 @@ TCPStream& TCPSocket::getStream() noexcept
|
||||
return stream_;
|
||||
}
|
||||
|
||||
StreamError TCPServerSocket::setup(const char* address, std::uint16_t port) noexcept
|
||||
StreamError TCPServerSocket::setup(ip_address_t address, std::uint16_t port) noexcept
|
||||
{
|
||||
MIJIN_ASSERT(!isListening(), "Socket is already listening.");
|
||||
|
||||
handle_ = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (handle_ < 0)
|
||||
handle_ = osCreateSocket(AF_INET, SOCK_STREAM, 0);
|
||||
if (!osIsSocketValid(handle_))
|
||||
{
|
||||
return translateErrno();
|
||||
}
|
||||
sockaddr_in bindAddress =
|
||||
if (setsockopt(handle_, SOL_SOCKET, SO_REUSEADDR, &SOCKOPT_ONE, sizeof(int)) != 0)
|
||||
{
|
||||
.sin_family = AF_INET,
|
||||
.sin_port = htons(port),
|
||||
.sin_addr = {inet_addr(address)}
|
||||
};
|
||||
static const int ONE = 1;
|
||||
if ((setsockopt(handle_, SOL_SOCKET, SO_REUSEADDR, &ONE, sizeof(int)))
|
||||
|| (bind(handle_, reinterpret_cast<sockaddr*>(&bindAddress), sizeof(sockaddr_in)) < 0)
|
||||
close();
|
||||
return translateErrno();
|
||||
}
|
||||
|
||||
const bool bound = std::visit(Visitor{
|
||||
[&](const IPv4Address& address4)
|
||||
{
|
||||
#if __BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__
|
||||
#error "TODO: swap byte order of the address"
|
||||
#endif
|
||||
sockaddr_in bindAddress =
|
||||
{
|
||||
.sin_family = AF_INET,
|
||||
.sin_port = htons(port),
|
||||
.sin_addr = std::bit_cast<in_addr>(address4.octets)
|
||||
};
|
||||
return bind(handle_, reinterpret_cast<sockaddr*>(&bindAddress), sizeof(sockaddr_in)) == 0;
|
||||
},
|
||||
[&](const IPv6Address& address6)
|
||||
{
|
||||
sockaddr_in6 bindAddress =
|
||||
{
|
||||
.sin6_family = AF_INET,
|
||||
.sin6_port = htons(port),
|
||||
.sin6_addr = std::bit_cast<in6_addr>(address6)
|
||||
};
|
||||
return bind(handle_, reinterpret_cast<sockaddr*>(&bindAddress), sizeof(sockaddr_in6)) == 0;
|
||||
}
|
||||
}, address);
|
||||
if (!bound
|
||||
|| (listen(handle_, LISTEN_BACKLOG) < 0)
|
||||
|| !appendSocketFlags(handle_, O_NONBLOCK))
|
||||
|| !osSetSocketNonBlocking(handle_, true))
|
||||
{
|
||||
close();
|
||||
return translateErrno();
|
||||
@ -335,8 +459,8 @@ void TCPServerSocket::close() noexcept
|
||||
{
|
||||
MIJIN_ASSERT(isListening(), "Socket is not listening.");
|
||||
|
||||
::close(handle_);
|
||||
handle_ = -1;
|
||||
osCloseSocket(handle_);
|
||||
handle_ = INVALID_SOCKET_HANDLE;
|
||||
}
|
||||
|
||||
Task<StreamResult<std::unique_ptr<Socket>>> TCPServerSocket::c_waitForConnection() noexcept
|
||||
@ -345,8 +469,8 @@ Task<StreamResult<std::unique_ptr<Socket>>> TCPServerSocket::c_waitForConnection
|
||||
{
|
||||
sockaddr_in client = {};
|
||||
socklen_t LENGTH = sizeof(sockaddr_in);
|
||||
const int newSocket = accept(handle_, reinterpret_cast<sockaddr*>(&client), &LENGTH);
|
||||
if (newSocket < 0)
|
||||
const socket_handle_t newSocket = accept(handle_, reinterpret_cast<sockaddr*>(&client), &LENGTH);
|
||||
if (!osIsSocketValid(newSocket))
|
||||
{
|
||||
if (errno != EAGAIN)
|
||||
{
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include <array>
|
||||
#include <variant>
|
||||
#include "../detect.hpp"
|
||||
#include "../async/coroutine.hpp"
|
||||
#include "../container/optional.hpp"
|
||||
#include "../io/stream.hpp"
|
||||
@ -17,6 +18,16 @@ namespace mijin
|
||||
// public types
|
||||
//
|
||||
|
||||
#if MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
|
||||
using socket_handle_t = std::uintptr_t;
|
||||
|
||||
inline constexpr socket_handle_t INVALID_SOCKET_HANDLE = static_cast<socket_handle_t>(-1);
|
||||
#else
|
||||
using socket_handle_t = int;
|
||||
|
||||
inline constexpr socket_handle_t INVALID_SOCKET_HANDLE = -1;
|
||||
#endif
|
||||
|
||||
struct IPv4Address
|
||||
{
|
||||
std::array<std::uint8_t, 4> octets;
|
||||
@ -86,7 +97,7 @@ public:
|
||||
class TCPStream : public Stream
|
||||
{
|
||||
private:
|
||||
int handle_ = -1;
|
||||
socket_handle_t handle_ = INVALID_SOCKET_HANDLE;
|
||||
bool async_ = false;
|
||||
public:
|
||||
StreamError readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options, std::size_t* outBytesRead) override;
|
||||
@ -101,7 +112,7 @@ public:
|
||||
|
||||
StreamError open(ip_address_t address, std::uint16_t port) noexcept;
|
||||
void close() noexcept;
|
||||
[[nodiscard]] bool isOpen() const noexcept { return handle_ >= 0; }
|
||||
[[nodiscard]] bool isOpen() const noexcept { return handle_ != INVALID_SOCKET_HANDLE; }
|
||||
private:
|
||||
void setAsync(bool async);
|
||||
|
||||
@ -133,9 +144,17 @@ public:
|
||||
class TCPServerSocket : public ServerSocket
|
||||
{
|
||||
private:
|
||||
int handle_ = -1;
|
||||
socket_handle_t handle_ = INVALID_SOCKET_HANDLE;
|
||||
public:
|
||||
StreamError setup(const char* address, std::uint16_t port) noexcept;
|
||||
StreamError setup(ip_address_t address, std::uint16_t port) noexcept;
|
||||
StreamError setup(std::string_view addressText, std::uint16_t port) noexcept
|
||||
{
|
||||
if (Optional<ip_address_t> address = ipAddressFromString(addressText); !address.empty())
|
||||
{
|
||||
return setup(*address, port);
|
||||
}
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
void close() noexcept override;
|
||||
[[nodiscard]] bool isListening() const noexcept { return handle_ >= 0; }
|
||||
|
||||
|
@ -4,6 +4,7 @@
|
||||
#if !defined(MIJIN_UTIL_STRING_HPP_INCLUDED)
|
||||
#define MIJIN_UTIL_STRING_HPP_INCLUDED 1
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <charconv>
|
||||
#include <iterator>
|
||||
@ -288,8 +289,10 @@ template<typename TNumber>
|
||||
[[nodiscard]]
|
||||
bool toNumber(std::string_view stringView, TNumber& outNumber, int base = 10) noexcept
|
||||
{
|
||||
const std::from_chars_result res = std::from_chars(&*stringView.begin(), &*stringView.end(), outNumber, base);
|
||||
return res.ec == std::errc{} && res.ptr == &*stringView.end();
|
||||
const char* start = &*stringView.begin();
|
||||
const char* end = start + stringView.size();
|
||||
const std::from_chars_result res = std::from_chars(start, end, outNumber, base);
|
||||
return res.ec == std::errc{} && res.ptr == end;
|
||||
}
|
||||
|
||||
namespace pipe
|
||||
|
16
source/mijin/util/winundef.hpp
Normal file
16
source/mijin/util/winundef.hpp
Normal file
@ -0,0 +1,16 @@
|
||||
|
||||
#if defined(NEAR)
|
||||
#undef NEAR
|
||||
#endif
|
||||
|
||||
#if defined(FAR)
|
||||
#undef FAR
|
||||
#endif
|
||||
|
||||
#if defined(ERROR)
|
||||
#undef ERROR
|
||||
#endif
|
||||
|
||||
#if defined(IGNORE)
|
||||
#undef IGNORE
|
||||
#endif
|
Loading…
x
Reference in New Issue
Block a user