diff --git a/examples/simple_net_client_server.cpp b/examples/simple_net_client_server.cpp index 2e5380e..c93cc9c 100644 --- a/examples/simple_net_client_server.cpp +++ b/examples/simple_net_client_server.cpp @@ -30,6 +30,11 @@ int main(int, char*[]) { auto onClientConnected = [&](const std::shared_ptr& socket) { auto session = dap::Session::create(); + + // Set the session to close on invalid data. This ensures that data received over the network + // receives a baseline level of validation before being processed. + session->setOnInvalidData(dap::kClose); + session->bind(socket); // The Initialize request is the first message sent from the client and diff --git a/include/dap/session.h b/include/dap/session.h index 3933886..96db04b 100644 --- a/include/dap/session.h +++ b/include/dap/session.h @@ -103,6 +103,14 @@ ResponseOrError& ResponseOrError::operator=(ResponseOrError&& other) { // Session //////////////////////////////////////////////////////////////////////////////// +// An enum flag that controls how the Session handles invalid data. +enum OnInvalidData { + // Ignore invalid data. + kIgnore, + // Close the underlying reader when invalid data is received. + kClose, +}; + // Session implements a DAP client or server endpoint. // The general usage is as follows: // (1) Create a session with Session::create(). @@ -144,6 +152,9 @@ class Session { // create() constructs and returns a new Session. static std::unique_ptr create(); + // Sets how the Session handles invalid data. + virtual void setOnInvalidData(OnInvalidData) = 0; + // onError() registers a error handler that will be called whenever a protocol // error is encountered. // Only one error handler can be bound at any given time, and later calls diff --git a/src/content_stream.cpp b/src/content_stream.cpp index e7c6628..c5264aa 100644 --- a/src/content_stream.cpp +++ b/src/content_stream.cpp @@ -24,12 +24,15 @@ namespace dap { //////////////////////////////////////////////////////////////////////////////// // ContentReader //////////////////////////////////////////////////////////////////////////////// -ContentReader::ContentReader(const std::shared_ptr& reader) - : reader(reader) {} +ContentReader::ContentReader( + const std::shared_ptr& reader, + OnInvalidData on_invalid_data /* = OnInvalidData::kIgnore */) + : reader(reader), on_invalid_data(on_invalid_data) {} ContentReader& ContentReader::operator=(ContentReader&& rhs) noexcept { buf = std::move(rhs.buf); reader = std::move(rhs.reader); + on_invalid_data = std::move(rhs.on_invalid_data); return *this; } @@ -45,8 +48,14 @@ void ContentReader::close() { std::string ContentReader::read() { // Find Content-Length header prefix - if (!scan("Content-Length:")) { - return ""; + if (on_invalid_data == kClose) { + if (!match("Content-Length:")) { + return badHeader(); + } + } else { + if (!scan("Content-Length:")) { + return ""; + } } // Skip whitespace and tabs while (matchAny(" \t")) { @@ -64,10 +73,12 @@ std::string ContentReader::read() { if (len == 0) { return ""; } + // Expect \r\n\r\n if (!match("\r\n\r\n")) { - return ""; + return badHeader(); } + // Read message if (!buffer(len)) { return ""; @@ -149,6 +160,13 @@ bool ContentReader::buffer(size_t bytes) { return true; } +std::string ContentReader::badHeader() { + if (on_invalid_data == kClose) { + close(); + } + return ""; +} + //////////////////////////////////////////////////////////////////////////////// // ContentWriter //////////////////////////////////////////////////////////////////////////////// diff --git a/src/content_stream.h b/src/content_stream.h index f01fef7..eee998e 100644 --- a/src/content_stream.h +++ b/src/content_stream.h @@ -21,6 +21,8 @@ #include +#include "dap/session.h" + namespace dap { // Forward declarations @@ -30,7 +32,8 @@ class Writer; class ContentReader { public: ContentReader() = default; - ContentReader(const std::shared_ptr&); + ContentReader(const std::shared_ptr&, + const OnInvalidData on_invalid_data = kIgnore); ContentReader& operator=(ContentReader&&) noexcept; bool isOpen(); @@ -44,9 +47,11 @@ class ContentReader { bool match(const char* str); char matchAny(const char* chars); bool buffer(size_t bytes); + std::string badHeader(); std::shared_ptr reader; std::deque buf; + OnInvalidData on_invalid_data; }; class ContentWriter { diff --git a/src/content_stream_test.cpp b/src/content_stream_test.cpp index 0cd2edb..3f00472 100644 --- a/src/content_stream_test.cpp +++ b/src/content_stream_test.cpp @@ -94,3 +94,33 @@ TEST(ContentStreamTest, PartialReadAndParse) { ASSERT_EQ(cs.read(), "Content payload number one"); ASSERT_EQ(cs.read(), ""); } + +TEST(ContentStreamTest, HttpRequest) { + const char* const part1 = + "POST / HTTP/1.1\r\n" + "Host: localhost:8001\r\n" + "Connection: keep-alive\r\n" + "Content-Length: 99\r\n"; + const char* const part2 = + "Pragma: no-cache\r\n" + "Cache-Control: no-cache\r\n" + "Content-Type: text/plain;charset=UTF-8\r\n" + "Accept: */*\r\n" + "Origin: null\r\n" + "Sec-Fetch-Site: cross-site\r\n" + "Sec-Fetch-Mode: cors\r\n" + "Sec-Fetch-Dest: empty\r\n" + "Accept-Encoding: gzip, deflate, br\r\n" + "Accept-Language: en-US,en;q=0.9\r\n" + "\r\n" + "{\"type\":\"request\",\"command\":\"launch\",\"arguments\":{\"cmd\":\"/" + "bin/sh -c 'echo remote code execution'\"}}"; + + auto sb = dap::StringBuffer::create(); + sb->write(part1); + sb->write(part2); + + dap::ContentReader cr(std::move(sb), dap::kClose); + ASSERT_EQ(cr.read(), ""); + ASSERT_FALSE(cr.isOpen()); +} diff --git a/src/session.cpp b/src/session.cpp index d88a697..5bf22c9 100644 --- a/src/session.cpp +++ b/src/session.cpp @@ -35,6 +35,10 @@ namespace { class Impl : public dap::Session { public: + void setOnInvalidData(dap::OnInvalidData onInvalidData_) override { + this->onInvalidData = onInvalidData_; + } + void onError(const ErrorHandler& handler) override { handlers.put(handler); } void registerHandler(const dap::TypeInfo* typeinfo, @@ -69,7 +73,7 @@ class Impl : public dap::Session { return; } - reader = dap::ContentReader(r); + reader = dap::ContentReader(r, this->onInvalidData); writer = dap::ContentWriter(w); } @@ -490,6 +494,7 @@ class Impl : public dap::Session { dap::Chan inbox; std::atomic nextSeq = {1}; std::mutex sendMutex; + dap::OnInvalidData onInvalidData = dap::kIgnore; }; } // anonymous namespace