299 lines
		
	
	
		
			9.2 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			299 lines
		
	
	
		
			9.2 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
 | 
						|
#include "./http.hpp"
 | 
						|
 | 
						|
#include <format>
 | 
						|
 | 
						|
#include "../util/iterators.hpp"
 | 
						|
#include "../util/string.hpp"
 | 
						|
 | 
						|
#if defined(MIJIN_ENABLE_OPENSSL)
 | 
						|
#include "./ssl.hpp"
 | 
						|
#endif
 | 
						|
 | 
						|
#define MIJIN_HTTP_WRITE(text)                                                                      \
 | 
						|
do                                                                                                  \
 | 
						|
{                                                                                                   \
 | 
						|
    if (const StreamError error = co_await base_->c_writeText(text); error != StreamError::SUCCESS) \
 | 
						|
    {                                                                                               \
 | 
						|
        co_return error;                                                                            \
 | 
						|
    }                                                                                               \
 | 
						|
} while(false)
 | 
						|
 | 
						|
#define MIJIN_HTTP_CHECKREAD(read)                                                                 \
 | 
						|
do                                                                                                 \
 | 
						|
{                                                                                                  \
 | 
						|
    if (const StreamError error = co_await read; error != StreamError::SUCCESS) \
 | 
						|
    {                                                                                              \
 | 
						|
        co_return error;                                                                           \
 | 
						|
    }                                                                                              \
 | 
						|
} while(false)
 | 
						|
 | 
						|
#define MIJIN_HTTP_READLINE(text) MIJIN_HTTP_CHECKREAD(base_->c_readLine(text)); text = trim(text)
 | 
						|
 | 
						|
namespace mijin
 | 
						|
{
 | 
						|
namespace
 | 
						|
{
 | 
						|
inline constexpr std::size_t CONTENT_LENGTH_LIMIT = 100 << 20; // 100MiB
 | 
						|
bool parseHTTPVersion(std::string_view version, HTTPVersion& outVersion) MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    std::vector<std::string_view> parts = split(version, ".");
 | 
						|
    if (parts.size() != 2)
 | 
						|
    {
 | 
						|
        return false;
 | 
						|
    }
 | 
						|
    return toNumber(parts[0], outVersion.major) && toNumber(parts[1], outVersion.minor);
 | 
						|
}
 | 
						|
}
 | 
						|
 | 
						|
Task<StreamResult<HTTPResponse>> HTTPStream::c_request(HTTPRequest request) MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    if (const StreamError error = co_await c_writeRequest(request); error != StreamError::SUCCESS)
 | 
						|
    {
 | 
						|
        co_return error;
 | 
						|
    }
 | 
						|
    co_return co_await c_readResponse();
 | 
						|
}
 | 
						|
 | 
						|
Task<StreamError> HTTPStream::c_writeRequest(const mijin::HTTPRequest& request) MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    std::map<std::string, std::string> moreHeaders;
 | 
						|
    if (!request.body.empty())
 | 
						|
    {
 | 
						|
        auto itLength = request.headers.find("content-length");
 | 
						|
        if (itLength == request.headers.end())
 | 
						|
        {
 | 
						|
            moreHeaders.emplace("content-length", std::to_string(request.body.size()));
 | 
						|
        }
 | 
						|
        else
 | 
						|
        {
 | 
						|
            std::size_t headerValue = 0;
 | 
						|
            if (!toNumber(itLength->second, headerValue) || headerValue != request.body.size())
 | 
						|
            {
 | 
						|
                co_return StreamError::PROTOCOL_ERROR;
 | 
						|
            }
 | 
						|
        }
 | 
						|
    }
 | 
						|
 | 
						|
    MIJIN_HTTP_WRITE(std::format("{} {} HTTP/{}.{}\n", request.method, request.address, request.version.major, request.version.minor));
 | 
						|
    for (const auto& [key, value] : moreHeaders)
 | 
						|
    {
 | 
						|
        MIJIN_HTTP_WRITE(std::format("{}: {}\n", key, value));
 | 
						|
    }
 | 
						|
    for (const auto& [key, value] : request.headers)
 | 
						|
    {
 | 
						|
        MIJIN_HTTP_WRITE(std::format("{}: {}\n", key, value));
 | 
						|
    }
 | 
						|
 | 
						|
    MIJIN_HTTP_WRITE("\n");
 | 
						|
    if (!request.body.empty())
 | 
						|
    {
 | 
						|
        MIJIN_HTTP_WRITE(request.body);
 | 
						|
    }
 | 
						|
 | 
						|
    co_return StreamError::SUCCESS;
 | 
						|
}
 | 
						|
 | 
						|
