#include "./http.hpp" #include #include "../util/iterators.hpp" #include "../util/string.hpp" #if defined(MIJIN_ENABLE_OPENSSL) #include "./ssl.hpp" #endif #define MIJIN_HTTP_WRITE(text) \ do \ { \ if (const StreamError error = co_await base_->c_writeText(text); error != StreamError::SUCCESS) \ { \ co_return error; \ } \ } while(false) #define MIJIN_HTTP_CHECKREAD(read) \ do \ { \ if (const StreamError error = co_await read; error != StreamError::SUCCESS) \ { \ co_return error; \ } \ } while(false) #define MIJIN_HTTP_READLINE(text) MIJIN_HTTP_CHECKREAD(base_->c_readLine(text)); text = trim(text) namespace mijin { namespace { inline constexpr std::size_t CONTENT_LENGTH_LIMIT = 100 << 20; // 100MiB bool parseHTTPVersion(std::string_view version, HTTPVersion& outVersion) MIJIN_NOEXCEPT { std::vector parts = split(version, "."); if (parts.size() != 2) { return false; } return toNumber(parts[0], outVersion.major) && toNumber(parts[1], outVersion.minor); } } Task> HTTPStream::c_request(HTTPRequest request) MIJIN_NOEXCEPT { if (const StreamError error = co_await c_writeRequest(request); error != StreamError::SUCCESS) { co_return error; } co_return co_await c_readResponse(); } Task HTTPStream::c_writeRequest(const mijin::HTTPRequest& request) MIJIN_NOEXCEPT { std::map moreHeaders; if (!request.body.empty()) { auto itLength = request.headers.find("content-length"); if (itLength == request.headers.end()) { moreHeaders.emplace("content-length", std::to_string(request.body.size())); } else { std::size_t headerValue = 0; if (!toNumber(itLength->second, headerValue) || headerValue != request.body.size()) { co_return StreamError::PROTOCOL_ERROR; } } } MIJIN_HTTP_WRITE(std::format("{} {} HTTP/{}.{}\n", request.method, request.address, request.version.major, request.version.minor)); for (const auto& [key, value] : moreHeaders) { MIJIN_HTTP_WRITE(std::format("{}: {}\n", key, value)); } for (const auto& [key, value] : request.headers) { MIJIN_HTTP_WRITE(std::format("{}: {}\n", key, value)); } MIJIN_HTTP_WRITE("\n"); if (!request.body.empty()) { MIJIN_HTTP_WRITE(request.body); } co_return StreamError::SUCCESS; } Task> HTTPStream::c_readResponse() MIJIN_NOEXCEPT { std::string line; MIJIN_HTTP_READLINE(line); std::vector parts = split(line, " ", {.limitParts = 3}); if (parts.size() != 3) { co_return StreamError::PROTOCOL_ERROR; } if (!parts[0].starts_with("HTTP/")) { co_return StreamError::PROTOCOL_ERROR; } HTTPResponse response; if (!parseHTTPVersion(parts[0].substr(5), response.version) || !toNumber(parts[1], response.status)) { co_return StreamError::PROTOCOL_ERROR; } response.statusMessage = parts[2]; decltype(response.headers)::iterator lastHeader; while (true) { MIJIN_HTTP_READLINE(line); if (line.empty()) { break; } if (line[0] == ' ' || line[0] == '\t') { // continuation if (lastHeader == response.headers.end()) { co_return StreamError::PROTOCOL_ERROR; } lastHeader->second.push_back(' '); lastHeader->second.append(trim(line)); } parts = split(line, ":", {.limitParts = 2}); if (parts.size() != 2) { co_return StreamError::PROTOCOL_ERROR; } lastHeader = response.headers.emplace(toLower(trim(parts[0])), trim(parts[1])); } auto itContentLength = response.headers.find("content-length"); if (itContentLength != response.headers.end()) { std::size_t contentLength = 0; if (!toNumber(itContentLength->second, contentLength)) { co_return StreamError::PROTOCOL_ERROR; } if (contentLength > CONTENT_LENGTH_LIMIT) { co_return StreamError::PROTOCOL_ERROR; } response.body.resize(contentLength); MIJIN_HTTP_CHECKREAD(base_->c_readRaw(response.body)); } co_return response; } Task> HTTPClient::c_request(ip_address_t address, std::uint16_t port, bool https, HTTPRequest request) MIJIN_NOEXCEPT { std::string hostname; if (auto it = request.headers.find("host"); it != request.headers.end()) { hostname = it->second; } if (const StreamError error = createSocket(address, hostname, port, https); error != StreamError::SUCCESS) { co_return error; } if (!request.headers.contains("connection")) { request.headers.emplace("connection", "keep-alive"); } StreamResult response = co_await stream_->c_request(request); if (response.isError()) { disconnect(); if (const StreamError error = createSocket(address, hostname, port, https); error != StreamError::SUCCESS) { co_return error; } response = co_await stream_->c_request(request); } co_return response; } Task> HTTPClient::c_request(const URL& url, HTTPRequest request) MIJIN_NOEXCEPT { if (url.getHost().empty()) { co_return StreamError::UNKNOWN_ERROR; } std::uint16_t port = url.getPort(); bool https = false; if (equalsIgnoreCase(url.getScheme(), "http")) { port = (port != 0) ? port : 80; } else if (equalsIgnoreCase(url.getScheme(), "https")) { port = (port != 0) ? port : 443; https = true; } else { co_return StreamError::UNKNOWN_ERROR; } Optional ipAddress = ipAddressFromString(url.getHost()); if (ipAddress.empty()) { StreamResult> addresses = co_await c_resolveHostname(url.getHost()); if (addresses.isError()) { co_return addresses.getError(); } else if (addresses->empty()) { co_return StreamError::UNKNOWN_ERROR; } // TODO: try all addresses ipAddress = addresses->front(); } if (!request.headers.contains("host")) { request.headers.emplace("host", url.getHost()); } request.address = url.getPathQueryFragment(); if (request.address.empty()) { request.address = "/"; } co_return co_await c_request(*ipAddress, port, https, std::move(request)); } void HTTPClient::disconnect() MIJIN_NOEXCEPT { if (socket_ == nullptr) { return; } stream_.destroy(); socket_ = nullptr; } StreamError HTTPClient::createSocket(ip_address_t address, const std::string& hostname, std::uint16_t port, bool https) MIJIN_NOEXCEPT { if (socket_ != nullptr && address == lastIP_ && port == lastPort_ && https == lastWasHttps_) { return StreamError::SUCCESS; } disconnect(); std::unique_ptr newSocket = std::make_unique(); if (const StreamError error = newSocket->open(address, port); error != StreamError::SUCCESS) { return error; } socket_ = std::move(newSocket); if (!https) { sslStream_.reset(); stream_.construct(socket_->getStream()); } else { #if defined(MIJIN_ENABLE_OPENSSL) std::unique_ptr sslStream = std::make_unique(); if (const StreamError error = sslStream->open(socket_->getStream(), hostname); error != StreamError::SUCCESS) { return error; } sslStream_ = std::move(sslStream); stream_.construct(*sslStream_); #else (void) hostname; return StreamError::NOT_SUPPORTED; #endif } lastIP_ = address; lastPort_ = port; lastWasHttps_ = https; return StreamError::SUCCESS; } } #undef MIJIN_HTTP_WRITE #undef MIJIN_HTTP_CHECKREAD #undef MIJIN_HTTP_READLINE