Skip to content

Commit

Permalink
Add option serialization option to separate primitive results into th…
Browse files Browse the repository at this point in the history
…eir own e-classes

Fixes #244
  • Loading branch information
saulshanabrook committed Oct 1, 2023
1 parent 43a6a3b commit 5005e65
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 21 deletions.
11 changes: 9 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1355,8 +1355,15 @@ impl EGraph {
}

/// Serializes the egraph for export to graphviz.
pub fn serialize_for_graphviz(&self) -> egraph_serialize::EGraph {
let mut serialized = self.serialize(SerializeConfig::default());
pub fn serialize_for_graphviz(
&self,
split_primitive_outputs: bool,
) -> egraph_serialize::EGraph {
let config = SerializeConfig {
split_primitive_outputs,
..Default::default()
};
let mut serialized = self.serialize(config);
serialized.inline_leaves();
serialized
}
Expand Down
26 changes: 20 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ struct Args {
to_dot: bool,
#[clap(long)]
to_svg: bool,
#[clap(long)]
serialize_split_primitive_outputs: bool,
}

#[allow(clippy::disallowed_macros)]
Expand Down Expand Up @@ -161,21 +163,33 @@ fn main() {
}
}
}

// if we are splitting primitive outputs, add `-split` to the end of the file name
let serialiaze_filename = if args.serialize_split_primitive_outputs {
input.with_file_name(format!(
"{}-split",
input.file_stem().unwrap().to_str().unwrap()
))
} else {
input.clone()
};
if args.to_json {
let json_path = input.with_extension("json");
let serialized = egraph.serialize(SerializeConfig::default());
let json_path = serialiaze_filename.with_extension("json");
let config = SerializeConfig {
split_primitive_outputs: args.serialize_split_primitive_outputs,
..SerializeConfig::default()
};
let serialized = egraph.serialize(config);
serialized.to_json_file(json_path).unwrap();
}

if args.to_dot || args.to_svg {
let serialized = egraph.serialize_for_graphviz();
let serialized = egraph.serialize_for_graphviz(args.serialize_split_primitive_outputs);
if args.to_dot {
let dot_path = input.with_extension("dot");
let dot_path = serialiaze_filename.with_extension("dot");
serialized.to_dot_file(dot_path).unwrap()
}
if args.to_svg {
let svg_path = input.with_extension("svg");
let svg_path = serialiaze_filename.with_extension("svg");
serialized.to_svg_file(svg_path).unwrap()
}
}
Expand Down
46 changes: 35 additions & 11 deletions src/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ pub struct SerializeConfig {
pub max_calls_per_function: Option<usize>,
// Whether to include temporary functions in the serialized graph
pub include_temporary_functions: bool,
// Whether to split primitive output values into their own e-classes with the function
pub split_primitive_outputs: bool,
}

impl Default for SerializeConfig {
Expand All @@ -23,14 +25,15 @@ impl Default for SerializeConfig {
max_functions: Some(40),
max_calls_per_function: Some(40),
include_temporary_functions: false,
split_primitive_outputs: false,
}
}
}

