486 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			486 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| 
 | |
| #include "./socket.hpp"
 | |
| 
 | |
| #include "./detail/net_common.hpp"
 | |
| #include "../detect.hpp"
 | |
| #include "../internal/common.hpp"
 | |
| #include "../util/variant.hpp"
 | |
| 
 | |
| #if MIJIN_COMPILER == MIJIN_COMPILER_CLANG
 | |
| #pragma clang diagnostic push
 | |
| #pragma clang diagnostic ignored "-Wmissing-field-initializers"
 | |
| #endif // MIJIN_COMPILER == MIJIN_COMPILER_CLANG
 | |
| 
 | |
| namespace mijin
 | |
| {
 | |
| namespace
 | |
| {
 | |
| inline constexpr int LISTEN_BACKLOG = 3;
 | |
| StreamError translateErrno() MIJIN_NOEXCEPT
 | |
| {
 | |
|     switch (errno)
 | |
|     {
 | |
|     case EIO:
 | |
|         return StreamError::IO_ERROR;
 | |
|     case ECONNREFUSED:
 | |
|         return StreamError::CONNECTION_REFUSED;
 | |
|     default:
 | |
|         return StreamError::UNKNOWN_ERROR;
 | |
|     }
 | |
| }
 | |
| 
 | |
| int readFlags(const ReadOptions& options)
 | |
| {
 | |
|     return (options.partial ? 0 : MSG_WAITALL)
 | |
|          | (options.peek ? MSG_PEEK : 0);
 | |
| }
 | |
| 
 | |
| #if MIJIN_TARGET_OS == MIJIN_OS_LINUX
 | |
| const int SOCKOPT_ONE = 1;
 | |
| 
 | |
| bool appendSocketFlags(int handle, int flags) MIJIN_NOEXCEPT
 | |
| {
 | |
|     const int currentFlags = fcntl(handle, F_GETFL);
 | |
|     if (currentFlags < 0)
 | |
|     {
 | |
|         return false;
 | |
|     }
 | |
|     return fcntl(handle, F_SETFL, currentFlags | flags) >= 0;
 | |
| }
 | |
| 
 | |
| bool removeSocketFlags(int handle, int flags) MIJIN_NOEXCEPT
 | |
| {
 | |
|     const int currentFlags = fcntl(handle, F_GETFL);
 | |
|     if (currentFlags < 0)
 | |
|     {
 | |
|         return false;
 | |
|     }
 | |
|     return fcntl(handle, F_SETFL, currentFlags & ~flags) >= 0;
 | |
| }
 | |
| 
 | |
| long osRecv(int socket, std::span<std::uint8_t> buffer, int flags) MIJIN_NOEXCEPT
 | |
| {
 | |
|     return static_cast<long>(recv(socket, buffer.data(), buffer.size(), flags));
 | |
| }
 | |
| 
 | |
| long osSend(int socket, std::span<const std::uint8_t> buffer, int flags) MIJIN_NOEXCEPT
 | |
| {
 | |
|     return static_cast<long>(send(socket, buffer.data(), buffer.size(), flags));
 | |
| }
 | |
| 
 | |
| int osCreateSocket(int domain, int type, int protocol) MIJIN_NOEXCEPT
 | |
| {
 | |
|     return socket(domain, type, protocol);
 | |
| }
 | |
| 
 | |
| int osCloseSocket(int socket) MIJIN_NOEXCEPT
 | |
| {
 | |
|     return ::close(socket);
 | |
| }
 | |
| 
 | |
| bool osIsSocketValid(int socket) MIJIN_NOEXCEPT
 | |
| {
 | |
|     return socket >= 0;
 | |
| }
 | |
| 
 | |
| bool osSetSocketNonBlocking(int socket, bool blocking) MIJIN_NOEXCEPT
 | |
| {
 | |
|     if (blocking)
 | |
|     {
 | |
|         return appendSocketFlags(socket, O_NONBLOCK);
 | |
|     }
 | |
|     else
 | |
|     {
 | |
|         return removeSocketFlags(socket, O_NONBLOCK);
 | |
|     }
 | |
| }
 | |
| #elif MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
 | |
| using in_addr_t = ULONG;
 | |
| 
 | |
| const char SOCKOPT_ONE = 1;
 | |
| thread_local bool gWsaInited = false;
 | |
| 
 | |
| class WSAGuard
 | |
| {
 | |
| public:
 | |
|     ~WSAGuard() MIJIN_NOEXCEPT
 | |
|     {
 | |
|         if (gWsaInited)
 | |
|         {
 | |
|             WSACleanup();
 | |
|         }
 | |
|     }
 | |
| } thread_local [[maybe_unused]] gWsaGuard;
 | |
| 
 | |
| long osRecv(SOCKET socket, std::span<std::uint8_t> buffer, int flags) MIJIN_NOEXCEPT
 | |
| {
 | |
|     return recv(socket, reinterpret_cast<char*>(buffer.data()), static_cast<int>(buffer.size()), flags);
 | |
| }
 | |
| 
 | |
| long osSend(SOCKET socket, std::span<const std::uint8_t> buffer, int flags) MIJIN_NOEXCEPT
 | |
| {
 | |
|     return send(socket, reinterpret_cast<const char*>(buffer.data()), static_cast<int>(buffer.size()), flags);
 | |
| }
 | |
| 
 | |
| SOCKET osCreateSocket(int addressFamily, int type, int protocol) MIJIN_NOEXCEPT
 | |
| {
 | |
|     if (!detail::initWSA())
 | |
|     {
 | |
|         return INVALID_SOCKET_HANDLE;
 | |
|     }
 | |
|     return socket(addressFamily, type, protocol);
 | |
| }
 | |
| 
 | |
| int osCloseSocket(SOCKET socket) MIJIN_NOEXCEPT
 | |
| {
 | |
|     return closesocket(socket);
 | |
| }
 | |
| 
 | |
| bool osIsSocketValid(SOCKET socket) MIJIN_NOEXCEPT
 | |
| {
 | |
|     return socket != INVALID_SOCKET;
 | |
| }
 | |
| 
 | |
| bool osSetSocketNonBlocking(SOCKET socket, bool blocking) MIJIN_NOEXCEPT
 | |
| {
 | |
|     u_long value = blocking ? 0 : 1;
 | |
|     return ioctlsocket(socket, FIONBIO, &value) == NO_ERROR;
 | |
| }
 | |
| #endif // MIJIN_TARGET_OS == MIJIN_OS_LINUX
 | |
| }
 | |
| 
 | |
| namespace detail
 | |
| {
 | |
| #if MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
 | |
| bool initWSA() MIJIN_NOEXCEPT
 | |
| {
 | |
|     if (gWsaInited)
 | |
|     {
 | |
|         return true;
 | |
|     }
 | |
| 
 | |
|     WSADATA wsaData;
 | |
|     if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0)
 | |
|     {
 | |
|         return false;
 | |
|     }
 | |
|     gWsaInited = true;
 | |
|     return true;
 | |
| }
 | |
| 
 | |
| StreamError translateWSAError() MIJIN_NOEXCEPT
 | |
| {
 | |
|     // TODO
 | |
|     return StreamError::UNKNOWN_ERROR;
 | |
| }
 | |
| 
 | |
| StreamError translateWinError(DWORD error) MIJIN_NOEXCEPT
 | |
| {
 | |
|     // TODO
 | |
|     (void) error;
 | |
|     return StreamError::UNKNOWN_ERROR;
 | |
| }
 | |
| 
 | |
| StreamError translateWinError() MIJIN_NOEXCEPT
 | |
| {
 | |
|     return translateWinError(GetLastError());
 | |
| }
 | |
| #endif // MIJIN_TARGET_OS == MIJIN_OS_WINDOWS
 | |
| }// namespace impl
 | |
