More SSL stuff (still doesn't work :/).

This commit is contained in:
Patrick 2024-08-27 19:52:08 +02:00
parent 0be34a845a
commit a43f92fb58
7 changed files with 156 additions and 28 deletions

View File

@ -36,7 +36,47 @@ void Stream::flush() {}
mijin::Task<StreamError> Stream::c_readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options, std::size_t* outBytesRead)
{
co_return readRaw(buffer, options, outBytesRead);
std::size_t bytesToRead = buffer.size();
if (bytesToRead == 0)
{
co_return StreamError::SUCCESS;
}
while(true)
{
bool done = false;
std::size_t bytesRead = 0;
const StreamError error = readRaw(buffer.data() + buffer.size() - bytesToRead, bytesToRead,
{.partial = true, .noBlock = true}, &bytesToRead);
switch (error)
{
case StreamError::SUCCESS:
bytesToRead -= bytesRead;
if (options.partial || bytesToRead == 0)
{
done = true;
}
break;
case StreamError::WOULD_BLOCK:
if (options.noBlock)
{
co_return StreamError::WOULD_BLOCK;
}
break;
default:
co_return error;
}
if (done)
{
break;
}
co_await mijin::c_suspend();
}
if (outBytesRead != nullptr)
{
*outBytesRead = buffer.size() - bytesToRead;
}
co_return StreamError::SUCCESS;
}
mijin::Task<StreamError> Stream::c_writeRaw(std::span<const std::uint8_t> buffer)

View File

@ -70,6 +70,7 @@ enum class [[nodiscard]] StreamError
NOT_SUPPORTED = 2,
CONNECTION_CLOSED = 3,
PROTOCOL_ERROR = 4,
WOULD_BLOCK = 5,
UNKNOWN_ERROR = -1
};
@ -434,6 +435,8 @@ inline const char* errorName(StreamError error) noexcept
return "connection closed";
case StreamError::PROTOCOL_ERROR:
return "protocol error";
case StreamError::WOULD_BLOCK:
return "would block";
case StreamError::UNKNOWN_ERROR:
return "unknown error";
}

View File

