308 lines
7.6 KiB
C++
308 lines
7.6 KiB
C++
|
|
#include "./ssl.hpp"
|
|
|
|
#include <mutex>
|
|
#include <openssl/ssl.h>
|
|
|
|
|
|
namespace mijin
|
|
{
|
|
namespace
|
|
{
|
|
inline constexpr int BIO_BUFFER_SIZE = 4096;
|
|
SSL_CTX* getSSLContext(bool create = true) noexcept
|
|
{
|
|
static SSL_CTX* context = nullptr;
|
|
static std::mutex contextMutex;
|
|
|
|
if (create && context == nullptr)
|
|
{
|
|
const std::unique_lock contextLock(contextMutex);
|
|
if (context != nullptr)
|
|
{
|
|
return context;
|
|
}
|
|
context = SSL_CTX_new(SSLv23_client_method());
|
|
SSL_CTX_set_verify(context, SSL_VERIFY_PEER, nullptr);
|
|
if (!SSL_CTX_set_default_verify_paths(context)
|
|
|| !SSL_CTX_set_min_proto_version(context, TLS1_2_VERSION))
|
|
{
|
|
SSL_CTX_free(context);
|
|
context = nullptr;
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
return context;
|
|
}
|
|
|
|
class SSLCleanupHelper
|
|
{
|
|
public:
|
|
~SSLCleanupHelper() noexcept
|
|
{
|
|
SSL_CTX* context = getSSLContext(false);
|
|
if (context != nullptr)
|
|
{
|
|
SSL_CTX_free(context);
|
|
}
|
|
}
|
|
} gSSLCleanupHelper;
|
|
}
|
|
|
|
StreamError SSLStream::open(Stream& base, const std::string& hostname) noexcept
|
|
{
|
|
MIJIN_ASSERT(base_ == nullptr, "SSL stream is already open.");
|
|
|
|
SSL_CTX* context = getSSLContext();
|
|
if (context == nullptr)
|
|
{
|
|
return StreamError::UNKNOWN_ERROR;
|
|
}
|
|
|
|
SSL* ssl = SSL_new(context);
|
|
if (ssl == nullptr)
|
|
{
|
|
return StreamError::UNKNOWN_ERROR;
|
|
}
|
|
|
|
BIO* bioA;
|
|
BIO* bioB;
|
|
if (!BIO_new_bio_pair(&bioB, 0, &bioA, 0))
|
|
{
|
|
SSL_free(ssl);
|
|
return StreamError::UNKNOWN_ERROR;
|
|
}
|
|
SSL_set_bio(ssl, bioB, bioB);
|
|
|
|
if (!SSL_set_tlsext_host_name(ssl, hostname.c_str()))
|
|
{
|
|
SSL_free(ssl);
|
|
BIO_free_all(bioA);
|
|
return StreamError::UNKNOWN_ERROR;
|
|
}
|
|
|
|
if (!SSL_set1_host(ssl, hostname.c_str()))
|
|
{
|
|
SSL_free(ssl);
|
|
BIO_free_all(bioA);
|
|
return StreamError::UNKNOWN_ERROR;
|
|
}
|
|
|
|
ssl_ = ssl;
|
|
bioA_ = bioA;
|
|
bioB_ = bioB;
|
|
base_ = &base;
|
|
while(true)
|
|
{
|
|
if (const int result = SSL_connect(ssl); result < 1)
|
|
{
|
|
const int err = SSL_get_error(ssl, result);
|
|
[[maybe_unused]] int rrA = BIO_get_read_request(bioA);
|
|
[[maybe_unused]] int rrB = BIO_get_read_request(bioB);
|
|
[[maybe_unused]] int wgA = BIO_get_write_guarantee(bioA);
|
|
[[maybe_unused]] int wgB = BIO_get_write_guarantee(bioB);
|
|
switch (err)
|
|
{
|
|
case SSL_ERROR_WANT_READ:
|
|
case SSL_ERROR_WANT_WRITE:
|
|
if(const StreamError error = baseToBio(); error != StreamError::SUCCESS)
|
|
{
|
|
SSL_free(ssl);
|
|
BIO_free_all(bioA);
|
|
return error;
|
|
}
|
|
if (const StreamError error = bioToBase(); error != StreamError::SUCCESS)
|
|
{
|
|
SSL_free(ssl);
|
|
BIO_free_all(bioA);
|
|
return error;
|
|
}
|
|
break;
|
|
default:
|
|
SSL_free(ssl);
|
|
BIO_free_all(bioA);
|
|
return StreamError::UNKNOWN_ERROR;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (SSL_get_verify_result(ssl) != X509_V_OK)
|
|
{
|
|
SSL_free(ssl);
|
|
BIO_free_all(bioA);
|
|
return StreamError::UNKNOWN_ERROR;
|
|
}
|
|
|
|
return StreamError::SUCCESS;
|
|
}
|
|
|
|
void SSLStream::close() noexcept
|
|
{
|
|
MIJIN_ASSERT(base_ != nullptr, "SSL stream is not open.");
|
|
base_ = nullptr;
|
|
SSL_shutdown(static_cast<SSL*>(ssl_));
|
|
SSL_free(static_cast<SSL*>(ssl_));
|
|
BIO_free_all(static_cast<BIO*>(bioA_));
|
|
}
|
|
|
|
StreamError SSLStream::writeRaw(std::span<const std::uint8_t> buffer)
|
|
{
|
|
if (buffer.empty())
|
|
{
|
|
return StreamError::SUCCESS;
|
|
}
|
|
|
|
SSL* ssl = static_cast<SSL*>(ssl_);
|
|
std::size_t bytesToWrite = buffer.size();
|
|
while (bytesToWrite > 0)
|
|
{
|
|
const int result = SSL_write(ssl, buffer.data() + buffer.size() - bytesToWrite, static_cast<int>(bytesToWrite));
|
|
if (result <= 0)
|
|
{
|
|
return StreamError::UNKNOWN_ERROR;
|
|
}
|
|
bytesToWrite -= result;
|
|
|
|
if (const StreamError error = bioToBase(); error != StreamError::SUCCESS)
|
|
{
|
|
return error;
|
|
}
|
|
}
|
|
|
|
return StreamError::SUCCESS;
|
|
}
|
|
|
|
StreamError SSLStream::readRaw(std::span<std::uint8_t> buffer, const mijin::ReadOptions& options,
|
|
std::size_t* outBytesRead)
|
|
{
|
|
if (buffer.empty())
|
|
{
|
|
return StreamError::SUCCESS;
|
|
}
|
|
SSL* ssl = static_cast<SSL*>(ssl_);
|
|
|
|
std::size_t bytesToRead = buffer.size();
|
|
while (bytesToRead > 0)
|
|
{
|
|
if (const StreamError error = baseToBio(); error != StreamError::SUCCESS)
|
|
{
|
|
return error;
|
|
}
|
|
const int result = SSL_read(ssl, buffer.data() + buffer.size() - bytesToRead, static_cast<int>(bytesToRead));
|
|
if (result <= 0)
|
|
{
|
|
return StreamError::UNKNOWN_ERROR;
|
|
}
|
|
}
|
|
|
|
// TODO: options and outBytesRead
|
|
(void) options;
|
|
if (outBytesRead != nullptr)
|
|
{
|
|
*outBytesRead = buffer.size();
|
|
}
|
|
|
|
return StreamError::SUCCESS;
|
|
}
|
|
|
|
std::size_t SSLStream::tell()
|
|
{
|
|
MIJIN_ERROR("SSLStream does not support tell().");
|
|
return 0;
|
|
}
|
|
|
|
StreamError SSLStream::seek(std::intptr_t pos, SeekMode seekMode)
|
|
{
|
|
MIJIN_ERROR("SSLStream does not support tell().");
|
|
(void) pos;
|
|
(void) seekMode;
|
|
return StreamError::NOT_SUPPORTED;
|
|
}
|
|
|
|
void SSLStream::flush()
|
|
{
|
|
base_->flush();
|
|
}
|
|
|
|
bool SSLStream::isAtEnd()
|
|
{
|
|
return base_->isAtEnd();
|
|
}
|
|
|
|
StreamFeatures SSLStream::getFeatures()
|
|
{
|
|
return {
|
|
.read = true,
|
|
.write = true,
|
|
.tell = false,
|
|
.seek = false,
|
|
.readOptions = {
|
|
.partial = false,
|
|
.peek = false,
|
|
.noBlock = false
|
|
}
|
|
};
|
|
}
|
|
|
|
StreamError SSLStream::bioToBase() noexcept
|
|
{
|
|
BIO* bio = static_cast<BIO*>(bioA_);
|
|
std::array<std::uint8_t, BIO_BUFFER_SIZE> buffer;
|
|
std::size_t bytes = std::min(BIO_ctrl_pending(bio), buffer.size());
|
|
while (bytes > 0)
|
|
{
|
|
const int result = BIO_read(bio, buffer.data(), static_cast<int>(bytes));
|
|
if (result <= 0)
|
|
{
|
|
return StreamError::UNKNOWN_ERROR;
|
|
}
|
|
if (const StreamError error = base_->writeRaw(buffer.data(), result); error != StreamError::SUCCESS)
|
|
{
|
|
return error;
|
|
}
|
|
|
|
bytes = BIO_ctrl_pending(bio);
|
|
}
|
|
return StreamError::SUCCESS;
|
|
}
|
|
|
|
StreamError SSLStream::baseToBio() noexcept
|
|
{
|
|
BIO* bio = static_cast<BIO*>(bioA_);
|
|
std::array<std::uint8_t, BIO_BUFFER_SIZE> buffer;
|
|
while (true)
|
|
{
|
|
std::size_t bytes = 0;
|
|
std::size_t maxBytes = std::min(BIO_BUFFER_SIZE - BIO_ctrl_wpending(bio), buffer.size());
|
|
|
|
if (maxBytes == 0)
|
|
{
|
|
// buffer is full
|
|
return StreamError::SUCCESS;
|
|
}
|
|
|
|
if (const StreamError error = base_->readRaw(buffer.data(), maxBytes,{.partial = true, .noBlock = true},
|
|
&bytes); error != StreamError::SUCCESS)
|
|
{
|
|
return error;
|
|
}
|
|
if (bytes == 0)
|
|
{
|
|
// nothing more to read
|
|
return StreamError::SUCCESS;
|
|
}
|
|
const int result = BIO_write(bio, buffer.data(), static_cast<int>(bytes));
|
|
if (result <= 0)
|
|
{
|
|
return StreamError::UNKNOWN_ERROR;
|
|
}
|
|
MIJIN_ASSERT(result == static_cast<int>(bytes), "BIO reported more bytes in buffer than it actually accepted?");
|
|
}
|
|
}
|
|
} // namespace shiken
|