| 
 | |
| StreamError TCPStream::readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options, std::size_t* outBytesRead)
 | |
| {
 | |
|     MIJIN_ASSERT(isOpen(), "Socket is not open.");
 | |
|     setNoblock(options.noBlock);
 | |
| 
 | |
|     const long bytesRead = osRecv(handle_, buffer, readFlags(options));
 | |
|     if (bytesRead < 0)
 | |
|     {
 | |
|         if (!options.noBlock || errno != EAGAIN)
 | |
|         {
 | |
|             return translateErrno();
 | |
|         }
 | |
|         if (outBytesRead != nullptr)
 | |
|         {
 | |
|             *outBytesRead = 0;
 | |
|         }
 | |
|         return StreamError::SUCCESS;
 | |
|     }
 | |
|     if (outBytesRead != nullptr)
 | |
|     {
 | |
|         *outBytesRead = static_cast<std::size_t>(bytesRead);
 | |
|     }
 | |
| 
 | |
|     return StreamError::SUCCESS;
 | |
| }
 | |
| 
 | |
| StreamError TCPStream::writeRaw(std::span<const std::uint8_t> buffer)
 | |
| {
 | |
|     MIJIN_ASSERT(isOpen(), "Socket is not open.");
 | |
|     setNoblock(false);
 | |
| 
 | |
|     if (osSend(handle_, buffer, 0) < 0)
 | |
|     {
 | |
|         return translateErrno();
 | |
|     }
 | |
| 
 | |
|     return StreamError::SUCCESS;
 | |
| }
 | |
