Socket: Use the RWMutex to fix TSAN error
... about `close()`ing the socket on one thread while in a blocking `recv()` or `send()` call on another thread. Fixes #35
This commit is contained in:
parent
53a62fd794
commit
9003ee55b9
227
src/socket.cpp
227
src/socket.cpp
@ -14,6 +14,8 @@
|
|||||||
|
|
||||||
#include "socket.h"
|
#include "socket.h"
|
||||||
|
|
||||||
|
#include "rwmutex.h"
|
||||||
|
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
#include <winsock2.h>
|
#include <winsock2.h>
|
||||||
#include <ws2tcpip.h>
|
#include <ws2tcpip.h>
|
||||||
@ -41,7 +43,7 @@ using SOCKET = int;
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
constexpr SOCKET InvalidSocket = static_cast<SOCKET>(-1);
|
constexpr SOCKET InvalidSocket = static_cast<SOCKET>(-1);
|
||||||
static void init() {
|
void init() {
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
if (wsaInitCount++ == 0) {
|
if (wsaInitCount++ == 0) {
|
||||||
WSADATA winsockData;
|
WSADATA winsockData;
|
||||||
@ -50,7 +52,7 @@ static void init() {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
static void term() {
|
void term() {
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
if (--wsaInitCount == 0) {
|
if (--wsaInitCount == 0) {
|
||||||
WSACleanup();
|
WSACleanup();
|
||||||
@ -58,6 +60,30 @@ static void term() {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool setBlocking(SOCKET s, bool blocking) {
|
||||||
|
#if defined(_WIN32)
|
||||||
|
u_long mode = blocking ? 0 : 1;
|
||||||
|
return ioctlsocket(s, FIONBIO, &mode) == NO_ERROR;
|
||||||
|
#else
|
||||||
|
auto arg = fcntl(s, F_GETFL, nullptr);
|
||||||
|
if (arg < 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
arg = blocking ? (arg & ~O_NONBLOCK) : (arg | O_NONBLOCK);
|
||||||
|
return fcntl(s, F_SETFL, arg) >= 0;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
bool errored(SOCKET s) {
|
||||||
|
if (s == InvalidSocket) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
char error = 0;
|
||||||
|
socklen_t len = sizeof(error);
|
||||||
|
getsockopt(s, SOL_SOCKET, SO_ERROR, &error, &len);
|
||||||
|
return error != 0;
|
||||||
|
}
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
class dap::Socket::Shared : public dap::ReaderWriter {
|
class dap::Socket::Shared : public dap::ReaderWriter {
|
||||||
@ -87,8 +113,8 @@ class dap::Socket::Shared : public dap::ReaderWriter {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
Shared(SOCKET socket) : info(nullptr), sock(socket) {}
|
Shared(SOCKET socket) : info(nullptr), s(socket) {}
|
||||||
Shared(addrinfo* info, SOCKET socket) : info(info), sock(socket) {}
|
Shared(addrinfo* info, SOCKET socket) : info(info), s(socket) {}
|
||||||
|
|
||||||
~Shared() {
|
~Shared() {
|
||||||
freeaddrinfo(info);
|
freeaddrinfo(info);
|
||||||
@ -96,10 +122,14 @@ class dap::Socket::Shared : public dap::ReaderWriter {
|
|||||||
term();
|
term();
|
||||||
}
|
}
|
||||||
|
|
||||||
SOCKET socket() { return sock.load(); }
|
template <typename FUNCTION>
|
||||||
|
void lock(FUNCTION&& f) {
|
||||||
|
RLock l(mutex);
|
||||||
|
f(s, info);
|
||||||
|
}
|
||||||
|
|
||||||
void setOptions() {
|
void setOptions() {
|
||||||
SOCKET s = socket();
|
RLock l(mutex);
|
||||||
if (s == InvalidSocket) {
|
if (s == InvalidSocket) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -125,59 +155,42 @@ class dap::Socket::Shared : public dap::ReaderWriter {
|
|||||||
setsockopt(s, IPPROTO_TCP, TCP_NODELAY, (char*)&enable, sizeof(enable));
|
setsockopt(s, IPPROTO_TCP, TCP_NODELAY, (char*)&enable, sizeof(enable));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool setBlocking(bool blocking) {
|
// dap::ReaderWriter compliance
|
||||||
SOCKET s = socket();
|
bool isOpen() {
|
||||||
if (s == InvalidSocket) {
|
{
|
||||||
return false;
|
RLock l(mutex);
|
||||||
|
if ((s != InvalidSocket) && !errored(s)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
WLock lock(mutex);
|
||||||
#if defined(_WIN32)
|
s = InvalidSocket;
|
||||||
u_long mode = blocking ? 0 : 1;
|
|
||||||
return ioctlsocket(s, FIONBIO, &mode) == NO_ERROR;
|
|
||||||
#else
|
|
||||||
auto arg = fcntl(s, F_GETFL, nullptr);
|
|
||||||
if (arg < 0) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
arg = blocking ? (arg & ~O_NONBLOCK) : (arg | O_NONBLOCK);
|
|
||||||
return fcntl(s, F_SETFL, arg) >= 0;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
bool errored() {
|
|
||||||
SOCKET s = socket();
|
|
||||||
if (s == InvalidSocket) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
char error = 0;
|
|
||||||
socklen_t len = sizeof(error);
|
|
||||||
getsockopt(s, SOL_SOCKET, SO_ERROR, &error, &len);
|
|
||||||
if (error != 0) {
|
|
||||||
sock.compare_exchange_weak(s, InvalidSocket);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// dap::ReaderWriter compliance
|
|
||||||
bool isOpen() { return !errored(); }
|
|
||||||
|
|
||||||
void close() {
|
void close() {
|
||||||
SOCKET s = sock.exchange(InvalidSocket);
|
#if !defined(_WIN32)
|
||||||
|
{
|
||||||
|
RLock l(mutex);
|
||||||
|
if (s != InvalidSocket) {
|
||||||
|
::shutdown(s, SHUT_RDWR);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
WLock l(mutex);
|
||||||
if (s != InvalidSocket) {
|
if (s != InvalidSocket) {
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
closesocket(s);
|
closesocket(s);
|
||||||
#else
|
#else
|
||||||
::shutdown(s, SHUT_RDWR);
|
|
||||||
::close(s);
|
::close(s);
|
||||||
#endif
|
#endif
|
||||||
|
s = InvalidSocket;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t read(void* buffer, size_t bytes) {
|
size_t read(void* buffer, size_t bytes) {
|
||||||
SOCKET s = socket();
|
RLock lock(mutex);
|
||||||
if (s == InvalidSocket) {
|
if (s == InvalidSocket) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
@ -187,7 +200,7 @@ class dap::Socket::Shared : public dap::ReaderWriter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool write(const void* buffer, size_t bytes) {
|
bool write(const void* buffer, size_t bytes) {
|
||||||
SOCKET s = socket();
|
RLock lock(mutex);
|
||||||
if (s == InvalidSocket) {
|
if (s == InvalidSocket) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -198,44 +211,43 @@ class dap::Socket::Shared : public dap::ReaderWriter {
|
|||||||
static_cast<int>(bytes), 0) > 0;
|
static_cast<int>(bytes), 0) > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
addrinfo* const info;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::atomic<SOCKET> sock = {InvalidSocket};
|
addrinfo* const info;
|
||||||
|
SOCKET s = InvalidSocket;
|
||||||
|
RWMutex mutex;
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace dap {
|
namespace dap {
|
||||||
|
|
||||||
Socket::Socket(const char* address, const char* port)
|
Socket::Socket(const char* address, const char* port)
|
||||||
: shared(Shared::create(address, port)) {
|
: shared(Shared::create(address, port)) {
|
||||||
if (!shared) {
|
if (shared) {
|
||||||
return;
|
shared->lock([&](SOCKET socket, const addrinfo* info) {
|
||||||
}
|
if (bind(socket, info->ai_addr, (int)info->ai_addrlen) != 0) {
|
||||||
auto socket = shared->socket();
|
shared.reset();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (bind(socket, shared->info->ai_addr, (int)shared->info->ai_addrlen) != 0) {
|
if (listen(socket, 0) != 0) {
|
||||||
shared.reset();
|
shared.reset();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
});
|
||||||
if (listen(socket, 0) != 0) {
|
|
||||||
shared.reset();
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<ReaderWriter> Socket::accept() const {
|
std::shared_ptr<ReaderWriter> Socket::accept() const {
|
||||||
|
std::shared_ptr<Shared> out;
|
||||||
if (shared) {
|
if (shared) {
|
||||||
SOCKET socket = shared->socket();
|
shared->lock([&](SOCKET socket, const addrinfo*) {
|
||||||
if (socket != InvalidSocket) {
|
if (socket != InvalidSocket) {
|
||||||
init();
|
init();
|
||||||
auto out = std::make_shared<Shared>(::accept(socket, 0, 0));
|
out = std::make_shared<Shared>(::accept(socket, 0, 0));
|
||||||
out->setOptions();
|
out->setOptions();
|
||||||
return out;
|
}
|
||||||
}
|
});
|
||||||
}
|
}
|
||||||
|
return out;
|
||||||
return {};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Socket::isOpen() const {
|
bool Socket::isOpen() const {
|
||||||
@ -259,47 +271,50 @@ std::shared_ptr<ReaderWriter> Socket::connect(const char* address,
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (timeoutMillis == 0) {
|
std::shared_ptr<ReaderWriter> out;
|
||||||
if (::connect(shared->socket(), shared->info->ai_addr,
|
shared->lock([&](SOCKET socket, const addrinfo* info) {
|
||||||
(int)shared->info->ai_addrlen) == 0) {
|
if (socket == InvalidSocket) {
|
||||||
return shared;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (timeoutMillis == 0) {
|
||||||
|
if (::connect(socket, info->ai_addr, (int)info->ai_addrlen) == 0) {
|
||||||
|
out = shared;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!setBlocking(socket, false)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = ::connect(socket, info->ai_addr, (int)info->ai_addrlen);
|
||||||
|
if (res == 0) {
|
||||||
|
if (setBlocking(socket, true)) {
|
||||||
|
out = shared;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const auto microseconds = timeoutMillis * 1000;
|
||||||
|
|
||||||
|
fd_set fdset;
|
||||||
|
FD_ZERO(&fdset);
|
||||||
|
FD_SET(socket, &fdset);
|
||||||
|
|
||||||
|
timeval tv;
|
||||||
|
tv.tv_sec = microseconds / 1000000;
|
||||||
|
tv.tv_usec = microseconds - (tv.tv_sec * 1000000);
|
||||||
|
res = select(static_cast<int>(socket + 1), nullptr, &fdset, nullptr, &tv);
|
||||||
|
if (res > 0 && !errored(socket) && setBlocking(socket, true)) {
|
||||||
|
out = shared;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!out) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto s = shared->socket();
|
return out->isOpen() ? out : nullptr;
|
||||||
if (s == InvalidSocket) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!shared->setBlocking(false)) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto res = ::connect(s, shared->info->ai_addr, (int)shared->info->ai_addrlen);
|
|
||||||
if (res == 0) {
|
|
||||||
return shared->setBlocking(true) ? shared : nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto microseconds = timeoutMillis * 1000;
|
|
||||||
|
|
||||||
fd_set fdset;
|
|
||||||
FD_ZERO(&fdset);
|
|
||||||
FD_SET(s, &fdset);
|
|
||||||
|
|
||||||
timeval tv;
|
|
||||||
tv.tv_sec = microseconds / 1000000;
|
|
||||||
tv.tv_usec = microseconds - (tv.tv_sec * 1000000);
|
|
||||||
res = select(static_cast<int>(s + 1), nullptr, &fdset, nullptr, &tv);
|
|
||||||
if (res <= 0) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (shared->errored()) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
return shared->setBlocking(true) ? shared : nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace dap
|
} // namespace dap
|
||||||
|
Loading…
x
Reference in New Issue
Block a user