Skip to content

Commit

Permalink
GH-5167 correctly handle rsx:targetShape with complex paths
Browse files Browse the repository at this point in the history
  • Loading branch information
hmottestad committed Nov 6, 2024
1 parent 92a5f68 commit 4c27043
Show file tree
Hide file tree
Showing 9 changed files with 197 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
package org.eclipse.rdf4j.sail.shacl.ast;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
Expand All @@ -31,7 +33,7 @@ public class StatementMatcher {
private final Variable<IRI> predicate;
private final Variable<? extends Value> object;

// private final Set<String> varNames;
// private final Set<String> varNames;
private final Targetable origin;

private final Set<String> inheritedVarNames;
Expand Down Expand Up @@ -216,37 +218,47 @@ private static String formatForToString(String field, String name, Value value)

private StatementMatcher swap(Variable<?> existingVariable, Variable<?> newVariable) {
String subjectName = getSubjectName();
String subjectBasename = getSubjectBasename();
Resource subjectValue = getSubjectValue();

String predicateName = getPredicateName();
String predicateBasename = getPredicateBasename();
IRI predicateValue = getPredicateValue();

String objectName = getObjectName();
String objectBasename = getObjectBasename();
Value objectValue = getObjectValue();

boolean changed = false;

if (Objects.equals(existingVariable.name, subjectName)
&& Objects.equals(existingVariable.value, subjectValue)) {
changed = true;
subjectName = newVariable.name;
subjectValue = (Resource) newVariable.value;
subjectBasename = newVariable.baseName;
}

if (Objects.equals(existingVariable.name, predicateName)
&& Objects.equals(existingVariable.value, predicateValue)) {
changed = true;
predicateName = newVariable.name;
predicateValue = (IRI) newVariable.value;
predicateBasename = newVariable.baseName;
}

if (Objects.equals(existingVariable.name, objectName) && Objects.equals(existingVariable.value, objectValue)) {
changed = true;
objectName = newVariable.name;
objectValue = newVariable.value;
objectBasename = newVariable.baseName;
}

if (changed) {
assert subset.isEmpty();
return new StatementMatcher(new Variable<>(subjectName, subjectValue),
new Variable<>(predicateName, predicateValue), new Variable<>(objectName, objectValue), origin,
return new StatementMatcher(new Variable<>(subjectName, subjectValue, subjectBasename),
new Variable<>(predicateName, predicateValue, predicateBasename),
new Variable<>(objectName, objectValue, objectBasename), origin,
inheritedVarNames);
}
return this;
Expand All @@ -268,6 +280,10 @@ public String getSubjectName() {
return subject.name;
}

public String getSubjectBasename() {
return subject.baseName;
}

public Resource getSubjectValue() {
return subject.value;
}
Expand All @@ -280,6 +296,10 @@ public String getPredicateName() {
return predicate.name;
}

public String getPredicateBasename() {
return predicate.baseName;
}

public IRI getPredicateValue() {
return predicate.value;
}
Expand All @@ -292,6 +312,10 @@ public String getObjectName() {
return object.name;
}

public String getObjectBasename() {
return object.baseName;
}

public Value getObjectValue() {
return object.value;
}
Expand Down Expand Up @@ -324,6 +348,7 @@ public int hashCode() {

public String getSparqlValuesDecl(Set<String> varNamesRestriction, boolean addInheritedVarNames,
Set<String> varNamesInQueryFragment) {

StringBuilder sb = new StringBuilder("VALUES ( ");
if (subject.name != null && varNamesRestriction.contains(subject.name) ||
subject.baseName != null && varNamesRestriction.contains(subject.baseName)) {
Expand Down Expand Up @@ -362,13 +387,13 @@ public String getSparqlValuesDecl(Set<String> varNamesRestriction, boolean addIn
return sb.toString();
}

public Set<String> getVarNames(Set<String> varNamesRestriction, boolean addInheritedVarNames,
public LinkedHashSet<String> getVarNames(Set<String> varNamesRestriction, boolean addInheritedVarNames,
Set<String> varNamesInQueryFragment) {
if (varNamesRestriction.isEmpty()) {
return Set.of();
return new LinkedHashSet<>();
}

HashSet<String> ret = new HashSet<>();
LinkedHashSet<String> ret = new LinkedHashSet<>();
if (subject.name != null && varNamesRestriction.contains(subject.name)
&& varNamesInQueryFragment.contains(subject.name)) {
ret.add(subject.name);
Expand Down Expand Up @@ -448,6 +473,26 @@ public boolean hasObject(Variable<Value> variable) {
return variable.name.equals(object.name);
}

public Set<String> getInheritedVarNames() {
return Collections.unmodifiableSet(new HashSet<>(inheritedVarNames));
}

public Set<String> getVarNames() {
Set<String> varNames = new HashSet<>();

if (subject.name != null) {
varNames.add(subject.name);
}
if (predicate.name != null) {
varNames.add(predicate.name);
}
if (object.name != null) {
varNames.add(object.name);
}

return Collections.unmodifiableSet(varNames);
}

public static class StableRandomVariableProvider {

// We just need a random base that isn't used elsewhere in the ShaclSail, but we don't want it to be stable so
Expand All @@ -471,9 +516,12 @@ public StableRandomVariableProvider(String prefix) {
* increments of one.
*
* @param inputQuery the query string that should be normalized
* @param union
* @return a normalized query string
*/
public static String normalize(String inputQuery) {
public static String normalize(String inputQuery, List<? extends Variable> protectedVars,
List<StatementMatcher> union) {

if (!inputQuery.contains(BASE)) {
return inputQuery;
}
Expand All @@ -499,18 +547,30 @@ public static String normalize(String inputQuery) {
if (lowest == 0 && incrementsOfOne) {
return inputQuery;
}
String joinedProtectedVars = protectedVars.stream()
.map(Variable::getName)
.filter(Objects::nonNull)
.filter(s -> s.contains(BASE))
.collect(Collectors.joining());

return normalizeRange(inputQuery, lowest, highest);
return normalizeRange(inputQuery, lowest, highest, joinedProtectedVars, union);
}

private static String normalizeRange(String inputQuery, int lowest, int highest) {
private static String normalizeRange(String inputQuery, int lowest, int highest, String joinedProtectedVars,
List<StatementMatcher> union) {

String normalizedQuery = inputQuery;
for (int i = 0; i <= highest; i++) {
if (!normalizedQuery.contains(BASE + i + "_")) {
String replacement = BASE + i + "_";
if (!normalizedQuery.contains(replacement)) {
for (int j = Math.max(i + 1, lowest); j <= highest; j++) {
if (normalizedQuery.contains(BASE + j + "_")) {
normalizedQuery = normalizedQuery.replace(BASE + j + "_", BASE + i + "_");
String original = BASE + j + "_";
if (normalizedQuery.contains(original)) {
if (joinedProtectedVars.contains(original)) {
continue;
}
normalizedQuery = normalizedQuery.replace(original, replacement);
replaceInStatementMatcher(union, original, replacement);
break;
}
}
Expand All @@ -520,6 +580,13 @@ private static String normalizeRange(String inputQuery, int lowest, int highest)
return normalizedQuery;
}

private static void replaceInStatementMatcher(List<StatementMatcher> statementMatchers, String original,
String replacement) {
for (StatementMatcher statementMatcher : statementMatchers) {
statementMatcher.replaceVariableName(original, replacement);
}
}

public Variable<Value> next() {
counter++;

Expand All @@ -538,6 +605,29 @@ public Variable<Value> current() {
}
}

private void replaceVariableName(String original, String replacement) {

if (subject.name != null && subject.name.contains(original)) {
subject.name = subject.name.replace(original, replacement);
}
if (subject.baseName != null && subject.baseName.contains(original)) {
subject.baseName = subject.baseName.replace(original, replacement);
}
if (predicate.name != null && predicate.name.contains(original)) {
predicate.name = predicate.name.replace(original, replacement);
}
if (predicate.baseName != null && predicate.baseName.contains(original)) {
predicate.baseName = predicate.baseName.replace(original, replacement);
}
if (object.name != null && object.name.contains(original)) {
object.name = object.name.replace(original, replacement);
}
if (object.baseName != null && object.baseName.contains(original)) {
object.baseName = object.baseName.replace(original, replacement);
}

}

public static class Variable<T extends Value> {
public static final Variable<Value> VALUE = new Variable<>("value");
public static final Variable<Value> THIS = new Variable<>("this");
Expand All @@ -562,6 +652,12 @@ public Variable(Variable<?> baseVariable, String name) {
this.baseName = baseVariable.name;
}

public Variable(String name, T value, String baseName) {
this.name = name;
this.value = value;
this.baseName = baseName;
}

public Variable(T value) {
this.value = value;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public BindSelect(SailConnection connection, Resource[] dataGraph, SparqlFragmen
throw new IllegalStateException();
}

this.query = StatementMatcher.StableRandomVariableProvider.normalize(query.getFragment());
this.query = StatementMatcher.StableRandomVariableProvider.normalize(query.getFragment(), vars, List.of());
this.prefixes = query.getNamespacesForSparql();
this.direction = direction;
this.includePropertyShapeValues = includePropertyShapeValues;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
package org.eclipse.rdf4j.sail.shacl.ast.planNodes;

import java.util.ArrayDeque;
import java.util.List;
import java.util.Objects;
import java.util.function.Function;

Expand Down Expand Up @@ -64,7 +65,7 @@ public BulkedExternalInnerJoin(PlanNode leftNode, SailConnection connection, Res

this.leftNode = PlanNodeHelper.handleSorting(this, leftNode);
this.query = query.getNamespacesForSparql() + StatementMatcher.StableRandomVariableProvider
.normalize(query.getFragment());
.normalize(query.getFragment(), List.of(), List.of());
this.connection = connection;
assert this.connection != null;
this.skipBasedOnPreviousConnection = skipBasedOnPreviousConnection;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
package org.eclipse.rdf4j.sail.shacl.ast.planNodes;

import java.util.ArrayDeque;
import java.util.List;
import java.util.Objects;
import java.util.function.Function;

Expand Down Expand Up @@ -49,7 +50,7 @@ public BulkedExternalLeftOuterJoin(PlanNode leftNode, SailConnection connection,
leftNode = PlanNodeHelper.handleSorting(this, leftNode);
this.leftNode = leftNode;
this.query = query.getNamespacesForSparql()
+ StatementMatcher.StableRandomVariableProvider.normalize(query.getFragment());
+ StatementMatcher.StableRandomVariableProvider.normalize(query.getFragment(), List.of(), List.of());
this.connection = connection;
assert this.connection != null;
this.mapper = mapper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

package org.eclipse.rdf4j.sail.shacl.ast.planNodes;

import java.util.List;
import java.util.Objects;
import java.util.function.Function;

Expand Down Expand Up @@ -53,7 +54,7 @@ public ExternalFilterByQuery(SailConnection connection, Resource[] dataGraph, Pl

this.queryString = queryFragment.getNamespacesForSparql()
+ StatementMatcher.StableRandomVariableProvider.normalize("SELECT " + queryVariable.asSparqlVariable()
+ " WHERE {\n" + queryFragment.getFragment() + "\n}");
+ " WHERE {\n" + queryFragment.getFragment() + "\n}", List.of(queryVariable), List.of());
try {
this.query = SparqlQueryParserCache.get(queryString);
} catch (MalformedQueryException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

package org.eclipse.rdf4j.sail.shacl.ast.planNodes;

import java.util.List;
import java.util.Objects;
import java.util.function.Function;

Expand Down Expand Up @@ -67,10 +68,11 @@ public Select(SailConnection connection, SparqlFragment queryFragment, String or

if (!sorted && fragment.trim().startsWith("select ")) {
this.query = queryFragment.getNamespacesForSparql() + "\n"
+ StatementMatcher.StableRandomVariableProvider.normalize(fragment);
+ StatementMatcher.StableRandomVariableProvider.normalize(fragment, List.of(), List.of());
} else {
this.query = queryFragment.getNamespacesForSparql() + "\n" + StatementMatcher.StableRandomVariableProvider
.normalize("select * where {\n" + fragment + "\n}" + (sorted ? " order by " + orderBy : ""));
.normalize("select * where {\n" + fragment + "\n}" + (sorted ? " order by " + orderBy : ""),
List.of(), List.of());
}

dataset = PlanNodeHelper.asDefaultGraphDataset(dataGraph);
Expand All @@ -87,7 +89,7 @@ public Select(SailConnection connection, String query, Function<BindingSet, Vali
this.connection = connection;
assert this.connection != null;
this.mapper = mapper;
this.query = StatementMatcher.StableRandomVariableProvider.normalize(query);
this.query = StatementMatcher.StableRandomVariableProvider.normalize(query, List.of(), List.of());
this.dataset = PlanNodeHelper.asDefaultGraphDataset(dataGraph);

this.sorted = false;
Expand Down
Loading

0 comments on commit 4c27043

Please sign in to comment.