Task<StreamResult<HTTPResponse>> HTTPStream::c_readResponse() MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    std::string line;
 | 
						|
    MIJIN_HTTP_READLINE(line);
 | 
						|
 | 
						|
    std::vector<std::string_view> parts = split(line, " ", {.limitParts = 3});
 | 
						|
    if (parts.size() != 3)
 | 
						|
    {
 | 
						|
        co_return StreamError::PROTOCOL_ERROR;
 | 
						|
    }
 | 
						|
    if (!parts[0].starts_with("HTTP/"))
 | 
						|
    {
 | 
						|
        co_return StreamError::PROTOCOL_ERROR;
 | 
						|
    }
 | 
						|
 | 
						|
    HTTPResponse response;
 | 
						|
    if (!parseHTTPVersion(parts[0].substr(5), response.version)
 | 
						|
        || !toNumber(parts[1], response.status))
 | 
						|
    {
 | 
						|
        co_return StreamError::PROTOCOL_ERROR;
 | 
						|
    }
 | 
						|
    response.statusMessage = parts[2];
 | 
						|
 | 
						|
    decltype(response.headers)::iterator lastHeader;
 | 
						|
    while (true)
 | 
						|
    {
 | 
						|
        MIJIN_HTTP_READLINE(line);
 | 
						|
        if (line.empty()) {
 | 
						|
            break;
 | 
						|
        }
 | 
						|
        if (line[0] == ' ' || line[0] == '\t')
 | 
						|
        {
 | 
						|
            // continuation
 | 
						|
            if (lastHeader == response.headers.end())
 | 
						|
            {
 | 
						|
                co_return StreamError::PROTOCOL_ERROR;
 | 
						|
            }
 | 
						|
            lastHeader->second.push_back(' ');
 | 
						|
            lastHeader->second.append(trim(line));
 | 
						|
        }
 | 
						|
        parts = split(line, ":", {.limitParts = 2});
 | 
						|
        if (parts.size() != 2)
 | 
						|
        {
 | 
						|
            co_return StreamError::PROTOCOL_ERROR;
 | 
						|
        }
 | 
						|
        lastHeader = response.headers.emplace(toLower(trim(parts[0])), trim(parts[1]));
 | 
						|
    }
 | 
						|
 | 
						|
    auto itContentLength = response.headers.find("content-length");
 | 
						|
    if (itContentLength != response.headers.end())
 | 
						|
    {
 | 
						|
        std::size_t contentLength = 0;
 | 
						|
        if (!toNumber(itContentLength->second, contentLength))
 | 
						|
        {
 | 
						|
            co_return StreamError::PROTOCOL_ERROR;
 | 
						|
        }
 | 
						|
        if (contentLength > CONTENT_LENGTH_LIMIT)
 | 
						|
        {
 | 
						|
            co_return StreamError::PROTOCOL_ERROR;
 | 
						|
        }
 | 
						|
        response.body.resize(contentLength);
 | 
						|
        MIJIN_HTTP_CHECKREAD(base_->c_readRaw(response.body));
 | 
						|
    }
 | 
						|
 | 
						|
    co_return response;
 | 
						|
}
 | 
						|
 | 
						|
Task<StreamResult<HTTPResponse>> HTTPClient::c_request(ip_address_t address, std::uint16_t port, bool https,
 | 
						|
                                                       HTTPRequest request) MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    std::string hostname;
 | 
						|
    if (auto it = request.headers.find("host"); it != request.headers.end())
 | 
						|
    {
 | 
						|
        hostname = it->second;
 | 
						|
    }
 | 
						|
    if (const StreamError error = createSocket(address, hostname, port, https); error != StreamError::SUCCESS)
 | 
						|
    {
 | 
						|
        co_return error;
 | 
						|
    }
 | 
						|
    if (!request.headers.contains("connection"))
 | 
						|
    {
 | 
						|
        request.headers.emplace("connection", "keep-alive");
 | 
						|
    }
 | 
						|
    StreamResult<HTTPResponse> response = co_await stream_->c_request(request);
 | 
						|
    if (response.isError())
 | 
						|
    {
 | 
						|
        disconnect();
 | 
						|
        if (const StreamError error = createSocket(address, hostname, port, https); error != StreamError::SUCCESS)
 | 
						|
        {
 | 
						|
            co_return error;
 | 
						|
        }
 | 
						|
        response = co_await stream_->c_request(request);
 | 
						|
    }
 | 
						|
    co_return response;
 | 
						|
}
 | 
						|
 | 
						|
