Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workaround TLS fragmented records #32

Merged
merged 1 commit into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
CMakeCache.txt
src/.vscode/
*.o
*test_tls
2 changes: 1 addition & 1 deletion src/netguard/include/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ uint16_t calc_checksum(uint16_t start, const uint8_t *buffer, size_t length);

int compare_u32(uint32_t seq1, uint32_t seq2);

int sdk_int(JNIEnv *env);
//int sdk_int(JNIEnv *env);

void hex2bytes(const char *hex, uint8_t *buffer);

Expand Down
50 changes: 27 additions & 23 deletions src/netguard/tls_parser.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <netinet/ip6.h>
#include "platform.h"
#include "tls.h"
#include "util.h"

static int parse_tls_server_name(const uint8_t *data, const size_t data_len, char *server_name);
static int parse_extensions(const uint8_t*, size_t, char *);
Expand Down Expand Up @@ -69,19 +70,18 @@ static int parse_tls_server_name(const uint8_t *data, const size_t data_len, cha
}

/* TLS record length */
// uint16_t len = ((size_t)data[3] << 8) + (size_t)data[4] + TLS_HEADER_LEN;
size_t len = ntohs(*((uint16_t *) (data + 3))) + TLS_HEADER_LEN;
// data_len = MIN(len, data_len);
log_print(PLATFORM_LOG_PRIORITY_INFO, "data len %d, record len %d\n", data_len, len);
if (data_len < len) {
// purposely don't return as we have checks later on
log_print(PLATFORM_LOG_PRIORITY_WARN, "TLS data length smaller than expected, proceed anyways");
}

/* handshake */
size_t pos = TLS_HEADER_LEN;
// if (pos + 1 > data_len) {
// return -5;
// }
if (pos + 1 > data_len) {
return -5;
}

