299 lines
7.6 KiB
C++
299 lines
7.6 KiB
C++
|
|
#include "./ssl.hpp"
|
|
|
|
#include <mutex>
|
|
|
|
|
|
namespace mijin
|
|
{
|
|
namespace
|
|
{
|
|
inline constexpr int BIO_BUFFER_SIZE = 4096;
|
|
ossl::Result<ossl::Context*> getSSLContext() noexcept
|
|
{
|
|
static ossl::Context context;
|
|
static std::mutex contextMutex;
|
|
|
|
if (!context)
|
|
{
|
|
const std::unique_lock contextLock(contextMutex);
|
|
|
|
if (context)
|
|
{
|
|
return &context;
|
|
}
|
|
|
|
ossl::Context newContext;
|
|
if (const ossl::Error error = newContext.create(SSLv23_client_method()); !error.isSuccess())
|
|
{
|
|
return error;
|
|
}
|
|
newContext.setVerify(SSL_VERIFY_PEER);
|
|
|
|
ossl::X509Store store;
|
|
if (const ossl::Error error = store.create(); !error.isSuccess())
|
|
{
|
|
return error;
|
|
}
|
|
if (const ossl::Error error = store.loadFile("/etc/ssl/cert.pem"); !error.isSuccess())
|
|
{
|
|
return error;
|
|
}
|
|
newContext.setCertStore(std::move(store));
|
|
if (const ossl::Error error = newContext.setMinProtoVersion(TLS1_2_VERSION); !error.isSuccess())
|
|
{
|
|
return error;
|
|
}
|
|
|
|
context = std::move(newContext);
|
|
}
|
|
|
|
return &context;
|
|
}
|
|
}
|
|
|
|
StreamError SSLStream::open(Stream& base, const std::string& hostname) noexcept
|
|
{
|
|
MIJIN_ASSERT(base_ == nullptr, "SSL stream is already open.");
|
|
|
|
const ossl::Result<ossl::Context*> contextResult = getSSLContext();
|
|
if (contextResult.isError())
|
|
{
|
|
// TODO: convert/print error
|
|
return StreamError::UNKNOWN_ERROR;
|
|
}
|
|
ossl::Context& context = *contextResult.getValue();
|
|
|
|
ossl::Ssl ssl;
|
|
if (const ossl::Error error = ssl.create(context); !error.isSuccess())
|
|
{
|
|
return StreamError::UNKNOWN_ERROR;
|
|
}
|
|
|
|
ossl::Bio externalBio;
|
|
ossl::Bio internalBio;
|
|
if (const ossl::Error error = internalBio.createPair(externalBio, BIO_BUFFER_SIZE, BIO_BUFFER_SIZE); !error.isSuccess())
|
|
{
|
|
return StreamError::UNKNOWN_ERROR;
|
|
}
|
|
ssl.setBio(std::move(internalBio));
|
|
|
|
if (const ossl::Error error = ssl.setTLSExtHostname(hostname.c_str()); !error.isSuccess())
|
|
{
|
|
return StreamError::UNKNOWN_ERROR;
|
|
}
|
|
|
|
if (const ossl::Error error = ssl.setHost(hostname.c_str()); !error.isSuccess())
|
|
{
|
|
return StreamError::UNKNOWN_ERROR;
|
|
}
|
|
|
|
// these need to be initialized for connecting
|
|
ssl_ = std::move(ssl);
|
|
externalBio_ = std::move(externalBio);
|
|
base_ = &base;
|
|
|
|
if (const ossl::Error error = runIOLoop(&ossl::Ssl::connect, true); !error.isSuccess())
|
|
{
|
|
ssl_.free();
|
|
externalBio.free();
|
|
base_ = nullptr;
|
|
return StreamError::UNKNOWN_ERROR; // TODO: translate
|
|
}
|
|
|
|
if (ssl_.getVerifyResult() != X509_V_OK)
|
|
{
|
|
ssl_.free();
|
|
externalBio.free();
|
|
base_ = nullptr;
|
|
return StreamError::UNKNOWN_ERROR;
|
|
}
|
|
|
|
return StreamError::SUCCESS;
|
|
}
|
|
|
|
void SSLStream::close() noexcept
|
|
{
|
|
MIJIN_ASSERT(base_ != nullptr, "SSL stream is not open.");
|
|
base_ = nullptr;
|
|
(void) runIOLoop(&ossl::Ssl::shutdown, true);
|
|
ssl_.free();
|
|
externalBio_.free();
|
|
}
|
|
|
|
StreamError SSLStream::writeRaw(std::span<const std::uint8_t> buffer)
|
|
{
|
|
if (buffer.empty())
|
|
{
|
|
return StreamError::SUCCESS;
|
|
}
|
|
|
|
std::size_t bytesToWrite = buffer.size();
|
|
while (bytesToWrite > 0)
|
|
{
|
|
const ossl::Result<int> result = runIOLoop(&ossl::Ssl::write, true, buffer.data() + buffer.size() - bytesToWrite,
|
|
static_cast<int>(bytesToWrite));
|
|
if (result.isError())
|
|
{
|
|
return StreamError::UNKNOWN_ERROR;
|
|
}
|
|
bytesToWrite -= result.getValue();
|
|
|
|
if (const StreamError error = bioToBase(); error != StreamError::SUCCESS)
|
|
{
|
|
return error;
|
|
}
|
|
}
|
|
|
|
return StreamError::SUCCESS;
|
|
}
|
|
|
|
StreamError SSLStream::readRaw(std::span<std::uint8_t> buffer, const mijin::ReadOptions& options,
|
|
std::size_t* outBytesRead)
|
|
{
|
|
if (buffer.empty())
|
|
{
|
|
return StreamError::SUCCESS;
|
|
}
|
|
|
|
if (const StreamError error = baseToBio(); error != StreamError::SUCCESS)
|
|
{
|
|
return error;
|
|
}
|
|
|
|
if (!options.partial && options.noBlock && static_cast<std::size_t>(ssl_.pending()) < buffer.size())
|
|
{
|
|
return StreamError::WOULD_BLOCK;
|
|
}
|
|
|
|
std::size_t bytesToRead = buffer.size();
|
|
if (options.partial)
|
|
{
|
|
bytesToRead = std::min<std::size_t>(bytesToRead, ssl_.pending());
|
|
}
|
|
while (bytesToRead > 0)
|
|
{
|
|
const ossl::Result<int> result = runIOLoop(&ossl::Ssl::read, !options.noBlock,
|
|
buffer.data() + buffer.size() - bytesToRead,
|
|
static_cast<int>(bytesToRead));
|
|
if (result.isError())
|
|
{
|
|
return StreamError::UNKNOWN_ERROR;
|
|
}
|
|
|
|
bytesToRead -= result.getValue();
|
|
}
|
|
|
|
// TODO: options and outBytesRead
|
|
(void) options;
|
|
if (outBytesRead != nullptr)
|
|
{
|
|
*outBytesRead = buffer.size();
|
|
}
|
|
|
|
return StreamError::SUCCESS;
|
|
}
|
|
|
|
std::size_t SSLStream::tell()
|
|
{
|
|
MIJIN_ERROR("SSLStream does not support tell().");
|
|
return 0;
|
|
}
|
|
|
|
StreamError SSLStream::seek(std::intptr_t pos, SeekMode seekMode)
|
|
{
|
|
MIJIN_ERROR("SSLStream does not support tell().");
|
|
(void) pos;
|
|
(void) seekMode;
|
|
return StreamError::NOT_SUPPORTED;
|
|
}
|
|
|
|
void SSLStream::flush()
|
|
{
|
|
base_->flush();
|
|
}
|
|
|
|
bool SSLStream::isAtEnd()
|
|
{
|
|
return base_->isAtEnd();
|
|
}
|
|
|
|
StreamFeatures SSLStream::getFeatures()
|
|
{
|
|
return {
|
|
.read = true,
|
|
.write = true,
|
|
.tell = false,
|
|
.seek = false,
|
|
.readOptions = {
|
|
.partial = true,
|
|
.peek = true,
|
|
.noBlock = true
|
|
}
|
|
};
|
|
}
|
|
|
|
StreamError SSLStream::bioToBase() noexcept
|
|
{
|
|
std::array<std::uint8_t, BIO_BUFFER_SIZE> buffer;
|
|
std::size_t bytes = std::min(externalBio_.ctrlPending(), buffer.size());
|
|
while (bytes > 0)
|
|
{
|
|
const ossl::Result<int> result = externalBio_.read(buffer.data(), static_cast<int>(bytes));
|
|
if (result.isError())
|
|
{
|
|
return StreamError::UNKNOWN_ERROR;
|
|
}
|
|
if (const StreamError error = base_->writeRaw(buffer.data(), result); error != StreamError::SUCCESS)
|
|
{
|
|
return error;
|
|
}
|
|
|
|
bytes = externalBio_.ctrlPending();
|
|
}
|
|
return StreamError::SUCCESS;
|
|
}
|
|
|
|
StreamError SSLStream::baseToBio() noexcept
|
|
{
|
|
std::array<std::uint8_t, BIO_BUFFER_SIZE> buffer;
|
|
std::size_t toRead = externalBio_.getWriteGuarantee();
|
|
|
|
while (true)
|
|
{
|
|
std::size_t bytes = 0;
|
|
std::size_t maxBytes = std::min(toRead, buffer.size());
|
|
|
|
if (maxBytes == 0)
|
|
{
|
|
// buffer is full
|
|
break;
|
|
}
|
|
|
|
if (const StreamError error = base_->readRaw(buffer.data(), maxBytes, {.partial = true, .noBlock = true},
|
|
&bytes); error != StreamError::SUCCESS)
|
|
{
|
|
return error;
|
|
}
|
|
if (bytes == 0)
|
|
{
|
|
// nothing more to read
|
|
return StreamError::SUCCESS;
|
|
}
|
|
const ossl::Result<int> result = externalBio_.write(buffer.data(), static_cast<int>(bytes));
|
|
if (result.isError())
|
|
{
|
|
return StreamError::UNKNOWN_ERROR;
|
|
}
|
|
MIJIN_ASSERT(result.getValue() == static_cast<int>(bytes), "BIO reported more bytes in buffer than it actually accepted?");
|
|
toRead -= bytes;
|
|
}
|
|
|
|
if (const ossl::Error error = externalBio_.flush(); !error.isSuccess())
|
|
{
|
|
return StreamError::UNKNOWN_ERROR;
|
|
}
|
|
return StreamError::SUCCESS;
|
|
}
|
|
} // namespace shiken
|