| 
 | |
| mijin::Task<StreamError> TCPStream::c_readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options, std::size_t* outBytesRead)
 | |
| {
 | |
|     MIJIN_ASSERT(isOpen(), "Socket is not open.");
 | |
|     setNoblock(true);
 | |
| 
 | |
|     if (buffer.empty())
 | |
|     {
 | |
|         co_return StreamError::SUCCESS;
 | |
|     }
 | |
| 
 | |
|     while(true)
 | |
|     {
 | |
|         const long bytesRead = osRecv(handle_, buffer, readFlags(options));
 | |
|         if (bytesRead > 0)
 | |
|         {
 | |
|             if (outBytesRead != nullptr) {
 | |
|                 *outBytesRead = static_cast<std::size_t>(bytesRead);
 | |
|             }
 | |
|             co_return StreamError::SUCCESS;
 | |
|         }
 | |
|         if (bytesRead == 0)
 | |
|         {
 | |
|             co_return StreamError::CONNECTION_CLOSED;
 | |
|         }
 | |
|         if (errno != EAGAIN)
 | |
|         {
 | |
|             co_return translateErrno();
 | |
|         }
 | |
|         co_await mijin::c_suspend();
 | |
|     }
 | |
| }
 | |
| 
 | |
| Task<StreamError> TCPStream::c_writeRaw(std::span<const std::uint8_t> buffer)
 | |
| {
 | |
|     MIJIN_ASSERT(isOpen(), "Socket is not open.");
 | |
| 
 | |
|     if (buffer.empty())
 | |
|     {
 | |
|         co_return StreamError::SUCCESS;
 | |
|     }
 | |
| 
 | |
|     setNoblock(true);
 | |
| 
 | |
|     while (true)
 | |
|     {
 | |
|         const long bytesSent = osSend(handle_, buffer, 0);
 | |
|         if (bytesSent == static_cast<long>(buffer.size()))
 | |
|         {
 | |
|             co_return StreamError::SUCCESS;
 | |
|         }
 | |
|         else if (bytesSent == 0)
 | |
|         {
 | |
|             co_return StreamError::CONNECTION_CLOSED;
 | |
|         }
 | |
|         else if (errno != EAGAIN)
 | |
|         {
 | |
|             co_return translateErrno();
 | |
|         }
 | |
|         co_await c_suspend();
 | |
|     }
 | |
| }
 | |
| 
 | |
| void TCPStream::setNoblock(bool async)
 | |
| {
 | |
|     if (async == async_)
 | |
|     {
 | |
|         return;
 | |
|     }
 | |
|     async_ = async;
 | |
| 
 | |
|     osSetSocketNonBlocking(handle_, async);
 | |
| }
 | |