if (data[pos] != 0x1) {
// not a client hello
Expand All @@ -98,18 +98,17 @@ static int parse_tls_server_name(const uint8_t *data, const size_t data_len, cha
pos += 38;

// Session ID
// if (pos + 1 > data_len) return -7;
if (pos + 1 > data_len) return -7;
len = (size_t)data[pos];
pos += 1 + len;

/* Cipher Suites */
// if (pos + 2 > data_len) return -8;
// len = ((size_t)data[pos] << 8) + (size_t)data[pos + 1];
if (pos + 2 > data_len) return -8;
len = ntohs(*((uint16_t *) (data + pos)));
pos += 2 + len;

/* Compression Methods */
// if (pos + 1 > data_len) return -9;
if (pos + 1 > data_len) return -9;
len = (size_t)data[pos];
pos += 1 + len;

Expand All @@ -119,16 +118,17 @@ static int parse_tls_server_name(const uint8_t *data, const size_t data_len, cha
}

/* Extensions */
// if (pos + 2 > data_len) {
// return -11;
// }
// len = ((size_t)data[pos] << 8) + (size_t)data[pos + 1];
if (pos + 2 > data_len) {
return -11;
}
len = ntohs(*((uint16_t *) (data + pos)));
pos += 2;

// if (pos + len > data_len) {
if (pos + len > data_len) {
// Possibly a TLS fragmented record, continue anyways to see if we find SNI in the fragment
log_print(PLATFORM_LOG_PRIORITY_WARN, "Out of bounds at extensions length, pos(%d) + len(%d) > data_len(%d)", pos, len, data_len);
// return -12;
// }
}
return parse_extensions(data + pos, len, server_name);
}

Expand All @@ -139,15 +139,14 @@ static int parse_extensions(const uint8_t *data, size_t data_len, char *hostname
/* Parse each 4 bytes for the extension header */
while (pos + 4 <= data_len) {
/* Extension Length */
// len = ((size_t)data[pos + 2] << 8) +(size_t)data[pos + 3];
len = ntohs(*((uint16_t *) (data + pos + 2)));

/* Check if it's a server name extension */
if (data[pos] == 0x00 && data[pos + 1] == 0x00) {
/* There can be only one extension of each type, so we break
our state and move p to beinnging of the extension here */
// if (pos + 4 + len > data_len)
// return -20;
if (pos + 4 + len > data_len)
return -20;
return parse_server_name_extension(data + pos + 4, len, hostname);
}
pos += 4 + len; /* Advance to the next extension header */
Expand All @@ -164,12 +163,11 @@ static int parse_server_name_extension(const uint8_t *data, size_t data_len, cha
size_t len;

while (pos + 3 < data_len) {
// len = ((size_t)data[pos + 1] << 8) + (size_t)data[pos + 2];
len = ntohs(*((uint16_t *) (data + pos + 1)));

// if (pos + 3 + len > data_len) {
// return -30;
// }
if (pos + 3 + len > data_len) {
return -30;
}

switch (data[pos]) { /* name type */
case 0x00: /* host_name */
Expand All @@ -180,7 +178,13 @@ static int parse_server_name_extension(const uint8_t *data, size_t data_len, cha
}
strncpy(hostname, (const char *)(data + pos + 3), len);
(hostname)[len] = '\0';
return len;
if (is_valid_utf8(hostname)) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a check we had before that got lost when we re-wrote the TLS parsing.
We've seen some instances where the hostname is not UTF-8. The JNI interface (is_domain_blocked) uses NewStringUTF to pass the domain name to JVM, this is just to ensure we never get there with non UTF string.

return len;
} else {
log_print(PLATFORM_LOG_PRIORITY_WARN, "invalid UTF-8");
*hostname = 0;
return -34;
}
default:
log_print(PLATFORM_LOG_PRIORITY_DEBUG, "Unknown server name extension name type: %d", data[pos]);
}
Expand Down
2 changes: 1 addition & 1 deletion src/test/Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
CC = gcc
CFLAGS = -Wall -Wimplicit-function-declaration -I../netguard/include

SRC = test_tls.c ../netguard/tls_parser.c
SRC = test_tls.c stubs.c ../netguard/tls_parser.c
OBJ = $(SRC:.c=.o)
EXECUTABLE = test_tls

Expand Down
3 changes: 3 additions & 0 deletions src/test/stubs.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
int is_valid_utf8(const char *str) {
return 1;
}
42 changes: 40 additions & 2 deletions src/test/test_tls.c
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,28 @@ const unsigned char wrong_sni_length[] = {
0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x68, 0x6f, 0x73, 0x74
};

// Fragmented SNI
const unsigned char fragmentedSNI1[] = {
0x16, 0x3, 0x1, 0x0, 0x6e, 0x1, 0x0, 0x0,
0x6d, 0x3, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0x0, 0x2f,
0x0, 0xff, 0x1, 0x0, 0x0, 0x40, 0x0, 0x0,
0x0, 0xe, 0x0, 0xc, 0x0, 0x0, 0x9, 0x6c,
0x6f, 0x63, 0x61, 0x6c, 0x68, 0x6f, 0x73, 0x74,
0x0, 0xd, 0x0, 0xc, 0x0, 0xa, 0x6, 0x1,
0x5, 0x1, 0x4, 0x1, 0x3, 0x1, 0x2, 0x1,
0x0, 0x32, 0x0, 0x1a, 0x0, 0x18, 0x6, 0x1,
0x5, 0x1, 0x4, 0x1, 0x3, 0x1, 0x2, 0x1,
0x1, 0x1, 0x8, 0x4, 0x8, 0x5, 0x8, 0x6,
0x8, 0x9, 0x8
};
const unsigned char fragmentedSNI2[] = {
0x16, 0x3, 0x1, 0x0, 0x3, 0xa, 0x8, 0xb
};


int main() {
uint8_t *pkt = (uint8_t *)good_data_1;
Expand Down Expand Up @@ -547,7 +569,7 @@ int main() {
error = get_server_name(pkt, sizeof(bad_data_2), pkt, sn);
assert(strcmp("localhost", sn) != 0);
assert(strlen(sn) == 0);
assert(error == -31);
assert(error == -30);

pkt = (uint8_t *)bad_data_3;
memset(sn, 0, FQDN_LENGTH);
Expand All @@ -563,7 +585,23 @@ int main() {
error = get_server_name(pkt, sizeof(wrong_sni_length), pkt, sn);
assert(strcmp("localhost", sn) != 0);
assert(strlen(sn) == 0);
assert(error == -33);
assert(error == -30);

pkt = (uint8_t *)fragmentedSNI2;
memset(sn, 0, FQDN_LENGTH);
*sn = 0;
error = get_server_name(pkt, sizeof(fragmentedSNI2), pkt, sn);
assert(strcmp("localhost", sn) != 0);
assert(strlen(sn) == 0);
assert(error == -6);

pkt = (uint8_t *)fragmentedSNI1;
memset(sn, 0, FQDN_LENGTH);
*sn = 0;
error = get_server_name(pkt, sizeof(fragmentedSNI1), pkt, sn);
assert(strcmp("localhost", sn) == 0);
assert(strlen(sn) == 9);
assert(error == 9);

return 0;
}
Loading