More SSL stuff (still doesn't work :/).
This commit is contained in:
parent
0be34a845a
commit
a43f92fb58
@ -36,7 +36,47 @@ void Stream::flush() {}
|
||||
|
||||
mijin::Task<StreamError> Stream::c_readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options, std::size_t* outBytesRead)
|
||||
{
|
||||
co_return readRaw(buffer, options, outBytesRead);
|
||||
std::size_t bytesToRead = buffer.size();
|
||||
if (bytesToRead == 0)
|
||||
{
|
||||
co_return StreamError::SUCCESS;
|
||||
}
|
||||
|
||||
while(true)
|
||||
{
|
||||
bool done = false;
|
||||
std::size_t bytesRead = 0;
|
||||
const StreamError error = readRaw(buffer.data() + buffer.size() - bytesToRead, bytesToRead,
|
||||
{.partial = true, .noBlock = true}, &bytesToRead);
|
||||
switch (error)
|
||||
{
|
||||
case StreamError::SUCCESS:
|
||||
bytesToRead -= bytesRead;
|
||||
if (options.partial || bytesToRead == 0)
|
||||
{
|
||||
done = true;
|
||||
}
|
||||
break;
|
||||
case StreamError::WOULD_BLOCK:
|
||||
if (options.noBlock)
|
||||
{
|
||||
co_return StreamError::WOULD_BLOCK;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
co_return error;
|
||||
}
|
||||
if (done)
|
||||
{
|
||||
break;
|
||||
}
|
||||
co_await mijin::c_suspend();
|
||||
}
|
||||
if (outBytesRead != nullptr)
|
||||
{
|
||||
*outBytesRead = buffer.size() - bytesToRead;
|
||||
}
|
||||
co_return StreamError::SUCCESS;
|
||||
}
|
||||
|
||||
mijin::Task<StreamError> Stream::c_writeRaw(std::span<const std::uint8_t> buffer)
|
||||
|
@ -70,6 +70,7 @@ enum class [[nodiscard]] StreamError
|
||||
NOT_SUPPORTED = 2,
|
||||
CONNECTION_CLOSED = 3,
|
||||
PROTOCOL_ERROR = 4,
|
||||
WOULD_BLOCK = 5,
|
||||
UNKNOWN_ERROR = -1
|
||||
};
|
||||
|
||||
@ -434,6 +435,8 @@ inline const char* errorName(StreamError error) noexcept
|
||||
return "connection closed";
|
||||
case StreamError::PROTOCOL_ERROR:
|
||||
return "protocol error";
|
||||
case StreamError::WOULD_BLOCK:
|
||||
return "would block";
|
||||
case StreamError::UNKNOWN_ERROR:
|
||||
return "unknown error";
|
||||
}
|
||||
|
@ -6,6 +6,10 @@
|
||||
#include "../util/iterators.hpp"
|
||||
#include "../util/string.hpp"
|
||||
|
||||
#if defined(MIJIN_ENABLE_OPENSSL)
|
||||
#include "./ssl.hpp"
|
||||
#endif
|
||||
|
||||
#define MIJIN_HTTP_WRITE(text) \
|
||||
do \
|
||||
{ \
|
||||
@ -160,7 +164,12 @@ Task<StreamResult<HTTPResponse>> HTTPStream::c_readResponse() noexcept
|
||||
Task<StreamResult<HTTPResponse>> HTTPClient::c_request(ip_address_t address, std::uint16_t port, bool https,
|
||||
HTTPRequest request) noexcept
|
||||
{
|
||||
if (const StreamError error = createSocket(address, port, https); error != StreamError::SUCCESS)
|
||||
std::string hostname;
|
||||
if (auto it = request.headers.find("host"); it != request.headers.end())
|
||||
{
|
||||
hostname = it->second;
|
||||
}
|
||||
if (const StreamError error = createSocket(address, hostname, port, https); error != StreamError::SUCCESS)
|
||||
{
|
||||
co_return error;
|
||||
}
|
||||
@ -172,7 +181,7 @@ Task<StreamResult<HTTPResponse>> HTTPClient::c_request(ip_address_t address, std
|
||||
if (response.isError())
|
||||
{
|
||||
disconnect();
|
||||
if (const StreamError error = createSocket(address, port, https); error != StreamError::SUCCESS)
|
||||
if (const StreamError error = createSocket(address, hostname, port, https); error != StreamError::SUCCESS)
|
||||
{
|
||||
co_return error;
|
||||
}
|
||||
@ -242,7 +251,7 @@ void HTTPClient::disconnect() noexcept
|
||||
socket_ = nullptr;
|
||||
}
|
||||
|
||||
StreamError HTTPClient::createSocket(ip_address_t address, std::uint16_t port, bool https) noexcept
|
||||
StreamError HTTPClient::createSocket(ip_address_t address, const std::string& hostname, std::uint16_t port, bool https) noexcept
|
||||
{
|
||||
if (socket_ != nullptr && address == lastIP_ && port == lastPort_ && https == lastWasHttps_)
|
||||
{
|
||||
@ -250,17 +259,36 @@ StreamError HTTPClient::createSocket(ip_address_t address, std::uint16_t port, b
|
||||
}
|
||||
|
||||
disconnect();
|
||||
MIJIN_ASSERT(!https, "HTTPS not supported yet.");
|
||||
|
||||
std::unique_ptr<TCPSocket> newSocket = std::make_unique<TCPSocket>();
|
||||
if (const StreamError error = newSocket->open(address, port); error != StreamError::SUCCESS)
|
||||
{
|
||||
return error;
|
||||
}
|
||||
socket_ = std::move(newSocket);
|
||||
if (!https)
|
||||
{
|
||||
sslStream_.reset();
|
||||
stream_.construct(socket_->getStream());
|
||||
}
|
||||
else
|
||||
{
|
||||
#if defined(MIJIN_ENABLE_OPENSSL)
|
||||
std::unique_ptr<SSLStream> sslStream = std::make_unique<SSLStream>();
|
||||
if (const StreamError error = sslStream->open(socket_->getStream(), hostname); error != StreamError::SUCCESS)
|
||||
{
|
||||
return error;
|
||||
}
|
||||
sslStream_ = std::move(sslStream);
|
||||
stream_.construct(*sslStream_);
|
||||
#else
|
||||
return StreamError::NOT_SUPPORTED;
|
||||
#endif
|
||||
}
|
||||
|
||||
lastIP_ = address;
|
||||
lastPort_ = port;
|
||||
lastWasHttps_ = https;
|
||||
socket_ = std::move(newSocket);
|
||||
stream_.construct(socket_->getStream());
|
||||
return StreamError::SUCCESS;
|
||||
}
|
||||
}
|
||||
|
@ -57,6 +57,7 @@ class HTTPClient
|
||||
{
|
||||
private:
|
||||
std::unique_ptr<Socket> socket_;
|
||||
std::unique_ptr<Stream> sslStream_;
|
||||
mijin::BoxedObject<HTTPStream> stream_;
|
||||
ip_address_t lastIP_;
|
||||
std::uint16_t lastPort_ = 0;
|
||||
@ -67,7 +68,7 @@ public:
|
||||
Task<StreamResult<HTTPResponse>> c_request(const URL& url, HTTPRequest request = {}) noexcept;
|
||||
void disconnect() noexcept;
|
||||
private:
|
||||
StreamError createSocket(ip_address_t address, std::uint16_t port, bool https) noexcept;
|
||||
StreamError createSocket(ip_address_t address, const std::string& hostname, std::uint16_t port, bool https) noexcept;
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -257,6 +257,34 @@ public:
|
||||
return result;
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
int getReadRequest() const noexcept
|
||||
{
|
||||
return BIO_get_read_request(handle_);
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
int getWriteGuarantee() const noexcept
|
||||
{
|
||||
return BIO_get_write_guarantee(handle_);
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
int getWritePending() const noexcept
|
||||
{
|
||||
return BIO_wpending(handle_);
|
||||
}
|
||||
|
||||
Error flush() const noexcept
|
||||
{
|
||||
ERR_clear_error();
|
||||
if (!BIO_flush(handle_))
|
||||
{
|
||||
return Error::current();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
static void upReferences(BIO* handle) noexcept
|
||||
{
|
||||
BIO_up_ref(handle);
|
||||
@ -371,6 +399,12 @@ public:
|
||||
return result;
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
int pending() const noexcept
|
||||
{
|
||||
return SSL_pending(handle_);
|
||||
}
|
||||
|
||||
static void upReferences(SSL* handle) noexcept
|
||||
{
|
||||
SSL_up_ref(handle);
|
||||
|
@ -93,7 +93,7 @@ StreamError SSLStream::open(Stream& base, const std::string& hostname) noexcept
|
||||
externalBio_ = std::move(externalBio);
|
||||
base_ = &base;
|
||||
|
||||
if (const ossl::Error error = runIOLoop(&ossl::Ssl::connect); !error.isSuccess())
|
||||
if (const ossl::Error error = runIOLoop(&ossl::Ssl::connect, true); !error.isSuccess())
|
||||
{
|
||||
ssl_.free();
|
||||
externalBio.free();
|
||||
@ -116,7 +116,7 @@ void SSLStream::close() noexcept
|
||||
{
|
||||
MIJIN_ASSERT(base_ != nullptr, "SSL stream is not open.");
|
||||
base_ = nullptr;
|
||||
(void) runIOLoop(&ossl::Ssl::shutdown);
|
||||
(void) runIOLoop(&ossl::Ssl::shutdown, true);
|
||||
ssl_.free();
|
||||
externalBio_.free();
|
||||
}
|
||||
@ -131,7 +131,7 @@ StreamError SSLStream::writeRaw(std::span<const std::uint8_t> buffer)
|
||||
std::size_t bytesToWrite = buffer.size();
|
||||
while (bytesToWrite > 0)
|
||||
{
|
||||
const ossl::Result<int> result = runIOLoop(&ossl::Ssl::write, buffer.data() + buffer.size() - bytesToWrite,
|
||||
const ossl::Result<int> result = runIOLoop(&ossl::Ssl::write, true, buffer.data() + buffer.size() - bytesToWrite,
|
||||
static_cast<int>(bytesToWrite));
|
||||
if (result.isError())
|
||||
{
|
||||
@ -156,15 +156,25 @@ StreamError SSLStream::readRaw(std::span<std::uint8_t> buffer, const mijin::Read
|
||||
return StreamError::SUCCESS;
|
||||
}
|
||||
|
||||
if (const StreamError error = baseToBio(); error != StreamError::SUCCESS)
|
||||
{
|
||||
return error;
|
||||
}
|
||||
|
||||
if (!options.partial && options.noBlock && static_cast<std::size_t>(ssl_.pending()) < buffer.size())
|
||||
{
|
||||
return StreamError::WOULD_BLOCK;
|
||||
}
|
||||
|
||||
std::size_t bytesToRead = buffer.size();
|
||||
if (options.partial)
|
||||
{
|
||||
bytesToRead = std::min<std::size_t>(bytesToRead, ssl_.pending());
|
||||
}
|
||||
while (bytesToRead > 0)
|
||||
{
|
||||
if (const StreamError error = baseToBio(); error != StreamError::SUCCESS)
|
||||
{
|
||||
return error;
|
||||
}
|
||||
|
||||
const ossl::Result<int> result = runIOLoop(&ossl::Ssl::read, buffer.data() + buffer.size() - bytesToRead,
|
||||
const ossl::Result<int> result = runIOLoop(&ossl::Ssl::read, !options.noBlock,
|
||||
buffer.data() + buffer.size() - bytesToRead,
|
||||
static_cast<int>(bytesToRead));
|
||||
if (result.isError())
|
||||
{
|
||||
@ -216,9 +226,9 @@ StreamFeatures SSLStream::getFeatures()
|
||||
.tell = false,
|
||||
.seek = false,
|
||||
.readOptions = {
|
||||
.partial = false,
|
||||
.peek = false,
|
||||
.noBlock = false
|
||||
.partial = true,
|
||||
.peek = true,
|
||||
.noBlock = true
|
||||
}
|
||||
};
|
||||
}
|
||||
@ -247,18 +257,20 @@ StreamError SSLStream::bioToBase() noexcept
|
||||
StreamError SSLStream::baseToBio() noexcept
|
||||
{
|
||||
std::array<std::uint8_t, BIO_BUFFER_SIZE> buffer;
|
||||
std::size_t toRead = externalBio_.getWriteGuarantee();
|
||||
|
||||
while (true)
|
||||
{
|
||||
std::size_t bytes = 0;
|
||||
std::size_t maxBytes = std::min(BIO_BUFFER_SIZE - externalBio_.ctrlWPending(), buffer.size());
|
||||
std::size_t maxBytes = std::min(toRead, buffer.size());
|
||||
|
||||
if (maxBytes == 0)
|
||||
{
|
||||
// buffer is full
|
||||
return StreamError::SUCCESS;
|
||||
break;
|
||||
}
|
||||
|
||||
if (const StreamError error = base_->readRaw(buffer.data(), maxBytes,{.partial = true, .noBlock = true},
|
||||
if (const StreamError error = base_->readRaw(buffer.data(), maxBytes, {.partial = true, .noBlock = true},
|
||||
&bytes); error != StreamError::SUCCESS)
|
||||
{
|
||||
return error;
|
||||
@ -274,6 +286,13 @@ StreamError SSLStream::baseToBio() noexcept
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
MIJIN_ASSERT(result.getValue() == static_cast<int>(bytes), "BIO reported more bytes in buffer than it actually accepted?");
|
||||
toRead -= bytes;
|
||||
}
|
||||
|
||||
if (const ossl::Error error = externalBio_.flush(); !error.isSuccess())
|
||||
{
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
return StreamError::SUCCESS;
|
||||
}
|
||||
} // namespace shiken
|
||||
|
@ -5,8 +5,8 @@
|
||||
#if !defined(MIJIN_NET_SSL_HPP_INCLUDED)
|
||||
#define MIJIN_NET_SSL_HPP_INCLUDED 1
|
||||
|
||||
#if !MIJIN_ENABLE_OPENSSL
|
||||
#error "SSL support not enabled. Set MIJIN_ENABLE_OPENSSL to True in your environment settings."
|
||||
#if !defined(MIJIN_ENABLE_OPENSSL)
|
||||
#error "SSL support not enabled. Set MIJIN_ENABLE_OPENSSL to True in your SCons environment settings."
|
||||
#endif // !MIJIN_ENABLE_OPENSSL
|
||||
|
||||
#include <memory>
|
||||
@ -45,13 +45,15 @@ private:
|
||||
|
||||
|
||||
template<typename TFunc, typename... TArgs>
|
||||
auto runIOLoop(TFunc&& func, TArgs&&... args) -> std::decay_t<std::invoke_result_t<TFunc, ossl::Ssl&, TArgs...>>
|
||||
auto runIOLoop(TFunc&& func, bool block, TArgs&&... args) -> std::decay_t<std::invoke_result_t<TFunc, ossl::Ssl&, TArgs...>>
|
||||
{
|
||||
using result_t = std::decay_t<std::invoke_result_t<TFunc, ossl::Ssl&, TArgs...>>;
|
||||
while (true)
|
||||
const std::size_t maxTries = block ? std::numeric_limits<std::size_t>::max() : 10;
|
||||
|
||||
ossl::Error error;
|
||||
for (std::size_t tryNum = 0; tryNum < maxTries; ++tryNum)
|
||||
{
|
||||
auto result = std::invoke(std::forward<TFunc>(func), ssl_, std::forward<TArgs>(args)...);
|
||||
ossl::Error error;
|
||||
if constexpr (std::is_same_v<result_t, ossl::Error>)
|
||||
{
|
||||
if (error.isSuccess())
|
||||
@ -86,6 +88,7 @@ private:
|
||||
return error;
|
||||
}
|
||||
}
|
||||
return error;
|
||||
}
|
||||
// mijin::Task<StreamError> c_readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) override;
|
||||
// mijin::Task<StreamError> c_writeRaw(std::span<const std::uint8_t> buffer) override;
|
||||
|
Loading…
x
Reference in New Issue
Block a user