#include "./ssl.hpp" #include namespace mijin { namespace { inline constexpr int BIO_BUFFER_SIZE = 4096; ossl::Result getSSLContext() MIJIN_NOEXCEPT { static ossl::Context context; static std::mutex contextMutex; if (!context) { const std::unique_lock contextLock(contextMutex); if (context) { return &context; } ossl::Context newContext; if (const ossl::Error error = newContext.create(SSLv23_client_method()); !error.isSuccess()) { return error; } newContext.setVerify(SSL_VERIFY_PEER); ossl::X509Store store; if (const ossl::Error error = store.create(); !error.isSuccess()) { return error; } if (const ossl::Error error = store.loadFile("/etc/ssl/cert.pem"); !error.isSuccess()) { return error; } newContext.setCertStore(std::move(store)); if (const ossl::Error error = newContext.setMinProtoVersion(TLS1_2_VERSION); !error.isSuccess()) { return error; } context = std::move(newContext); } return &context; } } StreamError SSLStream::open(Stream& base, const std::string& hostname) MIJIN_NOEXCEPT { MIJIN_ASSERT(base_ == nullptr, "SSL stream is already open."); const ossl::Result contextResult = getSSLContext(); if (contextResult.isError()) { // TODO: convert/print error return StreamError::UNKNOWN_ERROR; } ossl::Context& context = *contextResult.getValue(); ossl::Ssl ssl; if (const ossl::Error error = ssl.create(context); !error.isSuccess()) { return StreamError::UNKNOWN_ERROR; } ossl::Bio externalBio; ossl::Bio internalBio; if (const ossl::Error error = internalBio.createPair(externalBio, BIO_BUFFER_SIZE, BIO_BUFFER_SIZE); !error.isSuccess()) { return StreamError::UNKNOWN_ERROR; } ssl.setBio(std::move(internalBio)); if (const ossl::Error error = ssl.setTLSExtHostname(hostname.c_str()); !error.isSuccess()) { return StreamError::UNKNOWN_ERROR; } if (const ossl::Error error = ssl.setHost(hostname.c_str()); !error.isSuccess()) { return StreamError::UNKNOWN_ERROR; } // these need to be initialized for connecting ssl_ = std::move(ssl); externalBio_ = std::move(externalBio); base_ = &base; if (const ossl::Error error = runIOLoop(&ossl::Ssl::connect, true); !error.isSuccess()) { ssl_.free(); externalBio.free(); base_ = nullptr; return StreamError::UNKNOWN_ERROR; // TODO: translate } if (ssl_.getVerifyResult() != X509_V_OK) { ssl_.free(); externalBio.free(); base_ = nullptr; return StreamError::UNKNOWN_ERROR; } return StreamError::SUCCESS; } void SSLStream::close() MIJIN_NOEXCEPT { MIJIN_ASSERT(base_ != nullptr, "SSL stream is not open."); base_ = nullptr; (void) runIOLoop(&ossl::Ssl::shutdown, true); ssl_.free(); externalBio_.free(); } StreamError SSLStream::writeRaw(std::span buffer) { if (buffer.empty()) { return StreamError::SUCCESS; } std::size_t bytesToWrite = buffer.size(); while (bytesToWrite > 0) { const ossl::Result result = runIOLoop(&ossl::Ssl::write, true, buffer.data() + buffer.size() - bytesToWrite, static_cast(bytesToWrite)); if (result.isError()) { return StreamError::UNKNOWN_ERROR; } bytesToWrite -= result.getValue(); if (const StreamError error = bioToBase(); error != StreamError::SUCCESS) { return error; } } return StreamError::SUCCESS; } StreamError SSLStream::readRaw(std::span buffer, const mijin::ReadOptions& options, std::size_t* outBytesRead) { if (buffer.empty()) { return StreamError::SUCCESS; } if (const StreamError error = baseToBio(); error != StreamError::SUCCESS) { return error; } if (!options.partial && options.noBlock && static_cast(ssl_.pending()) < buffer.size()) { return StreamError::WOULD_BLOCK; } std::size_t bytesToRead = buffer.size(); if (options.partial) { bytesToRead = std::min(bytesToRead, ssl_.pending()); } while (bytesToRead > 0) { const ossl::Result result = runIOLoop(&ossl::Ssl::read, !options.noBlock, buffer.data() + buffer.size() - bytesToRead, static_cast(bytesToRead)); if (result.isError()) { return StreamError::UNKNOWN_ERROR; } bytesToRead -= result.getValue(); } // TODO: options and outBytesRead (void) options; if (outBytesRead != nullptr) { *outBytesRead = buffer.size(); } return StreamError::SUCCESS; } std::size_t SSLStream::tell() { MIJIN_ERROR("SSLStream does not support tell()."); return 0; } StreamError SSLStream::seek(std::intptr_t pos, SeekMode seekMode) { MIJIN_ERROR("SSLStream does not support tell()."); (void) pos; (void) seekMode; return StreamError::NOT_SUPPORTED; } void SSLStream::flush() { base_->flush(); } bool SSLStream::isAtEnd() { return base_->isAtEnd(); } StreamFeatures SSLStream::getFeatures() { return { .read = true, .write = true, .tell = false, .seek = false, .readOptions = { .partial = true, .peek = true, .noBlock = true } }; } StreamError SSLStream::bioToBase() MIJIN_NOEXCEPT { std::array buffer; std::size_t bytes = std::min(externalBio_.ctrlPending(), buffer.size()); while (bytes > 0) { const ossl::Result result = externalBio_.read(buffer.data(), static_cast(bytes)); if (result.isError()) { return StreamError::UNKNOWN_ERROR; } if (const StreamError error = base_->writeRaw(buffer.data(), result); error != StreamError::SUCCESS) { return error; } bytes = externalBio_.ctrlPending(); } return StreamError::SUCCESS; } StreamError SSLStream::baseToBio() MIJIN_NOEXCEPT { std::array buffer; std::size_t toRead = externalBio_.getWriteGuarantee(); while (true) { std::size_t bytes = 0; std::size_t maxBytes = std::min(toRead, buffer.size()); if (maxBytes == 0) { // buffer is full break; } if (const StreamError error = base_->readRaw(buffer.data(), maxBytes, {.partial = true, .noBlock = true}, &bytes); error != StreamError::SUCCESS) { return error; } if (bytes == 0) { // nothing more to read return StreamError::SUCCESS; } const ossl::Result result = externalBio_.write(buffer.data(), static_cast(bytes)); if (result.isError()) { return StreamError::UNKNOWN_ERROR; } MIJIN_ASSERT(result.getValue() == static_cast(bytes), "BIO reported more bytes in buffer than it actually accepted?"); toRead -= bytes; } if (const ossl::Error error = externalBio_.flush(); !error.isSuccess()) { return StreamError::UNKNOWN_ERROR; } return StreamError::SUCCESS; } } // namespace shiken