Skip to content

Commit

Permalink
Merge pull request #378 from CopernicaMarketingSoftware/poll-replace-…
Browse files Browse the repository at this point in the history
…select

select only supports upto fd 1024, which can cause stack smashing if using higher ones
  • Loading branch information
EmielBruijntjes authored Oct 30, 2020
2 parents 79d8839 + 537ee3f commit 7c07ab1
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 66 deletions.
79 changes: 30 additions & 49 deletions src/linux_tcp/poll.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
*/
#pragma once

/**
* Includes
*/
#include <poll.h>

/**
* Begin of namespace
*/
Expand All @@ -26,7 +31,7 @@ class Poll
* Set with just one filedescriptor
* @var fd_set
*/
fd_set _set;
pollfd _fd;

/**
* The socket filedescriptor
Expand All @@ -39,13 +44,10 @@ class Poll
* Constructor
* @param fd the filedescriptor that we're waiting on
*/
Poll(int fd) : _socket(fd)
Poll(int fd)
{
// initialize the set
FD_ZERO(&_set);

// add the one socket
FD_SET(_socket, &_set);
// set the fd
_fd.fd = fd;
}

/**
Expand All @@ -60,66 +62,45 @@ class Poll
virtual ~Poll() = default;

/**
* Wait until the filedescriptor becomes readable
* @param block block until readable
* Check if a file descriptor is readable.
* @return bool
*/
bool readable(bool block)
{
// wait for the socket
if (block) return select(_socket + 1, &_set, nullptr, nullptr, nullptr) > 0;

// we do not want to block, so we use a small timeout
struct timeval timeout;

// no timeout at all
timeout.tv_sec = timeout.tv_usec = 0;

// no timeout at all
return select(_socket + 1, &_set, nullptr, nullptr, &timeout) > 0;
bool readable()
{
// check for readable
_fd.events = POLLIN;
_fd.revents = 0;

// poll the fd with no timeout
return poll(&_fd, 1, 0) > 0;
}

/**
* Wait until the filedescriptor becomes writable
* @param block block until readable
* @return bool
*/
bool writable(bool block)
bool writable()
{
// wait for the socket
if (block) return select(_socket + 1, nullptr, &_set, nullptr, nullptr) > 0;
// check for readable
_fd.events = POLLOUT;
_fd.revents = 0;

// we do not want to block, so we use a small timeout
struct timeval timeout;

// no timeout at all
timeout.tv_sec = timeout.tv_usec = 0;

// no timeout at all
return select(_socket + 1, nullptr, &_set, nullptr, &timeout) > 0;
// poll the fd with no timeout
return poll(&_fd, 1, 0) > 0;
}

/**
* Wait until a filedescriptor becomes active (readable or writable)
* @param block block until readable
* @return bool
*/
bool active(bool block)
bool active()
{
// accommodate restrict qualifier on fd_set params
fd_set set2 = _set;
// check for readable
_fd.events = POLLIN | POLLOUT;
_fd.revents = 0;

// wait for the socket
if (block) return select(_socket + 1, &_set, &set2, nullptr, nullptr) > 0;

// we do not want to block, so we use a small timeout
struct timeval timeout;

// no timeout at all
timeout.tv_sec = timeout.tv_usec = 0;

// no timeout at all
return select(_socket + 1, &_set, &set2, nullptr, &timeout) > 0;
// poll the fd with no timeout
return poll(&_fd, 1, 0) > 0;
}
};

Expand Down
8 changes: 4 additions & 4 deletions src/linux_tcp/sslconnected.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ class SslConnected : public TcpExtState
// object to poll a socket
Poll poll(_socket);

// wait until socket is readable, but do not block
return poll.readable(false);
// check if socket is readable
return poll.readable();
}

/**
Expand All @@ -233,8 +233,8 @@ class SslConnected : public TcpExtState
// object to poll a socket
Poll poll(_socket);

// wait until socket is writable, but do not block
return poll.writable(false);
// check if socket is writable
return poll.writable();
}

/**
Expand Down
23 changes: 10 additions & 13 deletions src/linux_tcp/tcpresolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "sslhandshake.h"
#include <thread>
#include <netinet/in.h>
#include <poll.h>

/**
* Set up namespace
Expand Down Expand Up @@ -99,8 +100,8 @@ class TcpResolver : public TcpExtState
// get address info
AddressInfo addresses(_hostname.data(), _port);

// an fdset to monitor for writability
fd_set writeset;
// the pollfd structure, needed for poll()
pollfd fd;

// iterate over the addresses
for (size_t i = 0; i < addresses.size(); ++i)
Expand All @@ -117,17 +118,13 @@ class TcpResolver : public TcpExtState
// try to connect non-blocking
if (connect(_socket, addresses[i]->ai_addr, addresses[i]->ai_addrlen) == 0) break;

// we set the timeout to a timeout, with 5 seconds as the default
struct timeval timeout{_timeout,0};

// reset the fdset
FD_ZERO(&writeset);

// set the fd to monitor for writing
FD_SET(_socket, &writeset);

// perform a select, wait for something to happen on one of the fds
int ret = select(_socket + 1, nullptr, &writeset, nullptr, &timeout);
// set the struct members
fd.fd = _socket;
fd.events = POLLOUT;
fd.revents = 0;

// perform the poll, with a very long time to allow the event to occur
int ret = poll(&fd, 1, _timeout * 1000);

// log the error for the time being
if (ret == 0) _error = "connection timed out";
Expand Down

0 comments on commit 7c07ab1

Please sign in to comment.