From 0be34a845a484e6a14d55aff743b5ddcfb7e7e74 Mon Sep 17 00:00:00 2001 From: Patrick Wuttke Date: Thu, 22 Aug 2024 00:30:38 +0200 Subject: [PATCH] Added wrapper for openssl types. --- source/mijin/net/openssl_wrappers.hpp | 407 ++++++++++++++++++++++++++ source/mijin/net/ssl.cpp | 194 ++++++------ source/mijin/net/ssl.hpp | 58 +++- 3 files changed, 545 insertions(+), 114 deletions(-) create mode 100644 source/mijin/net/openssl_wrappers.hpp diff --git a/source/mijin/net/openssl_wrappers.hpp b/source/mijin/net/openssl_wrappers.hpp new file mode 100644 index 0000000..d0baaa7 --- /dev/null +++ b/source/mijin/net/openssl_wrappers.hpp @@ -0,0 +1,407 @@ + + +#pragma once + +#if !defined(MIJIN_NET_OPENSSL_WRAPPERS_HPP_INCLUDED) +#define MIJIN_NET_OPENSSL_WRAPPERS_HPP_INCLUDED 1 + +#include +#include + +#include +#include +#include + +#include "../debug/assert.hpp" +#include "../types/result.hpp" + +namespace ossl +{ +struct ErrorFrame +{ + std::string message; + std::string file; + std::string function; + std::string data; + unsigned long numeric = 0; + int line = 0; + int flags = 0; +}; + +struct [[nodiscard]] Error +{ + int sslError = SSL_ERROR_NONE; + std::vector frames; + + [[nodiscard]] + bool isSuccess() const noexcept { return sslError == SSL_ERROR_NONE; } + + static inline Error current(int sslError = -1) noexcept; + static inline Error current(SSL* handle, int result) noexcept { return current(SSL_get_error(handle, result)); } +}; +template +using Result = mijin::ResultBase; + +// callback typedefs +using verify_callback_t = int (*) (int, X509_STORE_CTX *); + +template +class Base +{ +protected: + using base_t = Base; + + THandle handle_ = nullptr; +protected: + explicit Base(THandle handle) noexcept : handle_(handle) {} +public: + Base() noexcept = default; + Base(const Base& other) noexcept : handle_(other.handle_) + { + if (handle_) + { + TActual::upReferences(handle_); + } + } + Base(Base&& other) noexcept : handle_(std::exchange(other.handle_, {})) {} + + ~Base() noexcept + { + static_cast(*this).free(); + } + + TActual& operator=(const Base& other) noexcept + { + if (this == &other) + { + return static_cast(*this); + } + static_cast(*this).free(); + handle_ = other.handle_; + if (handle_) + { + TActual::upReferences(handle_); + } + return static_cast(*this); + } + + TActual& operator=(Base&& other) noexcept + { + if (this == &other) + { + return static_cast(*this); + } + static_cast(*this).free(); + handle_ = std::exchange(other.handle_, {}); + return static_cast(*this); + } + auto operator<=>(const Base&) const noexcept = default; + operator bool() const noexcept { return static_cast(handle_); } + bool operator!() const noexcept { return !static_cast(handle_); } + + [[nodiscard]] + THandle getHandle() const noexcept { return handle_; } + + [[nodiscard]] + THandle releaseHandle() noexcept { return std::exchange(handle_, nullptr); } +}; + +class X509Store : public Base +{ +public: + using Base::Base; + Error create() noexcept + { + MIJIN_ASSERT(handle_ == nullptr, "X509 Store already created."); + ERR_clear_error(); + handle_ = X509_STORE_new(); + if (handle_ == nullptr) + { + return Error::current(); + } + return {}; + } + + void free() noexcept + { + if (handle_ != nullptr) + { + X509_STORE_free(handle_); + handle_ = nullptr; + } + } + + Error loadFile(const char* file) const noexcept + { + ERR_clear_error(); + if (!X509_STORE_load_file(handle_, file)) + { + return Error::current(); + } + return {}; + } + + static void upReferences(X509_STORE* handle) noexcept + { + X509_STORE_up_ref(handle); + } +}; + +class Context : public Base +{ +public: + Error create(const SSL_METHOD* method) noexcept + { + MIJIN_ASSERT(handle_ == nullptr, "Context already created."); + ERR_clear_error(); + handle_ = SSL_CTX_new(method); + if (handle_ == nullptr) + { + return Error::current(); + } + return {}; + } + + void free() noexcept + { + if (handle_ == nullptr) + { + return; + } + SSL_CTX_free(handle_); + handle_ = nullptr; + } + + void setVerify(int mode, verify_callback_t callback = nullptr) const noexcept + { + SSL_CTX_set_verify(handle_, mode, callback); + } + + void setCertStore(X509Store store) const noexcept + { + SSL_CTX_set_cert_store(handle_, store.releaseHandle()); + } + + Error setMinProtoVersion(int version) const noexcept + { + ERR_clear_error(); + if (!SSL_CTX_set_min_proto_version(handle_, version)) + { + return Error::current(); + } + return {}; + } + + static void upReferences(SSL_CTX* handle) noexcept + { + SSL_CTX_up_ref(handle); + } +}; + +class Bio : public Base +{ +public: + Error createPair(Bio& otherBio, std::size_t writeBuf = 0, std::size_t otherWriteBuf = 0) noexcept + { + MIJIN_ASSERT(handle_ == nullptr, "Ssl already created."); + MIJIN_ASSERT(otherBio.handle_ == nullptr, "Ssl already created."); + ERR_clear_error(); + if (!BIO_new_bio_pair(&handle_, writeBuf, &otherBio.handle_, otherWriteBuf)) + { + return Error::current(); + } + return {}; + } + + void free() noexcept + { + if (handle_ == nullptr) + { + return; + } + BIO_free_all(handle_); + handle_ = nullptr; + } + + [[nodiscard]] + std::size_t ctrlPending() const noexcept + { + return BIO_ctrl_pending(handle_); + } + + [[nodiscard]] + std::size_t ctrlWPending() const noexcept + { + return BIO_ctrl_wpending(handle_); + } + + Result write(const void* data, int length) const noexcept + { + ERR_clear_error(); + const int result = BIO_write(handle_, data, length); + if (result <= 0) + { + return Error::current(); + } + return result; + } + + Result read(void* data, int length) const noexcept + { + ERR_clear_error(); + const int result = BIO_read(handle_, data, length); + if (result <= 0) + { + return Error::current(); + } + return result; + } + + static void upReferences(BIO* handle) noexcept + { + BIO_up_ref(handle); + } +}; + +class Ssl : public Base +{ +public: + Error create(const Context& context) noexcept + { + MIJIN_ASSERT(handle_ == nullptr, "Ssl already created."); + ERR_clear_error(); + handle_ = SSL_new(context.getHandle()); + if (handle_ == nullptr) + { + return Error::current(); + } + return {}; + } + + void free() noexcept + { + if (handle_ == nullptr) + { + return; + } + SSL_free(handle_); + handle_ = nullptr; + } + + void setBio(Bio readBio, Bio writeBio) const noexcept + { + SSL_set_bio(handle_, readBio.releaseHandle(), writeBio.releaseHandle()); + } + + void setBio(Bio&& bio) const noexcept + { + BIO* bioHandle = bio.releaseHandle(); + SSL_set_bio(handle_, bioHandle, bioHandle); + } + + Error setTLSExtHostname(const char* hostname) const noexcept + { + ERR_clear_error(); + if (const int result = SSL_set_tlsext_host_name(handle_, hostname); result != 1) + { + return Error::current(handle_, result); + } + return {}; + } + + Error setHost(const char* hostname) const noexcept + { + ERR_clear_error(); + if (const int result = SSL_set1_host(handle_, hostname); result != 1) + { + return Error::current(handle_, result); + } + return {}; + } + + Error connect() const noexcept + { + ERR_clear_error(); + if (const int result = SSL_connect(handle_); result != 1) + { + return Error::current(handle_, result); + } + return {}; + } + + Error shutdown() const noexcept + { + ERR_clear_error(); + if (const int result = SSL_shutdown(handle_); result != 1) + { + if (result == 0) + { + return Error{.sslError = SSL_ERROR_WANT_WRITE}; // TODO? + } + return Error::current(handle_, result); + } + return {}; + } + + [[nodiscard]] + long getVerifyResult() const noexcept + { + return SSL_get_verify_result(handle_); + } + + Result write(const void* data, int length) const noexcept + { + ERR_clear_error(); + const int result = SSL_write(handle_, data, length); + if (result <= 0) + { + return Error::current(handle_, result); + } + return result; + } + + Result read(void* data, int length) const noexcept + { + ERR_clear_error(); + const int result = SSL_read(handle_, data, length); + if (result <= 0) + { + return Error::current(handle_, result); + } + return result; + } + + static void upReferences(SSL* handle) noexcept + { + SSL_up_ref(handle); + } +}; + +Error Error::current(int sslError_) noexcept +{ + Error error = { + .sslError = sslError_ + }; + const char* file = nullptr; + int line = 0; + const char* func = nullptr; + const char* data = nullptr; + int flags = 0; + + while (const unsigned long numeric = ERR_get_error_all(&file, &line, &func, &data, &flags)) + { + error.frames.push_back({ + .message = ERR_error_string(numeric, nullptr), + .file = file != nullptr ? file : "", + .function = func != nullptr ? func : "", + .data = data != nullptr ? data : "", + .line = line, + .flags = flags + }); + } + + return error; +} +} + +#endif // !defined(MIJIN_NET_OPENSSL_WRAPPERS_HPP_INCLUDED) diff --git a/source/mijin/net/ssl.cpp b/source/mijin/net/ssl.cpp index 1a997e4..516f20d 100644 --- a/source/mijin/net/ssl.cpp +++ b/source/mijin/net/ssl.cpp @@ -2,7 +2,6 @@ #include "./ssl.hpp" #include -#include namespace mijin @@ -10,131 +9,103 @@ namespace mijin namespace { inline constexpr int BIO_BUFFER_SIZE = 4096; -SSL_CTX* getSSLContext(bool create = true) noexcept +ossl::Result getSSLContext() noexcept { - static SSL_CTX* context = nullptr; + static ossl::Context context; static std::mutex contextMutex; - if (create && context == nullptr) + if (!context) { const std::unique_lock contextLock(contextMutex); - if (context != nullptr) + + if (context) { - return context; + return &context; } - context = SSL_CTX_new(SSLv23_client_method()); - SSL_CTX_set_verify(context, SSL_VERIFY_PEER, nullptr); - if (!SSL_CTX_set_default_verify_paths(context) - || !SSL_CTX_set_min_proto_version(context, TLS1_2_VERSION)) + + ossl::Context newContext; + if (const ossl::Error error = newContext.create(SSLv23_client_method()); !error.isSuccess()) { - SSL_CTX_free(context); - context = nullptr; - return nullptr; + return error; } + newContext.setVerify(SSL_VERIFY_PEER); + + ossl::X509Store store; + if (const ossl::Error error = store.create(); !error.isSuccess()) + { + return error; + } + if (const ossl::Error error = store.loadFile("/etc/ssl/cert.pem"); !error.isSuccess()) + { + return error; + } + newContext.setCertStore(std::move(store)); + if (const ossl::Error error = newContext.setMinProtoVersion(TLS1_2_VERSION); !error.isSuccess()) + { + return error; + } + + context = std::move(newContext); } - return context; + return &context; } - -class SSLCleanupHelper -{ -public: - ~SSLCleanupHelper() noexcept - { - SSL_CTX* context = getSSLContext(false); - if (context != nullptr) - { - SSL_CTX_free(context); - } - } -} gSSLCleanupHelper; } StreamError SSLStream::open(Stream& base, const std::string& hostname) noexcept { MIJIN_ASSERT(base_ == nullptr, "SSL stream is already open."); - SSL_CTX* context = getSSLContext(); - if (context == nullptr) + const ossl::Result contextResult = getSSLContext(); + if (contextResult.isError()) + { + // TODO: convert/print error + return StreamError::UNKNOWN_ERROR; + } + ossl::Context& context = *contextResult.getValue(); + + ossl::Ssl ssl; + if (const ossl::Error error = ssl.create(context); !error.isSuccess()) { return StreamError::UNKNOWN_ERROR; } - SSL* ssl = SSL_new(context); - if (ssl == nullptr) + ossl::Bio externalBio; + ossl::Bio internalBio; + if (const ossl::Error error = internalBio.createPair(externalBio, BIO_BUFFER_SIZE, BIO_BUFFER_SIZE); !error.isSuccess()) + { + return StreamError::UNKNOWN_ERROR; + } + ssl.setBio(std::move(internalBio)); + + if (const ossl::Error error = ssl.setTLSExtHostname(hostname.c_str()); !error.isSuccess()) { return StreamError::UNKNOWN_ERROR; } - BIO* bioA; - BIO* bioB; - if (!BIO_new_bio_pair(&bioB, 0, &bioA, 0)) + if (const ossl::Error error = ssl.setHost(hostname.c_str()); !error.isSuccess()) { - SSL_free(ssl); - return StreamError::UNKNOWN_ERROR; - } - SSL_set_bio(ssl, bioB, bioB); - - if (!SSL_set_tlsext_host_name(ssl, hostname.c_str())) - { - SSL_free(ssl); - BIO_free_all(bioA); return StreamError::UNKNOWN_ERROR; } - if (!SSL_set1_host(ssl, hostname.c_str())) - { - SSL_free(ssl); - BIO_free_all(bioA); - return StreamError::UNKNOWN_ERROR; - } - - ssl_ = ssl; - bioA_ = bioA; - bioB_ = bioB; + // these need to be initialized for connecting + ssl_ = std::move(ssl); + externalBio_ = std::move(externalBio); base_ = &base; - while(true) + + if (const ossl::Error error = runIOLoop(&ossl::Ssl::connect); !error.isSuccess()) { - if (const int result = SSL_connect(ssl); result < 1) - { - const int err = SSL_get_error(ssl, result); - [[maybe_unused]] int rrA = BIO_get_read_request(bioA); - [[maybe_unused]] int rrB = BIO_get_read_request(bioB); - [[maybe_unused]] int wgA = BIO_get_write_guarantee(bioA); - [[maybe_unused]] int wgB = BIO_get_write_guarantee(bioB); - switch (err) - { - case SSL_ERROR_WANT_READ: - case SSL_ERROR_WANT_WRITE: - if(const StreamError error = baseToBio(); error != StreamError::SUCCESS) - { - SSL_free(ssl); - BIO_free_all(bioA); - return error; - } - if (const StreamError error = bioToBase(); error != StreamError::SUCCESS) - { - SSL_free(ssl); - BIO_free_all(bioA); - return error; - } - break; - default: - SSL_free(ssl); - BIO_free_all(bioA); - return StreamError::UNKNOWN_ERROR; - } - } - else - { - break; - } + ssl_.free(); + externalBio.free(); + base_ = nullptr; + return StreamError::UNKNOWN_ERROR; // TODO: translate } - if (SSL_get_verify_result(ssl) != X509_V_OK) + if (ssl_.getVerifyResult() != X509_V_OK) { - SSL_free(ssl); - BIO_free_all(bioA); + ssl_.free(); + externalBio.free(); + base_ = nullptr; return StreamError::UNKNOWN_ERROR; } @@ -145,9 +116,9 @@ void SSLStream::close() noexcept { MIJIN_ASSERT(base_ != nullptr, "SSL stream is not open."); base_ = nullptr; - SSL_shutdown(static_cast(ssl_)); - SSL_free(static_cast(ssl_)); - BIO_free_all(static_cast(bioA_)); + (void) runIOLoop(&ossl::Ssl::shutdown); + ssl_.free(); + externalBio_.free(); } StreamError SSLStream::writeRaw(std::span buffer) @@ -157,16 +128,16 @@ StreamError SSLStream::writeRaw(std::span buffer) return StreamError::SUCCESS; } - SSL* ssl = static_cast(ssl_); std::size_t bytesToWrite = buffer.size(); while (bytesToWrite > 0) { - const int result = SSL_write(ssl, buffer.data() + buffer.size() - bytesToWrite, static_cast(bytesToWrite)); - if (result <= 0) + const ossl::Result result = runIOLoop(&ossl::Ssl::write, buffer.data() + buffer.size() - bytesToWrite, + static_cast(bytesToWrite)); + if (result.isError()) { return StreamError::UNKNOWN_ERROR; } - bytesToWrite -= result; + bytesToWrite -= result.getValue(); if (const StreamError error = bioToBase(); error != StreamError::SUCCESS) { @@ -184,7 +155,6 @@ StreamError SSLStream::readRaw(std::span buffer, const mijin::Read { return StreamError::SUCCESS; } - SSL* ssl = static_cast(ssl_); std::size_t bytesToRead = buffer.size(); while (bytesToRead > 0) @@ -193,11 +163,15 @@ StreamError SSLStream::readRaw(std::span buffer, const mijin::Read { return error; } - const int result = SSL_read(ssl, buffer.data() + buffer.size() - bytesToRead, static_cast(bytesToRead)); - if (result <= 0) + + const ossl::Result result = runIOLoop(&ossl::Ssl::read, buffer.data() + buffer.size() - bytesToRead, + static_cast(bytesToRead)); + if (result.isError()) { return StreamError::UNKNOWN_ERROR; } + + bytesToRead -= result.getValue(); } // TODO: options and outBytesRead @@ -251,13 +225,12 @@ StreamFeatures SSLStream::getFeatures() StreamError SSLStream::bioToBase() noexcept { - BIO* bio = static_cast(bioA_); std::array buffer; - std::size_t bytes = std::min(BIO_ctrl_pending(bio), buffer.size()); + std::size_t bytes = std::min(externalBio_.ctrlPending(), buffer.size()); while (bytes > 0) { - const int result = BIO_read(bio, buffer.data(), static_cast(bytes)); - if (result <= 0) + const ossl::Result result = externalBio_.read(buffer.data(), static_cast(bytes)); + if (result.isError()) { return StreamError::UNKNOWN_ERROR; } @@ -266,19 +239,18 @@ StreamError SSLStream::bioToBase() noexcept return error; } - bytes = BIO_ctrl_pending(bio); + bytes = externalBio_.ctrlPending(); } return StreamError::SUCCESS; } StreamError SSLStream::baseToBio() noexcept { - BIO* bio = static_cast(bioA_); std::array buffer; while (true) { std::size_t bytes = 0; - std::size_t maxBytes = std::min(BIO_BUFFER_SIZE - BIO_ctrl_wpending(bio), buffer.size()); + std::size_t maxBytes = std::min(BIO_BUFFER_SIZE - externalBio_.ctrlWPending(), buffer.size()); if (maxBytes == 0) { @@ -296,12 +268,12 @@ StreamError SSLStream::baseToBio() noexcept // nothing more to read return StreamError::SUCCESS; } - const int result = BIO_write(bio, buffer.data(), static_cast(bytes)); - if (result <= 0) + const ossl::Result result = externalBio_.write(buffer.data(), static_cast(bytes)); + if (result.isError()) { return StreamError::UNKNOWN_ERROR; } - MIJIN_ASSERT(result == static_cast(bytes), "BIO reported more bytes in buffer than it actually accepted?"); + MIJIN_ASSERT(result.getValue() == static_cast(bytes), "BIO reported more bytes in buffer than it actually accepted?"); } } } // namespace shiken diff --git a/source/mijin/net/ssl.hpp b/source/mijin/net/ssl.hpp index 47f786b..7750699 100644 --- a/source/mijin/net/ssl.hpp +++ b/source/mijin/net/ssl.hpp @@ -10,6 +10,7 @@ #endif // !MIJIN_ENABLE_OPENSSL #include +#include "./openssl_wrappers.hpp" #include "../io/stream.hpp" namespace mijin @@ -18,10 +19,16 @@ class SSLStream : public Stream { private: Stream* base_ = nullptr; - void* ssl_ = nullptr; - void* bioA_ = nullptr; - void* bioB_ = nullptr; + ossl::Ssl ssl_; + ossl::Bio externalBio_; public: + ~SSLStream() noexcept override + { + if (base_ != nullptr) + { + close(); + } + } StreamError open(Stream& base, const std::string& hostname) noexcept; void close() noexcept; @@ -35,6 +42,51 @@ public: private: StreamError bioToBase() noexcept; StreamError baseToBio() noexcept; + + + template + auto runIOLoop(TFunc&& func, TArgs&&... args) -> std::decay_t> + { + using result_t = std::decay_t>; + while (true) + { + auto result = std::invoke(std::forward(func), ssl_, std::forward(args)...); + ossl::Error error; + if constexpr (std::is_same_v) + { + if (error.isSuccess()) + { + return error; + } + error = result; + } + else + { + // assume result type + if (result.isSuccess()) + { + return result; + } + error = result.getError(); + } + switch (error.sslError) + { + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + if(const StreamError streamError = baseToBio(); streamError != StreamError::SUCCESS) + { + return error; + } + if (const StreamError streamError = bioToBase(); streamError != StreamError::SUCCESS) + { + return error; + } + break; + default: + return error; + } + } + } // mijin::Task c_readRaw(std::span buffer, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) override; // mijin::Task c_writeRaw(std::span buffer) override; };