Task<StreamResult<HTTPResponse>> HTTPClient::c_request(const URL& url, HTTPRequest request) MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    if (url.getHost().empty())
 | 
						|
    {
 | 
						|
        co_return StreamError::UNKNOWN_ERROR;
 | 
						|
    }
 | 
						|
 | 
						|
    std::uint16_t port = url.getPort();
 | 
						|
    bool https = false;
 | 
						|
    if (equalsIgnoreCase(url.getScheme(), "http"))
 | 
						|
    {
 | 
						|
        port = (port != 0) ? port : 80;
 | 
						|
    }
 | 
						|
    else if (equalsIgnoreCase(url.getScheme(), "https"))
 | 
						|
    {
 | 
						|
        port = (port != 0) ? port : 443;
 | 
						|
        https = true;
 | 
						|
    }
 | 
						|
    else
 | 
						|
    {
 | 
						|
        co_return StreamError::UNKNOWN_ERROR;
 | 
						|
    }
 | 
						|
    Optional<ip_address_t> ipAddress = ipAddressFromString(url.getHost());
 | 
						|
    if (ipAddress.empty())
 | 
						|
    {
 | 
						|
        StreamResult<std::vector<ip_address_t>> addresses = co_await c_resolveHostname(url.getHost());
 | 
						|
        if (addresses.isError())
 | 
						|
        {
 | 
						|
            co_return addresses.getError();
 | 
						|
        }
 | 
						|
        else if (addresses->empty())
 | 
						|
        {
 | 
						|
            co_return StreamError::UNKNOWN_ERROR;
 | 
						|
        }
 | 
						|
        // TODO: try all addresses
 | 
						|
        ipAddress = addresses->front();
 | 
						|
    }
 | 
						|
 | 
						|
    if (!request.headers.contains("host"))
 | 
						|
    {
 | 
						|
        request.headers.emplace("host", url.getHost());
 | 
						|
    }
 | 
						|
    request.address = url.getPathQueryFragment();
 | 
						|
    if (request.address.empty())
 | 
						|
    {
 | 
						|
        request.address = "/";
 | 
						|
    }
 | 
						|
 | 
						|
    co_return co_await c_request(*ipAddress, port, https, std::move(request));
 | 
						|
}
 | 
						|
 | 
						|
void HTTPClient::disconnect() MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    if (socket_ == nullptr)
 | 
						|
    {
 | 
						|
        return;
 | 
						|
    }
 | 
						|
    stream_.destroy();
 | 
						|
    socket_ = nullptr;
 | 
						|
}
 | 
						|
 | 
						|
StreamError HTTPClient::createSocket(ip_address_t address, const std::string& hostname, std::uint16_t port, bool https) MIJIN_NOEXCEPT
 | 
						|
{
 | 
						|
    if (socket_ != nullptr && address == lastIP_ && port == lastPort_ && https == lastWasHttps_)
 | 
						|
    {
 | 
						|
        return StreamError::SUCCESS;
 | 
						|
    }
 | 
						|
 | 
						|
    disconnect();
 | 
						|
 | 
						|
    std::unique_ptr<TCPSocket> newSocket = std::make_unique<TCPSocket>();
 | 
						|
    if (const StreamError error = newSocket->open(address, port); error != StreamError::SUCCESS)
 | 
						|
    {
 | 
						|
        return error;
 | 
						|
    }
 | 
						|
    socket_ = std::move(newSocket);
 | 
						|
    if (!https)
 | 
						|
    {
 | 
						|
        sslStream_.reset();
 | 
						|
        stream_.construct(socket_->getStream());
 | 
						|
    }
 | 
						|
    else
 | 
						|
    {
 | 
						|
#if defined(MIJIN_ENABLE_OPENSSL)
 | 
						|
        std::unique_ptr<SSLStream> sslStream = std::make_unique<SSLStream>();
 | 
						|
        if (const StreamError error = sslStream->open(socket_->getStream(), hostname); error != StreamError::SUCCESS)
 | 
						|
        {
 | 
						|
            return error;
 | 
						|
        }
 | 
						|
        sslStream_ = std::move(sslStream);
 | 
						|
        stream_.construct(*sslStream_);
 | 
						|
#else
 | 
						|
        (void) hostname;
 | 
						|
        return StreamError::NOT_SUPPORTED;
 | 
						|
#endif
 | 
						|
    }
 | 
						|
 | 
						|
    lastIP_ = address;
 | 
						|
    lastPort_ = port;
 | 
						|
    lastWasHttps_ = https;
 | 
						|
    return StreamError::SUCCESS;
 | 
						|
}
 | 
						|
}
 | 
						|
 | 
						|
#undef MIJIN_HTTP_WRITE
 | 
						|
#undef MIJIN_HTTP_CHECKREAD
 | 
						|
#undef MIJIN_HTTP_READLINE |