diff --git a/src/main/java/com/aws/greengrass/clientdevices/auth/PermissionEvaluationUtils.java b/src/main/java/com/aws/greengrass/clientdevices/auth/PermissionEvaluationUtils.java index be32f6661..8c68d929c 100644 --- a/src/main/java/com/aws/greengrass/clientdevices/auth/PermissionEvaluationUtils.java +++ b/src/main/java/com/aws/greengrass/clientdevices/auth/PermissionEvaluationUtils.java @@ -137,14 +137,16 @@ private boolean compareResource(Resource requestResource, String policyResource) if (Objects.equals(requestResource.getResourceStr(), policyResource)) { return true; } - WildcardTrie trie = new WildcardTrie(); - trie.add(policyResource); - return trie.matches(requestResource.getResourceStr(), matchSingleCharacterWildcard()); + return new WildcardTrie(wildcardOpts()) + .withPattern(policyResource) + .matches(requestResource.getResourceStr()); } - private boolean matchSingleCharacterWildcard() { + private WildcardTrie.MatchOptions wildcardOpts() { CDAConfiguration config = cdaConfiguration; - return config != null && config.isMatchSingleCharacterWildcard(); + return WildcardTrie.MatchOptions.builder() + .useSingleCharWildcard(config != null && config.isMatchSingleCharacterWildcard()) + .build(); } private Operation parseOperation(String operationStr) throws PolicyException { diff --git a/src/main/java/com/aws/greengrass/clientdevices/auth/util/WildcardTrie.java b/src/main/java/com/aws/greengrass/clientdevices/auth/util/WildcardTrie.java index 663250a1f..da3ae031f 100644 --- a/src/main/java/com/aws/greengrass/clientdevices/auth/util/WildcardTrie.java +++ b/src/main/java/com/aws/greengrass/clientdevices/auth/util/WildcardTrie.java @@ -7,151 +7,183 @@ import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import lombok.Builder; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import lombok.Value; +import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.function.Supplier; +@RequiredArgsConstructor public class WildcardTrie { - private static final String GLOB_WILDCARD = "*"; - private static final String SINGLE_CHAR_WILDCARD = "?"; - private final Map children = new DefaultHashMap<>(WildcardTrie::new); + private Node root; - private boolean isTerminal; - private boolean isGlobWildcard; - private boolean isSingleCharWildcard; + private final MatchOptions opts; - public void add(String subject) { - add(subject, true); + private static String cleanPattern(@NonNull String s) { + // for example "abc***def" can be reduced to "abc*def" + return s.replaceAll(String.format("\\%s+", WildcardType.GLOB.val), WildcardType.GLOB.val); } - private WildcardTrie add(String subject, boolean isTerminal) { - if (subject == null || subject.isEmpty()) { - this.isTerminal |= isTerminal; - return this; - } - StringBuilder currPrefix = new StringBuilder(subject.length()); - for (int i = 0; i < subject.length(); i++) { - char c = subject.charAt(i); - if (c == GLOB_WILDCARD.charAt(0)) { - return addGlobWildcard(subject, currPrefix.toString(), isTerminal); - } - if (c == SINGLE_CHAR_WILDCARD.charAt(0)) { - return addSingleCharWildcard(subject, currPrefix.toString(), isTerminal); + public WildcardTrie withPattern(@NonNull String s) { + root = new Node(); + withPattern(root, cleanPattern(s)); + return this; + } + + private Node withPattern(@NonNull Node n, @NonNull String s) { + StringBuilder token = new StringBuilder(s.length()); + for (int i = 0; i < s.length(); i++) { + char c = s.charAt(i); + if (isWildcard(s.charAt(i))) { + WildcardType type = WildcardType.from(c); + Node node = n.children.get(type.val()); + node.wildcardType = type; + if (i == s.length() - 1) { + // we've reached the last token + return node; + } + return withPattern(node, s.substring(token.length() + 2)); + } else { + token.append(c); } - currPrefix.append(c); } - WildcardTrie node = children.get(currPrefix.toString()); - node.isTerminal |= isTerminal; - return node; + // use remaining (non-wildcard) chars as last token + if (token.length() > 0) { + return n.children.get(token.toString()); + } else { + return n; + } } - private WildcardTrie addGlobWildcard(String subject, String currPrefix, boolean isTerminal) { - WildcardTrie node = this; - node = node.add(currPrefix, false); - node = node.children.get(GLOB_WILDCARD); - node.isGlobWildcard = true; - // wildcard at end of subject is terminal - if (subject.length() - currPrefix.length() == 1) { - node.isTerminal = isTerminal; - return node; - } - return node.add(subject.substring(currPrefix.length() + 2), true); + private boolean isWildcard(char c) { + WildcardType type = WildcardType.from(c); + if (type == null) { + return false; + } + if (type == WildcardType.SINGLE) { + return opts.useSingleCharWildcard; + } + return true; } - private WildcardTrie addSingleCharWildcard(String subject, String currPrefix, boolean isTerminal) { - WildcardTrie node = this; - node = node.add(currPrefix, false); - node = node.children.get(SINGLE_CHAR_WILDCARD); - node.isSingleCharWildcard = true; - // wildcard at end of subject is terminal - if (subject.length() - currPrefix.length() == 1) { - node.isTerminal = isTerminal; - return node; - } - return node.add(subject.substring(currPrefix.length() + 1), true); + public boolean matches(@NonNull String s) { + if (root == null) { + return s.isEmpty(); + } + return matches(root, s); } - public boolean matches(String s) { - return matches(s, true); + private boolean matches(@NonNull Node n, String s) { + if (n.isTerminal()) { + if (n.isWildcard()) { + switch (n.wildcardType) { + case SINGLE: + return s.length() == 1; + case GLOB: + return true; + default: + throw new UnsupportedOperationException("wildcard type " + n.wildcardType.name() + " not supported"); + } + } else { + return s.isEmpty(); + } + } + + for (String token : n.children.keySet()) { + Node child = n.children.get(token); + + if (n.isWildcard()) { // parent is a wildcard + switch (n.wildcardType) { + case SINGLE: + // skip over one character for single wildcard + return matches(child, s.substring(1)); + case GLOB: + // consume the input string to find a match + return allIndicesOf(s, token).stream() + .anyMatch(tokenIndex -> + matches(child, s.substring(tokenIndex + token.length())) + ); + default: + throw new UnsupportedOperationException("wildcard type " + n.wildcardType.name() + " not supported"); + } + } + + if (child.isWildcard()) { + // skip past the wildcard node, + // on the next iteration we need to figure out + // the part the wildcard matched (if at all). + return matches(child, s); + } else { + // match found, keep following this trie branch + if (s.startsWith(token)) { + return matches(child, s.substring(token.length())); + } + } + } + + return false; } - public boolean matches(String s, boolean matchSingleCharWildcard) { - if (s == null) { - return children.isEmpty(); + private static List allIndicesOf(@NonNull String s, @NonNull String sub) { + List indices = new ArrayList<>(); + int i = s.indexOf(sub); + while (i >= 0) { + indices.add(i); + i = s.indexOf(sub, i + sub.length()); } + return indices; + } + + @Value + @Builder + public static class MatchOptions { + boolean useSingleCharWildcard; + } + + enum WildcardType { + GLOB("*"), + SINGLE("?"); - if ((isWildcard() && isTerminal) || (isTerminal && s.isEmpty())) { - return true; + private final String val; + + WildcardType(@NonNull String val) { + this.val = val; } - boolean childMatchesWildcard = children - .values() - .stream() - .filter(WildcardTrie::isWildcard) - .filter(childNode -> matchSingleCharWildcard || !childNode.isSingleCharWildcard) - .anyMatch(childNode -> childNode.matches(s, matchSingleCharWildcard)); - if (childMatchesWildcard) { - return true; + public static WildcardType from(char c) { + return Arrays.stream(WildcardType.values()) + .filter(t -> t.charVal() == c) + .findFirst() + .orElse(null); } - if (matchSingleCharWildcard) { - boolean childMatchesSingleCharWildcard = children - .values() - .stream() - .filter(childNode -> childNode.isSingleCharWildcard) - .anyMatch(childNode -> childNode.matches(s, matchSingleCharWildcard)); - if (childMatchesSingleCharWildcard) { - return true; - } + public String val() { + return val; } - boolean childMatchesRegularCharacters = children - .keySet() - .stream() - .filter(s::startsWith) - .anyMatch(childToken -> { - WildcardTrie childNode = children.get(childToken); - String rest = s.substring(childToken.length()); - return childNode.matches(rest, matchSingleCharWildcard); - }); - if (childMatchesRegularCharacters) { - return true; - } - - if (isWildcard() && !isTerminal) { - return findMatchingChildSuffixesAfterWildcard(s, matchSingleCharWildcard) - .entrySet() - .stream() - .anyMatch((e) -> { - String suffix = e.getKey(); - WildcardTrie childNode = e.getValue(); - return childNode.matches(suffix, matchSingleCharWildcard); - }); + public char charVal() { + return val.charAt(0); } - return false; } - private Map findMatchingChildSuffixesAfterWildcard(String s, boolean matchSingleCharWildcard) { - Map matchingSuffixes = new HashMap<>(); - for (Map.Entry e : children.entrySet()) { - String childToken = e.getKey(); - WildcardTrie childNode = e.getValue(); - int suffixIndex = s.indexOf(childToken); - if (matchSingleCharWildcard && suffixIndex > 1) { - continue; - } - while (suffixIndex >= 0) { - matchingSuffixes.put(s.substring(suffixIndex + childToken.length()), childNode); - suffixIndex = s.indexOf(childToken, suffixIndex + 1); - } + private class Node { + private final Map children = new DefaultHashMap<>(Node::new); + private WildcardType wildcardType; + + public boolean isWildcard() { + return wildcardType != null; } - return matchingSuffixes; - } - private boolean isWildcard() { - return isGlobWildcard || isSingleCharWildcard; + public boolean isTerminal() { + return children.isEmpty(); + } } @SuppressFBWarnings("EQ_DOESNT_OVERRIDE_EQUALS") diff --git a/src/test/java/com/aws/greengrass/clientdevices/auth/util/WildcardTrieTest.java b/src/test/java/com/aws/greengrass/clientdevices/auth/util/WildcardTrieTest.java index bd0b639f4..d195d8f95 100644 --- a/src/test/java/com/aws/greengrass/clientdevices/auth/util/WildcardTrieTest.java +++ b/src/test/java/com/aws/greengrass/clientdevices/auth/util/WildcardTrieTest.java @@ -43,8 +43,8 @@ static Stream validMatches() { @MethodSource("validMatches") @ParameterizedTest void GIVEN_trie_with_wildcards_WHEN_valid_matches_provided_THEN_pass(String pattern, List matches) { - WildcardTrie trie = new WildcardTrie(); - trie.add(pattern); + WildcardTrie.MatchOptions opts = WildcardTrie.MatchOptions.builder().useSingleCharWildcard(true).build(); + WildcardTrie trie = new WildcardTrie(opts).withPattern(pattern); matches.forEach(m -> assertTrue(trie.matches(m))); } @@ -70,8 +70,8 @@ static Stream invalidMatches() { @MethodSource("invalidMatches") @ParameterizedTest void GIVEN_trie_with_wildcards_WHEN_invalid_matches_provided_THEN_fail(String pattern, List matches) { - WildcardTrie trie = new WildcardTrie(); - trie.add(pattern); + WildcardTrie.MatchOptions opts = WildcardTrie.MatchOptions.builder().useSingleCharWildcard(true).build(); + WildcardTrie trie = new WildcardTrie(opts).withPattern(pattern); matches.forEach(m -> assertFalse(trie.matches(m))); } }