diff --git a/pom.xml b/pom.xml index 7ce3d9f..40aa004 100644 --- a/pom.xml +++ b/pom.xml @@ -104,6 +104,11 @@ metrics-core 3.0.2 + + org.keycloak + keycloak-common + 18.0.0 + diff --git a/src/main/java/com/tarento/commenthub/authentication/model/KeyData.java b/src/main/java/com/tarento/commenthub/authentication/model/KeyData.java new file mode 100644 index 0000000..dc69bc6 --- /dev/null +++ b/src/main/java/com/tarento/commenthub/authentication/model/KeyData.java @@ -0,0 +1,31 @@ +package com.tarento.commenthub.authentication.model; + +import java.security.PublicKey; + +public class KeyData { + + private String keyId; + private PublicKey publicKey; + + public KeyData(String keyId, PublicKey publicKey) { + this.keyId = keyId; + this.publicKey = publicKey; + } + + public String getKeyId() { + return keyId; + } + + public void setKeyId(String keyId) { + this.keyId = keyId; + } + + public PublicKey getPublicKey() { + return publicKey; + } + + public void setPublicKey(PublicKey publicKey) { + this.publicKey = publicKey; + } + +} diff --git a/src/main/java/com/tarento/commenthub/authentication/util/AccessTokenValidator.java b/src/main/java/com/tarento/commenthub/authentication/util/AccessTokenValidator.java new file mode 100644 index 0000000..11e6b4b --- /dev/null +++ b/src/main/java/com/tarento/commenthub/authentication/util/AccessTokenValidator.java @@ -0,0 +1,163 @@ +package com.tarento.commenthub.authentication.util; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.tarento.commenthub.constant.Constants; +import com.tarento.commenthub.transactional.utils.PropertiesCache; +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import org.apache.commons.lang3.StringUtils; +import org.keycloak.common.util.Time; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; + +@Component +public class AccessTokenValidator { + + @Autowired + KeyManager keyManager; + + private static Logger logger = LoggerFactory.getLogger(AccessTokenValidator.class.getName()); + private static final ObjectMapper mapper = new ObjectMapper(); + private static PropertiesCache cache = PropertiesCache.getInstance(); + private static final String REALM_URL = + cache.getProperty(Constants.SSO_URL) + "realms/" + cache.getProperty(Constants.SSO_REALM); + + + /** + * Validates the provided JWT token. + * + * @param token The JWT token to be validated. + * @return A map containing the token body if the token is valid and not expired, otherwise an + * empty map. + */ + private Map validateToken(String token) { + try { + // Split the token into its elements + String[] tokenElements = token.split("\\."); + // Check if the token has at least three elements + if (tokenElements.length < 3) { + throw new IllegalArgumentException("Invalid token format"); + } + // Extract header, body, and signature from token elements + String header = tokenElements[0]; + String body = tokenElements[1]; + String signature = tokenElements[2]; + // Concatenate header and body to form the payload + String payload = header + Constants.DOT_SEPARATOR + body; + // Parse header data from base64 encoded header + Map headerData = mapper.readValue(new String(decodeFromBase64(header)), + new TypeReference>() { + }); + String keyId = headerData.get("kid").toString(); + // Verify the token signature + boolean isValid = CryptoUtil.verifyRSASign(payload, decodeFromBase64(signature), + keyManager.getPublicKey(keyId).getPublicKey(), Constants.SHA_256_WITH_RSA); + // If token signature is valid, parse token body and check expiration + if (isValid) { + Map tokenBody = mapper.readValue(new String(decodeFromBase64(body)), + new TypeReference>() { + }); + if (isExpired((Integer) tokenBody.get("exp"))) { + logger.error("Token expired: {}", token); + return Collections.emptyMap(); + } + return tokenBody; + } + } catch (IOException | IllegalArgumentException e) { + logger.error("Error validating token: {}", e.getMessage()); + } catch (Exception ex) { + logger.error("Unexpected error validating token: {}", ex.getMessage()); + } + return Collections.emptyMap(); + } + + + /** + * Verifies the user token and extracts the user ID from it. + * + * @param token The user token to be verified. + * @return The user ID extracted from the token, or UNAUTHORIZED if verification fails or an + * exception occurs. + */ + public String verifyUserToken(String token) { + // Initialize user ID to UNAUTHORIZED + String userId = Constants.UNAUTHORIZED_USER; + try { + // Validate the token and obtain its payload + Map payload = validateToken(token); + // Check if payload is not empty and issuer is valid + if (!payload.isEmpty() && checkIss((String) payload.get("iss"))) { + // Extract user ID from payload + userId = (String) payload.get(Constants.SUB); + // If user ID is not blank, extract the actual user ID + if (StringUtils.isNotBlank(userId)) { + userId = userId.substring(userId.lastIndexOf(":") + 1); + } + } + } catch (Exception ex) { + logger.error("Exception in verifyUserAccessToken: verify ", ex); + } + return userId; + } + + /** + * Checks if the issuer of the token matches the predefined realm URL. + * + * @param iss The issuer extracted from the token. + * @return true if the issuer matches the realm URL, false otherwise. + */ + private boolean checkIss(String iss) { + // Check if the realm URL is blank or if the issuer does not match the realm URL + if (StringUtils.isBlank(REALM_URL) || !REALM_URL.equalsIgnoreCase(iss)) { + logger.warn("Issuer does not match the expected realm URL. Issuer: {}, Expected: {}", iss, + REALM_URL); + return false; + } + logger.info("Issuer validation successful. Issuer: {}", iss); + return true; + } + + + private boolean isExpired(Integer expiration) { + return (Time.currentTime() > expiration); + } + + private byte[] decodeFromBase64(String data) { + return Base64Util.decode(data, 11); + } + + /** + * Fetches the user ID from the provided access token. + * + * @param accessToken The access token from which to fetch the user ID. + * @return The user ID fetched from the access token, or null if the token is invalid or an + * exception occurs. + */ + public String fetchUserIdFromAccessToken(String accessToken) { + // Initialize clientAccessTokenId to null + String clientAccessTokenId = null; + // Check if the accessToken is not null + if (accessToken != null) { + try { + // Verify the access token to fetch the user ID + clientAccessTokenId = verifyUserToken(accessToken); + // If the user ID is UNAUTHORIZED, set it to null + if (Constants.UNAUTHORIZED_USER.equalsIgnoreCase(clientAccessTokenId)) { + clientAccessTokenId = null; + } + } catch (Exception ex) { + String errMsg = + "Exception occurred while fetching the userid from the access token. Exception: " + + ex.getMessage(); + logger.error(errMsg, ex); + clientAccessTokenId = null; + } + } + return clientAccessTokenId; + } + +} diff --git a/src/main/java/com/tarento/commenthub/authentication/util/Base64Util.java b/src/main/java/com/tarento/commenthub/authentication/util/Base64Util.java new file mode 100644 index 0000000..bc55c17 --- /dev/null +++ b/src/main/java/com/tarento/commenthub/authentication/util/Base64Util.java @@ -0,0 +1,703 @@ +package com.tarento.commenthub.authentication.util; + +import java.io.UnsupportedEncodingException; + +public class Base64Util { + + public static final int DEFAULT = 0; + + /** + * Encoder flag bit to omit the padding '=' characters at the end of the output (if any). + */ + public static final int NO_PADDING = 1; + + /** + * Encoder flag bit to omit all line terminators (i.e., the output will be on one long line). + */ + public static final int NO_WRAP = 2; + + /** + * Encoder flag bit to indicate lines should be terminated with a CRLF pair instead of just an LF. + * Has no effect if {@code NO_WRAP} is specified as well. + */ + public static final int CRLF = 4; + + /** + * Encoder/decoder flag bit to indicate using the "URL and filename safe" variant of Base64 (see + * RFC 3548 section 4) where {@code -} and {@code _} are used in place of {@code +} and + * {@code /}. + */ + public static final int URL_SAFE = 8; + + /** + * Flag to pass to {Base64OutputStream} to indicate that it should not close the output stream it + * is wrapping when it itself is closed. + */ + public static final int NO_CLOSE = 16; + + // -------------------------------------------------------- + // shared code + // -------------------------------------------------------- + + private Base64Util() { + } // don't instantiate + + // -------------------------------------------------------- + // decoding + // -------------------------------------------------------- + + /** + * Decode the Base64-encoded data in input and return the data in a new byte array. + *

+ *

The padding '=' characters at the end are considered optional, but + * if any are present, there must be the correct number of them. + * + * @param str the input String to decode, which is converted to bytes using the default charset + * @param flags controls certain features of the decoded output. Pass {@code DEFAULT} to decode + * standard Base64. + * @throws IllegalArgumentException if the input contains incorrect padding + */ + public static byte[] decode(String str, int flags) { + return decode(str.getBytes(), flags); + } + + /** + * Decode the Base64-encoded data in input and return the data in a new byte array. + *

+ *

The padding '=' characters at the end are considered optional, but + * if any are present, there must be the correct number of them. + * + * @param input the input array to decode + * @param flags controls certain features of the decoded output. Pass {@code DEFAULT} to decode + * standard Base64. + * @throws IllegalArgumentException if the input contains incorrect padding + */ + public static byte[] decode(byte[] input, int flags) { + return decode(input, 0, input.length, flags); + } + + /** + * Decode the Base64-encoded data in input and return the data in a new byte array. + *

+ *

The padding '=' characters at the end are considered optional, but + * if any are present, there must be the correct number of them. + * + * @param input the data to decode + * @param offset the position within the input array at which to start + * @param len the number of bytes of input to decode + * @param flags controls certain features of the decoded output. Pass {@code DEFAULT} to decode + * standard Base64. + * @throws IllegalArgumentException if the input contains incorrect padding + */ + public static byte[] decode(byte[] input, int offset, int len, int flags) { + // Allocate space for the most data the input could represent. + // (It could contain less if it contains whitespace, etc.) + Decoder decoder = new Decoder(flags, new byte[len * 3 / 4]); + + if (!decoder.process(input, offset, len, true)) { + throw new IllegalArgumentException("bad base-64"); + } + + // Maybe we got lucky and allocated exactly enough output space. + if (decoder.op == decoder.output.length) { + return decoder.output; + } + + // Need to shorten the array, so allocate a new one of the + // right size and copy. + byte[] temp = new byte[decoder.op]; + System.arraycopy(decoder.output, 0, temp, 0, decoder.op); + return temp; + } + + /** + * Base64-encode the given data and return a newly allocated String with the result. + * + * @param input the data to encode + * @param flags controls certain features of the encoded output. Passing {@code DEFAULT} results + * in output that adheres to RFC 2045. + */ + public static String encodeToString(byte[] input, int flags) { + try { + return new String(encode(input, flags), "US-ASCII"); + } catch (UnsupportedEncodingException e) { + // US-ASCII is guaranteed to be available. + throw new AssertionError(e); + } + } + + // -------------------------------------------------------- + // encoding + // -------------------------------------------------------- + + /** + * Base64-encode the given data and return a newly allocated String with the result. + * + * @param input the data to encode + * @param offset the position within the input array at which to start + * @param len the number of bytes of input to encode + * @param flags controls certain features of the encoded output. Passing {@code DEFAULT} results + * in output that adheres to RFC 2045. + */ + public static String encodeToString(byte[] input, int offset, int len, int flags) { + try { + return new String(encode(input, offset, len, flags), "US-ASCII"); + } catch (UnsupportedEncodingException e) { + // US-ASCII is guaranteed to be available. + throw new AssertionError(e); + } + } + + /** + * Base64-encode the given data and return a newly allocated byte[] with the result. + * + * @param input the data to encode + * @param flags controls certain features of the encoded output. Passing {@code DEFAULT} results + * in output that adheres to RFC 2045. + */ + public static byte[] encode(byte[] input, int flags) { + return encode(input, 0, input.length, flags); + } + + /** + * Base64-encode the given data and return a newly allocated byte[] with the result. + * + * @param input the data to encode + * @param offset the position within the input array at which to start + * @param len the number of bytes of input to encode + * @param flags controls certain features of the encoded output. Passing {@code DEFAULT} results + * in output that adheres to RFC 2045. + */ + public static byte[] encode(byte[] input, int offset, int len, int flags) { + Encoder encoder = new Encoder(flags, null); + + // Compute the exact length of the array we will produce. + int output_len = len / 3 * 4; + + // Account for the tail of the data and the padding bytes, if any. + if (encoder.do_padding) { + if (len % 3 > 0) { + output_len += 4; + } + } else { + switch (len % 3) { + case 0: + break; + case 1: + output_len += 2; + break; + case 2: + output_len += 3; + break; + } + } + + // Account for the newlines, if any. + if (encoder.do_newline && len > 0) { + output_len += (((len - 1) / (3 * Encoder.LINE_GROUPS)) + 1) * + (encoder.do_cr ? 2 : 1); + } + + encoder.output = new byte[output_len]; + encoder.process(input, offset, len, true); + + assert encoder.op == output_len; + + return encoder.output; + } + + /* package */ static abstract class Coder { + + public byte[] output; + public int op; + + /** + * Encode/decode another block of input data. this.output is provided by the caller, and must + * be big enough to hold all the coded data. On exit, this.opwill be set to the length of the + * coded data. + * + * @param finish true if this is the final call to process for this object. Will finalize the + * coder state and include any final bytes in the output. + * @return true if the input so far is good; false if some error has been detected in the input + * stream.. + */ + public abstract boolean process(byte[] input, int offset, int len, boolean finish); + + /** + * @return the maximum number of bytes a call to process() could produce for the given number of + * input bytes. This may be an overestimate. + */ + public abstract int maxOutputSize(int len); + } + + /* package */ static class Decoder extends Coder { + + /** + * Lookup table for turning bytes into their position in the Base64 alphabet. + */ + private static final int DECODE[] = { + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, -1, 63, + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -2, -1, -1, + -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, -1, -1, -1, -1, -1, + -1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, + 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + }; + + /** + * Decode lookup table for the "web safe" variant (RFC 3548 sec. 4) where - and _ replace + and + * /. + */ + private static final int DECODE_WEBSAFE[] = { + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -2, -1, -1, + -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, -1, -1, -1, -1, 63, + -1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, + 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + }; + + /** + * Non-data values in the DECODE arrays. + */ + private static final int SKIP = -1; + private static final int EQUALS = -2; + final private int[] alphabet; + /** + * States 0-3 are reading through the next input tuple. State 4 is having read one '=' and + * expecting exactly one more. State 5 is expecting no more data or padding characters in the + * input. State 6 is the error state; an error has been detected in the input and no future + * input can "fix" it. + */ + private int state; // state number (0 to 6) + private int value; + + public Decoder(int flags, byte[] output) { + this.output = output; + + alphabet = ((flags & URL_SAFE) == 0) ? DECODE : DECODE_WEBSAFE; + state = 0; + value = 0; + } + + /** + * @return an overestimate for the number of bytes {@code len} bytes could decode to. + */ + public int maxOutputSize(int len) { + return len * 3 / 4 + 10; + } + + /** + * Decode another block of input data. + * + * @return true if the state machine is still healthy. false if bad base-64 data has been + * detected in the input stream. + */ + public boolean process(byte[] input, int offset, int len, boolean finish) { + if (this.state == 6) { + return false; + } + + int p = offset; + len += offset; + + // Using local variables makes the decoder about 12% + // faster than if we manipulate the member variables in + // the loop. (Even alphabet makes a measurable + // difference, which is somewhat surprising to me since + // the member variable is final.) + int state = this.state; + int value = this.value; + int op = 0; + final byte[] output = this.output; + final int[] alphabet = this.alphabet; + + while (p < len) { + // Try the fast path: we're starting a new tuple and the + // next four bytes of the input stream are all data + // bytes. This corresponds to going through states + // 0-1-2-3-0. We expect to use this method for most of + // the data. + // + // If any of the next four bytes of input are non-data + // (whitespace, etc.), value will end up negative. (All + // the non-data values in decode are small negative + // numbers, so shifting any of them up and or'ing them + // together will result in a value with its top bit set.) + // + // You can remove this whole block and the output should + // be the same, just slower. + if (state == 0) { + while (p + 4 <= len && + (value = ((alphabet[input[p] & 0xff] << 18) | + (alphabet[input[p + 1] & 0xff] << 12) | + (alphabet[input[p + 2] & 0xff] << 6) | + (alphabet[input[p + 3] & 0xff]))) >= 0) { + output[op + 2] = (byte) value; + output[op + 1] = (byte) (value >> 8); + output[op] = (byte) (value >> 16); + op += 3; + p += 4; + } + if (p >= len) { + break; + } + } + + // The fast path isn't available -- either we've read a + // partial tuple, or the next four input bytes aren't all + // data, or whatever. Fall back to the slower state + // machine implementation. + + int d = alphabet[input[p++] & 0xff]; + + switch (state) { + case 0: + if (d >= 0) { + value = d; + ++state; + } else if (d != SKIP) { + this.state = 6; + return false; + } + break; + + case 1: + if (d >= 0) { + value = (value << 6) | d; + ++state; + } else if (d != SKIP) { + this.state = 6; + return false; + } + break; + + case 2: + if (d >= 0) { + value = (value << 6) | d; + ++state; + } else if (d == EQUALS) { + // Emit the last (partial) output tuple; + // expect exactly one more padding character. + output[op++] = (byte) (value >> 4); + state = 4; + } else if (d != SKIP) { + this.state = 6; + return false; + } + break; + + case 3: + if (d >= 0) { + // Emit the output triple and return to state 0. + value = (value << 6) | d; + output[op + 2] = (byte) value; + output[op + 1] = (byte) (value >> 8); + output[op] = (byte) (value >> 16); + op += 3; + state = 0; + } else if (d == EQUALS) { + // Emit the last (partial) output tuple; + // expect no further data or padding characters. + output[op + 1] = (byte) (value >> 2); + output[op] = (byte) (value >> 10); + op += 2; + state = 5; + } else if (d != SKIP) { + this.state = 6; + return false; + } + break; + + case 4: + if (d == EQUALS) { + ++state; + } else if (d != SKIP) { + this.state = 6; + return false; + } + break; + + case 5: + if (d != SKIP) { + this.state = 6; + return false; + } + break; + } + } + + if (!finish) { + // We're out of input, but a future call could provide + // more. + this.state = state; + this.value = value; + this.op = op; + return true; + } + + // Done reading input. Now figure out where we are left in + // the state machine and finish up. + + switch (state) { + case 0: + // Output length is a multiple of three. Fine. + break; + case 1: + // Read one extra input byte, which isn't enough to + // make another output byte. Illegal. + this.state = 6; + return false; + case 2: + // Read two extra input bytes, enough to emit 1 more + // output byte. Fine. + output[op++] = (byte) (value >> 4); + break; + case 3: + // Read three extra input bytes, enough to emit 2 more + // output bytes. Fine. + output[op++] = (byte) (value >> 10); + output[op++] = (byte) (value >> 2); + break; + case 4: + // Read one padding '=' when we expected 2. Illegal. + this.state = 6; + return false; + case 5: + // Read all the padding '='s we expected and no more. + // Fine. + break; + } + + this.state = state; + this.op = op; + return true; + } + } + + /* package */ static class Encoder extends Coder { + + /** + * Emit a new line every this many output tuples. Corresponds to a 76-character line length + * (the maximum allowable according to + * RFC 2045). + */ + public static final int LINE_GROUPS = 19; + + /** + * Lookup table for turning Base64 alphabet positions (6 bits) into output bytes. + */ + private static final byte ENCODE[] = { + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', + 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', + 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', + 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/', + }; + + /** + * Lookup table for turning Base64 alphabet positions (6 bits) into output bytes. + */ + private static final byte ENCODE_WEBSAFE[] = { + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', + 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', + 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', + 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-', '_', + }; + final public boolean do_padding; + final public boolean do_newline; + final public boolean do_cr; + final private byte[] tail; + final private byte[] alphabet; + /* package */ int tailLen; + private int count; + + public Encoder(int flags, byte[] output) { + this.output = output; + + do_padding = (flags & NO_PADDING) == 0; + do_newline = (flags & NO_WRAP) == 0; + do_cr = (flags & CRLF) != 0; + alphabet = ((flags & URL_SAFE) == 0) ? ENCODE : ENCODE_WEBSAFE; + + tail = new byte[2]; + tailLen = 0; + + count = do_newline ? LINE_GROUPS : -1; + } + + /** + * @return an overestimate for the number of bytes {@code len} bytes could encode to. + */ + public int maxOutputSize(int len) { + return len * 8 / 5 + 10; + } + + public boolean process(byte[] input, int offset, int len, boolean finish) { + // Using local variables makes the encoder about 9% faster. + final byte[] alphabet = this.alphabet; + final byte[] output = this.output; + int op = 0; + int count = this.count; + + int p = offset; + len += offset; + int v = -1; + + // First we need to concatenate the tail of the previous call + // with any input bytes available now and see if we can empty + // the tail. + + switch (tailLen) { + case 0: + // There was no tail. + break; + + case 1: + if (p + 2 <= len) { + // A 1-byte tail with at least 2 bytes of + // input available now. + v = ((tail[0] & 0xff) << 16) | + ((input[p++] & 0xff) << 8) | + (input[p++] & 0xff); + tailLen = 0; + } + ; + break; + + case 2: + if (p + 1 <= len) { + // A 2-byte tail with at least 1 byte of input. + v = ((tail[0] & 0xff) << 16) | + ((tail[1] & 0xff) << 8) | + (input[p++] & 0xff); + tailLen = 0; + } + break; + } + + if (v != -1) { + output[op++] = alphabet[(v >> 18) & 0x3f]; + output[op++] = alphabet[(v >> 12) & 0x3f]; + output[op++] = alphabet[(v >> 6) & 0x3f]; + output[op++] = alphabet[v & 0x3f]; + if (--count == 0) { + if (do_cr) { + output[op++] = '\r'; + } + output[op++] = '\n'; + count = LINE_GROUPS; + } + } + + // At this point either there is no tail, or there are fewer + // than 3 bytes of input available. + + // The main loop, turning 3 input bytes into 4 output bytes on + // each iteration. + while (p + 3 <= len) { + v = ((input[p] & 0xff) << 16) | + ((input[p + 1] & 0xff) << 8) | + (input[p + 2] & 0xff); + output[op] = alphabet[(v >> 18) & 0x3f]; + output[op + 1] = alphabet[(v >> 12) & 0x3f]; + output[op + 2] = alphabet[(v >> 6) & 0x3f]; + output[op + 3] = alphabet[v & 0x3f]; + p += 3; + op += 4; + if (--count == 0) { + if (do_cr) { + output[op++] = '\r'; + } + output[op++] = '\n'; + count = LINE_GROUPS; + } + } + + if (finish) { + // Finish up the tail of the input. Note that we need to + // consume any bytes in tail before any bytes + // remaining in input; there should be at most two bytes + // total. + + if (p - tailLen == len - 1) { + int t = 0; + v = ((tailLen > 0 ? tail[t++] : input[p++]) & 0xff) << 4; + tailLen -= t; + output[op++] = alphabet[(v >> 6) & 0x3f]; + output[op++] = alphabet[v & 0x3f]; + if (do_padding) { + output[op++] = '='; + output[op++] = '='; + } + if (do_newline) { + if (do_cr) { + output[op++] = '\r'; + } + output[op++] = '\n'; + } + } else if (p - tailLen == len - 2) { + int t = 0; + v = (((tailLen > 1 ? tail[t++] : input[p++]) & 0xff) << 10) | + (((tailLen > 0 ? tail[t++] : input[p++]) & 0xff) << 2); + tailLen -= t; + output[op++] = alphabet[(v >> 12) & 0x3f]; + output[op++] = alphabet[(v >> 6) & 0x3f]; + output[op++] = alphabet[v & 0x3f]; + if (do_padding) { + output[op++] = '='; + } + if (do_newline) { + if (do_cr) { + output[op++] = '\r'; + } + output[op++] = '\n'; + } + } else if (do_newline && op > 0 && count != LINE_GROUPS) { + if (do_cr) { + output[op++] = '\r'; + } + output[op++] = '\n'; + } + + assert tailLen == 0; + assert p == len; + } else { + // Save the leftovers in tail to be consumed on the next + // call to encodeInternal. + + if (p == len - 1) { + tail[tailLen++] = input[p]; + } else if (p == len - 2) { + tail[tailLen++] = input[p]; + tail[tailLen++] = input[p + 1]; + } + } + + this.op = op; + this.count = count; + + return true; + } + } + +} diff --git a/src/main/java/com/tarento/commenthub/authentication/util/CryptoUtil.java b/src/main/java/com/tarento/commenthub/authentication/util/CryptoUtil.java new file mode 100644 index 0000000..1314eb8 --- /dev/null +++ b/src/main/java/com/tarento/commenthub/authentication/util/CryptoUtil.java @@ -0,0 +1,47 @@ +package com.tarento.commenthub.authentication.util; + +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.security.PublicKey; +import java.security.Signature; +import java.security.SignatureException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class CryptoUtil { + + private static final Charset US_ASCII = StandardCharsets.US_ASCII; + private static final Logger logger = LoggerFactory.getLogger(CryptoUtil.class.getName()); + + private CryptoUtil() { + } + + /** + * Verifies an RSA signature using the provided payload, signature, public key, and algorithm. + * + * @param payLoad The payload to be verified. + * @param signature The signature to be verified. + * @param key The public key used for verification. + * @param algorithm The algorithm used for verification. + * @return true if the signature is valid, false otherwise. + */ + public static boolean verifyRSASign(String payLoad, byte[] signature, PublicKey key, + String algorithm) { + Signature sign; + try { + // Initialize a Signature instance with the provided algorithm + sign = Signature.getInstance(algorithm); + // Initialize the Signature instance with the public key for verification + sign.initVerify(key); + // Update the Signature instance with the payload bytes + sign.update(payLoad.getBytes(US_ASCII)); + return sign.verify(signature); + } catch (NoSuchAlgorithmException | InvalidKeyException | SignatureException e) { + logger.error("An error occurred during RSA signature verification: {}", e.getMessage(), e); + return false; + } + } + +} diff --git a/src/main/java/com/tarento/commenthub/authentication/util/KeyManager.java b/src/main/java/com/tarento/commenthub/authentication/util/KeyManager.java new file mode 100644 index 0000000..f74fa90 --- /dev/null +++ b/src/main/java/com/tarento/commenthub/authentication/util/KeyManager.java @@ -0,0 +1,81 @@ +package com.tarento.commenthub.authentication.util; + +import com.tarento.commenthub.authentication.model.KeyData; +import com.tarento.commenthub.constant.Constants; +import com.tarento.commenthub.transactional.utils.PropertiesCache; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.KeyFactory; +import java.security.PublicKey; +import java.security.spec.X509EncodedKeySpec; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import javax.annotation.PostConstruct; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Component; + +@Component +public class KeyManager { + + private static final Logger logger = LoggerFactory.getLogger(KeyManager.class.getName()); + private static final PropertiesCache propertiesCache = PropertiesCache.getInstance(); + + private static final Map keyMap = new HashMap<>(); + + @PostConstruct + public void init() { + // Read the content of the file and load it as a PublicKey + String basePath = propertiesCache.getProperty(Constants.ACCESS_TOKEN_PUBLICKEY_BASEPATH); + try (Stream walk = Files.walk(Paths.get(basePath))) { + List result = + walk.filter(Files::isRegularFile).map(Path::toString).collect(Collectors.toList()); + result.forEach(file -> { + try { + Path path = Paths.get(file); + List lines = Files.readAllLines(path, StandardCharsets.UTF_8); + String content = String.join("", lines); + KeyData keyData = new KeyData(path.getFileName().toString(), loadPublicKey(content)); + // Store the KeyData object in the keyMap + keyMap.put(path.getFileName().toString(), keyData); + } catch (Exception e) { + logger.error("KeyManager:init: exception in reading public keys ", e); + } + }); + } catch (Exception e) { + logger.error("KeyManager:init: exception in loading publickeys ", e); + } + } + + public KeyData getPublicKey(String keyId) { + return keyMap.get(keyId); + } + + + /** + * Loads a public key from a string representation. + * + * @param key The string representation of the public key + * @return The loaded public key + * @throws Exception If there's an error during the loading process + */ + public static PublicKey loadPublicKey(String key) throws Exception { + String publicKey = new String(key.getBytes(), StandardCharsets.UTF_8); + // Remove header and footer from the key string + publicKey = publicKey.replaceAll("(-+BEGIN PUBLIC KEY-+)", ""); + publicKey = publicKey.replaceAll("(-+END PUBLIC KEY-+)", ""); + publicKey = publicKey.replaceAll("[\\r\\n]+", ""); + // Decode the key string from Base64 + byte[] keyBytes = Base64Util.decode(publicKey.getBytes("UTF-8"), Base64Util.DEFAULT); + // Convert the key bytes to a PublicKey object + X509EncodedKeySpec x509publicKey = new X509EncodedKeySpec(keyBytes); + KeyFactory kf = KeyFactory.getInstance("RSA"); + return kf.generatePublic(x509publicKey); + } + +} diff --git a/src/main/java/com/tarento/commenthub/constant/Constants.java b/src/main/java/com/tarento/commenthub/constant/Constants.java index a34bb3d..0285684 100644 --- a/src/main/java/com/tarento/commenthub/constant/Constants.java +++ b/src/main/java/com/tarento/commenthub/constant/Constants.java @@ -3,6 +3,7 @@ public class Constants { + private Constants() { } @@ -94,6 +95,18 @@ private Constants() { public static final String CREATED_DATE = "createdDate"; public static final String OFFSET = "offset"; public static final String LIMIT = "limit"; - - + public static final String X_AUTH_TOKEN = "x-authenticated-user-token"; + public static final String SSO_URL = "sso.url"; + public static final String SSO_REALM = "sso.realm"; + public static final String DOT_SEPARATOR = "."; + public static final String UNAUTHORIZED_USER = "Unauthorized"; + public static final String SUB = "sub"; + public static final String SHA_256_WITH_RSA = "SHA256withRSA"; + public static final String ACCESS_TOKEN_PUBLICKEY_BASEPATH = "accesstoken.publickey.basepath"; + public static final String INVALID_USER = "Invalid user"; + public static final String NOT_FOUND = "Not found"; + public static final String NOT_ACTIVE_STATUS = "Only active coments can be reported"; + public static final String REPORTED_BY = "reportedBy"; + public static final String NOT_SUSPENDED_STATUS = "Only reported coments can be deleted by admin"; + public static final String DELETED_BY = "deletedBy"; } \ No newline at end of file diff --git a/src/main/java/com/tarento/commenthub/controller/CommentController.java b/src/main/java/com/tarento/commenthub/controller/CommentController.java index b05413a..83cc6ba 100644 --- a/src/main/java/com/tarento/commenthub/controller/CommentController.java +++ b/src/main/java/com/tarento/commenthub/controller/CommentController.java @@ -24,6 +24,7 @@ import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.PutMapping; import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestHeader; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; @@ -73,14 +74,15 @@ public Comment deleteComment( @PathVariable String commentId, @RequestParam(name = "entityType") String entityType, @RequestParam(name = "entityId") String entityId, - @RequestParam(name = "workflow") String workflow) { + @RequestParam(name = "workflow") String workflow, + @RequestHeader(Constants.X_AUTH_TOKEN) String token) { CommentTreeIdentifierDTO commentTreeIdentifierDTO = new CommentTreeIdentifierDTO(); commentTreeIdentifierDTO.setEntityType(entityType); commentTreeIdentifierDTO.setEntityId(entityId); commentTreeIdentifierDTO.setWorkflow(workflow); - return commentService.deleteCommentById(commentId, commentTreeIdentifierDTO); + return commentService.deleteCommentById(commentId, commentTreeIdentifierDTO, token); } @PostMapping("/v1/setStatusToResolved") @@ -129,4 +131,24 @@ public ResponseEntity search(@RequestBody List commentIds) { return new ResponseEntity<>(response, response.getResponseCode()); } + @PostMapping("/report") + public ResponseEntity report(@RequestBody Map request, + @RequestHeader(Constants.X_AUTH_TOKEN) String token) { + ApiResponse response = commentService.reportComment(request, token); + if (response.getResponseCode().equals(HttpStatus.NOT_FOUND) && response.getResult().isEmpty()) { + return new ResponseEntity<>(response, HttpStatus.OK); + } + return new ResponseEntity<>(response, response.getResponseCode()); + } + + @PostMapping("/delete/reported") + public ResponseEntity delete(@RequestBody Map request, + @RequestHeader(Constants.X_AUTH_TOKEN) String token) { + ApiResponse response = commentService.deleteReportedComments(request, token); + if (response.getResponseCode().equals(HttpStatus.NOT_FOUND) && response.getResult().isEmpty()) { + return new ResponseEntity<>(response, HttpStatus.OK); + } + return new ResponseEntity<>(response, response.getResponseCode()); + } + } diff --git a/src/main/java/com/tarento/commenthub/service/CommentService.java b/src/main/java/com/tarento/commenthub/service/CommentService.java index 054eee2..e47ab35 100644 --- a/src/main/java/com/tarento/commenthub/service/CommentService.java +++ b/src/main/java/com/tarento/commenthub/service/CommentService.java @@ -22,7 +22,8 @@ public interface CommentService { CommentsResoponseDTO getComments(CommentTreeIdentifierDTO commentTreeIdentifierDTO); - Comment deleteCommentById(String commentId, CommentTreeIdentifierDTO commentTreeIdentifierDTO); + Comment deleteCommentById(String commentId, CommentTreeIdentifierDTO commentTreeIdentifierDTO, + String token); ApiResponse likeComment(Map likePayload); @@ -32,4 +33,8 @@ public interface CommentService { ApiResponse paginatedComment(SearchCriteria searchCriteria); ApiResponse listOfComments(List commentIds); + + ApiResponse reportComment(Map request, String token); + + ApiResponse deleteReportedComments(Map request, String token); } diff --git a/src/main/java/com/tarento/commenthub/service/impl/CommentServiceImpl.java b/src/main/java/com/tarento/commenthub/service/impl/CommentServiceImpl.java index 5a55938..d2165cf 100644 --- a/src/main/java/com/tarento/commenthub/service/impl/CommentServiceImpl.java +++ b/src/main/java/com/tarento/commenthub/service/impl/CommentServiceImpl.java @@ -14,6 +14,7 @@ import com.networknt.schema.JsonSchema; import com.networknt.schema.JsonSchemaFactory; import com.networknt.schema.ValidationMessage; +import com.tarento.commenthub.authentication.util.AccessTokenValidator; import com.tarento.commenthub.constant.Constants; import com.tarento.commenthub.dto.CommentTreeIdentifierDTO; import com.tarento.commenthub.dto.CommentsResoponseDTO; @@ -88,6 +89,9 @@ public class CommentServiceImpl implements CommentService { @Autowired private CommentTreeRepository commentTreeRepository; + @Autowired + private AccessTokenValidator accessTokenValidator; + @Override public ResponseDTO addFirstCommentToCreateTree(JsonNode payload) { validatePayload(Constants.ADD_FIRST_COMMENT_PAYLOAD_VALIDATION_FILE, payload); @@ -138,18 +142,23 @@ public ResponseDTO updateExistingComment(JsonNode paylaod) { throw new CommentException( Constants.ERROR, "To update an existing comment, please provide a valid commentId."); } - log.info("commentId: " + paylaod.get(Constants.COMMENT_ID).asText()); Optional optComment = commentRepository.findById(paylaod.get(Constants.COMMENT_ID).asText()); - if (!optComment.isPresent() || !optComment.get().getStatus().equalsIgnoreCase(Status.ACTIVE.name())) { throw new CommentException( Constants.ERROR, "The requested comment was not found or has been deleted."); } + if (!paylaod.get(Constants.COMMENT_DATA).get(Constants.COMMENT_SOURCE).get(Constants.USER_ID) + .asText() + .equalsIgnoreCase(optComment.get().getCommentData().get(Constants.COMMENT_SOURCE) + .get(Constants.USER_ID).asText())) { + throw new CommentException( + Constants.ERROR, "No access to edit the comment"); + } Comment commentToBeUpdated = optComment.get(); commentToBeUpdated.setCommentData(paylaod.get(Constants.COMMENT_DATA)); @@ -195,13 +204,24 @@ public CommentsResoponseDTO getComments(CommentTreeIdentifierDTO commentTreeIden @Override public Comment deleteCommentById( - String commentId, CommentTreeIdentifierDTO commentTreeIdentifierDTO) { + String commentId, CommentTreeIdentifierDTO commentTreeIdentifierDTO, String token) { log.info("CommentServiceImpl::deleteCommentById: Deleting comment with ID: {}", commentId); + String userId = accessTokenValidator.verifyUserToken(token); + if (StringUtils.isBlank(userId) || userId.equalsIgnoreCase(Constants.UNAUTHORIZED_USER)) { + throw new CommentException(Constants.ERROR, "Not a valid user"); + } Optional fetchedComment = commentRepository.findById(commentId); if (!fetchedComment.isPresent()) { throw new CommentException(Constants.ERROR, "No such comment found"); } Comment comment = fetchedComment.get(); + if (!userId + .equalsIgnoreCase(comment.getCommentData().get(Constants.COMMENT_SOURCE) + .get(Constants.USER_ID).asText())) { + throw new CommentException( + Constants.ERROR, "No access to edit the comment"); + + } if (!comment.getStatus().equalsIgnoreCase(Status.ACTIVE.name())) { throw new CommentException( Constants.ERROR, "You are trying to delete an already deleted comment"); @@ -479,6 +499,80 @@ public ApiResponse listOfComments(List commentIds) { return response; } + @Override + public ApiResponse reportComment(Map request, String token) { + log.info("CommentServiceImpl:reportComment::inside the method"); + ApiResponse response = new ApiResponse(); + response.setResponseCode(HttpStatus.OK); + String userId = accessTokenValidator.verifyUserToken(token); + if (StringUtils.isBlank(userId) || userId.equalsIgnoreCase(Constants.UNAUTHORIZED_USER)) { + return returnErrorMsg(Constants.INVALID_USER, HttpStatus.BAD_REQUEST, response); + } + String error = validateReportCommentPayload(request); + if (StringUtils.isNotBlank(error)) { + return returnErrorMsg(error, HttpStatus.BAD_REQUEST, response); + } + Optional fetchedComment = commentRepository.findById( + (String) request.get(Constants.COMMENT_ID)); + if (!fetchedComment.isPresent()) { + return returnErrorMsg(Constants.NOT_FOUND, HttpStatus.NOT_FOUND, response); + } + Comment comment = fetchedComment.get(); + if (!comment.getStatus().equalsIgnoreCase(Status.ACTIVE.name())) { + return returnErrorMsg(Constants.NOT_ACTIVE_STATUS, HttpStatus.NOT_FOUND, response); + } + ObjectNode commentData = (ObjectNode) comment.getCommentData(); + commentData.put(Constants.REPORTED_BY, userId); + comment.setStatus(Status.SUSPENDED.name().toLowerCase()); + comment = commentRepository.save(comment); + response.setResult(objectMapper.convertValue(comment, Map.class)); + return response; + } + + @Override + public ApiResponse deleteReportedComments(Map request, String token) { + log.info("CommentServiceImpl:reportComment::inside the method"); + ApiResponse response = new ApiResponse(); + response.setResponseCode(HttpStatus.OK); + String userId = accessTokenValidator.verifyUserToken(token); + if (StringUtils.isBlank(userId) || userId.equalsIgnoreCase(Constants.UNAUTHORIZED_USER)) { + return returnErrorMsg(Constants.INVALID_USER, HttpStatus.BAD_REQUEST, response); + } + String error = validateReportCommentPayload(request); + if (StringUtils.isNotBlank(error)) { + return returnErrorMsg(error, HttpStatus.BAD_REQUEST, response); + } + Optional fetchedComment = commentRepository.findById( + (String) request.get(Constants.COMMENT_ID)); + if (!fetchedComment.isPresent()) { + return returnErrorMsg(Constants.NOT_FOUND, HttpStatus.NOT_FOUND, response); + } + Comment comment = fetchedComment.get(); + if (!comment.getStatus().equalsIgnoreCase(Status.SUSPENDED.name())) { + return returnErrorMsg(Constants.NOT_SUSPENDED_STATUS, HttpStatus.NOT_FOUND, response); + } + ObjectNode commentData = (ObjectNode) comment.getCommentData(); + commentData.put(Constants.DELETED_BY, userId); + comment.setStatus(Status.INACTIVE.name().toLowerCase()); + comment = commentRepository.save(comment); + response.setResult(objectMapper.convertValue(comment, Map.class)); + return response; + } + + private String validateReportCommentPayload(Map request) { + StringBuffer str = new StringBuffer(); + List errList = new ArrayList<>(); + + if (request.containsKey(Constants.COMMENT_ID) && + StringUtils.isBlank((String) request.get(Constants.COMMENT_ID))){ + errList.add(Constants.COMMENT_TREE_ID); + } + if (!errList.isEmpty()) { + str.append("Failed Due To Missing Params - ").append(errList).append("."); + } + return str.toString(); + } + private ApiResponse returnErrorMsg(String error, HttpStatus type, ApiResponse response){ response.setResponseCode(type); response.getParams().setErr(error); diff --git a/src/main/java/com/tarento/commenthub/utility/Status.java b/src/main/java/com/tarento/commenthub/utility/Status.java index 4e407e5..3e6503b 100644 --- a/src/main/java/com/tarento/commenthub/utility/Status.java +++ b/src/main/java/com/tarento/commenthub/utility/Status.java @@ -2,5 +2,6 @@ public enum Status { ACTIVE, - INACTIVE + INACTIVE, + SUSPENDED } \ No newline at end of file diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index 3e8f022..ce9f390 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -40,4 +40,8 @@ spring.data.cassandra.local-datacenter=DC1 cassandra.config.host=localhost default.page.size=20 -default.offset=0 \ No newline at end of file +default.offset=0 + +accesstoken.publickey.basepath= +sso.url= +sso.realm= \ No newline at end of file