Skip to content

Commit

Permalink
Merge pull request #7302 from dgarske/pk_psk
Browse files Browse the repository at this point in the history
Support for Public Key (PK) callbacks with PSK
  • Loading branch information
JacobBarthelmeh authored Mar 13, 2024
2 parents d2fd937 + 11303ab commit 1e054b9
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 137 deletions.
153 changes: 80 additions & 73 deletions src/internal.c
Original file line number Diff line number Diff line change
Expand Up @@ -31464,23 +31464,13 @@ int SendClientKeyExchange(WOLFSSL* ssl)
case psk_kea:
{
byte* pms = ssl->arrays->preMasterSecret;
int cbret = (int)ssl->options.client_psk_cb(ssl,
ssl->arrays->psk_keySz = ssl->options.client_psk_cb(ssl,
ssl->arrays->server_hint, ssl->arrays->client_identity,
MAX_PSK_ID_LEN, ssl->arrays->psk_key, MAX_PSK_KEY_LEN);

if (cbret == 0 || cbret > MAX_PSK_KEY_LEN) {
if (cbret != USE_HW_PSK) {
ERROR_OUT(PSK_KEY_ERROR, exit_scke);
}
}

if (cbret == USE_HW_PSK) {
/* USE_HW_PSK indicates that the hardware has the PSK
* and generates the premaster secret. */
ssl->arrays->psk_keySz = 0;
}
else {
ssl->arrays->psk_keySz = (word32)cbret;
if (ssl->arrays->psk_keySz == 0 ||
(ssl->arrays->psk_keySz > MAX_PSK_KEY_LEN &&
(int)ssl->arrays->psk_keySz != USE_HW_PSK)) {
ERROR_OUT(PSK_KEY_ERROR, exit_scke);
}

/* Ensure the buffer is null-terminated. */
Expand All @@ -31492,7 +31482,7 @@ int SendClientKeyExchange(WOLFSSL* ssl)
XMEMCPY(args->encSecret, ssl->arrays->client_identity,
args->encSz);
ssl->options.peerAuthGood = 1;
if (cbret != USE_HW_PSK) {
if ((int)ssl->arrays->psk_keySz > 0) {
/* CLIENT: Pre-shared Key for peer authentication. */

/* make psk pre master secret */
Expand All @@ -31508,8 +31498,8 @@ int SendClientKeyExchange(WOLFSSL* ssl)
ssl->arrays->preMasterSz = (ssl->arrays->psk_keySz * 2)
+ (2 * OPAQUE16_LEN);
ForceZero(ssl->arrays->psk_key, ssl->arrays->psk_keySz);
ssl->arrays->psk_keySz = 0; /* No further need */
}
ssl->arrays->psk_keySz = 0; /* No further need */
break;
}
#endif /* !NO_PSK */
Expand All @@ -31520,12 +31510,14 @@ int SendClientKeyExchange(WOLFSSL* ssl)
args->output = args->encSecret;

ssl->arrays->psk_keySz = ssl->options.client_psk_cb(ssl,
ssl->arrays->server_hint, ssl->arrays->client_identity,
MAX_PSK_ID_LEN, ssl->arrays->psk_key, MAX_PSK_KEY_LEN);
ssl->arrays->server_hint, ssl->arrays->client_identity,
MAX_PSK_ID_LEN, ssl->arrays->psk_key, MAX_PSK_KEY_LEN);
if (ssl->arrays->psk_keySz == 0 ||
ssl->arrays->psk_keySz > MAX_PSK_KEY_LEN) {
(ssl->arrays->psk_keySz > MAX_PSK_KEY_LEN &&
(int)ssl->arrays->psk_keySz != USE_HW_PSK)) {
ERROR_OUT(PSK_KEY_ERROR, exit_scke);
}

ssl->arrays->client_identity[MAX_PSK_ID_LEN] = '\0'; /* null term */
esSz = (word32)XSTRLEN(ssl->arrays->client_identity);

Expand Down Expand Up @@ -31601,12 +31593,14 @@ int SendClientKeyExchange(WOLFSSL* ssl)

/* Send PSK client identity */
ssl->arrays->psk_keySz = ssl->options.client_psk_cb(ssl,
ssl->arrays->server_hint, ssl->arrays->client_identity,
MAX_PSK_ID_LEN, ssl->arrays->psk_key, MAX_PSK_KEY_LEN);
ssl->arrays->server_hint, ssl->arrays->client_identity,
MAX_PSK_ID_LEN, ssl->arrays->psk_key, MAX_PSK_KEY_LEN);
if (ssl->arrays->psk_keySz == 0 ||
ssl->arrays->psk_keySz > MAX_PSK_KEY_LEN) {
(ssl->arrays->psk_keySz > MAX_PSK_KEY_LEN &&
(int)ssl->arrays->psk_keySz != USE_HW_PSK)) {
ERROR_OUT(PSK_KEY_ERROR, exit_scke);
}

ssl->arrays->client_identity[MAX_PSK_ID_LEN] = '\0'; /* null term */
esSz = (word32)XSTRLEN(ssl->arrays->client_identity);
if (esSz > MAX_PSK_ID_LEN) {
Expand All @@ -31626,7 +31620,7 @@ int SendClientKeyExchange(WOLFSSL* ssl)
args->length = MAX_ENCRYPT_SZ;

/* Create shared ECC key leaving room at the beginning
of buffer for size of shared key. */
* of buffer for size of shared key. */
ssl->arrays->preMasterSz = ENCRYPT_LEN - OPAQUE16_LEN;

#ifdef HAVE_CURVE25519
Expand Down Expand Up @@ -32017,13 +32011,15 @@ int SendClientKeyExchange(WOLFSSL* ssl)
pms += ssl->arrays->preMasterSz;

/* make psk pre master secret */
/* length of key + length 0s + length of key + key */
c16toa((word16)ssl->arrays->psk_keySz, pms);
pms += OPAQUE16_LEN;
XMEMCPY(pms, ssl->arrays->psk_key, ssl->arrays->psk_keySz);
ssl->arrays->preMasterSz +=
ssl->arrays->psk_keySz + OPAQUE16_LEN;
ForceZero(ssl->arrays->psk_key, ssl->arrays->psk_keySz);
if ((int)ssl->arrays->psk_keySz > 0) {
/* length of key + length 0s + length of key + key */
c16toa((word16)ssl->arrays->psk_keySz, pms);
pms += OPAQUE16_LEN;
XMEMCPY(pms, ssl->arrays->psk_key, ssl->arrays->psk_keySz);
ssl->arrays->preMasterSz +=
ssl->arrays->psk_keySz + OPAQUE16_LEN;
ForceZero(ssl->arrays->psk_key, ssl->arrays->psk_keySz);
}
ssl->arrays->psk_keySz = 0; /* No further need */
break;
}
Expand All @@ -32044,18 +32040,19 @@ int SendClientKeyExchange(WOLFSSL* ssl)
args->encSz += args->length + OPAQUE8_LEN;

/* Create pre master secret is the concatenation of
eccSize + eccSharedKey + pskSize + pskKey */
* eccSize + eccSharedKey + pskSize + pskKey */
c16toa((word16)ssl->arrays->preMasterSz, pms);
ssl->arrays->preMasterSz += OPAQUE16_LEN;
pms += ssl->arrays->preMasterSz;

c16toa((word16)ssl->arrays->psk_keySz, pms);
pms += OPAQUE16_LEN;
XMEMCPY(pms, ssl->arrays->psk_key, ssl->arrays->psk_keySz);
ssl->arrays->preMasterSz +=
ssl->arrays->psk_keySz + OPAQUE16_LEN;
if ((int)ssl->arrays->psk_keySz > 0) {
c16toa((word16)ssl->arrays->psk_keySz, pms);
pms += OPAQUE16_LEN;
XMEMCPY(pms, ssl->arrays->psk_key, ssl->arrays->psk_keySz);
ssl->arrays->preMasterSz += ssl->arrays->psk_keySz + OPAQUE16_LEN;

ForceZero(ssl->arrays->psk_key, ssl->arrays->psk_keySz);
ForceZero(ssl->arrays->psk_key, ssl->arrays->psk_keySz);
}
ssl->arrays->psk_keySz = 0; /* No further need */
break;
}
Expand Down Expand Up @@ -38691,31 +38688,35 @@ static int DefTicketEncCb(WOLFSSL* ssl, byte key_name[WOLFSSL_TICKET_NAME_SZ],
MAX_PSK_KEY_LEN);

