Socket: Use the RWMutex to fix TSAN error

... about `close()`ing the socket on one thread while in a blocking `recv()` or `send()` call on another thread.

Fixes #35
This commit is contained in:
Ben Clayton 2020-06-08 18:31:21 +01:00
parent 53a62fd794
commit 9003ee55b9

View File

@ -14,6 +14,8 @@
#include "socket.h" #include "socket.h"
#include "rwmutex.h"
#if defined(_WIN32) #if defined(_WIN32)
#include <winsock2.h> #include <winsock2.h>
#include <ws2tcpip.h> #include <ws2tcpip.h>
@ -41,7 +43,7 @@ using SOCKET = int;
namespace { namespace {
constexpr SOCKET InvalidSocket = static_cast<SOCKET>(-1); constexpr SOCKET InvalidSocket = static_cast<SOCKET>(-1);
static void init() { void init() {
#if defined(_WIN32) #if defined(_WIN32)
if (wsaInitCount++ == 0) { if (wsaInitCount++ == 0) {
WSADATA winsockData; WSADATA winsockData;
@ -50,7 +52,7 @@ static void init() {
#endif #endif
} }
static void term() { void term() {
#if defined(_WIN32) #if defined(_WIN32)
if (--wsaInitCount == 0) { if (--wsaInitCount == 0) {
WSACleanup(); WSACleanup();
@ -58,6 +60,30 @@ static void term() {
#endif #endif
} }
bool setBlocking(SOCKET s, bool blocking) {
#if defined(_WIN32)
u_long mode = blocking ? 0 : 1;
return ioctlsocket(s, FIONBIO, &mode) == NO_ERROR;
#else
auto arg = fcntl(s, F_GETFL, nullptr);
if (arg < 0) {
return false;
}
arg = blocking ? (arg & ~O_NONBLOCK) : (arg | O_NONBLOCK);
return fcntl(s, F_SETFL, arg) >= 0;
#endif
}
bool errored(SOCKET s) {
if (s == InvalidSocket) {
return true;
}
char error = 0;
socklen_t len = sizeof(error);
getsockopt(s, SOL_SOCKET, SO_ERROR, &error, &len);
return error != 0;
}
} // anonymous namespace } // anonymous namespace
class dap::Socket::Shared : public dap::ReaderWriter { class dap::Socket::Shared : public dap::ReaderWriter {
@ -87,8 +113,8 @@ class dap::Socket::Shared : public dap::ReaderWriter {
return nullptr; return nullptr;
} }
Shared(SOCKET socket) : info(nullptr), sock(socket) {} Shared(SOCKET socket) : info(nullptr), s(socket) {}
Shared(addrinfo* info, SOCKET socket) : info(info), sock(socket) {} Shared(addrinfo* info, SOCKET socket) : info(info), s(socket) {}
~Shared() { ~Shared() {
freeaddrinfo(info); freeaddrinfo(info);
@ -96,10 +122,14 @@ class dap::Socket::Shared : public dap::ReaderWriter {
term(); term();
} }
SOCKET socket() { return sock.load(); } template <typename FUNCTION>
void lock(FUNCTION&& f) {
RLock l(mutex);
f(s, info);
}
void setOptions() { void setOptions() {
SOCKET s = socket(); RLock l(mutex);
if (s == InvalidSocket) { if (s == InvalidSocket) {
return; return;
} }
@ -125,59 +155,42 @@ class dap::Socket::Shared : public dap::ReaderWriter {
setsockopt(s, IPPROTO_TCP, TCP_NODELAY, (char*)&enable, sizeof(enable)); setsockopt(s, IPPROTO_TCP, TCP_NODELAY, (char*)&enable, sizeof(enable));
} }
bool setBlocking(bool blocking) { // dap::ReaderWriter compliance
SOCKET s = socket(); bool isOpen() {
if (s == InvalidSocket) { {
return false; RLock l(mutex);
if ((s != InvalidSocket) && !errored(s)) {
return true;
}
} }
WLock lock(mutex);
#if defined(_WIN32) s = InvalidSocket;
u_long mode = blocking ? 0 : 1;
return ioctlsocket(s, FIONBIO, &mode) == NO_ERROR;
#else
auto arg = fcntl(s, F_GETFL, nullptr);
if (arg < 0) {
return false;
}
arg = blocking ? (arg & ~O_NONBLOCK) : (arg | O_NONBLOCK);
return fcntl(s, F_SETFL, arg) >= 0;
#endif
}
bool errored() {
SOCKET s = socket();
if (s == InvalidSocket) {
return true;
}
char error = 0;
socklen_t len = sizeof(error);
getsockopt(s, SOL_SOCKET, SO_ERROR, &error, &len);
if (error != 0) {
sock.compare_exchange_weak(s, InvalidSocket);
return true;
}
return false; return false;
} }
// dap::ReaderWriter compliance
bool isOpen() { return !errored(); }
void close() { void close() {
SOCKET s = sock.exchange(InvalidSocket); #if !defined(_WIN32)
{
RLock l(mutex);
if (s != InvalidSocket) {
::shutdown(s, SHUT_RDWR);
}
}
#endif
WLock l(mutex);
if (s != InvalidSocket) { if (s != InvalidSocket) {
#if defined(_WIN32) #if defined(_WIN32)
closesocket(s); closesocket(s);
#else #else
::shutdown(s, SHUT_RDWR);
::close(s); ::close(s);
#endif #endif
s = InvalidSocket;
} }
} }
size_t read(void* buffer, size_t bytes) { size_t read(void* buffer, size_t bytes) {
SOCKET s = socket(); RLock lock(mutex);
if (s == InvalidSocket) { if (s == InvalidSocket) {
return 0; return 0;
} }
@ -187,7 +200,7 @@ class dap::Socket::Shared : public dap::ReaderWriter {
} }
bool write(const void* buffer, size_t bytes) { bool write(const void* buffer, size_t bytes) {
SOCKET s = socket(); RLock lock(mutex);
if (s == InvalidSocket) { if (s == InvalidSocket) {
return false; return false;
} }
@ -198,44 +211,43 @@ class dap::Socket::Shared : public dap::ReaderWriter {
static_cast<int>(bytes), 0) > 0; static_cast<int>(bytes), 0) > 0;
} }
addrinfo* const info;
private: private:
std::atomic<SOCKET> sock = {InvalidSocket}; addrinfo* const info;
SOCKET s = InvalidSocket;
RWMutex mutex;
}; };
namespace dap { namespace dap {
Socket::Socket(const char* address, const char* port) Socket::Socket(const char* address, const char* port)
: shared(Shared::create(address, port)) { : shared(Shared::create(address, port)) {
if (!shared) { if (shared) {
return; shared->lock([&](SOCKET socket, const addrinfo* info) {
} if (bind(socket, info->ai_addr, (int)info->ai_addrlen) != 0) {
auto socket = shared->socket(); shared.reset();
return;
}
if (bind(socket, shared->info->ai_addr, (int)shared->info->ai_addrlen) != 0) { if (listen(socket, 0) != 0) {
shared.reset(); shared.reset();
return; return;
} }
});
if (listen(socket, 0) != 0) {
shared.reset();
return;
} }
} }
std::shared_ptr<ReaderWriter> Socket::accept() const { std::shared_ptr<ReaderWriter> Socket::accept() const {
std::shared_ptr<Shared> out;
if (shared) { if (shared) {
SOCKET socket = shared->socket(); shared->lock([&](SOCKET socket, const addrinfo*) {
if (socket != InvalidSocket) { if (socket != InvalidSocket) {
init(); init();
auto out = std::make_shared<Shared>(::accept(socket, 0, 0)); out = std::make_shared<Shared>(::accept(socket, 0, 0));
out->setOptions(); out->setOptions();
return out; }
} });
} }
return out;
return {};
} }
bool Socket::isOpen() const { bool Socket::isOpen() const {
@ -259,47 +271,50 @@ std::shared_ptr<ReaderWriter> Socket::connect(const char* address,
return nullptr; return nullptr;
} }
if (timeoutMillis == 0) { std::shared_ptr<ReaderWriter> out;
if (::connect(shared->socket(), shared->info->ai_addr, shared->lock([&](SOCKET socket, const addrinfo* info) {
(int)shared->info->ai_addrlen) == 0) { if (socket == InvalidSocket) {
return shared; return;
} }
if (timeoutMillis == 0) {
if (::connect(socket, info->ai_addr, (int)info->ai_addrlen) == 0) {
out = shared;
}
return;
}
if (!setBlocking(socket, false)) {
return;
}
auto res = ::connect(socket, info->ai_addr, (int)info->ai_addrlen);
if (res == 0) {
if (setBlocking(socket, true)) {
out = shared;
}
} else {
const auto microseconds = timeoutMillis * 1000;
fd_set fdset;
FD_ZERO(&fdset);
FD_SET(socket, &fdset);
timeval tv;
tv.tv_sec = microseconds / 1000000;
tv.tv_usec = microseconds - (tv.tv_sec * 1000000);
res = select(static_cast<int>(socket + 1), nullptr, &fdset, nullptr, &tv);
if (res > 0 && !errored(socket) && setBlocking(socket, true)) {
out = shared;
}
}
});
if (!out) {
return nullptr; return nullptr;
} }
auto s = shared->socket(); return out->isOpen() ? out : nullptr;
if (s == InvalidSocket) {
return nullptr;
}
if (!shared->setBlocking(false)) {
return nullptr;
}
auto res = ::connect(s, shared->info->ai_addr, (int)shared->info->ai_addrlen);
if (res == 0) {
return shared->setBlocking(true) ? shared : nullptr;
}
const auto microseconds = timeoutMillis * 1000;
fd_set fdset;
FD_ZERO(&fdset);
FD_SET(s, &fdset);
timeval tv;
tv.tv_sec = microseconds / 1000000;
tv.tv_usec = microseconds - (tv.tv_sec * 1000000);
res = select(static_cast<int>(s + 1), nullptr, &fdset, nullptr, &tv);
if (res <= 0) {
return nullptr;
}
if (shared->errored()) {
return nullptr;
}
return shared->setBlocking(true) ? shared : nullptr;
} }
} // namespace dap } // namespace dap