diff --git a/.gitignore b/.gitignore index 62a5fc2..0074786 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ CMakeCache.txt src/.vscode/ *.o +*test_tls diff --git a/src/netguard/include/util.h b/src/netguard/include/util.h index 7f41aaf..b8e4e6f 100644 --- a/src/netguard/include/util.h +++ b/src/netguard/include/util.h @@ -7,7 +7,7 @@ uint16_t calc_checksum(uint16_t start, const uint8_t *buffer, size_t length); int compare_u32(uint32_t seq1, uint32_t seq2); -int sdk_int(JNIEnv *env); +//int sdk_int(JNIEnv *env); void hex2bytes(const char *hex, uint8_t *buffer); diff --git a/src/netguard/tls_parser.c b/src/netguard/tls_parser.c index 92100cc..e4d8e8c 100644 --- a/src/netguard/tls_parser.c +++ b/src/netguard/tls_parser.c @@ -10,6 +10,7 @@ #include #include "platform.h" #include "tls.h" +#include "util.h" static int parse_tls_server_name(const uint8_t *data, const size_t data_len, char *server_name); static int parse_extensions(const uint8_t*, size_t, char *); @@ -69,9 +70,8 @@ static int parse_tls_server_name(const uint8_t *data, const size_t data_len, cha } /* TLS record length */ -// uint16_t len = ((size_t)data[3] << 8) + (size_t)data[4] + TLS_HEADER_LEN; size_t len = ntohs(*((uint16_t *) (data + 3))) + TLS_HEADER_LEN; -// data_len = MIN(len, data_len); + log_print(PLATFORM_LOG_PRIORITY_INFO, "data len %d, record len %d\n", data_len, len); if (data_len < len) { // purposely don't return as we have checks later on log_print(PLATFORM_LOG_PRIORITY_WARN, "TLS data length smaller than expected, proceed anyways"); @@ -79,9 +79,9 @@ static int parse_tls_server_name(const uint8_t *data, const size_t data_len, cha /* handshake */ size_t pos = TLS_HEADER_LEN; -// if (pos + 1 > data_len) { -// return -5; -// } + if (pos + 1 > data_len) { + return -5; + } if (data[pos] != 0x1) { // not a client hello @@ -98,18 +98,17 @@ static int parse_tls_server_name(const uint8_t *data, const size_t data_len, cha pos += 38; // Session ID -// if (pos + 1 > data_len) return -7; + if (pos + 1 > data_len) return -7; len = (size_t)data[pos]; pos += 1 + len; /* Cipher Suites */ -// if (pos + 2 > data_len) return -8; -// len = ((size_t)data[pos] << 8) + (size_t)data[pos + 1]; + if (pos + 2 > data_len) return -8; len = ntohs(*((uint16_t *) (data + pos))); pos += 2 + len; /* Compression Methods */ -// if (pos + 1 > data_len) return -9; + if (pos + 1 > data_len) return -9; len = (size_t)data[pos]; pos += 1 + len; @@ -119,16 +118,17 @@ static int parse_tls_server_name(const uint8_t *data, const size_t data_len, cha } /* Extensions */ -// if (pos + 2 > data_len) { -// return -11; -// } -// len = ((size_t)data[pos] << 8) + (size_t)data[pos + 1]; + if (pos + 2 > data_len) { + return -11; + } len = ntohs(*((uint16_t *) (data + pos))); pos += 2; -// if (pos + len > data_len) { + if (pos + len > data_len) { + // Possibly a TLS fragmented record, continue anyways to see if we find SNI in the fragment + log_print(PLATFORM_LOG_PRIORITY_WARN, "Out of bounds at extensions length, pos(%d) + len(%d) > data_len(%d)", pos, len, data_len); // return -12; -// } + } return parse_extensions(data + pos, len, server_name); } @@ -139,15 +139,14 @@ static int parse_extensions(const uint8_t *data, size_t data_len, char *hostname /* Parse each 4 bytes for the extension header */ while (pos + 4 <= data_len) { /* Extension Length */ -// len = ((size_t)data[pos + 2] << 8) +(size_t)data[pos + 3]; len = ntohs(*((uint16_t *) (data + pos + 2))); /* Check if it's a server name extension */ if (data[pos] == 0x00 && data[pos + 1] == 0x00) { /* There can be only one extension of each type, so we break our state and move p to beinnging of the extension here */ -// if (pos + 4 + len > data_len) -// return -20; + if (pos + 4 + len > data_len) + return -20; return parse_server_name_extension(data + pos + 4, len, hostname); } pos += 4 + len; /* Advance to the next extension header */ @@ -164,12 +163,11 @@ static int parse_server_name_extension(const uint8_t *data, size_t data_len, cha size_t len; while (pos + 3 < data_len) { -// len = ((size_t)data[pos + 1] << 8) + (size_t)data[pos + 2]; len = ntohs(*((uint16_t *) (data + pos + 1))); -// if (pos + 3 + len > data_len) { -// return -30; -// } + if (pos + 3 + len > data_len) { + return -30; + } switch (data[pos]) { /* name type */ case 0x00: /* host_name */ @@ -180,7 +178,13 @@ static int parse_server_name_extension(const uint8_t *data, size_t data_len, cha } strncpy(hostname, (const char *)(data + pos + 3), len); (hostname)[len] = '\0'; - return len; + if (is_valid_utf8(hostname)) { + return len; + } else { + log_print(PLATFORM_LOG_PRIORITY_WARN, "invalid UTF-8"); + *hostname = 0; + return -34; + } default: log_print(PLATFORM_LOG_PRIORITY_DEBUG, "Unknown server name extension name type: %d", data[pos]); } diff --git a/src/test/Makefile b/src/test/Makefile index 82cf727..b4215d1 100644 --- a/src/test/Makefile +++ b/src/test/Makefile @@ -1,7 +1,7 @@ CC = gcc CFLAGS = -Wall -Wimplicit-function-declaration -I../netguard/include -SRC = test_tls.c ../netguard/tls_parser.c +SRC = test_tls.c stubs.c ../netguard/tls_parser.c OBJ = $(SRC:.c=.o) EXECUTABLE = test_tls diff --git a/src/test/stubs.c b/src/test/stubs.c new file mode 100644 index 0000000..2e3123a --- /dev/null +++ b/src/test/stubs.c @@ -0,0 +1,3 @@ +int is_valid_utf8(const char *str) { + return 1; +} \ No newline at end of file diff --git a/src/test/test_tls.c b/src/test/test_tls.c index 4351ccb..b107d47 100644 --- a/src/test/test_tls.c +++ b/src/test/test_tls.c @@ -478,6 +478,28 @@ const unsigned char wrong_sni_length[] = { 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x68, 0x6f, 0x73, 0x74 }; +// Fragmented SNI +const unsigned char fragmentedSNI1[] = { + 0x16, 0x3, 0x1, 0x0, 0x6e, 0x1, 0x0, 0x0, + 0x6d, 0x3, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0x0, 0x2f, + 0x0, 0xff, 0x1, 0x0, 0x0, 0x40, 0x0, 0x0, + 0x0, 0xe, 0x0, 0xc, 0x0, 0x0, 0x9, 0x6c, + 0x6f, 0x63, 0x61, 0x6c, 0x68, 0x6f, 0x73, 0x74, + 0x0, 0xd, 0x0, 0xc, 0x0, 0xa, 0x6, 0x1, + 0x5, 0x1, 0x4, 0x1, 0x3, 0x1, 0x2, 0x1, + 0x0, 0x32, 0x0, 0x1a, 0x0, 0x18, 0x6, 0x1, + 0x5, 0x1, 0x4, 0x1, 0x3, 0x1, 0x2, 0x1, + 0x1, 0x1, 0x8, 0x4, 0x8, 0x5, 0x8, 0x6, + 0x8, 0x9, 0x8 +}; +const unsigned char fragmentedSNI2[] = { + 0x16, 0x3, 0x1, 0x0, 0x3, 0xa, 0x8, 0xb +}; + int main() { uint8_t *pkt = (uint8_t *)good_data_1; @@ -547,7 +569,7 @@ int main() { error = get_server_name(pkt, sizeof(bad_data_2), pkt, sn); assert(strcmp("localhost", sn) != 0); assert(strlen(sn) == 0); - assert(error == -31); + assert(error == -30); pkt = (uint8_t *)bad_data_3; memset(sn, 0, FQDN_LENGTH); @@ -563,7 +585,23 @@ int main() { error = get_server_name(pkt, sizeof(wrong_sni_length), pkt, sn); assert(strcmp("localhost", sn) != 0); assert(strlen(sn) == 0); - assert(error == -33); + assert(error == -30); + + pkt = (uint8_t *)fragmentedSNI2; + memset(sn, 0, FQDN_LENGTH); + *sn = 0; + error = get_server_name(pkt, sizeof(fragmentedSNI2), pkt, sn); + assert(strcmp("localhost", sn) != 0); + assert(strlen(sn) == 0); + assert(error == -6); + + pkt = (uint8_t *)fragmentedSNI1; + memset(sn, 0, FQDN_LENGTH); + *sn = 0; + error = get_server_name(pkt, sizeof(fragmentedSNI1), pkt, sn); + assert(strcmp("localhost", sn) == 0); + assert(strlen(sn) == 9); + assert(error == 9); return 0; }