More SSL stuff (still doesn't work :/).

This commit is contained in:
Patrick 2024-08-27 19:52:08 +02:00
parent 0be34a845a
commit a43f92fb58
7 changed files with 156 additions and 28 deletions

View File

@ -36,7 +36,47 @@ void Stream::flush() {}
mijin::Task<StreamError> Stream::c_readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options, std::size_t* outBytesRead) mijin::Task<StreamError> Stream::c_readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options, std::size_t* outBytesRead)
{ {
co_return readRaw(buffer, options, outBytesRead); std::size_t bytesToRead = buffer.size();
if (bytesToRead == 0)
{
co_return StreamError::SUCCESS;
}
while(true)
{
bool done = false;
std::size_t bytesRead = 0;
const StreamError error = readRaw(buffer.data() + buffer.size() - bytesToRead, bytesToRead,
{.partial = true, .noBlock = true}, &bytesToRead);
switch (error)
{
case StreamError::SUCCESS:
bytesToRead -= bytesRead;
if (options.partial || bytesToRead == 0)
{
done = true;
}
break;
case StreamError::WOULD_BLOCK:
if (options.noBlock)
{
co_return StreamError::WOULD_BLOCK;
}
break;
default:
co_return error;
}
if (done)
{
break;
}
co_await mijin::c_suspend();
}
if (outBytesRead != nullptr)
{
*outBytesRead = buffer.size() - bytesToRead;
}
co_return StreamError::SUCCESS;
} }
mijin::Task<StreamError> Stream::c_writeRaw(std::span<const std::uint8_t> buffer) mijin::Task<StreamError> Stream::c_writeRaw(std::span<const std::uint8_t> buffer)

View File

@ -70,6 +70,7 @@ enum class [[nodiscard]] StreamError
NOT_SUPPORTED = 2, NOT_SUPPORTED = 2,
CONNECTION_CLOSED = 3, CONNECTION_CLOSED = 3,
PROTOCOL_ERROR = 4, PROTOCOL_ERROR = 4,
WOULD_BLOCK = 5,
UNKNOWN_ERROR = -1 UNKNOWN_ERROR = -1
}; };
@ -434,6 +435,8 @@ inline const char* errorName(StreamError error) noexcept
return "connection closed"; return "connection closed";
case StreamError::PROTOCOL_ERROR: case StreamError::PROTOCOL_ERROR:
return "protocol error"; return "protocol error";
case StreamError::WOULD_BLOCK:
return "would block";
case StreamError::UNKNOWN_ERROR: case StreamError::UNKNOWN_ERROR:
return "unknown error"; return "unknown error";
} }

View File

