Implement timeouts for dap::Socket::connect

Fixes: #24
This commit is contained in:
Ben Clayton 2020-06-05 12:47:38 +01:00
parent 261d62d91e
commit 13c9e7d465
6 changed files with 161 additions and 34 deletions

View File

@ -195,6 +195,7 @@ if(CPPDAP_BUILD_TESTS)
${CPPDAP_SRC_DIR}/network_test.cpp ${CPPDAP_SRC_DIR}/network_test.cpp
${CPPDAP_SRC_DIR}/optional_test.cpp ${CPPDAP_SRC_DIR}/optional_test.cpp
${CPPDAP_SRC_DIR}/session_test.cpp ${CPPDAP_SRC_DIR}/session_test.cpp
${CPPDAP_SRC_DIR}/socket_test.cpp
${CPPDAP_SRC_DIR}/typeinfo_test.cpp ${CPPDAP_SRC_DIR}/typeinfo_test.cpp
${CPPDAP_SRC_DIR}/variant_test.cpp ${CPPDAP_SRC_DIR}/variant_test.cpp
${CPPDAP_GOOGLETEST_DIR}/googletest/src/gtest-all.cc ${CPPDAP_GOOGLETEST_DIR}/googletest/src/gtest-all.cc

View File

@ -24,12 +24,16 @@ class ReaderWriter;
namespace net { namespace net {
// connect() connects to the given TCP address and port. // connect() connects to the given TCP address and port.
std::shared_ptr<ReaderWriter> connect(const char* addr, int port); // If timeoutMillis is non-zero and no connection was made before timeoutMillis
// milliseconds, then nullptr is returned.
std::shared_ptr<ReaderWriter> connect(const char* addr,
int port,
uint32_t timeoutMillis = 0);
// Server implements a basic TCP server. // Server implements a basic TCP server.
class Server { class Server {
// IgnoreErrors matches the OnError signature, and does nothing. // ignoreErrors() matches the OnError signature, and does nothing.
static inline void IgnoreErrors(const char*) {} static inline void ignoreErrors(const char*) {}
public: public:
using OnError = std::function<void(const char*)>; using OnError = std::function<void(const char*)>;
@ -45,7 +49,7 @@ class Server {
// onError will be called for any connection errors. // onError will be called for any connection errors.
virtual bool start(int port, virtual bool start(int port,
const OnConnect& callback, const OnConnect& callback,
const OnError& onError = IgnoreErrors) = 0; const OnError& onError = ignoreErrors) = 0;
// stop() stops listening for connections. // stop() stops listening for connections.
// stop() is implicitly called on destruction. // stop() is implicitly called on destruction.

View File

@ -92,8 +92,10 @@ std::unique_ptr<Server> Server::create() {
return std::unique_ptr<Server>(new Impl()); return std::unique_ptr<Server>(new Impl());
} }
std::shared_ptr<ReaderWriter> connect(const char* addr, int port) { std::shared_ptr<ReaderWriter> connect(const char* addr,
return Socket::connect(addr, std::to_string(port).c_str()); int port,
uint32_t timeoutMillis) {
return Socket::connect(addr, std::to_string(port).c_str(), timeoutMillis);
} }
} // namespace net } // namespace net

View File

@ -32,6 +32,8 @@ namespace {
std::atomic<int> wsaInitCount = {0}; std::atomic<int> wsaInitCount = {0};
} // anonymous namespace } // anonymous namespace
#else #else
#include <fcntl.h>
#include <unistd.h>
namespace { namespace {
using SOCKET = int; using SOCKET = int;
} // anonymous namespace } // anonymous namespace
@ -39,27 +41,27 @@ using SOCKET = int;
namespace { namespace {
constexpr SOCKET InvalidSocket = static_cast<SOCKET>(-1); constexpr SOCKET InvalidSocket = static_cast<SOCKET>(-1);
static void init() {
#if defined(_WIN32)
if (wsaInitCount++ == 0) {
WSADATA winsockData;
(void)WSAStartup(MAKEWORD(2, 2), &winsockData);
}
#endif
}
static void term() {
#if defined(_WIN32)
if (--wsaInitCount == 0) {
WSACleanup();
}
#endif
}
} // anonymous namespace } // anonymous namespace
class dap::Socket::Shared : public dap::ReaderWriter { class dap::Socket::Shared : public dap::ReaderWriter {
public: public:
static void init() {
#if defined(_WIN32)
if (wsaInitCount++ == 0) {
WSADATA winsockData;
(void)WSAStartup(MAKEWORD(2, 2), &winsockData);
}
#endif
}
static void term() {
#if defined(_WIN32)
if (--wsaInitCount == 0) {
WSACleanup();
}
#endif
}
static std::shared_ptr<Shared> create(const char* address, const char* port) { static std::shared_ptr<Shared> create(const char* address, const char* port) {
init(); init();
@ -123,24 +125,45 @@ class dap::Socket::Shared : public dap::ReaderWriter {
setsockopt(s, IPPROTO_TCP, TCP_NODELAY, (char*)&enable, sizeof(enable)); setsockopt(s, IPPROTO_TCP, TCP_NODELAY, (char*)&enable, sizeof(enable));
} }
// dap::ReaderWriter compliance bool setBlocking(bool blocking) {
bool isOpen() {
SOCKET s = socket(); SOCKET s = socket();
if (s == InvalidSocket) { if (s == InvalidSocket) {
return false; return false;
} }
#if defined(_WIN32)
u_long mode = blocking ? 0 : 1;
return ioctlsocket(s, FIONBIO, &mode) == NO_ERROR;
#else
auto arg = fcntl(s, F_GETFL, nullptr);
if (arg < 0) {
return false;
}
arg = blocking ? (arg & ~O_NONBLOCK) : (arg | O_NONBLOCK);
return fcntl(s, F_SETFL, arg) >= 0;
#endif
}
bool errored() {
SOCKET s = socket();
if (s == InvalidSocket) {
return true;
}
char error = 0; char error = 0;
socklen_t len = sizeof(error); socklen_t len = sizeof(error);
getsockopt(s, SOL_SOCKET, SO_ERROR, &error, &len); getsockopt(s, SOL_SOCKET, SO_ERROR, &error, &len);
if (error != 0) { if (error != 0) {
sock.compare_exchange_weak(s, InvalidSocket); sock.compare_exchange_weak(s, InvalidSocket);
return false; return true;
} }
return true; return false;
} }
// dap::ReaderWriter compliance
bool isOpen() { return !errored(); }
void close() { void close() {
SOCKET s = sock.exchange(InvalidSocket); SOCKET s = sock.exchange(InvalidSocket);
if (s != InvalidSocket) { if (s != InvalidSocket) {
@ -195,7 +218,7 @@ Socket::Socket(const char* address, const char* port)
return; return;
} }
if (listen(socket, 1) != 0) { if (listen(socket, 0) != 0) {
shared.reset(); shared.reset();
return; return;
} }
@ -205,6 +228,7 @@ std::shared_ptr<ReaderWriter> Socket::accept() const {
if (shared) { if (shared) {
SOCKET socket = shared->socket(); SOCKET socket = shared->socket();
if (socket != InvalidSocket) { if (socket != InvalidSocket) {
init();
auto out = std::make_shared<Shared>(::accept(socket, 0, 0)); auto out = std::make_shared<Shared>(::accept(socket, 0, 0));
out->setOptions(); out->setOptions();
return out; return out;
@ -228,13 +252,54 @@ void Socket::close() const {
} }
std::shared_ptr<ReaderWriter> Socket::connect(const char* address, std::shared_ptr<ReaderWriter> Socket::connect(const char* address,
const char* port) { const char* port,
uint32_t timeoutMillis) {
auto shared = Shared::create(address, port); auto shared = Shared::create(address, port);
if (::connect(shared->socket(), shared->info->ai_addr, if (!shared) {
(int)shared->info->ai_addrlen) == 0) { return nullptr;
return shared;
} }
return {};
if (timeoutMillis == 0) {
if (::connect(shared->socket(), shared->info->ai_addr,
(int)shared->info->ai_addrlen) == 0) {
return shared;
}
return nullptr;
}
auto s = shared->socket();
if (s == InvalidSocket) {
return nullptr;
}
if (!shared->setBlocking(false)) {
return nullptr;
}
auto res = ::connect(s, shared->info->ai_addr, (int)shared->info->ai_addrlen);
if (res == 0) {
return shared->setBlocking(true) ? shared : nullptr;
}
const auto microseconds = timeoutMillis * 1000;
fd_set fdset;
FD_ZERO(&fdset);
FD_SET(s, &fdset);
timeval tv;
tv.tv_sec = microseconds / 1000000;
tv.tv_usec = microseconds - (tv.tv_sec * 1000000);
res = select(static_cast<int>(s + 1), nullptr, &fdset, nullptr, &tv);
if (res <= 0) {
return nullptr;
}
if (shared->errored()) {
return nullptr;
}
return shared->setBlocking(true) ? shared : nullptr;
} }
} // namespace dap } // namespace dap

View File

@ -26,8 +26,12 @@ class Socket {
public: public:
class Shared; class Shared;
// connect() connects to the given TCP address and port.
// If timeoutMillis is non-zero and no connection was made before
// timeoutMillis milliseconds, then nullptr is returned.
static std::shared_ptr<ReaderWriter> connect(const char* address, static std::shared_ptr<ReaderWriter> connect(const char* address,
const char* port); const char* port,
uint32_t timeoutMillis);
Socket(const char* address, const char* port); Socket(const char* address, const char* port);
bool isOpen() const; bool isOpen() const;

51
src/socket_test.cpp Normal file
View File

@ -0,0 +1,51 @@
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "socket.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include <chrono>
#include <vector>
TEST(Socket, ConnectTimeout) {
const char* port = "19021";
const int timeoutMillis = 200;
const int maxAttempts = 1024;
using namespace std::chrono;
auto server = dap::Socket("localhost", port);
std::vector<std::shared_ptr<dap::ReaderWriter>> connections;
for (int i = 0; i < maxAttempts; i++) {
auto start = system_clock::now();
auto connection = dap::Socket::connect("localhost", port, timeoutMillis);
auto end = system_clock::now();
if (!connection) {
auto timeTakenMillis = duration_cast<milliseconds>(end - start).count();
ASSERT_GE(timeTakenMillis + 20, // +20ms for a bit of timing wiggle room
timeoutMillis);
return;
}
// Keep hold of the connections to saturate any incoming socket buffers.
connections.emplace_back(std::move(connection));
}
FAIL() << "Failed to test timeout after " << maxAttempts << " attempts";
}