101 lines
3.2 KiB
C++
101 lines
3.2 KiB
C++
|
|
|
|
#pragma once
|
|
|
|
#if !defined(MIJIN_NET_SSL_HPP_INCLUDED)
|
|
#define MIJIN_NET_SSL_HPP_INCLUDED 1
|
|
|
|
#if !defined(MIJIN_ENABLE_OPENSSL)
|
|
#error "SSL support not enabled. Set MIJIN_ENABLE_OPENSSL to True in your SCons environment settings."
|
|
#endif // !MIJIN_ENABLE_OPENSSL
|
|
|
|
#include <memory>
|
|
#include "./openssl_wrappers.hpp"
|
|
#include "../internal/common.hpp"
|
|
#include "../io/stream.hpp"
|
|
|
|
namespace mijin
|
|
{
|
|
class SSLStream : public Stream
|
|
{
|
|
private:
|
|
Stream* base_ = nullptr;
|
|
ossl::Ssl ssl_;
|
|
ossl::Bio externalBio_;
|
|
public:
|
|
~SSLStream() MIJIN_NOEXCEPT override
|
|
{
|
|
if (base_ != nullptr)
|
|
{
|
|
close();
|
|
}
|
|
}
|
|
StreamError open(Stream& base, const std::string& hostname) MIJIN_NOEXCEPT;
|
|
void close() MIJIN_NOEXCEPT;
|
|
|
|
StreamError readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options, std::size_t* outBytesRead) override;
|
|
StreamError writeRaw(std::span<const std::uint8_t> buffer) override;
|
|
std::size_t tell() override;
|
|
StreamError seek(std::intptr_t pos, SeekMode seekMode) override;
|
|
void flush() override;
|
|
bool isAtEnd() override;
|
|
StreamFeatures getFeatures() override;
|
|
private:
|
|
StreamError bioToBase() MIJIN_NOEXCEPT;
|
|
StreamError baseToBio() MIJIN_NOEXCEPT;
|
|
|
|
|
|
template<typename TFunc, typename... TArgs>
|
|
auto runIOLoop(TFunc&& func, bool block, TArgs&&... args) -> std::decay_t<std::invoke_result_t<TFunc, ossl::Ssl&, TArgs...>>
|
|
{
|
|
using result_t = std::decay_t<std::invoke_result_t<TFunc, ossl::Ssl&, TArgs...>>;
|
|
const std::size_t maxTries = block ? std::numeric_limits<std::size_t>::max() : 10;
|
|
|
|
ossl::Error error;
|
|
for (std::size_t tryNum = 0; tryNum < maxTries; ++tryNum)
|
|
{
|
|
auto result = std::invoke(std::forward<TFunc>(func), ssl_, std::forward<TArgs>(args)...);
|
|
if constexpr (std::is_same_v<result_t, ossl::Error>)
|
|
{
|
|
if (error.isSuccess())
|
|
{
|
|
return error;
|
|
}
|
|
error = result;
|
|
}
|
|
else
|
|
{
|
|
// assume result type
|
|
if (result.isSuccess())
|
|
{
|
|
return result;
|
|
}
|
|
error = result.getError();
|
|
}
|
|
switch (error.sslError)
|
|
{
|
|
case SSL_ERROR_WANT_READ:
|
|
case SSL_ERROR_WANT_WRITE:
|
|
if(const StreamError streamError = baseToBio(); streamError != StreamError::SUCCESS)
|
|
{
|
|
return error;
|
|
}
|
|
if (const StreamError streamError = bioToBase(); streamError != StreamError::SUCCESS)
|
|
{
|
|
return error;
|
|
}
|
|
break;
|
|
default:
|
|
return error;
|
|
}
|
|
}
|
|
return error;
|
|
}
|
|
// mijin::Task<StreamError> c_readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) override;
|
|
// mijin::Task<StreamError> c_writeRaw(std::span<const std::uint8_t> buffer) override;
|
|
};
|
|
}
|
|
|
|
#endif // !defined(MIJIN_NET_SSL_HPP_INCLUDED)
|
|
|