486 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			486 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
 | 
						|
#include "./socket.hpp"
 | 
						|
 | 
						|
#include "./detail/net_common.hpp"
 | 
						|
#include "../detect.hpp"
 | 
						|
#include "../internal/common.hpp"
 | 
						|
#include "../util/variant.hpp"
 | 
						|
 | 
						|
#if MIJIN_COMPILER == MIJIN_COMPILER_CLANG
 | 
						|
#pragma clang diagnostic push
 | 
						|
#pragma clang diagnostic ignored "-Wmissing-field-initializers"
 | 
						|
#endif // MIJIN_COMPILER == MIJIN_COMPILER_CLANG
 | 
						|
 | 
						|
namespace mijin
 | 
						|
{
 | 
						|
namespace
 | 
						|
{
 | 
						|
inline constexpr int LISTEN_BACKLOG = 3;
 | 
						|
StreamError translateErrno() MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    switch (errno)
 | 
						|
    {
 | 
						|
    case EIO:
 | 
						|
        return StreamError::IO_ERROR;
 | 
						|
    case ECONNREFUSED:
 | 
						|
        return StreamError::CONNECTION_REFUSED;
 | 
						|
    default:
 | 
						|
        return StreamError::UNKNOWN_ERROR;
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
int readFlags(const ReadOptions& options)
 | 
						|
{
 | 
						|
    return (options.partial ? 0 : MSG_WAITALL)
 | 
						|
         | (options.peek ? MSG_PEEK : 0);
 | 
						|
}
 | 
						|
 | 
						|
#if MIJIN_TARGET_OS == MIJIN_OS_LINUX
 | 
						|
const int SOCKOPT_ONE = 1;
 | 
						|
 | 
						|
bool appendSocketFlags(int handle, int flags) MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    const int currentFlags = fcntl(handle, F_GETFL);
 | 
						|
    if (currentFlags < 0)
 | 
						|
    {
 | 
						|
        return false;
 | 
						|
    }
 | 
						|
    return fcntl(handle, F_SETFL, currentFlags | flags) >= 0;
 | 
						|
}
 | 
						|
 | 
						|
bool removeSocketFlags(int handle, int flags) MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    const int currentFlags = fcntl(handle, F_GETFL);
 | 
						|
    if (currentFlags < 0)
 | 
						|
    {
 | 
						|
        return false;
 | 
						|
    }
 | 
						|
    return fcntl(handle, F_SETFL, currentFlags & ~flags) >= 0;
 | 
						|
}
 | 
						|
 | 
						|
long osRecv(int socket, std::span<std::uint8_t> buffer, int flags) MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    return static_cast<long>(recv(socket, buffer.data(), buffer.size(), flags));
 | 
						|
}
 | 
						|
 | 
						|
long osSend(int socket, std::span<const std::uint8_t> buffer, int flags) MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    return static_cast<long>(send(socket, buffer.data(), buffer.size(), flags));
 | 
						|
}
 | 
						|
 | 
						|
int osCreateSocket(int domain, int type, int protocol) MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    return socket(domain, type, protocol);
 | 
						|
}
 | 
						|
 | 
						|
int osCloseSocket(int socket) MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    return ::close(socket);
 | 
						|
}
 | 
						|
 | 
						|
bool osIsSocketValid(int socket) MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    return socket >= 0;
 | 
						|
}
 | 
						|
 | 
						|
bool osSetSocketNonBlocking(int socket, bool blocking) MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    if (blocking)
 | 
						|
    {
 | 
						|
        return appendSocketFlags(socket, O_NONBLOCK);
 | 
						|
    }
 | 
						|
    else
 | 
						|
    {
 | 
						|
        return removeSocketFlags(socket, O_NONBLOCK);
 | 
						|
    }
 | 
						|
}
 | 
						|
#elif MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
 | 
						|
using in_addr_t = ULONG;
 | 
						|
 | 
						|
const char SOCKOPT_ONE = 1;
 | 
						|
thread_local bool gWsaInited = false;
 | 
						|
 | 
						|
