#include "./socket.hpp" #include "./detail/net_common.hpp" #include "../detect.hpp" #include "../util/variant.hpp" namespace mijin { namespace { inline constexpr int LISTEN_BACKLOG = 3; StreamError translateErrno() noexcept { switch (errno) { case EIO: return StreamError::IO_ERROR; default: return StreamError::UNKNOWN_ERROR; } } int readFlags(const ReadOptions& options) { return (options.partial ? 0 : MSG_WAITALL) | (options.peek ? MSG_PEEK : 0); } #if MIJIN_TARGET_OS == MIJIN_OS_LINUX const int SOCKOPT_ONE = 1; bool appendSocketFlags(int handle, int flags) noexcept { const int currentFlags = fcntl(handle, F_GETFL); if (currentFlags < 0) { return false; } return fcntl(handle, F_SETFL, currentFlags | flags) >= 0; } bool removeSocketFlags(int handle, int flags) noexcept { const int currentFlags = fcntl(handle, F_GETFL); if (currentFlags < 0) { return false; } return fcntl(handle, F_SETFL, currentFlags & ~flags) >= 0; } long osRecv(int socket, std::span buffer, int flags) { return static_cast(recv(socket, buffer.data(), buffer.size(), flags)); } long osSend(int socket, std::span buffer, int flags) { return static_cast(send(socket, buffer.data(), buffer.size(), flags)); } int osCreateSocket(int domain, int type, int protocol) { return socket(domain, type, protocol); } int osCloseSocket(int socket) { return ::close(socket); } bool osIsSocketValid(int socket) { return socket >= 0; } bool osSetSocketNonBlocking(int socket, bool blocking) { if (blocking) { return appendSocketFlags(socket, O_NONBLOCK); } else { return removeSocketFlags(socket, O_NONBLOCK); } } #elif MIJIN_TARGET_OS == MIJIN_OS_WINDOWS using in_addr_t = ULONG; const char SOCKOPT_ONE = 1; thread_local bool gWsaInited = false; class WSAGuard { public: ~WSAGuard() noexcept { if (gWsaInited) { WSACleanup(); } } } thread_local [[maybe_unused]] gWsaGuard; long osRecv(SOCKET socket, std::span buffer, int flags) { return recv(socket, reinterpret_cast(buffer.data()), static_cast(buffer.size()), flags); } long osSend(SOCKET socket, std::span buffer, int flags) { return send(socket, reinterpret_cast(buffer.data()), static_cast(buffer.size()), flags); } SOCKET osCreateSocket(int addressFamily, int type, int protocol) { if (!detail::initWSA()) { return INVALID_SOCKET_HANDLE; } return socket(addressFamily, type, protocol); } int osCloseSocket(SOCKET socket) { return closesocket(socket); } bool osIsSocketValid(SOCKET socket) { return socket != INVALID_SOCKET; } bool osSetSocketNonBlocking(SOCKET socket, bool blocking) { u_long value = blocking ? 0 : 1; return ioctlsocket(socket, FIONBIO, &value) == NO_ERROR; } #endif // MIJIN_TARGET_OS == MIJIN_OS_LINUX } namespace detail { #if MIJIN_TARGET_OS == MIJIN_OS_WINDOWS bool initWSA() noexcept { if (gWsaInited) { return true; } WSADATA wsaData; if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) { return false; } gWsaInited = true; return true; } StreamError translateWSAError() noexcept { // TODO return StreamError::UNKNOWN_ERROR; } StreamError translateWinError(DWORD error) noexcept { // TODO (void) error; return StreamError::UNKNOWN_ERROR; } StreamError translateWinError() noexcept { return translateWinError(GetLastError()); } #endif // MIJIN_TARGET_OS == MIJIN_OS_WINDOWS }// namespace impl StreamError TCPStream::readRaw(std::span buffer, const ReadOptions& options, std::size_t* outBytesRead) { MIJIN_ASSERT(isOpen(), "Socket is not open."); setAsync(false); const long bytesRead = osRecv(handle_, buffer, readFlags(options)); if (bytesRead < 0) { return translateErrno(); } *outBytesRead = static_cast(bytesRead); return StreamError::SUCCESS; } StreamError TCPStream::writeRaw(std::span buffer) { MIJIN_ASSERT(isOpen(), "Socket is not open."); setAsync(false); if (osSend(handle_, buffer, 0) < 0) { return translateErrno(); } return StreamError::SUCCESS; } mijin::Task TCPStream::c_readRaw(std::span buffer, const ReadOptions& options, std::size_t* outBytesRead) { MIJIN_ASSERT(isOpen(), "Socket is not open."); setAsync(true); if (buffer.empty()) { co_return StreamError::SUCCESS; } while(true) { const long bytesRead = osRecv(handle_, buffer, readFlags(options)); if (bytesRead > 0) { if (outBytesRead != nullptr) { *outBytesRead = static_cast(bytesRead); } co_return StreamError::SUCCESS; } if (bytesRead == 0) { co_return StreamError::CONNECTION_CLOSED; } if (errno != EAGAIN) { co_return translateErrno(); } co_await mijin::c_suspend(); } } Task TCPStream::c_writeRaw(std::span buffer) { MIJIN_ASSERT(isOpen(), "Socket is not open."); if (buffer.empty()) { co_return StreamError::SUCCESS; } setAsync(true); while (true) { const long bytesSent = osSend(handle_, buffer, 0); if (bytesSent == static_cast(buffer.size())) { co_return StreamError::SUCCESS; } else if (bytesSent == 0) { co_return StreamError::CONNECTION_CLOSED; } else if (errno != EAGAIN) { co_return translateErrno(); } co_await c_suspend(); } } void TCPStream::setAsync(bool async) { if (async == async_) { return; } async_ = async; osSetSocketNonBlocking(handle_, async); } std::size_t TCPStream::tell() { return 0; } StreamError TCPStream::seek(std::intptr_t /* pos */, SeekMode /* seekMode */) { return StreamError::NOT_SUPPORTED; } void TCPStream::flush() { } bool TCPStream::isAtEnd() { return !isOpen(); } StreamFeatures TCPStream::getFeatures() { return { .read = true, .write = true, .tell = false, .seek = false, .async = true, .readOptions = { .partial = true, .peek = true } }; } StreamError TCPStream::open(ip_address_t address, std::uint16_t port) noexcept { MIJIN_ASSERT(!isOpen(), "Socket is already open."); handle_ = osCreateSocket(AF_INET, SOCK_STREAM, 0); if (!osIsSocketValid(handle_)) { return translateErrno(); } const bool connected = std::visit(Visitor{ [&](const IPv4Address& address4) { #if __BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__ #error "TODO: swap byte order of the address" #endif sockaddr_in connectAddress = { .sin_family = AF_INET, .sin_port = htons(port), .sin_addr = std::bit_cast(address4.octets) }; return connect(handle_, reinterpret_cast(&connectAddress), sizeof(sockaddr_in)) == 0; }, [&](IPv6Address address6) { for (std::uint16_t& hextet : address6.hextets) { hextet = htons(hextet); } sockaddr_in6 connectAddress = { .sin6_family = AF_INET6, .sin6_port = htons(port), .sin6_addr = std::bit_cast(address6) }; return connect(handle_, reinterpret_cast(&connectAddress), sizeof(sockaddr_in6)) == 0; } }, address); if (!connected) { osCloseSocket(handle_); handle_ = INVALID_SOCKET_HANDLE; return translateErrno(); } return StreamError::SUCCESS; } void TCPStream::close() noexcept { MIJIN_ASSERT(isOpen(), "Socket is not open."); osCloseSocket(handle_); handle_ = INVALID_SOCKET_HANDLE; } TCPStream& TCPSocket::getStream() noexcept { return stream_; } StreamError TCPServerSocket::setup(ip_address_t address, std::uint16_t port) noexcept { MIJIN_ASSERT(!isListening(), "Socket is already listening."); handle_ = osCreateSocket(AF_INET, SOCK_STREAM, 0); if (!osIsSocketValid(handle_)) { return translateErrno(); } if (setsockopt(handle_, SOL_SOCKET, SO_REUSEADDR, &SOCKOPT_ONE, sizeof(int)) != 0) { close(); return translateErrno(); } const bool bound = std::visit(Visitor{ [&](const IPv4Address& address4) { #if __BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__ #error "TODO: swap byte order of the address" #endif sockaddr_in bindAddress = { .sin_family = AF_INET, .sin_port = htons(port), .sin_addr = std::bit_cast(address4.octets) }; return bind(handle_, reinterpret_cast(&bindAddress), sizeof(sockaddr_in)) == 0; }, [&](IPv6Address address6) { for (std::uint16_t& hextet : address6.hextets) { hextet = htons(hextet); } sockaddr_in6 bindAddress = { .sin6_family = AF_INET6, .sin6_port = htons(port), .sin6_addr = std::bit_cast(address6) }; return bind(handle_, reinterpret_cast(&bindAddress), sizeof(sockaddr_in6)) == 0; } }, address); if (!bound || (listen(handle_, LISTEN_BACKLOG) < 0) || !osSetSocketNonBlocking(handle_, true)) { close(); return translateErrno(); } return StreamError::SUCCESS; } void TCPServerSocket::close() noexcept { MIJIN_ASSERT(isListening(), "Socket is not listening."); osCloseSocket(handle_); handle_ = INVALID_SOCKET_HANDLE; } Task>> TCPServerSocket::c_waitForConnection() noexcept { while (isListening()) { sockaddr_in client = {}; socklen_t LENGTH = sizeof(sockaddr_in); const socket_handle_t newSocket = accept(handle_, reinterpret_cast(&client), &LENGTH); if (!osIsSocketValid(newSocket)) { if (errno != EAGAIN) { co_return translateErrno(); } co_await c_suspend(); continue; } std::unique_ptr socket = std::make_unique(); socket->stream_.handle_ = newSocket; co_return socket; } co_return StreamError::CONNECTION_CLOSED; } }