Skip to content

Commit

Permalink
Refactor: check limit before send, not after
Browse files Browse the repository at this point in the history
  • Loading branch information
lrstewart committed Dec 14, 2023
1 parent 97db458 commit 942bd95
Show file tree
Hide file tree
Showing 7 changed files with 285 additions and 54 deletions.
25 changes: 19 additions & 6 deletions tests/unit/s2n_ktls_io_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -1262,7 +1262,6 @@ int main(int argc, char **argv)
EXPECT_OK(s2n_assert_seq_num_equal(seq_num, expected_seq_num));

/* Test: Send enough data to hit the encryption limit */
expected_seq_num += large_test_data_records;
EXPECT_FAILURE_WITH_ERRNO(
s2n_send(conn, large_test_data, sizeof(large_test_data), &blocked),
S2N_ERR_KTLS_KEY_LIMIT);
Expand All @@ -1284,23 +1283,37 @@ int main(int argc, char **argv)
EXPECT_FAILURE_WITH_ERRNO(
s2n_send(conn, large_test_data, 1, &blocked),
S2N_ERR_KTLS_KEY_LIMIT);
EXPECT_OK(s2n_assert_seq_num_equal(seq_num, test_encryption_limit + 1));
EXPECT_OK(s2n_assert_seq_num_equal(seq_num, test_encryption_limit));
};

/* Test: Limit not tracked with TLS1.2 */
{
conn->actual_protocol_version = S2N_TLS12;

DEFER_CLEANUP(struct s2n_blob seq_num = { 0 }, s2n_blob_zero);
EXPECT_OK(s2n_connection_get_sequence_number(conn, conn->mode, &seq_num));

EXPECT_EQUAL(s2n_send(conn, large_test_data, 1, &blocked), 1);
/* Sequence number not incremented with TLS1.2 */
conn->actual_protocol_version = S2N_TLS12;
EXPECT_EQUAL(
s2n_send(conn, large_test_data, sizeof(large_test_data), &blocked),
sizeof(large_test_data));
EXPECT_OK(s2n_assert_seq_num_equal(seq_num, 0));

/* Sequence number incremented with TLS1.3 */
conn->actual_protocol_version = S2N_TLS13;
EXPECT_EQUAL(
s2n_send(conn, large_test_data, sizeof(large_test_data), &blocked),
sizeof(large_test_data));
EXPECT_OK(s2n_assert_seq_num_equal(seq_num, 0));
EXPECT_OK(s2n_assert_seq_num_equal(seq_num, test_encryption_limit));

/* Passing the limit with TLS1.3 is an error */
conn->actual_protocol_version = S2N_TLS13;
EXPECT_FAILURE_WITH_ERRNO(
s2n_send(conn, large_test_data, 1, &blocked),
S2N_ERR_KTLS_KEY_LIMIT);

/* Passing the limit with TLS1.2 is NOT an error */
conn->actual_protocol_version = S2N_TLS12;
EXPECT_EQUAL(s2n_send(conn, large_test_data, 1, &blocked), 1);
};
}
};
Expand Down
51 changes: 51 additions & 0 deletions tests/unit/s2n_safety_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,57 @@ int main(int argc, char **argv)
CHECK_OVF(s2n_add_overflow, uint32_t, 100, ACTUAL_MAX - 99);
CHECK_OVF(s2n_add_overflow, uint32_t, 100, ACTUAL_MAX - 1);

