diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSession.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSession.java index 97dd82a3ca..965ce5ee51 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSession.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSession.java @@ -1,8 +1,6 @@ package ai.timefold.solver.core.impl.score.stream.bavet; -import java.util.Collections; import java.util.IdentityHashMap; -import java.util.List; import java.util.Map; import ai.timefold.solver.core.api.score.Score; @@ -10,7 +8,6 @@ import ai.timefold.solver.core.api.score.constraint.Indictment; import ai.timefold.solver.core.impl.score.director.stream.BavetConstraintStreamScoreDirectorFactory; import ai.timefold.solver.core.impl.score.stream.bavet.common.PropagationQueue; -import ai.timefold.solver.core.impl.score.stream.bavet.common.Propagator; import ai.timefold.solver.core.impl.score.stream.bavet.uni.AbstractForEachUniNode; import ai.timefold.solver.core.impl.score.stream.common.inliner.AbstractScoreInliner; @@ -24,21 +21,17 @@ public final class BavetConstraintSession> { private final AbstractScoreInliner scoreInliner; - private final Map, List>> declaredClassToNodeMap; - private final Propagator[][] layeredNodes; // First level is the layer, second determines iteration order. + private final NodeNetwork nodeNetwork; private final Map, AbstractForEachUniNode[]> effectiveClassToNodeArrayMap; BavetConstraintSession(AbstractScoreInliner scoreInliner) { - this(scoreInliner, Collections.emptyMap(), new Propagator[0][0]); + this(scoreInliner, NodeNetwork.EMPTY); } - BavetConstraintSession(AbstractScoreInliner scoreInliner, - Map, List>> declaredClassToNodeMap, - Propagator[][] layeredNodes) { + BavetConstraintSession(AbstractScoreInliner scoreInliner, NodeNetwork nodeNetwork) { this.scoreInliner = scoreInliner; - this.declaredClassToNodeMap = declaredClassToNodeMap; - this.layeredNodes = layeredNodes; - this.effectiveClassToNodeArrayMap = new IdentityHashMap<>(declaredClassToNodeMap.size()); + this.nodeNetwork = nodeNetwork; + this.effectiveClassToNodeArrayMap = new IdentityHashMap<>(nodeNetwork.forEachNodeCount()); } public void insert(Object fact) { @@ -52,12 +45,7 @@ private AbstractForEachUniNode[] findNodes(Class factClass) { // Map.computeIfAbsent() would have created lambdas on the hot path, this will not. var nodeArray = effectiveClassToNodeArrayMap.get(factClass); if (nodeArray == null) { - nodeArray = declaredClassToNodeMap.entrySet() - .stream() - .filter(entry -> entry.getKey().isAssignableFrom(factClass)) - .map(Map.Entry::getValue) - .flatMap(List::stream) - .toArray(AbstractForEachUniNode[]::new); + nodeArray = nodeNetwork.getApplicableForEachNodes(factClass); effectiveClassToNodeArrayMap.put(factClass, nodeArray); } return nodeArray; @@ -78,31 +66,10 @@ public void retract(Object fact) { } public Score_ calculateScore(int initScore) { - var layerCount = layeredNodes.length; - for (var layerIndex = 0; layerIndex < layerCount; layerIndex++) { - calculateScoreInLayer(layerIndex); - } + nodeNetwork.propagate(); return scoreInliner.extractScore(initScore); } - private void calculateScoreInLayer(int layerIndex) { - var nodesInLayer = layeredNodes[layerIndex]; - var nodeCount = nodesInLayer.length; - if (nodeCount == 1) { - nodesInLayer[0].propagateEverything(); - } else { - for (var node : nodesInLayer) { - node.propagateRetracts(); - } - for (var node : nodesInLayer) { - node.propagateUpdates(); - } - for (var node : nodesInLayer) { - node.propagateInserts(); - } - } - } - public AbstractScoreInliner getScoreInliner() { return scoreInliner; } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSessionFactory.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSessionFactory.java index 82628552ee..d08b39e212 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSessionFactory.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSessionFactory.java @@ -2,11 +2,11 @@ import java.util.ArrayList; import java.util.Collections; -import java.util.HashMap; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Objects; +import java.util.Set; import java.util.TreeMap; import java.util.stream.Collectors; @@ -28,6 +28,7 @@ import ai.timefold.solver.core.impl.score.stream.bavet.uni.AbstractForEachUniNode; import ai.timefold.solver.core.impl.score.stream.common.ConstraintLibrary; import ai.timefold.solver.core.impl.score.stream.common.inliner.AbstractScoreInliner; +import ai.timefold.solver.core.impl.util.CollectionUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -66,7 +67,7 @@ public BavetConstraintSession buildSession(Solution_ workingSolution, bo var scoreDefinition = solutionDescriptor. getScoreDefinition(); var zeroScore = scoreDefinition.getZeroScore(); var constraintStreamSet = new LinkedHashSet>(); - var constraintWeightMap = new HashMap(constraintLibrary.getConstraints().size()); + var constraintWeightMap = CollectionUtils. newHashMap(constraintLibrary.getConstraints().size()); // Only log constraint weights if logging is enabled; otherwise we don't need to build the string. var constraintWeightLoggingEnabled = !scoreDirectorDerived && LOGGER.isEnabledForLevel(CONSTRAINT_WEIGHT_LOGGING_LEVEL); @@ -118,6 +119,11 @@ public BavetConstraintSession buildSession(Solution_ workingSolution, bo LOGGER.atLevel(CONSTRAINT_WEIGHT_LOGGING_LEVEL) .log(constraintWeightString.toString().trim()); } + return new BavetConstraintSession<>(scoreInliner, buildNodeNetwork(constraintStreamSet, scoreInliner)); + } + + private static > NodeNetwork buildNodeNetwork( + Set> constraintStreamSet, AbstractScoreInliner scoreInliner) { /* * Build constraintStreamSet in reverse order to create downstream nodes first * so every node only has final variables (some of which have downstream node method references). @@ -162,7 +168,7 @@ public BavetConstraintSession buildSession(Solution_ workingSolution, bo var layer = layerMap.get((long) i); layeredNodes[i] = layer.toArray(new Propagator[0]); } - return new BavetConstraintSession<>(scoreInliner, declaredClassToNodeMap, layeredNodes); + return new NodeNetwork(declaredClassToNodeMap, layeredNodes); } /** @@ -180,7 +186,8 @@ public BavetConstraintSession buildSession(Solution_ workingSolution, bo * @param buildHelper never null * @return at least 0 */ - private long determineLayerIndex(AbstractNode node, NodeBuildHelper buildHelper) { + private static > long determineLayerIndex(AbstractNode node, + NodeBuildHelper buildHelper) { if (node instanceof AbstractForEachUniNode) { // ForEach nodes, and only they, are in layer 0. return 0; } else if (node instanceof AbstractJoinNode joinNode) { @@ -199,8 +206,8 @@ private long determineLayerIndex(AbstractNode node, NodeBuildHelper buil } } - private long determineLayerIndexOfBinaryOperation(BavetStreamBinaryOperation nodeCreator, - NodeBuildHelper buildHelper) { + private static > long determineLayerIndexOfBinaryOperation( + BavetStreamBinaryOperation nodeCreator, NodeBuildHelper buildHelper) { var leftParent = nodeCreator.getLeftParent(); var rightParent = nodeCreator.getRightParent(); var leftParentNode = buildHelper.findParentNode(leftParent); diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/NodeNetwork.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/NodeNetwork.java new file mode 100644 index 0000000000..83511b5ecf --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/NodeNetwork.java @@ -0,0 +1,85 @@ +package ai.timefold.solver.core.impl.score.stream.bavet; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import ai.timefold.solver.core.impl.score.stream.bavet.common.Propagator; +import ai.timefold.solver.core.impl.score.stream.bavet.uni.AbstractForEachUniNode; + +/** + * Represents Bavet's network of nodes, specific to a particular session. + * Nodes only used by disabled constraints have already been removed. + * + * @param declaredClassToNodeMap starting nodes, one for each class used in the constraints; + * root nodes, layer index 0. + * @param layeredNodes nodes grouped first by their layer, then by their index within the layer; + * propagation needs to happen in this order. + */ +record NodeNetwork(Map, List>> declaredClassToNodeMap, Propagator[][] layeredNodes) { + + public static final NodeNetwork EMPTY = new NodeNetwork(Map.of(), new Propagator[0][0]); + + public int forEachNodeCount() { + return declaredClassToNodeMap.size(); + } + + public int layerCount() { + return layeredNodes.length; + } + + @SuppressWarnings("unchecked") + public AbstractForEachUniNode[] getApplicableForEachNodes(Class factClass) { + return declaredClassToNodeMap.entrySet() + .stream() + .filter(entry -> entry.getKey().isAssignableFrom(factClass)) + .map(Map.Entry::getValue) + .flatMap(List::stream) + .toArray(AbstractForEachUniNode[]::new); + } + + public void propagate() { + for (var layerIndex = 0; layerIndex < layerCount(); layerIndex++) { + propagateInLayer(layeredNodes[layerIndex]); + } + } + + private static void propagateInLayer(Propagator[] nodesInLayer) { + var nodeCount = nodesInLayer.length; + if (nodeCount == 1) { + nodesInLayer[0].propagateEverything(); + } else { + for (var node : nodesInLayer) { + node.propagateRetracts(); + } + for (var node : nodesInLayer) { + node.propagateUpdates(); + } + for (var node : nodesInLayer) { + node.propagateInserts(); + } + } + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof NodeNetwork that)) + return false; + return Objects.equals(declaredClassToNodeMap, that.declaredClassToNodeMap) + && Objects.deepEquals(layeredNodes, that.layeredNodes); + } + + @Override + public int hashCode() { + return Objects.hash(declaredClassToNodeMap, Arrays.deepHashCode(layeredNodes)); + } + + @Override + public String toString() { + return this.getClass().getSimpleName() + " with " + forEachNodeCount() + " forEach nodes."; + } + +}