Added wrapper for openssl types.
This commit is contained in:
parent
f761f2fb07
commit
0be34a845a
407
source/mijin/net/openssl_wrappers.hpp
Normal file
407
source/mijin/net/openssl_wrappers.hpp
Normal file
@ -0,0 +1,407 @@
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#if !defined(MIJIN_NET_OPENSSL_WRAPPERS_HPP_INCLUDED)
|
||||
#define MIJIN_NET_OPENSSL_WRAPPERS_HPP_INCLUDED 1
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include <openssl/err.h>
|
||||
#include <openssl/ssl.h>
|
||||
#include <openssl/x509_vfy.h>
|
||||
|
||||
#include "../debug/assert.hpp"
|
||||
#include "../types/result.hpp"
|
||||
|
||||
namespace ossl
|
||||
{
|
||||
struct ErrorFrame
|
||||
{
|
||||
std::string message;
|
||||
std::string file;
|
||||
std::string function;
|
||||
std::string data;
|
||||
unsigned long numeric = 0;
|
||||
int line = 0;
|
||||
int flags = 0;
|
||||
};
|
||||
|
||||
struct [[nodiscard]] Error
|
||||
{
|
||||
int sslError = SSL_ERROR_NONE;
|
||||
std::vector<ErrorFrame> frames;
|
||||
|
||||
[[nodiscard]]
|
||||
bool isSuccess() const noexcept { return sslError == SSL_ERROR_NONE; }
|
||||
|
||||
static inline Error current(int sslError = -1) noexcept;
|
||||
static inline Error current(SSL* handle, int result) noexcept { return current(SSL_get_error(handle, result)); }
|
||||
};
|
||||
template<typename TSuccess>
|
||||
using Result = mijin::ResultBase<TSuccess, Error>;
|
||||
|
||||
// callback typedefs
|
||||
using verify_callback_t = int (*) (int, X509_STORE_CTX *);
|
||||
|
||||
template<typename TActual, typename THandle>
|
||||
class Base
|
||||
{
|
||||
protected:
|
||||
using base_t = Base<TActual, THandle>;
|
||||
|
||||
THandle handle_ = nullptr;
|
||||
protected:
|
||||
explicit Base(THandle handle) noexcept : handle_(handle) {}
|
||||
public:
|
||||
Base() noexcept = default;
|
||||
Base(const Base& other) noexcept : handle_(other.handle_)
|
||||
{
|
||||
if (handle_)
|
||||
{
|
||||
TActual::upReferences(handle_);
|
||||
}
|
||||
}
|
||||
Base(Base&& other) noexcept : handle_(std::exchange(other.handle_, {})) {}
|
||||
|
||||
~Base() noexcept
|
||||
{
|
||||
static_cast<TActual&>(*this).free();
|
||||
}
|
||||
|
||||
TActual& operator=(const Base& other) noexcept
|
||||
{
|
||||
if (this == &other)
|
||||
{
|
||||
return static_cast<TActual&>(*this);
|
||||
}
|
||||
static_cast<TActual&>(*this).free();
|
||||
handle_ = other.handle_;
|
||||
if (handle_)
|
||||
{
|
||||
TActual::upReferences(handle_);
|
||||
}
|
||||
return static_cast<TActual&>(*this);
|
||||
}
|
||||
|
||||
TActual& operator=(Base&& other) noexcept
|
||||
{
|
||||
if (this == &other)
|
||||
{
|
||||
return static_cast<TActual&>(*this);
|
||||
}
|
||||
static_cast<TActual&>(*this).free();
|
||||
handle_ = std::exchange(other.handle_, {});
|
||||
return static_cast<TActual&>(*this);
|
||||
}
|
||||
auto operator<=>(const Base&) const noexcept = default;
|
||||
operator bool() const noexcept { return static_cast<bool>(handle_); }
|
||||
bool operator!() const noexcept { return !static_cast<bool>(handle_); }
|
||||
|
||||
[[nodiscard]]
|
||||
THandle getHandle() const noexcept { return handle_; }
|
||||
|
||||
[[nodiscard]]
|
||||
THandle releaseHandle() noexcept { return std::exchange(handle_, nullptr); }
|
||||
};
|
||||
|
||||
class X509Store : public Base<X509Store, X509_STORE*>
|
||||
{
|
||||
public:
|
||||
using Base::Base;
|
||||
Error create() noexcept
|
||||
{
|
||||
MIJIN_ASSERT(handle_ == nullptr, "X509 Store already created.");
|
||||
ERR_clear_error();
|
||||
handle_ = X509_STORE_new();
|
||||
if (handle_ == nullptr)
|
||||
{
|
||||
return Error::current();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
void free() noexcept
|
||||
{
|
||||
if (handle_ != nullptr)
|
||||
{
|
||||
X509_STORE_free(handle_);
|
||||
handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Error loadFile(const char* file) const noexcept
|
||||
{
|
||||
ERR_clear_error();
|
||||
if (!X509_STORE_load_file(handle_, file))
|
||||
{
|
||||
return Error::current();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
static void upReferences(X509_STORE* handle) noexcept
|
||||
{
|
||||
X509_STORE_up_ref(handle);
|
||||
}
|
||||
};
|
||||
|
||||
class Context : public Base<Context, SSL_CTX*>
|
||||
{
|
||||
public:
|
||||
Error create(const SSL_METHOD* method) noexcept
|
||||
{
|
||||
MIJIN_ASSERT(handle_ == nullptr, "Context already created.");
|
||||
ERR_clear_error();
|
||||
handle_ = SSL_CTX_new(method);
|
||||
if (handle_ == nullptr)
|
||||
{
|
||||
return Error::current();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
void free() noexcept
|
||||
{
|
||||
if (handle_ == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
SSL_CTX_free(handle_);
|
||||
handle_ = nullptr;
|
||||
}
|
||||
|
||||
void setVerify(int mode, verify_callback_t callback = nullptr) const noexcept
|
||||
{
|
||||
SSL_CTX_set_verify(handle_, mode, callback);
|
||||
}
|
||||
|
||||
void setCertStore(X509Store store) const noexcept
|
||||
{
|
||||
SSL_CTX_set_cert_store(handle_, store.releaseHandle());
|
||||
}
|
||||
|
||||
Error setMinProtoVersion(int version) const noexcept
|
||||
{
|
||||
ERR_clear_error();
|
||||
if (!SSL_CTX_set_min_proto_version(handle_, version))
|
||||
{
|
||||
return Error::current();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
static void upReferences(SSL_CTX* handle) noexcept
|
||||
{
|
||||
SSL_CTX_up_ref(handle);
|
||||
}
|
||||
};
|
||||
|
||||
class Bio : public Base<Bio, BIO*>
|
||||
{
|
||||
public:
|
||||
Error createPair(Bio& otherBio, std::size_t writeBuf = 0, std::size_t otherWriteBuf = 0) noexcept
|
||||
{
|
||||
MIJIN_ASSERT(handle_ == nullptr, "Ssl already created.");
|
||||
MIJIN_ASSERT(otherBio.handle_ == nullptr, "Ssl already created.");
|
||||
ERR_clear_error();
|
||||
if (!BIO_new_bio_pair(&handle_, writeBuf, &otherBio.handle_, otherWriteBuf))
|
||||
{
|
||||
return Error::current();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
void free() noexcept
|
||||
{
|
||||
if (handle_ == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
BIO_free_all(handle_);
|
||||
handle_ = nullptr;
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
std::size_t ctrlPending() const noexcept
|
||||
{
|
||||
return BIO_ctrl_pending(handle_);
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
std::size_t ctrlWPending() const noexcept
|
||||
{
|
||||
return BIO_ctrl_wpending(handle_);
|
||||
}
|
||||
|
||||
Result<int> write(const void* data, int length) const noexcept
|
||||
{
|
||||
ERR_clear_error();
|
||||
const int result = BIO_write(handle_, data, length);
|
||||
if (result <= 0)
|
||||
{
|
||||
return Error::current();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Result<int> read(void* data, int length) const noexcept
|
||||
{
|
||||
ERR_clear_error();
|
||||
const int result = BIO_read(handle_, data, length);
|
||||
if (result <= 0)
|
||||
{
|
||||
return Error::current();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static void upReferences(BIO* handle) noexcept
|
||||
{
|
||||
BIO_up_ref(handle);
|
||||
}
|
||||
};
|
||||
|
||||
class Ssl : public Base<Ssl, SSL*>
|
||||
{
|
||||
public:
|
||||
Error create(const Context& context) noexcept
|
||||
{
|
||||
MIJIN_ASSERT(handle_ == nullptr, "Ssl already created.");
|
||||
ERR_clear_error();
|
||||
handle_ = SSL_new(context.getHandle());
|
||||
if (handle_ == nullptr)
|
||||
{
|
||||
return Error::current();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
void free() noexcept
|
||||
{
|
||||
if (handle_ == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
SSL_free(handle_);
|
||||
handle_ = nullptr;
|
||||
}
|
||||
|
||||
void setBio(Bio readBio, Bio writeBio) const noexcept
|
||||
{
|
||||
SSL_set_bio(handle_, readBio.releaseHandle(), writeBio.releaseHandle());
|
||||
}
|
||||
|
||||
void setBio(Bio&& bio) const noexcept
|
||||
{
|
||||
BIO* bioHandle = bio.releaseHandle();
|
||||
SSL_set_bio(handle_, bioHandle, bioHandle);
|
||||
}
|
||||
|
||||
Error setTLSExtHostname(const char* hostname) const noexcept
|
||||
{
|
||||
ERR_clear_error();
|
||||
if (const int result = SSL_set_tlsext_host_name(handle_, hostname); result != 1)
|
||||
{
|
||||
return Error::current(handle_, result);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
Error setHost(const char* hostname) const noexcept
|
||||
{
|
||||
ERR_clear_error();
|
||||
if (const int result = SSL_set1_host(handle_, hostname); result != 1)
|
||||
{
|
||||
return Error::current(handle_, result);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
Error connect() const noexcept
|
||||
{
|
||||
ERR_clear_error();
|
||||
if (const int result = SSL_connect(handle_); result != 1)
|
||||
{
|
||||
return Error::current(handle_, result);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
Error shutdown() const noexcept
|
||||
{
|
||||
ERR_clear_error();
|
||||
if (const int result = SSL_shutdown(handle_); result != 1)
|
||||
{
|
||||
if (result == 0)
|
||||
{
|
||||
return Error{.sslError = SSL_ERROR_WANT_WRITE}; // TODO?
|
||||
}
|
||||
return Error::current(handle_, result);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
long getVerifyResult() const noexcept
|
||||
{
|
||||
return SSL_get_verify_result(handle_);
|
||||
}
|
||||
|
||||
Result<int> write(const void* data, int length) const noexcept
|
||||
{
|
||||
ERR_clear_error();
|
||||
const int result = SSL_write(handle_, data, length);
|
||||
if (result <= 0)
|
||||
{
|
||||
return Error::current(handle_, result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Result<int> read(void* data, int length) const noexcept
|
||||
{
|
||||
ERR_clear_error();
|
||||
const int result = SSL_read(handle_, data, length);
|
||||
if (result <= 0)
|
||||
{
|
||||
return Error::current(handle_, result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static void upReferences(SSL* handle) noexcept
|
||||
{
|
||||
SSL_up_ref(handle);
|
||||
}
|
||||
};
|
||||
|
||||
Error Error::current(int sslError_) noexcept
|
||||
{
|
||||
Error error = {
|
||||
.sslError = sslError_
|
||||
};
|
||||
const char* file = nullptr;
|
||||
int line = 0;
|
||||
const char* func = nullptr;
|
||||
const char* data = nullptr;
|
||||
int flags = 0;
|
||||
|
||||
while (const unsigned long numeric = ERR_get_error_all(&file, &line, &func, &data, &flags))
|
||||
{
|
||||
error.frames.push_back({
|
||||
.message = ERR_error_string(numeric, nullptr),
|
||||
.file = file != nullptr ? file : "",
|
||||
.function = func != nullptr ? func : "",
|
||||
.data = data != nullptr ? data : "",
|
||||
.line = line,
|
||||
.flags = flags
|
||||
});
|
||||
}
|
||||
|
||||
return error;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // !defined(MIJIN_NET_OPENSSL_WRAPPERS_HPP_INCLUDED)
|
@ -2,7 +2,6 @@
|
||||
#include "./ssl.hpp"
|
||||
|
||||
#include <mutex>
|
||||
#include <openssl/ssl.h>
|
||||
|
||||
|
||||
namespace mijin
|
||||
@ -10,131 +9,103 @@ namespace mijin
|
||||
namespace
|
||||
{
|
||||
inline constexpr int BIO_BUFFER_SIZE = 4096;
|
||||
SSL_CTX* getSSLContext(bool create = true) noexcept
|
||||
ossl::Result<ossl::Context*> getSSLContext() noexcept
|
||||
{
|
||||
static SSL_CTX* context = nullptr;
|
||||
static ossl::Context context;
|
||||
static std::mutex contextMutex;
|
||||
|
||||
if (create && context == nullptr)
|
||||
if (!context)
|
||||
{
|
||||
const std::unique_lock contextLock(contextMutex);
|
||||
if (context != nullptr)
|
||||
|
||||
if (context)
|
||||
{
|
||||
return context;
|
||||
return &context;
|
||||
}
|
||||
context = SSL_CTX_new(SSLv23_client_method());
|
||||
SSL_CTX_set_verify(context, SSL_VERIFY_PEER, nullptr);
|
||||
if (!SSL_CTX_set_default_verify_paths(context)
|
||||
|| !SSL_CTX_set_min_proto_version(context, TLS1_2_VERSION))
|
||||
|
||||
ossl::Context newContext;
|
||||
if (const ossl::Error error = newContext.create(SSLv23_client_method()); !error.isSuccess())
|
||||
{
|
||||
SSL_CTX_free(context);
|
||||
context = nullptr;
|
||||
return nullptr;
|
||||
return error;
|
||||
}
|
||||
newContext.setVerify(SSL_VERIFY_PEER);
|
||||
|
||||
ossl::X509Store store;
|
||||
if (const ossl::Error error = store.create(); !error.isSuccess())
|
||||
{
|
||||
return error;
|
||||
}
|
||||
if (const ossl::Error error = store.loadFile("/etc/ssl/cert.pem"); !error.isSuccess())
|
||||
{
|
||||
return error;
|
||||
}
|
||||
newContext.setCertStore(std::move(store));
|
||||
if (const ossl::Error error = newContext.setMinProtoVersion(TLS1_2_VERSION); !error.isSuccess())
|
||||
{
|
||||
return error;
|
||||
}
|
||||
|
||||
context = std::move(newContext);
|
||||
}
|
||||
|
||||
return context;
|
||||
return &context;
|
||||
}
|
||||
|
||||
class SSLCleanupHelper
|
||||
{
|
||||
public:
|
||||
~SSLCleanupHelper() noexcept
|
||||
{
|
||||
SSL_CTX* context = getSSLContext(false);
|
||||
if (context != nullptr)
|
||||
{
|
||||
SSL_CTX_free(context);
|
||||
}
|
||||
}
|
||||
} gSSLCleanupHelper;
|
||||
}
|
||||
|
||||
StreamError SSLStream::open(Stream& base, const std::string& hostname) noexcept
|
||||
{
|
||||
MIJIN_ASSERT(base_ == nullptr, "SSL stream is already open.");
|
||||
|
||||
SSL_CTX* context = getSSLContext();
|
||||
if (context == nullptr)
|
||||
const ossl::Result<ossl::Context*> contextResult = getSSLContext();
|
||||
if (contextResult.isError())
|
||||
{
|
||||
// TODO: convert/print error
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
ossl::Context& context = *contextResult.getValue();
|
||||
|
||||
ossl::Ssl ssl;
|
||||
if (const ossl::Error error = ssl.create(context); !error.isSuccess())
|
||||
{
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
|
||||
SSL* ssl = SSL_new(context);
|
||||
if (ssl == nullptr)
|
||||
ossl::Bio externalBio;
|
||||
ossl::Bio internalBio;
|
||||
if (const ossl::Error error = internalBio.createPair(externalBio, BIO_BUFFER_SIZE, BIO_BUFFER_SIZE); !error.isSuccess())
|
||||
{
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
ssl.setBio(std::move(internalBio));
|
||||
|
||||
if (const ossl::Error error = ssl.setTLSExtHostname(hostname.c_str()); !error.isSuccess())
|
||||
{
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
|
||||
BIO* bioA;
|
||||
BIO* bioB;
|
||||
if (!BIO_new_bio_pair(&bioB, 0, &bioA, 0))
|
||||
if (const ossl::Error error = ssl.setHost(hostname.c_str()); !error.isSuccess())
|
||||
{
|
||||
SSL_free(ssl);
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
SSL_set_bio(ssl, bioB, bioB);
|
||||
|
||||
if (!SSL_set_tlsext_host_name(ssl, hostname.c_str()))
|
||||
{
|
||||
SSL_free(ssl);
|
||||
BIO_free_all(bioA);
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
|
||||
if (!SSL_set1_host(ssl, hostname.c_str()))
|
||||
{
|
||||
SSL_free(ssl);
|
||||
BIO_free_all(bioA);
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
|
||||
ssl_ = ssl;
|
||||
bioA_ = bioA;
|
||||
bioB_ = bioB;
|
||||
// these need to be initialized for connecting
|
||||
ssl_ = std::move(ssl);
|
||||
externalBio_ = std::move(externalBio);
|
||||
base_ = &base;
|
||||
while(true)
|
||||
|
||||
if (const ossl::Error error = runIOLoop(&ossl::Ssl::connect); !error.isSuccess())
|
||||
{
|
||||
if (const int result = SSL_connect(ssl); result < 1)
|
||||
{
|
||||
const int err = SSL_get_error(ssl, result);
|
||||
[[maybe_unused]] int rrA = BIO_get_read_request(bioA);
|
||||
[[maybe_unused]] int rrB = BIO_get_read_request(bioB);
|
||||
[[maybe_unused]] int wgA = BIO_get_write_guarantee(bioA);
|
||||
[[maybe_unused]] int wgB = BIO_get_write_guarantee(bioB);
|
||||
switch (err)
|
||||
{
|
||||
case SSL_ERROR_WANT_READ:
|
||||
case SSL_ERROR_WANT_WRITE:
|
||||
if(const StreamError error = baseToBio(); error != StreamError::SUCCESS)
|
||||
{
|
||||
SSL_free(ssl);
|
||||
BIO_free_all(bioA);
|
||||
return error;
|
||||
}
|
||||
if (const StreamError error = bioToBase(); error != StreamError::SUCCESS)
|
||||
{
|
||||
SSL_free(ssl);
|
||||
BIO_free_all(bioA);
|
||||
return error;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
SSL_free(ssl);
|
||||
BIO_free_all(bioA);
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
}
|
||||
ssl_.free();
|
||||
externalBio.free();
|
||||
base_ = nullptr;
|
||||
return StreamError::UNKNOWN_ERROR; // TODO: translate
|
||||
}
|
||||
|
||||
if (SSL_get_verify_result(ssl) != X509_V_OK)
|
||||
if (ssl_.getVerifyResult() != X509_V_OK)
|
||||
{
|
||||
SSL_free(ssl);
|
||||
BIO_free_all(bioA);
|
||||
ssl_.free();
|
||||
externalBio.free();
|
||||
base_ = nullptr;
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
|
||||
@ -145,9 +116,9 @@ void SSLStream::close() noexcept
|
||||
{
|
||||
MIJIN_ASSERT(base_ != nullptr, "SSL stream is not open.");
|
||||
base_ = nullptr;
|
||||
SSL_shutdown(static_cast<SSL*>(ssl_));
|
||||
SSL_free(static_cast<SSL*>(ssl_));
|
||||
BIO_free_all(static_cast<BIO*>(bioA_));
|
||||
(void) runIOLoop(&ossl::Ssl::shutdown);
|
||||
ssl_.free();
|
||||
externalBio_.free();
|
||||
}
|
||||
|
||||
StreamError SSLStream::writeRaw(std::span<const std::uint8_t> buffer)
|
||||
@ -157,16 +128,16 @@ StreamError SSLStream::writeRaw(std::span<const std::uint8_t> buffer)
|
||||
return StreamError::SUCCESS;
|
||||
}
|
||||
|
||||
SSL* ssl = static_cast<SSL*>(ssl_);
|
||||
std::size_t bytesToWrite = buffer.size();
|
||||
while (bytesToWrite > 0)
|
||||
{
|
||||
const int result = SSL_write(ssl, buffer.data() + buffer.size() - bytesToWrite, static_cast<int>(bytesToWrite));
|
||||
if (result <= 0)
|
||||
const ossl::Result<int> result = runIOLoop(&ossl::Ssl::write, buffer.data() + buffer.size() - bytesToWrite,
|
||||
static_cast<int>(bytesToWrite));
|
||||
if (result.isError())
|
||||
{
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
bytesToWrite -= result;
|
||||
bytesToWrite -= result.getValue();
|
||||
|
||||
if (const StreamError error = bioToBase(); error != StreamError::SUCCESS)
|
||||
{
|
||||
@ -184,7 +155,6 @@ StreamError SSLStream::readRaw(std::span<std::uint8_t> buffer, const mijin::Read
|
||||
{
|
||||
return StreamError::SUCCESS;
|
||||
}
|
||||
SSL* ssl = static_cast<SSL*>(ssl_);
|
||||
|
||||
std::size_t bytesToRead = buffer.size();
|
||||
while (bytesToRead > 0)
|
||||
@ -193,11 +163,15 @@ StreamError SSLStream::readRaw(std::span<std::uint8_t> buffer, const mijin::Read
|
||||
{
|
||||
return error;
|
||||
}
|
||||
const int result = SSL_read(ssl, buffer.data() + buffer.size() - bytesToRead, static_cast<int>(bytesToRead));
|
||||
if (result <= 0)
|
||||
|
||||
const ossl::Result<int> result = runIOLoop(&ossl::Ssl::read, buffer.data() + buffer.size() - bytesToRead,
|
||||
static_cast<int>(bytesToRead));
|
||||
if (result.isError())
|
||||
{
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
|
||||
bytesToRead -= result.getValue();
|
||||
}
|
||||
|
||||
// TODO: options and outBytesRead
|
||||
@ -251,13 +225,12 @@ StreamFeatures SSLStream::getFeatures()
|
||||
|
||||
StreamError SSLStream::bioToBase() noexcept
|
||||
{
|
||||
BIO* bio = static_cast<BIO*>(bioA_);
|
||||
std::array<std::uint8_t, BIO_BUFFER_SIZE> buffer;
|
||||
std::size_t bytes = std::min(BIO_ctrl_pending(bio), buffer.size());
|
||||
std::size_t bytes = std::min(externalBio_.ctrlPending(), buffer.size());
|
||||
while (bytes > 0)
|
||||
{
|
||||
const int result = BIO_read(bio, buffer.data(), static_cast<int>(bytes));
|
||||
if (result <= 0)
|
||||
const ossl::Result<int> result = externalBio_.read(buffer.data(), static_cast<int>(bytes));
|
||||
if (result.isError())
|
||||
{
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
@ -266,19 +239,18 @@ StreamError SSLStream::bioToBase() noexcept
|
||||
return error;
|
||||
}
|
||||
|
||||
bytes = BIO_ctrl_pending(bio);
|
||||
bytes = externalBio_.ctrlPending();
|
||||
}
|
||||
return StreamError::SUCCESS;
|
||||
}
|
||||
|
||||
StreamError SSLStream::baseToBio() noexcept
|
||||
{
|
||||
BIO* bio = static_cast<BIO*>(bioA_);
|
||||
std::array<std::uint8_t, BIO_BUFFER_SIZE> buffer;
|
||||
while (true)
|
||||
{
|
||||
std::size_t bytes = 0;
|
||||
std::size_t maxBytes = std::min(BIO_BUFFER_SIZE - BIO_ctrl_wpending(bio), buffer.size());
|
||||
std::size_t maxBytes = std::min(BIO_BUFFER_SIZE - externalBio_.ctrlWPending(), buffer.size());
|
||||
|
||||
if (maxBytes == 0)
|
||||
{
|
||||
@ -296,12 +268,12 @@ StreamError SSLStream::baseToBio() noexcept
|
||||
// nothing more to read
|
||||
return StreamError::SUCCESS;
|
||||
}
|
||||
const int result = BIO_write(bio, buffer.data(), static_cast<int>(bytes));
|
||||
if (result <= 0)
|
||||
const ossl::Result<int> result = externalBio_.write(buffer.data(), static_cast<int>(bytes));
|
||||
if (result.isError())
|
||||
{
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
MIJIN_ASSERT(result == static_cast<int>(bytes), "BIO reported more bytes in buffer than it actually accepted?");
|
||||
MIJIN_ASSERT(result.getValue() == static_cast<int>(bytes), "BIO reported more bytes in buffer than it actually accepted?");
|
||||
}
|
||||
}
|
||||
} // namespace shiken
|
||||
|
@ -10,6 +10,7 @@
|
||||
#endif // !MIJIN_ENABLE_OPENSSL
|
||||
|
||||
#include <memory>
|
||||
#include "./openssl_wrappers.hpp"
|
||||
#include "../io/stream.hpp"
|
||||
|
||||
namespace mijin
|
||||
@ -18,10 +19,16 @@ class SSLStream : public Stream
|
||||
{
|
||||
private:
|
||||
Stream* base_ = nullptr;
|
||||
void* ssl_ = nullptr;
|
||||
void* bioA_ = nullptr;
|
||||
void* bioB_ = nullptr;
|
||||
ossl::Ssl ssl_;
|
||||
ossl::Bio externalBio_;
|
||||
public:
|
||||
~SSLStream() noexcept override
|
||||
{
|
||||
if (base_ != nullptr)
|
||||
{
|
||||
close();
|
||||
}
|
||||
}
|
||||
StreamError open(Stream& base, const std::string& hostname) noexcept;
|
||||
void close() noexcept;
|
||||
|
||||
@ -35,6 +42,51 @@ public:
|
||||
private:
|
||||
StreamError bioToBase() noexcept;
|
||||
StreamError baseToBio() noexcept;
|
||||
|
||||
|
||||
template<typename TFunc, typename... TArgs>
|
||||
auto runIOLoop(TFunc&& func, TArgs&&... args) -> std::decay_t<std::invoke_result_t<TFunc, ossl::Ssl&, TArgs...>>
|
||||
{
|
||||
using result_t = std::decay_t<std::invoke_result_t<TFunc, ossl::Ssl&, TArgs...>>;
|
||||
while (true)
|
||||
{
|
||||
auto result = std::invoke(std::forward<TFunc>(func), ssl_, std::forward<TArgs>(args)...);
|
||||
ossl::Error error;
|
||||
if constexpr (std::is_same_v<result_t, ossl::Error>)
|
||||
{
|
||||
if (error.isSuccess())
|
||||
{
|
||||
return error;
|
||||
}
|
||||
error = result;
|
||||
}
|
||||
else
|
||||
{
|
||||
// assume result type
|
||||
if (result.isSuccess())
|
||||
{
|
||||
return result;
|
||||
}
|
||||
error = result.getError();
|
||||
}
|
||||
switch (error.sslError)
|
||||
{
|
||||
case SSL_ERROR_WANT_READ:
|
||||
case SSL_ERROR_WANT_WRITE:
|
||||
if(const StreamError streamError = baseToBio(); streamError != StreamError::SUCCESS)
|
||||
{
|
||||
return error;
|
||||
}
|
||||
if (const StreamError streamError = bioToBase(); streamError != StreamError::SUCCESS)
|
||||
{
|
||||
return error;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
return error;
|
||||
}
|
||||
}
|
||||
}
|
||||
// mijin::Task<StreamError> c_readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) override;
|
||||
// mijin::Task<StreamError> c_writeRaw(std::span<const std::uint8_t> buffer) override;
|
||||
};
|
||||
|
Loading…
x
Reference in New Issue
Block a user