Skip to content

Commit

Permalink
ktls: send alerts (aws#4185)
Browse files Browse the repository at this point in the history
  • Loading branch information
lrstewart authored Sep 7, 2023
1 parent a888cfc commit a7b0dfa
Show file tree
Hide file tree
Showing 10 changed files with 337 additions and 66 deletions.
4 changes: 2 additions & 2 deletions tests/testlib/s2n_ktls_test_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@ S2N_CLEANUP_RESULT s2n_ktls_io_stuffer_pair_free(struct s2n_test_ktls_io_stuffer
return S2N_RESULT_OK;
}

S2N_RESULT s2n_test_validate_data(struct s2n_test_ktls_io_stuffer *ktls_io, uint8_t *expected_data,
uint16_t expected_len)
S2N_RESULT s2n_test_validate_data(struct s2n_test_ktls_io_stuffer *ktls_io,
const uint8_t *expected_data, uint16_t expected_len)
{
RESULT_ENSURE_REF(ktls_io);
RESULT_ENSURE_REF(expected_data);
Expand Down
4 changes: 2 additions & 2 deletions tests/testlib/s2n_ktls_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ S2N_RESULT s2n_test_init_ktls_io_stuffer(struct s2n_connection *server,
struct s2n_connection *client, struct s2n_test_ktls_io_stuffer_pair *io_pair);
S2N_CLEANUP_RESULT s2n_ktls_io_stuffer_free(struct s2n_test_ktls_io_stuffer *io);
S2N_CLEANUP_RESULT s2n_ktls_io_stuffer_pair_free(struct s2n_test_ktls_io_stuffer_pair *pair);
S2N_RESULT s2n_test_validate_data(struct s2n_test_ktls_io_stuffer *ktls_io, uint8_t *expected_data,
uint16_t expected_len);
S2N_RESULT s2n_test_validate_data(struct s2n_test_ktls_io_stuffer *ktls_io,
const uint8_t *expected_data, uint16_t expected_len);
S2N_RESULT s2n_test_validate_ancillary(struct s2n_test_ktls_io_stuffer *ktls_io,
uint8_t expected_record_type, uint16_t expected_len);
S2N_RESULT s2n_test_records_in_ancillary(struct s2n_test_ktls_io_stuffer *ktls_io,
Expand Down
190 changes: 155 additions & 35 deletions tests/unit/s2n_ktls_io_test.c

Large diffs are not rendered by default.

90 changes: 89 additions & 1 deletion tests/unit/s2n_shutdown_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
#include "tls/s2n_shutdown.c"

#include "s2n_test.h"
#include "testlib/s2n_ktls_test_utils.h"
#include "testlib/s2n_testlib.h"
#include "tls/s2n_alerts.h"
#include "utils/s2n_socket.h"

#define ALERT_LEN (sizeof(uint16_t))

Expand Down Expand Up @@ -614,7 +616,93 @@ int main(int argc, char **argv)
EXPECT_TRUE(s2n_connection_check_io_status(conn, S2N_IO_CLOSED));
EXPECT_FALSE(s2n_atomic_flag_test(&conn->close_notify_received));
};
}

/* Test: kTLS enabled */
{
/* Test: Successfully send alert */
{
DEFER_CLEANUP(struct s2n_connection *conn = s2n_connection_new(S2N_SERVER),
s2n_connection_ptr_free);
EXPECT_NOT_NULL(conn);
EXPECT_OK(s2n_ktls_configure_connection(conn, S2N_KTLS_MODE_SEND));

DEFER_CLEANUP(struct s2n_test_ktls_io_stuffer out = { 0 },
s2n_ktls_io_stuffer_free);
EXPECT_OK(s2n_test_init_ktls_io_stuffer_send(conn, &out));

s2n_blocked_status blocked = S2N_NOT_BLOCKED;
EXPECT_SUCCESS(s2n_shutdown_send(conn, &blocked));
EXPECT_EQUAL(blocked, S2N_NOT_BLOCKED);
EXPECT_TRUE(conn->alert_sent);
EXPECT_EQUAL(out.sendmsg_invoked_count, 1);
EXPECT_OK(s2n_test_validate_ancillary(&out, TLS_ALERT, S2N_ALERT_LENGTH));
EXPECT_OK(s2n_test_validate_data(&out,
close_notify_alert, sizeof(close_notify_alert)));

/* Repeating the shutdown does not resend the alert */
for (size_t i = 0; i < 5; i++) {
EXPECT_SUCCESS(s2n_shutdown_send(conn, &blocked));
EXPECT_EQUAL(blocked, S2N_NOT_BLOCKED);
EXPECT_TRUE(conn->alert_sent);
EXPECT_EQUAL(out.sendmsg_invoked_count, 1);
}
};

/* Test: Successfully send alert after blocking */
{
/* One call does the partial write, the second blocks */
const size_t partial_write = 1;
const size_t second_write = sizeof(close_notify_alert) - partial_write;
EXPECT_TRUE(second_write > 0);

DEFER_CLEANUP(struct s2n_connection *conn = s2n_connection_new(S2N_SERVER),
s2n_connection_ptr_free);
EXPECT_NOT_NULL(conn);
EXPECT_OK(s2n_ktls_configure_connection(conn, S2N_KTLS_MODE_SEND));

DEFER_CLEANUP(struct s2n_test_ktls_io_stuffer out = { 0 },
s2n_ktls_io_stuffer_free);
EXPECT_OK(s2n_test_init_ktls_io_stuffer_send(conn, &out));
EXPECT_SUCCESS(s2n_stuffer_free(&out.data_buffer));
EXPECT_SUCCESS(s2n_stuffer_alloc(&out.data_buffer, partial_write));

/* One call does the partial write, the second blocks */
size_t expected_calls = 2;

/* Initial shutdown blocks */
s2n_blocked_status blocked = S2N_NOT_BLOCKED;
EXPECT_FAILURE_WITH_ERRNO(s2n_shutdown_send(conn, &blocked),
S2N_ERR_IO_BLOCKED);
EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_WRITE);
EXPECT_TRUE(conn->alert_sent);
EXPECT_EQUAL(out.sendmsg_invoked_count, expected_calls);
EXPECT_OK(s2n_test_validate_ancillary(&out, TLS_ALERT, partial_write));
EXPECT_OK(s2n_test_validate_data(&out, close_notify_alert, partial_write));

/* Unblock the output stuffer */
out.data_buffer.growable = true;
expected_calls++;
EXPECT_SUCCESS(s2n_stuffer_wipe(&out.ancillary_buffer));

/* Second shutdown succeeds */
EXPECT_SUCCESS(s2n_shutdown_send(conn, &blocked));
EXPECT_EQUAL(blocked, S2N_NOT_BLOCKED);
EXPECT_TRUE(conn->alert_sent);
EXPECT_EQUAL(out.sendmsg_invoked_count, expected_calls);
EXPECT_OK(s2n_test_validate_ancillary(&out, TLS_ALERT, second_write));
EXPECT_OK(s2n_test_validate_data(&out, close_notify_alert,
sizeof(close_notify_alert)));

/* Repeating the shutdown does not resend the alert */
for (size_t i = 0; i < 5; i++) {
EXPECT_SUCCESS(s2n_shutdown_send(conn, &blocked));
EXPECT_EQUAL(blocked, S2N_NOT_BLOCKED);
EXPECT_TRUE(conn->alert_sent);
EXPECT_EQUAL(out.sendmsg_invoked_count, expected_calls);
}
};
};
};

