From 9003ee55b9d406b995d4c12d3d2d76e94fd90411 Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Mon, 8 Jun 2020 18:31:21 +0100 Subject: [PATCH] Socket: Use the RWMutex to fix TSAN error ... about `close()`ing the socket on one thread while in a blocking `recv()` or `send()` call on another thread. Fixes #35 --- src/socket.cpp | 227 ++++++++++++++++++++++++++----------------------- 1 file changed, 121 insertions(+), 106 deletions(-) diff --git a/src/socket.cpp b/src/socket.cpp index 653b384..fe27ff7 100644 --- a/src/socket.cpp +++ b/src/socket.cpp @@ -14,6 +14,8 @@ #include "socket.h" +#include "rwmutex.h" + #if defined(_WIN32) #include #include @@ -41,7 +43,7 @@ using SOCKET = int; namespace { constexpr SOCKET InvalidSocket = static_cast(-1); -static void init() { +void init() { #if defined(_WIN32) if (wsaInitCount++ == 0) { WSADATA winsockData; @@ -50,7 +52,7 @@ static void init() { #endif } -static void term() { +void term() { #if defined(_WIN32) if (--wsaInitCount == 0) { WSACleanup(); @@ -58,6 +60,30 @@ static void term() { #endif } +bool setBlocking(SOCKET s, bool blocking) { +#if defined(_WIN32) + u_long mode = blocking ? 0 : 1; + return ioctlsocket(s, FIONBIO, &mode) == NO_ERROR; +#else + auto arg = fcntl(s, F_GETFL, nullptr); + if (arg < 0) { + return false; + } + arg = blocking ? (arg & ~O_NONBLOCK) : (arg | O_NONBLOCK); + return fcntl(s, F_SETFL, arg) >= 0; +#endif +} + +bool errored(SOCKET s) { + if (s == InvalidSocket) { + return true; + } + char error = 0; + socklen_t len = sizeof(error); + getsockopt(s, SOL_SOCKET, SO_ERROR, &error, &len); + return error != 0; +} + } // anonymous namespace class dap::Socket::Shared : public dap::ReaderWriter { @@ -87,8 +113,8 @@ class dap::Socket::Shared : public dap::ReaderWriter { return nullptr; } - Shared(SOCKET socket) : info(nullptr), sock(socket) {} - Shared(addrinfo* info, SOCKET socket) : info(info), sock(socket) {} + Shared(SOCKET socket) : info(nullptr), s(socket) {} + Shared(addrinfo* info, SOCKET socket) : info(info), s(socket) {} ~Shared() { freeaddrinfo(info); @@ -96,10 +122,14 @@ class dap::Socket::Shared : public dap::ReaderWriter { term(); } - SOCKET socket() { return sock.load(); } + template + void lock(FUNCTION&& f) { + RLock l(mutex); + f(s, info); + } void setOptions() { - SOCKET s = socket(); + RLock l(mutex); if (s == InvalidSocket) { return; } @@ -125,59 +155,42 @@ class dap::Socket::Shared : public dap::ReaderWriter { setsockopt(s, IPPROTO_TCP, TCP_NODELAY, (char*)&enable, sizeof(enable)); } - bool setBlocking(bool blocking) { - SOCKET s = socket(); - if (s == InvalidSocket) { - return false; + // dap::ReaderWriter compliance + bool isOpen() { + { + RLock l(mutex); + if ((s != InvalidSocket) && !errored(s)) { + return true; + } } - -#if defined(_WIN32) - u_long mode = blocking ? 0 : 1; - return ioctlsocket(s, FIONBIO, &mode) == NO_ERROR; -#else - auto arg = fcntl(s, F_GETFL, nullptr); - if (arg < 0) { - return false; - } - arg = blocking ? (arg & ~O_NONBLOCK) : (arg | O_NONBLOCK); - return fcntl(s, F_SETFL, arg) >= 0; -#endif - } - - bool errored() { - SOCKET s = socket(); - if (s == InvalidSocket) { - return true; - } - - char error = 0; - socklen_t len = sizeof(error); - getsockopt(s, SOL_SOCKET, SO_ERROR, &error, &len); - if (error != 0) { - sock.compare_exchange_weak(s, InvalidSocket); - return true; - } - + WLock lock(mutex); + s = InvalidSocket; return false; } - // dap::ReaderWriter compliance - bool isOpen() { return !errored(); } - void close() { - SOCKET s = sock.exchange(InvalidSocket); +#if !defined(_WIN32) + { + RLock l(mutex); + if (s != InvalidSocket) { + ::shutdown(s, SHUT_RDWR); + } + } +#endif + + WLock l(mutex); if (s != InvalidSocket) { #if defined(_WIN32) closesocket(s); #else - ::shutdown(s, SHUT_RDWR); ::close(s); #endif + s = InvalidSocket; } } size_t read(void* buffer, size_t bytes) { - SOCKET s = socket(); + RLock lock(mutex); if (s == InvalidSocket) { return 0; } @@ -187,7 +200,7 @@ class dap::Socket::Shared : public dap::ReaderWriter { } bool write(const void* buffer, size_t bytes) { - SOCKET s = socket(); + RLock lock(mutex); if (s == InvalidSocket) { return false; } @@ -198,44 +211,43 @@ class dap::Socket::Shared : public dap::ReaderWriter { static_cast(bytes), 0) > 0; } - addrinfo* const info; - private: - std::atomic sock = {InvalidSocket}; + addrinfo* const info; + SOCKET s = InvalidSocket; + RWMutex mutex; }; namespace dap { Socket::Socket(const char* address, const char* port) : shared(Shared::create(address, port)) { - if (!shared) { - return; - } - auto socket = shared->socket(); + if (shared) { + shared->lock([&](SOCKET socket, const addrinfo* info) { + if (bind(socket, info->ai_addr, (int)info->ai_addrlen) != 0) { + shared.reset(); + return; + } - if (bind(socket, shared->info->ai_addr, (int)shared->info->ai_addrlen) != 0) { - shared.reset(); - return; - } - - if (listen(socket, 0) != 0) { - shared.reset(); - return; + if (listen(socket, 0) != 0) { + shared.reset(); + return; + } + }); } } std::shared_ptr Socket::accept() const { + std::shared_ptr out; if (shared) { - SOCKET socket = shared->socket(); - if (socket != InvalidSocket) { - init(); - auto out = std::make_shared(::accept(socket, 0, 0)); - out->setOptions(); - return out; - } + shared->lock([&](SOCKET socket, const addrinfo*) { + if (socket != InvalidSocket) { + init(); + out = std::make_shared(::accept(socket, 0, 0)); + out->setOptions(); + } + }); } - - return {}; + return out; } bool Socket::isOpen() const { @@ -259,47 +271,50 @@ std::shared_ptr Socket::connect(const char* address, return nullptr; } - if (timeoutMillis == 0) { - if (::connect(shared->socket(), shared->info->ai_addr, - (int)shared->info->ai_addrlen) == 0) { - return shared; + std::shared_ptr out; + shared->lock([&](SOCKET socket, const addrinfo* info) { + if (socket == InvalidSocket) { + return; } + + if (timeoutMillis == 0) { + if (::connect(socket, info->ai_addr, (int)info->ai_addrlen) == 0) { + out = shared; + } + return; + } + + if (!setBlocking(socket, false)) { + return; + } + + auto res = ::connect(socket, info->ai_addr, (int)info->ai_addrlen); + if (res == 0) { + if (setBlocking(socket, true)) { + out = shared; + } + } else { + const auto microseconds = timeoutMillis * 1000; + + fd_set fdset; + FD_ZERO(&fdset); + FD_SET(socket, &fdset); + + timeval tv; + tv.tv_sec = microseconds / 1000000; + tv.tv_usec = microseconds - (tv.tv_sec * 1000000); + res = select(static_cast(socket + 1), nullptr, &fdset, nullptr, &tv); + if (res > 0 && !errored(socket) && setBlocking(socket, true)) { + out = shared; + } + } + }); + + if (!out) { return nullptr; } - auto s = shared->socket(); - if (s == InvalidSocket) { - return nullptr; - } - - if (!shared->setBlocking(false)) { - return nullptr; - } - - auto res = ::connect(s, shared->info->ai_addr, (int)shared->info->ai_addrlen); - if (res == 0) { - return shared->setBlocking(true) ? shared : nullptr; - } - - const auto microseconds = timeoutMillis * 1000; - - fd_set fdset; - FD_ZERO(&fdset); - FD_SET(s, &fdset); - - timeval tv; - tv.tv_sec = microseconds / 1000000; - tv.tv_usec = microseconds - (tv.tv_sec * 1000000); - res = select(static_cast(s + 1), nullptr, &fdset, nullptr, &tv); - if (res <= 0) { - return nullptr; - } - - if (shared->errored()) { - return nullptr; - } - - return shared->setBlocking(true) ? shared : nullptr; + return out->isOpen() ? out : nullptr; } } // namespace dap