Skip to content

Commit

Permalink
add draw utility penalty in puct selection
Browse files Browse the repository at this point in the history
Add @hzyhhzy's idea of less exploration for drawish children nodes. This seems to reduce the amount of useless search spent on highly drawish node, so the search will focus more on important branches.

For MCTS search,
Passed LTC on f15:
TC: 60+0.6
Total/Win/Draw/Lose: 2212 / 743 / 1059 / 410
PTNML: 31 / 151 / 495 / 312 / 117
WinRate: 57.53%
ELO: 52.41[41.58, 63.54]

Passed VVLTC on f15:
TC: 360+3.6
Total/Win/Draw/Lose: 5048 / 1162 / 3257 / 629
PTNML: 30 / 282 / 1446 / 657 / 109
WinRate: 55.28%
ELO: 36.58[28.33, 44.97]
  • Loading branch information
dhbloo committed Oct 13, 2024
1 parent e73658a commit ed79856
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 18 deletions.
3 changes: 3 additions & 0 deletions Rapfi/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ int MaxNonPVRootmovesToPrint = 10;
int NumNodesAfterSingularRoot = 100;
/// The power of two number of shards that the node table has.
int NumNodeTableShardsPowerOfTwo = 10;
/// The ratio to decrase utility when child draw rate is high.
float DrawUtilityPenalty = 0.35f;

// Time management options

Expand Down Expand Up @@ -430,6 +432,7 @@ void Config::readSearch(const cpptoml::table &t)
t.get_as<int>("num_nodes_after_singular_root").value_or(NumNodesAfterSingularRoot);
NumNodeTableShardsPowerOfTwo =
t.get_as<int>("num_node_table_shards_power_of_two").value_or(NumNodeTableShardsPowerOfTwo);
DrawUtilityPenalty = t.get_as<double>("draw_utility_penalty").value_or(DrawUtilityPenalty);

// Read time management options
if (auto tm = t.get_table("timectl")) {
Expand Down
15 changes: 8 additions & 7 deletions Rapfi/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,14 @@ extern int NumIterationAfterMate;
extern int NumIterationAfterSingularRoot;
extern int MaxSearchDepth;

extern bool ExpandWhenFirstEvaluate;
extern int MaxNumVisitsPerPlayout;
extern int NodesToPrintMCTSRootmoves;
extern int TimeToPrintMCTSRootmoves;
extern int MaxNonPVRootmovesToPrint;
extern int NumNodesAfterSingularRoot;
extern int NumNodeTableShardsPowerOfTwo;
extern bool ExpandWhenFirstEvaluate;
extern int MaxNumVisitsPerPlayout;
extern int NodesToPrintMCTSRootmoves;
extern int TimeToPrintMCTSRootmoves;
extern int MaxNonPVRootmovesToPrint;
extern int NumNodesAfterSingularRoot;
extern int NumNodeTableShardsPowerOfTwo;
extern float DrawUtilityPenalty;

// -------------------------------------------------
// Time management options
Expand Down
22 changes: 11 additions & 11 deletions Rapfi/search/mcts/parameter.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,27 @@ namespace Search::MCTS {

constexpr float MaxNewVisitsProp = 0.36f;

constexpr float CpuctExploration = 0.39f;
constexpr float CpuctExplorationLog = 0.98f;
constexpr float CpuctExplorationBase = 340;
constexpr float CpuctExploration = 0.35f;
constexpr float CpuctExplorationLog = 1.02f;
constexpr float CpuctExplorationBase = 328;

constexpr float CpuctUtilityStdevScale = 0.043f;
constexpr float CpuctUtilityVarPrior = 0.16f;
constexpr float CpuctUtilityVarPriorWeight = 1.87f;
constexpr float CpuctUtilityStdevScale = 0.05f;
constexpr float CpuctUtilityVarPrior = 0.15f;
constexpr float CpuctUtilityVarPriorWeight = 1.80f;

constexpr float FpuReductionMax = 0.06f;
constexpr float FpuReductionMax = 0.055f;
constexpr float FpuLossProp = 0.0008f;
constexpr float RootFpuReductionMax = 0.073f;
constexpr float RootFpuLossProp = 0.0036f;
constexpr float FpuUtilityBlendPow = 0.84f;
constexpr float RootFpuReductionMax = 0.07f;
constexpr float RootFpuLossProp = 0.003f;
constexpr float FpuUtilityBlendPow = 0.75f;

constexpr uint32_t MinTranspositionSkipVisits = 11;

constexpr bool UseLCBForBestmoveSelection = true;
constexpr float LCBStdevs = 6.28f;
constexpr float LCBMinVisitProp = 0.1f;

constexpr float PolicyTemperature = 0.91f;
constexpr float PolicyTemperature = 0.90f;
constexpr float RootPolicyTemperature = 1.05f;

} // namespace Search::MCTS
13 changes: 13 additions & 0 deletions Rapfi/search/mcts/search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ inline float fpuValue(float parentAvgUtility, float parentRawUtility, float expl

/// Compute PUCT selection value with the given child statistics.
inline float puctSelectionValue(float childUtility,
float childDraw,
float parentDraw,
float childPolicy,
uint32_t childVisits,
uint32_t childVirtualVisits,
Expand All @@ -265,6 +267,11 @@ inline float puctSelectionValue(float childUtility,
float U = cpuctExploration * childPolicy / (1 + childVisits);
float Q = childUtility;

// Reduce utility value for drawish child nodes for PUCT selection
// Encourage exploration for less drawish child nodes
if (Config::DrawUtilityPenalty != 0)
Q -= Config::DrawUtilityPenalty * childDraw * (1 - parentDraw);

// Account for virtual losses
if (childVirtualVisits > 0)
Q = (Q * childVisits - childVirtualVisits) / (childVisits + childVirtualVisits);
Expand Down Expand Up @@ -302,6 +309,7 @@ std::pair<Edge *, Node *> selectChild(Node &node, const Board &board)
SearchThread *thisThread = board.thisThread();

uint32_t parentVisits = node.getVisits();
float parentDraw = node.getD();
float cpuctExploration = cpuctExplorationFactor(parentVisits);

// Apply dynamic cpuct scaling based on parent utility variance if needed
Expand Down Expand Up @@ -346,7 +354,10 @@ std::pair<Edge *, Node *> selectChild(Node &node, const Board &board)
uint32_t childVisits = childEdge.getVisits();
uint32_t childVirtualVisits = childNode->getVirtualVisits();
float childUtility = -childNode->getQ();
float childDraw = childNode->getD();
float selectionValue = puctSelectionValue(childUtility,
childDraw,
parentDraw,
childPolicy,
childVisits,
childVirtualVisits,
Expand All @@ -368,6 +379,8 @@ std::pair<Edge *, Node *> selectChild(Node &node, const Board &board)
uint32_t childVisits = 0; // Unexplored edge must has zero edge visit
uint32_t childVirtualVisits = 0; // Unexplored edge must has zero virtual visit
float selectionValue = puctSelectionValue(fpuUtility,
parentDraw,
parentDraw,
childPolicy,
childVisits,
childVirtualVisits,
Expand Down

0 comments on commit ed79856

Please sign in to comment.