Skip to content

Commit

Permalink
allow sizes to be most generally be float
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Sep 1, 2023
1 parent 2afb7e8 commit ce998b2
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ fn compute_legs(ilegs: &Legs, jlegs: &Legs, appearances: &Vec<Count>) -> Legs {
}

fn compute_size(legs: &Legs, sizes: &Vec<Score>) -> Score {
// legs.iter().map(|&(ix, _)| sizes[ix as usize]).product()
legs.iter().map(|&(ix, _)| sizes[ix as usize]).sum()
}

Expand Down Expand Up @@ -131,7 +130,7 @@ impl ContractionProcessor {
fn new(
inputs: Vec<Vec<char>>,
output: Vec<char>,
size_dict: Dict<char, u32>,
size_dict: Dict<char, f32>,
) -> ContractionProcessor {
let mut nodes: Dict<Node, Legs> = Dict::default();
let mut edges: Dict<Ix, Vec<Node>> = Dict::default();
Expand Down Expand Up @@ -367,16 +366,15 @@ impl ContractionProcessor {
fn optimize_greedy(&mut self, costmod: Option<f32>, temperature: Option<f32>) {
let mut rng = rand::thread_rng();
let coeff_t = temperature.unwrap_or(0.0);
let coeff_a = costmod.unwrap_or(1.0);
let log_coeff_a = f32::ln(coeff_a);
let log_coeff_a = f32::ln(costmod.unwrap_or(1.0));

let mut local_score = |sa: Score, sb: Score, sab: Score| -> Score {
let gumbel = if coeff_t.ne(&(0.0 as f32)) {
-f32::ln(-f32::ln(rng.gen()))
let gumbel = if coeff_t != 0.0 {
coeff_t * -f32::ln(-f32::ln(rng.gen()))
} else {
0.0 as f32
};
logsub(sab, logadd(sa, sb) + log_coeff_a) - coeff_t * gumbel
logsub(sab, log_coeff_a + logadd(sa, sb)) - gumbel
};

// cache all current nodes sizes as we go
Expand Down Expand Up @@ -756,7 +754,7 @@ impl ContractionProcessor {
fn find_subgraphs(
inputs: Vec<Vec<char>>,
output: Vec<char>,
size_dict: Dict<char, u32>,
size_dict: Dict<char, f32>,
) -> Vec<Vec<Node>> {
let cp = ContractionProcessor::new(inputs, output, size_dict);
cp.subgraphs()
Expand All @@ -767,7 +765,7 @@ fn find_subgraphs(
fn optimize_simplify(
inputs: Vec<Vec<char>>,
output: Vec<char>,
size_dict: Dict<char, u32>,
size_dict: Dict<char, f32>,
) -> SSAPath {
let mut cp = ContractionProcessor::new(inputs, output, size_dict);
cp.simplify();
Expand All @@ -779,7 +777,7 @@ fn optimize_simplify(
fn optimize_greedy(
inputs: Vec<Vec<char>>,
output: Vec<char>,
size_dict: Dict<char, u32>,
size_dict: Dict<char, f32>,
costmod: Option<f32>,
temperature: Option<f32>,
simplify: Option<bool>,
Expand All @@ -799,7 +797,7 @@ fn optimize_greedy(
fn optimize_optimal(
inputs: Vec<Vec<char>>,
output: Vec<char>,
size_dict: Dict<char, u32>,
size_dict: Dict<char, f32>,
minimize: Option<String>,
factor: Option<Score>,
cost_cap: Option<Score>,
Expand Down

0 comments on commit ce998b2

Please sign in to comment.