diff --git a/src/evaluation.rs b/src/evaluation.rs index 4efb044..42fc78c 100644 --- a/src/evaluation.rs +++ b/src/evaluation.rs @@ -37,6 +37,18 @@ impl Flag { Flag::TablebaseWin | Flag::TablebaseDraw | Flag::TablebaseLoss ) } + + #[must_use] + pub fn adjust_eval(self, eval: i64) -> i64 { + match self { + Flag::TerminalWin => 2 * SCALE as i64, + Flag::TerminalLoss => -2 * SCALE as i64, + Flag::TablebaseWin => SCALE as i64, + Flag::TablebaseLoss => -SCALE as i64, + Flag::TerminalDraw | Flag::TablebaseDraw => 0, + Flag::Standard => eval, + } + } } #[cfg(feature = "value-net")] diff --git a/src/options.rs b/src/options.rs index 0dc3cc1..415ebb2 100644 --- a/src/options.rs +++ b/src/options.rs @@ -88,12 +88,19 @@ static HASH: UciOption = UciOption::spin("Hash", 16, 1, 2 << 24); static THREADS: UciOption = UciOption::spin("Threads", 1, 1, 255); static MULTI_PV: UciOption = UciOption::spin("MultiPV", 1, 1, 255); static SYZYGY_PATH: UciOption = UciOption::string("SyzygyPath", ""); + static CPUCT: UciOption = UciOption::spin("CPuct", 16, 1, 2 << 16); static CPUCT_TAU: UciOption = UciOption::spin("CPuctTau", 84, 0, 100); static CVISITS_SELECTION: UciOption = UciOption::spin("CVisitsSelection", 1, 0, 100); static POLICY_TEMPERATURE: UciOption = UciOption::spin("PolicyTemperature", 100, 0, 2 << 16); static POLICY_TEMPERATURE_ROOT: UciOption = UciOption::spin("PolicyTemperatureRoot", 1450, 0, 2 << 16); + +static TM_MIN_M: UciOption = UciOption::spin("TMMinM", 10, 0, 2 << 16); +static TM_MAX_M: UciOption = UciOption::spin("TMMaxM", 500, 0, 2 << 16); +static TM_VISITS_BASE: UciOption = UciOption::spin("TMVisitsBase", 140, 0, 2 << 16); +static TM_VISITS_M: UciOption = UciOption::spin("TMVisitsM", 139, 0, 2 << 16); + static CHESS960: UciOption = UciOption::check("UCI_Chess960", false); static POLICY_ONLY: UciOption = UciOption::check("PolicyOnly", false); static SHOW_MOVESLEFT: UciOption = UciOption::check("UCI_ShowMovesLeft", false); @@ -108,6 +115,10 @@ static ALL_OPTIONS: &[UciOption] = &[ CVISITS_SELECTION, POLICY_TEMPERATURE, POLICY_TEMPERATURE_ROOT, + TM_MIN_M, + TM_MAX_M, + TM_VISITS_BASE, + TM_VISITS_M, CHESS960, POLICY_ONLY, SHOW_MOVESLEFT, @@ -187,6 +198,16 @@ pub struct SearchOptions { pub is_policy_only: bool, pub show_movesleft: bool, pub mcts_options: MctsOptions, + pub time_management_options: TimeManagementOptions, +} + +#[allow(clippy::module_name_repetitions)] +#[derive(Debug, Clone, Copy)] +pub struct TimeManagementOptions { + pub min_m: f32, + pub max_m: f32, + pub visits_base: f32, + pub visits_m: f32, } impl From<&UciOptionMap> for MctsOptions { @@ -200,6 +221,17 @@ impl From<&UciOptionMap> for MctsOptions { } } +impl From<&UciOptionMap> for TimeManagementOptions { + fn from(map: &UciOptionMap) -> Self { + Self { + min_m: map.get_f32(&TM_MIN_M), + max_m: map.get_f32(&TM_MAX_M), + visits_base: map.get_f32(&TM_VISITS_BASE), + visits_m: map.get_f32(&TM_VISITS_M), + } + } +} + impl Default for SearchOptions { fn default() -> Self { SearchOptions::from(&UciOptionMap::default()) @@ -217,6 +249,7 @@ impl From<&UciOptionMap> for SearchOptions { is_policy_only: map.get_and(&POLICY_ONLY, |s| s.parse().ok()), show_movesleft: map.get_and(&SHOW_MOVESLEFT, |s| s.parse().ok()), mcts_options: MctsOptions::from(map), + time_management_options: TimeManagementOptions::from(map), } } } diff --git a/src/search.rs b/src/search.rs index b5039fa..af76897 100644 --- a/src/search.rs +++ b/src/search.rs @@ -18,10 +18,12 @@ const MOVE_OVERHEAD: Duration = Duration::from_millis(50); pub const SCALE: f32 = 256. * 256.; +#[must_use] #[derive(Copy, Clone, Debug)] pub struct TimeManagement { start: Instant, - end: Option, + soft_limit: Option, + hard_limit: Option, node_limit: usize, } @@ -32,39 +34,41 @@ impl Default for TimeManagement { } impl TimeManagement { - #[must_use] pub fn from_duration(d: Duration) -> Self { - let start = Instant::now(); - let end = Some(start + d); - let node_limit = usize::MAX; + Self { + start: Instant::now(), + soft_limit: None, + hard_limit: Some(d), + node_limit: usize::MAX, + } + } + pub fn from_limits(soft: Duration, hard: Duration) -> Self { Self { - start, - end, - node_limit, + start: Instant::now(), + soft_limit: Some(soft), + hard_limit: Some(hard), + node_limit: usize::MAX, } } - #[must_use] pub fn infinite() -> Self { - let start = Instant::now(); - let end = None; - let node_limit = usize::MAX; - Self { - start, - end, - node_limit, + start: Instant::now(), + soft_limit: None, + hard_limit: None, + node_limit: usize::MAX, } } #[must_use] - pub fn is_after_end(&self) -> bool { - if let Some(end) = self.end { - Instant::now() > end - } else { - false - } + pub fn soft_limit(&self) -> Option { + self.soft_limit + } + + #[must_use] + pub fn hard_limit(&self) -> Option { + self.hard_limit } #[must_use] @@ -84,12 +88,14 @@ impl TimeManagement { pub struct ThreadData<'a> { pub allocator: LRAllocator<'a>, + pub playouts: usize, } impl<'a> ThreadData<'a> { fn create(tree: &'a SearchTree) -> Self { Self { allocator: tree.allocator(), + playouts: 0, } } } @@ -226,11 +232,12 @@ impl Search { move_time_fraction = (m + 2).min(move_time_fraction); } - let ideal_think_time = - (r + move_time_fraction * increment - MOVE_OVERHEAD) / move_time_fraction; - let max_think_time = r / 3; + let r = r - MOVE_OVERHEAD; + + let soft_limit = (r + move_time_fraction * increment) / move_time_fraction; + let hard_limit = r / 3; - think_time = TimeManagement::from_duration(ideal_think_time.min(max_think_time)); + think_time = TimeManagement::from_limits(soft_limit.min(hard_limit), hard_limit); } if self.search_options.is_policy_only { diff --git a/src/search_tree.rs b/src/search_tree.rs index d82bc3b..7b51737 100644 --- a/src/search_tree.rs +++ b/src/search_tree.rs @@ -9,7 +9,7 @@ use std::sync::atomic::{ use crate::arena::Error as ArenaError; use crate::chess; use crate::evaluation::{self, Flag}; -use crate::options::{MctsOptions, SearchOptions}; +use crate::options::{MctsOptions, SearchOptions, TimeManagementOptions}; use crate::search::{eval_in_cp, ThreadData}; use crate::search::{TimeManagement, SCALE}; use crate::state::State; @@ -290,10 +290,7 @@ impl SearchTree { loop { self.ttable.wait_if_flipping(); - if node.is_terminal() { - break; - } - if node.hots().is_empty() { + if node.is_terminal() || node.hots().is_empty() { break; } if node.is_tablebase() && state.halfmove_clock() == 0 { @@ -337,7 +334,6 @@ impl SearchTree { || state.drawn_by_fifty_move_rule() || state.board().is_insufficient_material() { - evaln = 0; node = &DRAW_NODE; break; } @@ -354,21 +350,15 @@ impl SearchTree { node = new_node; } - evaln = match node.flag { - Flag::TerminalWin => 2 * SCALE as i64, - Flag::TerminalLoss => -2 * SCALE as i64, - Flag::TablebaseWin => SCALE as i64, - Flag::TablebaseLoss => -SCALE as i64, - Flag::TerminalDraw | Flag::TablebaseDraw => 0, - Flag::Standard => evaln, - }; + evaln = node.flag.adjust_eval(evaln); Self::finish_playout(&path, evaln); let depth = path.len(); let num_nodes = self.num_nodes.fetch_add(depth, Ordering::Relaxed) + depth; self.max_depth.fetch_max(depth, Ordering::Relaxed); - let playouts = self.playouts.fetch_add(1, Ordering::Relaxed) + 1; + self.playouts.fetch_add(1, Ordering::Relaxed); + tld.playouts += 1; if node.is_tablebase() { self.tb_hits.fetch_add(1, Ordering::Relaxed); @@ -378,13 +368,29 @@ impl SearchTree { return false; } - if playouts % 128 == 0 - && (time_management.is_after_end() || stop_signal.load(Ordering::Relaxed)) - { + if tld.playouts % 128 == 0 && stop_signal.load(Ordering::Relaxed) { return false; } - if playouts % 65536 == 0 { + let elapsed = time_management.elapsed(); + + if tld.playouts % 128 == 0 { + if let Some(hard_limit) = time_management.hard_limit() { + if elapsed >= hard_limit { + return false; + } + } + + if let Some(soft_limit) = time_management.soft_limit() { + let opts = &self.search_options.time_management_options; + + if elapsed >= soft_limit.mul_f32(self.soft_time_multiplier(opts)) { + return false; + } + } + } + + if tld.playouts % 65536 == 0 { let elapsed = time_management.elapsed().as_secs(); let next_info = self.next_info.fetch_max(elapsed, Ordering::Relaxed); @@ -461,6 +467,19 @@ impl SearchTree { )[0] } + fn soft_time_multiplier(&self, opts: &TimeManagementOptions) -> f32 { + let mut m = 1.0; + + let bm_frac = self.root_node().select_child_by_rewards().visits() as f32 + / self.root_node().visits() as f32; + + m *= (opts.visits_base - bm_frac) * opts.visits_m; + + m = m.clamp(opts.min_m, opts.max_m); + + m + } + pub fn print_info(&self, time_management: &TimeManagement) { let mut info_str = String::with_capacity(256);