END_TEST();
}
6 changes: 6 additions & 0 deletions tls/s2n_connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -853,11 +853,17 @@ int s2n_connection_use_corked_io(struct s2n_connection *conn)

uint64_t s2n_connection_get_wire_bytes_in(struct s2n_connection *conn)
{
if (conn->ktls_recv_enabled) {
return 0;
}
return conn->wire_bytes_in;
}

uint64_t s2n_connection_get_wire_bytes_out(struct s2n_connection *conn)
{
if (conn->ktls_send_enabled) {
return 0;
}
return conn->wire_bytes_out;
}

Expand Down
27 changes: 14 additions & 13 deletions tls/s2n_ktls.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,6 @@ static int s2n_ktls_disabled_read(void *io_context, uint8_t *buf, uint32_t len)
POSIX_BAIL(S2N_ERR_IO);
}

static int s2n_ktls_disabled_write(void *io_context, const uint8_t *buf, uint32_t len)
{
POSIX_BAIL(S2N_ERR_IO);
}

static S2N_RESULT s2n_ktls_validate(struct s2n_connection *conn, s2n_ktls_mode ktls_mode)
{
RESULT_ENSURE_REF(conn);
Expand Down Expand Up @@ -244,6 +239,18 @@ static S2N_RESULT s2n_ktls_configure_socket(struct s2n_connection *conn, s2n_ktl
return S2N_RESULT_OK;
}

S2N_RESULT s2n_ktls_configure_connection(struct s2n_connection *conn, s2n_ktls_mode ktls_mode)
{
if (ktls_mode == S2N_KTLS_MODE_SEND) {
conn->ktls_send_enabled = true;
conn->send = s2n_ktls_send_cb;
} else {
conn->ktls_recv_enabled = true;
conn->recv = s2n_ktls_disabled_read;
}
return S2N_RESULT_OK;
}

/*
* Since kTLS is an optimization, it is possible to continue operation
* by using userspace TLS if kTLS is not supported.
Expand All @@ -265,10 +272,7 @@ int s2n_connection_ktls_enable_send(struct s2n_connection *conn)
}

POSIX_GUARD_RESULT(s2n_ktls_configure_socket(conn, S2N_KTLS_MODE_SEND));

conn->ktls_send_enabled = true;
/* kTLS now handles I/O for the connection */
conn->send = s2n_ktls_disabled_write;
POSIX_GUARD_RESULT(s2n_ktls_configure_connection(conn, S2N_KTLS_MODE_SEND));

return S2N_SUCCESS;
}
Expand All @@ -287,10 +291,7 @@ int s2n_connection_ktls_enable_recv(struct s2n_connection *conn)
}

