101 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			101 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
 | 
						|
 | 
						|
#pragma once
 | 
						|
 | 
						|
#if !defined(MIJIN_NET_SSL_HPP_INCLUDED)
 | 
						|
#define MIJIN_NET_SSL_HPP_INCLUDED 1
 | 
						|
 | 
						|
#if !defined(MIJIN_ENABLE_OPENSSL)
 | 
						|
#error "SSL support not enabled. Set MIJIN_ENABLE_OPENSSL to True in your SCons environment settings."
 | 
						|
#endif // !MIJIN_ENABLE_OPENSSL
 | 
						|
 | 
						|
#include <memory>
 | 
						|
#include "./openssl_wrappers.hpp"
 | 
						|
#include "../internal/common.hpp"
 | 
						|
#include "../io/stream.hpp"
 | 
						|
 | 
						|
namespace mijin
 | 
						|
{
 | 
						|
class SSLStream : public Stream
 | 
						|
{
 | 
						|
private:
 | 
						|
    Stream* base_ = nullptr;
 | 
						|
    ossl::Ssl ssl_;
 | 
						|
    ossl::Bio externalBio_;
 | 
						|
public:
 | 
						|
    ~SSLStream() MIJIN_NOEXCEPT override
 | 
						|
    {
 | 
						|
        if (base_ != nullptr)
 | 
						|
        {
 | 
						|
            close();
 | 
						|
        }
 | 
						|
    }
 | 
						|
    StreamError open(Stream& base, const std::string& hostname) MIJIN_NOEXCEPT;
 | 
						|
    void close() MIJIN_NOEXCEPT;
 | 
						|
 | 
						|
    StreamError readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options, std::size_t* outBytesRead) override;
 | 
						|
    StreamError writeRaw(std::span<const std::uint8_t> buffer) override;
 | 
						|
    std::size_t tell() override;
 | 
						|
    StreamError seek(std::intptr_t pos, SeekMode seekMode) override;
 | 
						|
    void flush() override;
 | 
						|
    bool isAtEnd() override;
 | 
						|
    StreamFeatures getFeatures() override;
 | 
						|
private:
 | 
						|
    StreamError bioToBase() MIJIN_NOEXCEPT;
 | 
						|
    StreamError baseToBio() MIJIN_NOEXCEPT;
 | 
						|
 | 
						|
 | 
						|
    template<typename TFunc, typename... TArgs>
 | 
						|
    auto runIOLoop(TFunc&& func, bool block, TArgs&&... args) -> std::decay_t<std::invoke_result_t<TFunc, ossl::Ssl&, TArgs...>>
 | 
						|
    {
 | 
						|
        using result_t = std::decay_t<std::invoke_result_t<TFunc, ossl::Ssl&, TArgs...>>;
 | 
						|
        const std::size_t maxTries = block ? std::numeric_limits<std::size_t>::max() : 10;
 | 
						|
 | 
						|
        ossl::Error error;
 | 
						|
        for (std::size_t tryNum = 0; tryNum < maxTries; ++tryNum)
 | 
						|
        {
 | 
						|
            auto result = std::invoke(std::forward<TFunc>(func), ssl_, std::forward<TArgs>(args)...);
 | 
						|
            if constexpr (std::is_same_v<result_t, ossl::Error>)
 | 
						|
            {
 | 
						|
                if (error.isSuccess())
 | 
						|
                {
 | 
						|
                    return error;
 | 
						|
                }
 | 
						|
                error = result;
 | 
						|
            }
 | 
						|
            else
 | 
						|
            {
 | 
						|
                // assume result type
 | 
						|
                if (result.isSuccess())
 | 
						|
                {
 | 
						|
                    return result;
 | 
						|
                }
 | 
						|
                error = result.getError();
 | 
						|
            }
 | 
						|
            switch (error.sslError)
 | 
						|
            {
 | 
						|
                case SSL_ERROR_WANT_READ:
 | 
						|
                case SSL_ERROR_WANT_WRITE:
 | 
						|
                    if(const StreamError streamError = baseToBio(); streamError != StreamError::SUCCESS)
 | 
						|
                    {
 | 
						|
                        return error;
 | 
						|
                    }
 | 
						|
                    if (const StreamError streamError = bioToBase(); streamError != StreamError::SUCCESS)
 | 
						|
                    {
 | 
						|
                        return error;
 | 
						|
                    }
 | 
						|
                    break;
 | 
						|
                default:
 | 
						|
                    return error;
 | 
						|
            }
 | 
						|
        }
 | 
						|
        return error;
 | 
						|
    }
 | 
						|
    // mijin::Task<StreamError> c_readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) override;
 | 
						|
    // mijin::Task<StreamError> c_writeRaw(std::span<const std::uint8_t> buffer) override;
 | 
						|
};
 | 
						|
}
 | 
						|
 | 
						|
#endif // !defined(MIJIN_NET_SSL_HPP_INCLUDED)
 | 
						|
 |