diff --git a/impl/src/main/java/io/jsonwebtoken/impl/security/ConcatKDF.java b/impl/src/main/java/io/jsonwebtoken/impl/security/ConcatKDF.java index 5bd2f8ef2..374398456 100644 --- a/impl/src/main/java/io/jsonwebtoken/impl/security/ConcatKDF.java +++ b/impl/src/main/java/io/jsonwebtoken/impl/security/ConcatKDF.java @@ -15,6 +15,7 @@ */ package io.jsonwebtoken.impl.security; +import io.jsonwebtoken.impl.lang.Bytes; import io.jsonwebtoken.impl.lang.CheckedFunction; import io.jsonwebtoken.lang.Assert; import io.jsonwebtoken.security.SecurityException; @@ -114,43 +115,68 @@ public SecretKey deriveKey(final byte[] Z, final long derivedKeyBitLength, final long inputBitLength = bitLength(counter) + bitLength(Z) + bitLength(OtherInfo); Assert.state(inputBitLength <= MAX_HASH_INPUT_BIT_LENGTH, "Hash input is too large."); - byte[] derivedKeyBytes = jca().withMessageDigest(new CheckedFunction() { - @Override - public byte[] apply(MessageDigest md) throws Exception { - - final ByteArrayOutputStream stream = new ByteArrayOutputStream((int) derivedKeyByteLength); - - // Section 5.8.1.1, Process step #5. We depart from Java idioms here by starting iteration index at 1 - // (instead of 0) and continue to <= reps (instead of < reps) to match the NIST publication algorithm - // notation convention (so variables like Ki and kLast below match the NIST definitions). - for (long i = 1; i <= reps; i++) { - - // Section 5.8.1.1, Process step #5.1: - md.update(counter); - md.update(Z); - md.update(OtherInfo); - byte[] Ki = md.digest(); - - // Section 5.8.1.1, Process step #5.2: - increment(counter); - - // Section 5.8.1.1, Process step #6: - if (i == reps && kLastPartial) { - long leftmostBitLength = derivedKeyBitLength % hashBitLength; - int leftmostByteLength = (int) (leftmostBitLength / Byte.SIZE); - byte[] kLast = new byte[leftmostByteLength]; - System.arraycopy(Ki, 0, kLast, 0, kLast.length); - Ki = kLast; + final ClearableByteArrayOutputStream stream = new ClearableByteArrayOutputStream((int) derivedKeyByteLength); + byte[] derivedKeyBytes = EMPTY; + + try { + derivedKeyBytes = jca().withMessageDigest(new CheckedFunction() { + @Override + public byte[] apply(MessageDigest md) throws Exception { + + // Section 5.8.1.1, Process step #5. We depart from Java idioms here by starting iteration index at 1 + // (instead of 0) and continue to <= reps (instead of < reps) to match the NIST publication algorithm + // notation convention (so variables like Ki and kLast below match the NIST definitions). + for (long i = 1; i <= reps; i++) { + + // Section 5.8.1.1, Process step #5.1: + md.update(counter); + md.update(Z); + md.update(OtherInfo); + byte[] Ki = md.digest(); + + // Section 5.8.1.1, Process step #5.2: + increment(counter); + + // Section 5.8.1.1, Process step #6: + if (i == reps && kLastPartial) { + long leftmostBitLength = derivedKeyBitLength % hashBitLength; + int leftmostByteLength = (int) (leftmostBitLength / Byte.SIZE); + byte[] kLast = new byte[leftmostByteLength]; + System.arraycopy(Ki, 0, kLast, 0, kLast.length); + Ki = kLast; + } + + stream.write(Ki); } - stream.write(Ki); + // Section 5.8.1.1, Process step #7: + return stream.toByteArray(); } + }); + return new SecretKeySpec(derivedKeyBytes, AesAlgorithm.KEY_ALG_NAME); + } finally { + // key cleanup + Bytes.clear(derivedKeyBytes); // SecretKeySpec clones this, so we can clear it out safely + Bytes.clear(counter); + stream.reset(); + // we don't clear out 'Z', since that is the responsibility of the caller + } + } - // Section 5.8.1.1, Process step #7: - return stream.toByteArray(); - } - }); + /** + * Calling ByteArrayOutputStream.toByteArray returns a copy of the bytes, so this class allows us to completely + * zero-out the buffer upon reset (whereas BAOS just resets the position marker, leaving the bytes in tact) + */ + private static class ClearableByteArrayOutputStream extends ByteArrayOutputStream { - return new SecretKeySpec(derivedKeyBytes, AesAlgorithm.KEY_ALG_NAME); + public ClearableByteArrayOutputStream(int size) { + super(size); + } + + @Override + public synchronized void reset() { + super.reset(); + Bytes.clear(buf); // zero out internal buffer + } } } diff --git a/impl/src/main/java/io/jsonwebtoken/impl/security/EcdhKeyAlgorithm.java b/impl/src/main/java/io/jsonwebtoken/impl/security/EcdhKeyAlgorithm.java index 8e7711318..97aec5cfd 100644 --- a/impl/src/main/java/io/jsonwebtoken/impl/security/EcdhKeyAlgorithm.java +++ b/impl/src/main/java/io/jsonwebtoken/impl/security/EcdhKeyAlgorithm.java @@ -139,7 +139,11 @@ private SecretKey deriveKey(KeyRequest request, PublicKey publicKey, PrivateK byte[] apv = request.getHeader().getAgreementPartyVInfo(); byte[] OtherInfo = createOtherInfo(requiredCekBitLen, AlgorithmID, apu, apv); byte[] Z = generateZ(request, publicKey, privateKey); - return CONCAT_KDF.deriveKey(Z, requiredCekBitLen, OtherInfo); + try { + return CONCAT_KDF.deriveKey(Z, requiredCekBitLen, OtherInfo); + } finally { + Bytes.clear(Z); + } } @Override diff --git a/impl/src/main/java/io/jsonwebtoken/impl/security/HmacAesAeadAlgorithm.java b/impl/src/main/java/io/jsonwebtoken/impl/security/HmacAesAeadAlgorithm.java index 3f82040c1..69cf0e018 100644 --- a/impl/src/main/java/io/jsonwebtoken/impl/security/HmacAesAeadAlgorithm.java +++ b/impl/src/main/java/io/jsonwebtoken/impl/security/HmacAesAeadAlgorithm.java @@ -96,7 +96,13 @@ public void encrypt(final AeadRequest req, final AeadResult res) { int halfCount = compositeKeyBytes.length / 2; // https://tools.ietf.org/html/rfc7518#section-5.2 byte[] macKeyBytes = Arrays.copyOfRange(compositeKeyBytes, 0, halfCount); byte[] encKeyBytes = Arrays.copyOfRange(compositeKeyBytes, halfCount, compositeKeyBytes.length); - final SecretKey encryptionKey = new SecretKeySpec(encKeyBytes, KEY_ALG_NAME); + final SecretKey encryptionKey; + try { + encryptionKey = new SecretKeySpec(encKeyBytes, KEY_ALG_NAME); + } finally { + Bytes.clear(encKeyBytes); + Bytes.clear(compositeKeyBytes); + } final InputStream plaintext = Assert.notNull(req.getPayload(), "Request content (plaintext) InputStream cannot be null."); @@ -121,9 +127,13 @@ public Object apply(Cipher cipher) throws Exception { byte[] aadBytes = aad == null ? Bytes.EMPTY : Streams.bytes(aad, "Unable to read AAD bytes."); - byte[] tag = sign(aadBytes, iv, Streams.of(copy.toByteArray()), macKeyBytes); - - res.setTag(tag).setIv(iv); + byte[] tag; + try { + tag = sign(aadBytes, iv, Streams.of(copy.toByteArray()), macKeyBytes); + res.setTag(tag).setIv(iv); + } finally { + Bytes.clear(macKeyBytes); + } } private byte[] sign(byte[] aad, byte[] iv, InputStream ciphertext, byte[] macKeyBytes) { @@ -162,7 +172,13 @@ public void decrypt(final DecryptAeadRequest req, final OutputStream plaintext) int halfCount = compositeKeyBytes.length / 2; // https://tools.ietf.org/html/rfc7518#section-5.2 byte[] macKeyBytes = Arrays.copyOfRange(compositeKeyBytes, 0, halfCount); byte[] encKeyBytes = Arrays.copyOfRange(compositeKeyBytes, halfCount, compositeKeyBytes.length); - final SecretKey decryptionKey = new SecretKeySpec(encKeyBytes, KEY_ALG_NAME); + final SecretKey decryptionKey; + try { + decryptionKey = new SecretKeySpec(encKeyBytes, KEY_ALG_NAME); + } finally { + Bytes.clear(encKeyBytes); + Bytes.clear(compositeKeyBytes); + } InputStream in = Assert.notNull(req.getPayload(), "Decryption request content (ciphertext) InputStream cannot be null."); @@ -174,7 +190,12 @@ public void decrypt(final DecryptAeadRequest req, final OutputStream plaintext) // Assert that the aad + iv + ciphertext provided, when signed, equals the tag provided, // thereby verifying none of it has been tampered with: byte[] aadBytes = aad == null ? Bytes.EMPTY : Streams.bytes(aad, "Unable to read AAD bytes."); - byte[] digest = sign(aadBytes, iv, in, macKeyBytes); + byte[] digest; + try { + digest = sign(aadBytes, iv, in, macKeyBytes); + } finally { + Bytes.clear(macKeyBytes); + } if (!MessageDigest.isEqual(digest, tag)) { //constant time comparison to avoid side-channel attacks String msg = "Ciphertext decryption failed: Authentication tag verification failed."; throw new SignatureException(msg);