Skip to content

Commit

Permalink
📦 NEW: Add Visits/Node TM
Browse files Browse the repository at this point in the history
  • Loading branch information
ianagbip1oti committed Sep 29, 2024
1 parent b539b89 commit a519e1c
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 45 deletions.
12 changes: 12 additions & 0 deletions src/evaluation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@ impl Flag {
Flag::TablebaseWin | Flag::TablebaseDraw | Flag::TablebaseLoss
)
}

#[must_use]
pub fn adjust_eval(self, eval: i64) -> i64 {
match self {
Flag::TerminalWin => 2 * SCALE as i64,
Flag::TerminalLoss => -2 * SCALE as i64,
Flag::TablebaseWin => SCALE as i64,
Flag::TablebaseLoss => -SCALE as i64,
Flag::TerminalDraw | Flag::TablebaseDraw => 0,
Flag::Standard => eval,
}
}
}

#[cfg(feature = "value-net")]
Expand Down
33 changes: 33 additions & 0 deletions src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,19 @@ static HASH: UciOption = UciOption::spin("Hash", 16, 1, 2 << 24);
static THREADS: UciOption = UciOption::spin("Threads", 1, 1, 255);
static MULTI_PV: UciOption = UciOption::spin("MultiPV", 1, 1, 255);
static SYZYGY_PATH: UciOption = UciOption::string("SyzygyPath", "<empty>");

static CPUCT: UciOption = UciOption::spin("CPuct", 16, 1, 2 << 16);
static CPUCT_TAU: UciOption = UciOption::spin("CPuctTau", 84, 0, 100);
static CVISITS_SELECTION: UciOption = UciOption::spin("CVisitsSelection", 1, 0, 100);
static POLICY_TEMPERATURE: UciOption = UciOption::spin("PolicyTemperature", 100, 0, 2 << 16);
static POLICY_TEMPERATURE_ROOT: UciOption =
UciOption::spin("PolicyTemperatureRoot", 1450, 0, 2 << 16);

static TM_MIN_M: UciOption = UciOption::spin("TMMinM", 10, 0, 2 << 16);
static TM_MAX_M: UciOption = UciOption::spin("TMMaxM", 500, 0, 2 << 16);
static TM_VISITS_BASE: UciOption = UciOption::spin("TMVisitsBase", 140, 0, 2 << 16);
static TM_VISITS_M: UciOption = UciOption::spin("TMVisitsM", 139, 0, 2 << 16);

