Skip to content

Commit

Permalink
[KOGITO #2163] ContainsAny&ConstainsAll
Browse files Browse the repository at this point in the history
  • Loading branch information
fjtirado committed Dec 16, 2024
1 parent c34f9c9 commit 6ef4d43
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ private AttributeFilter<?> mapJsonArgument(String attribute, String key, Object
case LIKE:
return jsonFilter(like(sb.toString(), value.toString()));
case CONTAINS_ALL:
return filterValueList(value, val -> containsAll(sb.toString(), val));
return jsonFilter(filterValueList(value, val -> containsAll(sb.toString(), val)));
case CONTAINS_ANY:
return filterValueList(value, val -> containsAny(sb.toString(), val));
return jsonFilter(filterValueList(value, val -> containsAny(sb.toString(), val)));
case EQUAL:
default:
return jsonFilter(equalTo(sb.toString(), value));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,18 @@ void testJsonMapperContains() {
jsonFilter(contains("variables.workflowdata.number", 1)));
}

@Test
void testJsonMapperContainsAny() {
assertThat(mapper.mapJsonArgument("variables").apply(Map.of("workflowdata", Map.of("number", Map.of("containsAny", List.of(1, 2, 3)))))).containsExactly(
jsonFilter(containsAny("variables.workflowdata.number", List.of(1, 2, 3))));
}

@Test
void testJsonMapperContainsAll() {
assertThat(mapper.mapJsonArgument("variables").apply(Map.of("workflowdata", Map.of("number", Map.of("containsAll", List.of(1, 2, 3)))))).containsExactly(
jsonFilter(containsAll("variables.workflowdata.number", List.of(1, 2, 3))));
}

@Test
void testJsonMapperLike() {
assertThat(mapper.mapJsonArgument("variables").apply(Map.of("workflowdata", Map.of("number", Map.of("like", "kk"))))).containsExactly(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.kie.kogito.index.postgresql;

import java.util.Iterator;
import java.util.List;

import org.hibernate.dialect.function.StandardSQLFunction;
Expand All @@ -30,12 +31,21 @@

public class ContainsSQLFunction extends StandardSQLFunction {

static final String NAME = "contains";
static final String CONTAINS_NAME = "contains";
static final String CONTAINS_ALL_NAME = "containsAll";
static final String CONTAINS_ANY_NAME = "containsAny";

static final String CONTAINS_SEQ = "??";
static final String CONTAINS_ALL_SEQ = "??&";
static final String CONTAINS_ANY_SEQ = "??|";

private final String operator;

private static final BasicTypeReference<Boolean> RETURN_TYPE = new BasicTypeReference<>("boolean", Boolean.class, SqlTypes.BOOLEAN);

public ContainsSQLFunction() {
super(NAME, RETURN_TYPE);
ContainsSQLFunction(String name, String operator) {
super(name, RETURN_TYPE);
this.operator = operator;
}

@Override
Expand All @@ -44,9 +54,23 @@ public void render(
List<? extends SqlAstNode> args,
ReturnableType<?> returnType,
SqlAstTranslator<?> translator) {
args.get(0).accept(translator);
sqlAppender.append(" ?? ");
args.get(1).accept(translator);
int size = args.size();
if (size < 2) {
throw new IllegalArgumentException("Function " + getName() + " requires at least two arguments");
}
Iterator<? extends SqlAstNode> iter = args.iterator();
iter.next().accept(translator);
sqlAppender.append(' ');
sqlAppender.append(operator);
sqlAppender.append(' ');
if (size == 2) {
iter.next().accept(translator);
} else {
sqlAppender.append("array[");
do {
iter.next().accept(translator);
sqlAppender.append(iter.hasNext() ? ',' : ']');
} while (iter.hasNext());
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,17 @@

import org.hibernate.boot.model.FunctionContributions;
import org.hibernate.boot.model.FunctionContributor;
import org.hibernate.query.sqm.function.SqmFunctionRegistry;

import static org.kie.kogito.index.postgresql.ContainsSQLFunction.*;

public class CustomFunctionsContributor implements FunctionContributor {

@Override
public void contributeFunctions(FunctionContributions functionContributions) {
functionContributions.getFunctionRegistry()
.register(ContainsSQLFunction.NAME, new ContainsSQLFunction());
SqmFunctionRegistry registry = functionContributions.getFunctionRegistry();
registry.register(CONTAINS_NAME, new ContainsSQLFunction(CONTAINS_NAME, CONTAINS_SEQ));
registry.register(CONTAINS_ANY_NAME, new ContainsSQLFunction(CONTAINS_ANY_NAME, CONTAINS_ANY_SEQ));
registry.register(CONTAINS_ALL_NAME, new ContainsSQLFunction(CONTAINS_ALL_NAME, CONTAINS_ALL_SEQ));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.kie.kogito.persistence.api.query.AttributeFilter;

Expand Down Expand Up @@ -74,11 +75,22 @@ public static Predicate buildPredicate(AttributeFilter<?> filter, Root<?> root,
return buildPathExpression(builder, root, filter.getAttribute(), isString).in(values.stream().map(o -> buildObjectExpression(builder, o, isString)).collect(Collectors.toList()));
case CONTAINS:
return builder.isTrue(
builder.function("contains", Boolean.class, buildPathExpression(builder, root, filter.getAttribute(), false), builder.literal(filter.getValue())));
builder.function(ContainsSQLFunction.CONTAINS_NAME, Boolean.class, buildPathExpression(builder, root, filter.getAttribute(), false), builder.literal(filter.getValue())));
case CONTAINS_ANY:
return containsPredicate(filter, root, builder, ContainsSQLFunction.CONTAINS_ANY_NAME);
case CONTAINS_ALL:
return containsPredicate(filter, root, builder, ContainsSQLFunction.CONTAINS_ALL_NAME);
}
throw new UnsupportedOperationException("Filter " + filter + " is not supported");
}

private static Predicate containsPredicate(AttributeFilter<?> filter, Root<?> root, CriteriaBuilder builder, String name) {
return builder.isTrue(
builder.function(name, Boolean.class,
Stream.concat(Stream.of(buildPathExpression(builder, root, filter.getAttribute(), false)), ((List<?>) filter.getValue()).stream().map(o -> builder.literal(o)))
.toArray(Expression[]::new)));
}

private static Expression buildObjectExpression(CriteriaBuilder builder, Object value, boolean isString) {
return isString ? builder.literal(value) : builder.function("to_jsonb", Object.class, builder.literal(value));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,5 +115,13 @@ void testProcessInstanceVariables() {
processInstanceId);
queryAndAssert(assertNotId(), storage, singletonList(jsonFilter(contains("variables.traveller.aliases", "TheDummyThing"))), null, null, null,
processInstanceId);
queryAndAssert(assertWithId(), storage, singletonList(jsonFilter(containsAny("variables.traveller.aliases", List.of("TheRealThing", "TheDummyThing")))), null, null, null,
processInstanceId);
queryAndAssert(assertNotId(), storage, singletonList(jsonFilter(containsAny("variables.traveller.aliases", List.of("TheRedPandaThing", "TheDummyThing")))), null, null, null,
processInstanceId);
queryAndAssert(assertWithId(), storage, singletonList(jsonFilter(containsAll("variables.traveller.aliases", List.of("Super", "Astonishing", "TheRealThing")))), null, null, null,
processInstanceId);
queryAndAssert(assertNotId(), storage, singletonList(jsonFilter(containsAll("variables.traveller.aliases", List.of("Super", "TheDummyThing")))), null, null, null,
processInstanceId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ public static <T> AttributeFilter<List<T>> in(String attribute, List<T> values)
return new AttributeFilter<>(attribute, FilterCondition.IN, values);
}

public static AttributeFilter<List<String>> containsAny(String attribute, List<String> values) {
public static <T> AttributeFilter<List<T>> containsAny(String attribute, List<T> values) {
return new AttributeFilter<>(attribute, FilterCondition.CONTAINS_ANY, values);
}

public static AttributeFilter<List<String>> containsAll(String attribute, List<String> values) {
public static <T> AttributeFilter<List<T>> containsAll(String attribute, List<T> values) {
return new AttributeFilter<>(attribute, FilterCondition.CONTAINS_ALL, values);
}

Expand Down

0 comments on commit 6ef4d43

Please sign in to comment.