From 13c9e7d4656f91ae1cb7a3798881dfd94bb2c790 Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Fri, 5 Jun 2020 12:47:38 +0100 Subject: [PATCH] Implement timeouts for dap::Socket::connect Fixes: #24 --- CMakeLists.txt | 1 + include/dap/network.h | 12 +++-- src/network.cpp | 6 ++- src/socket.cpp | 119 ++++++++++++++++++++++++++++++++---------- src/socket.h | 6 ++- src/socket_test.cpp | 51 ++++++++++++++++++ 6 files changed, 161 insertions(+), 34 deletions(-) create mode 100644 src/socket_test.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 2e2f28f..9cdba73 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -195,6 +195,7 @@ if(CPPDAP_BUILD_TESTS) ${CPPDAP_SRC_DIR}/network_test.cpp ${CPPDAP_SRC_DIR}/optional_test.cpp ${CPPDAP_SRC_DIR}/session_test.cpp + ${CPPDAP_SRC_DIR}/socket_test.cpp ${CPPDAP_SRC_DIR}/typeinfo_test.cpp ${CPPDAP_SRC_DIR}/variant_test.cpp ${CPPDAP_GOOGLETEST_DIR}/googletest/src/gtest-all.cc diff --git a/include/dap/network.h b/include/dap/network.h index d0d76e8..9d14f6b 100644 --- a/include/dap/network.h +++ b/include/dap/network.h @@ -24,12 +24,16 @@ class ReaderWriter; namespace net { // connect() connects to the given TCP address and port. -std::shared_ptr connect(const char* addr, int port); +// If timeoutMillis is non-zero and no connection was made before timeoutMillis +// milliseconds, then nullptr is returned. +std::shared_ptr connect(const char* addr, + int port, + uint32_t timeoutMillis = 0); // Server implements a basic TCP server. class Server { - // IgnoreErrors matches the OnError signature, and does nothing. - static inline void IgnoreErrors(const char*) {} + // ignoreErrors() matches the OnError signature, and does nothing. + static inline void ignoreErrors(const char*) {} public: using OnError = std::function; @@ -45,7 +49,7 @@ class Server { // onError will be called for any connection errors. virtual bool start(int port, const OnConnect& callback, - const OnError& onError = IgnoreErrors) = 0; + const OnError& onError = ignoreErrors) = 0; // stop() stops listening for connections. // stop() is implicitly called on destruction. diff --git a/src/network.cpp b/src/network.cpp index ff2a311..887d762 100644 --- a/src/network.cpp +++ b/src/network.cpp @@ -92,8 +92,10 @@ std::unique_ptr Server::create() { return std::unique_ptr(new Impl()); } -std::shared_ptr connect(const char* addr, int port) { - return Socket::connect(addr, std::to_string(port).c_str()); +std::shared_ptr connect(const char* addr, + int port, + uint32_t timeoutMillis) { + return Socket::connect(addr, std::to_string(port).c_str(), timeoutMillis); } } // namespace net diff --git a/src/socket.cpp b/src/socket.cpp index 506d7ae..653b384 100644 --- a/src/socket.cpp +++ b/src/socket.cpp @@ -32,6 +32,8 @@ namespace { std::atomic wsaInitCount = {0}; } // anonymous namespace #else +#include +#include namespace { using SOCKET = int; } // anonymous namespace @@ -39,27 +41,27 @@ using SOCKET = int; namespace { constexpr SOCKET InvalidSocket = static_cast(-1); +static void init() { +#if defined(_WIN32) + if (wsaInitCount++ == 0) { + WSADATA winsockData; + (void)WSAStartup(MAKEWORD(2, 2), &winsockData); + } +#endif +} + +static void term() { +#if defined(_WIN32) + if (--wsaInitCount == 0) { + WSACleanup(); + } +#endif +} + } // anonymous namespace class dap::Socket::Shared : public dap::ReaderWriter { public: - static void init() { -#if defined(_WIN32) - if (wsaInitCount++ == 0) { - WSADATA winsockData; - (void)WSAStartup(MAKEWORD(2, 2), &winsockData); - } -#endif - } - - static void term() { -#if defined(_WIN32) - if (--wsaInitCount == 0) { - WSACleanup(); - } -#endif - } - static std::shared_ptr create(const char* address, const char* port) { init(); @@ -123,24 +125,45 @@ class dap::Socket::Shared : public dap::ReaderWriter { setsockopt(s, IPPROTO_TCP, TCP_NODELAY, (char*)&enable, sizeof(enable)); } - // dap::ReaderWriter compliance - bool isOpen() { + bool setBlocking(bool blocking) { SOCKET s = socket(); if (s == InvalidSocket) { return false; } +#if defined(_WIN32) + u_long mode = blocking ? 0 : 1; + return ioctlsocket(s, FIONBIO, &mode) == NO_ERROR; +#else + auto arg = fcntl(s, F_GETFL, nullptr); + if (arg < 0) { + return false; + } + arg = blocking ? (arg & ~O_NONBLOCK) : (arg | O_NONBLOCK); + return fcntl(s, F_SETFL, arg) >= 0; +#endif + } + + bool errored() { + SOCKET s = socket(); + if (s == InvalidSocket) { + return true; + } + char error = 0; socklen_t len = sizeof(error); getsockopt(s, SOL_SOCKET, SO_ERROR, &error, &len); if (error != 0) { sock.compare_exchange_weak(s, InvalidSocket); - return false; + return true; } - return true; + return false; } + // dap::ReaderWriter compliance + bool isOpen() { return !errored(); } + void close() { SOCKET s = sock.exchange(InvalidSocket); if (s != InvalidSocket) { @@ -195,7 +218,7 @@ Socket::Socket(const char* address, const char* port) return; } - if (listen(socket, 1) != 0) { + if (listen(socket, 0) != 0) { shared.reset(); return; } @@ -205,6 +228,7 @@ std::shared_ptr Socket::accept() const { if (shared) { SOCKET socket = shared->socket(); if (socket != InvalidSocket) { + init(); auto out = std::make_shared(::accept(socket, 0, 0)); out->setOptions(); return out; @@ -228,13 +252,54 @@ void Socket::close() const { } std::shared_ptr Socket::connect(const char* address, - const char* port) { + const char* port, + uint32_t timeoutMillis) { auto shared = Shared::create(address, port); - if (::connect(shared->socket(), shared->info->ai_addr, - (int)shared->info->ai_addrlen) == 0) { - return shared; + if (!shared) { + return nullptr; } - return {}; + + if (timeoutMillis == 0) { + if (::connect(shared->socket(), shared->info->ai_addr, + (int)shared->info->ai_addrlen) == 0) { + return shared; + } + return nullptr; + } + + auto s = shared->socket(); + if (s == InvalidSocket) { + return nullptr; + } + + if (!shared->setBlocking(false)) { + return nullptr; + } + + auto res = ::connect(s, shared->info->ai_addr, (int)shared->info->ai_addrlen); + if (res == 0) { + return shared->setBlocking(true) ? shared : nullptr; + } + + const auto microseconds = timeoutMillis * 1000; + + fd_set fdset; + FD_ZERO(&fdset); + FD_SET(s, &fdset); + + timeval tv; + tv.tv_sec = microseconds / 1000000; + tv.tv_usec = microseconds - (tv.tv_sec * 1000000); + res = select(static_cast(s + 1), nullptr, &fdset, nullptr, &tv); + if (res <= 0) { + return nullptr; + } + + if (shared->errored()) { + return nullptr; + } + + return shared->setBlocking(true) ? shared : nullptr; } } // namespace dap diff --git a/src/socket.h b/src/socket.h index ea722c6..ec5b0df 100644 --- a/src/socket.h +++ b/src/socket.h @@ -26,8 +26,12 @@ class Socket { public: class Shared; + // connect() connects to the given TCP address and port. + // If timeoutMillis is non-zero and no connection was made before + // timeoutMillis milliseconds, then nullptr is returned. static std::shared_ptr connect(const char* address, - const char* port); + const char* port, + uint32_t timeoutMillis); Socket(const char* address, const char* port); bool isOpen() const; diff --git a/src/socket_test.cpp b/src/socket_test.cpp new file mode 100644 index 0000000..c219bf8 --- /dev/null +++ b/src/socket_test.cpp @@ -0,0 +1,51 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "socket.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include +#include + +TEST(Socket, ConnectTimeout) { + const char* port = "19021"; + const int timeoutMillis = 200; + const int maxAttempts = 1024; + + using namespace std::chrono; + + auto server = dap::Socket("localhost", port); + + std::vector> connections; + + for (int i = 0; i < maxAttempts; i++) { + auto start = system_clock::now(); + auto connection = dap::Socket::connect("localhost", port, timeoutMillis); + auto end = system_clock::now(); + + if (!connection) { + auto timeTakenMillis = duration_cast(end - start).count(); + ASSERT_GE(timeTakenMillis + 20, // +20ms for a bit of timing wiggle room + timeoutMillis); + return; + } + + // Keep hold of the connections to saturate any incoming socket buffers. + connections.emplace_back(std::move(connection)); + } + + FAIL() << "Failed to test timeout after " << maxAttempts << " attempts"; +}