if (ssl->arrays->psk_keySz == 0 ||
ssl->arrays->psk_keySz > MAX_PSK_KEY_LEN) {
#if defined(WOLFSSL_EXTRA_ALERTS) || \
defined(WOLFSSL_PSK_IDENTITY_ALERT)
SendAlert(ssl, alert_fatal,
unknown_psk_identity);
#endif
(ssl->arrays->psk_keySz > MAX_PSK_KEY_LEN &&
(int)ssl->arrays->psk_keySz != USE_HW_PSK)) {
#if defined(WOLFSSL_EXTRA_ALERTS) || \
defined(WOLFSSL_PSK_IDENTITY_ALERT)
SendAlert(ssl, alert_fatal,
unknown_psk_identity);
#endif
ERROR_OUT(PSK_KEY_ERROR, exit_dcke);
}
/* SERVER: Pre-shared Key for peer authentication. */
ssl->options.peerAuthGood = 1;

/* make psk pre master secret */
/* length of key + length 0s + length of key + key */
c16toa((word16) ssl->arrays->psk_keySz, pms);
pms += OPAQUE16_LEN;
if ((int)ssl->arrays->psk_keySz > 0) {
/* length of key + length 0s + length of key + key */
c16toa((word16) ssl->arrays->psk_keySz, pms);
pms += OPAQUE16_LEN;

XMEMSET(pms, 0, ssl->arrays->psk_keySz);
pms += ssl->arrays->psk_keySz;
XMEMSET(pms, 0, ssl->arrays->psk_keySz);
pms += ssl->arrays->psk_keySz;

c16toa((word16) ssl->arrays->psk_keySz, pms);
pms += OPAQUE16_LEN;
c16toa((word16) ssl->arrays->psk_keySz, pms);
pms += OPAQUE16_LEN;

XMEMCPY(pms, ssl->arrays->psk_key, ssl->arrays->psk_keySz);
ssl->arrays->preMasterSz =
(ssl->arrays->psk_keySz * 2) + (OPAQUE16_LEN * 2);
XMEMCPY(pms, ssl->arrays->psk_key, ssl->arrays->psk_keySz);
ssl->arrays->preMasterSz = (ssl->arrays->psk_keySz * 2) +
(OPAQUE16_LEN * 2);
}
ssl->arrays->psk_keySz = 0; /* no further need */
break;
}
#endif /* !NO_PSK */
Expand Down Expand Up @@ -39530,24 +39531,27 @@ static int DefTicketEncCb(WOLFSSL* ssl, byte key_name[WOLFSSL_TICKET_NAME_SZ],
MAX_PSK_KEY_LEN);

