Skip to content

Commit

Permalink
random greedy: use early stopping
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed May 8, 2024
1 parent d6dafc9 commit d1ea4f8
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ struct ContractionProcessor {
ssa_path: SSAPath,
track_flops: bool,
flops: Score,
flops_limit: Score,
}

/// given log(x) and log(y) compute log(x + y), without exponentiating both
Expand Down Expand Up @@ -195,6 +196,7 @@ impl ContractionProcessor {
let ssa = nodes.len() as Node;
let ssa_path: SSAPath = Vec::with_capacity(2 * ssa as usize - 1);
let flops: Score = 0.0;
let flops_limit: Score = Score::INFINITY;

ContractionProcessor {
nodes,
Expand All @@ -205,6 +207,7 @@ impl ContractionProcessor {
ssa_path,
track_flops,
flops,
flops_limit,
}
}

Expand Down Expand Up @@ -415,7 +418,7 @@ impl ContractionProcessor {
costmod: Option<f32>,
temperature: Option<f32>,
seed: Option<u64>,
) {
) -> bool {
let coeff_t = temperature.unwrap_or(0.0);
let log_coeff_a = f32::ln(costmod.unwrap_or(1.0));

Expand Down Expand Up @@ -483,6 +486,12 @@ impl ContractionProcessor {

// perform contraction:
let k = self.contract_nodes_given_legs(i, j, klegs.clone());

if self.track_flops && self.flops >= self.flops_limit {
// stop if we have reached the flops limit
return false;
}

node_sizes.insert(k, ksize);

for l in self.neighbors(k) {
Expand All @@ -498,6 +507,8 @@ impl ContractionProcessor {
c -= 1;
}
}
// success
return true;
}

/// Optimize the contraction order of all terms using a greedy algorithm
Expand Down Expand Up @@ -990,13 +1001,19 @@ fn optimize_random_greedy_track_flops(
};

// greedily contract each connected subgraph
cp.optimize_greedy(Some(costmod), Some(temperature), Some(seed));
let success = cp.optimize_greedy(Some(costmod), Some(temperature), Some(seed));

if !success {
continue;
}

// optimize any remaining disconnected terms
cp.optimize_remaining_by_size();

if cp.flops < best_flops {
best_flops = cp.flops;
best_path = Some(cp.ssa_path);
best_flops = cp.flops;
cp0.flops_limit = cp.flops;
}
}

Expand Down

0 comments on commit d1ea4f8

Please sign in to comment.