class WSAGuard
 | 
						|
{
 | 
						|
public:
 | 
						|
    ~WSAGuard() MIJIN_NOEXCEPT
 | 
						|
    {
 | 
						|
        if (gWsaInited)
 | 
						|
        {
 | 
						|
            WSACleanup();
 | 
						|
        }
 | 
						|
    }
 | 
						|
} thread_local [[maybe_unused]] gWsaGuard;
 | 
						|
 | 
						|
long osRecv(SOCKET socket, std::span<std::uint8_t> buffer, int flags) MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    return recv(socket, reinterpret_cast<char*>(buffer.data()), static_cast<int>(buffer.size()), flags);
 | 
						|
}
 | 
						|
 | 
						|
long osSend(SOCKET socket, std::span<const std::uint8_t> buffer, int flags) MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    return send(socket, reinterpret_cast<const char*>(buffer.data()), static_cast<int>(buffer.size()), flags);
 | 
						|
}
 | 
						|
 | 
						|
SOCKET osCreateSocket(int addressFamily, int type, int protocol) MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    if (!detail::initWSA())
 | 
						|
    {
 | 
						|
        return INVALID_SOCKET_HANDLE;
 | 
						|
    }
 | 
						|
    return socket(addressFamily, type, protocol);
 | 
						|
}
 | 
						|
 | 
						|
int osCloseSocket(SOCKET socket) MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    return closesocket(socket);
 | 
						|
}
 | 
						|
 | 
						|
bool osIsSocketValid(SOCKET socket) MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    return socket != INVALID_SOCKET;
 | 
						|
}
 | 
						|
 | 
						|
bool osSetSocketNonBlocking(SOCKET socket, bool blocking) MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    u_long value = blocking ? 0 : 1;
 | 
						|
    return ioctlsocket(socket, FIONBIO, &value) == NO_ERROR;
 | 
						|
}
 | 
						|
#endif // MIJIN_TARGET_OS == MIJIN_OS_LINUX
 | 
						|
}
 | 
						|
 | 
						|
namespace detail
 | 
						|
{
 | 
						|
#if MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
 | 
						|
bool initWSA() MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    if (gWsaInited)
 | 
						|
    {
 | 
						|
        return true;
 | 
						|
    }
 | 
						|
 | 
						|
    WSADATA wsaData;
 | 
						|
    if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0)
 | 
						|
    {
 | 
						|
        return false;
 | 
						|
    }
 | 
						|
    gWsaInited = true;
 | 
						|
    return true;
 | 
						|
}
 | 
						|
 | 
						|
StreamError translateWSAError() MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    // TODO
 | 
						|
    return StreamError::UNKNOWN_ERROR;
 | 
						|
}
 | 
						|
 | 
						|
StreamError translateWinError(DWORD error) MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    // TODO
 | 
						|
    (void) error;
 | 
						|
    return StreamError::UNKNOWN_ERROR;
 | 
						|
}
 | 
						|
 | 
						|
StreamError translateWinError() MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    return translateWinError(GetLastError());
 | 
						|
}
 | 
						|
#endif // MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
 | 
						|
}// namespace impl
 | 
						|
 | 
						|
StreamError TCPStream::readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options, std::size_t* outBytesRead)
 | 
						|
{
 | 
						|
    MIJIN_ASSERT(isOpen(), "Socket is not open.");
 | 
						|
    setNoblock(options.noBlock);
 | 
						|
 | 
						|
    const long bytesRead = osRecv(handle_, buffer, readFlags(options));
 | 
						|
    if (bytesRead < 0)
 | 
						|
    {
 | 
						|
        if (!options.noBlock || errno != EAGAIN)
 | 
						|
        {
 | 
						|
            return translateErrno();
 | 
						|
        }
 | 
						|
        if (outBytesRead != nullptr)
 | 
						|
        {
 | 
						|
            *outBytesRead = 0;
 | 
						|
        }
 | 
						|
        return StreamError::SUCCESS;
 | 
						|
    }
 | 
						|
    if (outBytesRead != nullptr)
 | 
						|
    {
 | 
						|
        *outBytesRead = static_cast<std::size_t>(bytesRead);
 | 
						|
    }
 | 
						|
 | 
						|
    return StreamError::SUCCESS;
 | 
						|
}
 | 
						|
 | 
						|
