299 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			299 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| 
 | |
| #include "./ssl.hpp"
 | |
| 
 | |
| #include <mutex>
 | |
| 
 | |
| 
 | |
| namespace mijin
 | |
| {
 | |
| namespace
 | |
| {
 | |
| inline constexpr int BIO_BUFFER_SIZE = 4096;
 | |
| ossl::Result<ossl::Context*> getSSLContext() MIJIN_NOEXCEPT
 | |
| {
 | |
|     static ossl::Context context;
 | |
|     static std::mutex contextMutex;
 | |
| 
 | |
|     if (!context)
 | |
|     {
 | |
|         const std::unique_lock contextLock(contextMutex);
 | |
| 
 | |
|         if (context)
 | |
|         {
 | |
|             return &context;
 | |
|         }
 | |
| 
 | |
|         ossl::Context newContext;
 | |
|         if (const ossl::Error error = newContext.create(SSLv23_client_method()); !error.isSuccess())
 | |
|         {
 | |
|             return error;
 | |
|         }
 | |
|         newContext.setVerify(SSL_VERIFY_PEER);
 | |
| 
 | |
|         ossl::X509Store store;
 | |
|         if (const ossl::Error error = store.create(); !error.isSuccess())
 | |
|         {
 | |
|             return error;
 | |
|         }
 | |
|         if (const ossl::Error error = store.loadFile("/etc/ssl/cert.pem"); !error.isSuccess())
 | |
|         {
 | |
|             return error;
 | |
|         }
 | |
|         newContext.setCertStore(std::move(store));
 | |
|         if (const ossl::Error error = newContext.setMinProtoVersion(TLS1_2_VERSION); !error.isSuccess())
 | |
|         {
 | |
|             return error;
 | |
|         }
 | |
| 
 | |
|         context = std::move(newContext);
 | |
|     }
 | |
| 
 | |
|     return &context;
 | |
| }
 | |
| }
 | |
| 
 | |
| StreamError SSLStream::open(Stream& base, const std::string& hostname) MIJIN_NOEXCEPT
 | |
| {
 | |
|     MIJIN_ASSERT(base_ == nullptr, "SSL stream is already open.");
 | |
| 
 | |
|     const ossl::Result<ossl::Context*> contextResult = getSSLContext();
 | |
|     if (contextResult.isError())
 | |
|     {
 | |
|         // TODO: convert/print error
 | |
|         return StreamError::UNKNOWN_ERROR;
 | |
|     }
 | |
|     ossl::Context& context = *contextResult.getValue();
 | |
| 
 | |
|     ossl::Ssl ssl;
 | |
|     if (const ossl::Error error = ssl.create(context); !error.isSuccess())
 | |
|     {
 | |
|         return StreamError::UNKNOWN_ERROR;
 | |
|     }
 | |
| 
 | |
|     ossl::Bio externalBio;
 | |
|     ossl::Bio internalBio;
 | |
|     if (const ossl::Error error = internalBio.createPair(externalBio, BIO_BUFFER_SIZE, BIO_BUFFER_SIZE); !error.isSuccess())
 | |
|     {
 | |
|         return StreamError::UNKNOWN_ERROR;
 | |
|     }
 | |
|     ssl.setBio(std::move(internalBio));
 | |
| 
 | |
|     if (const ossl::Error error = ssl.setTLSExtHostname(hostname.c_str()); !error.isSuccess())
 | |
|     {
 | |
|         return StreamError::UNKNOWN_ERROR;
 | |
|     }
 | |
| 
 | |
|     if (const ossl::Error error = ssl.setHost(hostname.c_str()); !error.isSuccess())
 | |
|     {
 | |
|         return StreamError::UNKNOWN_ERROR;
 | |
|     }
 | |
| 
 | |
|     // these need to be initialized for connecting
 | |
|     ssl_ = std::move(ssl);
 | |
|     externalBio_ = std::move(externalBio);
 | |
|     base_ = &base;
 | |
| 
 | |
|     if (const ossl::Error error = runIOLoop(&ossl::Ssl::connect, true); !error.isSuccess())
 | |
|     {
 | |
|         ssl_.free();
 | |
|         externalBio.free();
 | |
|         base_ = nullptr;
 | |
|         return StreamError::UNKNOWN_ERROR; // TODO: translate
 | |
|     }
 | |
| 
 | |
|     if (ssl_.getVerifyResult() != X509_V_OK)
 | |
|     {
 | |
|         ssl_.free();
 | |
|         externalBio.free();
 | |
|         base_ = nullptr;
 | |
|         return StreamError::UNKNOWN_ERROR;
 | |
|     }
 | |
| 
 | |
|     return StreamError::SUCCESS;
 | |
| }
 | |
| 
 | |
| void SSLStream::close() MIJIN_NOEXCEPT
 | |
| {
 | |
|     MIJIN_ASSERT(base_ != nullptr, "SSL stream is not open.");
 | |
|     base_ = nullptr;
 | |
|     (void) runIOLoop(&ossl::Ssl::shutdown, true);
 | |
|     ssl_.free();
 | |
|     externalBio_.free();
 | |
| }
 | |
| 
 | |
| StreamError SSLStream::writeRaw(std::span<const std::uint8_t> buffer)
 | |
| {
 | |
|     if (buffer.empty())
 | |
|     {
 | |
|         return StreamError::SUCCESS;
 | |
|     }
 | |
| 
 | |
|     std::size_t bytesToWrite = buffer.size();
 | |
|     while (bytesToWrite > 0)
 | |
|     {
 | |
|         const ossl::Result<int> result = runIOLoop(&ossl::Ssl::write, true, buffer.data() + buffer.size() - bytesToWrite,
 | |
|                                                    static_cast<int>(bytesToWrite));
 | |
|         if (result.isError())
 | |
|         {
 | |
|             return StreamError::UNKNOWN_ERROR;
 | |
|         }
 | |
|         bytesToWrite -= result.getValue();
 | |
| 
 | |
|         if (const StreamError error = bioToBase(); error != StreamError::SUCCESS)
 | |
|         {
 | |
|             return error;
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     return StreamError::SUCCESS;
 | |
| }
 | |
| 
 | |
| StreamError SSLStream::readRaw(std::span<std::uint8_t> buffer, const mijin::ReadOptions& options,
 | |
|                                std::size_t* outBytesRead)
 | |
| {
 | |
|     if (buffer.empty())
 | |
|     {
 | |
|         return StreamError::SUCCESS;
 | |
|     }
 | |
| 
 | |
|     if (const StreamError error = baseToBio(); error != StreamError::SUCCESS)
 | |
|     {
 | |
|         return error;
 | |
|     }
 | |
| 
 | |
|     if (!options.partial && options.noBlock && static_cast<std::size_t>(ssl_.pending()) < buffer.size())
 | |
|     {
 | |
|         return StreamError::WOULD_BLOCK;
 | |
|     }
 | |
| 
 | |
|     std::size_t bytesToRead = buffer.size();
 | |
|     if (options.partial)
 | |
|     {
 | |
|         bytesToRead = std::min<std::size_t>(bytesToRead, ssl_.pending());
 | |
|     }
 | |
|     while (bytesToRead > 0)
 | |
|     {
 | |
|         const ossl::Result<int> result = runIOLoop(&ossl::Ssl::read, !options.noBlock,
 | |
|                                                    buffer.data() + buffer.size() - bytesToRead,
 | |
|                                                    static_cast<int>(bytesToRead));
 | |
|         if (result.isError())
 | |
|         {
 | |
|             return StreamError::UNKNOWN_ERROR;
 | |
|         }
 | |
| 
 | |
|         bytesToRead -= result.getValue();
 | |
|     }
 | |
| 
 | |
|     // TODO: options and outBytesRead
 | |
|     (void) options;
 | |
|     if (outBytesRead != nullptr)
 | |
|     {
 | |
|         *outBytesRead = buffer.size();
 | |
|     }
 | |
| 
 | |
|     return StreamError::SUCCESS;
 | |
| }
 | |
| 
 | |
| std::size_t SSLStream::tell()
 | |
| {
 | |
|     MIJIN_ERROR("SSLStream does not support tell().");
 | |
|     return 0;
 | |
| }
 | |
| 
 | |
| StreamError SSLStream::seek(std::intptr_t pos, SeekMode seekMode)
 | |
| {
 | |
|     MIJIN_ERROR("SSLStream does not support tell().");
 | |
|     (void) pos;
 | |
|     (void) seekMode;
 | |
|     return StreamError::NOT_SUPPORTED;
 | |
| }
 | |
| 
 | |
| void SSLStream::flush()
 | |
| {
 | |
|     base_->flush();
 | |
| }
 | |
| 
 | |
| bool SSLStream::isAtEnd()
 | |
| {
 | |
|     return base_->isAtEnd();
 | |
| }
 | |
| 
 | |
| StreamFeatures SSLStream::getFeatures()
 | |
| {
 | |
|     return {
 | |
|         .read = true,
 | |
|         .write = true,
 | |
|         .tell = false,
 | |
|         .seek = false,
 | |
|         .readOptions = {
 | |
|             .partial = true,
 | |
|             .peek = true,
 | |
|             .noBlock = true
 | |
|         }
 | |
|     };
 | |
| }
 | |
| 
 | |
| StreamError SSLStream::bioToBase() MIJIN_NOEXCEPT
 | |
| {
 | |
|     std::array<std::uint8_t, BIO_BUFFER_SIZE> buffer;
 | |
|     std::size_t bytes = std::min(externalBio_.ctrlPending(), buffer.size());
 | |
|     while (bytes > 0)
 | |
|     {
 | |
|         const ossl::Result<int> result = externalBio_.read(buffer.data(), static_cast<int>(bytes));
 | |
|         if (result.isError())
 | |
|         {
 | |
|             return StreamError::UNKNOWN_ERROR;
 | |
|         }
 | |
|         if (const StreamError error = base_->writeRaw(buffer.data(), result); error != StreamError::SUCCESS)
 | |
|         {
 | |
|             return error;
 | |
|         }
 | |
| 
 | |
|         bytes = externalBio_.ctrlPending();
 | |
|     }
 | |
|     return StreamError::SUCCESS;
 | |
| }
 | |
| 
 | |
| StreamError SSLStream::baseToBio() MIJIN_NOEXCEPT
 | |
| {
 | |
|     std::array<std::uint8_t, BIO_BUFFER_SIZE> buffer;
 | |
|     std::size_t toRead = externalBio_.getWriteGuarantee();
 | |
| 
 | |
|     while (true)
 | |
|     {
 | |
|         std::size_t bytes = 0;
 | |
|         std::size_t maxBytes = std::min(toRead, buffer.size());
 | |
| 
 | |
|         if (maxBytes == 0)
 | |
|         {
 | |
|             // buffer is full
 | |
|             break;
 | |
|         }
 | |
| 
 | |
|         if (const StreamError error = base_->readRaw(buffer.data(), maxBytes, {.partial = true, .noBlock = true},
 | |
|                                                      &bytes); error != StreamError::SUCCESS)
 | |
|         {
 | |
|             return error;
 | |
|         }
 | |
|         if (bytes == 0)
 | |
|         {
 | |
|             // nothing more to read
 | |
|             return StreamError::SUCCESS;
 | |
|         }
 | |
|         const ossl::Result<int> result = externalBio_.write(buffer.data(), static_cast<int>(bytes));
 | |
|         if (result.isError())
 | |
|         {
 | |
|             return StreamError::UNKNOWN_ERROR;
 | |
|         }
 | |
|         MIJIN_ASSERT(result.getValue() == static_cast<int>(bytes), "BIO reported more bytes in buffer than it actually accepted?");
 | |
|         toRead -= bytes;
 | |
|     }
 | |
| 
 | |
|     if (const ossl::Error error = externalBio_.flush(); !error.isSuccess())
 | |
|     {
 | |
|         return StreamError::UNKNOWN_ERROR;
 | |
|     }
 | |
|     return StreamError::SUCCESS;
 | |
| }
 | |
| } // namespace shiken
 |