SSLStream (WIP)
This commit is contained in:
parent
0acadf994d
commit
f761f2fb07
8
SModule
8
SModule
@ -27,6 +27,14 @@ if env['BUILD_TYPE'] == 'debug':
|
||||
cppdefines += ['MIJIN_DEBUG=1', 'MIJIN_CHECKED_ITERATORS=1']
|
||||
|
||||
|
||||
# SSL libs
|
||||
if env.get('MIJIN_ENABLE_OPENSSL'):
|
||||
cppdefines.append('MIJIN_ENABLE_OPENSSL=1')
|
||||
mijin_sources.extend(Split("""
|
||||
source/mijin/net/ssl.cpp
|
||||
"""))
|
||||
|
||||
|
||||
lib_mijin = env.UnityStaticLibrary(
|
||||
target = env['LIB_DIR'] + '/mijin',
|
||||
source = mijin_sources,
|
||||
|
@ -6,5 +6,9 @@
|
||||
"winsock2":
|
||||
{
|
||||
"condition": "target_os == 'nt'"
|
||||
},
|
||||
"openssl":
|
||||
{
|
||||
"condition": "getenv('MIJIN_ENABLE_OPENSSL')"
|
||||
}
|
||||
}
|
||||
|
@ -41,6 +41,7 @@ struct ReadOptions
|
||||
{
|
||||
bool partial : 1 = false;
|
||||
bool peek : 1 = false;
|
||||
bool noBlock : 1 = false;
|
||||
};
|
||||
|
||||
struct StreamFeatures
|
||||
|
@ -204,11 +204,20 @@ Task<StreamResult<HTTPResponse>> HTTPClient::c_request(const URL& url, HTTPReque
|
||||
co_return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
Optional<ip_address_t> ipAddress = ipAddressFromString(url.getHost());
|
||||
// TODO: lookup host
|
||||
if (ipAddress.empty())
|
||||
{
|
||||
StreamResult<std::vector<ip_address_t>> addresses = co_await c_resolveHostname(url.getHost());
|
||||
if (addresses.isError())
|
||||
{
|
||||
co_return addresses.getError();
|
||||
}
|
||||
else if (addresses->empty())
|
||||
{
|
||||
co_return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
// TODO: try all addresses
|
||||
ipAddress = addresses->front();
|
||||
}
|
||||
|
||||
if (!request.headers.contains("host"))
|
||||
{
|
||||
|
@ -62,6 +62,18 @@ inline Optional<ip_address_t> ipAddressFromString(std::string_view stringView) n
|
||||
|
||||
[[nodiscard]]
|
||||
Task<StreamResult<std::vector<ip_address_t>>> c_resolveHostname(std::string hostname) noexcept;
|
||||
|
||||
[[nodiscard]]
|
||||
inline Task<StreamResult<std::vector<ip_address_t>>> c_resolveHostname(std::string_view hostname) noexcept
|
||||
{
|
||||
return c_resolveHostname(std::string(hostname.begin(), hostname.end()));
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
inline Task<StreamResult<std::vector<ip_address_t>>> c_resolveHostname(const char* hostname) noexcept
|
||||
{
|
||||
return c_resolveHostname(std::string(hostname));
|
||||
}
|
||||
}
|
||||
|
||||
#endif // !defined(MIJIN_NET_IP_HPP_INCLUDED)
|
||||
|
@ -183,14 +183,25 @@ StreamError translateWinError() noexcept
|
||||
StreamError TCPStream::readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options, std::size_t* outBytesRead)
|
||||
{
|
||||
MIJIN_ASSERT(isOpen(), "Socket is not open.");
|
||||
setAsync(false);
|
||||
setNoblock(options.noBlock);
|
||||
|
||||
const long bytesRead = osRecv(handle_, buffer, readFlags(options));
|
||||
if (bytesRead < 0)
|
||||
{
|
||||
if (!options.noBlock || errno != EAGAIN)
|
||||
{
|
||||
return translateErrno();
|
||||
}
|
||||
if (outBytesRead != nullptr)
|
||||
{
|
||||
*outBytesRead = 0;
|
||||
}
|
||||
return StreamError::SUCCESS;
|
||||
}
|
||||
if (outBytesRead != nullptr)
|
||||
{
|
||||
*outBytesRead = static_cast<std::size_t>(bytesRead);
|
||||
}
|
||||
|
||||
return StreamError::SUCCESS;
|
||||
}
|
||||
@ -198,7 +209,7 @@ StreamError TCPStream::readRaw(std::span<std::uint8_t> buffer, const ReadOptions
|
||||
StreamError TCPStream::writeRaw(std::span<const std::uint8_t> buffer)
|
||||
{
|
||||
MIJIN_ASSERT(isOpen(), "Socket is not open.");
|
||||
setAsync(false);
|
||||
setNoblock(false);
|
||||
|
||||
if (osSend(handle_, buffer, 0) < 0)
|
||||
{
|
||||
@ -211,7 +222,7 @@ StreamError TCPStream::writeRaw(std::span<const std::uint8_t> buffer)
|
||||
mijin::Task<StreamError> TCPStream::c_readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options, std::size_t* outBytesRead)
|
||||
{
|
||||
MIJIN_ASSERT(isOpen(), "Socket is not open.");
|
||||
setAsync(true);
|
||||
setNoblock(true);
|
||||
|
||||
if (buffer.empty())
|
||||
{
|
||||
@ -249,7 +260,7 @@ Task<StreamError> TCPStream::c_writeRaw(std::span<const std::uint8_t> buffer)
|
||||
co_return StreamError::SUCCESS;
|
||||
}
|
||||
|
||||
setAsync(true);
|
||||
setNoblock(true);
|
||||
|
||||
while (true)
|
||||
{
|
||||
@ -270,7 +281,7 @@ Task<StreamError> TCPStream::c_writeRaw(std::span<const std::uint8_t> buffer)
|
||||
}
|
||||
}
|
||||
|
||||
void TCPStream::setAsync(bool async)
|
||||
void TCPStream::setNoblock(bool async)
|
||||
{
|
||||
if (async == async_)
|
||||
{
|
||||
|
@ -78,7 +78,7 @@ public:
|
||||
void close() noexcept;
|
||||
[[nodiscard]] bool isOpen() const noexcept { return handle_ != INVALID_SOCKET_HANDLE; }
|
||||
private:
|
||||
void setAsync(bool async);
|
||||
void setNoblock(bool async);
|
||||
|
||||
friend class TCPServerSocket;
|
||||
};
|
||||
|
307
source/mijin/net/ssl.cpp
Normal file
307
source/mijin/net/ssl.cpp
Normal file
@ -0,0 +1,307 @@
|
||||
|
||||
#include "./ssl.hpp"
|
||||
|
||||
#include <mutex>
|
||||
#include <openssl/ssl.h>
|
||||
|
||||
|
||||
namespace mijin
|
||||
{
|
||||
namespace
|
||||
{
|
||||
inline constexpr int BIO_BUFFER_SIZE = 4096;
|
||||
SSL_CTX* getSSLContext(bool create = true) noexcept
|
||||
{
|
||||
static SSL_CTX* context = nullptr;
|
||||
static std::mutex contextMutex;
|
||||
|
||||
if (create && context == nullptr)
|
||||
{
|
||||
const std::unique_lock contextLock(contextMutex);
|
||||
if (context != nullptr)
|
||||
{
|
||||
return context;
|
||||
}
|
||||
context = SSL_CTX_new(SSLv23_client_method());
|
||||
SSL_CTX_set_verify(context, SSL_VERIFY_PEER, nullptr);
|
||||
if (!SSL_CTX_set_default_verify_paths(context)
|
||||
|| !SSL_CTX_set_min_proto_version(context, TLS1_2_VERSION))
|
||||
{
|
||||
SSL_CTX_free(context);
|
||||
context = nullptr;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
return context;
|
||||
}
|
||||
|
||||
class SSLCleanupHelper
|
||||
{
|
||||
public:
|
||||
~SSLCleanupHelper() noexcept
|
||||
{
|
||||
SSL_CTX* context = getSSLContext(false);
|
||||
if (context != nullptr)
|
||||
{
|
||||
SSL_CTX_free(context);
|
||||
}
|
||||
}
|
||||
} gSSLCleanupHelper;
|
||||
}
|
||||
|
||||
StreamError SSLStream::open(Stream& base, const std::string& hostname) noexcept
|
||||
{
|
||||
MIJIN_ASSERT(base_ == nullptr, "SSL stream is already open.");
|
||||
|
||||
SSL_CTX* context = getSSLContext();
|
||||
if (context == nullptr)
|
||||
{
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
|
||||
SSL* ssl = SSL_new(context);
|
||||
if (ssl == nullptr)
|
||||
{
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
|
||||
BIO* bioA;
|
||||
BIO* bioB;
|
||||
if (!BIO_new_bio_pair(&bioB, 0, &bioA, 0))
|
||||
{
|
||||
SSL_free(ssl);
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
SSL_set_bio(ssl, bioB, bioB);
|
||||
|
||||
if (!SSL_set_tlsext_host_name(ssl, hostname.c_str()))
|
||||
{
|
||||
SSL_free(ssl);
|
||||
BIO_free_all(bioA);
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
|
||||
if (!SSL_set1_host(ssl, hostname.c_str()))
|
||||
{
|
||||
SSL_free(ssl);
|
||||
BIO_free_all(bioA);
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
|
||||
ssl_ = ssl;
|
||||
bioA_ = bioA;
|
||||
bioB_ = bioB;
|
||||
base_ = &base;
|
||||
while(true)
|
||||
{
|
||||
if (const int result = SSL_connect(ssl); result < 1)
|
||||
{
|
||||
const int err = SSL_get_error(ssl, result);
|
||||
[[maybe_unused]] int rrA = BIO_get_read_request(bioA);
|
||||
[[maybe_unused]] int rrB = BIO_get_read_request(bioB);
|
||||
[[maybe_unused]] int wgA = BIO_get_write_guarantee(bioA);
|
||||
[[maybe_unused]] int wgB = BIO_get_write_guarantee(bioB);
|
||||
switch (err)
|
||||
{
|
||||
case SSL_ERROR_WANT_READ:
|
||||
case SSL_ERROR_WANT_WRITE:
|
||||
if(const StreamError error = baseToBio(); error != StreamError::SUCCESS)
|
||||
{
|
||||
SSL_free(ssl);
|
||||
BIO_free_all(bioA);
|
||||
return error;
|
||||
}
|
||||
if (const StreamError error = bioToBase(); error != StreamError::SUCCESS)
|
||||
{
|
||||
SSL_free(ssl);
|
||||
BIO_free_all(bioA);
|
||||
return error;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
SSL_free(ssl);
|
||||
BIO_free_all(bioA);
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (SSL_get_verify_result(ssl) != X509_V_OK)
|
||||
{
|
||||
SSL_free(ssl);
|
||||
BIO_free_all(bioA);
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
|
||||
return StreamError::SUCCESS;
|
||||
}
|
||||
|
||||
void SSLStream::close() noexcept
|
||||
{
|
||||
MIJIN_ASSERT(base_ != nullptr, "SSL stream is not open.");
|
||||
base_ = nullptr;
|
||||
SSL_shutdown(static_cast<SSL*>(ssl_));
|
||||
SSL_free(static_cast<SSL*>(ssl_));
|
||||
BIO_free_all(static_cast<BIO*>(bioA_));
|
||||
}
|
||||
|
||||
StreamError SSLStream::writeRaw(std::span<const std::uint8_t> buffer)
|
||||
{
|
||||
if (buffer.empty())
|
||||
{
|
||||
return StreamError::SUCCESS;
|
||||
}
|
||||
|
||||
SSL* ssl = static_cast<SSL*>(ssl_);
|
||||
std::size_t bytesToWrite = buffer.size();
|
||||
while (bytesToWrite > 0)
|
||||
{
|
||||
const int result = SSL_write(ssl, buffer.data() + buffer.size() - bytesToWrite, static_cast<int>(bytesToWrite));
|
||||
if (result <= 0)
|
||||
{
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
bytesToWrite -= result;
|
||||
|
||||
if (const StreamError error = bioToBase(); error != StreamError::SUCCESS)
|
||||
{
|
||||
return error;
|
||||
}
|
||||
}
|
||||
|
||||
return StreamError::SUCCESS;
|
||||
}
|
||||
|
||||
StreamError SSLStream::readRaw(std::span<std::uint8_t> buffer, const mijin::ReadOptions& options,
|
||||
std::size_t* outBytesRead)
|
||||
{
|
||||
if (buffer.empty())
|
||||
{
|
||||
return StreamError::SUCCESS;
|
||||
}
|
||||
SSL* ssl = static_cast<SSL*>(ssl_);
|
||||
|
||||
std::size_t bytesToRead = buffer.size();
|
||||
while (bytesToRead > 0)
|
||||
{
|
||||
if (const StreamError error = baseToBio(); error != StreamError::SUCCESS)
|
||||
{
|
||||
return error;
|
||||
}
|
||||
const int result = SSL_read(ssl, buffer.data() + buffer.size() - bytesToRead, static_cast<int>(bytesToRead));
|
||||
if (result <= 0)
|
||||
{
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: options and outBytesRead
|
||||
(void) options;
|
||||
if (outBytesRead != nullptr)
|
||||
{
|
||||
*outBytesRead = buffer.size();
|
||||
}
|
||||
|
||||
return StreamError::SUCCESS;
|
||||
}
|
||||
|
||||
std::size_t SSLStream::tell()
|
||||
{
|
||||
MIJIN_ERROR("SSLStream does not support tell().");
|
||||
return 0;
|
||||
}
|
||||
|
||||
StreamError SSLStream::seek(std::intptr_t pos, SeekMode seekMode)
|
||||
{
|
||||
MIJIN_ERROR("SSLStream does not support tell().");
|
||||
(void) pos;
|
||||
(void) seekMode;
|
||||
return StreamError::NOT_SUPPORTED;
|
||||
}
|
||||
|
||||
void SSLStream::flush()
|
||||
{
|
||||
base_->flush();
|
||||
}
|
||||
|
||||
bool SSLStream::isAtEnd()
|
||||
{
|
||||
return base_->isAtEnd();
|
||||
}
|
||||
|
||||
StreamFeatures SSLStream::getFeatures()
|
||||
{
|
||||
return {
|
||||
.read = true,
|
||||
.write = true,
|
||||
.tell = false,
|
||||
.seek = false,
|
||||
.readOptions = {
|
||||
.partial = false,
|
||||
.peek = false,
|
||||
.noBlock = false
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
StreamError SSLStream::bioToBase() noexcept
|
||||
{
|
||||
BIO* bio = static_cast<BIO*>(bioA_);
|
||||
std::array<std::uint8_t, BIO_BUFFER_SIZE> buffer;
|
||||
std::size_t bytes = std::min(BIO_ctrl_pending(bio), buffer.size());
|
||||
while (bytes > 0)
|
||||
{
|
||||
const int result = BIO_read(bio, buffer.data(), static_cast<int>(bytes));
|
||||
if (result <= 0)
|
||||
{
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
if (const StreamError error = base_->writeRaw(buffer.data(), result); error != StreamError::SUCCESS)
|
||||
{
|
||||
return error;
|
||||
}
|
||||
|
||||
bytes = BIO_ctrl_pending(bio);
|
||||
}
|
||||
return StreamError::SUCCESS;
|
||||
}
|
||||
|
||||
StreamError SSLStream::baseToBio() noexcept
|
||||
{
|
||||
BIO* bio = static_cast<BIO*>(bioA_);
|
||||
std::array<std::uint8_t, BIO_BUFFER_SIZE> buffer;
|
||||
while (true)
|
||||
{
|
||||
std::size_t bytes = 0;
|
||||
std::size_t maxBytes = std::min(BIO_BUFFER_SIZE - BIO_ctrl_wpending(bio), buffer.size());
|
||||
|
||||
if (maxBytes == 0)
|
||||
{
|
||||
// buffer is full
|
||||
return StreamError::SUCCESS;
|
||||
}
|
||||
|
||||
if (const StreamError error = base_->readRaw(buffer.data(), maxBytes,{.partial = true, .noBlock = true},
|
||||
&bytes); error != StreamError::SUCCESS)
|
||||
{
|
||||
return error;
|
||||
}
|
||||
if (bytes == 0)
|
||||
{
|
||||
// nothing more to read
|
||||
return StreamError::SUCCESS;
|
||||
}
|
||||
const int result = BIO_write(bio, buffer.data(), static_cast<int>(bytes));
|
||||
if (result <= 0)
|
||||
{
|
||||
return StreamError::UNKNOWN_ERROR;
|
||||
}
|
||||
MIJIN_ASSERT(result == static_cast<int>(bytes), "BIO reported more bytes in buffer than it actually accepted?");
|
||||
}
|
||||
}
|
||||
} // namespace shiken
|
44
source/mijin/net/ssl.hpp
Normal file
44
source/mijin/net/ssl.hpp
Normal file
@ -0,0 +1,44 @@
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#if !defined(MIJIN_NET_SSL_HPP_INCLUDED)
|
||||
#define MIJIN_NET_SSL_HPP_INCLUDED 1
|
||||
|
||||
#if !MIJIN_ENABLE_OPENSSL
|
||||
#error "SSL support not enabled. Set MIJIN_ENABLE_OPENSSL to True in your environment settings."
|
||||
#endif // !MIJIN_ENABLE_OPENSSL
|
||||
|
||||
#include <memory>
|
||||
#include "../io/stream.hpp"
|
||||
|
||||
namespace mijin
|
||||
{
|
||||
class SSLStream : public Stream
|
||||
{
|
||||
private:
|
||||
Stream* base_ = nullptr;
|
||||
void* ssl_ = nullptr;
|
||||
void* bioA_ = nullptr;
|
||||
void* bioB_ = nullptr;
|
||||
public:
|
||||
StreamError open(Stream& base, const std::string& hostname) noexcept;
|
||||
void close() noexcept;
|
||||
|
||||
StreamError readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options, std::size_t* outBytesRead) override;
|
||||
StreamError writeRaw(std::span<const std::uint8_t> buffer) override;
|
||||
std::size_t tell() override;
|
||||
StreamError seek(std::intptr_t pos, SeekMode seekMode) override;
|
||||
void flush() override;
|
||||
bool isAtEnd() override;
|
||||
StreamFeatures getFeatures() override;
|
||||
private:
|
||||
StreamError bioToBase() noexcept;
|
||||
StreamError baseToBio() noexcept;
|
||||
// mijin::Task<StreamError> c_readRaw(std::span<std::uint8_t> buffer, const ReadOptions& options = {}, std::size_t* outBytesRead = nullptr) override;
|
||||
// mijin::Task<StreamError> c_writeRaw(std::span<const std::uint8_t> buffer) override;
|
||||
};
|
||||
}
|
||||
|
||||
#endif // !defined(MIJIN_NET_SSL_HPP_INCLUDED)
|
||||
|
Loading…
x
Reference in New Issue
Block a user