Skip to content

Commit

Permalink
Drop all data in websocket buffer after close frame
Browse files Browse the repository at this point in the history
Otherwise if there are still bytes, it will assume they define a frame
header and start to parse it.
  • Loading branch information
halfgaar committed Oct 29, 2023
1 parent b41b9d8 commit 34dc32b
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 0 deletions.
1 change: 1 addition & 0 deletions FlashMQTests/tst_maintests.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ private slots:
void testWebsocketCorruptLengthFrame();
void testWebsocketHugePing();
void testWebsocketManyBigPingFrames();
void testWebsocketClose();
};


Expand Down
107 changes: 107 additions & 0 deletions FlashMQTests/websockettests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -638,3 +638,110 @@ void MainTests::testWebsocketManyBigPingFrames()
QVERIFY2(false, ex.what());
}
}

void MainTests::testWebsocketClose()
{
try
{
Settings settings;
PluginLoader pluginLoader;
std::shared_ptr<SubscriptionStore> store(new SubscriptionStore());
std::shared_ptr<ThreadData> t(new ThreadData(0, settings, pluginLoader));

// Kind of a hack...
Authentication auth(settings);
ThreadGlobals::assign(&auth);
ThreadGlobals::assignThreadData(t.get());

int listen_socket = socket(AF_INET, SOCK_STREAM, 0);
FileCloser listener_closer(listen_socket);

int optval = 1;
check<std::runtime_error>(setsockopt(listen_socket, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &optval, sizeof(optval)));

BindAddr bindAddr = getBindAddr(AF_INET, "127.0.0.1", 22123);

check<std::runtime_error>(bind(listen_socket, bindAddr.p.get(), bindAddr.len));
check<std::runtime_error>(listen(listen_socket, 64));

int client_socket = socket(AF_INET, SOCK_STREAM, 0);
int flags = fcntl(listen_socket, F_GETFL);
check<std::runtime_error>(fcntl(client_socket, F_SETFL, flags | O_NONBLOCK ));

std::shared_ptr<Client> c1(new Client(client_socket, t, nullptr, true, false, nullptr, settings, false));
std::shared_ptr<Client> client = c1;
t->giveClient(std::move(c1));

::connect(client_socket, bindAddr.p.get(), bindAddr.len);

int socket_to_client = accept(listen_socket, nullptr, nullptr);
FileCloser socket_to_client_closer(socket_to_client);

if (socket_to_client < 0)
throw std::runtime_error("Couldn't accept socket.");

flags = fcntl(listen_socket, F_GETFL);
check<std::runtime_error>(fcntl(socket_to_client, F_SETFL, flags | O_NONBLOCK ));

int error = 0;
socklen_t optlen = sizeof(int);
int count = 0;
do
{
check<std::runtime_error>(getsockopt(client_socket, SOL_SOCKET, SO_ERROR, &error, &optlen));
}
while(error == EINPROGRESS && count++ < 1000);

if (error > 0 && error != EINPROGRESS)
throw std::runtime_error(strerror(error));

std::ifstream input("plainwebsocketpacket1_handshake.dat", std::ios::binary);
std::vector<unsigned char> websocketstart(std::istreambuf_iterator<char>(input), {});

{
write(socket_to_client, websocketstart.data(), websocketstart.size());
client->readFdIntoBuffer();
client->writeBufIntoFd();
std::vector<char> answer = readFromSocket(socket_to_client, true);
std::string answer_string(answer.begin(), answer.end());

QVERIFY(startsWith(answer_string, "HTTP/1.1 101 Switching Protocols"));

}

// We now have an upgraded connection, and can test websocket frame decoding.

{
{
size_t l = 0;
std::vector<char> closeFrame(1024);
closeFrame[l++] = 0x08; // opcode 8 = close
closeFrame[l++] = 0x00; // Unmasked. payload length;
write(socket_to_client, closeFrame.data(), l);
}

{
size_t l = 0;
std::vector<char> pingFrameWithPayload(1024);
pingFrameWithPayload[l++] = 0x09; // opcode 9
pingFrameWithPayload[l++] = 0x05; // Unmasked. payload length;
pingFrameWithPayload[l++] = 'h';
pingFrameWithPayload[l++] = 'e';
pingFrameWithPayload[l++] = 'l';
pingFrameWithPayload[l++] = 'l';
pingFrameWithPayload[l++] = 'o';

write(socket_to_client, pingFrameWithPayload.data(), l);
}

pollFd(client_socket, true);
bool connectionStatus = client->readFdIntoBuffer();

QVERIFY(!connectionStatus);
}
}
catch (std::exception &ex)
{
QVERIFY2(false, ex.what());
}
}
1 change: 1 addition & 0 deletions iowrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,7 @@ ssize_t IoWrapper::websocketBytesToReadBuffer(void *buf, const size_t nbytes, Io

// There may be a UTF8 string with a reason in the packet still, but ignoring that for now.
incompleteWebsocketRead.reset();
websocketPendingBytes.reset();
}
else
{
Expand Down

0 comments on commit 34dc32b

Please sign in to comment.