101 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			101 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| 
 | |
| 
 | |
| #pragma once
 | |
| 
 | |
| #if !defined(MIJIN_NET_SSL_HPP_INCLUDED)
 | |
| #define MIJIN_NET_SSL_HPP_INCLUDED 1
 | |
| 
 | |
| #if !defined(MIJIN_ENABLE_OPENSSL)
 | |
| #error "SSL support not enabled. Set MIJIN_ENABLE_OPENSSL to True in your SCons environment settings."
 | |
| #endif // !MIJIN_ENABLE_OPENSSL
 | |
| 
 | |
| #include <memory>
 | |
| #include "./openssl_wrappers.hpp"
 | |
| #include "../internal/common.hpp"
 | |
| #include "../io/stream.hpp"
 | |
| 
 | |
| namespace mijin
 | |
| {
 | |
| class SSLStream : public Stream
 | |
| {
 | |
| private:
 | |
|     Stream* base_ = nullptr;
 | |
|     ossl::Ssl ssl_;
 | |
|     ossl::Bio externalBio_;
 | |
| public:
 | |
|     ~SSLStream() MIJIN_NOEXCEPT override
 | |
|     {
 | |
|         if (base_ != nullptr)
 | |
|         {
 | |
|             close();
 | |
|         }
 | |
|     }
 | |
|     StreamError open(Stream& base, const std::string& hostname) MIJIN_NOEXCEPT;
 | |
|     void close() MIJIN_NOEXCEPT;
 | |
| 
 | |
|     StreamError readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options, std::size_t* outBytesRead) override;
 | |
|     StreamError writeRaw(std::span<const std::uint8_t> buffer) override;
 | |
|     std::size_t tell() override;
 | |
|     StreamError seek(std::intptr_t pos, SeekMode seekMode) override;
 | |
|     void flush() override;
 | |
|     bool isAtEnd() override;
 | |
|     StreamFeatures getFeatures() override;
 | |
| private:
 | |
|     StreamError bioToBase() MIJIN_NOEXCEPT;
 | |
|     StreamError baseToBio() MIJIN_NOEXCEPT;
 | |
| 
 | |
| 
 | |
|     template<typename TFunc, typename... TArgs>
 | |
|     auto runIOLoop(TFunc&& func, bool block, TArgs&&... args) -> std::decay_t<std::invoke_result_t<TFunc, ossl::Ssl&, TArgs...>>
 | |
|     {
 | |
|         using result_t = std::decay_t<std::invoke_result_t<TFunc, ossl::Ssl&, TArgs...>>;
 | |
|         const std::size_t maxTries = block ? std::numeric_limits<std::size_t>::max() : 10;
 | |
| 
 | |
|         ossl::Error error;
 | |
|         for (std::size_t tryNum = 0; tryNum < maxTries; ++tryNum)
 | |
|         {
 | |
|             auto result = std::invoke(std::forward<TFunc>(func), ssl_, std::forward<TArgs>(args)...);
 | |
|             if constexpr (std::is_same_v<result_t, ossl::Error>)
 | |
|             {
 | |
|                 if (error.isSuccess())
 | |
|                 {
 | |
|                     return error;
 | |
|                 }
 | |
|                 error = result;
 | |
|             }
 | |
|             else
 | |
|             {
 | |
|                 // assume result type
 | |
|                 if (result.isSuccess())
 | |
|                 {
 | |
|                     return result;
 | |
|                 }
 | |
|                 error = result.getError();
 | |
|             }
 | |
|             switch (error.sslError)
 | |
|             {
 | |
|                 case SSL_ERROR_WANT_READ:
 | |
|                 case SSL_ERROR_WANT_WRITE:
 | |
|                     if(const StreamError streamError = baseToBio(); streamError != StreamError::SUCCESS)
 | |
|                     {
 | |
|                         return error;
 | |
|                     }
 | |
|                     if (const StreamError streamError = bioToBase(); streamError != StreamError::SUCCESS)
 | |
|                     {
 | |
|                         return error;
 | |
|                     }
 | |
|                     break;
 | |
|                 default:
 | |
|                     return error;
 | |
|             }
 | |
|         }
 | |
|         return error;
 | |
|     }
 | |
|     // mijin::Task<StreamError> c_readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) override;
 | |
|     // mijin::Task<StreamError> c_writeRaw(std::span<const std::uint8_t> buffer) override;
 | |
| };
 | |
| }
 | |
| 
 | |
| #endif // !defined(MIJIN_NET_SSL_HPP_INCLUDED)
 | |
| 
 |