From a43f92fb581b7aac0a5f202540376803308f8c0f Mon Sep 17 00:00:00 2001 From: Patrick Wuttke Date: Tue, 27 Aug 2024 19:52:08 +0200 Subject: [PATCH] More SSL stuff (still doesn't work :/). --- source/mijin/io/stream.cpp | 42 ++++++++++++++++++++++- source/mijin/io/stream.hpp | 3 ++ source/mijin/net/http.cpp | 40 ++++++++++++++++++---- source/mijin/net/http.hpp | 3 +- source/mijin/net/openssl_wrappers.hpp | 34 +++++++++++++++++++ source/mijin/net/ssl.cpp | 49 +++++++++++++++++++-------- source/mijin/net/ssl.hpp | 13 ++++--- 7 files changed, 156 insertions(+), 28 deletions(-) diff --git a/source/mijin/io/stream.cpp b/source/mijin/io/stream.cpp index 46b6784..2bb7ea1 100644 --- a/source/mijin/io/stream.cpp +++ b/source/mijin/io/stream.cpp @@ -36,7 +36,47 @@ void Stream::flush() {} mijin::Task Stream::c_readRaw(std::span buffer, const ReadOptions& options, std::size_t* outBytesRead) { - co_return readRaw(buffer, options, outBytesRead); + std::size_t bytesToRead = buffer.size(); + if (bytesToRead == 0) + { + co_return StreamError::SUCCESS; + } + + while(true) + { + bool done = false; + std::size_t bytesRead = 0; + const StreamError error = readRaw(buffer.data() + buffer.size() - bytesToRead, bytesToRead, + {.partial = true, .noBlock = true}, &bytesToRead); + switch (error) + { + case StreamError::SUCCESS: + bytesToRead -= bytesRead; + if (options.partial || bytesToRead == 0) + { + done = true; + } + break; + case StreamError::WOULD_BLOCK: + if (options.noBlock) + { + co_return StreamError::WOULD_BLOCK; + } + break; + default: + co_return error; + } + if (done) + { + break; + } + co_await mijin::c_suspend(); + } + if (outBytesRead != nullptr) + { + *outBytesRead = buffer.size() - bytesToRead; + } + co_return StreamError::SUCCESS; } mijin::Task Stream::c_writeRaw(std::span buffer) diff --git a/source/mijin/io/stream.hpp b/source/mijin/io/stream.hpp index 59276d2..b47604a 100644 --- a/source/mijin/io/stream.hpp +++ b/source/mijin/io/stream.hpp @@ -70,6 +70,7 @@ enum class [[nodiscard]] StreamError NOT_SUPPORTED = 2, CONNECTION_CLOSED = 3, PROTOCOL_ERROR = 4, + WOULD_BLOCK = 5, UNKNOWN_ERROR = -1 }; @@ -434,6 +435,8 @@ inline const char* errorName(StreamError error) noexcept return "connection closed"; case StreamError::PROTOCOL_ERROR: return "protocol error"; + case StreamError::WOULD_BLOCK: + return "would block"; case StreamError::UNKNOWN_ERROR: return "unknown error"; } diff --git a/source/mijin/net/http.cpp b/source/mijin/net/http.cpp index fcef3d9..07997fc 100644 --- a/source/mijin/net/http.cpp +++ b/source/mijin/net/http.cpp @@ -6,6 +6,10 @@ #include "../util/iterators.hpp" #include "../util/string.hpp" +#if defined(MIJIN_ENABLE_OPENSSL) +#include "./ssl.hpp" +#endif + #define MIJIN_HTTP_WRITE(text) \ do \ { \ @@ -160,7 +164,12 @@ Task> HTTPStream::c_readResponse() noexcept Task> HTTPClient::c_request(ip_address_t address, std::uint16_t port, bool https, HTTPRequest request) noexcept { - if (const StreamError error = createSocket(address, port, https); error != StreamError::SUCCESS) + std::string hostname; + if (auto it = request.headers.find("host"); it != request.headers.end()) + { + hostname = it->second; + } + if (const StreamError error = createSocket(address, hostname, port, https); error != StreamError::SUCCESS) { co_return error; } @@ -172,7 +181,7 @@ Task> HTTPClient::c_request(ip_address_t address, std if (response.isError()) { disconnect(); - if (const StreamError error = createSocket(address, port, https); error != StreamError::SUCCESS) + if (const StreamError error = createSocket(address, hostname, port, https); error != StreamError::SUCCESS) { co_return error; } @@ -242,7 +251,7 @@ void HTTPClient::disconnect() noexcept socket_ = nullptr; } -StreamError HTTPClient::createSocket(ip_address_t address, std::uint16_t port, bool https) noexcept +StreamError HTTPClient::createSocket(ip_address_t address, const std::string& hostname, std::uint16_t port, bool https) noexcept { if (socket_ != nullptr && address == lastIP_ && port == lastPort_ && https == lastWasHttps_) { @@ -250,17 +259,36 @@ StreamError HTTPClient::createSocket(ip_address_t address, std::uint16_t port, b } disconnect(); - MIJIN_ASSERT(!https, "HTTPS not supported yet."); + std::unique_ptr newSocket = std::make_unique(); if (const StreamError error = newSocket->open(address, port); error != StreamError::SUCCESS) { return error; } + socket_ = std::move(newSocket); + if (!https) + { + sslStream_.reset(); + stream_.construct(socket_->getStream()); + } + else + { +#if defined(MIJIN_ENABLE_OPENSSL) + std::unique_ptr sslStream = std::make_unique(); + if (const StreamError error = sslStream->open(socket_->getStream(), hostname); error != StreamError::SUCCESS) + { + return error; + } + sslStream_ = std::move(sslStream); + stream_.construct(*sslStream_); +#else + return StreamError::NOT_SUPPORTED; +#endif + } + lastIP_ = address; lastPort_ = port; lastWasHttps_ = https; - socket_ = std::move(newSocket); - stream_.construct(socket_->getStream()); return StreamError::SUCCESS; } } diff --git a/source/mijin/net/http.hpp b/source/mijin/net/http.hpp index dacda00..4fc9964 100644 --- a/source/mijin/net/http.hpp +++ b/source/mijin/net/http.hpp @@ -57,6 +57,7 @@ class HTTPClient { private: std::unique_ptr socket_; + std::unique_ptr sslStream_; mijin::BoxedObject stream_; ip_address_t lastIP_; std::uint16_t lastPort_ = 0; @@ -67,7 +68,7 @@ public: Task> c_request(const URL& url, HTTPRequest request = {}) noexcept; void disconnect() noexcept; private: - StreamError createSocket(ip_address_t address, std::uint16_t port, bool https) noexcept; + StreamError createSocket(ip_address_t address, const std::string& hostname, std::uint16_t port, bool https) noexcept; }; } diff --git a/source/mijin/net/openssl_wrappers.hpp b/source/mijin/net/openssl_wrappers.hpp index d0baaa7..b1f53cb 100644 --- a/source/mijin/net/openssl_wrappers.hpp +++ b/source/mijin/net/openssl_wrappers.hpp @@ -257,6 +257,34 @@ public: return result; } + [[nodiscard]] + int getReadRequest() const noexcept + { + return BIO_get_read_request(handle_); + } + + [[nodiscard]] + int getWriteGuarantee() const noexcept + { + return BIO_get_write_guarantee(handle_); + } + + [[nodiscard]] + int getWritePending() const noexcept + { + return BIO_wpending(handle_); + } + + Error flush() const noexcept + { + ERR_clear_error(); + if (!BIO_flush(handle_)) + { + return Error::current(); + } + return {}; + } + static void upReferences(BIO* handle) noexcept { BIO_up_ref(handle); @@ -371,6 +399,12 @@ public: return result; } + [[nodiscard]] + int pending() const noexcept + { + return SSL_pending(handle_); + } + static void upReferences(SSL* handle) noexcept { SSL_up_ref(handle); diff --git a/source/mijin/net/ssl.cpp b/source/mijin/net/ssl.cpp index 516f20d..c4059aa 100644 --- a/source/mijin/net/ssl.cpp +++ b/source/mijin/net/ssl.cpp @@ -93,7 +93,7 @@ StreamError SSLStream::open(Stream& base, const std::string& hostname) noexcept externalBio_ = std::move(externalBio); base_ = &base; - if (const ossl::Error error = runIOLoop(&ossl::Ssl::connect); !error.isSuccess()) + if (const ossl::Error error = runIOLoop(&ossl::Ssl::connect, true); !error.isSuccess()) { ssl_.free(); externalBio.free(); @@ -116,7 +116,7 @@ void SSLStream::close() noexcept { MIJIN_ASSERT(base_ != nullptr, "SSL stream is not open."); base_ = nullptr; - (void) runIOLoop(&ossl::Ssl::shutdown); + (void) runIOLoop(&ossl::Ssl::shutdown, true); ssl_.free(); externalBio_.free(); } @@ -131,7 +131,7 @@ StreamError SSLStream::writeRaw(std::span buffer) std::size_t bytesToWrite = buffer.size(); while (bytesToWrite > 0) { - const ossl::Result result = runIOLoop(&ossl::Ssl::write, buffer.data() + buffer.size() - bytesToWrite, + const ossl::Result result = runIOLoop(&ossl::Ssl::write, true, buffer.data() + buffer.size() - bytesToWrite, static_cast(bytesToWrite)); if (result.isError()) { @@ -156,15 +156,25 @@ StreamError SSLStream::readRaw(std::span buffer, const mijin::Read return StreamError::SUCCESS; } + if (const StreamError error = baseToBio(); error != StreamError::SUCCESS) + { + return error; + } + + if (!options.partial && options.noBlock && static_cast(ssl_.pending()) < buffer.size()) + { + return StreamError::WOULD_BLOCK; + } + std::size_t bytesToRead = buffer.size(); + if (options.partial) + { + bytesToRead = std::min(bytesToRead, ssl_.pending()); + } while (bytesToRead > 0) { - if (const StreamError error = baseToBio(); error != StreamError::SUCCESS) - { - return error; - } - - const ossl::Result result = runIOLoop(&ossl::Ssl::read, buffer.data() + buffer.size() - bytesToRead, + const ossl::Result result = runIOLoop(&ossl::Ssl::read, !options.noBlock, + buffer.data() + buffer.size() - bytesToRead, static_cast(bytesToRead)); if (result.isError()) { @@ -216,9 +226,9 @@ StreamFeatures SSLStream::getFeatures() .tell = false, .seek = false, .readOptions = { - .partial = false, - .peek = false, - .noBlock = false + .partial = true, + .peek = true, + .noBlock = true } }; } @@ -247,18 +257,20 @@ StreamError SSLStream::bioToBase() noexcept StreamError SSLStream::baseToBio() noexcept { std::array buffer; + std::size_t toRead = externalBio_.getWriteGuarantee(); + while (true) { std::size_t bytes = 0; - std::size_t maxBytes = std::min(BIO_BUFFER_SIZE - externalBio_.ctrlWPending(), buffer.size()); + std::size_t maxBytes = std::min(toRead, buffer.size()); if (maxBytes == 0) { // buffer is full - return StreamError::SUCCESS; + break; } - if (const StreamError error = base_->readRaw(buffer.data(), maxBytes,{.partial = true, .noBlock = true}, + if (const StreamError error = base_->readRaw(buffer.data(), maxBytes, {.partial = true, .noBlock = true}, &bytes); error != StreamError::SUCCESS) { return error; @@ -274,6 +286,13 @@ StreamError SSLStream::baseToBio() noexcept return StreamError::UNKNOWN_ERROR; } MIJIN_ASSERT(result.getValue() == static_cast(bytes), "BIO reported more bytes in buffer than it actually accepted?"); + toRead -= bytes; } + + if (const ossl::Error error = externalBio_.flush(); !error.isSuccess()) + { + return StreamError::UNKNOWN_ERROR; + } + return StreamError::SUCCESS; } } // namespace shiken diff --git a/source/mijin/net/ssl.hpp b/source/mijin/net/ssl.hpp index 7750699..360ee0d 100644 --- a/source/mijin/net/ssl.hpp +++ b/source/mijin/net/ssl.hpp @@ -5,8 +5,8 @@ #if !defined(MIJIN_NET_SSL_HPP_INCLUDED) #define MIJIN_NET_SSL_HPP_INCLUDED 1 -#if !MIJIN_ENABLE_OPENSSL -#error "SSL support not enabled. Set MIJIN_ENABLE_OPENSSL to True in your environment settings." +#if !defined(MIJIN_ENABLE_OPENSSL) +#error "SSL support not enabled. Set MIJIN_ENABLE_OPENSSL to True in your SCons environment settings." #endif // !MIJIN_ENABLE_OPENSSL #include @@ -45,13 +45,15 @@ private: template - auto runIOLoop(TFunc&& func, TArgs&&... args) -> std::decay_t> + auto runIOLoop(TFunc&& func, bool block, TArgs&&... args) -> std::decay_t> { using result_t = std::decay_t>; - while (true) + const std::size_t maxTries = block ? std::numeric_limits::max() : 10; + + ossl::Error error; + for (std::size_t tryNum = 0; tryNum < maxTries; ++tryNum) { auto result = std::invoke(std::forward(func), ssl_, std::forward(args)...); - ossl::Error error; if constexpr (std::is_same_v) { if (error.isSuccess()) @@ -86,6 +88,7 @@ private: return error; } } + return error; } // mijin::Task c_readRaw(std::span buffer, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) override; // mijin::Task c_writeRaw(std::span buffer) override;