mijin2/source/mijin/net/socket.cpp

463 lines
11 KiB
C++

#include "./socket.hpp"
#include "./detail/net_common.hpp"
#include "../detect.hpp"
#include "../util/variant.hpp"
namespace mijin
{
namespace
{
inline constexpr int LISTEN_BACKLOG = 3;
StreamError translateErrno() noexcept
{
switch (errno)
{
case EIO:
return StreamError::IO_ERROR;
default:
return StreamError::UNKNOWN_ERROR;
}
}
int readFlags(const ReadOptions& options)
{
return (options.partial ? 0 : MSG_WAITALL)
| (options.peek ? MSG_PEEK : 0);
}
#if MIJIN_TARGET_OS == MIJIN_OS_LINUX
const int SOCKOPT_ONE = 1;
bool appendSocketFlags(int handle, int flags) noexcept
{
const int currentFlags = fcntl(handle, F_GETFL);
if (currentFlags < 0)
{
return false;
}
return fcntl(handle, F_SETFL, currentFlags | flags) >= 0;
}
bool removeSocketFlags(int handle, int flags) noexcept
{
const int currentFlags = fcntl(handle, F_GETFL);
if (currentFlags < 0)
{
return false;
}
return fcntl(handle, F_SETFL, currentFlags & ~flags) >= 0;
}
long osRecv(int socket, std::span<std::uint8_t> buffer, int flags)
{
return static_cast<long>(recv(socket, buffer.data(), buffer.size(), flags));
}
long osSend(int socket, std::span<const std::uint8_t> buffer, int flags)
{
return static_cast<long>(send(socket, buffer.data(), buffer.size(), flags));
}
int osCreateSocket(int domain, int type, int protocol)
{
return socket(domain, type, protocol);
}
int osCloseSocket(int socket)
{
return ::close(socket);
}
bool osIsSocketValid(int socket)
{
return socket >= 0;
}
bool osSetSocketNonBlocking(int socket, bool blocking)
{
if (blocking)
{
return appendSocketFlags(socket, O_NONBLOCK);
}
else
{
return removeSocketFlags(socket, O_NONBLOCK);
}
}
#elif MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
using in_addr_t = ULONG;
const char SOCKOPT_ONE = 1;
thread_local bool gWsaInited = false;
class WSAGuard
{
public:
~WSAGuard() noexcept
{
if (gWsaInited)
{
WSACleanup();
}
}
} thread_local [[maybe_unused]] gWsaGuard;
long osRecv(SOCKET socket, std::span<std::uint8_t> buffer, int flags)
{
return recv(socket, reinterpret_cast<char*>(buffer.data()), static_cast<int>(buffer.size()), flags);
}
long osSend(SOCKET socket, std::span<const std::uint8_t> buffer, int flags)
{
return send(socket, reinterpret_cast<const char*>(buffer.data()), static_cast<int>(buffer.size()), flags);
}
SOCKET osCreateSocket(int addressFamily, int type, int protocol)
{
if (!detail::initWSA())
{
return INVALID_SOCKET_HANDLE;
}
return socket(addressFamily, type, protocol);
}
int osCloseSocket(SOCKET socket)
{
return closesocket(socket);
}
bool osIsSocketValid(SOCKET socket)
{
return socket != INVALID_SOCKET;
}
bool osSetSocketNonBlocking(SOCKET socket, bool blocking)
{
u_long value = blocking ? 0 : 1;
return ioctlsocket(socket, FIONBIO, &value) == NO_ERROR;
}
#endif // MIJIN_TARGET_OS == MIJIN_OS_LINUX
}
namespace detail
{
#if MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
bool initWSA() noexcept
{
if (gWsaInited)
{
return true;
}
WSADATA wsaData;
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0)
{
return false;
}
gWsaInited = true;
return true;
}
StreamError translateWSAError() noexcept
{
// TODO
return StreamError::UNKNOWN_ERROR;
}
StreamError translateWinError(DWORD error) noexcept
{
// TODO
(void) error;
return StreamError::UNKNOWN_ERROR;
}
StreamError translateWinError() noexcept
{
return translateWinError(GetLastError());
}
#endif // MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
}// namespace impl
StreamError TCPStream::readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options, std::size_t* outBytesRead)
{
MIJIN_ASSERT(isOpen(), "Socket is not open.");
setAsync(false);
const long bytesRead = osRecv(handle_, buffer, readFlags(options));
if (bytesRead < 0)
{
return translateErrno();
}
*outBytesRead = static_cast<std::size_t>(bytesRead);
return StreamError::SUCCESS;
}
StreamError TCPStream::writeRaw(std::span<const std::uint8_t> buffer)
{
MIJIN_ASSERT(isOpen(), "Socket is not open.");
setAsync(false);
if (osSend(handle_, buffer, 0) < 0)
{
return translateErrno();
}
return StreamError::SUCCESS;
}
mijin::Task<StreamError> TCPStream::c_readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options, std::size_t* outBytesRead)
{
MIJIN_ASSERT(isOpen(), "Socket is not open.");
setAsync(true);
if (buffer.empty())
{
co_return StreamError::SUCCESS;
}
while(true)
{
const long bytesRead = osRecv(handle_, buffer, readFlags(options));
if (bytesRead > 0)
{
if (outBytesRead != nullptr) {
*outBytesRead = static_cast<std::size_t>(bytesRead);
}
co_return StreamError::SUCCESS;
}
if (bytesRead == 0)
{
co_return StreamError::CONNECTION_CLOSED;
}
if (errno != EAGAIN)
{
co_return translateErrno();
}
co_await mijin::c_suspend();
}
}
Task<StreamError> TCPStream::c_writeRaw(std::span<const std::uint8_t> buffer)
{
MIJIN_ASSERT(isOpen(), "Socket is not open.");
if (buffer.empty())
{
co_return StreamError::SUCCESS;
}
setAsync(true);
while (true)
{
const long bytesSent = osSend(handle_, buffer, 0);
if (bytesSent == static_cast<long>(buffer.size()))
{
co_return StreamError::SUCCESS;
}
else if (bytesSent == 0)
{
co_return StreamError::CONNECTION_CLOSED;
}
else if (errno != EAGAIN)
{
co_return translateErrno();
}
co_await c_suspend();
}
}
void TCPStream::setAsync(bool async)
{
if (async == async_)
{
return;
}
async_ = async;
osSetSocketNonBlocking(handle_, async);
}
std::size_t TCPStream::tell()
{
return 0;
}
StreamError TCPStream::seek(std::intptr_t /* pos */, SeekMode /* seekMode */)
{
return StreamError::NOT_SUPPORTED;
}
void TCPStream::flush()
{
}
bool TCPStream::isAtEnd()
{
return !isOpen();
}
StreamFeatures TCPStream::getFeatures()
{
return {
.read = true,
.write = true,
.tell = false,
.seek = false,
.async = true,
.readOptions = {
.partial = true,
.peek = true
}
};
}
StreamError TCPStream::open(ip_address_t address, std::uint16_t port) noexcept
{
MIJIN_ASSERT(!isOpen(), "Socket is already open.");
handle_ = osCreateSocket(AF_INET, SOCK_STREAM, 0);
if (!osIsSocketValid(handle_))
{
return translateErrno();
}
const bool connected = std::visit(Visitor{
[&](const IPv4Address& address4)
{
#if __BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__
#error "TODO: swap byte order of the address"
#endif
sockaddr_in connectAddress =
{
.sin_family = AF_INET,
.sin_port = htons(port),
.sin_addr = std::bit_cast<in_addr>(address4.octets)
};
return connect(handle_, reinterpret_cast<sockaddr*>(&connectAddress), sizeof(sockaddr_in)) == 0;
},
[&](IPv6Address address6)
{
for (std::uint16_t& hextet : address6.hextets) {
hextet = htons(hextet);
}
sockaddr_in6 connectAddress =
{
.sin6_family = AF_INET6,
.sin6_port = htons(port),
.sin6_addr = std::bit_cast<in6_addr>(address6)
};
return connect(handle_, reinterpret_cast<sockaddr*>(&connectAddress), sizeof(sockaddr_in6)) == 0;
}
}, address);
if (!connected)
{
osCloseSocket(handle_);
handle_ = INVALID_SOCKET_HANDLE;
return translateErrno();
}
return StreamError::SUCCESS;
}
void TCPStream::close() noexcept
{
MIJIN_ASSERT(isOpen(), "Socket is not open.");
osCloseSocket(handle_);
handle_ = INVALID_SOCKET_HANDLE;
}
TCPStream& TCPSocket::getStream() noexcept
{
return stream_;
}
StreamError TCPServerSocket::setup(ip_address_t address, std::uint16_t port) noexcept
{
MIJIN_ASSERT(!isListening(), "Socket is already listening.");
handle_ = osCreateSocket(AF_INET, SOCK_STREAM, 0);
if (!osIsSocketValid(handle_))
{
return translateErrno();
}
if (setsockopt(handle_, SOL_SOCKET, SO_REUSEADDR, &SOCKOPT_ONE, sizeof(int)) != 0)
{
close();
return translateErrno();
}
const bool bound = std::visit(Visitor{
[&](const IPv4Address& address4)
{
#if __BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__
#error "TODO: swap byte order of the address"
#endif
sockaddr_in bindAddress =
{
.sin_family = AF_INET,
.sin_port = htons(port),
.sin_addr = std::bit_cast<in_addr>(address4.octets)
};
return bind(handle_, reinterpret_cast<sockaddr*>(&bindAddress), sizeof(sockaddr_in)) == 0;
},
[&](IPv6Address address6)
{
for (std::uint16_t& hextet : address6.hextets) {
hextet = htons(hextet);
}
sockaddr_in6 bindAddress =
{
.sin6_family = AF_INET6,
.sin6_port = htons(port),
.sin6_addr = std::bit_cast<in6_addr>(address6)
};
return bind(handle_, reinterpret_cast<sockaddr*>(&bindAddress), sizeof(sockaddr_in6)) == 0;
}
}, address);
if (!bound
|| (listen(handle_, LISTEN_BACKLOG) < 0)
|| !osSetSocketNonBlocking(handle_, true))
{
close();
return translateErrno();
}
return StreamError::SUCCESS;
}
void TCPServerSocket::close() noexcept
{
MIJIN_ASSERT(isListening(), "Socket is not listening.");
osCloseSocket(handle_);
handle_ = INVALID_SOCKET_HANDLE;
}
Task<StreamResult<std::unique_ptr<Socket>>> TCPServerSocket::c_waitForConnection() noexcept
{
while (isListening())
{
sockaddr_in client = {};
socklen_t LENGTH = sizeof(sockaddr_in);
const socket_handle_t newSocket = accept(handle_, reinterpret_cast<sockaddr*>(&client), &LENGTH);
if (!osIsSocketValid(newSocket))
{
if (errno != EAGAIN)
{
co_return translateErrno();
}
co_await c_suspend();
continue;
}
std::unique_ptr<TCPSocket> socket = std::make_unique<TCPSocket>();
socket->stream_.handle_ = newSocket;
co_return socket;
}
co_return StreamError::CONNECTION_CLOSED;
}
}