From f761f2fb078c7734b05762b9b38805230a8a717b Mon Sep 17 00:00:00 2001 From: Patrick Wuttke Date: Wed, 21 Aug 2024 09:35:49 +0200 Subject: [PATCH] SSLStream (WIP) --- SModule | 8 + dependencies.json | 4 + source/mijin/io/stream.hpp | 1 + source/mijin/net/http.cpp | 13 +- source/mijin/net/ip.hpp | 12 ++ source/mijin/net/socket.cpp | 25 ++- source/mijin/net/socket.hpp | 2 +- source/mijin/net/ssl.cpp | 307 ++++++++++++++++++++++++++++++++++++ source/mijin/net/ssl.hpp | 44 ++++++ 9 files changed, 406 insertions(+), 10 deletions(-) create mode 100644 source/mijin/net/ssl.cpp create mode 100644 source/mijin/net/ssl.hpp diff --git a/SModule b/SModule index 57ab3f2..30b6d6f 100644 --- a/SModule +++ b/SModule @@ -27,6 +27,14 @@ if env['BUILD_TYPE'] == 'debug': cppdefines += ['MIJIN_DEBUG=1', 'MIJIN_CHECKED_ITERATORS=1'] +# SSL libs +if env.get('MIJIN_ENABLE_OPENSSL'): + cppdefines.append('MIJIN_ENABLE_OPENSSL=1') + mijin_sources.extend(Split(""" + source/mijin/net/ssl.cpp + """)) + + lib_mijin = env.UnityStaticLibrary( target = env['LIB_DIR'] + '/mijin', source = mijin_sources, diff --git a/dependencies.json b/dependencies.json index e2931e7..ed61f6d 100644 --- a/dependencies.json +++ b/dependencies.json @@ -6,5 +6,9 @@ "winsock2": { "condition": "target_os == 'nt'" + }, + "openssl": + { + "condition": "getenv('MIJIN_ENABLE_OPENSSL')" } } diff --git a/source/mijin/io/stream.hpp b/source/mijin/io/stream.hpp index 22d7094..59276d2 100644 --- a/source/mijin/io/stream.hpp +++ b/source/mijin/io/stream.hpp @@ -41,6 +41,7 @@ struct ReadOptions { bool partial : 1 = false; bool peek : 1 = false; + bool noBlock : 1 = false; }; struct StreamFeatures diff --git a/source/mijin/net/http.cpp b/source/mijin/net/http.cpp index 60ac471..fcef3d9 100644 --- a/source/mijin/net/http.cpp +++ b/source/mijin/net/http.cpp @@ -204,10 +204,19 @@ Task> HTTPClient::c_request(const URL& url, HTTPReque co_return StreamError::UNKNOWN_ERROR; } Optional ipAddress = ipAddressFromString(url.getHost()); - // TODO: lookup host if (ipAddress.empty()) { - co_return StreamError::UNKNOWN_ERROR; + StreamResult> addresses = co_await c_resolveHostname(url.getHost()); + if (addresses.isError()) + { + co_return addresses.getError(); + } + else if (addresses->empty()) + { + co_return StreamError::UNKNOWN_ERROR; + } + // TODO: try all addresses + ipAddress = addresses->front(); } if (!request.headers.contains("host")) diff --git a/source/mijin/net/ip.hpp b/source/mijin/net/ip.hpp index c4faa12..07ac797 100644 --- a/source/mijin/net/ip.hpp +++ b/source/mijin/net/ip.hpp @@ -62,6 +62,18 @@ inline Optional ipAddressFromString(std::string_view stringView) n [[nodiscard]] Task>> c_resolveHostname(std::string hostname) noexcept; + +[[nodiscard]] +inline Task>> c_resolveHostname(std::string_view hostname) noexcept +{ + return c_resolveHostname(std::string(hostname.begin(), hostname.end())); +} + +[[nodiscard]] +inline Task>> c_resolveHostname(const char* hostname) noexcept +{ + return c_resolveHostname(std::string(hostname)); +} } #endif // !defined(MIJIN_NET_IP_HPP_INCLUDED) diff --git a/source/mijin/net/socket.cpp b/source/mijin/net/socket.cpp index e1e3531..617bf10 100644 --- a/source/mijin/net/socket.cpp +++ b/source/mijin/net/socket.cpp @@ -183,14 +183,25 @@ StreamError translateWinError() noexcept StreamError TCPStream::readRaw(std::span buffer, const ReadOptions& options, std::size_t* outBytesRead) { MIJIN_ASSERT(isOpen(), "Socket is not open."); - setAsync(false); + setNoblock(options.noBlock); const long bytesRead = osRecv(handle_, buffer, readFlags(options)); if (bytesRead < 0) { - return translateErrno(); + if (!options.noBlock || errno != EAGAIN) + { + return translateErrno(); + } + if (outBytesRead != nullptr) + { + *outBytesRead = 0; + } + return StreamError::SUCCESS; + } + if (outBytesRead != nullptr) + { + *outBytesRead = static_cast(bytesRead); } - *outBytesRead = static_cast(bytesRead); return StreamError::SUCCESS; } @@ -198,7 +209,7 @@ StreamError TCPStream::readRaw(std::span buffer, const ReadOptions StreamError TCPStream::writeRaw(std::span buffer) { MIJIN_ASSERT(isOpen(), "Socket is not open."); - setAsync(false); + setNoblock(false); if (osSend(handle_, buffer, 0) < 0) { @@ -211,7 +222,7 @@ StreamError TCPStream::writeRaw(std::span buffer) mijin::Task TCPStream::c_readRaw(std::span buffer, const ReadOptions& options, std::size_t* outBytesRead) { MIJIN_ASSERT(isOpen(), "Socket is not open."); - setAsync(true); + setNoblock(true); if (buffer.empty()) { @@ -249,7 +260,7 @@ Task TCPStream::c_writeRaw(std::span buffer) co_return StreamError::SUCCESS; } - setAsync(true); + setNoblock(true); while (true) { @@ -270,7 +281,7 @@ Task TCPStream::c_writeRaw(std::span buffer) } } -void TCPStream::setAsync(bool async) +void TCPStream::setNoblock(bool async) { if (async == async_) { diff --git a/source/mijin/net/socket.hpp b/source/mijin/net/socket.hpp index 39672e9..8cf35ff 100644 --- a/source/mijin/net/socket.hpp +++ b/source/mijin/net/socket.hpp @@ -78,7 +78,7 @@ public: void close() noexcept; [[nodiscard]] bool isOpen() const noexcept { return handle_ != INVALID_SOCKET_HANDLE; } private: - void setAsync(bool async); + void setNoblock(bool async); friend class TCPServerSocket; }; diff --git a/source/mijin/net/ssl.cpp b/source/mijin/net/ssl.cpp new file mode 100644 index 0000000..1a997e4 --- /dev/null +++ b/source/mijin/net/ssl.cpp @@ -0,0 +1,307 @@ + +#include "./ssl.hpp" + +#include +#include + + +namespace mijin +{ +namespace +{ +inline constexpr int BIO_BUFFER_SIZE = 4096; +SSL_CTX* getSSLContext(bool create = true) noexcept +{ + static SSL_CTX* context = nullptr; + static std::mutex contextMutex; + + if (create && context == nullptr) + { + const std::unique_lock contextLock(contextMutex); + if (context != nullptr) + { + return context; + } + context = SSL_CTX_new(SSLv23_client_method()); + SSL_CTX_set_verify(context, SSL_VERIFY_PEER, nullptr); + if (!SSL_CTX_set_default_verify_paths(context) + || !SSL_CTX_set_min_proto_version(context, TLS1_2_VERSION)) + { + SSL_CTX_free(context); + context = nullptr; + return nullptr; + } + } + + return context; +} + +class SSLCleanupHelper +{ +public: + ~SSLCleanupHelper() noexcept + { + SSL_CTX* context = getSSLContext(false); + if (context != nullptr) + { + SSL_CTX_free(context); + } + } +} gSSLCleanupHelper; +} + +StreamError SSLStream::open(Stream& base, const std::string& hostname) noexcept +{ + MIJIN_ASSERT(base_ == nullptr, "SSL stream is already open."); + + SSL_CTX* context = getSSLContext(); + if (context == nullptr) + { + return StreamError::UNKNOWN_ERROR; + } + + SSL* ssl = SSL_new(context); + if (ssl == nullptr) + { + return StreamError::UNKNOWN_ERROR; + } + + BIO* bioA; + BIO* bioB; + if (!BIO_new_bio_pair(&bioB, 0, &bioA, 0)) + { + SSL_free(ssl); + return StreamError::UNKNOWN_ERROR; + } + SSL_set_bio(ssl, bioB, bioB); + + if (!SSL_set_tlsext_host_name(ssl, hostname.c_str())) + { + SSL_free(ssl); + BIO_free_all(bioA); + return StreamError::UNKNOWN_ERROR; + } + + if (!SSL_set1_host(ssl, hostname.c_str())) + { + SSL_free(ssl); + BIO_free_all(bioA); + return StreamError::UNKNOWN_ERROR; + } + + ssl_ = ssl; + bioA_ = bioA; + bioB_ = bioB; + base_ = &base; + while(true) + { + if (const int result = SSL_connect(ssl); result < 1) + { + const int err = SSL_get_error(ssl, result); + [[maybe_unused]] int rrA = BIO_get_read_request(bioA); + [[maybe_unused]] int rrB = BIO_get_read_request(bioB); + [[maybe_unused]] int wgA = BIO_get_write_guarantee(bioA); + [[maybe_unused]] int wgB = BIO_get_write_guarantee(bioB); + switch (err) + { + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + if(const StreamError error = baseToBio(); error != StreamError::SUCCESS) + { + SSL_free(ssl); + BIO_free_all(bioA); + return error; + } + if (const StreamError error = bioToBase(); error != StreamError::SUCCESS) + { + SSL_free(ssl); + BIO_free_all(bioA); + return error; + } + break; + default: + SSL_free(ssl); + BIO_free_all(bioA); + return StreamError::UNKNOWN_ERROR; + } + } + else + { + break; + } + } + + if (SSL_get_verify_result(ssl) != X509_V_OK) + { + SSL_free(ssl); + BIO_free_all(bioA); + return StreamError::UNKNOWN_ERROR; + } + + return StreamError::SUCCESS; +} + +void SSLStream::close() noexcept +{ + MIJIN_ASSERT(base_ != nullptr, "SSL stream is not open."); + base_ = nullptr; + SSL_shutdown(static_cast(ssl_)); + SSL_free(static_cast(ssl_)); + BIO_free_all(static_cast(bioA_)); +} + +StreamError SSLStream::writeRaw(std::span buffer) +{ + if (buffer.empty()) + { + return StreamError::SUCCESS; + } + + SSL* ssl = static_cast(ssl_); + std::size_t bytesToWrite = buffer.size(); + while (bytesToWrite > 0) + { + const int result = SSL_write(ssl, buffer.data() + buffer.size() - bytesToWrite, static_cast(bytesToWrite)); + if (result <= 0) + { + return StreamError::UNKNOWN_ERROR; + } + bytesToWrite -= result; + + if (const StreamError error = bioToBase(); error != StreamError::SUCCESS) + { + return error; + } + } + + return StreamError::SUCCESS; +} + +StreamError SSLStream::readRaw(std::span buffer, const mijin::ReadOptions& options, + std::size_t* outBytesRead) +{ + if (buffer.empty()) + { + return StreamError::SUCCESS; + } + SSL* ssl = static_cast(ssl_); + + std::size_t bytesToRead = buffer.size(); + while (bytesToRead > 0) + { + if (const StreamError error = baseToBio(); error != StreamError::SUCCESS) + { + return error; + } + const int result = SSL_read(ssl, buffer.data() + buffer.size() - bytesToRead, static_cast(bytesToRead)); + if (result <= 0) + { + return StreamError::UNKNOWN_ERROR; + } + } + + // TODO: options and outBytesRead + (void) options; + if (outBytesRead != nullptr) + { + *outBytesRead = buffer.size(); + } + + return StreamError::SUCCESS; +} + +std::size_t SSLStream::tell() +{ + MIJIN_ERROR("SSLStream does not support tell()."); + return 0; +} + +StreamError SSLStream::seek(std::intptr_t pos, SeekMode seekMode) +{ + MIJIN_ERROR("SSLStream does not support tell()."); + (void) pos; + (void) seekMode; + return StreamError::NOT_SUPPORTED; +} + +void SSLStream::flush() +{ + base_->flush(); +} + +bool SSLStream::isAtEnd() +{ + return base_->isAtEnd(); +} + +StreamFeatures SSLStream::getFeatures() +{ + return { + .read = true, + .write = true, + .tell = false, + .seek = false, + .readOptions = { + .partial = false, + .peek = false, + .noBlock = false + } + }; +} + +StreamError SSLStream::bioToBase() noexcept +{ + BIO* bio = static_cast(bioA_); + std::array buffer; + std::size_t bytes = std::min(BIO_ctrl_pending(bio), buffer.size()); + while (bytes > 0) + { + const int result = BIO_read(bio, buffer.data(), static_cast(bytes)); + if (result <= 0) + { + return StreamError::UNKNOWN_ERROR; + } + if (const StreamError error = base_->writeRaw(buffer.data(), result); error != StreamError::SUCCESS) + { + return error; + } + + bytes = BIO_ctrl_pending(bio); + } + return StreamError::SUCCESS; +} + +StreamError SSLStream::baseToBio() noexcept +{ + BIO* bio = static_cast(bioA_); + std::array buffer; + while (true) + { + std::size_t bytes = 0; + std::size_t maxBytes = std::min(BIO_BUFFER_SIZE - BIO_ctrl_wpending(bio), buffer.size()); + + if (maxBytes == 0) + { + // buffer is full + return StreamError::SUCCESS; + } + + if (const StreamError error = base_->readRaw(buffer.data(), maxBytes,{.partial = true, .noBlock = true}, + &bytes); error != StreamError::SUCCESS) + { + return error; + } + if (bytes == 0) + { + // nothing more to read + return StreamError::SUCCESS; + } + const int result = BIO_write(bio, buffer.data(), static_cast(bytes)); + if (result <= 0) + { + return StreamError::UNKNOWN_ERROR; + } + MIJIN_ASSERT(result == static_cast(bytes), "BIO reported more bytes in buffer than it actually accepted?"); + } +} +} // namespace shiken diff --git a/source/mijin/net/ssl.hpp b/source/mijin/net/ssl.hpp new file mode 100644 index 0000000..47f786b --- /dev/null +++ b/source/mijin/net/ssl.hpp @@ -0,0 +1,44 @@ + + +#pragma once + +#if !defined(MIJIN_NET_SSL_HPP_INCLUDED) +#define MIJIN_NET_SSL_HPP_INCLUDED 1 + +#if !MIJIN_ENABLE_OPENSSL +#error "SSL support not enabled. Set MIJIN_ENABLE_OPENSSL to True in your environment settings." +#endif // !MIJIN_ENABLE_OPENSSL + +#include +#include "../io/stream.hpp" + +namespace mijin +{ +class SSLStream : public Stream +{ +private: + Stream* base_ = nullptr; + void* ssl_ = nullptr; + void* bioA_ = nullptr; + void* bioB_ = nullptr; +public: + StreamError open(Stream& base, const std::string& hostname) noexcept; + void close() noexcept; + + StreamError readRaw(std::span buffer, const ReadOptions& options, std::size_t* outBytesRead) override; + StreamError writeRaw(std::span buffer) override; + std::size_t tell() override; + StreamError seek(std::intptr_t pos, SeekMode seekMode) override; + void flush() override; + bool isAtEnd() override; + StreamFeatures getFeatures() override; +private: + StreamError bioToBase() noexcept; + StreamError baseToBio() noexcept; + // mijin::Task c_readRaw(std::span buffer, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) override; + // mijin::Task c_writeRaw(std::span buffer) override; +}; +} + +#endif // !defined(MIJIN_NET_SSL_HPP_INCLUDED) +