Skip to content

Commit

Permalink
chore: refactor constraint session logic into smaller pieces
Browse files Browse the repository at this point in the history
  • Loading branch information
triceo committed Sep 16, 2024
1 parent f09a244 commit 9d9d8fb
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 46 deletions.
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
package ai.timefold.solver.core.impl.score.stream.bavet;

import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;

import ai.timefold.solver.core.api.score.Score;
import ai.timefold.solver.core.api.score.constraint.ConstraintMatchTotal;
import ai.timefold.solver.core.api.score.constraint.Indictment;
import ai.timefold.solver.core.impl.score.director.stream.BavetConstraintStreamScoreDirectorFactory;
import ai.timefold.solver.core.impl.score.stream.bavet.common.PropagationQueue;
import ai.timefold.solver.core.impl.score.stream.bavet.common.Propagator;
import ai.timefold.solver.core.impl.score.stream.bavet.uni.AbstractForEachUniNode;
import ai.timefold.solver.core.impl.score.stream.common.inliner.AbstractScoreInliner;

Expand All @@ -24,21 +21,17 @@
public final class BavetConstraintSession<Score_ extends Score<Score_>> {

private final AbstractScoreInliner<Score_> scoreInliner;
private final Map<Class<?>, List<AbstractForEachUniNode<Object>>> declaredClassToNodeMap;
private final Propagator[][] layeredNodes; // First level is the layer, second determines iteration order.
private final NodeNetwork nodeNetwork;
private final Map<Class<?>, AbstractForEachUniNode<Object>[]> effectiveClassToNodeArrayMap;

BavetConstraintSession(AbstractScoreInliner<Score_> scoreInliner) {
this(scoreInliner, Collections.emptyMap(), new Propagator[0][0]);
this(scoreInliner, NodeNetwork.EMPTY);
}

BavetConstraintSession(AbstractScoreInliner<Score_> scoreInliner,
Map<Class<?>, List<AbstractForEachUniNode<Object>>> declaredClassToNodeMap,
Propagator[][] layeredNodes) {
BavetConstraintSession(AbstractScoreInliner<Score_> scoreInliner, NodeNetwork nodeNetwork) {
this.scoreInliner = scoreInliner;
this.declaredClassToNodeMap = declaredClassToNodeMap;
this.layeredNodes = layeredNodes;
this.effectiveClassToNodeArrayMap = new IdentityHashMap<>(declaredClassToNodeMap.size());
this.nodeNetwork = nodeNetwork;
this.effectiveClassToNodeArrayMap = new IdentityHashMap<>(nodeNetwork.forEachNodeCount());
}

public void insert(Object fact) {
Expand All @@ -52,12 +45,7 @@ private AbstractForEachUniNode<Object>[] findNodes(Class<?> factClass) {
// Map.computeIfAbsent() would have created lambdas on the hot path, this will not.
var nodeArray = effectiveClassToNodeArrayMap.get(factClass);
if (nodeArray == null) {
nodeArray = declaredClassToNodeMap.entrySet()
.stream()
.filter(entry -> entry.getKey().isAssignableFrom(factClass))
.map(Map.Entry::getValue)
.flatMap(List::stream)
.toArray(AbstractForEachUniNode[]::new);
nodeArray = nodeNetwork.getApplicableForEachNodes(factClass);
effectiveClassToNodeArrayMap.put(factClass, nodeArray);
}
return nodeArray;
Expand All @@ -78,31 +66,10 @@ public void retract(Object fact) {
}

public Score_ calculateScore(int initScore) {
var layerCount = layeredNodes.length;
for (var layerIndex = 0; layerIndex < layerCount; layerIndex++) {
calculateScoreInLayer(layerIndex);
}
nodeNetwork.propagate();
return scoreInliner.extractScore(initScore);
}

private void calculateScoreInLayer(int layerIndex) {
var nodesInLayer = layeredNodes[layerIndex];
var nodeCount = nodesInLayer.length;
if (nodeCount == 1) {
nodesInLayer[0].propagateEverything();
} else {
for (var node : nodesInLayer) {
node.propagateRetracts();
}
for (var node : nodesInLayer) {
node.propagateUpdates();
}
for (var node : nodesInLayer) {
node.propagateInserts();
}
}
}

public AbstractScoreInliner<Score_> getScoreInliner() {
return scoreInliner;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.TreeMap;
import java.util.stream.Collectors;

Expand All @@ -28,6 +28,7 @@
import ai.timefold.solver.core.impl.score.stream.bavet.uni.AbstractForEachUniNode;
import ai.timefold.solver.core.impl.score.stream.common.ConstraintLibrary;
import ai.timefold.solver.core.impl.score.stream.common.inliner.AbstractScoreInliner;
import ai.timefold.solver.core.impl.util.CollectionUtils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -66,7 +67,7 @@ public BavetConstraintSession<Score_> buildSession(Solution_ workingSolution, bo
var scoreDefinition = solutionDescriptor.<Score_> getScoreDefinition();
var zeroScore = scoreDefinition.getZeroScore();
var constraintStreamSet = new LinkedHashSet<BavetAbstractConstraintStream<Solution_>>();
var constraintWeightMap = new HashMap<Constraint, Score_>(constraintLibrary.getConstraints().size());
var constraintWeightMap = CollectionUtils.<Constraint, Score_> newHashMap(constraintLibrary.getConstraints().size());

// Only log constraint weights if logging is enabled; otherwise we don't need to build the string.
var constraintWeightLoggingEnabled = !scoreDirectorDerived && LOGGER.isEnabledForLevel(CONSTRAINT_WEIGHT_LOGGING_LEVEL);
Expand Down Expand Up @@ -118,6 +119,11 @@ public BavetConstraintSession<Score_> buildSession(Solution_ workingSolution, bo
LOGGER.atLevel(CONSTRAINT_WEIGHT_LOGGING_LEVEL)
.log(constraintWeightString.toString().trim());
}
return new BavetConstraintSession<>(scoreInliner, buildNodeNetwork(constraintStreamSet, scoreInliner));
}

private static <Solution_, Score_ extends Score<Score_>> NodeNetwork buildNodeNetwork(
Set<BavetAbstractConstraintStream<Solution_>> constraintStreamSet, AbstractScoreInliner<Score_> scoreInliner) {
/*
* Build constraintStreamSet in reverse order to create downstream nodes first
* so every node only has final variables (some of which have downstream node method references).
Expand Down Expand Up @@ -162,7 +168,7 @@ public BavetConstraintSession<Score_> buildSession(Solution_ workingSolution, bo
var layer = layerMap.get((long) i);
layeredNodes[i] = layer.toArray(new Propagator[0]);
}
return new BavetConstraintSession<>(scoreInliner, declaredClassToNodeMap, layeredNodes);
return new NodeNetwork(declaredClassToNodeMap, layeredNodes);
}

/**
Expand All @@ -180,7 +186,8 @@ public BavetConstraintSession<Score_> buildSession(Solution_ workingSolution, bo
* @param buildHelper never null
* @return at least 0
*/
private long determineLayerIndex(AbstractNode node, NodeBuildHelper<Score_> buildHelper) {
private static <Score_ extends Score<Score_>> long determineLayerIndex(AbstractNode node,
NodeBuildHelper<Score_> buildHelper) {
if (node instanceof AbstractForEachUniNode<?>) { // ForEach nodes, and only they, are in layer 0.
return 0;
} else if (node instanceof AbstractJoinNode<?, ?, ?> joinNode) {
Expand All @@ -199,8 +206,8 @@ private long determineLayerIndex(AbstractNode node, NodeBuildHelper<Score_> buil
}
}

private long determineLayerIndexOfBinaryOperation(BavetStreamBinaryOperation<?> nodeCreator,
NodeBuildHelper<Score_> buildHelper) {
private static <Score_ extends Score<Score_>> long determineLayerIndexOfBinaryOperation(
BavetStreamBinaryOperation<?> nodeCreator, NodeBuildHelper<Score_> buildHelper) {
var leftParent = nodeCreator.getLeftParent();
var rightParent = nodeCreator.getRightParent();
var leftParentNode = buildHelper.findParentNode(leftParent);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package ai.timefold.solver.core.impl.score.stream.bavet;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import ai.timefold.solver.core.impl.score.stream.bavet.common.Propagator;
import ai.timefold.solver.core.impl.score.stream.bavet.uni.AbstractForEachUniNode;

/**
* Represents Bavet's network of nodes, specific to a particular session.
* Nodes only used by disabled constraints have already been removed.
*
* @param declaredClassToNodeMap starting nodes, one for each class used in the constraints;
* root nodes, layer index 0.
* @param layeredNodes nodes grouped first by their layer, then by their index within the layer;
* propagation needs to happen in this order.
*/
record NodeNetwork(Map<Class<?>, List<AbstractForEachUniNode<Object>>> declaredClassToNodeMap, Propagator[][] layeredNodes) {

public static final NodeNetwork EMPTY = new NodeNetwork(Map.of(), new Propagator[0][0]);

public int forEachNodeCount() {
return declaredClassToNodeMap.size();
}

public int layerCount() {
return layeredNodes.length;
}

@SuppressWarnings("unchecked")
public AbstractForEachUniNode<Object>[] getApplicableForEachNodes(Class<?> factClass) {
return declaredClassToNodeMap.entrySet()
.stream()
.filter(entry -> entry.getKey().isAssignableFrom(factClass))
.map(Map.Entry::getValue)
.flatMap(List::stream)
.toArray(AbstractForEachUniNode[]::new);
}

public void propagate() {
for (var layerIndex = 0; layerIndex < layerCount(); layerIndex++) {
propagateInLayer(layeredNodes[layerIndex]);
}
}

private static void propagateInLayer(Propagator[] nodesInLayer) {
var nodeCount = nodesInLayer.length;
if (nodeCount == 1) {
nodesInLayer[0].propagateEverything();
} else {
for (var node : nodesInLayer) {
node.propagateRetracts();
}
for (var node : nodesInLayer) {
node.propagateUpdates();
}
for (var node : nodesInLayer) {
node.propagateInserts();
}
}
}

@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (!(o instanceof NodeNetwork that))
return false;
return Objects.equals(declaredClassToNodeMap, that.declaredClassToNodeMap)
&& Objects.deepEquals(layeredNodes, that.layeredNodes);
}

@Override
public int hashCode() {
return Objects.hash(declaredClassToNodeMap, Arrays.deepHashCode(layeredNodes));
}

@Override
public String toString() {
return this.getClass().getSimpleName() + " with " + forEachNodeCount() + " forEach nodes.";
}

}

0 comments on commit 9d9d8fb

Please sign in to comment.