diff --git a/wolfcrypt/src/rsa.c b/wolfcrypt/src/rsa.c index 587e47c4b4..76a58d377c 100644 --- a/wolfcrypt/src/rsa.c +++ b/wolfcrypt/src/rsa.c @@ -3111,9 +3111,11 @@ int RsaFunctionCheckIn(const byte* in, word32 inLen, RsaKey* key, static int wc_RsaFunction_ex(const byte* in, word32 inLen, byte* out, word32* outLen, int type, RsaKey* key, WC_RNG* rng, - int checkSmallCt) + int checkSmallCt, RsaPadding *padding) { int ret = 0; + const byte *in2 = NULL; + int in2Len = 0; (void)rng; (void)checkSmallCt; @@ -3127,7 +3129,25 @@ static int wc_RsaFunction_ex(const byte* in, word32 inLen, byte* out, if (key->devId != INVALID_DEVID) #endif { - ret = wc_CryptoCb_Rsa(in, inLen, out, outLen, type, key, rng); +#ifndef WOLF_CRYPTO_CB_NOPAD + /* skip software padding - callback supports padding */ + ret = wc_CryptoCb_RsaPadding(in, inLen, out, outLen, type, key, rng, padding); +#else + /* perform software padding here - callback does not support padding */ + in2Len = inLen; + in2 = in; + if (type == RSA_PUBLIC_ENCRYPT || type == RSA_PRIVATE_ENCRYPT) { + in2Len = wc_RsaEncryptSize(key); + ret = wc_RsaPad_ex(in, inLen, out, in2Len, padding->pad_value, rng, padding->pad_type, + padding->hash, padding->mgf, padding->label, padding->labelSz, padding->saltLen, + mp_count_bits(&key->n), key->heap); + + in2 = out; + } + /* ensure callbacks use raw RSA */ + padding->pad_type = WC_RSA_NO_PAD; + ret = wc_CryptoCb_RsaPadding(in2, in2Len, out, outLen, type, key, rng, padding); +#endif #ifndef WOLF_CRYPTO_CB_ONLY_RSA if (ret != WC_NO_ERR_TRACE(CRYPTOCB_UNAVAILABLE)) return ret; @@ -3147,12 +3167,24 @@ static int wc_RsaFunction_ex(const byte* in, word32 inLen, byte* out, #else /* !WOLF_CRYPTO_CB_ONLY_RSA */ SAVE_VECTOR_REGISTERS(return _svr_ret;); + /* perform software padding here */ + in2Len = inLen; + in2 = in; + if (type == RSA_PUBLIC_ENCRYPT || type == RSA_PRIVATE_ENCRYPT) { + in2Len = wc_RsaEncryptSize(key); + ret = wc_RsaPad_ex(in, inLen, out, in2Len, padding->pad_value, rng, padding->pad_type, + padding->hash, padding->mgf, padding->label, padding->labelSz, padding->saltLen, + mp_count_bits(&key->n), key->heap); + + in2 = out; + } + #if !defined(WOLFSSL_RSA_VERIFY_ONLY) && !defined(TEST_UNPAD_CONSTANT_TIME) && \ !defined(NO_RSA_BOUNDS_CHECK) if (type == RSA_PRIVATE_DECRYPT && key->state == RSA_STATE_DECRYPT_EXPTMOD) { - ret = RsaFunctionCheckIn(in, inLen, key, checkSmallCt); + ret = RsaFunctionCheckIn(out, inLen, key, checkSmallCt); if (ret != 0) { RESTORE_VECTOR_REGISTERS(); return ret; @@ -3164,18 +3196,18 @@ static int wc_RsaFunction_ex(const byte* in, word32 inLen, byte* out, #if defined(WOLFSSL_ASYNC_CRYPT) && defined(WC_ASYNC_ENABLE_RSA) if (key->asyncDev.marker == WOLFSSL_ASYNC_MARKER_RSA && key->n.raw.len > 0) { - ret = wc_RsaFunctionAsync(in, inLen, out, outLen, type, key, rng); + ret = wc_RsaFunctionAsync(in2, in2Len, out, outLen, type, key, rng); } else #endif #ifdef WC_RSA_NONBLOCK if (key->nb) { - ret = wc_RsaFunctionNonBlock(in, inLen, out, outLen, type, key); + ret = wc_RsaFunctionNonBlock(in2, in2Len, out, outLen, type, key); } else #endif { - ret = wc_RsaFunctionSync(in, inLen, out, outLen, type, key, rng); + ret = wc_RsaFunctionSync(in2, in2Len, out, outLen, type, key, rng); } RESTORE_VECTOR_REGISTERS(); @@ -3194,6 +3226,7 @@ static int wc_RsaFunction_ex(const byte* in, word32 inLen, byte* out, key->state = RSA_STATE_NONE; wc_RsaCleanup(key); } + (void)padding; return ret; #endif /* !WOLF_CRYPTO_CB_ONLY_RSA */ } @@ -3202,7 +3235,19 @@ int wc_RsaFunction(const byte* in, word32 inLen, byte* out, word32* outLen, int type, RsaKey* key, WC_RNG* rng) { /* Always check for ciphertext of 0 or 1. (Shouldn't for OAEP decrypt.) */ - return wc_RsaFunction_ex(in, inLen, out, outLen, type, key, rng, 1); + RsaPadding padding; + + XMEMSET(&padding, 0, sizeof(RsaPadding)); + padding.pad_type = WC_RSA_NO_PAD; + + return wc_RsaFunction_ex(in, inLen, out, outLen, type, key, rng, 1, &padding); +} + +int wc_RsaFunctionPad(const byte* in, word32 inLen, byte* out, + word32* outLen, int type, RsaKey* key, WC_RNG* rng, RsaPadding *padding) +{ + /* Always check for ciphertext of 0 or 1. (Shouldn't for OAEP decrypt.) */ + return wc_RsaFunction_ex(in, inLen, out, outLen, type, key, rng, 1, padding); } #ifndef WOLFSSL_RSA_VERIFY_ONLY @@ -3232,9 +3277,11 @@ static int RsaPublicEncryptEx(const byte* in, word32 inLen, byte* out, byte* label, word32 labelSz, int saltLen, WC_RNG* rng) { + int ret = 0; int sz; int state; + RsaPadding padding; if (in == NULL || inLen == 0 || out == NULL || key == NULL) { return BAD_FUNC_ARG; @@ -3332,21 +3379,32 @@ static int RsaPublicEncryptEx(const byte* in, word32 inLen, byte* out, #endif /* WOLFSSL_SE050 */ key->state = RSA_STATE_ENCRYPT_PAD; - ret = wc_RsaPad_ex(in, inLen, out, (word32)sz, pad_value, rng, pad_type, - hash, mgf, label, labelSz, saltLen, - mp_count_bits(&key->n), key->heap); - if (ret < 0) { - break; + if (pad_type == WC_RSA_NO_PAD) { + ret = wc_RsaPad_ex(in, inLen, out, (word32)sz, pad_value, rng, pad_type, + hash, mgf, label, labelSz, saltLen, + mp_count_bits(&key->n), key->heap); + if (ret < 0) { + break; + } } + XMEMSET(&padding, 0, sizeof(RsaPadding)); + padding.pad_value = pad_value; + padding.pad_type = pad_type; + padding.hash = hash; + padding.mgf = mgf; + padding.label = label; + padding.labelSz = labelSz; + padding.saltLen = saltLen; + key->state = RSA_STATE_ENCRYPT_EXPTMOD; FALL_THROUGH; case RSA_STATE_ENCRYPT_EXPTMOD: key->dataLen = outLen; - ret = wc_RsaFunction(out, (word32)sz, out, &key->dataLen, rsa_type, key, - rng); + ret = wc_RsaFunctionPad(in, (word32)inLen, out, &key->dataLen, rsa_type, key, + rng, &padding); if (ret >= 0 || ret == WC_NO_ERR_TRACE(WC_PENDING_E)) { key->state = RSA_STATE_ENCRYPT_RES; @@ -3408,8 +3466,10 @@ static int RsaPrivateDecryptEx(const byte* in, word32 inLen, byte* out, byte* label, word32 labelSz, int saltLen, WC_RNG* rng) { + int ret = WC_NO_ERR_TRACE(RSA_WRONG_TYPE_E); byte* pad = NULL; + RsaPadding padding; if (in == NULL || inLen == 0 || out == NULL || key == NULL) { return BAD_FUNC_ARG; @@ -3520,14 +3580,25 @@ static int RsaPrivateDecryptEx(const byte* in, word32 inLen, byte* out, FALL_THROUGH; case RSA_STATE_DECRYPT_EXPTMOD: + + XMEMSET(&padding, 0, sizeof(padding)); + padding.pad_type = pad_type; + padding.pad_value = pad_value; + padding.hash = hash; + padding.mgf = mgf; + padding.label = label; + padding.labelSz = labelSz; + padding.saltLen = saltLen; + padding.unpadded = 0; + #if !defined(WOLFSSL_RSA_VERIFY_ONLY) && !defined(WOLFSSL_RSA_VERIFY_INLINE) && \ !defined(WOLFSSL_NO_MALLOC) ret = wc_RsaFunction_ex(key->data, inLen, key->data, &key->dataLen, rsa_type, key, rng, - pad_type != WC_RSA_OAEP_PAD); + pad_type != WC_RSA_OAEP_PAD, &padding); #else ret = wc_RsaFunction_ex(in, inLen, out, &key->dataLen, rsa_type, key, - rng, pad_type != WC_RSA_OAEP_PAD); + rng, pad_type != WC_RSA_OAEP_PAD, &padding); #endif if (ret >= 0 || ret == WC_NO_ERR_TRACE(WC_PENDING_E)) { @@ -3536,6 +3607,10 @@ static int RsaPrivateDecryptEx(const byte* in, word32 inLen, byte* out, if (ret < 0) { break; } + if (padding.unpadded) { + ret = key->dataLen; + goto unpadded; + } FALL_THROUGH; @@ -3628,6 +3703,7 @@ static int RsaPrivateDecryptEx(const byte* in, word32 inLen, byte* out, return ret; } + unpadded: key->state = RSA_STATE_NONE; wc_RsaCleanup(key);