/* Test: S2N_ADD_IS_OVERFLOW_SAFE */
{
const size_t num = 100;

uint64_t success_test_values[][3] = {
{ 0, 0, 0 },
{ 1, 0, 1 },
{ 0, 0, UINT8_MAX },
{ 1, 1, UINT8_MAX },
{ UINT8_MAX, 0, UINT8_MAX },
{ UINT8_MAX - num, num, UINT8_MAX },
{ UINT8_MAX / 2, UINT8_MAX / 2, UINT8_MAX },
{ 1, 1, UINT64_MAX },
{ UINT64_MAX, 0, UINT64_MAX },
{ UINT64_MAX - num, num, UINT64_MAX },
{ UINT64_MAX / 2, UINT64_MAX / 2, UINT64_MAX },
};
for (size_t i = 0; i < s2n_array_len(success_test_values); i++) {
uint64_t v1 = success_test_values[i][0];
uint64_t v2 = success_test_values[i][1];
uint64_t max = success_test_values[i][2];
EXPECT_TRUE(S2N_ADD_IS_OVERFLOW_SAFE(v1, v2, max));
EXPECT_TRUE(S2N_ADD_IS_OVERFLOW_SAFE(v2, v1, max));
}

uint64_t failure_test_values[][3] = {
{ 1, 0, 0 },
{ UINT8_MAX, 0, 0 },
{ UINT64_MAX, 0, UINT8_MAX },
{ UINT64_MAX, UINT64_MAX, UINT8_MAX },
{ UINT8_MAX, 1, UINT8_MAX },
{ UINT8_MAX - 1, UINT8_MAX - 1, UINT8_MAX },
{ UINT16_MAX, 1, UINT16_MAX },
{ UINT64_MAX, 1, UINT64_MAX },
{ UINT8_MAX, num, UINT8_MAX },
{ UINT16_MAX, num, UINT16_MAX },
{ UINT64_MAX, num, UINT64_MAX },
{ UINT8_MAX, UINT8_MAX, UINT8_MAX },
{ UINT16_MAX, UINT16_MAX, UINT16_MAX },
{ UINT64_MAX, UINT64_MAX, UINT64_MAX },
{ UINT64_MAX - num, UINT64_MAX - num, UINT64_MAX },
};
for (size_t i = 0; i < s2n_array_len(failure_test_values); i++) {
uint64_t v1 = failure_test_values[i][0];
uint64_t v2 = failure_test_values[i][1];
uint64_t max = failure_test_values[i][2];
EXPECT_FALSE(S2N_ADD_IS_OVERFLOW_SAFE(v1, v2, max));
EXPECT_FALSE(S2N_ADD_IS_OVERFLOW_SAFE(v2, v1, max));
}
}

END_TEST();
return 0;
}
126 changes: 126 additions & 0 deletions tests/unit/s2n_send_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -600,5 +600,131 @@ int main(int argc, char **argv)
EXPECT_EQUAL(conn->out.blob.size, out_size[S2N_MFL_DEFAULT]);
};

/* Test: s2n_sendv_with_offset_total_size */
{
const struct iovec test_multiple_bufs[] = {
{ .iov_len = 0 },
{ .iov_len = 1 },
{ .iov_len = 2 },
{ .iov_len = 0 },
{ .iov_len = 14 },
{ .iov_len = 0 },
{ .iov_len = 3 },
{ .iov_len = 0 },
};
const size_t test_multiple_bufs_total_size = 20;

/* Safety */
{
size_t out = 0;
EXPECT_ERROR_WITH_ERRNO(
s2n_sendv_with_offset_total_size(NULL, 0, 0, NULL),
S2N_ERR_NULL);
EXPECT_ERROR_WITH_ERRNO(
s2n_sendv_with_offset_total_size(NULL, 1, 0, &out),
S2N_ERR_NULL);
}

/* No iovecs */
{
size_t out = 0;
EXPECT_OK(s2n_sendv_with_offset_total_size(NULL, 0, 0, &out));
EXPECT_EQUAL(out, 0);
}

/* Array of zero-length iovecs */
{
const struct iovec test_bufs[10] = { 0 };
size_t out = 0;
EXPECT_OK(s2n_sendv_with_offset_total_size(
test_bufs, s2n_array_len(test_bufs), 0, &out));
EXPECT_EQUAL(out, 0);
}

/* Single iovec */
{
const size_t expected_size = 10;
const struct iovec test_buf = { .iov_len = expected_size };
size_t out = 0;
EXPECT_OK(s2n_sendv_with_offset_total_size(&test_buf, 1, 0, &out));
EXPECT_EQUAL(out, expected_size);
}

/* Single iovec with offset */
{
const struct iovec test_buf = { .iov_len = 10 };
const ssize_t offset = 5;
size_t out = 0;
EXPECT_OK(s2n_sendv_with_offset_total_size(&test_buf, 1, offset, &out));
EXPECT_EQUAL(out, test_buf.iov_len - offset);
}

/* Multiple iovecs */
{
size_t out = 0;
EXPECT_OK(s2n_sendv_with_offset_total_size(
test_multiple_bufs, s2n_array_len(test_multiple_bufs), 0, &out));
EXPECT_EQUAL(out, test_multiple_bufs_total_size);
}

/* Multiple iovecs with offset */
{
const size_t offset = 10;
size_t out = 0;
EXPECT_OK(s2n_sendv_with_offset_total_size(
test_multiple_bufs, s2n_array_len(test_multiple_bufs), offset, &out));
EXPECT_EQUAL(out, test_multiple_bufs_total_size - offset);
}

/* Offset with no data */
{
const struct iovec test_bufs[10] = { 0 };
size_t out = 0;
EXPECT_ERROR_WITH_ERRNO(
s2n_sendv_with_offset_total_size(NULL, 0, 1, &out),
S2N_ERR_INVALID_ARGUMENT);
EXPECT_ERROR_WITH_ERRNO(
s2n_sendv_with_offset_total_size(test_bufs, 0, 1, &out),
S2N_ERR_INVALID_ARGUMENT);
EXPECT_ERROR_WITH_ERRNO(
s2n_sendv_with_offset_total_size(test_bufs, s2n_array_len(test_bufs), 1, &out),
S2N_ERR_INVALID_ARGUMENT);
}

/* Offset larger than available data */
{
const struct iovec test_buf = { .iov_len = 10 };
size_t out = 0;

ssize_t test_buf_offset = test_buf.iov_len + 1;
EXPECT_ERROR_WITH_ERRNO(
s2n_sendv_with_offset_total_size(&test_buf, 1, test_buf_offset, &out),
S2N_ERR_INVALID_ARGUMENT);

ssize_t test_multiple_bufs_offset = test_multiple_bufs_total_size + 1;
EXPECT_ERROR_WITH_ERRNO(
s2n_sendv_with_offset_total_size(test_multiple_bufs,
s2n_array_len(test_multiple_bufs), test_multiple_bufs_offset, &out),
S2N_ERR_INVALID_ARGUMENT);
}

/* Too much data to count
*
* This isn't really practically possible since an application would need
* to allocate more than SIZE_MAX memory for the iovec buffers, but we
* should ensure that the inputs don't cause unexpected behavior.
*/
{
const struct iovec test_bufs[] = {
{ .iov_len = SIZE_MAX },
{ .iov_len = 1 },
};
size_t out = 0;
EXPECT_ERROR_WITH_ERRNO(
s2n_sendv_with_offset_total_size(test_bufs, s2n_array_len(test_bufs), 0, &out),
S2N_ERR_INVALID_ARGUMENT);
}
};

