diff --git a/src/network.cpp b/src/network.cpp index 887d762..613c234 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -16,6 +16,7 @@ #include "socket.h" +#include #include #include #include @@ -24,7 +25,7 @@ namespace { class Impl : public dap::net::Server { public: - Impl() {} + Impl() : stopped{true} {} ~Impl() { stop(); } @@ -41,17 +42,18 @@ class Impl : public dap::net::Server { return false; } - running = true; + stopped = false; thread = std::thread([=] { - do { + while (true) { if (auto rw = socket->accept()) { onConnect(rw); continue; } - if (!isRunning()) { + if (!stopped) { onError("Failed to accept connection"); } - } while (false); + break; + }; }); return true; @@ -63,23 +65,19 @@ class Impl : public dap::net::Server { } private: - bool isRunning() { - std::unique_lock lock(mutex); - return running; - } + bool isRunning() { return !stopped; } void stopWithLock() { - if (running) { + if (!stopped.exchange(true)) { socket->close(); thread.join(); - running = false; } } std::mutex mutex; std::thread thread; std::unique_ptr socket; - bool running = false; + std::atomic stopped; OnError errorHandler; }; diff --git a/src/network_test.cpp b/src/network_test.cpp index 8e965f1..57bb0a9 100644 --- a/src/network_test.cpp +++ b/src/network_test.cpp @@ -25,6 +25,8 @@ namespace { +constexpr int port = 19021; + bool write(const std::shared_ptr& w, const std::string& s) { return w->write(s.data(), s.size()) && w->write("\0", 1); } @@ -44,7 +46,6 @@ std::string read(const std::shared_ptr& r) { } // anonymous namespace TEST(Network, ClientServer) { - const int port = 19021; dap::Chan done; auto server = dap::net::Server::create(); if (!server->start( @@ -59,15 +60,51 @@ TEST(Network, ClientServer) { return; } - for (int i = 0; i < 10; i++) { - if (auto client = dap::net::connect("localhost", port)) { - ASSERT_TRUE(write(client, "client to server")); - ASSERT_EQ(read(client), "server to client"); - break; - } + for (int i = 0; i < 5; i++) { + auto client = dap::net::connect("localhost", port); + ASSERT_NE(client, nullptr) << "Failed to connect client " << i; + ASSERT_TRUE(write(client, "client to server")); + ASSERT_EQ(read(client), "server to client"); + done.take(); std::this_thread::sleep_for(std::chrono::seconds(1)); } - done.take(); + server.reset(); +} + +TEST(Network, ServerRepeatStopAndRestart) { + dap::Chan done; + auto onConnect = [&](const std::shared_ptr& rw) { + ASSERT_EQ(read(rw), "client to server"); + ASSERT_TRUE(write(rw, "server to client")); + done.put(true); + }; + auto onError = [&](const char* err) { FAIL() << "Server error: " << err; }; + + auto server = dap::net::Server::create(); + if (!server->start(port, onConnect, onError)) { + FAIL() << "Couldn't start server"; + return; + } + + server->stop(); + server->stop(); + server->stop(); + + if (!server->start(port, onConnect, onError)) { + FAIL() << "Couldn't restart server"; + return; + } + + auto client = dap::net::connect("localhost", port); + ASSERT_NE(client, nullptr) << "Failed to connect"; + ASSERT_TRUE(write(client, "client to server")); + ASSERT_EQ(read(client), "server to client"); + done.take(); + + server->stop(); + server->stop(); + server->stop(); + server.reset(); } diff --git a/src/socket.cpp b/src/socket.cpp index 998fed9..1211310 100644 --- a/src/socket.cpp +++ b/src/socket.cpp @@ -174,7 +174,17 @@ class dap::Socket::Shared : public dap::ReaderWriter { if (s != InvalidSocket) { #if defined(_WIN32) closesocket(s); +#elif __APPLE__ + // ::shutdown() *should* be sufficient to unblock ::accept(), but + // apparently on macos it can return ENOTCONN and ::accept() continues + // to block indefinitely. + // Note: There is a race here. Calling ::close() frees the socket ID, + // which may be reused before `s` is assigned InvalidSocket. + ::shutdown(s, SHUT_RDWR); + ::close(s); #else + // ::shutdown() to unblock ::accept(). We'll actually close the socket + // under lock below. ::shutdown(s, SHUT_RDWR); #endif } @@ -182,7 +192,7 @@ class dap::Socket::Shared : public dap::ReaderWriter { WLock l(mutex); if (s != InvalidSocket) { -#if !defined(_WIN32) +#if !defined(_WIN32) && !defined(__APPLE__) ::close(s); #endif s = InvalidSocket; @@ -240,10 +250,13 @@ std::shared_ptr Socket::accept() const { std::shared_ptr out; if (shared) { shared->lock([&](SOCKET socket, const addrinfo*) { - if (socket != InvalidSocket) { + if (socket != InvalidSocket && !errored(socket)) { init(); - out = std::make_shared(::accept(socket, 0, 0)); - out->setOptions(); + auto accepted = ::accept(socket, 0, 0); + if (accepted != InvalidSocket) { + out = std::make_shared(accepted); + out->setOptions(); + } } }); }