#include "./http.hpp" #include #include "../util/iterators.hpp" #include "../util/string.hpp" #define MIJIN_HTTP_WRITE(text) \ do \ { \ if (const StreamError error = co_await base_->c_writeText(text); error != StreamError::SUCCESS) \ { \ co_return error; \ } \ } while(false) #define MIJIN_HTTP_CHECKREAD(read) \ do \ { \ if (const StreamError error = co_await read; error != StreamError::SUCCESS) \ { \ co_return error; \ } \ } while(false) #define MIJIN_HTTP_READLINE(text) MIJIN_HTTP_CHECKREAD(base_->c_readLine(text)); text = trim(text) namespace mijin { namespace { inline constexpr std::size_t CONTENT_LENGTH_LIMIT = 100 << 20; // 100MiB bool parseHTTPVersion(std::string_view version, HTTPVersion& outVersion) noexcept { std::vector parts = split(version, "."); if (parts.size() != 2) { return false; } return toNumber(parts[0], outVersion.major) && toNumber(parts[1], outVersion.minor); } } Task> HTTPStream::c_request(HTTPRequest request) noexcept { if (const StreamError error = co_await c_writeRequest(request); error != StreamError::SUCCESS) { co_return error; } co_return co_await c_readResponse(); } Task HTTPStream::c_writeRequest(const mijin::HTTPRequest& request) noexcept { std::map moreHeaders; if (!request.body.empty()) { auto itLength = request.headers.find("content-length"); if (itLength == request.headers.end()) { moreHeaders.emplace("content-length", std::to_string(request.body.size())); } else { std::size_t headerValue = 0; if (!toNumber(itLength->second, headerValue) || headerValue != request.body.size()) { co_return StreamError::PROTOCOL_ERROR; } } } MIJIN_HTTP_WRITE(std::format("{} {} HTTP/{}.{}\n", request.method, request.address, request.version.major, request.version.minor)); for (const auto& [key, value] : moreHeaders) { MIJIN_HTTP_WRITE(std::format("{}: {}\n", key, value)); } for (const auto& [key, value] : request.headers) { MIJIN_HTTP_WRITE(std::format("{}: {}\n", key, value)); } MIJIN_HTTP_WRITE("\n"); if (!request.body.empty()) { MIJIN_HTTP_WRITE(request.body); } co_return StreamError::SUCCESS; } Task> HTTPStream::c_readResponse() noexcept { std::string line; MIJIN_HTTP_READLINE(line); std::vector parts = split(line, " ", {.limitParts = 3}); if (parts.size() != 3) { co_return StreamError::PROTOCOL_ERROR; } if (!parts[0].starts_with("HTTP/")) { co_return StreamError::PROTOCOL_ERROR; } HTTPResponse response; if (!parseHTTPVersion(parts[0].substr(5), response.version) || !toNumber(parts[1], response.status)) { co_return StreamError::PROTOCOL_ERROR; } response.statusMessage = parts[2]; decltype(response.headers)::iterator lastHeader; while (true) { MIJIN_HTTP_READLINE(line); if (line.empty()) { break; } if (line[0] == ' ' || line[0] == '\t') { // continuation if (lastHeader == response.headers.end()) { co_return StreamError::PROTOCOL_ERROR; } lastHeader->second.push_back(' '); lastHeader->second.append(trim(line)); } parts = split(line, ":", {.limitParts = 2}); if (parts.size() != 2) { co_return StreamError::PROTOCOL_ERROR; } lastHeader = response.headers.emplace(toLower(trim(parts[0])), trim(parts[1])); } auto itContentLength = response.headers.find("content-length"); if (itContentLength != response.headers.end()) { std::size_t contentLength = 0; if (!toNumber(itContentLength->second, contentLength)) { co_return StreamError::PROTOCOL_ERROR; } if (contentLength > CONTENT_LENGTH_LIMIT) { co_return StreamError::PROTOCOL_ERROR; } response.content.resize(contentLength); MIJIN_HTTP_CHECKREAD(base_->c_readRaw(response.content)); } co_return response; } Task> HTTPClient::c_request(ip_address_t address, std::uint16_t port, bool https, HTTPRequest request) noexcept { if (const StreamError error = createSocket(address, port, https); error != StreamError::SUCCESS) { co_return error; } if (!request.headers.contains("connection")) { request.headers.emplace("connection", "keep-alive"); } StreamResult response = co_await stream_->c_request(request); if (response.isError()) { disconnect(); if (const StreamError error = createSocket(address, port, https); error != StreamError::SUCCESS) { co_return error; } response = co_await stream_->c_request(request); } co_return response; } Task> HTTPClient::c_request(const URL& url, HTTPRequest request) noexcept { if (url.getHost().empty()) { co_return StreamError::UNKNOWN_ERROR; } std::uint16_t port = url.getPort(); bool https = false; if (equalsIgnoreCase(url.getScheme(), "http")) { port = (port != 0) ? port : 80; } else if (equalsIgnoreCase(url.getScheme(), "https")) { port = (port != 0) ? port : 443; https = true; } else { co_return StreamError::UNKNOWN_ERROR; } Optional ipAddress = ipAddressFromString(url.getHost()); // TODO: lookup host if (ipAddress.empty()) { co_return StreamError::UNKNOWN_ERROR; } if (!request.headers.contains("host")) { request.headers.emplace("host", url.getHost()); } request.address = url.getPathQueryFragment(); if (request.address.empty()) { request.address = "/"; } co_return co_await c_request(*ipAddress, port, https, std::move(request)); } void HTTPClient::disconnect() noexcept { if (socket_ == nullptr) { return; } stream_.destroy(); socket_ = nullptr; } StreamError HTTPClient::createSocket(ip_address_t address, std::uint16_t port, bool https) noexcept { if (socket_ != nullptr && address == lastIP_ && port == lastPort_ && https == lastWasHttps_) { return StreamError::SUCCESS; } disconnect(); MIJIN_ASSERT(!https, "HTTPS not supported yet."); std::unique_ptr newSocket = std::make_unique(); if (const StreamError error = newSocket->open(address, port); error != StreamError::SUCCESS) { return error; } lastIP_ = address; lastPort_ = port; lastWasHttps_ = https; socket_ = std::move(newSocket); stream_.construct(socket_->getStream()); return StreamError::SUCCESS; } } #undef MIJIN_HTTP_WRITE #undef MIJIN_HTTP_CHECKREAD #undef MIJIN_HTTP_READLINE