impl EGraph {
/// Serialize the egraph into a format that can be read by the egraph-serialize crate.
///
/// There are multiple different semantically valid ways to do this.
/// There are multiple different semantically valid ways to do this. This is how this implementation does it:
///
/// For node costs:
/// - Primitives: 1.0
Expand All @@ -41,6 +44,7 @@ impl EGraph {
/// - Functions: Function name + hash of input values
/// - Args which are eq sorts: Choose one ID from the e-class, distribute roughly evenly.
/// - Args and outputs values which are primitives: Sort name + hash of value
/// Notes: If `split_primitive_returns` is true, then each output value will be the function node id + `-output`
///
/// For e-classes:
/// - Eq sorts: Use the canonical ID of the e-class
Expand All @@ -53,7 +57,7 @@ impl EGraph {
/// - Edges in the visualization will be well distributed (used for animating changes in the visualization)
/// (Note that this will be changed in `<https://github.com/egraphs-good/egglog/pull/158>` so that edges point to exact nodes instead of looking up the e-class)
pub fn serialize(&self, config: SerializeConfig) -> egraph_serialize::EGraph {
// First collect a list of all the calls we want to serialize, into the function decl, the inputs, and the output, and if its an eq sort
// First collect a list of all the calls we want to serialize as (function decl, inputs, the output, the node id)
let all_calls: Vec<(&FunctionDecl, &ValueVec, &Value, egraph_serialize::NodeId)> = self
.functions
.values()
Expand Down Expand Up @@ -85,6 +89,8 @@ impl EGraph {

// Then create a mapping from each canonical e-class ID to the set of node IDs in that e-class
// Note that this is only for e-classes, primitives have e-classes equal to their node ID
// This is for when we need to find what node ID to use for an edge to an e-class, we can rotate them evenly
// amoung all possible options.
let mut node_ids: NodeIDs = all_calls
.iter()
.filter_map(|(_decl, _input, output, node_id)| {
Expand All @@ -103,13 +109,21 @@ impl EGraph {
.push_back(node_id.clone());
acc
});

let mut egraph = egraph_serialize::EGraph::default();
for (decl, input, output, node_id) in all_calls {
let eclass = self.serialize_value(&mut egraph, &mut node_ids, output).0;
let prim_node_id = if config.split_primitive_outputs {
Some(format!("{}-value", node_id.clone()))
} else {
None
};
let eclass = self
.serialize_value(&mut egraph, &mut node_ids, output, prim_node_id)
.0;
let children: Vec<_> = input
.iter()
// Filter out children which don't have an ID, meaning that we skipped emitting them due to size constraints
.filter_map(|v| self.serialize_value(&mut egraph, &mut node_ids, v).1)
.filter_map(|v| self.serialize_value(&mut egraph, &mut node_ids, v, None).1)
.collect();
egraph.nodes.insert(
node_id,
Expand All @@ -126,11 +140,16 @@ impl EGraph {

/// Serialize the value and return the eclass and node ID
/// If this is a primitive value, we will add the node to the data, but if it is an eclass, we will not
/// When this is called on the output of a node, we only use the e-class to know which e-class its a part of
/// When this is called on an input of a node, we only use the node ID to know which node to point to.
fn serialize_value(
&self,
egraph: &mut egraph_serialize::EGraph,
node_ids: &mut NodeIDs,
value: &Value,
// The node ID to use for a primitve value, if this is None, use the hash of the value and the sort name
// Set iff `split_primitive_outputs` is set and this is an output of a function.
prim_node_id: Option<String>,
) -> (egraph_serialize::ClassId, Option<egraph_serialize::NodeId>) {
let sort = self.get_sort(value).unwrap();
let (class_id, node_id): (egraph_serialize::ClassId, Option<egraph_serialize::NodeId>) =
Expand All @@ -140,16 +159,21 @@ impl EGraph {
let class_id: egraph_serialize::ClassId = canonical.to_string().into();
(class_id.clone(), get_node_id(node_ids, class_id))
} else {
let sort_name = sort.name().to_string();
let node_id_str = format!("{}-{}", sort_name, hash_values(vec![*value].as_slice()));
let (eclass, node_id): (egraph_serialize::ClassId, egraph_serialize::NodeId) =
(node_id_str.clone().into(), node_id_str.into());
let (class_id, node_id): (egraph_serialize::ClassId, egraph_serialize::NodeId) =
if let Some(node_id) = prim_node_id {
(node_id.clone().into(), node_id.into())
} else {
let sort_name = sort.name().to_string();
let node_id_str =
format!("{}-{}", sort_name, hash_values(vec![*value].as_slice()));
(node_id_str.clone().into(), node_id_str.into())
};
// Add node for value
{
let children: Vec<egraph_serialize::NodeId> = sort
.inner_values(value)
.into_iter()
.filter_map(|(_, v)| self.serialize_value(egraph, node_ids, &v).1)
.filter_map(|(_, v)| self.serialize_value(egraph, node_ids, &v, None).1)
.collect();
// If this is a container sort, use the name, otherwise use the value
let op: String = if sort.is_container_sort() {
Expand All @@ -162,13 +186,13 @@ impl EGraph {
node_id.clone(),
egraph_serialize::Node {
op,
eclass: eclass.clone(),
eclass: class_id.clone(),
cost: NotNan::new(0.0).unwrap(),
children,
},
);
};
(eclass, Some(node_id))
(class_id, Some(node_id))
};
egraph.class_data.insert(
class_id.clone(),
Expand Down
4 changes: 3 additions & 1 deletion tests/files.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ impl Run {
log::info!(" {}", msg);
}
// Test graphviz dot generation
egraph.serialize_for_graphviz().to_dot();
egraph.serialize_for_graphviz(false).to_dot();
// Also try splitting
egraph.serialize_for_graphviz(true).to_dot();
}
}
Err(err) => {
Expand Down
2 changes: 1 addition & 1 deletion web-demo/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub fn run_program(input: &str) -> Result {
match egraph.parse_and_run_program(input) {
Ok(outputs) => {
log::info!("egg ok, {} outputs", outputs.len());
let serialized = egraph.serialize_for_graphviz();
let serialized = egraph.serialize_for_graphviz(false);
Result {
text: outputs.join("<br>"),
dot: serialized.to_dot(),
Expand Down

0 comments on commit 5005e65

Please sign in to comment.