@ -6,6 +6,10 @@
#include "../util/iterators.hpp"
#include "../util/string.hpp"
#if defined(MIJIN_ENABLE_OPENSSL)
#include "./ssl.hpp"
#endif
#define MIJIN_HTTP_WRITE(text) \
do \
{ \
@ -160,7 +164,12 @@ Task<StreamResult<HTTPResponse>> HTTPStream::c_readResponse() noexcept
Task<StreamResult<HTTPResponse>> HTTPClient::c_request(ip_address_t address, std::uint16_t port, bool https,
HTTPRequest request) noexcept
{
if (const StreamError error = createSocket(address, port, https); error != StreamError::SUCCESS)
std::string hostname;
if (auto it = request.headers.find("host"); it != request.headers.end())
{
hostname = it->second;
}
if (const StreamError error = createSocket(address, hostname, port, https); error != StreamError::SUCCESS)
{
co_return error;
}
@ -172,7 +181,7 @@ Task<StreamResult<HTTPResponse>> HTTPClient::c_request(ip_address_t address, std
if (response.isError())
{
disconnect();
if (const StreamError error = createSocket(address, port, https); error != StreamError::SUCCESS)
if (const StreamError error = createSocket(address, hostname, port, https); error != StreamError::SUCCESS)
{
co_return error;
}
@ -242,7 +251,7 @@ void HTTPClient::disconnect() noexcept
socket_ = nullptr;
}
StreamError HTTPClient::createSocket(ip_address_t address, std::uint16_t port, bool https) noexcept
StreamError HTTPClient::createSocket(ip_address_t address, const std::string& hostname, std::uint16_t port, bool https) noexcept
{
if (socket_ != nullptr && address == lastIP_ && port == lastPort_ && https == lastWasHttps_)
{
@ -250,17 +259,36 @@ StreamError HTTPClient::createSocket(ip_address_t address, std::uint16_t port, b
}
disconnect();
MIJIN_ASSERT(!https, "HTTPS not supported yet.");
std::unique_ptr<TCPSocket> newSocket = std::make_unique<TCPSocket>();
if (const StreamError error = newSocket->open(address, port); error != StreamError::SUCCESS)
{
return error;
}
socket_ = std::move(newSocket);
if (!https)
{
sslStream_.reset();
stream_.construct(socket_->getStream());
}
else
{
#if defined(MIJIN_ENABLE_OPENSSL)
std::unique_ptr<SSLStream> sslStream = std::make_unique<SSLStream>();
if (const StreamError error = sslStream->open(socket_->getStream(), hostname); error != StreamError::SUCCESS)
{
return error;
}
sslStream_ = std::move(sslStream);
stream_.construct(*sslStream_);
#else
return StreamError::NOT_SUPPORTED;
#endif
}
lastIP_ = address;
lastPort_ = port;
lastWasHttps_ = https;
socket_ = std::move(newSocket);
stream_.construct(socket_->getStream());
return StreamError::SUCCESS;
}
}

View File

@ -57,6 +57,7 @@ class HTTPClient
{
private:
std::unique_ptr<Socket> socket_;
std::unique_ptr<Stream> sslStream_;
mijin::BoxedObject<HTTPStream> stream_;
ip_address_t lastIP_;
std::uint16_t lastPort_ = 0;
@ -67,7 +68,7 @@ public:
Task<StreamResult<HTTPResponse>> c_request(const URL& url, HTTPRequest request = {}) noexcept;
void disconnect() noexcept;
private:
StreamError createSocket(ip_address_t address, std::uint16_t port, bool https) noexcept;
StreamError createSocket(ip_address_t address, const std::string& hostname, std::uint16_t port, bool https) noexcept;
};
}

View File

@ -257,6 +257,34 @@ public:
return result;
}
[[nodiscard]]
int getReadRequest() const noexcept
{
return BIO_get_read_request(handle_);
}
[[nodiscard]]
int getWriteGuarantee() const noexcept
{
return BIO_get_write_guarantee(handle_);
}
[[nodiscard]]
int getWritePending() const noexcept
{
return BIO_wpending(handle_);
}
Error flush() const noexcept
{
ERR_clear_error();
if (!BIO_flush(handle_))
{
return Error::current();
}
return {};
}
static void upReferences(BIO* handle) noexcept
{
BIO_up_ref(handle);
@ -371,6 +399,12 @@ public:
return result;
}
[[nodiscard]]
int pending() const noexcept
{
return SSL_pending(handle_);
}
static void upReferences(SSL* handle) noexcept
{
SSL_up_ref(handle);

View File

@ -93,7 +93,7 @@ StreamError SSLStream::open(Stream& base, const std::string& hostname) noexcept
externalBio_ = std::move(externalBio);
base_ = &base;
if (const ossl::Error error = runIOLoop(&ossl::Ssl::connect); !error.isSuccess())
if (const ossl::Error error = runIOLoop(&ossl::Ssl::connect, true); !error.isSuccess())
{
ssl_.free();
externalBio.free();
@ -116,7 +116,7 @@ void SSLStream::close() noexcept
{
MIJIN_ASSERT(base_ != nullptr, "SSL stream is not open.");
base_ = nullptr;
(void) runIOLoop(&ossl::Ssl::shutdown);
(void) runIOLoop(&ossl::Ssl::shutdown, true);
ssl_.free();
externalBio_.free();
}
@ -131,7 +131,7 @@ StreamError SSLStream::writeRaw(std::span<const std::uint8_t> buffer)
std::size_t bytesToWrite = buffer.size();
while (bytesToWrite > 0)
{
const ossl::Result<int> result = runIOLoop(&ossl::Ssl::write, buffer.data() + buffer.size() - bytesToWrite,
const ossl::Result<int> result = runIOLoop(&ossl::Ssl::write, true, buffer.data() + buffer.size() - bytesToWrite,
static_cast<int>(bytesToWrite));
if (result.isError())
{
@ -156,15 +156,25 @@ StreamError SSLStream::readRaw(std::span<std::uint8_t> buffer, const mijin::Read
return StreamError::SUCCESS;
}
if (const StreamError error = baseToBio(); error != StreamError::SUCCESS)
{
return error;
}
if (!options.partial && options.noBlock && static_cast<std::size_t>(ssl_.pending()) < buffer.size())
{
return StreamError::WOULD_BLOCK;
}
std::size_t bytesToRead = buffer.size();
if (options.partial)
{
bytesToRead = std::min<std::size_t>(bytesToRead, ssl_.pending());
}
while (bytesToRead > 0)
{
if (const StreamError error = baseToBio(); error != StreamError::SUCCESS)
{
return error;
}
const ossl::Result<int> result = runIOLoop(&ossl::Ssl::read, buffer.data() + buffer.size() - bytesToRead,
const ossl::Result<int> result = runIOLoop(&ossl::Ssl::read, !options.noBlock,
buffer.data() + buffer.size() - bytesToRead,
static_cast<int>(bytesToRead));
if (result.isError())
{
@ -216,9 +226,9 @@ StreamFeatures SSLStream::getFeatures()
.tell = false,
.seek = false,
.readOptions = {
.partial = false,
.peek = false,
.noBlock = false
.partial = true,
.peek = true,
.noBlock = true
}
};
}
@ -247,18 +257,20 @@ StreamError SSLStream::bioToBase() noexcept
StreamError SSLStream::baseToBio() noexcept
{
std::array<std::uint8_t, BIO_BUFFER_SIZE> buffer;
std::size_t toRead = externalBio_.getWriteGuarantee();
while (true)
{
std::size_t bytes = 0;
std::size_t maxBytes = std::min(BIO_BUFFER_SIZE - externalBio_.ctrlWPending(), buffer.size());
std::size_t maxBytes = std::min(toRead, buffer.size());
if (maxBytes == 0)
{
// buffer is full
return StreamError::SUCCESS;
break;
}
if (const StreamError error = base_->readRaw(buffer.data(), maxBytes,{.partial = true, .noBlock = true},
if (const StreamError error = base_->readRaw(buffer.data(), maxBytes, {.partial = true, .noBlock = true},
&bytes); error != StreamError::SUCCESS)
{
return error;
@ -274,6 +286,13 @@ StreamError SSLStream::baseToBio() noexcept
return StreamError::UNKNOWN_ERROR;
}
MIJIN_ASSERT(result.getValue() == static_cast<int>(bytes), "BIO reported more bytes in buffer than it actually accepted?");
toRead -= bytes;
}
if (const ossl::Error error = externalBio_.flush(); !error.isSuccess())
{
return StreamError::UNKNOWN_ERROR;
}
return StreamError::SUCCESS;
}
} // namespace shiken

View File

@ -5,8 +5,8 @@
#if !defined(MIJIN_NET_SSL_HPP_INCLUDED)
#define MIJIN_NET_SSL_HPP_INCLUDED 1
#if !MIJIN_ENABLE_OPENSSL
#error "SSL support not enabled. Set MIJIN_ENABLE_OPENSSL to True in your environment settings."
#if !defined(MIJIN_ENABLE_OPENSSL)
#error "SSL support not enabled. Set MIJIN_ENABLE_OPENSSL to True in your SCons environment settings."
#endif // !MIJIN_ENABLE_OPENSSL
#include <memory>
@ -45,13 +45,15 @@ private:
template<typename TFunc, typename... TArgs>
auto runIOLoop(TFunc&& func, TArgs&&... args) -> std::decay_t<std::invoke_result_t<TFunc, ossl::Ssl&, TArgs...>>
auto runIOLoop(TFunc&& func, bool block, TArgs&&... args) -> std::decay_t<std::invoke_result_t<TFunc, ossl::Ssl&, TArgs...>>
{
using result_t = std::decay_t<std::invoke_result_t<TFunc, ossl::Ssl&, TArgs...>>;
while (true)
const std::size_t maxTries = block ? std::numeric_limits<std::size_t>::max() : 10;
ossl::Error error;
for (std::size_t tryNum = 0; tryNum < maxTries; ++tryNum)
{
auto result = std::invoke(std::forward<TFunc>(func), ssl_, std::forward<TArgs>(args)...);
ossl::Error error;
if constexpr (std::is_same_v<result_t, ossl::Error>)
{
if (error.isSuccess())
@ -86,6 +88,7 @@ private:
return error;
}
}
return error;
}
// mijin::Task<StreamError> c_readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) override;
// mijin::Task<StreamError> c_writeRaw(std::span<const std::uint8_t> buffer) override;