diff --git a/SModule b/SModule index 880d97f..1a823ac 100644 --- a/SModule +++ b/SModule @@ -10,6 +10,7 @@ mijin_sources = Split(""" source/mijin/debug/symbol_info.cpp source/mijin/io/process.cpp source/mijin/io/stream.cpp + source/mijin/net/socket.cpp source/mijin/util/os.cpp source/mijin/types/name.cpp source/mijin/virtual_filesystem/filesystem.cpp diff --git a/source/mijin/io/process.cpp b/source/mijin/io/process.cpp index 591f2df..1e0fde2 100644 --- a/source/mijin/io/process.cpp +++ b/source/mijin/io/process.cpp @@ -53,7 +53,7 @@ int ProcessStream::close() return result; } -StreamError ProcessStream::readRaw(std::span buffer, bool partial, std::size_t* outBytesRead) +StreamError ProcessStream::readRaw(std::span buffer, const ReadOptions& options, std::size_t* outBytesRead) { assert(handle); assert(mode == FileOpenMode::READ || mode == FileOpenMode::READ_WRITE); @@ -67,7 +67,7 @@ StreamError ProcessStream::readRaw(std::span buffer, bool partial, if (std::ferror(handle)) { return StreamError::IO_ERROR; } - if (!partial && readBytes < buffer.size()) { + if (!options.partial && readBytes < buffer.size()) { return StreamError::IO_ERROR; } if (outBytesRead != nullptr) @@ -134,7 +134,10 @@ StreamFeatures ProcessStream::getFeatures() .read = (mode == FileOpenMode::READ), .write = (mode == FileOpenMode::WRITE || mode == FileOpenMode::READ_WRITE), .tell = true, - .seek = false + .seek = false, + .readOptions = { + .partial = true + } }; } return {}; diff --git a/source/mijin/io/process.hpp b/source/mijin/io/process.hpp index 7d8ce52..260a086 100644 --- a/source/mijin/io/process.hpp +++ b/source/mijin/io/process.hpp @@ -26,7 +26,7 @@ public: [[nodiscard]] inline bool isOpen() const { return handle != nullptr; } // Stream overrides - StreamError readRaw(std::span buffer, bool partial = false, std::size_t* outBytesRead = nullptr) override; + StreamError readRaw(std::span buffer, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) override; StreamError writeRaw(std::span buffer) override; std::size_t tell() override; StreamError seek(std::intptr_t pos, SeekMode seekMode = SeekMode::ABSOLUTE) override; diff --git a/source/mijin/io/stream.cpp b/source/mijin/io/stream.cpp index a26014d..6498ad0 100644 --- a/source/mijin/io/stream.cpp +++ b/source/mijin/io/stream.cpp @@ -35,6 +35,24 @@ namespace mijin void Stream::flush() {} +mijin::Task Stream::c_readRaw(std::span buffer, const ReadOptions& options, std::size_t* outBytesRead) +{ + (void) buffer; + (void) options; + (void) outBytesRead; + + MIJIN_ASSERT(!getFeatures().async || !getFeatures().read, "Stream advertises async read, but doesn't implement it."); + co_return StreamError::NOT_SUPPORTED; +} + +mijin::Task Stream::c_writeRaw(std::span buffer) +{ + (void) buffer; + + MIJIN_ASSERT(!getFeatures().async || !getFeatures().write, "Stream advertises async write, but doesn't implement it."); + co_return StreamError::NOT_SUPPORTED; +} + StreamError Stream::readBinaryString(std::string& outString) { std::uint32_t length; // NOLINT(cppcoreguidelines-init-variables) @@ -55,7 +73,7 @@ StreamError Stream::readBinaryString(std::string& outString) StreamError Stream::writeBinaryString(std::string_view str) { - assert(str.length() <= std::numeric_limits::max()); + MIJIN_ASSERT(str.length() <= std::numeric_limits::max(), "Binary string is too long."); const std::uint32_t length = static_cast(str.length()); StreamError error = write(length); if (error != StreamError::SUCCESS) { @@ -64,6 +82,35 @@ StreamError Stream::writeBinaryString(std::string_view str) return writeSpan(str.begin(), str.end()); } +mijin::Task Stream::c_readBinaryString(std::string& outString) +{ + std::uint32_t length; // NOLINT(cppcoreguidelines-init-variables) + StreamError error = co_await c_read(length); + if (error != StreamError::SUCCESS) { + co_return error; + } + + std::string result; + result.resize(length); + error = co_await c_readSpan(result.begin(), result.end()); + if (error != StreamError::SUCCESS) { + co_return error; + } + outString = std::move(result); + co_return StreamError::SUCCESS; +} + +mijin::Task Stream::c_writeBinaryString(std::string_view str) +{ + MIJIN_ASSERT(str.length() <= std::numeric_limits::max(), "Binary string is too long."); + const std::uint32_t length = static_cast(str.length()); + StreamError error = co_await c_write(length); + if (error != StreamError::SUCCESS) { + co_return error; + } + co_return co_await c_writeSpan(str.begin(), str.end()); +} + StreamError Stream::getTotalLength(std::size_t& outLength) { const StreamFeatures features = getFeatures(); @@ -106,7 +153,7 @@ StreamError Stream::readRest(TypelessBuffer& outBuffer) while (!isAtEnd()) { std::size_t bytesRead = 0; - if (StreamError error = readRaw(chunk, true, &bytesRead); error != StreamError::SUCCESS) + if (StreamError error = readRaw(chunk, {.partial = true}, &bytesRead); error != StreamError::SUCCESS) { return error; } @@ -118,6 +165,129 @@ StreamError Stream::readRest(TypelessBuffer& outBuffer) return StreamError::SUCCESS; } +mijin::Task Stream::c_readRest(TypelessBuffer& outBuffer) +{ + // first try to allocate everything at once + std::size_t length = 0; + if (StreamError lengthError = getTotalLength(length); lengthError == StreamError::SUCCESS) + { + MIJIN_ASSERT(getFeatures().tell, "How did you find the length if you cannot tell()?"); + length -= tell(); + outBuffer.resize(length); + if (StreamError error = co_await c_readRaw(outBuffer.data(), length); error != StreamError::SUCCESS) + { + co_return error; + } + co_return StreamError::SUCCESS; + } + + // could not determine the size, read chunk-wise + static constexpr std::size_t CHUNK_SIZE = 4096; + std::array chunk = {}; + + while (!isAtEnd()) + { + std::size_t bytesRead = 0; + if (StreamError error = co_await c_readRaw(chunk, {.partial = true}, &bytesRead); error != StreamError::SUCCESS) + { + co_return error; + } + + outBuffer.resize(outBuffer.byteSize() + bytesRead); + std::span bufferBytes = outBuffer.makeSpan(); + std::copy_n(chunk.begin(), bytesRead, bufferBytes.end() - static_cast(bytesRead)); + } + co_return StreamError::SUCCESS; +} + +StreamError Stream::readLine(std::string& outString) +{ + MIJIN_ASSERT(getFeatures().readOptions.peek, "Stream needs to support peeking."); + + static constexpr std::size_t BUFFER_SIZE = 4096; + std::array buffer; + + outString.clear(); + bool done = false; + while(!done) + { + // read into the buffer + std::size_t bytesRead = 0; + if (StreamError error = readRaw(buffer, {.partial = true, .peek = true}, &bytesRead); error != StreamError::SUCCESS) + { + return error; + } + // try to find a \n + auto begin = buffer.begin(); + auto end = buffer.begin() + bytesRead; + auto newline = std::find(begin, end, '\n'); + + if (newline != end) + { + // found the end + outString.append(begin, newline); + end = newline + 1; + done = true; + } + else + { + outString.append(begin, end); + } + + // read again, this time to skip + if (StreamError error = readSpan(begin, end); error != StreamError::SUCCESS) + { + return error; + } + } + + return StreamError::SUCCESS; +} + +mijin::Task Stream::c_readLine(std::string& outString) +{ + MIJIN_ASSERT(getFeatures().readOptions.peek, "Stream needs to support peeking."); + + static constexpr std::size_t BUFFER_SIZE = 4096; + std::array buffer; + + outString.clear(); + bool done = false; + while(!done) + { + // read into the buffer + std::size_t bytesRead = 0; + if (StreamError error = co_await c_readRaw(buffer, {.partial = true, .peek = true}, &bytesRead); error != StreamError::SUCCESS) + { + co_return error; + } + // try to find a \n + auto begin = buffer.begin(); + auto end = buffer.begin() + bytesRead; + auto newline = std::find(begin, end, '\n'); + + if (newline != end) + { + // found the end + outString.append(begin, newline); + end = newline + 1; + done = true; + } + else + { + outString.append(begin, end); + } + + // read again, this time to skip + if (StreamError error = co_await c_readSpan(begin, end); error != StreamError::SUCCESS) + { + co_return error; + } + } + + co_return StreamError::SUCCESS; +} + FileStream::~FileStream() { if (handle) { @@ -172,7 +342,7 @@ void FileStream::close() assert(result == 0); } -StreamError FileStream::readRaw(std::span buffer, bool partial, std::size_t* outBytesRead) +StreamError FileStream::readRaw(std::span buffer, const ReadOptions& options, std::size_t* outBytesRead) { assert(handle); assert(mode == FileOpenMode::READ || mode == FileOpenMode::READ_WRITE); @@ -181,12 +351,19 @@ StreamError FileStream::readRaw(std::span buffer, bool partial, st if (std::ferror(handle)) { return StreamError::IO_ERROR; } - if (!partial && readBytes < buffer.size()) { + if (!options.partial && readBytes < buffer.size()) { return StreamError::IO_ERROR; } if (outBytesRead != nullptr) { *outBytesRead = readBytes; } + if (options.peek) + { + if (StreamError error = seek(-static_cast(readBytes), SeekMode::RELATIVE); error != StreamError::SUCCESS) + { + return error; + } + } return StreamError::SUCCESS; } @@ -265,7 +442,11 @@ StreamFeatures FileStream::getFeatures() .read = (mode == FileOpenMode::READ), .write = (mode == FileOpenMode::WRITE || mode == FileOpenMode::APPEND || mode == FileOpenMode::READ_WRITE), .tell = true, - .seek = true + .seek = true, + .readOptions = { + .partial = true, + .peek = true + } }; } return {}; @@ -293,10 +474,10 @@ void MemoryStream::close() data_ = {}; } -StreamError MemoryStream::readRaw(std::span buffer, bool partial, std::size_t* outBytesRead) +StreamError MemoryStream::readRaw(std::span buffer, const ReadOptions& options, std::size_t* outBytesRead) { assert(isOpen()); - if (!partial && availableBytes() < buffer.size()) { + if (!options.partial && availableBytes() < buffer.size()) { return StreamError::IO_ERROR; // TODO: need more errors? } const std::size_t numBytes = std::min(buffer.size(), availableBytes()); @@ -304,7 +485,9 @@ StreamError MemoryStream::readRaw(std::span buffer, bool partial, if (outBytesRead) { *outBytesRead = numBytes; } - pos_ += numBytes; + if (!options.peek) { + pos_ += numBytes; + } return StreamError::SUCCESS; } @@ -362,7 +545,10 @@ StreamFeatures MemoryStream::getFeatures() .read = true, .write = canWrite_, .tell = true, - .seek = true + .seek = true, + .readOptions = { + .peek = true + } }; } diff --git a/source/mijin/io/stream.hpp b/source/mijin/io/stream.hpp index fddee1a..bc9615c 100644 --- a/source/mijin/io/stream.hpp +++ b/source/mijin/io/stream.hpp @@ -11,7 +11,9 @@ #include #include #include +#include "../async/coroutine.hpp" #include "../container/typeless_buffer.hpp" +#include "../types/result.hpp" #include "../util/exception.hpp" namespace mijin @@ -36,12 +38,21 @@ enum class SeekMode RELATIVE_TO_END }; +struct ReadOptions +{ + bool partial : 1 = false; + bool peek : 1 = false; +}; + struct StreamFeatures { bool read : 1 = false; bool write : 1 = false; bool tell : 1 = false; bool seek : 1 = false; + bool async : 1 = false; + + ReadOptions readOptions = {}; }; enum class FileOpenMode @@ -54,10 +65,11 @@ enum class FileOpenMode enum class [[nodiscard]] StreamError { - SUCCESS, - IO_ERROR, - NOT_SUPPORTED, - UNKNOWN_ERROR + SUCCESS = 0, + IO_ERROR = 1, + NOT_SUPPORTED = 2, + CONNECTION_CLOSED = 3, + UNKNOWN_ERROR = -1 }; class Stream @@ -66,7 +78,12 @@ public: virtual ~Stream() = default; public: - virtual StreamError readRaw(std::span buffer, bool partial = false, std::size_t* outBytesRead = nullptr) = 0; + virtual StreamError readRaw(std::span buffer, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) = 0; + [[deprecated("Partial parameter has been replaced with options.")]] + StreamError readRaw(std::span buffer, bool partial, std::size_t* outBytesRead = nullptr) + { + return readRaw(buffer, {.partial = partial}, outBytesRead); + } virtual StreamError writeRaw(std::span buffer) = 0; virtual std::size_t tell() = 0; virtual StreamError seek(std::intptr_t pos, SeekMode seekMode = SeekMode::ABSOLUTE) = 0; @@ -74,74 +91,163 @@ public: virtual bool isAtEnd() = 0; virtual StreamFeatures getFeatures() = 0; - inline StreamError readRaw(void* outData, std::size_t bytes, bool partial = false, std::size_t* outBytesRead = nullptr) + // async interface (requires getFeatures().async to be set) + virtual mijin::Task c_readRaw(std::span buffer, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr); + virtual mijin::Task c_writeRaw(std::span buffer); + + StreamError readRaw(void* outData, std::size_t bytes, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) { std::uint8_t* ptr = static_cast(outData); - return readRaw(std::span(ptr, ptr + bytes), partial, outBytesRead); + return readRaw(std::span(ptr, ptr + bytes), options, outBytesRead); + } + + [[deprecated("Partial parameter has been replaced with options.")]] + StreamError readRaw(void* outData, std::size_t bytes, bool partial, std::size_t* outBytesRead = nullptr) + { + std::uint8_t* ptr = static_cast(outData); + return readRaw(std::span(ptr, ptr + bytes), {.partial = partial}, outBytesRead); + } + + mijin::Task c_readRaw(void* outData, std::size_t bytes, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) + { + std::uint8_t* ptr = static_cast(outData); + return c_readRaw(std::span(ptr, ptr + bytes), options, outBytesRead); } template - inline StreamError readRaw(TRange& range, bool partial = false, std::size_t* outBytesRead = nullptr) + StreamError readRaw(TRange& range, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) { const std::size_t bytes = std::distance(range.begin(), range.end()) * sizeof(std::ranges::range_value_t); - return readRaw(&*range.begin(), bytes, partial, outBytesRead); + return readRaw(&*range.begin(), bytes, options, outBytesRead); } - inline StreamError writeRaw(const void* data, std::size_t bytes) + template + [[deprecated("Partial parameter has been replaced with options.")]] + StreamError readRaw(TRange& range, bool partial, std::size_t* outBytesRead = nullptr) + { + const std::size_t bytes = std::distance(range.begin(), range.end()) * sizeof(std::ranges::range_value_t); + return readRaw(&*range.begin(), bytes, {.partial = partial}, outBytesRead); + } + + template + mijin::Task c_readRaw(TRange& range, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) + { + const std::size_t bytes = std::distance(range.begin(), range.end()) * sizeof(std::ranges::range_value_t); + return c_readRaw(&*range.begin(), bytes, options, outBytesRead); + } + + StreamError writeRaw(const void* data, std::size_t bytes) { const std::uint8_t* ptr = static_cast(data); return writeRaw(std::span(ptr, ptr + bytes)); } + mijin::Task c_writeRaw(const void* data, std::size_t bytes) + { + const std::uint8_t* ptr = static_cast(data); + return c_writeRaw(std::span(ptr, ptr + bytes)); + } + template - inline StreamError writeRaw(const TRange& range) + StreamError writeRaw(const TRange& range) { const std::size_t bytes = std::distance(range.begin(), range.end()) * sizeof(std::ranges::range_value_t); return writeRaw(&*range.begin(), bytes); } - template - inline StreamError read(T& value) + template + mijin::Task c_writeRaw(const TRange& range) { - return readRaw(&value, sizeof(T)); + const std::size_t bytes = std::distance(range.begin(), range.end()) * sizeof(std::ranges::range_value_t); + return c_writeRaw(&*range.begin(), bytes); } template - inline StreamError readSpan(T& values) + StreamError read(T& value, const ReadOptions& options = {}) + { + MIJIN_ASSERT(!options.partial, "Cannot partially read a value."); + return readRaw(&value, sizeof(T), options); + } + + template + mijin::Task c_read(T& value, const ReadOptions& options = {}) + { + MIJIN_ASSERT(!options.partial, "Cannot partially read a value."); + return c_readRaw(&value, sizeof(T), options); + } + + template + StreamError readSpan(T& values, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) { auto asSpan = std::span(values); - return readRaw(asSpan.data(), asSpan.size_bytes()); + return readRaw(asSpan.data(), asSpan.size_bytes(), options, outBytesRead); + } + + template + mijin::Task c_readSpan(T& values, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) + { + auto asSpan = std::span(values); + return c_readRaw(asSpan.data(), asSpan.size_bytes(), options, outBytesRead); } template - inline StreamError readSpan(TItBegin&& begin, TItEnd&& end) + StreamError readSpan(TItBegin&& begin, TItEnd&& end, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) { auto asSpan = std::span(std::forward(begin), std::forward(end)); - return readRaw(asSpan.data(), asSpan.size_bytes()); + return readRaw(asSpan.data(), asSpan.size_bytes(), options, outBytesRead); + } + + template + mijin::Task c_readSpan(TItBegin&& begin, TItEnd&& end, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) + { + auto asSpan = std::span(std::forward(begin), std::forward(end)); + return c_readRaw(asSpan.data(), asSpan.size_bytes(), options, outBytesRead); } template - inline StreamError write(const T& value) + StreamError write(const T& value) requires(std::is_trivial_v) { return writeRaw(&value, sizeof(T)); } template - inline StreamError writeSpan(const T& values) + mijin::Task c_write(const T& value) requires(std::is_trivial_v) + { + return c_writeRaw(&value, sizeof(T)); + } + + template + StreamError writeSpan(const T& values) { auto asSpan = std::span(values); return writeRaw(asSpan.data(), asSpan.size_bytes()); } + template + mijin::Task c_writeSpan(const T& values) + { + auto asSpan = std::span(values); + return c_writeRaw(asSpan.data(), asSpan.size_bytes()); + } + template - inline StreamError writeSpan(TItBegin&& begin, TItEnd&& end) + StreamError writeSpan(TItBegin&& begin, TItEnd&& end) { return writeSpan(std::span(std::forward(begin), std::forward(end))); } + template + mijin::Task c_writeSpan(TItBegin&& begin, TItEnd&& end) + { + return c_writeSpan(std::span(std::forward(begin), std::forward(end))); + } + StreamError readBinaryString(std::string& outString); StreamError writeBinaryString(std::string_view str); + mijin::Task c_readBinaryString(std::string& outString); + mijin::Task c_writeBinaryString(std::string_view str); + [[deprecated("Use readBinaryString() or readAsString() instead.")]] inline StreamError readString(std::string& outString) { return readBinaryString(outString); } @@ -150,14 +256,26 @@ public: StreamError getTotalLength(std::size_t& outLength); StreamError readRest(TypelessBuffer& outBuffer); + mijin::Task c_readRest(TypelessBuffer& outBuffer); + + StreamError readLine(std::string& outString); + mijin::Task c_readLine(std::string& outString); template StreamError readAsString(std::basic_string& outString); - inline StreamError writeText(std::string_view str) + template + mijin::Task c_readAsString(std::basic_string& outString); + + StreamError writeText(std::string_view str) { return writeSpan(str); } + + mijin::Task c_writeText(std::string_view str) + { + return c_writeSpan(str); + } }; class FileStream : public Stream @@ -177,7 +295,7 @@ public: [[nodiscard]] inline bool isOpen() const { return handle != nullptr; } // Stream overrides - StreamError readRaw(std::span buffer, bool partial = false, std::size_t* outBytesRead = nullptr) override; + StreamError readRaw(std::span buffer, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) override; StreamError writeRaw(std::span buffer) override; std::size_t tell() override; StreamError seek(std::intptr_t pos, SeekMode seekMode = SeekMode::ABSOLUTE) override; @@ -210,7 +328,7 @@ public: } // Stream overrides - StreamError readRaw(std::span buffer, bool partial = false, std::size_t* outBytesRead = nullptr) override; + StreamError readRaw(std::span buffer, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) override; StreamError writeRaw(std::span buffer) override; std::size_t tell() override; StreamError seek(std::intptr_t pos, SeekMode seekMode = SeekMode::ABSOLUTE) override; @@ -220,6 +338,9 @@ private: void openROImpl(const void* data, std::size_t bytes); }; +template +using StreamResult = ResultBase; + // // public functions // @@ -251,7 +372,7 @@ StreamError Stream::readAsString(std::basic_string& outString) while (!isAtEnd()) { std::size_t bytesRead = 0; - if (StreamError error = readRaw(chunk, true, &bytesRead); error != StreamError::SUCCESS) + if (StreamError error = readRaw(chunk, {.partial = true}, &bytesRead); error != StreamError::SUCCESS) { return error; } @@ -260,6 +381,42 @@ StreamError Stream::readAsString(std::basic_string& outString) return StreamError::SUCCESS; } +template +mijin::Task Stream::c_readAsString(std::basic_string& outString) +{ + static_assert(sizeof(TChar) == 1, "Can only read to 8-bit character types (char, unsigned char or char8_t"); + + // first try to allocate everything at once + std::size_t length = 0; + if (StreamError lengthError = getTotalLength(length); lengthError == StreamError::SUCCESS) + { + MIJIN_ASSERT(getFeatures().tell, "How did you find the length if you cannot tell()?"); + length -= tell(); + outString.resize(length); + if (StreamError error = co_await c_readRaw(outString.data(), length); error != StreamError::SUCCESS) + { + co_return error; + } + co_return StreamError::SUCCESS; + } + + // could not determine the size, read chunk-wise + static constexpr std::size_t CHUNK_SIZE = 4096; + std::array chunk; + + outString.clear(); + while (!isAtEnd()) + { + std::size_t bytesRead = 0; + if (StreamError error = co_await c_readRaw(chunk, true, &bytesRead); error != StreamError::SUCCESS) + { + co_return error; + } + outString.append(chunk.data(), chunk.data() + bytesRead); + } + co_return StreamError::SUCCESS; +} + inline const char* errorName(StreamError error) noexcept { @@ -271,6 +428,8 @@ inline const char* errorName(StreamError error) noexcept return "IO error"; case StreamError::NOT_SUPPORTED: return "not supported"; + case StreamError::CONNECTION_CLOSED: + return "connection closed"; case StreamError::UNKNOWN_ERROR: return "unknown error"; } diff --git a/source/mijin/net/socket.cpp b/source/mijin/net/socket.cpp new file mode 100644 index 0000000..a12b26f --- /dev/null +++ b/source/mijin/net/socket.cpp @@ -0,0 +1,271 @@ + +#include "./socket.hpp" + +#include "../detect.hpp" + +#if MIJIN_TARGET_OS == MIJIN_OS_LINUX +#include +#include +#include +#include +#endif + +namespace mijin +{ +namespace +{ +inline constexpr int LISTEN_BACKLOG = 3; +StreamError translateErrno() noexcept +{ + switch (errno) + { + default: + return StreamError::UNKNOWN_ERROR; + } +} + +bool appendSocketFlags(int handle, int flags) noexcept +{ + const int currentFlags = fcntl(handle, F_GETFL); + if (currentFlags < 0) + { + return false; + } + return fcntl(handle, F_SETFL, currentFlags | flags) >= 0; +} + +bool removeSocketFlags(int handle, int flags) noexcept +{ + const int currentFlags = fcntl(handle, F_GETFL); + if (currentFlags < 0) + { + return false; + } + return fcntl(handle, F_SETFL, currentFlags & ~flags) >= 0; +} + +int readFlags(const ReadOptions& options) +{ + return (options.partial ? 0 : MSG_WAITALL) + | (options.peek ? MSG_PEEK : 0); +} +} + +StreamError TCPStream::readRaw(std::span buffer, const ReadOptions& options, std::size_t* outBytesRead) +{ + MIJIN_ASSERT(isOpen(), "Socket is not open."); + setAsync(false); + + const ::ssize_t bytesRead = recv(handle_, buffer.data(), buffer.size(), readFlags(options)); + if (bytesRead < 0) + { + return translateErrno(); + } + *outBytesRead = static_cast(bytesRead); + + return StreamError::SUCCESS; +} + +StreamError TCPStream::writeRaw(std::span buffer) +{ + MIJIN_ASSERT(isOpen(), "Socket is not open."); + setAsync(false); + + if (send(handle_, buffer.data(), buffer.size(), 0) < 0) + { + return translateErrno(); + } + + return StreamError::SUCCESS; +} + +mijin::Task TCPStream::c_readRaw(std::span buffer, const ReadOptions& options, std::size_t* outBytesRead) +{ + MIJIN_ASSERT(isOpen(), "Socket is not open."); + setAsync(true); + + while(true) + { + const ::ssize_t bytesRead = recv(handle_, buffer.data(), buffer.size(), readFlags(options)); + if (bytesRead >= 0) + { + if (outBytesRead != nullptr) { + *outBytesRead = static_cast(bytesRead); + } + co_return StreamError::SUCCESS; + } + else if (errno != EAGAIN) + { + co_return translateErrno(); + } + co_await mijin::c_suspend(); + } +} + +mijin::Task TCPStream::c_writeRaw(std::span buffer) +{ + MIJIN_ASSERT(isOpen(), "Socket is not open."); + setAsync(true); + + while (true) + { + if (send(handle_, buffer.data(), buffer.size(), 0) >= 0) + { + co_return StreamError::SUCCESS; + } + else if (errno != EAGAIN) + { + co_return translateErrno(); + } + co_await mijin::c_suspend(); + } +} + +void TCPStream::setAsync(bool async) +{ + if (async == async_) + { + return; + } + async_ = async; + + if (async) + { + appendSocketFlags(handle_, O_NONBLOCK); + } + else + { + removeSocketFlags(handle_, O_NONBLOCK); + } +} + +std::size_t TCPStream::tell() +{ + return 0; +} + +StreamError TCPStream::seek(std::intptr_t /* pos */, mijin::SeekMode /* seekMode */) +{ + return StreamError::NOT_SUPPORTED; +} + +void TCPStream::flush() +{ + +} + +bool TCPStream::isAtEnd() +{ + return !isOpen(); +} + +StreamFeatures TCPStream::getFeatures() +{ + return { + .read = true, + .write = true, + .tell = false, + .seek = false, + .async = true, + .readOptions = { + .partial = true, + .peek = true + } + }; +} + +StreamError TCPStream::open(const char* address, std::uint16_t port) noexcept +{ + MIJIN_ASSERT(!isOpen(), "Socket is already open."); + + handle_ = socket(AF_INET, SOCK_STREAM, 0); + if (handle_ < 0) + { + return translateErrno(); + } + sockaddr_in connectAddress = + { + .sin_family = AF_INET, + .sin_port = htons(port), + .sin_addr = {inet_addr(address)} + }; + if (connect(handle_, reinterpret_cast(&connectAddress), sizeof(sockaddr_in)) < 0) + { + ::close(handle_); + handle_ = -1; + return translateErrno(); + } + + return StreamError::SUCCESS; +} + +void TCPStream::close() noexcept +{ + MIJIN_ASSERT(isOpen(), "Socket is not open."); + ::close(handle_); + handle_ = -1; +} + +TCPStream& TCPSocket::getStream() noexcept +{ + return stream_; +} + +StreamError TCPServerSocket::setup(const char* address, std::uint16_t port) noexcept +{ + MIJIN_ASSERT(!isListening(), "Socket is already listening."); + + handle_ = socket(AF_INET, SOCK_STREAM, 0); + if (handle_ < 0) + { + return translateErrno(); + } + sockaddr_in bindAddress = + { + .sin_family = AF_INET, + .sin_port = htons(port), + .sin_addr = {inet_addr(address)} + }; + static const int ONE = 1; + if ((setsockopt(handle_, SOL_SOCKET, SO_REUSEADDR, &ONE, sizeof(int))) + || (bind(handle_, reinterpret_cast(&bindAddress), sizeof(sockaddr_in)) < 0) + || (listen(handle_, LISTEN_BACKLOG) < 0) + || !appendSocketFlags(handle_, O_NONBLOCK)) + { + close(); + return translateErrno(); + } + return StreamError::SUCCESS; +} + +void TCPServerSocket::close() noexcept +{ + MIJIN_ASSERT(isListening(), "Socket is not listening."); + + ::close(handle_); + handle_ = -1; +} + +Task>> TCPServerSocket::c_waitForConnection() noexcept +{ + while (isListening()) + { + sockaddr_in client = {0}; + socklen_t LENGTH = sizeof(sockaddr_in); + const int newSocket = accept(handle_, reinterpret_cast(&client), &LENGTH); + if (newSocket < 0) + { + if (errno != EAGAIN) + { + co_return translateErrno(); + } + co_await c_suspend(); + continue; + } + std::unique_ptr socket = std::make_unique(); + socket->stream_.handle_ = newSocket; + co_return socket; + } + co_return StreamError::CONNECTION_CLOSED; +} +} diff --git a/source/mijin/net/socket.hpp b/source/mijin/net/socket.hpp new file mode 100644 index 0000000..4ea09eb --- /dev/null +++ b/source/mijin/net/socket.hpp @@ -0,0 +1,100 @@ + +#pragma once + +#if !defined(MIJIN_NET_SOCKET_HPP_INCLUDED) +#define MIJIN_NET_SOCKET_HPP_INCLUDED 1 + +#include "../async/coroutine.hpp" +#include "../io/stream.hpp" + +namespace mijin +{ + +// +// public types +// + +class Socket +{ +protected: + Socket() noexcept = default; + Socket(const Socket&) noexcept = default; + Socket(Socket&&) noexcept = default; + + Socket& operator=(const Socket&) noexcept = default; + Socket& operator=(Socket&&) noexcept = default; +public: + virtual ~Socket() noexcept = default; + + virtual Stream& getStream() noexcept = 0; +}; + +class ServerSocket +{ +protected: + ServerSocket() noexcept = default; + ServerSocket(const ServerSocket&) noexcept = default; + ServerSocket(ServerSocket&&) noexcept = default; + + ServerSocket& operator=(const ServerSocket&) noexcept = default; + ServerSocket& operator=(ServerSocket&&) noexcept = default; +public: + virtual ~ServerSocket() noexcept = default; + + virtual void close() noexcept = 0; + virtual Task>> c_waitForConnection() noexcept = 0; +}; + +class TCPStream : public Stream +{ +private: + int handle_ = -1; + bool async_ = false; +public: + StreamError readRaw(std::span buffer, const ReadOptions& options, std::size_t* outBytesRead) override; + StreamError writeRaw(std::span buffer) override; + mijin::Task c_readRaw(std::span buffer, const ReadOptions& options, std::size_t *outBytesRead) override; + mijin::Task c_writeRaw(std::span buffer) override; + std::size_t tell() override; + StreamError seek(std::intptr_t pos, SeekMode seekMode = SeekMode::ABSOLUTE) override; + void flush() override; + bool isAtEnd() override; + StreamFeatures getFeatures() override; + + StreamError open(const char* address, std::uint16_t port) noexcept; + void close() noexcept; + [[nodiscard]] bool isOpen() const noexcept { return handle_ >= 0; } +private: + void setAsync(bool async); + + friend class TCPServerSocket; +}; + +class TCPSocket : public Socket +{ +private: + TCPStream stream_; +public: + TCPStream& getStream() noexcept override; + + StreamError open(const char* address, std::uint16_t port) noexcept { return stream_.open(address, port); } + void close() noexcept { stream_.close(); } + [[nodiscard]] bool isOpen() const noexcept { return stream_.isOpen(); } + + friend class TCPServerSocket; +}; + +class TCPServerSocket : public ServerSocket +{ +private: + int handle_ = -1; +public: + StreamError setup(const char* address, std::uint16_t port) noexcept; + void close() noexcept override; + [[nodiscard]] bool isListening() const noexcept { return handle_ >= 0; } + + Task>> c_waitForConnection() noexcept override; +}; +} + +#endif // !defined(MIJIN_NET_SOCKET_HPP_INCLUDED)