Skip to content

Commit

Permalink
add python test suite, fix repeated index bug
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Sep 5, 2023
1 parent 9d54a9c commit 2e39c09
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 27 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,3 @@ rustc-hash = "1.1"
codegen-units = 1
lto = true
opt-level = 3
panic = "abort"
16 changes: 11 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
[build-system]
requires = ["maturin>=0.15,<0.16"]
build-backend = "maturin"

[project]
name = "cotengrust"
requires-python = ">=3.7"
version = "0.1.0"
description = "Fast contraction ordering primitives for tensor networks."
requires-python = ">=3.8"
classifiers = [
"Programming Language :: Rust",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
license = { file = "LICENSE" }
authors = [
{name = "Johnnie Gray", email = "johnniemcgray@gmail.com"}

]

[build-system]
requires = ["maturin>=0.15,<0.16"]
build-backend = "maturin"

[tool.maturin]
features = ["pyo3/extension-module"]
99 changes: 78 additions & 21 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type SubContraction = (Legs, Score, BitPath);
/// helper struct to build contractions from bottom up
struct ContractionProcessor {
nodes: Dict<Node, Legs>,
edges: Dict<Ix, Vec<Node>>,
edges: Dict<Ix, BTreeSet<Node>>,
appearances: Vec<Count>,
sizes: Vec<Score>,
ssa: Node,
Expand Down Expand Up @@ -133,7 +133,7 @@ impl ContractionProcessor {
size_dict: Dict<char, f32>,
) -> ContractionProcessor {
let mut nodes: Dict<Node, Legs> = Dict::default();
let mut edges: Dict<Ix, Vec<Node>> = Dict::default();
let mut edges: Dict<Ix, BTreeSet<Node>> = Dict::default();
let mut indmap: Dict<char, Ix> = Dict::default();
let mut sizes: Vec<Score> = Vec::with_capacity(size_dict.len());
let mut appearances: Vec<Count> = Vec::with_capacity(size_dict.len());
Expand All @@ -147,7 +147,7 @@ impl ContractionProcessor {
None => {
// index not parsed yet
indmap.insert(ind, c);
edges.insert(c, vec![i as Node]);
edges.insert(c, std::iter::once(i as Node).collect());
appearances.push(1);
sizes.push(f32::log(size_dict[&ind] as f32, 2.0));
legs.push((c, 1));
Expand All @@ -156,7 +156,7 @@ impl ContractionProcessor {
Some(&ix) => {
// index already present
appearances[ix as usize] += 1;
edges.get_mut(&ix).unwrap().push(i as Node);
edges.get_mut(&ix).unwrap().insert(i as Node);
legs.push((ix, 1));
}
};
Expand Down Expand Up @@ -204,11 +204,15 @@ impl ContractionProcessor {
fn pop_node(&mut self, i: Node) -> Legs {
let legs = self.nodes.remove(&i).unwrap();
for (ix, _) in legs.iter() {
let nodes = self.edges.get_mut(&ix).unwrap();
if nodes.len() == 1 {
let enodes = match self.edges.get_mut(&ix) {
Some(enodes) => enodes,
// if repeated index, might have already been removed
None => continue,
};
enodes.remove(&i);
if enodes.len() == 0 {
// last node with this index -> remove from map
self.edges.remove(&ix);
} else {
nodes.retain(|&j| j != i);
}
}
legs
Expand All @@ -221,8 +225,8 @@ impl ContractionProcessor {
for (ix, _) in &legs {
self.edges
.entry(*ix)
.and_modify(|nodes| nodes.push(i))
.or_insert(vec![i]);
.and_modify(|nodes| {nodes.insert(i);})
.or_insert(std::iter::once(i as Node).collect());
}
self.nodes.insert(i, legs);
i
Expand Down Expand Up @@ -393,6 +397,8 @@ impl ContractionProcessor {

// get the initial candidate contractions
for ix_nodes in self.edges.values() {
// convert to vector for combinational indexing
let ix_nodes: Vec<Node> = ix_nodes.iter().cloned().collect();
// for all combinations of nodes with a connected edge
for ip in 0..ix_nodes.len() {
let i = ix_nodes[ip];
Expand Down Expand Up @@ -617,7 +623,7 @@ impl ContractionProcessor {
subgraph: Vec<Node>,
minimize: Option<String>,
cost_cap: Option<Score>,
allow_outer: Option<bool>,
search_outer: Option<bool>,
) {
// parse the minimize argument
let minimize = minimize.unwrap_or("flops".to_string());
Expand All @@ -642,7 +648,7 @@ impl ContractionProcessor {
minimize
),
};
let allow_outer = allow_outer.unwrap_or(false);
let search_outer = search_outer.unwrap_or(false);

// storage for each possible contraction to reach subgraph of size m
let mut contractions: Vec<Dict<Subgraph, SubContraction>> =
Expand Down Expand Up @@ -691,8 +697,8 @@ impl ContractionProcessor {
let mut temp_legs: Legs = Vec::with_capacity(ilegs.len() + jlegs.len());
ip = 0;
jp = 0;
// if allow_outer -> we will never skip
skip_because_outer = !allow_outer;
// if search_outer -> we will never skip
skip_because_outer = !search_outer;
while ip < ilegs.len() && jp < jlegs.len() {
if ilegs[ip].0 < jlegs[jp].0 {
// index only appears in ilegs
Expand Down Expand Up @@ -784,16 +790,44 @@ impl ContractionProcessor {
&mut self,
minimize: Option<String>,
cost_cap: Option<Score>,
allow_outer: Option<bool>,
search_outer: Option<bool>,
) {
for subgraph in self.subgraphs() {
self.optimize_optimal_connected(subgraph, minimize.clone(), cost_cap, allow_outer);
self.optimize_optimal_connected(subgraph, minimize.clone(), cost_cap, search_outer);
}
}
}

// --------------------------- PYTHON FUNCTIONS ---------------------------- //

#[pyfunction]
#[pyo3()]
fn ssa_to_linear(ssa_path: SSAPath, n: Option<usize>) -> SSAPath {
let n = match n {
Some(n) => n,
None => ssa_path.iter().map(|v| v.len()).sum::<usize>() + ssa_path.len() + 1,
};
let mut ids: Vec<Node> = (0..n).map(|i| i as Node).collect();
let mut path: SSAPath = Vec::with_capacity(2 * n - 1);
let mut ssa = n as Node;
for scon in ssa_path {
// find the locations of the ssa ids in the list of ids
let mut con: Vec<Node> = scon
.iter()
.map(|s| ids.binary_search(s).unwrap() as Node)
.collect();
// remove the ssa ids from the list
con.sort();
for j in con.iter().rev() {
ids.remove(*j as usize);
}
path.push(con);
ids.push(ssa);
ssa += 1;
}
path
}

#[pyfunction]
#[pyo3()]
fn find_subgraphs(
Expand All @@ -811,10 +845,16 @@ fn optimize_simplify(
inputs: Vec<Vec<char>>,
output: Vec<char>,
size_dict: Dict<char, f32>,
use_ssa: Option<bool>,
) -> SSAPath {
let n = inputs.len();
let mut cp = ContractionProcessor::new(inputs, output, size_dict);
cp.simplify();
cp.ssa_path
if use_ssa.unwrap_or(false) {
cp.ssa_path
} else {
ssa_to_linear(cp.ssa_path, Some(n))
}
}

#[pyfunction]
Expand All @@ -826,15 +866,23 @@ fn optimize_greedy(
costmod: Option<f32>,
temperature: Option<f32>,
simplify: Option<bool>,
use_ssa: Option<bool>,
) -> Vec<Vec<Node>> {
let n = inputs.len();
let mut cp = ContractionProcessor::new(inputs, output, size_dict);
if simplify.unwrap_or(true) {
// perform simplifications
cp.simplify();
}
// greddily contract each connected subgraph
cp.optimize_greedy(costmod, temperature);
// optimize any remaining disconnected terms
cp.optimize_remaining_by_size();
cp.ssa_path
if use_ssa.unwrap_or(false) {
cp.ssa_path
} else {
ssa_to_linear(cp.ssa_path, Some(n))
}
}

#[pyfunction]
Expand All @@ -845,22 +893,31 @@ fn optimize_optimal(
size_dict: Dict<char, f32>,
minimize: Option<String>,
cost_cap: Option<Score>,
allow_outer: Option<bool>,
search_outer: Option<bool>,
simplify: Option<bool>,
use_ssa: Option<bool>,
) -> Vec<Vec<Node>> {
let n = inputs.len();
let mut cp = ContractionProcessor::new(inputs, output, size_dict);
if simplify.unwrap_or(true) {
// perform simplifications
cp.simplify();
}
cp.optimize_optimal(minimize, cost_cap, allow_outer);
// optimally contract each connected subgraph
cp.optimize_optimal(minimize, cost_cap, search_outer);
// optimize any remaining disconnected terms
cp.optimize_remaining_by_size();
cp.ssa_path
if use_ssa.unwrap_or(false) {
cp.ssa_path
} else {
ssa_to_linear(cp.ssa_path, Some(n))
}
}

/// A Python module implemented in Rust.
#[pymodule]
fn cotengrust(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(ssa_to_linear, m)?)?;
m.add_function(wrap_pyfunction!(find_subgraphs, m)?)?;
m.add_function(wrap_pyfunction!(optimize_simplify, m)?)?;
m.add_function(wrap_pyfunction!(optimize_greedy, m)?)?;
Expand Down
Loading

0 comments on commit 2e39c09

Please sign in to comment.