489 lines
12 KiB
C++
489 lines
12 KiB
C++
|
|
#include "./socket.hpp"
|
|
|
|
#include <iostream>
|
|
|
|
#include "../detect.hpp"
|
|
#include "../util/string.hpp"
|
|
#include "../util/variant.hpp"
|
|
|
|
#if MIJIN_TARGET_OS == MIJIN_OS_LINUX
|
|
#include <fcntl.h>
|
|
#include <unistd.h>
|
|
#include <arpa/inet.h>
|
|
#include <sys/socket.h>
|
|
#elif MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
|
|
#include <WinSock2.h>
|
|
#include <ws2tcpip.h>
|
|
#include "../util/winundef.hpp"
|
|
#endif
|
|
|
|
namespace mijin
|
|
{
|
|
namespace
|
|
{
|
|
inline constexpr int LISTEN_BACKLOG = 3;
|
|
StreamError translateErrno() noexcept
|
|
{
|
|
switch (errno)
|
|
{
|
|
case EIO:
|
|
return StreamError::IO_ERROR;
|
|
default:
|
|
return StreamError::UNKNOWN_ERROR;
|
|
}
|
|
}
|
|
|
|
int readFlags(const ReadOptions& options)
|
|
{
|
|
return (options.partial ? 0 : MSG_WAITALL)
|
|
| (options.peek ? MSG_PEEK : 0);
|
|
}
|
|
|
|
#if MIJIN_TARGET_OS == MIJIN_OS_LINUX
|
|
const int SOCKOPT_ONE = 1;
|
|
|
|
bool appendSocketFlags(int handle, int flags) noexcept
|
|
{
|
|
const int currentFlags = fcntl(handle, F_GETFL);
|
|
if (currentFlags < 0)
|
|
{
|
|
return false;
|
|
}
|
|
return fcntl(handle, F_SETFL, currentFlags | flags) >= 0;
|
|
}
|
|
|
|
bool removeSocketFlags(int handle, int flags) noexcept
|
|
{
|
|
const int currentFlags = fcntl(handle, F_GETFL);
|
|
if (currentFlags < 0)
|
|
{
|
|
return false;
|
|
}
|
|
return fcntl(handle, F_SETFL, currentFlags & ~flags) >= 0;
|
|
}
|
|
|
|
long osRecv(int socket, std::span<std::uint8_t> buffer, int flags)
|
|
{
|
|
return static_cast<long>(recv(socket, buffer.data(), buffer.size(), flags);
|
|
}
|
|
|
|
long osSend(int socket, std::span<const std::uint8_t> buffer, int flags)
|
|
{
|
|
return static_cast<long>(send(handle_, buffer.data(), buffer.size(), flags));
|
|
}
|
|
|
|
int osCreateSocket(int domain, int type, int protocol)
|
|
{
|
|
return socket(domain, type, protocol);
|
|
}
|
|
|
|
int osCloseSocket(int socket)
|
|
{
|
|
return ::close(socket);
|
|
}
|
|
|
|
bool osIsSocketValid(int socket)
|
|
{
|
|
return socket >= 0;
|
|
}
|
|
|
|
bool osSetSocketNonBlocking(int socket, bool blocking)
|
|
{
|
|
if (blocking)
|
|
{
|
|
return appendSocketFlags(socket, O_NONBLOCK);
|
|
}
|
|
else
|
|
{
|
|
return removeSocketFlags(socket, O_NONBLOCK);
|
|
}
|
|
}
|
|
#elif MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
|
|
using in_addr_t = ULONG;
|
|
|
|
const char SOCKOPT_ONE = 1;
|
|
thread_local int numSocketsOpen = 0;
|
|
|
|
long osRecv(SOCKET socket, std::span<std::uint8_t> buffer, int flags)
|
|
{
|
|
return recv(socket, reinterpret_cast<char*>(buffer.data()), static_cast<int>(buffer.size()), flags);
|
|
}
|
|
|
|
long osSend(SOCKET socket, std::span<const std::uint8_t> buffer, int flags)
|
|
{
|
|
return send(socket, reinterpret_cast<const char*>(buffer.data()), static_cast<int>(buffer.size()), flags);
|
|
}
|
|
|
|
SOCKET osCreateSocket(int addressFamily, int type, int protocol)
|
|
{
|
|
if (numSocketsOpen == 0)
|
|
{
|
|
WSADATA wsaData;
|
|
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0)
|
|
{
|
|
return INVALID_SOCKET_HANDLE;
|
|
}
|
|
}
|
|
SOCKET result = socket(addressFamily, type, protocol);
|
|
if (result != INVALID_SOCKET_HANDLE)
|
|
{
|
|
++numSocketsOpen;
|
|
}
|
|
return result;
|
|
}
|
|
|
|
int osCloseSocket(SOCKET socket)
|
|
{
|
|
const int result = closesocket(socket);
|
|
if (result == 0)
|
|
{
|
|
MIJIN_ASSERT(numSocketsOpen > 0, "Inbalanced calls to osOpenSocket and osCloseSocket!");
|
|
--numSocketsOpen;
|
|
if (numSocketsOpen == 0)
|
|
{
|
|
WSACleanup();
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
bool osIsSocketValid(SOCKET socket)
|
|
{
|
|
return socket != INVALID_SOCKET;
|
|
}
|
|
|
|
bool osSetSocketNonBlocking(SOCKET socket, bool blocking)
|
|
{
|
|
u_long value = blocking ? 0 : 1;
|
|
return ioctlsocket(socket, FIONBIO, &value) == NO_ERROR;
|
|
}
|
|
#endif // MIJIN_TARGET_OS == MIJIN_OS_LINUX
|
|
}
|
|
|
|
Optional<IPv4Address> IPv4Address::fromString(std::string_view stringView) noexcept
|
|
{
|
|
std::vector<std::string_view> parts = split(stringView, ".", {.limitParts = 4});
|
|
if (parts.size() != 4) {
|
|
return NULL_OPTIONAL;
|
|
}
|
|
IPv4Address address;
|
|
for (int idx = 0; idx < 4; ++idx)
|
|
{
|
|
if (!toNumber(parts[idx], address.octets[idx]))
|
|
{
|
|
return NULL_OPTIONAL;
|
|
}
|
|
}
|
|
return address;
|
|
}
|
|
|
|
Optional<IPv6Address> IPv6Address::fromString(std::string_view stringView) noexcept
|
|
{
|
|
// very specific edge case
|
|
if (stringView.contains(":::"))
|
|
{
|
|
return NULL_OPTIONAL;
|
|
}
|
|
|
|
std::vector<std::string_view> parts = split(stringView, "::", {.ignoreEmpty = false});
|
|
if (parts.size() > 2)
|
|
{
|
|
return NULL_OPTIONAL;
|
|
}
|
|
if (parts.size() == 1)
|
|
{
|
|
parts.emplace_back("");
|
|
}
|
|
|
|
std::vector<std::string_view> partsLeft = split(parts[0], ":");
|
|
std::vector<std::string_view> partsRight = split(parts[1], ":");
|
|
|
|
std::erase_if(partsLeft, std::mem_fn(&std::string_view::empty));
|
|
std::erase_if(partsRight, std::mem_fn(&std::string_view::empty));
|
|
|
|
if (partsLeft.size() + partsRight.size() > 8)
|
|
{
|
|
return NULL_OPTIONAL;
|
|
}
|
|
|
|
IPv6Address address = {};
|
|
unsigned hextet = 0;
|
|
for (std::string_view part : partsLeft)
|
|
{
|
|
if (!toNumber(part, address.hextets[hextet], /* base = */ 16))
|
|
{
|
|
return NULL_OPTIONAL;
|
|
}
|
|
++hextet;
|
|
}
|
|
for (; hextet < (8 - partsRight.size()); ++hextet)
|
|
{
|
|
address.hextets[hextet] = 0;
|
|
}
|
|
for (std::string_view part : partsRight)
|
|
{
|
|
if (!toNumber(part, address.hextets[hextet], /* base = */ 16))
|
|
{
|
|
return NULL_OPTIONAL;
|
|
}
|
|
++hextet;
|
|
}
|
|
return address;
|
|
}
|
|
|
|
StreamError TCPStream::readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options, std::size_t* outBytesRead)
|
|
{
|
|
MIJIN_ASSERT(isOpen(), "Socket is not open.");
|
|
setAsync(false);
|
|
|
|
const long bytesRead = osRecv(handle_, buffer, readFlags(options));
|
|
if (bytesRead < 0)
|
|
{
|
|
return translateErrno();
|
|
}
|
|
*outBytesRead = static_cast<std::size_t>(bytesRead);
|
|
|
|
return StreamError::SUCCESS;
|
|
}
|
|
|
|
StreamError TCPStream::writeRaw(std::span<const std::uint8_t> buffer)
|
|
{
|
|
MIJIN_ASSERT(isOpen(), "Socket is not open.");
|
|
setAsync(false);
|
|
|
|
if (osSend(handle_, buffer, 0) < 0)
|
|
{
|
|
return translateErrno();
|
|
}
|
|
|
|
return StreamError::SUCCESS;
|
|
}
|
|
|
|
mijin::Task<StreamError> TCPStream::c_readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options, std::size_t* outBytesRead)
|
|
{
|
|
MIJIN_ASSERT(isOpen(), "Socket is not open.");
|
|
setAsync(true);
|
|
|
|
while(true)
|
|
{
|
|
const long bytesRead = osRecv(handle_, buffer, readFlags(options));
|
|
if (bytesRead >= 0)
|
|
{
|
|
if (outBytesRead != nullptr) {
|
|
*outBytesRead = static_cast<std::size_t>(bytesRead);
|
|
}
|
|
co_return StreamError::SUCCESS;
|
|
}
|
|
else if (errno != EAGAIN)
|
|
{
|
|
co_return translateErrno();
|
|
}
|
|
co_await mijin::c_suspend();
|
|
}
|
|
}
|
|
|
|
mijin::Task<StreamError> TCPStream::c_writeRaw(std::span<const std::uint8_t> buffer)
|
|
{
|
|
MIJIN_ASSERT(isOpen(), "Socket is not open.");
|
|
setAsync(true);
|
|
|
|
while (true)
|
|
{
|
|
if (osSend(handle_, buffer, 0) >= 0)
|
|
{
|
|
co_return StreamError::SUCCESS;
|
|
}
|
|
else if (errno != EAGAIN)
|
|
{
|
|
co_return translateErrno();
|
|
}
|
|
co_await mijin::c_suspend();
|
|
}
|
|
}
|
|
|
|
void TCPStream::setAsync(bool async)
|
|
{
|
|
if (async == async_)
|
|
{
|
|
return;
|
|
}
|
|
async_ = async;
|
|
|
|
osSetSocketNonBlocking(handle_, async);
|
|
}
|
|
|
|
std::size_t TCPStream::tell()
|
|
{
|
|
return 0;
|
|
}
|
|
|
|
StreamError TCPStream::seek(std::intptr_t /* pos */, mijin::SeekMode /* seekMode */)
|
|
{
|
|
return StreamError::NOT_SUPPORTED;
|
|
}
|
|
|
|
void TCPStream::flush()
|
|
{
|
|
|
|
}
|
|
|
|
bool TCPStream::isAtEnd()
|
|
{
|
|
return !isOpen();
|
|
}
|
|
|
|
StreamFeatures TCPStream::getFeatures()
|
|
{
|
|
return {
|
|
.read = true,
|
|
.write = true,
|
|
.tell = false,
|
|
.seek = false,
|
|
.async = true,
|
|
.readOptions = {
|
|
.partial = true,
|
|
.peek = true
|
|
}
|
|
};
|
|
}
|
|
|
|
StreamError TCPStream::open(ip_address_t address, std::uint16_t port) noexcept
|
|
{
|
|
MIJIN_ASSERT(!isOpen(), "Socket is already open.");
|
|
|
|
handle_ = osCreateSocket(AF_INET, SOCK_STREAM, 0);
|
|
if (!osIsSocketValid(handle_))
|
|
{
|
|
return translateErrno();
|
|
}
|
|
|
|
const bool connected = std::visit(Visitor{
|
|
[&](const IPv4Address& address4)
|
|
{
|
|
#if __BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__
|
|
#error "TODO: swap byte order of the address"
|
|
#endif
|
|
sockaddr_in connectAddress =
|
|
{
|
|
.sin_family = AF_INET,
|
|
.sin_port = htons(port),
|
|
.sin_addr = std::bit_cast<in_addr>(address4.octets)
|
|
};
|
|
return connect(handle_, reinterpret_cast<sockaddr*>(&connectAddress), sizeof(sockaddr_in)) == 0;
|
|
},
|
|
[&](const IPv6Address& address6)
|
|
{
|
|
sockaddr_in6 connectAddress =
|
|
{
|
|
.sin6_family = AF_INET,
|
|
.sin6_port = htons(port),
|
|
.sin6_addr = std::bit_cast<in6_addr>(address6)
|
|
};
|
|
return connect(handle_, reinterpret_cast<sockaddr*>(&connectAddress), sizeof(sockaddr_in6)) == 0;
|
|
}
|
|
}, address);
|
|
if (!connected)
|
|
{
|
|
osCloseSocket(handle_);
|
|
handle_ = INVALID_SOCKET_HANDLE;
|
|
return translateErrno();
|
|
}
|
|
|
|
return StreamError::SUCCESS;
|
|
}
|
|
|
|
void TCPStream::close() noexcept
|
|
{
|
|
MIJIN_ASSERT(isOpen(), "Socket is not open.");
|
|
osCloseSocket(handle_);
|
|
handle_ = INVALID_SOCKET_HANDLE;
|
|
}
|
|
|
|
TCPStream& TCPSocket::getStream() noexcept
|
|
{
|
|
return stream_;
|
|
}
|
|
|
|
StreamError TCPServerSocket::setup(ip_address_t address, std::uint16_t port) noexcept
|
|
{
|
|
MIJIN_ASSERT(!isListening(), "Socket is already listening.");
|
|
|
|
handle_ = osCreateSocket(AF_INET, SOCK_STREAM, 0);
|
|
if (!osIsSocketValid(handle_))
|
|
{
|
|
return translateErrno();
|
|
}
|
|
if (setsockopt(handle_, SOL_SOCKET, SO_REUSEADDR, &SOCKOPT_ONE, sizeof(int)) != 0)
|
|
{
|
|
close();
|
|
return translateErrno();
|
|
}
|
|
|
|
const bool bound = std::visit(Visitor{
|
|
[&](const IPv4Address& address4)
|
|
{
|
|
#if __BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__
|
|
#error "TODO: swap byte order of the address"
|
|
#endif
|
|
sockaddr_in bindAddress =
|
|
{
|
|
.sin_family = AF_INET,
|
|
.sin_port = htons(port),
|
|
.sin_addr = std::bit_cast<in_addr>(address4.octets)
|
|
};
|
|
return bind(handle_, reinterpret_cast<sockaddr*>(&bindAddress), sizeof(sockaddr_in)) == 0;
|
|
},
|
|
[&](const IPv6Address& address6)
|
|
{
|
|
sockaddr_in6 bindAddress =
|
|
{
|
|
.sin6_family = AF_INET,
|
|
.sin6_port = htons(port),
|
|
.sin6_addr = std::bit_cast<in6_addr>(address6)
|
|
};
|
|
return bind(handle_, reinterpret_cast<sockaddr*>(&bindAddress), sizeof(sockaddr_in6)) == 0;
|
|
}
|
|
}, address);
|
|
if (!bound
|
|
|| (listen(handle_, LISTEN_BACKLOG) < 0)
|
|
|| !osSetSocketNonBlocking(handle_, true))
|
|
{
|
|
close();
|
|
return translateErrno();
|
|
}
|
|
return StreamError::SUCCESS;
|
|
}
|
|
|
|
void TCPServerSocket::close() noexcept
|
|
{
|
|
MIJIN_ASSERT(isListening(), "Socket is not listening.");
|
|
|
|
osCloseSocket(handle_);
|
|
handle_ = INVALID_SOCKET_HANDLE;
|
|
}
|
|
|
|
Task<StreamResult<std::unique_ptr<Socket>>> TCPServerSocket::c_waitForConnection() noexcept
|
|
{
|
|
while (isListening())
|
|
{
|
|
sockaddr_in client = {};
|
|
socklen_t LENGTH = sizeof(sockaddr_in);
|
|
const socket_handle_t newSocket = accept(handle_, reinterpret_cast<sockaddr*>(&client), &LENGTH);
|
|
if (!osIsSocketValid(newSocket))
|
|
{
|
|
if (errno != EAGAIN)
|
|
{
|
|
co_return translateErrno();
|
|
}
|
|
co_await c_suspend();
|
|
continue;
|
|
}
|
|
std::unique_ptr<TCPSocket> socket = std::make_unique<TCPSocket>();
|
|
socket->stream_.handle_ = newSocket;
|
|
co_return socket;
|
|
}
|
|
co_return StreamError::CONNECTION_CLOSED;
|
|
}
|
|
}
|