Skip to content

Commit

Permalink
wolfcrypt/src/rsa.c: add support for callback RSA PKCS padding
Browse files Browse the repository at this point in the history
- adds support for passing padding information to RSA callback for
  device-supported  RSA padding.
- compile-time gated by -DWOLF_CRYPTO_CB_NOPAD
  • Loading branch information
space88man committed Jul 15, 2024
1 parent 8c0a218 commit b1398d2
Showing 1 changed file with 92 additions and 16 deletions.
108 changes: 92 additions & 16 deletions wolfcrypt/src/rsa.c
Original file line number Diff line number Diff line change
Expand Up @@ -3111,9 +3111,11 @@ int RsaFunctionCheckIn(const byte* in, word32 inLen, RsaKey* key,

static int wc_RsaFunction_ex(const byte* in, word32 inLen, byte* out,
word32* outLen, int type, RsaKey* key, WC_RNG* rng,
int checkSmallCt)
int checkSmallCt, RsaPadding *padding)
{
int ret = 0;
const byte *in2 = NULL;
int in2Len = 0;
(void)rng;
(void)checkSmallCt;

Expand All @@ -3127,7 +3129,25 @@ static int wc_RsaFunction_ex(const byte* in, word32 inLen, byte* out,
if (key->devId != INVALID_DEVID)
#endif
{
ret = wc_CryptoCb_Rsa(in, inLen, out, outLen, type, key, rng);
#ifndef WOLF_CRYPTO_CB_NOPAD
/* skip software padding - callback supports padding */
ret = wc_CryptoCb_RsaPadding(in, inLen, out, outLen, type, key, rng, padding);
#else
/* perform software padding here - callback does not support padding */
in2Len = inLen;
in2 = in;
if (type == RSA_PUBLIC_ENCRYPT || type == RSA_PRIVATE_ENCRYPT) {
in2Len = wc_RsaEncryptSize(key);
ret = wc_RsaPad_ex(in, inLen, out, in2Len, padding->pad_value, rng, padding->pad_type,
padding->hash, padding->mgf, padding->label, padding->labelSz, padding->saltLen,
mp_count_bits(&key->n), key->heap);

in2 = out;
}
/* ensure callbacks use raw RSA */
padding->pad_type = WC_RSA_NO_PAD;
ret = wc_CryptoCb_RsaPadding(in2, in2Len, out, outLen, type, key, rng, padding);
#endif
#ifndef WOLF_CRYPTO_CB_ONLY_RSA
if (ret != WC_NO_ERR_TRACE(CRYPTOCB_UNAVAILABLE))
return ret;
Expand All @@ -3147,12 +3167,24 @@ static int wc_RsaFunction_ex(const byte* in, word32 inLen, byte* out,
#else /* !WOLF_CRYPTO_CB_ONLY_RSA */
SAVE_VECTOR_REGISTERS(return _svr_ret;);

/* perform software padding here */
in2Len = inLen;
in2 = in;
if (type == RSA_PUBLIC_ENCRYPT || type == RSA_PRIVATE_ENCRYPT) {
in2Len = wc_RsaEncryptSize(key);
ret = wc_RsaPad_ex(in, inLen, out, in2Len, padding->pad_value, rng, padding->pad_type,
padding->hash, padding->mgf, padding->label, padding->labelSz, padding->saltLen,
mp_count_bits(&key->n), key->heap);

in2 = out;
}

#if !defined(WOLFSSL_RSA_VERIFY_ONLY) && !defined(TEST_UNPAD_CONSTANT_TIME) && \
!defined(NO_RSA_BOUNDS_CHECK)
if (type == RSA_PRIVATE_DECRYPT &&
key->state == RSA_STATE_DECRYPT_EXPTMOD) {

ret = RsaFunctionCheckIn(in, inLen, key, checkSmallCt);
ret = RsaFunctionCheckIn(out, inLen, key, checkSmallCt);
if (ret != 0) {
RESTORE_VECTOR_REGISTERS();
return ret;
Expand All @@ -3164,18 +3196,18 @@ static int wc_RsaFunction_ex(const byte* in, word32 inLen, byte* out,
#if defined(WOLFSSL_ASYNC_CRYPT) && defined(WC_ASYNC_ENABLE_RSA)
if (key->asyncDev.marker == WOLFSSL_ASYNC_MARKER_RSA &&
key->n.raw.len > 0) {
ret = wc_RsaFunctionAsync(in, inLen, out, outLen, type, key, rng);
ret = wc_RsaFunctionAsync(in2, in2Len, out, outLen, type, key, rng);
}
else
#endif
#ifdef WC_RSA_NONBLOCK
if (key->nb) {
ret = wc_RsaFunctionNonBlock(in, inLen, out, outLen, type, key);
ret = wc_RsaFunctionNonBlock(in2, in2Len, out, outLen, type, key);
}
else
#endif
{
ret = wc_RsaFunctionSync(in, inLen, out, outLen, type, key, rng);
ret = wc_RsaFunctionSync(in2, in2Len, out, outLen, type, key, rng);
}

RESTORE_VECTOR_REGISTERS();
Expand All @@ -3194,6 +3226,7 @@ static int wc_RsaFunction_ex(const byte* in, word32 inLen, byte* out,
key->state = RSA_STATE_NONE;
wc_RsaCleanup(key);
}
(void)padding;
return ret;
#endif /* !WOLF_CRYPTO_CB_ONLY_RSA */
}
Expand All @@ -3202,7 +3235,19 @@ int wc_RsaFunction(const byte* in, word32 inLen, byte* out,
word32* outLen, int type, RsaKey* key, WC_RNG* rng)
{
/* Always check for ciphertext of 0 or 1. (Shouldn't for OAEP decrypt.) */
return wc_RsaFunction_ex(in, inLen, out, outLen, type, key, rng, 1);
RsaPadding padding;

XMEMSET(&padding, 0, sizeof(RsaPadding));
padding.pad_type = WC_RSA_NO_PAD;

return wc_RsaFunction_ex(in, inLen, out, outLen, type, key, rng, 1, &padding);
}

int wc_RsaFunctionPad(const byte* in, word32 inLen, byte* out,
word32* outLen, int type, RsaKey* key, WC_RNG* rng, RsaPadding *padding)
{
/* Always check for ciphertext of 0 or 1. (Shouldn't for OAEP decrypt.) */
return wc_RsaFunction_ex(in, inLen, out, outLen, type, key, rng, 1, padding);
}

#ifndef WOLFSSL_RSA_VERIFY_ONLY
Expand Down Expand Up @@ -3232,9 +3277,11 @@ static int RsaPublicEncryptEx(const byte* in, word32 inLen, byte* out,
byte* label, word32 labelSz, int saltLen,
WC_RNG* rng)
{

int ret = 0;
int sz;
int state;
RsaPadding padding;

if (in == NULL || inLen == 0 || out == NULL || key == NULL) {
return BAD_FUNC_ARG;
Expand Down Expand Up @@ -3332,21 +3379,32 @@ static int RsaPublicEncryptEx(const byte* in, word32 inLen, byte* out,
#endif /* WOLFSSL_SE050 */

key->state = RSA_STATE_ENCRYPT_PAD;
ret = wc_RsaPad_ex(in, inLen, out, (word32)sz, pad_value, rng, pad_type,
hash, mgf, label, labelSz, saltLen,
mp_count_bits(&key->n), key->heap);
if (ret < 0) {
break;
if (pad_type == WC_RSA_NO_PAD) {
ret = wc_RsaPad_ex(in, inLen, out, (word32)sz, pad_value, rng, pad_type,
hash, mgf, label, labelSz, saltLen,
mp_count_bits(&key->n), key->heap);
if (ret < 0) {
break;
}
}

XMEMSET(&padding, 0, sizeof(RsaPadding));
padding.pad_value = pad_value;
padding.pad_type = pad_type;
padding.hash = hash;
padding.mgf = mgf;
padding.label = label;
padding.labelSz = labelSz;
padding.saltLen = saltLen;

key->state = RSA_STATE_ENCRYPT_EXPTMOD;
FALL_THROUGH;

case RSA_STATE_ENCRYPT_EXPTMOD:

key->dataLen = outLen;
ret = wc_RsaFunction(out, (word32)sz, out, &key->dataLen, rsa_type, key,
rng);
ret = wc_RsaFunctionPad(in, (word32)inLen, out, &key->dataLen, rsa_type, key,
rng, &padding);

if (ret >= 0 || ret == WC_NO_ERR_TRACE(WC_PENDING_E)) {
key->state = RSA_STATE_ENCRYPT_RES;
Expand Down Expand Up @@ -3408,8 +3466,10 @@ static int RsaPrivateDecryptEx(const byte* in, word32 inLen, byte* out,
byte* label, word32 labelSz, int saltLen,
WC_RNG* rng)
{

int ret = WC_NO_ERR_TRACE(RSA_WRONG_TYPE_E);
byte* pad = NULL;
RsaPadding padding;

if (in == NULL || inLen == 0 || out == NULL || key == NULL) {
return BAD_FUNC_ARG;
Expand Down Expand Up @@ -3520,14 +3580,25 @@ static int RsaPrivateDecryptEx(const byte* in, word32 inLen, byte* out,
FALL_THROUGH;

case RSA_STATE_DECRYPT_EXPTMOD:

XMEMSET(&padding, 0, sizeof(padding));
padding.pad_type = pad_type;
padding.pad_value = pad_value;
padding.hash = hash;
padding.mgf = mgf;
padding.label = label;
padding.labelSz = labelSz;
padding.saltLen = saltLen;
padding.unpadded = 0;

#if !defined(WOLFSSL_RSA_VERIFY_ONLY) && !defined(WOLFSSL_RSA_VERIFY_INLINE) && \
!defined(WOLFSSL_NO_MALLOC)
ret = wc_RsaFunction_ex(key->data, inLen, key->data, &key->dataLen,
rsa_type, key, rng,
pad_type != WC_RSA_OAEP_PAD);
pad_type != WC_RSA_OAEP_PAD, &padding);
#else
ret = wc_RsaFunction_ex(in, inLen, out, &key->dataLen, rsa_type, key,
rng, pad_type != WC_RSA_OAEP_PAD);
rng, pad_type != WC_RSA_OAEP_PAD, &padding);
#endif

if (ret >= 0 || ret == WC_NO_ERR_TRACE(WC_PENDING_E)) {
Expand All @@ -3536,6 +3607,10 @@ static int RsaPrivateDecryptEx(const byte* in, word32 inLen, byte* out,
if (ret < 0) {
break;
}
if (padding.unpadded) {
ret = key->dataLen;
goto unpadded;
}

FALL_THROUGH;

Expand Down Expand Up @@ -3628,6 +3703,7 @@ static int RsaPrivateDecryptEx(const byte* in, word32 inLen, byte* out,
return ret;
}

unpadded:
key->state = RSA_STATE_NONE;
wc_RsaCleanup(key);

Expand Down

0 comments on commit b1398d2

Please sign in to comment.