if (ssl->arrays->psk_keySz == 0 ||
ssl->arrays->psk_keySz > MAX_PSK_KEY_LEN) {
#if defined(WOLFSSL_EXTRA_ALERTS) || \
defined(WOLFSSL_PSK_IDENTITY_ALERT)
SendAlert(ssl, alert_fatal,
unknown_psk_identity);
#endif
(ssl->arrays->psk_keySz > MAX_PSK_KEY_LEN &&
(int)ssl->arrays->psk_keySz != USE_HW_PSK)) {
#if defined(WOLFSSL_EXTRA_ALERTS) || \
defined(WOLFSSL_PSK_IDENTITY_ALERT)
SendAlert(ssl, alert_fatal,
unknown_psk_identity);
#endif
ERROR_OUT(PSK_KEY_ERROR, exit_dcke);
}
/* SERVER: Pre-shared Key for peer authentication. */
ssl->options.peerAuthGood = 1;

c16toa((word16) ssl->arrays->psk_keySz, pms);
pms += OPAQUE16_LEN;
if ((int)ssl->arrays->psk_keySz > 0) {
c16toa((word16) ssl->arrays->psk_keySz, pms);
pms += OPAQUE16_LEN;

XMEMCPY(pms, ssl->arrays->psk_key,
ssl->arrays->psk_keySz);
ssl->arrays->preMasterSz += ssl->arrays->psk_keySz +
OPAQUE16_LEN;
XMEMCPY(pms, ssl->arrays->psk_key, ssl->arrays->psk_keySz);
ssl->arrays->preMasterSz += ssl->arrays->psk_keySz + OPAQUE16_LEN;
ForceZero(ssl->arrays->psk_key, ssl->arrays->psk_keySz);
}
ssl->arrays->psk_keySz = 0; /* no further need */
break;
}
#endif /* !NO_DH && !NO_PSK */
Expand All @@ -39573,18 +39577,21 @@ static int DefTicketEncCb(WOLFSSL* ssl, byte key_name[WOLFSSL_TICKET_NAME_SZ],
MAX_PSK_KEY_LEN);