END_TEST();
}
85 changes: 53 additions & 32 deletions tls/s2n_ktls_io.c
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,21 @@ S2N_RESULT s2n_ktls_recvmsg(void *io_context, uint8_t *record_type, uint8_t *buf
return S2N_RESULT_OK;
}

/* The RFC defines the encryption limits in terms of "full-size records" sent.
* We can estimate the number of "full-sized records" sent by assuming that
* all records are full-sized.
*/
static S2N_RESULT s2n_ktls_estimate_records(size_t bytes, uint64_t *estimate)
{
RESULT_ENSURE_REF(estimate);
uint64_t records = bytes / S2N_TLS_MAXIMUM_FRAGMENT_LENGTH;
if (bytes % S2N_TLS_MAXIMUM_FRAGMENT_LENGTH) {
records++;
}
*estimate = records;
return S2N_RESULT_OK;
}

/* ktls does not currently support updating keys, so we should kill the connection
* when the key encryption limit is reached. We could get the current record
* sequence number from the kernel with getsockopt, but that requires a surprisingly
Expand All @@ -277,34 +292,51 @@ S2N_RESULT s2n_ktls_recvmsg(void *io_context, uint8_t *record_type, uint8_t *buf
* Instead, we track the estimated sequence number and enforce the limit based
* on that estimate.
*/
static S2N_RESULT s2n_ktls_enforce_estimated_record_limit(
struct s2n_connection *conn, size_t bytes_written)
static S2N_RESULT s2n_ktls_check_estimated_record_limit(
struct s2n_connection *conn, size_t bytes_requested)
{
RESULT_ENSURE_REF(conn);
if (conn->actual_protocol_version < S2N_TLS13) {
return S2N_RESULT_OK;
}

uint64_t new_records_sent = 0;
RESULT_GUARD(s2n_ktls_estimate_records(bytes_requested, &new_records_sent));

uint64_t old_records_sent = 0;
struct s2n_blob seq_num = { 0 };
RESULT_GUARD(s2n_connection_get_sequence_number(conn, conn->mode, &seq_num));
RESULT_GUARD_POSIX(s2n_sequence_number_to_uint64(&seq_num, &old_records_sent));

/* The RFC states the encryption limits in terms of "full-size records" sent.
* We can estimate the number of "full-sized records" sent by assuming that
* all records are full-sized.
*/
while (bytes_written > 0) {
RESULT_GUARD_POSIX(s2n_increment_sequence_number(&seq_num));
bytes_written -= MIN(bytes_written, S2N_TLS_MAXIMUM_FRAGMENT_LENGTH);
}

uint64_t records_sent = 0;
RESULT_GUARD_POSIX(s2n_sequence_number_to_uint64(&seq_num, &records_sent));
RESULT_ENSURE(S2N_ADD_IS_OVERFLOW_SAFE(old_records_sent, new_records_sent, UINT64_MAX),
S2N_ERR_KTLS_KEY_LIMIT);
uint64_t total_records_sent = old_records_sent + new_records_sent;

RESULT_ENSURE_REF(conn->secure);
RESULT_ENSURE_REF(conn->secure->cipher_suite);
RESULT_ENSURE_REF(conn->secure->cipher_suite->record_alg);
uint64_t encryption_limit = conn->secure->cipher_suite->record_alg->encryption_limit;
RESULT_ENSURE(records_sent <= encryption_limit, S2N_ERR_KTLS_KEY_LIMIT);
RESULT_ENSURE(total_records_sent <= encryption_limit, S2N_ERR_KTLS_KEY_LIMIT);
return S2N_RESULT_OK;
}