@ -6,6 +6,10 @@
#include "../util/iterators.hpp" #include "../util/iterators.hpp"
#include "../util/string.hpp" #include "../util/string.hpp"
#if defined(MIJIN_ENABLE_OPENSSL)
#include "./ssl.hpp"
#endif
#define MIJIN_HTTP_WRITE(text) \ #define MIJIN_HTTP_WRITE(text) \
do \ do \
{ \ { \
@ -160,7 +164,12 @@ Task<StreamResult<HTTPResponse>> HTTPStream::c_readResponse() noexcept
Task<StreamResult<HTTPResponse>> HTTPClient::c_request(ip_address_t address, std::uint16_t port, bool https, Task<StreamResult<HTTPResponse>> HTTPClient::c_request(ip_address_t address, std::uint16_t port, bool https,
HTTPRequest request) noexcept HTTPRequest request) noexcept
{ {
if (const StreamError error = createSocket(address, port, https); error != StreamError::SUCCESS) std::string hostname;
if (auto it = request.headers.find("host"); it != request.headers.end())
{
hostname = it->second;
}
if (const StreamError error = createSocket(address, hostname, port, https); error != StreamError::SUCCESS)
{ {
co_return error; co_return error;
} }
@ -172,7 +181,7 @@ Task<StreamResult<HTTPResponse>> HTTPClient::c_request(ip_address_t address, std
if (response.isError()) if (response.isError())
{ {
disconnect(); disconnect();
if (const StreamError error = createSocket(address, port, https); error != StreamError::SUCCESS) if (const StreamError error = createSocket(address, hostname, port, https); error != StreamError::SUCCESS)
{ {
co_return error; co_return error;
} }
@ -242,7 +251,7 @@ void HTTPClient::disconnect() noexcept
socket_ = nullptr; socket_ = nullptr;
} }
StreamError HTTPClient::createSocket(ip_address_t address, std::uint16_t port, bool https) noexcept StreamError HTTPClient::createSocket(ip_address_t address, const std::string& hostname, std::uint16_t port, bool https) noexcept
{ {
if (socket_ != nullptr && address == lastIP_ && port == lastPort_ && https == lastWasHttps_) if (socket_ != nullptr && address == lastIP_ && port == lastPort_ && https == lastWasHttps_)
{ {
@ -250,17 +259,36 @@ StreamError HTTPClient::createSocket(ip_address_t address, std::uint16_t port, b
} }
disconnect(); disconnect();
MIJIN_ASSERT(!https, "HTTPS not supported yet.");
std::unique_ptr<TCPSocket> newSocket = std::make_unique<TCPSocket>(); std::unique_ptr<TCPSocket> newSocket = std::make_unique<TCPSocket>();
if (const StreamError error = newSocket->open(address, port); error != StreamError::SUCCESS) if (const StreamError error = newSocket->open(address, port); error != StreamError::SUCCESS)
{ {
return error; return error;
} }
socket_ = std::move(newSocket);
if (!https)
{
sslStream_.reset();
stream_.construct(socket_->getStream());
}
else
{
#if defined(MIJIN_ENABLE_OPENSSL)
std::unique_ptr<SSLStream> sslStream = std::make_unique<SSLStream>();
if (const StreamError error = sslStream->open(socket_->getStream(), hostname); error != StreamError::SUCCESS)
{
return error;
}
sslStream_ = std::move(sslStream);
stream_.construct(*sslStream_);
#else
return StreamError::NOT_SUPPORTED;
#endif
}
lastIP_ = address; lastIP_ = address;
lastPort_ = port; lastPort_ = port;
lastWasHttps_ = https; lastWasHttps_ = https;
socket_ = std::move(newSocket);
stream_.construct(socket_->getStream());
return StreamError::SUCCESS; return StreamError::SUCCESS;
} }
} }

View File

@ -57,6 +57,7 @@ class HTTPClient
{ {
private: private:
std::unique_ptr<Socket> socket_; std::unique_ptr<Socket> socket_;
std::unique_ptr<Stream> sslStream_;
mijin::BoxedObject<HTTPStream> stream_; mijin::BoxedObject<HTTPStream> stream_;
ip_address_t lastIP_; ip_address_t lastIP_;
std::uint16_t lastPort_ = 0; std::uint16_t lastPort_ = 0;
@ -67,7 +68,7 @@ public:
Task<StreamResult<HTTPResponse>> c_request(const URL& url, HTTPRequest request = {}) noexcept; Task<StreamResult<HTTPResponse>> c_request(const URL& url, HTTPRequest request = {}) noexcept;
void disconnect() noexcept; void disconnect() noexcept;
private: private:
StreamError createSocket(ip_address_t address, std::uint16_t port, bool https) noexcept; StreamError createSocket(ip_address_t address, const std::string& hostname, std::uint16_t port, bool https) noexcept;
}; };
} }

View File

@ -257,6 +257,34 @@ public:
return result; return result;
} }
[[nodiscard]]
int getReadRequest() const noexcept
{
return BIO_get_read_request(handle_);
}
[[nodiscard]]
int getWriteGuarantee() const noexcept
{
return BIO_get_write_guarantee(handle_);
}
[[nodiscard]]
int getWritePending() const noexcept
{
return BIO_wpending(handle_);
}
Error flush() const noexcept
{
ERR_clear_error();
if (!BIO_flush(handle_))
{
return Error::current();
}
return {};
}
static void upReferences(BIO* handle) noexcept static void upReferences(BIO* handle) noexcept
{ {
BIO_up_ref(handle); BIO_up_ref(handle);
@ -371,6 +399,12 @@ public:
return result; return result;
} }
[[nodiscard]]
int pending() const noexcept
{
return SSL_pending(handle_);
}
static void upReferences(SSL* handle) noexcept static void upReferences(SSL* handle) noexcept
{ {
SSL_up_ref(handle); SSL_up_ref(handle);

View File

@ -93,7 +93,7 @@ StreamError SSLStream::open(Stream& base, const std::string& hostname) noexcept
externalBio_ = std::move(externalBio); externalBio_ = std::move(externalBio);
base_ = &base; base_ = &base;
if (const ossl::Error error = runIOLoop(&ossl::Ssl::connect); !error.isSuccess()) if (const ossl::Error error = runIOLoop(&ossl::Ssl::connect, true); !error.isSuccess())
{ {
ssl_.free(); ssl_.free();
externalBio.free(); externalBio.free();
@ -116,7 +116,7 @@ void SSLStream::close() noexcept
{ {
MIJIN_ASSERT(base_ != nullptr, "SSL stream is not open."); MIJIN_ASSERT(base_ != nullptr, "SSL stream is not open.");
base_ = nullptr; base_ = nullptr;
(void) runIOLoop(&ossl::Ssl::shutdown); (void) runIOLoop(&ossl::Ssl::shutdown, true);
ssl_.free(); ssl_.free();
externalBio_.free(); externalBio_.free();
} }
@ -131,7 +131,7 @@ StreamError SSLStream::writeRaw(std::span<const std::uint8_t> buffer)
std::size_t bytesToWrite = buffer.size(); std::size_t bytesToWrite = buffer.size();
while (bytesToWrite > 0) while (bytesToWrite > 0)
{ {
const ossl::Result<int> result = runIOLoop(&ossl::Ssl::write, buffer.data() + buffer.size() - bytesToWrite, const ossl::Result<int> result = runIOLoop(&ossl::Ssl::write, true, buffer.data() + buffer.size() - bytesToWrite,
static_cast<int>(bytesToWrite)); static_cast<int>(bytesToWrite));
if (result.isError()) if (result.isError())
{ {
@ -156,15 +156,25 @@ StreamError SSLStream::readRaw(std::span<std::uint8_t> buffer, const mijin::Read
return StreamError::SUCCESS; return StreamError::SUCCESS;
} }
std::size_t bytesToRead = buffer.size();
while (bytesToRead > 0)
{
if (const StreamError error = baseToBio(); error != StreamError::SUCCESS) if (const StreamError error = baseToBio(); error != StreamError::SUCCESS)
{ {
return error; return error;
} }
const ossl::Result<int> result = runIOLoop(&ossl::Ssl::read, buffer.data() + buffer.size() - bytesToRead, if (!options.partial && options.noBlock && static_cast<std::size_t>(ssl_.pending()) < buffer.size())
{
return StreamError::WOULD_BLOCK;
}
std::size_t bytesToRead = buffer.size();
if (options.partial)
{
bytesToRead = std::min<std::size_t>(bytesToRead, ssl_.pending());
}
while (bytesToRead > 0)
{
const ossl::Result<int> result = runIOLoop(&ossl::Ssl::read, !options.noBlock,
buffer.data() + buffer.size() - bytesToRead,
static_cast<int>(bytesToRead)); static_cast<int>(bytesToRead));
if (result.isError()) if (result.isError())
{ {
@ -216,9 +226,9 @@ StreamFeatures SSLStream::getFeatures()
.tell = false, .tell = false,
.seek = false, .seek = false,
.readOptions = { .readOptions = {
.partial = false, .partial = true,
.peek = false, .peek = true,
.noBlock = false .noBlock = true
} }
}; };
} }
@ -247,15 +257,17 @@ StreamError SSLStream::bioToBase() noexcept
StreamError SSLStream::baseToBio() noexcept StreamError SSLStream::baseToBio() noexcept
{ {
std::array<std::uint8_t, BIO_BUFFER_SIZE> buffer; std::array<std::uint8_t, BIO_BUFFER_SIZE> buffer;
std::size_t toRead = externalBio_.getWriteGuarantee();
while (true) while (true)
{ {
std::size_t bytes = 0; std::size_t bytes = 0;
std::size_t maxBytes = std::min(BIO_BUFFER_SIZE - externalBio_.ctrlWPending(), buffer.size()); std::size_t maxBytes = std::min(toRead, buffer.size());
if (maxBytes == 0) if (maxBytes == 0)
{ {
// buffer is full // buffer is full
return StreamError::SUCCESS; break;
} }
if (const StreamError error = base_->readRaw(buffer.data(), maxBytes, {.partial = true, .noBlock = true}, if (const StreamError error = base_->readRaw(buffer.data(), maxBytes, {.partial = true, .noBlock = true},
@ -274,6 +286,13 @@ StreamError SSLStream::baseToBio() noexcept
return StreamError::UNKNOWN_ERROR; return StreamError::UNKNOWN_ERROR;
} }
MIJIN_ASSERT(result.getValue() == static_cast<int>(bytes), "BIO reported more bytes in buffer than it actually accepted?"); MIJIN_ASSERT(result.getValue() == static_cast<int>(bytes), "BIO reported more bytes in buffer than it actually accepted?");
toRead -= bytes;
} }
if (const ossl::Error error = externalBio_.flush(); !error.isSuccess())
{
return StreamError::UNKNOWN_ERROR;
}
return StreamError::SUCCESS;
} }
} // namespace shiken } // namespace shiken

View File

@ -5,8 +5,8 @@
#if !defined(MIJIN_NET_SSL_HPP_INCLUDED) #if !defined(MIJIN_NET_SSL_HPP_INCLUDED)
#define MIJIN_NET_SSL_HPP_INCLUDED 1 #define MIJIN_NET_SSL_HPP_INCLUDED 1
#if !MIJIN_ENABLE_OPENSSL #if !defined(MIJIN_ENABLE_OPENSSL)
#error "SSL support not enabled. Set MIJIN_ENABLE_OPENSSL to True in your environment settings." #error "SSL support not enabled. Set MIJIN_ENABLE_OPENSSL to True in your SCons environment settings."
#endif // !MIJIN_ENABLE_OPENSSL #endif // !MIJIN_ENABLE_OPENSSL
#include <memory> #include <memory>
@ -45,13 +45,15 @@ private:
template<typename TFunc, typename... TArgs> template<typename TFunc, typename... TArgs>
auto runIOLoop(TFunc&& func, TArgs&&... args) -> std::decay_t<std::invoke_result_t<TFunc, ossl::Ssl&, TArgs...>> auto runIOLoop(TFunc&& func, bool block, TArgs&&... args) -> std::decay_t<std::invoke_result_t<TFunc, ossl::Ssl&, TArgs...>>
{ {
using result_t = std::decay_t<std::invoke_result_t<TFunc, ossl::Ssl&, TArgs...>>; using result_t = std::decay_t<std::invoke_result_t<TFunc, ossl::Ssl&, TArgs...>>;
while (true) const std::size_t maxTries = block ? std::numeric_limits<std::size_t>::max() : 10;
ossl::Error error;
for (std::size_t tryNum = 0; tryNum < maxTries; ++tryNum)
{ {
auto result = std::invoke(std::forward<TFunc>(func), ssl_, std::forward<TArgs>(args)...); auto result = std::invoke(std::forward<TFunc>(func), ssl_, std::forward<TArgs>(args)...);
ossl::Error error;
if constexpr (std::is_same_v<result_t, ossl::Error>) if constexpr (std::is_same_v<result_t, ossl::Error>)
{ {
if (error.isSuccess()) if (error.isSuccess())
@ -86,6 +88,7 @@ private:
return error; return error;
} }
} }
return error;
} }
// mijin::Task<StreamError> c_readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) override; // mijin::Task<StreamError> c_readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) override;
// mijin::Task<StreamError> c_writeRaw(std::span<const std::uint8_t> buffer) override; // mijin::Task<StreamError> c_writeRaw(std::span<const std::uint8_t> buffer) override;