298 lines
9.1 KiB
C++
298 lines
9.1 KiB
C++
|
|
#include "./http.hpp"
|
|
|
|
#include <format>
|
|
|
|
#include "../util/iterators.hpp"
|
|
#include "../util/string.hpp"
|
|
|
|
#if defined(MIJIN_ENABLE_OPENSSL)
|
|
#include "./ssl.hpp"
|
|
#endif
|
|
|
|
#define MIJIN_HTTP_WRITE(text) \
|
|
do \
|
|
{ \
|
|
if (const StreamError error = co_await base_->c_writeText(text); error != StreamError::SUCCESS) \
|
|
{ \
|
|
co_return error; \
|
|
} \
|
|
} while(false)
|
|
|
|
#define MIJIN_HTTP_CHECKREAD(read) \
|
|
do \
|
|
{ \
|
|
if (const StreamError error = co_await read; error != StreamError::SUCCESS) \
|
|
{ \
|
|
co_return error; \
|
|
} \
|
|
} while(false)
|
|
|
|
#define MIJIN_HTTP_READLINE(text) MIJIN_HTTP_CHECKREAD(base_->c_readLine(text)); text = trim(text)
|
|
|
|
namespace mijin
|
|
{
|
|
namespace
|
|
{
|
|
inline constexpr std::size_t CONTENT_LENGTH_LIMIT = 100 << 20; // 100MiB
|
|
bool parseHTTPVersion(std::string_view version, HTTPVersion& outVersion) noexcept
|
|
{
|
|
std::vector<std::string_view> parts = split(version, ".");
|
|
if (parts.size() != 2)
|
|
{
|
|
return false;
|
|
}
|
|
return toNumber(parts[0], outVersion.major) && toNumber(parts[1], outVersion.minor);
|
|
}
|
|
}
|
|
|
|
Task<StreamResult<HTTPResponse>> HTTPStream::c_request(HTTPRequest request) noexcept
|
|
{
|
|
if (const StreamError error = co_await c_writeRequest(request); error != StreamError::SUCCESS)
|
|
{
|
|
co_return error;
|
|
}
|
|
co_return co_await c_readResponse();
|
|
}
|
|
|
|
Task<StreamError> HTTPStream::c_writeRequest(const mijin::HTTPRequest& request) noexcept
|
|
{
|
|
std::map<std::string, std::string> moreHeaders;
|
|
if (!request.body.empty())
|
|
{
|
|
auto itLength = request.headers.find("content-length");
|
|
if (itLength == request.headers.end())
|
|
{
|
|
moreHeaders.emplace("content-length", std::to_string(request.body.size()));
|
|
}
|
|
else
|
|
{
|
|
std::size_t headerValue = 0;
|
|
if (!toNumber(itLength->second, headerValue) || headerValue != request.body.size())
|
|
{
|
|
co_return StreamError::PROTOCOL_ERROR;
|
|
}
|
|
}
|
|
}
|
|
|
|
MIJIN_HTTP_WRITE(std::format("{} {} HTTP/{}.{}\n", request.method, request.address, request.version.major, request.version.minor));
|
|
for (const auto& [key, value] : moreHeaders)
|
|
{
|
|
MIJIN_HTTP_WRITE(std::format("{}: {}\n", key, value));
|
|
}
|
|
for (const auto& [key, value] : request.headers)
|
|
{
|
|
MIJIN_HTTP_WRITE(std::format("{}: {}\n", key, value));
|
|
}
|
|
|
|
MIJIN_HTTP_WRITE("\n");
|
|
if (!request.body.empty())
|
|
{
|
|
MIJIN_HTTP_WRITE(request.body);
|
|
}
|
|
|
|
co_return StreamError::SUCCESS;
|
|
}
|
|
|
|
Task<StreamResult<HTTPResponse>> HTTPStream::c_readResponse() noexcept
|
|
{
|
|
std::string line;
|
|
MIJIN_HTTP_READLINE(line);
|
|
|
|
std::vector<std::string_view> parts = split(line, " ", {.limitParts = 3});
|
|
if (parts.size() != 3)
|
|
{
|
|
co_return StreamError::PROTOCOL_ERROR;
|
|
}
|
|
if (!parts[0].starts_with("HTTP/"))
|
|
{
|
|
co_return StreamError::PROTOCOL_ERROR;
|
|
}
|
|
|
|
HTTPResponse response;
|
|
if (!parseHTTPVersion(parts[0].substr(5), response.version)
|
|
|| !toNumber(parts[1], response.status))
|
|
{
|
|
co_return StreamError::PROTOCOL_ERROR;
|
|
}
|
|
response.statusMessage = parts[2];
|
|
|
|
decltype(response.headers)::iterator lastHeader;
|
|
while (true)
|
|
{
|
|
MIJIN_HTTP_READLINE(line);
|
|
if (line.empty()) {
|
|
break;
|
|
}
|
|
if (line[0] == ' ' || line[0] == '\t')
|
|
{
|
|
// continuation
|
|
if (lastHeader == response.headers.end())
|
|
{
|
|
co_return StreamError::PROTOCOL_ERROR;
|
|
}
|
|
lastHeader->second.push_back(' ');
|
|
lastHeader->second.append(trim(line));
|
|
}
|
|
parts = split(line, ":", {.limitParts = 2});
|
|
if (parts.size() != 2)
|
|
{
|
|
co_return StreamError::PROTOCOL_ERROR;
|
|
}
|
|
lastHeader = response.headers.emplace(toLower(trim(parts[0])), trim(parts[1]));
|
|
}
|
|
|
|
auto itContentLength = response.headers.find("content-length");
|
|
if (itContentLength != response.headers.end())
|
|
{
|
|
std::size_t contentLength = 0;
|
|
if (!toNumber(itContentLength->second, contentLength))
|
|
{
|
|
co_return StreamError::PROTOCOL_ERROR;
|
|
}
|
|
if (contentLength > CONTENT_LENGTH_LIMIT)
|
|
{
|
|
co_return StreamError::PROTOCOL_ERROR;
|
|
}
|
|
response.content.resize(contentLength);
|
|
MIJIN_HTTP_CHECKREAD(base_->c_readRaw(response.content));
|
|
}
|
|
|
|
co_return response;
|
|
}
|
|
|
|
Task<StreamResult<HTTPResponse>> HTTPClient::c_request(ip_address_t address, std::uint16_t port, bool https,
|
|
HTTPRequest request) noexcept
|
|
{
|
|
std::string hostname;
|
|
if (auto it = request.headers.find("host"); it != request.headers.end())
|
|
{
|
|
hostname = it->second;
|
|
}
|
|
if (const StreamError error = createSocket(address, hostname, port, https); error != StreamError::SUCCESS)
|
|
{
|
|
co_return error;
|
|
}
|
|
if (!request.headers.contains("connection"))
|
|
{
|
|
request.headers.emplace("connection", "keep-alive");
|
|
}
|
|
StreamResult<HTTPResponse> response = co_await stream_->c_request(request);
|
|
if (response.isError())
|
|
{
|
|
disconnect();
|
|
if (const StreamError error = createSocket(address, hostname, port, https); error != StreamError::SUCCESS)
|
|
{
|
|
co_return error;
|
|
}
|
|
response = co_await stream_->c_request(request);
|
|
}
|
|
co_return response;
|
|
}
|
|
|
|
Task<StreamResult<HTTPResponse>> HTTPClient::c_request(const URL& url, HTTPRequest request) noexcept
|
|
{
|
|
if (url.getHost().empty())
|
|
{
|
|
co_return StreamError::UNKNOWN_ERROR;
|
|
}
|
|
|
|
std::uint16_t port = url.getPort();
|
|
bool https = false;
|
|
if (equalsIgnoreCase(url.getScheme(), "http"))
|
|
{
|
|
port = (port != 0) ? port : 80;
|
|
}
|
|
else if (equalsIgnoreCase(url.getScheme(), "https"))
|
|
{
|
|
port = (port != 0) ? port : 443;
|
|
https = true;
|
|
}
|
|
else
|
|
{
|
|
co_return StreamError::UNKNOWN_ERROR;
|
|
}
|
|
Optional<ip_address_t> ipAddress = ipAddressFromString(url.getHost());
|
|
if (ipAddress.empty())
|
|
{
|
|
StreamResult<std::vector<ip_address_t>> addresses = co_await c_resolveHostname(url.getHost());
|
|
if (addresses.isError())
|
|
{
|
|
co_return addresses.getError();
|
|
}
|
|
else if (addresses->empty())
|
|
{
|
|
co_return StreamError::UNKNOWN_ERROR;
|
|
}
|
|
// TODO: try all addresses
|
|
ipAddress = addresses->front();
|
|
}
|
|
|
|
if (!request.headers.contains("host"))
|
|
{
|
|
request.headers.emplace("host", url.getHost());
|
|
}
|
|
request.address = url.getPathQueryFragment();
|
|
if (request.address.empty())
|
|
{
|
|
request.address = "/";
|
|
}
|
|
|
|
co_return co_await c_request(*ipAddress, port, https, std::move(request));
|
|
}
|
|
|
|
void HTTPClient::disconnect() noexcept
|
|
{
|
|
if (socket_ == nullptr)
|
|
{
|
|
return;
|
|
}
|
|
stream_.destroy();
|
|
socket_ = nullptr;
|
|
}
|
|
|
|
StreamError HTTPClient::createSocket(ip_address_t address, const std::string& hostname, std::uint16_t port, bool https) noexcept
|
|
{
|
|
if (socket_ != nullptr && address == lastIP_ && port == lastPort_ && https == lastWasHttps_)
|
|
{
|
|
return StreamError::SUCCESS;
|
|
}
|
|
|
|
disconnect();
|
|
|
|
std::unique_ptr<TCPSocket> newSocket = std::make_unique<TCPSocket>();
|
|
if (const StreamError error = newSocket->open(address, port); error != StreamError::SUCCESS)
|
|
{
|
|
return error;
|
|
}
|
|
socket_ = std::move(newSocket);
|
|
if (!https)
|
|
{
|
|
sslStream_.reset();
|
|
stream_.construct(socket_->getStream());
|
|
}
|
|
else
|
|
{
|
|
#if defined(MIJIN_ENABLE_OPENSSL)
|
|
std::unique_ptr<SSLStream> sslStream = std::make_unique<SSLStream>();
|
|
if (const StreamError error = sslStream->open(socket_->getStream(), hostname); error != StreamError::SUCCESS)
|
|
{
|
|
return error;
|
|
}
|
|
sslStream_ = std::move(sslStream);
|
|
stream_.construct(*sslStream_);
|
|
#else
|
|
return StreamError::NOT_SUPPORTED;
|
|
#endif
|
|
}
|
|
|
|
lastIP_ = address;
|
|
lastPort_ = port;
|
|
lastWasHttps_ = https;
|
|
return StreamError::SUCCESS;
|
|
}
|
|
}
|
|
|
|
#undef MIJIN_HTTP_WRITE
|
|
#undef MIJIN_HTTP_CHECKREAD
|
|
#undef MIJIN_HTTP_READLINE |