299 lines
		
	
	
		
			9.2 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			299 lines
		
	
	
		
			9.2 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| 
 | |
| #include "./http.hpp"
 | |
| 
 | |
| #include <format>
 | |
| 
 | |
| #include "../util/iterators.hpp"
 | |
| #include "../util/string.hpp"
 | |
| 
 | |
| #if defined(MIJIN_ENABLE_OPENSSL)
 | |
| #include "./ssl.hpp"
 | |
| #endif
 | |
| 
 | |
| #define MIJIN_HTTP_WRITE(text)                                                                      \
 | |
| do                                                                                                  \
 | |
| {                                                                                                   \
 | |
|     if (const StreamError error = co_await base_->c_writeText(text); error != StreamError::SUCCESS) \
 | |
|     {                                                                                               \
 | |
|         co_return error;                                                                            \
 | |
|     }                                                                                               \
 | |
| } while(false)
 | |
| 
 | |
| #define MIJIN_HTTP_CHECKREAD(read)                                                                 \
 | |
| do                                                                                                 \
 | |
| {                                                                                                  \
 | |
|     if (const StreamError error = co_await read; error != StreamError::SUCCESS) \
 | |
|     {                                                                                              \
 | |
|         co_return error;                                                                           \
 | |
|     }                                                                                              \
 | |
| } while(false)
 | |
| 
 | |
| #define MIJIN_HTTP_READLINE(text) MIJIN_HTTP_CHECKREAD(base_->c_readLine(text)); text = trim(text)
 | |
| 
 | |
| namespace mijin
 | |