| 
 | |
| std::size_t TCPStream::tell()
 | |
| {
 | |
|     return 0;
 | |
| }
 | |
| 
 | |
| StreamError TCPStream::seek(std::intptr_t /* pos */, SeekMode /* seekMode */)
 | |
| {
 | |
|     return StreamError::NOT_SUPPORTED;
 | |
| }
 | |
| 
 | |
| void TCPStream::flush()
 | |
| {
 | |
| 
 | |
| }
 | |
| 
 | |
| bool TCPStream::isAtEnd()
 | |
| {
 | |
|     return !isOpen();
 | |
| }
 | |
| 
 | |
| StreamFeatures TCPStream::getFeatures()
 | |
| {
 | |
|     return {
 | |
|         .read = true,
 | |
|         .write = true,
 | |
|         .tell = false,
 | |
|         .seek = false,
 | |
|         .async = true,
 | |
|         .readOptions = {
 | |
|             .partial = true,
 | |
|             .peek = true
 | |
|         }
 | |
|     };
 | |
| }
 | |
| 
 | |
| StreamError TCPStream::open(ip_address_t address, std::uint16_t port) MIJIN_NOEXCEPT
 | |
| {
 | |
|     MIJIN_ASSERT(!isOpen(), "Socket is already open.");
 | |
| 
 | |
|     handle_ = osCreateSocket(AF_INET, SOCK_STREAM, 0);
 | |
|     if (!osIsSocketValid(handle_))
 | |
|     {
 | |
|         return translateErrno();
 | |
|     }
 | |
| 
 | |
|     const bool connected = std::visit(Visitor{
 | |
|         [&](const IPv4Address& address4)
 | |
|         {
 | |
| #if __BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__
 | |
| #error "TODO: swap byte order of the address"
 | |
| #endif
 | |
|             sockaddr_in connectAddress =
 | |
|             {
 | |
|                 .sin_family = AF_INET,
 | |
|                 .sin_port = htons(port),
 | |
|                 .sin_addr = std::bit_cast<in_addr>(address4.octets)
 | |
|             };
 | |
|             return connect(handle_, reinterpret_cast<sockaddr*>(&connectAddress), sizeof(sockaddr_in)) == 0;
 | |
|         },
 | |
|         [&](IPv6Address address6)
 | |
|         {
 | |
|             for (std::uint16_t& hextet : address6.hextets) {
 | |
|                 hextet = htons(hextet);
 | |
|             }
 | |
|             sockaddr_in6 connectAddress =
 | |
|             {
 | |
|                 .sin6_family = AF_INET6,
 | |
|                 .sin6_port = htons(port),
 | |
|                 .sin6_addr = std::bit_cast<in6_addr>(address6)
 | |
|             };
 | |
|             return connect(handle_, reinterpret_cast<sockaddr*>(&connectAddress), sizeof(sockaddr_in6)) == 0;
 | |
|         }
 | |
|     }, address);
 | |
|     if (!connected)
 | |
|     {
 | |
|         osCloseSocket(handle_);
 | |
|         handle_ = INVALID_SOCKET_HANDLE;
 | |
|         return translateErrno();
 | |
|     }
 | |
| 
 | |
|     return StreamError::SUCCESS;
 | |
| }
 | |
| 
 | |
| void TCPStream::close() MIJIN_NOEXCEPT
 | |
| {
 | |
|     MIJIN_ASSERT(isOpen(), "Socket is not open.");
 | |
|     osCloseSocket(handle_);
 | |
|     handle_ = INVALID_SOCKET_HANDLE;
 | |
| }
 | |
| 
 | |
| TCPStream& TCPSocket::getStream() MIJIN_NOEXCEPT
 | |
| {
 | |
|     return stream_;
 | |
| }
 | |
| 
 | |
| StreamError TCPServerSocket::setup(ip_address_t address, std::uint16_t port) MIJIN_NOEXCEPT
 | |