StreamError TCPStream::writeRaw(std::span<const std::uint8_t> buffer)
 | 
						|
{
 | 
						|
    MIJIN_ASSERT(isOpen(), "Socket is not open.");
 | 
						|
    setNoblock(false);
 | 
						|
 | 
						|
    if (osSend(handle_, buffer, 0) < 0)
 | 
						|
    {
 | 
						|
        return translateErrno();
 | 
						|
    }
 | 
						|
 | 
						|
    return StreamError::SUCCESS;
 | 
						|
}
 | 
						|
 | 
						|
mijin::Task<StreamError> TCPStream::c_readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options, std::size_t* outBytesRead)
 | 
						|
{
 | 
						|
    MIJIN_ASSERT(isOpen(), "Socket is not open.");
 | 
						|
    setNoblock(true);
 | 
						|
 | 
						|
    if (buffer.empty())
 | 
						|
    {
 | 
						|
        co_return StreamError::SUCCESS;
 | 
						|
    }
 | 
						|
 | 
						|
    while(true)
 | 
						|
    {
 | 
						|
        const long bytesRead = osRecv(handle_, buffer, readFlags(options));
 | 
						|
        if (bytesRead > 0)
 | 
						|
        {
 | 
						|
            if (outBytesRead != nullptr) {
 | 
						|
                *outBytesRead = static_cast<std::size_t>(bytesRead);
 | 
						|
            }
 | 
						|
            co_return StreamError::SUCCESS;
 | 
						|
        }
 | 
						|
        if (bytesRead == 0)
 | 
						|
        {
 | 
						|
            co_return StreamError::CONNECTION_CLOSED;
 | 
						|
        }
 | 
						|
        if (errno != EAGAIN)
 | 
						|
        {
 | 
						|
            co_return translateErrno();
 | 
						|
        }
 | 
						|
        co_await mijin::c_suspend();
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
Task<StreamError> TCPStream::c_writeRaw(std::span<const std::uint8_t> buffer)
 | 
						|
{
 | 
						|
    MIJIN_ASSERT(isOpen(), "Socket is not open.");
 | 
						|
 | 
						|
    if (buffer.empty())
 | 
						|
    {
 | 
						|
        co_return StreamError::SUCCESS;
 | 
						|
    }
 | 
						|
 | 
						|
    setNoblock(true);
 | 
						|
 | 
						|
    while (true)
 | 
						|
    {
 | 
						|
        const long bytesSent = osSend(handle_, buffer, 0);
 | 
						|
        if (bytesSent == static_cast<long>(buffer.size()))
 | 
						|
        {
 | 
						|
            co_return StreamError::SUCCESS;
 | 
						|
        }
 | 
						|
        else if (bytesSent == 0)
 | 
						|
        {
 | 
						|
            co_return StreamError::CONNECTION_CLOSED;
 | 
						|
        }
 | 
						|
        else if (errno != EAGAIN)
 | 
						|
        {
 | 
						|
            co_return translateErrno();
 | 
						|
        }
 | 
						|
        co_await c_suspend();
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
void TCPStream::setNoblock(bool async)
 | 
						|
{
 | 
						|
    if (async == async_)
 | 
						|
    {
 | 
						|
        return;
 | 
						|
    }
 | 
						|
    async_ = async;
 | 
						|
 | 
						|
    osSetSocketNonBlocking(handle_, async);
 | 
						|
}
 | 
						|
 | 
						|
std::size_t TCPStream::tell()
 | 
						|
{
 | 
						|
    return 0;
 | 
						|
}
 | 
						|
 | 
						|
StreamError TCPStream::seek(std::intptr_t /* pos */, SeekMode /* seekMode */)
 | 
						|
{
 | 
						|
    return StreamError::NOT_SUPPORTED;
 | 
						|
}
 | 
						|
 | 
						|
void TCPStream::flush()
 | 
						|
{
 | 
						|
 | 
						|
}
 | 
						|
 | 
						|
bool TCPStream::isAtEnd()
 | 
						|
{
 | 
						|
    return !isOpen();
 | 
						|
}
 | 
						|
 | 
						|
StreamFeatures TCPStream::getFeatures()
 | 
						|
{
 | 
						|
    return {
 | 
						|
        .read = true,
 | 
						|
        .write = true,
 | 
						|
        .tell = false,
 | 
						|
        .seek = false,
 | 
						|
        .async = true,
 | 
						|
        .readOptions = {
 | 
						|
            .partial = true,
 | 
						|
            .peek = true
 | 
						|
        }
 | 
						|
    };
 | 
						|
}
 | 
						|
 | 
						|
StreamError TCPStream::open(ip_address_t address, std::uint16_t port) MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    MIJIN_ASSERT(!isOpen(), "Socket is already open.");
 | 
						|
 | 
						|
    handle_ = osCreateSocket(AF_INET, SOCK_STREAM, 0);
 | 
						|
    if (!osIsSocketValid(handle_))
 | 
						|
    {
 | 
						|
        return translateErrno();
 | 
						|
    }
 | 
						|
 | 
						|
    const bool connected = std::visit(Visitor{
 | 
						|
        [&](const IPv4Address& address4)
 | 
						|
        {
 | 
						|
#if __BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__
 | 
						|
#error "TODO: swap byte order of the address"
 | 
						|
#endif
 | 
						|
            sockaddr_in connectAddress =
 | 
						|
            {
 | 
						|
                .sin_family = AF_INET,
 | 
						|
                .sin_port = htons(port),
 | 
						|
                .sin_addr = std::bit_cast<in_addr>(address4.octets)
 | 
						|
            };
 | 
						|
            return connect(handle_, reinterpret_cast<sockaddr*>(&connectAddress), sizeof(sockaddr_in)) == 0;
 | 
						|
        },
 | 
						|
        [&](IPv6Address address6)
 | 
						|
        {
 | 
						|
            for (std::uint16_t& hextet : address6.hextets) {
 | 
						|
                hextet = htons(hextet);
 | 
						|
            }
 | 
						|
            sockaddr_in6 connectAddress =
 | 
						|
            {
 | 
						|
                .sin6_family = AF_INET6,
 | 
						|
                .sin6_port = htons(port),
 | 
						|
                .sin6_addr = std::bit_cast<in6_addr>(address6)
 | 
						|
            };
 | 
						|
            return connect(handle_, reinterpret_cast<sockaddr*>(&connectAddress), sizeof(sockaddr_in6)) == 0;
 | 
						|
        }
 | 
						|
    }, address);
 | 
						|
    if (!connected)
 | 
						|
    {
 | 
						|
        osCloseSocket(handle_);
 | 
						|
        handle_ = INVALID_SOCKET_HANDLE;
 | 
						|
        return translateErrno();
 | 
						|
    }
 | 
						|
 | 
						|
    return StreamError::SUCCESS;
 | 
						|
}
 | 
						|
 | 
						|
void TCPStream::close() MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    MIJIN_ASSERT(isOpen(), "Socket is not open.");
 | 
						|
    osCloseSocket(handle_);
 | 
						|
    handle_ = INVALID_SOCKET_HANDLE;
 | 
						|
}
 | 
						|
 | 
						|
TCPStream& TCPSocket::getStream() MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    return stream_;
 | 
						|
}
 | 
						|
 | 
						|
StreamError TCPServerSocket::setup(ip_address_t address, std::uint16_t port) MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    MIJIN_ASSERT(!isListening(), "Socket is already listening.");
 | 
						|
 | 
						|
    handle_ = osCreateSocket(AF_INET, SOCK_STREAM, 0);
 | 
						|
    if (!osIsSocketValid(handle_))
 | 
						|
    {
 | 
						|
        return translateErrno();
 | 
						|
    }
 | 
						|
    if (setsockopt(handle_, SOL_SOCKET, SO_REUSEADDR, &SOCKOPT_ONE, sizeof(int)) != 0)
 | 
						|
    {
 | 
						|
        close();
 | 
						|
        return translateErrno();
 | 
						|
    }
 | 
						|
 | 
						|
    const bool bound = std::visit(Visitor{
 | 
						|
        [&](const IPv4Address& address4)
 | 
						|
        {
 | 
						|
#if __BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__
 | 
						|
#error "TODO: swap byte order of the address"
 | 
						|
#endif
 | 
						|
            sockaddr_in bindAddress =
 | 
						|
            {
 | 
						|
                .sin_family = AF_INET,
 | 
						|
                .sin_port = htons(port),
 | 
						|
                .sin_addr = std::bit_cast<in_addr>(address4.octets)
 | 
						|
            };
 | 
						|
            return bind(handle_, reinterpret_cast<sockaddr*>(&bindAddress), sizeof(sockaddr_in)) == 0;
 | 
						|
        },
 | 
						|
        [&](IPv6Address address6)
 | 
						|
        {
 | 
						|
            for (std::uint16_t& hextet : address6.hextets) {
 | 
						|
                hextet = htons(hextet);
 | 
						|
            }
 | 
						|
            sockaddr_in6 bindAddress =
 | 
						|
            {
 | 
						|
                .sin6_family = AF_INET6,
 | 
						|
                .sin6_port = htons(port),
 | 
						|
                .sin6_addr = std::bit_cast<in6_addr>(address6)
 | 
						|
            };
 | 
						|
            return bind(handle_, reinterpret_cast<sockaddr*>(&bindAddress), sizeof(sockaddr_in6)) == 0;
 | 
						|
        }
 | 
						|
    }, address);
 | 
						|
    if (!bound
 | 
						|
        || (listen(handle_, LISTEN_BACKLOG) < 0)
 | 
						|
        || !osSetSocketNonBlocking(handle_, true))
 | 
						|
    {
 | 
						|
        close();
 | 
						|
        return translateErrno();
 | 
						|
    }
 | 
						|
    return StreamError::SUCCESS;
 | 
						|
}
 | 
						|
 | 
						|