POSIX_GUARD_RESULT(s2n_ktls_configure_socket(conn, S2N_KTLS_MODE_RECV));

conn->ktls_recv_enabled = true;
/* kTLS now handles I/O for the connection */
conn->recv = s2n_ktls_disabled_read;
POSIX_GUARD_RESULT(s2n_ktls_configure_connection(conn, S2N_KTLS_MODE_RECV));

return S2N_SUCCESS;
}
8 changes: 6 additions & 2 deletions tls/s2n_ktls.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,16 @@ typedef enum {
bool s2n_ktls_is_supported_on_platform();
S2N_RESULT s2n_ktls_get_file_descriptor(struct s2n_connection *conn, s2n_ktls_mode ktls_mode, int *fd);

S2N_RESULT s2n_ktls_sendmsg(struct s2n_connection *conn, uint8_t record_type, const struct iovec *msg_iov,
int s2n_ktls_send_cb(void *io_context, const uint8_t *buf, uint32_t len);
S2N_RESULT s2n_ktls_sendmsg(void *io_context, uint8_t record_type, const struct iovec *msg_iov,
size_t msg_iovlen, s2n_blocked_status *blocked, size_t *bytes_written);
S2N_RESULT s2n_ktls_recvmsg(struct s2n_connection *conn, uint8_t *record_type, uint8_t *buf,
S2N_RESULT s2n_ktls_recvmsg(void *io_context, uint8_t *record_type, uint8_t *buf,
size_t buf_len, s2n_blocked_status *blocked, size_t *bytes_read);

ssize_t s2n_ktls_sendv_with_offset(struct s2n_connection *conn, const struct iovec *bufs,
ssize_t count, ssize_t offs, s2n_blocked_status *blocked);
int s2n_ktls_record_writev(struct s2n_connection *conn, uint8_t content_type,
const struct iovec *in, int in_count, size_t offs, size_t to_write);

/* These functions will be part of the public API. */
int s2n_connection_ktls_enable_send(struct s2n_connection *conn);
Expand All @@ -61,3 +64,4 @@ S2N_RESULT s2n_ktls_set_sendmsg_cb(struct s2n_connection *conn, s2n_ktls_sendmsg
void *send_ctx);
S2N_RESULT s2n_ktls_set_recvmsg_cb(struct s2n_connection *conn, s2n_ktls_recvmsg_fn recv_cb,
void *recv_ctx);
S2N_RESULT s2n_ktls_configure_connection(struct s2n_connection *conn, s2n_ktls_mode ktls_mode);
63 changes: 55 additions & 8 deletions tls/s2n_ktls_io.c
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,11 @@ S2N_RESULT s2n_ktls_get_control_data(struct msghdr *msg, int cmsg_type, uint8_t
return S2N_RESULT_OK;
}

S2N_RESULT s2n_ktls_sendmsg(struct s2n_connection *conn, uint8_t record_type, const struct iovec *msg_iov,
S2N_RESULT s2n_ktls_sendmsg(void *io_context, uint8_t record_type, const struct iovec *msg_iov,
size_t msg_iovlen, s2n_blocked_status *blocked, size_t *bytes_written)
{
RESULT_ENSURE_REF(bytes_written);
RESULT_ENSURE_REF(blocked);
RESULT_ENSURE_REF(conn);
RESULT_ENSURE(msg_iov != NULL || msg_iovlen == 0, S2N_ERR_NULL);

*blocked = S2N_BLOCKED_ON_WRITE;
Expand All @@ -206,7 +205,7 @@ S2N_RESULT s2n_ktls_sendmsg(struct s2n_connection *conn, uint8_t record_type, co
RESULT_GUARD(s2n_ktls_set_control_data(&msg, control_data, sizeof(control_data),
S2N_TLS_SET_RECORD_TYPE, record_type));

ssize_t result = s2n_sendmsg_fn(conn->send_io_context, &msg);
ssize_t result = s2n_sendmsg_fn(io_context, &msg);
if (result < 0) {
if (errno == EWOULDBLOCK || errno == EAGAIN) {
RESULT_BAIL(S2N_ERR_IO_BLOCKED);
Expand All @@ -219,13 +218,12 @@ S2N_RESULT s2n_ktls_sendmsg(struct s2n_connection *conn, uint8_t record_type, co
return S2N_RESULT_OK;
}

S2N_RESULT s2n_ktls_recvmsg(struct s2n_connection *conn, uint8_t *record_type, uint8_t *buf,
S2N_RESULT s2n_ktls_recvmsg(void *io_context, uint8_t *record_type, uint8_t *buf,
size_t buf_len, s2n_blocked_status *blocked, size_t *bytes_read)
{
RESULT_ENSURE_REF(record_type);
RESULT_ENSURE_REF(bytes_read);
RESULT_ENSURE_REF(blocked);
RESULT_ENSURE_REF(conn);
RESULT_ENSURE_REF(buf);
/* Ensure that buf_len is > 0 since trying to receive 0 bytes does not
* make sense and a return value of `0` from recvmsg is treated as EOF.
Expand Down Expand Up @@ -254,7 +252,7 @@ S2N_RESULT s2n_ktls_recvmsg(struct s2n_connection *conn, uint8_t *record_type, u
msg.msg_controllen = sizeof(control_data);
msg.msg_control = control_data;

ssize_t result = s2n_recvmsg_fn(conn->recv_io_context, &msg);
ssize_t result = s2n_recvmsg_fn(io_context, &msg);
if (result < 0) {
if (errno == EWOULDBLOCK || errno == EAGAIN) {
RESULT_BAIL(S2N_ERR_IO_BLOCKED);
Expand Down Expand Up @@ -304,6 +302,7 @@ static S2N_RESULT s2n_ktls_new_iovecs_with_offset(const struct iovec *bufs,
ssize_t s2n_ktls_sendv_with_offset(struct s2n_connection *conn, const struct iovec *bufs,
ssize_t count_in, ssize_t offs_in, s2n_blocked_status *blocked)
{
POSIX_ENSURE_REF(conn);
POSIX_ENSURE(count_in >= 0, S2N_ERR_INVALID_ARGUMENT);
size_t count = count_in;
POSIX_ENSURE(offs_in >= 0, S2N_ERR_INVALID_ARGUMENT);
Expand All @@ -319,7 +318,55 @@ ssize_t s2n_ktls_sendv_with_offset(struct s2n_connection *conn, const struct iov
}

size_t bytes_written = 0;
POSIX_GUARD_RESULT(s2n_ktls_sendmsg(conn, TLS_APPLICATION_DATA, bufs, count,
blocked, &bytes_written));
POSIX_GUARD_RESULT(s2n_ktls_sendmsg(conn->send_io_context, TLS_APPLICATION_DATA,
bufs, count, blocked, &bytes_written));
return bytes_written;
}

int s2n_ktls_send_cb(void *io_context, const uint8_t *buf, uint32_t len)
{
/* For now, all control records are assumed to be alerts.
* We can set the record_type on the io_context in the future.
*/
const uint8_t record_type = TLS_ALERT;

const struct iovec iov = {
.iov_base = (void *) (uintptr_t) buf,
.iov_len = len,
};
s2n_blocked_status blocked = S2N_NOT_BLOCKED;
size_t bytes_written = 0;

POSIX_GUARD_RESULT(s2n_ktls_sendmsg(io_context, record_type, &iov, 1,
&blocked, &bytes_written));

POSIX_ENSURE_LTE(bytes_written, len);
return bytes_written;
}

int s2n_ktls_record_writev(struct s2n_connection *conn, uint8_t content_type,
const struct iovec *in, int in_count, size_t offs, size_t to_write)
{
POSIX_ENSURE_REF(conn);
POSIX_ENSURE(in_count > 0, S2N_ERR_INVALID_ARGUMENT);
size_t count = in_count;
POSIX_ENSURE_REF(in);

/* Currently, ktls only supports sending alerts.
* To also support handshake messages, we would need a way to track record_type.
* We could add a field to the send io context.
*/
POSIX_ENSURE(content_type == TLS_ALERT, S2N_ERR_UNIMPLEMENTED);

/* When stuffers automatically resize, they allocate a potentially large
* chunk of memory to avoid repeated resizes.
* Since ktls only uses conn->out for control messages (alerts and eventually
* handshake messages), we expect infrequent small writes with conn->out
* freed in between. Since we're therefore more concerned with the size of
* the allocation than the frequency, use a more accurate size for each write.
*/
POSIX_GUARD(s2n_stuffer_resize_if_empty(&conn->out, to_write));

POSIX_GUARD(s2n_stuffer_writev_bytes(&conn->out, in, count, offs, to_write));
return to_write;
}
5 changes: 5 additions & 0 deletions tls/s2n_record_write.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "tls/s2n_cipher_suites.h"
#include "tls/s2n_connection.h"
#include "tls/s2n_crypto.h"
#include "tls/s2n_ktls.h"
#include "tls/s2n_record.h"
#include "utils/s2n_blob.h"
#include "utils/s2n_random.h"
Expand Down Expand Up @@ -247,6 +248,10 @@ static inline int s2n_record_encrypt(

int s2n_record_writev(struct s2n_connection *conn, uint8_t content_type, const struct iovec *in, int in_count, size_t offs, size_t to_write)
{
if (conn->ktls_send_enabled) {
return s2n_ktls_record_writev(conn, content_type, in, in_count, offs, to_write);
}

struct s2n_blob iv = { 0 };
uint8_t padding = 0;
uint16_t block_size = 0;
Expand Down
6 changes: 3 additions & 3 deletions tls/s2n_send.c
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,13 @@ ssize_t s2n_sendv_with_offset_impl(struct s2n_connection *conn, const struct iov
POSIX_ENSURE(s2n_connection_check_io_status(conn, S2N_IO_WRITABLE), S2N_ERR_CLOSED);
POSIX_ENSURE(!s2n_connection_is_quic_enabled(conn), S2N_ERR_UNSUPPORTED_WITH_QUIC);

/* Flush any pending I/O */
POSIX_GUARD(s2n_flush(conn, blocked));

if (conn->ktls_send_enabled) {
return s2n_ktls_sendv_with_offset(conn, bufs, count, offs, blocked);
}

/* Flush any pending I/O */
POSIX_GUARD(s2n_flush(conn, blocked));

/* Acknowledge consumed and flushed user data as sent */
user_data_sent = conn->current_user_data_consumed;

Expand Down

0 comments on commit a7b0dfa

Please sign in to comment.