More SSL stuff (still doesn't work :/).
This commit is contained in:
parent
0be34a845a
commit
a43f92fb58
@ -36,7 +36,47 @@ void Stream::flush() {}
|
|||||||
|
|
||||||
mijin::Task<StreamError> Stream::c_readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options, std::size_t* outBytesRead)
|
mijin::Task<StreamError> Stream::c_readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options, std::size_t* outBytesRead)
|
||||||
{
|
{
|
||||||
co_return readRaw(buffer, options, outBytesRead);
|
std::size_t bytesToRead = buffer.size();
|
||||||
|
if (bytesToRead == 0)
|
||||||
|
{
|
||||||
|
co_return StreamError::SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
while(true)
|
||||||
|
{
|
||||||
|
bool done = false;
|
||||||
|
std::size_t bytesRead = 0;
|
||||||
|
const StreamError error = readRaw(buffer.data() + buffer.size() - bytesToRead, bytesToRead,
|
||||||
|
{.partial = true, .noBlock = true}, &bytesToRead);
|
||||||
|
switch (error)
|
||||||
|
{
|
||||||
|
case StreamError::SUCCESS:
|
||||||
|
bytesToRead -= bytesRead;
|
||||||
|
if (options.partial || bytesToRead == 0)
|
||||||
|
{
|
||||||
|
done = true;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case StreamError::WOULD_BLOCK:
|
||||||
|
if (options.noBlock)
|
||||||
|
{
|
||||||
|
co_return StreamError::WOULD_BLOCK;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
co_return error;
|
||||||
|
}
|
||||||
|
if (done)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
co_await mijin::c_suspend();
|
||||||
|
}
|
||||||
|
if (outBytesRead != nullptr)
|
||||||
|
{
|
||||||
|
*outBytesRead = buffer.size() - bytesToRead;
|
||||||
|
}
|
||||||
|
co_return StreamError::SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
mijin::Task<StreamError> Stream::c_writeRaw(std::span<const std::uint8_t> buffer)
|
mijin::Task<StreamError> Stream::c_writeRaw(std::span<const std::uint8_t> buffer)
|
||||||
|
@ -70,6 +70,7 @@ enum class [[nodiscard]] StreamError
|
|||||||
NOT_SUPPORTED = 2,
|
NOT_SUPPORTED = 2,
|
||||||
CONNECTION_CLOSED = 3,
|
CONNECTION_CLOSED = 3,
|
||||||
PROTOCOL_ERROR = 4,
|
PROTOCOL_ERROR = 4,
|
||||||
|
WOULD_BLOCK = 5,
|
||||||
UNKNOWN_ERROR = -1
|
UNKNOWN_ERROR = -1
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -434,6 +435,8 @@ inline const char* errorName(StreamError error) noexcept
|
|||||||
return "connection closed";
|
return "connection closed";
|
||||||
case StreamError::PROTOCOL_ERROR:
|
case StreamError::PROTOCOL_ERROR:
|
||||||
return "protocol error";
|
return "protocol error";
|
||||||
|
case StreamError::WOULD_BLOCK:
|
||||||
|
return "would block";
|
||||||
case StreamError::UNKNOWN_ERROR:
|
case StreamError::UNKNOWN_ERROR:
|
||||||
return "unknown error";
|
return "unknown error";
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,10 @@
|
|||||||
#include "../util/iterators.hpp"
|
#include "../util/iterators.hpp"
|
||||||
#include "../util/string.hpp"
|
#include "../util/string.hpp"
|
||||||
|
|
||||||
|
#if defined(MIJIN_ENABLE_OPENSSL)
|
||||||
|
#include "./ssl.hpp"
|
||||||
|
#endif
|
||||||
|
|
||||||
#define MIJIN_HTTP_WRITE(text) \
|
#define MIJIN_HTTP_WRITE(text) \
|
||||||
do \
|
do \
|
||||||
{ \
|
{ \
|
||||||
@ -160,7 +164,12 @@ Task<StreamResult<HTTPResponse>> HTTPStream::c_readResponse() noexcept
|
|||||||
Task<StreamResult<HTTPResponse>> HTTPClient::c_request(ip_address_t address, std::uint16_t port, bool https,
|
Task<StreamResult<HTTPResponse>> HTTPClient::c_request(ip_address_t address, std::uint16_t port, bool https,
|
||||||
HTTPRequest request) noexcept
|
HTTPRequest request) noexcept
|
||||||
{
|
{
|
||||||
if (const StreamError error = createSocket(address, port, https); error != StreamError::SUCCESS)
|
std::string hostname;
|
||||||
|
if (auto it = request.headers.find("host"); it != request.headers.end())
|
||||||
|
{
|
||||||
|
hostname = it->second;
|
||||||
|
}
|
||||||
|
if (const StreamError error = createSocket(address, hostname, port, https); error != StreamError::SUCCESS)
|
||||||
{
|
{
|
||||||
co_return error;
|
co_return error;
|
||||||
}
|
}
|
||||||
@ -172,7 +181,7 @@ Task<StreamResult<HTTPResponse>> HTTPClient::c_request(ip_address_t address, std
|
|||||||
if (response.isError())
|
if (response.isError())
|
||||||
{
|
{
|
||||||
disconnect();
|
disconnect();
|
||||||
if (const StreamError error = createSocket(address, port, https); error != StreamError::SUCCESS)
|
if (const StreamError error = createSocket(address, hostname, port, https); error != StreamError::SUCCESS)
|
||||||
{
|
{
|
||||||
co_return error;
|
co_return error;
|
||||||
}
|
}
|
||||||
@ -242,7 +251,7 @@ void HTTPClient::disconnect() noexcept
|
|||||||
socket_ = nullptr;
|
socket_ = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
StreamError HTTPClient::createSocket(ip_address_t address, std::uint16_t port, bool https) noexcept
|
StreamError HTTPClient::createSocket(ip_address_t address, const std::string& hostname, std::uint16_t port, bool https) noexcept
|
||||||
{
|
{
|
||||||
if (socket_ != nullptr && address == lastIP_ && port == lastPort_ && https == lastWasHttps_)
|
if (socket_ != nullptr && address == lastIP_ && port == lastPort_ && https == lastWasHttps_)
|
||||||
{
|
{
|
||||||
@ -250,17 +259,36 @@ StreamError HTTPClient::createSocket(ip_address_t address, std::uint16_t port, b
|
|||||||
}
|
}
|
||||||
|
|
||||||
disconnect();
|
disconnect();
|
||||||
MIJIN_ASSERT(!https, "HTTPS not supported yet.");
|
|
||||||
std::unique_ptr<TCPSocket> newSocket = std::make_unique<TCPSocket>();
|
std::unique_ptr<TCPSocket> newSocket = std::make_unique<TCPSocket>();
|
||||||
if (const StreamError error = newSocket->open(address, port); error != StreamError::SUCCESS)
|
if (const StreamError error = newSocket->open(address, port); error != StreamError::SUCCESS)
|
||||||
{
|
{
|
||||||
return error;
|
return error;
|
||||||
}
|
}
|
||||||
|
socket_ = std::move(newSocket);
|
||||||
|
if (!https)
|
||||||
|
{
|
||||||
|
sslStream_.reset();
|
||||||
|
stream_.construct(socket_->getStream());
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
#if defined(MIJIN_ENABLE_OPENSSL)
|
||||||
|
std::unique_ptr<SSLStream> sslStream = std::make_unique<SSLStream>();
|
||||||
|
if (const StreamError error = sslStream->open(socket_->getStream(), hostname); error != StreamError::SUCCESS)
|
||||||
|
{
|
||||||
|
return error;
|
||||||
|
}
|
||||||
|
sslStream_ = std::move(sslStream);
|
||||||
|
stream_.construct(*sslStream_);
|
||||||
|
#else
|
||||||
|
return StreamError::NOT_SUPPORTED;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
lastIP_ = address;
|
lastIP_ = address;
|
||||||
lastPort_ = port;
|
lastPort_ = port;
|
||||||
lastWasHttps_ = https;
|
lastWasHttps_ = https;
|
||||||
socket_ = std::move(newSocket);
|
|
||||||
stream_.construct(socket_->getStream());
|
|
||||||
return StreamError::SUCCESS;
|
return StreamError::SUCCESS;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -57,6 +57,7 @@ class HTTPClient
|
|||||||
{
|
{
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<Socket> socket_;
|
std::unique_ptr<Socket> socket_;
|
||||||
|
std::unique_ptr<Stream> sslStream_;
|
||||||
mijin::BoxedObject<HTTPStream> stream_;
|
mijin::BoxedObject<HTTPStream> stream_;
|
||||||
ip_address_t lastIP_;
|
ip_address_t lastIP_;
|
||||||
std::uint16_t lastPort_ = 0;
|
std::uint16_t lastPort_ = 0;
|
||||||
@ -67,7 +68,7 @@ public:
|
|||||||
Task<StreamResult<HTTPResponse>> c_request(const URL& url, HTTPRequest request = {}) noexcept;
|
Task<StreamResult<HTTPResponse>> c_request(const URL& url, HTTPRequest request = {}) noexcept;
|
||||||
void disconnect() noexcept;
|
void disconnect() noexcept;
|
||||||
private:
|
private:
|
||||||
StreamError createSocket(ip_address_t address, std::uint16_t port, bool https) noexcept;
|
StreamError createSocket(ip_address_t address, const std::string& hostname, std::uint16_t port, bool https) noexcept;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -257,6 +257,34 @@ public:
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]]
|
||||||
|
int getReadRequest() const noexcept
|
||||||
|
{
|
||||||
|
return BIO_get_read_request(handle_);
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]]
|
||||||
|
int getWriteGuarantee() const noexcept
|
||||||
|
{
|
||||||
|
return BIO_get_write_guarantee(handle_);
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]]
|
||||||
|
int getWritePending() const noexcept
|
||||||
|
{
|
||||||
|
return BIO_wpending(handle_);
|
||||||
|
}
|
||||||
|
|
||||||
|
Error flush() const noexcept
|
||||||
|
{
|
||||||
|
ERR_clear_error();
|
||||||
|
if (!BIO_flush(handle_))
|
||||||
|
{
|
||||||
|
return Error::current();
|
||||||
|
}
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
static void upReferences(BIO* handle) noexcept
|
static void upReferences(BIO* handle) noexcept
|
||||||
{
|
{
|
||||||
BIO_up_ref(handle);
|
BIO_up_ref(handle);
|
||||||
@ -371,6 +399,12 @@ public:
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]]
|
||||||
|
int pending() const noexcept
|
||||||
|
{
|
||||||
|
return SSL_pending(handle_);
|
||||||
|
}
|
||||||
|
|
||||||
static void upReferences(SSL* handle) noexcept
|
static void upReferences(SSL* handle) noexcept
|
||||||
{
|
{
|
||||||
SSL_up_ref(handle);
|
SSL_up_ref(handle);
|
||||||
|
@ -93,7 +93,7 @@ StreamError SSLStream::open(Stream& base, const std::string& hostname) noexcept
|
|||||||
externalBio_ = std::move(externalBio);
|
externalBio_ = std::move(externalBio);
|
||||||
base_ = &base;
|
base_ = &base;
|
||||||
|
|
||||||
if (const ossl::Error error = runIOLoop(&ossl::Ssl::connect); !error.isSuccess())
|
if (const ossl::Error error = runIOLoop(&ossl::Ssl::connect, true); !error.isSuccess())
|
||||||
{
|
{
|
||||||
ssl_.free();
|
ssl_.free();
|
||||||
externalBio.free();
|
externalBio.free();
|
||||||
@ -116,7 +116,7 @@ void SSLStream::close() noexcept
|
|||||||
{
|
{
|
||||||
MIJIN_ASSERT(base_ != nullptr, "SSL stream is not open.");
|
MIJIN_ASSERT(base_ != nullptr, "SSL stream is not open.");
|
||||||
base_ = nullptr;
|
base_ = nullptr;
|
||||||
(void) runIOLoop(&ossl::Ssl::shutdown);
|
(void) runIOLoop(&ossl::Ssl::shutdown, true);
|
||||||
ssl_.free();
|
ssl_.free();
|
||||||
externalBio_.free();
|
externalBio_.free();
|
||||||
}
|
}
|
||||||
@ -131,7 +131,7 @@ StreamError SSLStream::writeRaw(std::span<const std::uint8_t> buffer)
|
|||||||
std::size_t bytesToWrite = buffer.size();
|
std::size_t bytesToWrite = buffer.size();
|
||||||
while (bytesToWrite > 0)
|
while (bytesToWrite > 0)
|
||||||
{
|
{
|
||||||
const ossl::Result<int> result = runIOLoop(&ossl::Ssl::write, buffer.data() + buffer.size() - bytesToWrite,
|
const ossl::Result<int> result = runIOLoop(&ossl::Ssl::write, true, buffer.data() + buffer.size() - bytesToWrite,
|
||||||
static_cast<int>(bytesToWrite));
|
static_cast<int>(bytesToWrite));
|
||||||
if (result.isError())
|
if (result.isError())
|
||||||
{
|
{
|
||||||
@ -156,15 +156,25 @@ StreamError SSLStream::readRaw(std::span<std::uint8_t> buffer, const mijin::Read
|
|||||||
return StreamError::SUCCESS;
|
return StreamError::SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::size_t bytesToRead = buffer.size();
|
|
||||||
while (bytesToRead > 0)
|
|
||||||
{
|
|
||||||
if (const StreamError error = baseToBio(); error != StreamError::SUCCESS)
|
if (const StreamError error = baseToBio(); error != StreamError::SUCCESS)
|
||||||
{
|
{
|
||||||
return error;
|
return error;
|
||||||
}
|
}
|
||||||
|
|
||||||
const ossl::Result<int> result = runIOLoop(&ossl::Ssl::read, buffer.data() + buffer.size() - bytesToRead,
|
if (!options.partial && options.noBlock && static_cast<std::size_t>(ssl_.pending()) < buffer.size())
|
||||||
|
{
|
||||||
|
return StreamError::WOULD_BLOCK;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::size_t bytesToRead = buffer.size();
|
||||||
|
if (options.partial)
|
||||||
|
{
|
||||||
|
bytesToRead = std::min<std::size_t>(bytesToRead, ssl_.pending());
|
||||||
|
}
|
||||||
|
while (bytesToRead > 0)
|
||||||
|
{
|
||||||
|
const ossl::Result<int> result = runIOLoop(&ossl::Ssl::read, !options.noBlock,
|
||||||
|
buffer.data() + buffer.size() - bytesToRead,
|
||||||
static_cast<int>(bytesToRead));
|
static_cast<int>(bytesToRead));
|
||||||
if (result.isError())
|
if (result.isError())
|
||||||
{
|
{
|
||||||
@ -216,9 +226,9 @@ StreamFeatures SSLStream::getFeatures()
|
|||||||
.tell = false,
|
.tell = false,
|
||||||
.seek = false,
|
.seek = false,
|
||||||
.readOptions = {
|
.readOptions = {
|
||||||
.partial = false,
|
.partial = true,
|
||||||
.peek = false,
|
.peek = true,
|
||||||
.noBlock = false
|
.noBlock = true
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -247,18 +257,20 @@ StreamError SSLStream::bioToBase() noexcept
|
|||||||
StreamError SSLStream::baseToBio() noexcept
|
StreamError SSLStream::baseToBio() noexcept
|
||||||
{
|
{
|
||||||
std::array<std::uint8_t, BIO_BUFFER_SIZE> buffer;
|
std::array<std::uint8_t, BIO_BUFFER_SIZE> buffer;
|
||||||
|
std::size_t toRead = externalBio_.getWriteGuarantee();
|
||||||
|
|
||||||
while (true)
|
while (true)
|
||||||
{
|
{
|
||||||
std::size_t bytes = 0;
|
std::size_t bytes = 0;
|
||||||
std::size_t maxBytes = std::min(BIO_BUFFER_SIZE - externalBio_.ctrlWPending(), buffer.size());
|
std::size_t maxBytes = std::min(toRead, buffer.size());
|
||||||
|
|
||||||
if (maxBytes == 0)
|
if (maxBytes == 0)
|
||||||
{
|
{
|
||||||
// buffer is full
|
// buffer is full
|
||||||
return StreamError::SUCCESS;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (const StreamError error = base_->readRaw(buffer.data(), maxBytes,{.partial = true, .noBlock = true},
|
if (const StreamError error = base_->readRaw(buffer.data(), maxBytes, {.partial = true, .noBlock = true},
|
||||||
&bytes); error != StreamError::SUCCESS)
|
&bytes); error != StreamError::SUCCESS)
|
||||||
{
|
{
|
||||||
return error;
|
return error;
|
||||||
@ -274,6 +286,13 @@ StreamError SSLStream::baseToBio() noexcept
|
|||||||
return StreamError::UNKNOWN_ERROR;
|
return StreamError::UNKNOWN_ERROR;
|
||||||
}
|
}
|
||||||
MIJIN_ASSERT(result.getValue() == static_cast<int>(bytes), "BIO reported more bytes in buffer than it actually accepted?");
|
MIJIN_ASSERT(result.getValue() == static_cast<int>(bytes), "BIO reported more bytes in buffer than it actually accepted?");
|
||||||
|
toRead -= bytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (const ossl::Error error = externalBio_.flush(); !error.isSuccess())
|
||||||
|
{
|
||||||
|
return StreamError::UNKNOWN_ERROR;
|
||||||
|
}
|
||||||
|
return StreamError::SUCCESS;
|
||||||
}
|
}
|
||||||
} // namespace shiken
|
} // namespace shiken
|
||||||
|
@ -5,8 +5,8 @@
|
|||||||
#if !defined(MIJIN_NET_SSL_HPP_INCLUDED)
|
#if !defined(MIJIN_NET_SSL_HPP_INCLUDED)
|
||||||
#define MIJIN_NET_SSL_HPP_INCLUDED 1
|
#define MIJIN_NET_SSL_HPP_INCLUDED 1
|
||||||
|
|
||||||
#if !MIJIN_ENABLE_OPENSSL
|
#if !defined(MIJIN_ENABLE_OPENSSL)
|
||||||
#error "SSL support not enabled. Set MIJIN_ENABLE_OPENSSL to True in your environment settings."
|
#error "SSL support not enabled. Set MIJIN_ENABLE_OPENSSL to True in your SCons environment settings."
|
||||||
#endif // !MIJIN_ENABLE_OPENSSL
|
#endif // !MIJIN_ENABLE_OPENSSL
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
@ -45,13 +45,15 @@ private:
|
|||||||
|
|
||||||
|
|
||||||
template<typename TFunc, typename... TArgs>
|
template<typename TFunc, typename... TArgs>
|
||||||
auto runIOLoop(TFunc&& func, TArgs&&... args) -> std::decay_t<std::invoke_result_t<TFunc, ossl::Ssl&, TArgs...>>
|
auto runIOLoop(TFunc&& func, bool block, TArgs&&... args) -> std::decay_t<std::invoke_result_t<TFunc, ossl::Ssl&, TArgs...>>
|
||||||
{
|
{
|
||||||
using result_t = std::decay_t<std::invoke_result_t<TFunc, ossl::Ssl&, TArgs...>>;
|
using result_t = std::decay_t<std::invoke_result_t<TFunc, ossl::Ssl&, TArgs...>>;
|
||||||
while (true)
|
const std::size_t maxTries = block ? std::numeric_limits<std::size_t>::max() : 10;
|
||||||
|
|
||||||
|
ossl::Error error;
|
||||||
|
for (std::size_t tryNum = 0; tryNum < maxTries; ++tryNum)
|
||||||
{
|
{
|
||||||
auto result = std::invoke(std::forward<TFunc>(func), ssl_, std::forward<TArgs>(args)...);
|
auto result = std::invoke(std::forward<TFunc>(func), ssl_, std::forward<TArgs>(args)...);
|
||||||
ossl::Error error;
|
|
||||||
if constexpr (std::is_same_v<result_t, ossl::Error>)
|
if constexpr (std::is_same_v<result_t, ossl::Error>)
|
||||||
{
|
{
|
||||||
if (error.isSuccess())
|
if (error.isSuccess())
|
||||||
@ -86,6 +88,7 @@ private:
|
|||||||
return error;
|
return error;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return error;
|
||||||
}
|
}
|
||||||
// mijin::Task<StreamError> c_readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) override;
|
// mijin::Task<StreamError> c_readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) override;
|
||||||
// mijin::Task<StreamError> c_writeRaw(std::span<const std::uint8_t> buffer) override;
|
// mijin::Task<StreamError> c_writeRaw(std::span<const std::uint8_t> buffer) override;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user