diff --git a/include/dap/session.h b/include/dap/session.h index ddb4d01..3933886 100644 --- a/include/dap/session.h +++ b/include/dap/session.h @@ -137,6 +137,10 @@ class Session { // errors. using ErrorHandler = std::function; + // ClosedHandler is the type of callback function used to signal that a + // connected endpoint has closed. + using ClosedHandler = std::function; + // create() constructs and returns a new Session. static std::unique_ptr create(); @@ -205,9 +209,13 @@ class Session { // bind() connects this Session to an endpoint using connect(), and then // starts processing incoming messages with startProcessingMessages(). - inline void bind(const std::shared_ptr&, - const std::shared_ptr&); - inline void bind(const std::shared_ptr&); + // onClose is the optional callback which will be called when the session + // endpoint has been closed. + inline void bind(const std::shared_ptr& reader, + const std::shared_ptr& writer, + const ClosedHandler& onClose); + inline void bind(const std::shared_ptr& readerWriter, + const ClosedHandler& onClose); ////////////////////////////////////////////////////////////////////////////// // Note: @@ -227,9 +235,11 @@ class Session { // startProcessingMessages() starts a new thread to receive and dispatch // incoming messages. + // onClose is the optional callback which will be called when the session + // endpoint has been closed. // Note: This method is used for explicit control over message handling. // Most users will use bind() instead of calling this method directly. - virtual void startProcessingMessages() = 0; + virtual void startProcessingMessages(const ClosedHandler& onClose = {}) = 0; // getPayload() blocks until the next incoming message is received, returning // the payload or an empty function if the connection was lost. The returned @@ -423,13 +433,15 @@ void Session::connect(const std::shared_ptr& rw) { } void Session::bind(const std::shared_ptr& r, - const std::shared_ptr& w) { + const std::shared_ptr& w, + const ClosedHandler& onClose = {}) { connect(r, w); - startProcessingMessages(); + startProcessingMessages(onClose); } -void Session::bind(const std::shared_ptr& rw) { - bind(rw, rw); +void Session::bind(const std::shared_ptr& rw, + const ClosedHandler& onClose = {}) { + bind(rw, rw, onClose); } } // namespace dap diff --git a/src/session.cpp b/src/session.cpp index b521fc2..d88a697 100644 --- a/src/session.cpp +++ b/src/session.cpp @@ -65,7 +65,7 @@ class Impl : public dap::Session { void connect(const std::shared_ptr& r, const std::shared_ptr& w) override { if (isBound.exchange(true)) { - handlers.error("Session is already bound!"); + handlers.error("Session::connect called twice"); return; } @@ -73,13 +73,21 @@ class Impl : public dap::Session { writer = dap::ContentWriter(w); } - void startProcessingMessages() override { - recvThread = std::thread([this] { + void startProcessingMessages( + const ClosedHandler& onClose /* = {} */) override { + if (isProcessingMessages.exchange(true)) { + handlers.error("Session::startProcessingMessages() called twice"); + return; + } + recvThread = std::thread([this, onClose] { while (reader.isOpen()) { if (auto payload = getPayload()) { inbox.put(std::move(payload)); } } + if (onClose) { + onClose(); + } }); dispatchThread = std::thread([this] { @@ -398,17 +406,17 @@ class Impl : public dap::Session { // "body" is an optional field for some events, such as "Terminated Event". bool body_ok = true; d->field("body", [&](dap::Deserializer* d) { - if (!typeinfo->deserialize(d, data)) { - body_ok = false; - } - return true; + if (!typeinfo->deserialize(d, data)) { + body_ok = false; + } + return true; }); if (!body_ok) { - handlers.error("Failed to deserialize event '%s' body", event.c_str()); - typeinfo->destruct(data); - delete[] data; - return {}; + handlers.error("Failed to deserialize event '%s' body", event.c_str()); + typeinfo->destruct(data); + delete[] data; + return {}; } return [=] { @@ -471,6 +479,7 @@ class Impl : public dap::Session { } std::atomic isBound = {false}; + std::atomic isProcessingMessages = {false}; dap::ContentReader reader; dap::ContentWriter writer; diff --git a/src/session_test.cpp b/src/session_test.cpp index eeb8fe3..361152e 100644 --- a/src/session_test.cpp +++ b/src/session_test.cpp @@ -579,3 +579,47 @@ TEST_F(SessionTest, Concurrency) { client.reset(); server.reset(); } + +TEST_F(SessionTest, OnClientClosed) { + std::mutex mutex; + std::condition_variable cv; + bool clientClosed = false; + + auto client2server = dap::pipe(); + auto server2client = dap::pipe(); + + client->bind(server2client, client2server); + server->bind(client2server, server2client, [&] { + std::unique_lock lock(mutex); + clientClosed = true; + cv.notify_all(); + }); + + client.reset(); + + // Wait for the client closed handler to be called. + std::unique_lock lock(mutex); + cv.wait(lock, [&] { return static_cast(clientClosed); }); +} + +TEST_F(SessionTest, OnServerClosed) { + std::mutex mutex; + std::condition_variable cv; + bool serverClosed = false; + + auto client2server = dap::pipe(); + auto server2client = dap::pipe(); + + client->bind(server2client, client2server, [&] { + std::unique_lock lock(mutex); + serverClosed = true; + cv.notify_all(); + }); + server->bind(client2server, server2client); + + server.reset(); + + // Wait for the client closed handler to be called. + std::unique_lock lock(mutex); + cv.wait(lock, [&] { return static_cast(serverClosed); }); +}