static S2N_RESULT s2n_ktls_set_estimated_sequence_number(
struct s2n_connection *conn, size_t bytes_written)
{
RESULT_ENSURE_REF(conn);
if (conn->actual_protocol_version < S2N_TLS13) {
return S2N_RESULT_OK;
}

uint64_t new_records_sent = 0;
RESULT_GUARD(s2n_ktls_estimate_records(bytes_written, &new_records_sent));

struct s2n_blob seq_num = { 0 };
RESULT_GUARD(s2n_connection_get_sequence_number(conn, conn->mode, &seq_num));

for (size_t i = 0; i < new_records_sent; i++) {
RESULT_GUARD_POSIX(s2n_increment_sequence_number(&seq_num));
}
return S2N_RESULT_OK;
}

Expand Down Expand Up @@ -387,6 +419,10 @@ ssize_t s2n_ktls_sendv_with_offset(struct s2n_connection *conn, const struct iov
POSIX_ENSURE(offs_in >= 0, S2N_ERR_INVALID_ARGUMENT);
size_t offs = offs_in;

ssize_t total_bytes = 0;
POSIX_GUARD_RESULT(s2n_sendv_with_offset_total_size(bufs, count_in, offs_in, &total_bytes));
POSIX_GUARD_RESULT(s2n_ktls_check_estimated_record_limit(conn, total_bytes));

DEFER_CLEANUP(struct s2n_blob new_bufs = { 0 }, s2n_free_or_wipe);
uint8_t new_bufs_mem[S2N_MAX_STACK_IOVECS_MEM] = { 0 };
POSIX_GUARD(s2n_blob_init(&new_bufs, new_bufs_mem, sizeof(new_bufs_mem)));
Expand All @@ -398,11 +434,7 @@ ssize_t s2n_ktls_sendv_with_offset(struct s2n_connection *conn, const struct iov
POSIX_GUARD_RESULT(s2n_ktls_sendmsg(conn->send_io_context, TLS_APPLICATION_DATA,
bufs, count, blocked, &bytes_written));

/* Unlike s2n_sendfile, here we could calculate the number of bytes to be sent
* before actually sending them. However, we instead choose to maintain consistent
* behavior across our send methods and always check for the limit after the send.
*/
POSIX_GUARD_RESULT(s2n_ktls_enforce_estimated_record_limit(conn, bytes_written));
POSIX_GUARD_RESULT(s2n_ktls_set_estimated_sequence_number(conn, bytes_written));
return bytes_written;
}

Expand Down Expand Up @@ -466,6 +498,7 @@ int s2n_sendfile(struct s2n_connection *conn, int in_fd, off_t offset, size_t co
*bytes_written = 0;
POSIX_ENSURE_REF(conn);
POSIX_ENSURE(conn->ktls_send_enabled, S2N_ERR_KTLS_UNSUPPORTED_CONN);
POSIX_GUARD_RESULT(s2n_ktls_check_estimated_record_limit(conn, count));

int out_fd = 0;
POSIX_GUARD_RESULT(s2n_ktls_get_file_descriptor(conn, S2N_KTLS_MODE_SEND, &out_fd));
Expand All @@ -480,20 +513,8 @@ int s2n_sendfile(struct s2n_connection *conn, int in_fd, off_t offset, size_t co
POSIX_BAIL(S2N_ERR_UNIMPLEMENTED);
#endif

POSIX_GUARD_RESULT(s2n_ktls_set_estimated_sequence_number(conn, *bytes_written));
*blocked = S2N_NOT_BLOCKED;

/* Because we pass the input file descriptor to the kernel without examining
* it, we don't know how many bytes actually need to be sent. We therefore
* can't verify that the send is safe with respect to the encryption limit
* before sending the records. Instead, we raise a fatal error afterwards if
* the send violated the encryption limit.
*
* An application should treat S2N_ERR_KTLS_KEY_LIMIT as a very high severity
* error, as it indicates that the application is violating the requirements
* for using TLS1.3 with ktls without a kernel patch to enable KeyUpdates,
* and is therefore operating unsafely.
*/
POSIX_GUARD_RESULT(s2n_ktls_enforce_estimated_record_limit(conn, *bytes_written));
return S2N_SUCCESS;
}

Expand Down
Loading

0 comments on commit 942bd95

Please sign in to comment.