static CHESS960: UciOption = UciOption::check("UCI_Chess960", false);
static POLICY_ONLY: UciOption = UciOption::check("PolicyOnly", false);
static SHOW_MOVESLEFT: UciOption = UciOption::check("UCI_ShowMovesLeft", false);
Expand All @@ -108,6 +115,10 @@ static ALL_OPTIONS: &[UciOption] = &[
CVISITS_SELECTION,
POLICY_TEMPERATURE,
POLICY_TEMPERATURE_ROOT,
TM_MIN_M,
TM_MAX_M,
TM_VISITS_BASE,
TM_VISITS_M,
CHESS960,
POLICY_ONLY,
SHOW_MOVESLEFT,
Expand Down Expand Up @@ -187,6 +198,16 @@ pub struct SearchOptions {
pub is_policy_only: bool,
pub show_movesleft: bool,
pub mcts_options: MctsOptions,
pub time_management_options: TimeManagementOptions,
}

#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Clone, Copy)]
pub struct TimeManagementOptions {
pub min_m: f32,
pub max_m: f32,
pub visits_base: f32,
pub visits_m: f32,
}

impl From<&UciOptionMap> for MctsOptions {
Expand All @@ -200,6 +221,17 @@ impl From<&UciOptionMap> for MctsOptions {
}
}

impl From<&UciOptionMap> for TimeManagementOptions {
fn from(map: &UciOptionMap) -> Self {
Self {
min_m: map.get_f32(&TM_MIN_M),
max_m: map.get_f32(&TM_MAX_M),
visits_base: map.get_f32(&TM_VISITS_BASE),
visits_m: map.get_f32(&TM_VISITS_M),
}
}
}

impl Default for SearchOptions {
fn default() -> Self {
SearchOptions::from(&UciOptionMap::default())
Expand All @@ -217,6 +249,7 @@ impl From<&UciOptionMap> for SearchOptions {
is_policy_only: map.get_and(&POLICY_ONLY, |s| s.parse().ok()),
show_movesleft: map.get_and(&SHOW_MOVESLEFT, |s| s.parse().ok()),
mcts_options: MctsOptions::from(map),
time_management_options: TimeManagementOptions::from(map),
}
}
}
59 changes: 33 additions & 26 deletions src/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ const MOVE_OVERHEAD: Duration = Duration::from_millis(50);

pub const SCALE: f32 = 256. * 256.;

#[must_use]
#[derive(Copy, Clone, Debug)]
pub struct TimeManagement {
start: Instant,
end: Option<Instant>,
soft_limit: Option<Duration>,
hard_limit: Option<Duration>,
node_limit: usize,
}

Expand All @@ -32,39 +34,41 @@ impl Default for TimeManagement {
}

impl TimeManagement {
#[must_use]
pub fn from_duration(d: Duration) -> Self {
let start = Instant::now();
let end = Some(start + d);
let node_limit = usize::MAX;
Self {
start: Instant::now(),
soft_limit: None,
hard_limit: Some(d),
node_limit: usize::MAX,
}
}

pub fn from_limits(soft: Duration, hard: Duration) -> Self {
Self {
start,
end,
node_limit,
start: Instant::now(),
soft_limit: Some(soft),
hard_limit: Some(hard),
node_limit: usize::MAX,
}
}

#[must_use]
pub fn infinite() -> Self {
let start = Instant::now();
let end = None;
let node_limit = usize::MAX;

Self {
start,
end,
node_limit,
start: Instant::now(),
soft_limit: None,
hard_limit: None,
node_limit: usize::MAX,
}
}

#[must_use]
pub fn is_after_end(&self) -> bool {
if let Some(end) = self.end {
Instant::now() > end
} else {
false
}
pub fn soft_limit(&self) -> Option<Duration> {
self.soft_limit
}

#[must_use]
pub fn hard_limit(&self) -> Option<Duration> {
self.hard_limit
}

#[must_use]
Expand All @@ -84,12 +88,14 @@ impl TimeManagement {

pub struct ThreadData<'a> {
pub allocator: LRAllocator<'a>,
pub playouts: usize,
}

impl<'a> ThreadData<'a> {
fn create(tree: &'a SearchTree) -> Self {
Self {
allocator: tree.allocator(),
playouts: 0,
}
}
}
Expand Down Expand Up @@ -226,11 +232,12 @@ impl Search {
move_time_fraction = (m + 2).min(move_time_fraction);
}

let ideal_think_time =
(r + move_time_fraction * increment - MOVE_OVERHEAD) / move_time_fraction;
let max_think_time = r / 3;
let r = r - MOVE_OVERHEAD;

let soft_limit = (r + move_time_fraction * increment) / move_time_fraction;
let hard_limit = r / 3;

think_time = TimeManagement::from_duration(ideal_think_time.min(max_think_time));
think_time = TimeManagement::from_limits(soft_limit.min(hard_limit), hard_limit);
}

if self.search_options.is_policy_only {
Expand Down
57 changes: 38 additions & 19 deletions src/search_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::sync::atomic::{
use crate::arena::Error as ArenaError;
use crate::chess;
use crate::evaluation::{self, Flag};
use crate::options::{MctsOptions, SearchOptions};
use crate::options::{MctsOptions, SearchOptions, TimeManagementOptions};
use crate::search::{eval_in_cp, ThreadData};
use crate::search::{TimeManagement, SCALE};
use crate::state::State;
Expand Down Expand Up @@ -290,10 +290,7 @@ impl SearchTree {
loop {
self.ttable.wait_if_flipping();

if node.is_terminal() {
break;
}
if node.hots().is_empty() {
if node.is_terminal() || node.hots().is_empty() {
break;
}
if node.is_tablebase() && state.halfmove_clock() == 0 {
Expand Down Expand Up @@ -337,7 +334,6 @@ impl SearchTree {
|| state.drawn_by_fifty_move_rule()
|| state.board().is_insufficient_material()
{
evaln = 0;
node = &DRAW_NODE;
break;
}
Expand All @@ -354,21 +350,15 @@ impl SearchTree {
node = new_node;
}

evaln = match node.flag {
Flag::TerminalWin => 2 * SCALE as i64,
Flag::TerminalLoss => -2 * SCALE as i64,
Flag::TablebaseWin => SCALE as i64,
Flag::TablebaseLoss => -SCALE as i64,
Flag::TerminalDraw | Flag::TablebaseDraw => 0,
Flag::Standard => evaln,
};
evaln = node.flag.adjust_eval(evaln);

Self::finish_playout(&path, evaln);

let depth = path.len();
let num_nodes = self.num_nodes.fetch_add(depth, Ordering::Relaxed) + depth;
self.max_depth.fetch_max(depth, Ordering::Relaxed);
let playouts = self.playouts.fetch_add(1, Ordering::Relaxed) + 1;
self.playouts.fetch_add(1, Ordering::Relaxed);
tld.playouts += 1;

if node.is_tablebase() {
self.tb_hits.fetch_add(1, Ordering::Relaxed);
Expand All @@ -378,13 +368,29 @@ impl SearchTree {
return false;
}

if playouts % 128 == 0
&& (time_management.is_after_end() || stop_signal.load(Ordering::Relaxed))
{
if tld.playouts % 128 == 0 && stop_signal.load(Ordering::Relaxed) {
return false;
}

if playouts % 65536 == 0 {
let elapsed = time_management.elapsed();

if tld.playouts % 128 == 0 {
if let Some(hard_limit) = time_management.hard_limit() {
if elapsed >= hard_limit {
return false;
}
}

if let Some(soft_limit) = time_management.soft_limit() {
let opts = &self.search_options.time_management_options;

if elapsed >= soft_limit.mul_f32(self.soft_time_multiplier(opts)) {
return false;
}
}
}

if tld.playouts % 65536 == 0 {
let elapsed = time_management.elapsed().as_secs();

let next_info = self.next_info.fetch_max(elapsed, Ordering::Relaxed);
Expand Down Expand Up @@ -461,6 +467,19 @@ impl SearchTree {
)[0]
}

fn soft_time_multiplier(&self, opts: &TimeManagementOptions) -> f32 {
let mut m = 1.0;

let bm_frac = self.root_node().select_child_by_rewards().visits() as f32
/ self.root_node().visits() as f32;

m *= (opts.visits_base - bm_frac) * opts.visits_m;

m = m.clamp(opts.min_m, opts.max_m);

m
}

pub fn print_info(&self, time_management: &TimeManagement) {
let mut info_str = String::with_capacity(256);

Expand Down

0 comments on commit a519e1c

Please sign in to comment.