#include "./ssl.hpp" #include #include namespace mijin { namespace { inline constexpr int BIO_BUFFER_SIZE = 4096; SSL_CTX* getSSLContext(bool create = true) noexcept { static SSL_CTX* context = nullptr; static std::mutex contextMutex; if (create && context == nullptr) { const std::unique_lock contextLock(contextMutex); if (context != nullptr) { return context; } context = SSL_CTX_new(SSLv23_client_method()); SSL_CTX_set_verify(context, SSL_VERIFY_PEER, nullptr); if (!SSL_CTX_set_default_verify_paths(context) || !SSL_CTX_set_min_proto_version(context, TLS1_2_VERSION)) { SSL_CTX_free(context); context = nullptr; return nullptr; } } return context; } class SSLCleanupHelper { public: ~SSLCleanupHelper() noexcept { SSL_CTX* context = getSSLContext(false); if (context != nullptr) { SSL_CTX_free(context); } } } gSSLCleanupHelper; } StreamError SSLStream::open(Stream& base, const std::string& hostname) noexcept { MIJIN_ASSERT(base_ == nullptr, "SSL stream is already open."); SSL_CTX* context = getSSLContext(); if (context == nullptr) { return StreamError::UNKNOWN_ERROR; } SSL* ssl = SSL_new(context); if (ssl == nullptr) { return StreamError::UNKNOWN_ERROR; } BIO* bioA; BIO* bioB; if (!BIO_new_bio_pair(&bioB, 0, &bioA, 0)) { SSL_free(ssl); return StreamError::UNKNOWN_ERROR; } SSL_set_bio(ssl, bioB, bioB); if (!SSL_set_tlsext_host_name(ssl, hostname.c_str())) { SSL_free(ssl); BIO_free_all(bioA); return StreamError::UNKNOWN_ERROR; } if (!SSL_set1_host(ssl, hostname.c_str())) { SSL_free(ssl); BIO_free_all(bioA); return StreamError::UNKNOWN_ERROR; } ssl_ = ssl; bioA_ = bioA; bioB_ = bioB; base_ = &base; while(true) { if (const int result = SSL_connect(ssl); result < 1) { const int err = SSL_get_error(ssl, result); [[maybe_unused]] int rrA = BIO_get_read_request(bioA); [[maybe_unused]] int rrB = BIO_get_read_request(bioB); [[maybe_unused]] int wgA = BIO_get_write_guarantee(bioA); [[maybe_unused]] int wgB = BIO_get_write_guarantee(bioB); switch (err) { case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_WRITE: if(const StreamError error = baseToBio(); error != StreamError::SUCCESS) { SSL_free(ssl); BIO_free_all(bioA); return error; } if (const StreamError error = bioToBase(); error != StreamError::SUCCESS) { SSL_free(ssl); BIO_free_all(bioA); return error; } break; default: SSL_free(ssl); BIO_free_all(bioA); return StreamError::UNKNOWN_ERROR; } } else { break; } } if (SSL_get_verify_result(ssl) != X509_V_OK) { SSL_free(ssl); BIO_free_all(bioA); return StreamError::UNKNOWN_ERROR; } return StreamError::SUCCESS; } void SSLStream::close() noexcept { MIJIN_ASSERT(base_ != nullptr, "SSL stream is not open."); base_ = nullptr; SSL_shutdown(static_cast(ssl_)); SSL_free(static_cast(ssl_)); BIO_free_all(static_cast(bioA_)); } StreamError SSLStream::writeRaw(std::span buffer) { if (buffer.empty()) { return StreamError::SUCCESS; } SSL* ssl = static_cast(ssl_); std::size_t bytesToWrite = buffer.size(); while (bytesToWrite > 0) { const int result = SSL_write(ssl, buffer.data() + buffer.size() - bytesToWrite, static_cast(bytesToWrite)); if (result <= 0) { return StreamError::UNKNOWN_ERROR; } bytesToWrite -= result; if (const StreamError error = bioToBase(); error != StreamError::SUCCESS) { return error; } } return StreamError::SUCCESS; } StreamError SSLStream::readRaw(std::span buffer, const mijin::ReadOptions& options, std::size_t* outBytesRead) { if (buffer.empty()) { return StreamError::SUCCESS; } SSL* ssl = static_cast(ssl_); std::size_t bytesToRead = buffer.size(); while (bytesToRead > 0) { if (const StreamError error = baseToBio(); error != StreamError::SUCCESS) { return error; } const int result = SSL_read(ssl, buffer.data() + buffer.size() - bytesToRead, static_cast(bytesToRead)); if (result <= 0) { return StreamError::UNKNOWN_ERROR; } } // TODO: options and outBytesRead (void) options; if (outBytesRead != nullptr) { *outBytesRead = buffer.size(); } return StreamError::SUCCESS; } std::size_t SSLStream::tell() { MIJIN_ERROR("SSLStream does not support tell()."); return 0; } StreamError SSLStream::seek(std::intptr_t pos, SeekMode seekMode) { MIJIN_ERROR("SSLStream does not support tell()."); (void) pos; (void) seekMode; return StreamError::NOT_SUPPORTED; } void SSLStream::flush() { base_->flush(); } bool SSLStream::isAtEnd() { return base_->isAtEnd(); } StreamFeatures SSLStream::getFeatures() { return { .read = true, .write = true, .tell = false, .seek = false, .readOptions = { .partial = false, .peek = false, .noBlock = false } }; } StreamError SSLStream::bioToBase() noexcept { BIO* bio = static_cast(bioA_); std::array buffer; std::size_t bytes = std::min(BIO_ctrl_pending(bio), buffer.size()); while (bytes > 0) { const int result = BIO_read(bio, buffer.data(), static_cast(bytes)); if (result <= 0) { return StreamError::UNKNOWN_ERROR; } if (const StreamError error = base_->writeRaw(buffer.data(), result); error != StreamError::SUCCESS) { return error; } bytes = BIO_ctrl_pending(bio); } return StreamError::SUCCESS; } StreamError SSLStream::baseToBio() noexcept { BIO* bio = static_cast(bioA_); std::array buffer; while (true) { std::size_t bytes = 0; std::size_t maxBytes = std::min(BIO_BUFFER_SIZE - BIO_ctrl_wpending(bio), buffer.size()); if (maxBytes == 0) { // buffer is full return StreamError::SUCCESS; } if (const StreamError error = base_->readRaw(buffer.data(), maxBytes,{.partial = true, .noBlock = true}, &bytes); error != StreamError::SUCCESS) { return error; } if (bytes == 0) { // nothing more to read return StreamError::SUCCESS; } const int result = BIO_write(bio, buffer.data(), static_cast(bytes)); if (result <= 0) { return StreamError::UNKNOWN_ERROR; } MIJIN_ASSERT(result == static_cast(bytes), "BIO reported more bytes in buffer than it actually accepted?"); } } } // namespace shiken