void TCPServerSocket::close() MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    MIJIN_ASSERT(isListening(), "Socket is not listening.");
 | 
						|
 | 
						|
    osCloseSocket(handle_);
 | 
						|
    handle_ = INVALID_SOCKET_HANDLE;
 | 
						|
}
 | 
						|
 | 
						|
Task<StreamResult<std::unique_ptr<TCPSocket>>> TCPServerSocket::c_waitForConnection() MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    while (isListening())
 | 
						|
    {
 | 
						|
        sockaddr_in client = {};
 | 
						|
        socklen_t LENGTH = sizeof(sockaddr_in);
 | 
						|
        const socket_handle_t newSocket = accept(handle_, reinterpret_cast<sockaddr*>(&client), &LENGTH);
 | 
						|
        if (!osIsSocketValid(newSocket))
 | 
						|
        {
 | 
						|
            if (errno != EAGAIN)
 | 
						|
            {
 | 
						|
                co_return translateErrno();
 | 
						|
            }
 | 
						|
            co_await c_suspend();
 | 
						|
            continue;
 | 
						|
        }
 | 
						|
        std::unique_ptr<TCPSocket> socket = std::make_unique<TCPSocket>();
 | 
						|
        socket->stream_.handle_ = newSocket;
 | 
						|
        co_return socket;
 | 
						|
    }
 | 
						|
    co_return StreamError::CONNECTION_CLOSED;
 | 
						|
}
 | 
						|
}
 | 
						|
 | 
						|
#if MIJIN_COMPILER == MIJIN_COMPILER_CLANG
 | 
						|
#pragma clang diagnostic pop
 | 
						|
#endif // MIJIN_COMPILER == MIJIN_COMPILER_CLANG
 |