| {
 | |
|     MIJIN_ASSERT(!isListening(), "Socket is already listening.");
 | |
| 
 | |
|     handle_ = osCreateSocket(AF_INET, SOCK_STREAM, 0);
 | |
|     if (!osIsSocketValid(handle_))
 | |
|     {
 | |
|         return translateErrno();
 | |
|     }
 | |
|     if (setsockopt(handle_, SOL_SOCKET, SO_REUSEADDR, &SOCKOPT_ONE, sizeof(int)) != 0)
 | |
|     {
 | |
|         close();
 | |
|         return translateErrno();
 | |
|     }
 | |
| 
 | |
|     const bool bound = std::visit(Visitor{
 | |
|         [&](const IPv4Address& address4)
 | |
|         {
 | |
| #if __BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__
 | |
| #error "TODO: swap byte order of the address"
 | |
| #endif
 | |
|             sockaddr_in bindAddress =
 | |
|             {
 | |
|                 .sin_family = AF_INET,
 | |
|                 .sin_port = htons(port),
 | |
|                 .sin_addr = std::bit_cast<in_addr>(address4.octets)
 | |
|             };
 | |
|             return bind(handle_, reinterpret_cast<sockaddr*>(&bindAddress), sizeof(sockaddr_in)) == 0;
 | |
|         },
 | |
|         [&](IPv6Address address6)
 | |
|         {
 | |
|             for (std::uint16_t& hextet : address6.hextets) {
 | |
|                 hextet = htons(hextet);
 | |
|             }
 | |
|             sockaddr_in6 bindAddress =
 | |
|             {
 | |
|                 .sin6_family = AF_INET6,
 | |
|                 .sin6_port = htons(port),
 | |
|                 .sin6_addr = std::bit_cast<in6_addr>(address6)
 | |
|             };
 | |
|             return bind(handle_, reinterpret_cast<sockaddr*>(&bindAddress), sizeof(sockaddr_in6)) == 0;
 | |
|         }
 | |
|     }, address);
 | |
|     if (!bound
 | |
|         || (listen(handle_, LISTEN_BACKLOG) < 0)
 | |
|         || !osSetSocketNonBlocking(handle_, true))
 | |
|     {
 | |
|         close();
 | |
|         return translateErrno();
 | |
|     }
 | |
|     return StreamError::SUCCESS;
 | |
| }
 | |
| 
 | |
| void TCPServerSocket::close() MIJIN_NOEXCEPT
 | |
| {
 | |
|     MIJIN_ASSERT(isListening(), "Socket is not listening.");
 | |
| 
 | |
|     osCloseSocket(handle_);
 | |
|     handle_ = INVALID_SOCKET_HANDLE;
 | |
| }
 | |
| 
 | |
| Task<StreamResult<std::unique_ptr<TCPSocket>>> TCPServerSocket::c_waitForConnection() MIJIN_NOEXCEPT
 | |
| {
 | |
|     while (isListening())
 | |
|     {
 | |
|         sockaddr_in client = {};
 | |
|         socklen_t LENGTH = sizeof(sockaddr_in);
 | |
|         const socket_handle_t newSocket = accept(handle_, reinterpret_cast<sockaddr*>(&client), &LENGTH);
 | |
|         if (!osIsSocketValid(newSocket))
 | |
|         {
 | |
|             if (errno != EAGAIN)
 | |
|             {
 | |
|                 co_return translateErrno();
 | |
|             }
 | |
|             co_await c_suspend();
 | |
|             continue;
 | |
|         }
 | |
|         std::unique_ptr<TCPSocket> socket = std::make_unique<TCPSocket>();
 | |
|         socket->stream_.handle_ = newSocket;
 | |
|         co_return socket;
 | |
|     }
 | |
|     co_return StreamError::CONNECTION_CLOSED;
 | |
| }
 | |
| }
 | |
| 
 | |
| #if MIJIN_COMPILER == MIJIN_COMPILER_CLANG
 | |
| #pragma clang diagnostic pop
 | |
| #endif // MIJIN_COMPILER == MIJIN_COMPILER_CLANG
 |