mijin2/source/mijin/net/openssl_wrappers.hpp

442 lines
10 KiB
C++

#pragma once
#if !defined(MIJIN_NET_OPENSSL_WRAPPERS_HPP_INCLUDED)
#define MIJIN_NET_OPENSSL_WRAPPERS_HPP_INCLUDED 1
#include <utility>
#include <vector>
#include <openssl/err.h>
#include <openssl/ssl.h>
#include <openssl/x509_vfy.h>
#include "../debug/assert.hpp"
#include "../types/result.hpp"
namespace ossl
{
struct ErrorFrame
{
std::string message;
std::string file;
std::string function;
std::string data;
unsigned long numeric = 0;
int line = 0;
int flags = 0;
};
struct [[nodiscard]] Error
{
int sslError = SSL_ERROR_NONE;
std::vector<ErrorFrame> frames;
[[nodiscard]]
bool isSuccess() const noexcept { return sslError == SSL_ERROR_NONE; }
static inline Error current(int sslError = -1) noexcept;
static inline Error current(SSL* handle, int result) noexcept { return current(SSL_get_error(handle, result)); }
};
template<typename TSuccess>
using Result = mijin::ResultBase<TSuccess, Error>;
// callback typedefs
using verify_callback_t = int (*) (int, X509_STORE_CTX *);
template<typename TActual, typename THandle>
class Base
{
protected:
using base_t = Base<TActual, THandle>;
THandle handle_ = nullptr;
protected:
explicit Base(THandle handle) noexcept : handle_(handle) {}
public:
Base() noexcept = default;
Base(const Base& other) noexcept : handle_(other.handle_)
{
if (handle_)
{
TActual::upReferences(handle_);
}
}
Base(Base&& other) noexcept : handle_(std::exchange(other.handle_, {})) {}
~Base() noexcept
{
static_cast<TActual&>(*this).free();
}
TActual& operator=(const Base& other) noexcept
{
if (this == &other)
{
return static_cast<TActual&>(*this);
}
static_cast<TActual&>(*this).free();
handle_ = other.handle_;
if (handle_)
{
TActual::upReferences(handle_);
}
return static_cast<TActual&>(*this);
}
TActual& operator=(Base&& other) noexcept
{
if (this == &other)
{
return static_cast<TActual&>(*this);
}
static_cast<TActual&>(*this).free();
handle_ = std::exchange(other.handle_, {});
return static_cast<TActual&>(*this);
}
auto operator<=>(const Base&) const noexcept = default;
operator bool() const noexcept { return static_cast<bool>(handle_); }
bool operator!() const noexcept { return !static_cast<bool>(handle_); }
[[nodiscard]]
THandle getHandle() const noexcept { return handle_; }
[[nodiscard]]
THandle releaseHandle() noexcept { return std::exchange(handle_, nullptr); }
};
class X509Store : public Base<X509Store, X509_STORE*>
{
public:
using Base::Base;
Error create() noexcept
{
MIJIN_ASSERT(handle_ == nullptr, "X509 Store already created.");
ERR_clear_error();
handle_ = X509_STORE_new();
if (handle_ == nullptr)
{
return Error::current();
}
return {};
}
void free() noexcept
{
if (handle_ != nullptr)
{
X509_STORE_free(handle_);
handle_ = nullptr;
}
}
Error loadFile(const char* file) const noexcept
{
ERR_clear_error();
if (!X509_STORE_load_file(handle_, file))
{
return Error::current();
}
return {};
}
static void upReferences(X509_STORE* handle) noexcept
{
X509_STORE_up_ref(handle);
}
};
class Context : public Base<Context, SSL_CTX*>
{
public:
Error create(const SSL_METHOD* method) noexcept
{
MIJIN_ASSERT(handle_ == nullptr, "Context already created.");
ERR_clear_error();
handle_ = SSL_CTX_new(method);
if (handle_ == nullptr)
{
return Error::current();
}
return {};
}
void free() noexcept
{
if (handle_ == nullptr)
{
return;
}
SSL_CTX_free(handle_);
handle_ = nullptr;
}
void setVerify(int mode, verify_callback_t callback = nullptr) const noexcept
{
SSL_CTX_set_verify(handle_, mode, callback);
}
void setCertStore(X509Store store) const noexcept
{
SSL_CTX_set_cert_store(handle_, store.releaseHandle());
}
Error setMinProtoVersion(int version) const noexcept
{
ERR_clear_error();
if (!SSL_CTX_set_min_proto_version(handle_, version))
{
return Error::current();
}
return {};
}
static void upReferences(SSL_CTX* handle) noexcept
{
SSL_CTX_up_ref(handle);
}
};
class Bio : public Base<Bio, BIO*>
{
public:
Error createPair(Bio& otherBio, std::size_t writeBuf = 0, std::size_t otherWriteBuf = 0) noexcept
{
MIJIN_ASSERT(handle_ == nullptr, "Ssl already created.");
MIJIN_ASSERT(otherBio.handle_ == nullptr, "Ssl already created.");
ERR_clear_error();
if (!BIO_new_bio_pair(&handle_, writeBuf, &otherBio.handle_, otherWriteBuf))
{
return Error::current();
}
return {};
}
void free() noexcept
{
if (handle_ == nullptr)
{
return;
}
BIO_free_all(handle_);
handle_ = nullptr;
}
[[nodiscard]]
std::size_t ctrlPending() const noexcept
{
return BIO_ctrl_pending(handle_);
}
[[nodiscard]]
std::size_t ctrlWPending() const noexcept
{
return BIO_ctrl_wpending(handle_);
}
Result<int> write(const void* data, int length) const noexcept
{
ERR_clear_error();
const int result = BIO_write(handle_, data, length);
if (result <= 0)
{
return Error::current();
}
return result;
}
Result<int> read(void* data, int length) const noexcept
{
ERR_clear_error();
const int result = BIO_read(handle_, data, length);
if (result <= 0)
{
return Error::current();
}
return result;
}
[[nodiscard]]
int getReadRequest() const noexcept
{
return BIO_get_read_request(handle_);
}
[[nodiscard]]
int getWriteGuarantee() const noexcept
{
return BIO_get_write_guarantee(handle_);
}
[[nodiscard]]
int getWritePending() const noexcept
{
return BIO_wpending(handle_);
}
Error flush() const noexcept
{
ERR_clear_error();
if (!BIO_flush(handle_))
{
return Error::current();
}
return {};
}
static void upReferences(BIO* handle) noexcept
{
BIO_up_ref(handle);
}
};
class Ssl : public Base<Ssl, SSL*>
{
public:
Error create(const Context& context) noexcept
{
MIJIN_ASSERT(handle_ == nullptr, "Ssl already created.");
ERR_clear_error();
handle_ = SSL_new(context.getHandle());
if (handle_ == nullptr)
{
return Error::current();
}
return {};
}
void free() noexcept
{
if (handle_ == nullptr)
{
return;
}
SSL_free(handle_);
handle_ = nullptr;
}
void setBio(Bio readBio, Bio writeBio) const noexcept
{
SSL_set_bio(handle_, readBio.releaseHandle(), writeBio.releaseHandle());
}
void setBio(Bio&& bio) const noexcept
{
BIO* bioHandle = bio.releaseHandle();
SSL_set_bio(handle_, bioHandle, bioHandle);
}
Error setTLSExtHostname(const char* hostname) const noexcept
{
ERR_clear_error();
if (const int result = SSL_set_tlsext_host_name(handle_, hostname); result != 1)
{
return Error::current(handle_, result);
}
return {};
}
Error setHost(const char* hostname) const noexcept
{
ERR_clear_error();
if (const int result = SSL_set1_host(handle_, hostname); result != 1)
{
return Error::current(handle_, result);
}
return {};
}
Error connect() const noexcept
{
ERR_clear_error();
if (const int result = SSL_connect(handle_); result != 1)
{
return Error::current(handle_, result);
}
return {};
}
Error shutdown() const noexcept
{
ERR_clear_error();
if (const int result = SSL_shutdown(handle_); result != 1)
{
if (result == 0)
{
return Error{.sslError = SSL_ERROR_WANT_WRITE}; // TODO?
}
return Error::current(handle_, result);
}
return {};
}
[[nodiscard]]
long getVerifyResult() const noexcept
{
return SSL_get_verify_result(handle_);
}
Result<int> write(const void* data, int length) const noexcept
{
ERR_clear_error();
const int result = SSL_write(handle_, data, length);
if (result <= 0)
{
return Error::current(handle_, result);
}
return result;
}
Result<int> read(void* data, int length) const noexcept
{
ERR_clear_error();
const int result = SSL_read(handle_, data, length);
if (result <= 0)
{
return Error::current(handle_, result);
}
return result;
}
[[nodiscard]]
int pending() const noexcept
{
return SSL_pending(handle_);
}
static void upReferences(SSL* handle) noexcept
{
SSL_up_ref(handle);
}
};
Error Error::current(int sslError_) noexcept
{
Error error = {
.sslError = sslError_
};
const char* file = nullptr;
int line = 0;
const char* func = nullptr;
const char* data = nullptr;
int flags = 0;
while (const unsigned long numeric = ERR_get_error_all(&file, &line, &func, &data, &flags))
{
error.frames.push_back({
.message = ERR_error_string(numeric, nullptr),
.file = file != nullptr ? file : "",
.function = func != nullptr ? func : "",
.data = data != nullptr ? data : "",
.line = line,
.flags = flags
});
}
return error;
}
}
#endif // !defined(MIJIN_NET_OPENSSL_WRAPPERS_HPP_INCLUDED)