101 lines
3.2 KiB
C++

#pragma once
#if !defined(MIJIN_NET_SSL_HPP_INCLUDED)
#define MIJIN_NET_SSL_HPP_INCLUDED 1
#if !defined(MIJIN_ENABLE_OPENSSL)
#error "SSL support not enabled. Set MIJIN_ENABLE_OPENSSL to True in your SCons environment settings."
#endif // !MIJIN_ENABLE_OPENSSL
#include <memory>
#include "./openssl_wrappers.hpp"
#include "../internal/common.hpp"
#include "../io/stream.hpp"
namespace mijin
{
class SSLStream : public Stream
{
private:
Stream* base_ = nullptr;
ossl::Ssl ssl_;
ossl::Bio externalBio_;
public:
~SSLStream() MIJIN_NOEXCEPT override
{
if (base_ != nullptr)
{
close();
}
}
StreamError open(Stream& base, const std::string& hostname) MIJIN_NOEXCEPT;
void close() MIJIN_NOEXCEPT;
StreamError readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options, std::size_t* outBytesRead) override;
StreamError writeRaw(std::span<const std::uint8_t> buffer) override;
std::size_t tell() override;
StreamError seek(std::intptr_t pos, SeekMode seekMode) override;
void flush() override;
bool isAtEnd() override;
StreamFeatures getFeatures() override;
private:
StreamError bioToBase() MIJIN_NOEXCEPT;
StreamError baseToBio() MIJIN_NOEXCEPT;
template<typename TFunc, typename... TArgs>
auto runIOLoop(TFunc&& func, bool block, TArgs&&... args) -> std::decay_t<std::invoke_result_t<TFunc, ossl::Ssl&, TArgs...>>
{
using result_t = std::decay_t<std::invoke_result_t<TFunc, ossl::Ssl&, TArgs...>>;
const std::size_t maxTries = block ? std::numeric_limits<std::size_t>::max() : 10;
ossl::Error error;
for (std::size_t tryNum = 0; tryNum < maxTries; ++tryNum)
{
auto result = std::invoke(std::forward<TFunc>(func), ssl_, std::forward<TArgs>(args)...);
if constexpr (std::is_same_v<result_t, ossl::Error>)
{
if (error.isSuccess())
{
return error;
}
error = result;
}
else
{
// assume result type
if (result.isSuccess())
{
return result;
}
error = result.getError();
}
switch (error.sslError)
{
case SSL_ERROR_WANT_READ:
case SSL_ERROR_WANT_WRITE:
if(const StreamError streamError = baseToBio(); streamError != StreamError::SUCCESS)
{
return error;
}
if (const StreamError streamError = bioToBase(); streamError != StreamError::SUCCESS)
{
return error;
}
break;
default:
return error;
}
}
return error;
}
// mijin::Task<StreamError> c_readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) override;
// mijin::Task<StreamError> c_writeRaw(std::span<const std::uint8_t> buffer) override;
};
}
#endif // !defined(MIJIN_NET_SSL_HPP_INCLUDED)