| {
 | |
| namespace
 | |
| {
 | |
| inline constexpr std::size_t CONTENT_LENGTH_LIMIT = 100 << 20; // 100MiB
 | |
| bool parseHTTPVersion(std::string_view version, HTTPVersion& outVersion) MIJIN_NOEXCEPT
 | |
| {
 | |
|     std::vector<std::string_view> parts = split(version, ".");
 | |
|     if (parts.size() != 2)
 | |
|     {
 | |
|         return false;
 | |
|     }
 | |
|     return toNumber(parts[0], outVersion.major) && toNumber(parts[1], outVersion.minor);
 | |
| }
 | |
| }
 | |
| 
 | |
| Task<StreamResult<HTTPResponse>> HTTPStream::c_request(HTTPRequest request) MIJIN_NOEXCEPT
 | |
| {
 | |
|     if (const StreamError error = co_await c_writeRequest(request); error != StreamError::SUCCESS)
 | |
|     {
 | |
|         co_return error;
 | |
|     }
 | |
|     co_return co_await c_readResponse();
 | |
| }
 | |
| 
 | |
| Task<StreamError> HTTPStream::c_writeRequest(const mijin::HTTPRequest& request) MIJIN_NOEXCEPT
 | |
| {
 | |
|     std::map<std::string, std::string> moreHeaders;
 | |
|     if (!request.body.empty())
 | |
|     {
 | |
|         auto itLength = request.headers.find("content-length");
 | |
|         if (itLength == request.headers.end())
 | |
|         {
 | |
|             moreHeaders.emplace("content-length", std::to_string(request.body.size()));
 | |
|         }
 | |
|         else
 | |
|         {
 | |
|             std::size_t headerValue = 0;
 | |
|             if (!toNumber(itLength->second, headerValue) || headerValue != request.body.size())
 | |
|             {
 | |
|                 co_return StreamError::PROTOCOL_ERROR;
 | |
|             }
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     MIJIN_HTTP_WRITE(std::format("{} {} HTTP/{}.{}\n", request.method, request.address, request.version.major, request.version.minor));
 | |
|     for (const auto& [key, value] : moreHeaders)
 | |
|     {
 | |
|         MIJIN_HTTP_WRITE(std::format("{}: {}\n", key, value));
 | |
|     }
 | |
|     for (const auto& [key, value] : request.headers)
 | |
|     {
 | |
|         MIJIN_HTTP_WRITE(std::format("{}: {}\n", key, value));
 | |
|     }
 | |
| 
 | |
|     MIJIN_HTTP_WRITE("\n");
 | |
|     if (!request.body.empty())
 | |
|     {
 | |
|         MIJIN_HTTP_WRITE(request.body);
 | |
|     }
 | |
| 
 | |
|     co_return StreamError::SUCCESS;
 | |
| }
 | |
| 
 | |
| Task<StreamResult<HTTPResponse>> HTTPStream::c_readResponse() MIJIN_NOEXCEPT
 | |
| {
 | |
|     std::string line;
 | |
|     MIJIN_HTTP_READLINE(line);
 | |
| 
 | |
|     std::vector<std::string_view> parts = split(line, " ", {.limitParts = 3});
 | |
|     if (parts.size() != 3)
 | |
|     {
 | |
|         co_return StreamError::PROTOCOL_ERROR;
 | |
|     }
 | |
|     if (!parts[0].starts_with("HTTP/"))
 | |
|     {
 | |
|         co_return StreamError::PROTOCOL_ERROR;
 | |
|     }
 | |
| 
 | |
|     HTTPResponse response;
 | |
|     if (!parseHTTPVersion(parts[0].substr(5), response.version)
 | |
|         || !toNumber(parts[1], response.status))
 | |
|     {
 | |
|         co_return StreamError::PROTOCOL_ERROR;
 | |
|     }
 | |
|     response.statusMessage = parts[2];
 | |
| 
 | |
|     decltype(response.headers)::iterator lastHeader;
 | |
|     while (true)
 | |
|     {
 | |
|         MIJIN_HTTP_READLINE(line);
 | |
|         if (line.empty()) {
 | |
|             break;
 | |
|         }
 | |
|         if (line[0] == ' ' || line[0] == '\t')
 | |
|         {
 | |
|             // continuation
 | |
|             if (lastHeader == response.headers.end())
 | |
|             {
 | |
|                 co_return StreamError::PROTOCOL_ERROR;
 | |
|             }
 | |
|             lastHeader->second.push_back(' ');
 | |
|             lastHeader->second.append(trim(line));
 | |
|         }
 | |
|         parts = split(line, ":", {.limitParts = 2});
 | |
|         if (parts.size() != 2)
 | |
|         {
 | |
|             co_return StreamError::PROTOCOL_ERROR;
 | |
|         }
 | |
|         lastHeader = response.headers.emplace(toLower(trim(parts[0])), trim(parts[1]));
 | |
|     }
 | |
| 
 | |
|     auto itContentLength = response.headers.find("content-length");
 | |
|     if (itContentLength != response.headers.end())
 | |
|     {
 | |
|         std::size_t contentLength = 0;
 | |
|         if (!toNumber(itContentLength->second, contentLength))
 | |
|         {
 | |
|             co_return StreamError::PROTOCOL_ERROR;
 | |
|         }
 | |
|         if (contentLength > CONTENT_LENGTH_LIMIT)
 | |
|         {
 | |
|             co_return StreamError::PROTOCOL_ERROR;
 | |
|         }
 | |
|         response.body.resize(contentLength);
 | |
|         MIJIN_HTTP_CHECKREAD(base_->c_readRaw(response.body));
 | |
|     }
 | |
| 
 | |
|     co_return response;
 | |
| }
 | |
| 
 | |
| Task<StreamResult<HTTPResponse>> HTTPClient::c_request(ip_address_t address, std::uint16_t port, bool https,
 | |
|                                                        HTTPRequest request) MIJIN_NOEXCEPT
 | |
| {
 | |
|     std::string hostname;
 | |
|     if (auto it = request.headers.find("host"); it != request.headers.end())
 | |
|     {
 | |
|         hostname = it->second;
 | |
|     }
 | |
|     if (const StreamError error = createSocket(address, hostname, port, https); error != StreamError::SUCCESS)
 | |
|     {
 | |
|         co_return error;
 | |
|     }
 | |
|     if (!request.headers.contains("connection"))
 | |
|     {
 | |
|         request.headers.emplace("connection", "keep-alive");
 | |
|     }
 | |
|     StreamResult<HTTPResponse> response = co_await stream_->c_request(request);
 | |
|     if (response.isError())
 | |
|     {
 | |
|         disconnect();
 | |
|         if (const StreamError error = createSocket(address, hostname, port, https); error != StreamError::SUCCESS)
 | |
|         {
 | |
|             co_return error;
 | |
|         }
 | |
|         response = co_await stream_->c_request(request);
 | |
|     }
 | |
|     co_return response;
 | |
| }
 | |
| 
 | |
| Task<StreamResult<HTTPResponse>> HTTPClient::c_request(const URL& url, HTTPRequest request) MIJIN_NOEXCEPT
 | |
| {
 | |
|     if (url.getHost().empty())
 | |
|     {
 | |
|         co_return StreamError::UNKNOWN_ERROR;
 | |
|     }
 | |
| 
 | |
|     std::uint16_t port = url.getPort();
 | |
|     bool https = false;
 | |
|     if (equalsIgnoreCase(url.getScheme(), "http"))
 | |
|     {
 | |
|         port = (port != 0) ? port : 80;
 | |
|     }
 | |
|     else if (equalsIgnoreCase(url.getScheme(), "https"))
 | |
|     {
 | |
|         port = (port != 0) ? port : 443;
 | |
|         https = true;
 | |
|     }
 | |
|     else
 | |
|     {
 | |
|         co_return StreamError::UNKNOWN_ERROR;
 | |
|     }
 | |
|     Optional<ip_address_t> ipAddress = ipAddressFromString(url.getHost());
 | |
|     if (ipAddress.empty())
 | |
|     {
 | |
|         StreamResult<std::vector<ip_address_t>> addresses = co_await c_resolveHostname(url.getHost());
 | |
|         if (addresses.isError())
 | |
|         {
 | |
|             co_return addresses.getError();
 | |
|         }
 | |
|         else if (addresses->empty())
 | |
|         {
 | |
|             co_return StreamError::UNKNOWN_ERROR;
 | |
|         }
 | |
|         // TODO: try all addresses
 | |
|         ipAddress = addresses->front();
 | |
|     }
 | |
| 
 | |
|     if (!request.headers.contains("host"))
 | |
|     {
 | |
|         request.headers.emplace("host", url.getHost());
 | |
|     }
 | |
|     request.address = url.getPathQueryFragment();
 | |
|     if (request.address.empty())
 | |
|     {
 | |
|         request.address = "/";
 | |
|     }
 | |
| 
 | |
|     co_return co_await c_request(*ipAddress, port, https, std::move(request));
 | |
| }
 | |
| 
 | |
| void HTTPClient::disconnect() MIJIN_NOEXCEPT
 | |
| {
 | |
|     if (socket_ == nullptr)
 | |
|     {
 | |
|         return;
 | |
|     }
 | |
|     stream_.destroy();
 | |
|     socket_ = nullptr;
 | |
| }
 | |
| 
 | |
| StreamError HTTPClient::createSocket(ip_address_t address, const std::string& hostname, std::uint16_t port, bool https) MIJIN_NOEXCEPT
 | |
| {
 | |
|     if (socket_ != nullptr && address == lastIP_ && port == lastPort_ && https == lastWasHttps_)
 | |
|     {
 | |
|         return StreamError::SUCCESS;
 | |
|     }
 | |
| 
 | |
|     disconnect();
 | |
| 
 | |
|     std::unique_ptr<TCPSocket> newSocket = std::make_unique<TCPSocket>();
 | |
|     if (const StreamError error = newSocket->open(address, port); error != StreamError::SUCCESS)
 | |
|     {
 | |
|         return error;
 | |
|     }
 | |
|     socket_ = std::move(newSocket);
 | |
|     if (!https)
 | |
|     {
 | |
|         sslStream_.reset();
 | |
|         stream_.construct(socket_->getStream());
 | |
|     }
 | |
|     else
 | |
|     {
 | |
| #if defined(MIJIN_ENABLE_OPENSSL)
 | |
|         std::unique_ptr<SSLStream> sslStream = std::make_unique<SSLStream>();
 | |
|         if (const StreamError error = sslStream->open(socket_->getStream(), hostname); error != StreamError::SUCCESS)
 | |
|         {
 | |
|             return error;
 | |
|         }
 | |
|         sslStream_ = std::move(sslStream);
 | |
|         stream_.construct(*sslStream_);
 | |
| #else
 | |
|         (void) hostname;
 | |
|         return StreamError::NOT_SUPPORTED;
 | |
| #endif
 | |
|     }
 | |
| 
 | |
|     lastIP_ = address;
 | |
|     lastPort_ = port;
 | |
|     lastWasHttps_ = https;
 | |
|     return StreamError::SUCCESS;
 | |
| }
 | |
| }
 | |
| 
 | |
| #undef MIJIN_HTTP_WRITE
 | |
| #undef MIJIN_HTTP_CHECKREAD
 | |
| #undef MIJIN_HTTP_READLINE |