From 283fc940bc05a326e9e5fff691353f71d4b46356 Mon Sep 17 00:00:00 2001 From: Joseph Cosentino Date: Fri, 2 Feb 2024 12:24:12 -0800 Subject: [PATCH] chore: revert back to nucleus trie --- .../auth/PermissionEvaluationUtils.java | 6 +- .../clientdevices/auth/util/WildcardTrie.java | 349 ------------------ 2 files changed, 3 insertions(+), 352 deletions(-) delete mode 100644 src/main/java/com/aws/greengrass/clientdevices/auth/util/WildcardTrie.java diff --git a/src/main/java/com/aws/greengrass/clientdevices/auth/PermissionEvaluationUtils.java b/src/main/java/com/aws/greengrass/clientdevices/auth/PermissionEvaluationUtils.java index ad95b143c..11d360657 100644 --- a/src/main/java/com/aws/greengrass/clientdevices/auth/PermissionEvaluationUtils.java +++ b/src/main/java/com/aws/greengrass/clientdevices/auth/PermissionEvaluationUtils.java @@ -5,11 +5,11 @@ package com.aws.greengrass.clientdevices.auth; +import com.aws.greengrass.authorization.WildcardTrie; import com.aws.greengrass.clientdevices.auth.configuration.GroupManager; import com.aws.greengrass.clientdevices.auth.configuration.Permission; import com.aws.greengrass.clientdevices.auth.exception.PolicyException; import com.aws.greengrass.clientdevices.auth.session.Session; -import com.aws.greengrass.clientdevices.auth.util.WildcardTrie; import com.aws.greengrass.logging.api.Logger; import com.aws.greengrass.logging.impl.LogManager; import com.aws.greengrass.util.Utils; @@ -37,7 +37,6 @@ public final class PermissionEvaluationUtils { private static final Pattern SERVICE_RESOURCE_PATTERN = Pattern.compile( String.format(SERVICE_RESOURCE_FORMAT, SERVICE_PATTERN_STRING, SERVICE_RESOURCE_TYPE_PATTERN_STRING, SERVICE_RESOURCE_NAME_PATTERN_STRING), Pattern.UNICODE_CHARACTER_CLASS); - private final WildcardTrie wildcardTrie = new WildcardTrie(); private final GroupManager groupManager; /** @@ -132,7 +131,8 @@ private boolean compareResource(Resource requestResource, String policyResource) return true; } - wildcardTrie.set(policyResource); + WildcardTrie wildcardTrie = new WildcardTrie(); + wildcardTrie.add(policyResource); return wildcardTrie.matchesStandard(requestResource.getResourceStr()); } diff --git a/src/main/java/com/aws/greengrass/clientdevices/auth/util/WildcardTrie.java b/src/main/java/com/aws/greengrass/clientdevices/auth/util/WildcardTrie.java deleted file mode 100644 index 0c39ee19f..000000000 --- a/src/main/java/com/aws/greengrass/clientdevices/auth/util/WildcardTrie.java +++ /dev/null @@ -1,349 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package com.aws.greengrass.clientdevices.auth.util; - -import com.aws.greengrass.authorization.AuthorizationHandler.ResourceLookupPolicy; - -import java.util.HashMap; -import java.util.Map; - -/** - * Copied from nucleus with some customizations for performance. - */ -public class WildcardTrie { - protected static final String GLOB_WILDCARD = "*"; - protected static final String MQTT_MULTILEVEL_WILDCARD = "#"; - protected static final String MQTT_SINGLELEVEL_WILDCARD = "+"; - protected static final String MQTT_LEVEL_SEPARATOR = "/"; - protected static final char nullChar = '\0'; - protected static final char escapeChar = '$'; - protected static final char wildcardChar = GLOB_WILDCARD.charAt(0); - protected static final char multiLevelWildcardChar = MQTT_MULTILEVEL_WILDCARD.charAt(0); - protected static final char singleLevelWildcardChar = MQTT_SINGLELEVEL_WILDCARD.charAt(0); - protected static final char levelSeparatorChar = MQTT_LEVEL_SEPARATOR.charAt(0); - - private boolean isTerminal; - private boolean isTerminalLevel; - private boolean isWildcard; - private boolean isMQTTWildcard; - private boolean matchAll; - private final Map children = new HashMap<>(); - - private void clear() { - isTerminal = false; - isTerminalLevel = false; - isWildcard = false; - isMQTTWildcard = false; - matchAll = false; - children.clear(); - } - - public void set(String subject) { - clear(); - add(subject); - } - - /** - * Add allowed resources for a particular operation. - * - A new node is created for every occurrence of a wildcard (*, #, +). - * - Only nodes with valid usage of wildcards are marked with isWildcard or isMQTTWildcard. - * - Any other characters are grouped together to form a node. - * - Just a '*' or '#' creates a Node setting matchAll to true and would match all resources - * - * @param subject resource pattern - */ - public void add(String subject) { - if (subject == null) { - return; - } - if (subject.equals(GLOB_WILDCARD)) { - WildcardTrie initial = getOrInitializeChild(this, GLOB_WILDCARD); - initial.matchAll = true; - initial.isTerminal = true; - initial.isWildcard = true; - return; - } - if (subject.equals(MQTT_MULTILEVEL_WILDCARD)) { - WildcardTrie initial = getOrInitializeChild(this, MQTT_MULTILEVEL_WILDCARD); - initial.matchAll = true; - initial.isTerminal = true; - initial.isMQTTWildcard = true; - return; - } - if (subject.equals(MQTT_SINGLELEVEL_WILDCARD)) { - WildcardTrie initial = getOrInitializeChild(this, MQTT_SINGLELEVEL_WILDCARD); - initial.isTerminal = true; - initial.isMQTTWildcard = true; - return; - } - if (subject.startsWith(MQTT_SINGLELEVEL_WILDCARD + MQTT_LEVEL_SEPARATOR)) { - WildcardTrie initial = getOrInitializeChild(this, MQTT_SINGLELEVEL_WILDCARD); - initial.isMQTTWildcard = true; - initial.add(subject.substring(1), true); - } - - add(subject, true); - } - - @SuppressWarnings("PMD.AvoidDeeplyNestedIfStmts") - private WildcardTrie add(String subject, boolean isTerminal) { - if (subject.isEmpty()) { - this.isTerminal |= isTerminal; - return this; - } - int subjectLength = subject.length(); - WildcardTrie current = this; - StringBuilder sb = new StringBuilder(subjectLength); - for (int i = 0; i < subjectLength; i++) { - char currentChar = subject.charAt(i); - // Create separate Nodes for wildcards *, # and + - // Also tag them wildcard if its a valid usage - if (currentChar == wildcardChar) { - current = current.add(sb.toString(), false); - current = getOrInitializeChild(current, GLOB_WILDCARD); - current.isWildcard = true; - // If the string ends with *, then the wildcard is a terminal - if (i == subjectLength - 1) { - current.isTerminal = isTerminal; - return current; - } - return current.add(subject.substring(i + 1), true); - } - if (currentChar == multiLevelWildcardChar) { - WildcardTrie terminalLevel = current.add(sb.toString(), false); - current = getOrInitializeChild(terminalLevel, MQTT_MULTILEVEL_WILDCARD); - if (i == subjectLength - 1) { - current.isTerminal = true; - // check if # wildcard usage is valid - if (i > 0 && subject.charAt(i - 1) == levelSeparatorChar) { - current.isMQTTWildcard = true; - current.matchAll = true; - terminalLevel.isTerminalLevel = true; - } - return current; - } - return current.add(subject.substring(i + 1), true); - } - if (currentChar == singleLevelWildcardChar) { - current = current.add(sb.toString(), false); - current = getOrInitializeChild(current, MQTT_SINGLELEVEL_WILDCARD); - if (i == subjectLength - 1) { - current.isTerminal = true; - // check if '+' wildcard usage is valid - // if it's used at the last level - if (i > 0 && subject.charAt(i - 1) == levelSeparatorChar) { - current.isMQTTWildcard = true; - } - return current; - } - // check if '+' wildcard usage is valid - // if it's used in middle levels - if (i > 0 && subject.charAt(i - 1) == levelSeparatorChar - && subject.charAt(i + 1) == levelSeparatorChar) { - current.isMQTTWildcard = true; - } - return current.add(subject.substring(i + 1), true); - } - if (currentChar == escapeChar) { - char actualChar = getActualChar(subject.substring(i)); - if (actualChar != nullChar) { - sb.append(actualChar); - i = i + 3; - continue; - } - } - sb.append(currentChar); - } - // Handle non-wildcard value - current = getOrInitializeChild(current, sb.toString()); - current.isTerminal |= isTerminal; - return current; - } - - /** - * The method tries to parse the given string using escape sequence ${c} (where c is a character to be escaped) - * and returns the character c if the pattern is matched. In any other scenario it returns null character ('\0') - * - * @param str string provided to get - */ - static char getActualChar(String str) { - if (str.length() < 4) { - return nullChar; - } - // Match the escape format ${c} - if (str.charAt(0) == escapeChar && str.charAt(1) == '{' && str.charAt(3) == '}') { - return str.charAt(2); - } - return nullChar; - } - - /** - * Match given string to the corresponding allowed resources trie. MQTT wildcards are not processed. - * - * @param str string to match. - */ - @SuppressWarnings({"PMD.UselessParentheses", "PMD.CollapsibleIfStatements"}) - public boolean matchesStandard(String str) { - if (str == null) { - return true; - } - if ((isWildcard && isTerminal) || (isTerminal && str.isEmpty())) { - return true; - } - - boolean hasMatch = false; - Map matchingChildren = new HashMap<>(); - for (Map.Entry e : children.entrySet()) { - // Succeed fast - if (hasMatch) { - return true; - } - String key = e.getKey(); - WildcardTrie value = e.getValue(); - - // Process * wildcards - if (value.isWildcard && key.equals(GLOB_WILDCARD)) { - hasMatch = value.matchesStandard(str); - continue; - } - - // Match normal characters - if (str.startsWith(key)) { - hasMatch = value.matchesStandard(str.substring(key.length())); - // Succeed fast - if (hasMatch) { - return true; - } - } - - // If I'm a wildcard, then I need to maybe chomp many characters to match my children - if (isWildcard) { - int foundChildIndex = str.indexOf(key); - int keyLength = key.length(); - while (foundChildIndex >= 0) { - matchingChildren.put(str.substring(foundChildIndex + keyLength), value); - foundChildIndex = str.indexOf(key, foundChildIndex + 1); - } - } - } - // Succeed fast - if (hasMatch) { - return true; - } - if (isWildcard && !matchingChildren.isEmpty()) { - return matchingChildren.entrySet().stream().anyMatch((e) -> e.getValue().matchesStandard(e.getKey())); - } - - return false; - } - - /** - * Match given string to the corresponding allowed resources trie. MQTT wildcards are processed only if - * its a valid usage, otherwise treated as normal characters. - * - * @param str string to match - */ - @SuppressWarnings({"PMD.UselessParentheses", "PMD.CollapsibleIfStatements"}) - public boolean matchesMQTT(String str) { - if (str == null) { - return true; - } - if ((isWildcard && isTerminal) || (isTerminal && str.isEmpty())) { - return true; - } - if (isMQTTWildcard) { - if (matchAll || (isTerminal && (str.indexOf(MQTT_LEVEL_SEPARATOR) == -1))) { - return true; - } - } - - boolean hasMatch = false; - Map matchingChildren = new HashMap<>(); - for (Map.Entry e : children.entrySet()) { - // Succeed fast - if (hasMatch) { - return true; - } - String key = e.getKey(); - WildcardTrie value = e.getValue(); - - // Process *, # and + wildcards (only process MQTT wildcards that have valid usages) - if ((value.isWildcard && key.equals(GLOB_WILDCARD)) - || (value.isMQTTWildcard && (key.equals(MQTT_SINGLELEVEL_WILDCARD) - || key.equals(MQTT_MULTILEVEL_WILDCARD)))) { - hasMatch = value.matchesMQTT(str); - continue; - } - - // Match normal characters - if (str.startsWith(key)) { - hasMatch = value.matchesMQTT(str.substring(key.length())); - // Succeed fast - if (hasMatch) { - return true; - } - } - - // Check if it's terminalLevel to allow matching of string without "/" in the end - // "abc/#" should match "abc". - // "abc/*xy/#" should match "abc/12xy" - String terminalKey = key.substring(0, key.length() - 1); - if (value.isTerminalLevel) { - if (str.equals(terminalKey)) { - return true; - } - if (str.endsWith(terminalKey)) { - key = terminalKey; - } - } - - int keyLength = key.length(); - // If I'm a wildcard, then I need to maybe chomp many characters to match my children - if (isWildcard) { - int foundChildIndex = str.indexOf(key); - while (foundChildIndex >= 0 && foundChildIndex < str.length()) { - matchingChildren.put(str.substring(foundChildIndex + keyLength), value); - foundChildIndex = str.indexOf(key, foundChildIndex + 1); - } - } - // If I'm a MQTT wildcard (specifically +, as # is already covered), - // then I need to maybe chomp many characters to match my children - if (isMQTTWildcard) { - int foundChildIndex = str.indexOf(key); - // Matched characters inside + should not contain a "/" - while (foundChildIndex >= 0 - && foundChildIndex < str.length() - && (str.substring(0,foundChildIndex).indexOf(MQTT_LEVEL_SEPARATOR) == -1)) { - matchingChildren.put(str.substring(foundChildIndex + keyLength), value); - foundChildIndex = str.indexOf(key, foundChildIndex + 1); - } - } - } - // Succeed fast - if (hasMatch) { - return true; - } - if ((isWildcard || isMQTTWildcard) && !matchingChildren.isEmpty()) { - return matchingChildren.entrySet().stream().anyMatch((e) -> e.getValue().matchesMQTT(e.getKey())); - } - - return false; - } - - public boolean matches(String str, ResourceLookupPolicy lookupPolicy) { - return lookupPolicy == ResourceLookupPolicy.MQTT_STYLE ? matchesMQTT(str) - : matchesStandard(str); - } - - private static WildcardTrie getOrInitializeChild(WildcardTrie trie, String key) { - WildcardTrie child = trie.children.get(key); - if (child == null) { - child = new WildcardTrie(); - trie.children.put(key, child); - } - return child; - } -}