Added wrapper for openssl types.

This commit is contained in:
2024-08-22 00:30:38 +02:00
parent f761f2fb07
commit 0be34a845a
3 changed files with 545 additions and 114 deletions

View File

@@ -2,7 +2,6 @@
#include "./ssl.hpp"
#include <mutex>
#include <openssl/ssl.h>
namespace mijin
@@ -10,131 +9,103 @@ namespace mijin
namespace
{
inline constexpr int BIO_BUFFER_SIZE = 4096;
SSL_CTX* getSSLContext(bool create = true) noexcept
ossl::Result<ossl::Context*> getSSLContext() noexcept
{
static SSL_CTX* context = nullptr;
static ossl::Context context;
static std::mutex contextMutex;
if (create && context == nullptr)
if (!context)
{
const std::unique_lock contextLock(contextMutex);
if (context != nullptr)
if (context)
{
return context;
return &context;
}
context = SSL_CTX_new(SSLv23_client_method());
SSL_CTX_set_verify(context, SSL_VERIFY_PEER, nullptr);
if (!SSL_CTX_set_default_verify_paths(context)
|| !SSL_CTX_set_min_proto_version(context, TLS1_2_VERSION))
ossl::Context newContext;
if (const ossl::Error error = newContext.create(SSLv23_client_method()); !error.isSuccess())
{
SSL_CTX_free(context);
context = nullptr;
return nullptr;
return error;
}
newContext.setVerify(SSL_VERIFY_PEER);
ossl::X509Store store;
if (const ossl::Error error = store.create(); !error.isSuccess())
{
return error;
}
if (const ossl::Error error = store.loadFile("/etc/ssl/cert.pem"); !error.isSuccess())
{
return error;
}
newContext.setCertStore(std::move(store));
if (const ossl::Error error = newContext.setMinProtoVersion(TLS1_2_VERSION); !error.isSuccess())
{
return error;
}
context = std::move(newContext);
}
return context;
return &context;
}
class SSLCleanupHelper
{
public:
~SSLCleanupHelper() noexcept
{
SSL_CTX* context = getSSLContext(false);
if (context != nullptr)
{
SSL_CTX_free(context);
}
}
} gSSLCleanupHelper;
}
StreamError SSLStream::open(Stream& base, const std::string& hostname) noexcept
{
MIJIN_ASSERT(base_ == nullptr, "SSL stream is already open.");
SSL_CTX* context = getSSLContext();
if (context == nullptr)
const ossl::Result<ossl::Context*> contextResult = getSSLContext();
if (contextResult.isError())
{
// TODO: convert/print error
return StreamError::UNKNOWN_ERROR;
}
ossl::Context& context = *contextResult.getValue();
ossl::Ssl ssl;
if (const ossl::Error error = ssl.create(context); !error.isSuccess())
{
return StreamError::UNKNOWN_ERROR;
}
SSL* ssl = SSL_new(context);
if (ssl == nullptr)
ossl::Bio externalBio;
ossl::Bio internalBio;
if (const ossl::Error error = internalBio.createPair(externalBio, BIO_BUFFER_SIZE, BIO_BUFFER_SIZE); !error.isSuccess())
{
return StreamError::UNKNOWN_ERROR;
}
ssl.setBio(std::move(internalBio));
if (const ossl::Error error = ssl.setTLSExtHostname(hostname.c_str()); !error.isSuccess())
{
return StreamError::UNKNOWN_ERROR;
}
BIO* bioA;
BIO* bioB;
if (!BIO_new_bio_pair(&bioB, 0, &bioA, 0))
if (const ossl::Error error = ssl.setHost(hostname.c_str()); !error.isSuccess())
{
SSL_free(ssl);
return StreamError::UNKNOWN_ERROR;
}
SSL_set_bio(ssl, bioB, bioB);
if (!SSL_set_tlsext_host_name(ssl, hostname.c_str()))
{
SSL_free(ssl);
BIO_free_all(bioA);
return StreamError::UNKNOWN_ERROR;
}
if (!SSL_set1_host(ssl, hostname.c_str()))
{
SSL_free(ssl);
BIO_free_all(bioA);
return StreamError::UNKNOWN_ERROR;
}
ssl_ = ssl;
bioA_ = bioA;
bioB_ = bioB;
// these need to be initialized for connecting
ssl_ = std::move(ssl);
externalBio_ = std::move(externalBio);
base_ = &base;
while(true)
if (const ossl::Error error = runIOLoop(&ossl::Ssl::connect); !error.isSuccess())
{
if (const int result = SSL_connect(ssl); result < 1)
{
const int err = SSL_get_error(ssl, result);
[[maybe_unused]] int rrA = BIO_get_read_request(bioA);
[[maybe_unused]] int rrB = BIO_get_read_request(bioB);
[[maybe_unused]] int wgA = BIO_get_write_guarantee(bioA);
[[maybe_unused]] int wgB = BIO_get_write_guarantee(bioB);
switch (err)
{
case SSL_ERROR_WANT_READ:
case SSL_ERROR_WANT_WRITE:
if(const StreamError error = baseToBio(); error != StreamError::SUCCESS)
{
SSL_free(ssl);
BIO_free_all(bioA);
return error;
}
if (const StreamError error = bioToBase(); error != StreamError::SUCCESS)
{
SSL_free(ssl);
BIO_free_all(bioA);
return error;
}
break;
default:
SSL_free(ssl);
BIO_free_all(bioA);
return StreamError::UNKNOWN_ERROR;
}
}
else
{
break;
}
ssl_.free();
externalBio.free();
base_ = nullptr;
return StreamError::UNKNOWN_ERROR; // TODO: translate
}
if (SSL_get_verify_result(ssl) != X509_V_OK)
if (ssl_.getVerifyResult() != X509_V_OK)
{
SSL_free(ssl);
BIO_free_all(bioA);
ssl_.free();
externalBio.free();
base_ = nullptr;
return StreamError::UNKNOWN_ERROR;
}
@@ -145,9 +116,9 @@ void SSLStream::close() noexcept
{
MIJIN_ASSERT(base_ != nullptr, "SSL stream is not open.");
base_ = nullptr;
SSL_shutdown(static_cast<SSL*>(ssl_));
SSL_free(static_cast<SSL*>(ssl_));
BIO_free_all(static_cast<BIO*>(bioA_));
(void) runIOLoop(&ossl::Ssl::shutdown);
ssl_.free();
externalBio_.free();
}
StreamError SSLStream::writeRaw(std::span<const std::uint8_t> buffer)
@@ -157,16 +128,16 @@ StreamError SSLStream::writeRaw(std::span<const std::uint8_t> buffer)
return StreamError::SUCCESS;
}
SSL* ssl = static_cast<SSL*>(ssl_);
std::size_t bytesToWrite = buffer.size();
while (bytesToWrite > 0)
{
const int result = SSL_write(ssl, buffer.data() + buffer.size() - bytesToWrite, static_cast<int>(bytesToWrite));
if (result <= 0)
const ossl::Result<int> result = runIOLoop(&ossl::Ssl::write, buffer.data() + buffer.size() - bytesToWrite,
static_cast<int>(bytesToWrite));
if (result.isError())
{
return StreamError::UNKNOWN_ERROR;
}
bytesToWrite -= result;
bytesToWrite -= result.getValue();
if (const StreamError error = bioToBase(); error != StreamError::SUCCESS)
{
@@ -184,7 +155,6 @@ StreamError SSLStream::readRaw(std::span<std::uint8_t> buffer, const mijin::Read
{
return StreamError::SUCCESS;
}
SSL* ssl = static_cast<SSL*>(ssl_);
std::size_t bytesToRead = buffer.size();
while (bytesToRead > 0)
@@ -193,11 +163,15 @@ StreamError SSLStream::readRaw(std::span<std::uint8_t> buffer, const mijin::Read
{
return error;
}
const int result = SSL_read(ssl, buffer.data() + buffer.size() - bytesToRead, static_cast<int>(bytesToRead));
if (result <= 0)
const ossl::Result<int> result = runIOLoop(&ossl::Ssl::read, buffer.data() + buffer.size() - bytesToRead,
static_cast<int>(bytesToRead));
if (result.isError())
{
return StreamError::UNKNOWN_ERROR;
}
bytesToRead -= result.getValue();
}
// TODO: options and outBytesRead
@@ -251,13 +225,12 @@ StreamFeatures SSLStream::getFeatures()
StreamError SSLStream::bioToBase() noexcept
{
BIO* bio = static_cast<BIO*>(bioA_);
std::array<std::uint8_t, BIO_BUFFER_SIZE> buffer;
std::size_t bytes = std::min(BIO_ctrl_pending(bio), buffer.size());
std::size_t bytes = std::min(externalBio_.ctrlPending(), buffer.size());
while (bytes > 0)
{
const int result = BIO_read(bio, buffer.data(), static_cast<int>(bytes));
if (result <= 0)
const ossl::Result<int> result = externalBio_.read(buffer.data(), static_cast<int>(bytes));
if (result.isError())
{
return StreamError::UNKNOWN_ERROR;
}
@@ -266,19 +239,18 @@ StreamError SSLStream::bioToBase() noexcept
return error;
}
bytes = BIO_ctrl_pending(bio);
bytes = externalBio_.ctrlPending();
}
return StreamError::SUCCESS;
}
StreamError SSLStream::baseToBio() noexcept
{
BIO* bio = static_cast<BIO*>(bioA_);
std::array<std::uint8_t, BIO_BUFFER_SIZE> buffer;
while (true)
{
std::size_t bytes = 0;
std::size_t maxBytes = std::min(BIO_BUFFER_SIZE - BIO_ctrl_wpending(bio), buffer.size());
std::size_t maxBytes = std::min(BIO_BUFFER_SIZE - externalBio_.ctrlWPending(), buffer.size());
if (maxBytes == 0)
{
@@ -296,12 +268,12 @@ StreamError SSLStream::baseToBio() noexcept
// nothing more to read
return StreamError::SUCCESS;
}
const int result = BIO_write(bio, buffer.data(), static_cast<int>(bytes));
if (result <= 0)
const ossl::Result<int> result = externalBio_.write(buffer.data(), static_cast<int>(bytes));
if (result.isError())
{
return StreamError::UNKNOWN_ERROR;
}
MIJIN_ASSERT(result == static_cast<int>(bytes), "BIO reported more bytes in buffer than it actually accepted?");
MIJIN_ASSERT(result.getValue() == static_cast<int>(bytes), "BIO reported more bytes in buffer than it actually accepted?");
}
}
} // namespace shiken