if (ssl->arrays->psk_keySz == 0 ||
ssl->arrays->psk_keySz > MAX_PSK_KEY_LEN) {
(ssl->arrays->psk_keySz > MAX_PSK_KEY_LEN &&
(int)ssl->arrays->psk_keySz != USE_HW_PSK)) {
ERROR_OUT(PSK_KEY_ERROR, exit_dcke);
}
/* SERVER: Pre-shared Key for peer authentication. */
ssl->options.peerAuthGood = 1;
if ((int)ssl->arrays->psk_keySz > 0) {
c16toa((word16) ssl->arrays->psk_keySz, pms);
pms += OPAQUE16_LEN;

c16toa((word16) ssl->arrays->psk_keySz, pms);
pms += OPAQUE16_LEN;

XMEMCPY(pms, ssl->arrays->psk_key, ssl->arrays->psk_keySz);
ssl->arrays->preMasterSz +=
ssl->arrays->psk_keySz + OPAQUE16_LEN;
XMEMCPY(pms, ssl->arrays->psk_key, ssl->arrays->psk_keySz);
ssl->arrays->preMasterSz += ssl->arrays->psk_keySz + OPAQUE16_LEN;
ForceZero(ssl->arrays->psk_key, ssl->arrays->psk_keySz);
}
ssl->arrays->psk_keySz = 0; /* no further need */
break;
}
#endif /* (HAVE_ECC || CURVE25519 || CURVE448) && !NO_PSK */
Expand Down
62 changes: 29 additions & 33 deletions src/tls.c
Original file line number Diff line number Diff line change
Expand Up @@ -13341,7 +13341,7 @@ int TLSX_PopulateExtensions(WOLFSSL* ssl, byte isServer)
else
#endif
if (ssl->options.client_psk_cb != NULL ||
ssl->options.client_psk_tls13_cb != NULL) {
ssl->options.client_psk_tls13_cb != NULL) {
/* Default cipher suite. */
byte cipherSuite0 = TLS13_BYTE;
byte cipherSuite = WOLFSSL_DEF_PSK_CIPHER;
Expand All @@ -13363,42 +13363,38 @@ int TLSX_PopulateExtensions(WOLFSSL* ssl, byte isServer)
ssl->arrays->server_hint, ssl->arrays->client_identity,
MAX_PSK_ID_LEN, ssl->arrays->psk_key, MAX_PSK_KEY_LEN);
}
#if defined(OPENSSL_EXTRA)
/* OpenSSL treats 0 as a PSK key length of 0
* and meaning no PSK available.
*/
if (ssl->arrays->psk_keySz > MAX_PSK_KEY_LEN) {
return PSK_KEY_ERROR;
}
if (ssl->arrays->psk_keySz > 0) {
#else
if (ssl->arrays->psk_keySz == 0 ||
ssl->arrays->psk_keySz > MAX_PSK_KEY_LEN) {
return PSK_KEY_ERROR;
if (
#ifndef OPENSSL_EXTRA
/* OpenSSL treats a PSK key length of 0
* to indicate no PSK available.
*/
ssl->arrays->psk_keySz == 0 ||
#endif
(ssl->arrays->psk_keySz > MAX_PSK_KEY_LEN &&
(int)ssl->arrays->psk_keySz != USE_HW_PSK)) {
ret = PSK_KEY_ERROR;
}
#endif
ssl->arrays->client_identity[MAX_PSK_ID_LEN] = '\0';

ssl->options.cipherSuite0 = cipherSuite0;
ssl->options.cipherSuite = cipherSuite;
(void)cipherSuiteFlags;
ret = SetCipherSpecs(ssl);
if (ret != 0)
return ret;
else {
ssl->arrays->client_identity[MAX_PSK_ID_LEN] = '\0';

ret = TLSX_PreSharedKey_Use(&ssl->extensions,
(byte*)ssl->arrays->client_identity,
(word16)XSTRLEN(ssl->arrays->client_identity),
0, ssl->specs.mac_algorithm,
cipherSuite0, cipherSuite, 0,
NULL, ssl->heap);
ssl->options.cipherSuite0 = cipherSuite0;
ssl->options.cipherSuite = cipherSuite;
(void)cipherSuiteFlags;
ret = SetCipherSpecs(ssl);
if (ret == 0) {
ret = TLSX_PreSharedKey_Use(
&ssl->extensions,
(byte*)ssl->arrays->client_identity,
(word16)XSTRLEN(ssl->arrays->client_identity),
0, ssl->specs.mac_algorithm,
cipherSuite0, cipherSuite, 0,
NULL, ssl->heap);
}
if (ret == 0)
usingPSK = 1;
}
if (ret != 0)
return ret;

usingPSK = 1;
#if defined(OPENSSL_EXTRA)
}
#endif
}
#endif /* !NO_PSK */
#if defined(HAVE_SESSION_TICKET) || !defined(NO_PSK)
Expand Down
Loading

0 comments on commit 1e054b9

Please sign in to comment.