From 369260b41bccb2c73d0cd9d83fed780c61aeb371 Mon Sep 17 00:00:00 2001 From: Johnnie Gray Date: Tue, 7 May 2024 23:19:04 -0700 Subject: [PATCH] remove float match literal --- src/lib.rs | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index ad4946e..6741aa5 100755 --- a/src/lib.rs +++ b/src/lib.rs @@ -963,8 +963,15 @@ fn optimize_random_greedy_track_flops( use_ssa: Option, ) -> (Vec>, Score) { py.allow_threads(|| { - let (costmodmin, costmodmax) = costmod.unwrap_or((0.1, 4.0)); - let (tempmin, tempmax) = temperature.unwrap_or((0.001, 1.0)); + let (costmod_min, costmod_max) = costmod.unwrap_or((0.1, 4.0)); + let costmod_diff = (costmod_max - costmod_min).abs(); + let is_const_costmod = costmod_diff < Score::EPSILON; + + let (temp_min, temp_max) = temperature.unwrap_or((0.001, 1.0)); + let log_temp_min = Score::ln(temp_min); + let log_temp_max = Score::ln(temp_max); + let log_temp_diff = (log_temp_max - log_temp_min).abs(); + let is_const_temp = log_temp_diff < Score::EPSILON; let mut rng = match seed { Some(seed) => rand::rngs::StdRng::seed_from_u64(seed), @@ -982,22 +989,21 @@ fn optimize_random_greedy_track_flops( let mut best_path = None; let mut best_flops = f32::INFINITY; - let logtempmin = f32::ln(tempmin); - let logtempmax = f32::ln(tempmax); - for seed in seeds { let mut cp = cp0.clone(); // uniform sample for costmod - let costmod = match costmodmax - costmodmin { - 0.0 => costmodmin, - diff => costmodmin + rng.gen::() * diff, + let costmod = if is_const_costmod { + costmod_min + } else { + costmod_min + rng.gen::() * costmod_diff }; // log-uniform sample for temperature - let temperature = match logtempmax - logtempmin { - 0.0 => tempmin, - diff => f32::exp(logtempmin + rng.gen::() * diff), + let temperature = if is_const_temp { + temp_min + } else { + f32::exp(log_temp_min + rng.gen::() * log_temp_diff) }; // greedily contract each connected subgraph