Skip to content

Commit

Permalink
Fix Explanation Initialization to Share Subexpressions (egraphs-good#155
Browse files Browse the repository at this point in the history
)
  • Loading branch information
oflatt authored Jan 12, 2022
1 parent cfdce7f commit d0acf72
Showing 1 changed file with 68 additions and 19 deletions.
87 changes: 68 additions & 19 deletions src/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ pub type FlatExplanation<L> = Vec<FlatTerm<L>>;

// given two adjacent nodes and the direction of the proof
type ExplainCache<L> = HashMap<(Id, Id), Rc<TreeTerm<L>>>;
type NodeExplanationCache<L> = HashMap<Id, Rc<TreeTerm<L>>>;

/** A data structure representing an explanation that two terms are equivalent.
Expand Down Expand Up @@ -726,13 +727,23 @@ impl<L: Language> FlatTerm<L> {
}

impl<L: Language> Explain<L> {
fn node_to_explanation(&self, node_id: Id) -> TreeTerm<L> {
let node = self.explainfind[usize::from(node_id)].node.clone();
let children = node.fold(vec![], |mut sofar, child| {
sofar.push(vec![Rc::new(self.node_to_explanation(child))]);
sofar
});
TreeTerm::new(node, children)
fn node_to_explanation(
&self,
node_id: Id,
cache: &mut NodeExplanationCache<L>,
) -> Rc<TreeTerm<L>> {
if let Some(existing) = cache.get(&node_id) {
existing.clone()
} else {
let node = self.explainfind[usize::from(node_id)].node.clone();
let children = node.fold(vec![], |mut sofar, child| {
sofar.push(vec![self.node_to_explanation(child, cache)]);
sofar
});
let res = Rc::new(TreeTerm::new(node, children));
cache.insert(node_id, res.clone());
res
}
}

fn node_to_flat_explanation(&self, node_id: Id) -> FlatTerm<L> {
Expand Down Expand Up @@ -942,7 +953,8 @@ impl<L: Language> Explain<L> {
let left_added = self.add_expr(left, memo, unionfind);
let right_added = self.add_match(right, &subst, memo, unionfind);
let mut cache = Default::default();
Explanation::new(self.explain_enodes(left_added, right_added, &mut cache))
let mut enode_cache = Default::default();
Explanation::new(self.explain_enodes(left_added, right_added, &mut cache, &mut enode_cache))
}

pub(crate) fn explain_equivalence(
Expand All @@ -952,10 +964,13 @@ impl<L: Language> Explain<L> {
memo: &HashMap<L, Id>,
unionfind: &mut UnionFind,
) -> Explanation<L> {
println!("{:?}", left);
let left_added = self.add_expr(left, memo, unionfind);
let right_added = self.add_expr(right, memo, unionfind);
println!("done!");
let mut cache = Default::default();
Explanation::new(self.explain_enodes(left_added, right_added, &mut cache))
let mut enode_cache = Default::default();
Explanation::new(self.explain_enodes(left_added, right_added, &mut cache, &mut enode_cache))
}

pub(crate) fn explain_existance(
Expand All @@ -966,10 +981,12 @@ impl<L: Language> Explain<L> {
) -> Explanation<L> {
let left_added = self.add_expr(left, memo, unionfind);
let mut cache = Default::default();
let mut enode_cache = Default::default();
Explanation::new(self.explain_enode_existance(
left_added,
Rc::new(self.node_to_explanation(left_added)),
self.node_to_explanation(left_added, &mut enode_cache),
&mut cache,
&mut enode_cache,
))
}

Expand All @@ -982,10 +999,12 @@ impl<L: Language> Explain<L> {
) -> Explanation<L> {
let left_added = self.add_match(left, &subst, memo, unionfind);
let mut cache = Default::default();
let mut enode_cache = Default::default();
Explanation::new(self.explain_enode_existance(
left_added,
Rc::new(self.node_to_explanation(left_added)),
self.node_to_explanation(left_added, &mut enode_cache),
&mut cache,
&mut enode_cache,
))
}

Expand Down Expand Up @@ -1033,13 +1052,14 @@ impl<L: Language> Explain<L> {
node: Id,
rest_of_proof: Rc<TreeTerm<L>>,
cache: &mut ExplainCache<L>,
enode_cache: &mut NodeExplanationCache<L>,
) -> TreeExplanation<L> {
let graphnode = &self.explainfind[usize::from(node)];
let existance = graphnode.existance_node;
let existance_node = &self.explainfind[usize::from(existance)];
// case 1)
if existance == node {
return vec![Rc::new(self.node_to_explanation(node)), rest_of_proof];
return vec![self.node_to_explanation(node, enode_cache), rest_of_proof];
}

// case 2)
Expand All @@ -1055,13 +1075,21 @@ impl<L: Language> Explain<L> {
}
return self.explain_enode_existance(
existance,
self.explain_adjacent(existance, node, direction, justification, cache),
self.explain_adjacent(
existance,
node,
direction,
justification,
cache,
enode_cache,
),
cache,
enode_cache,
);
}

// case 3)
let mut new_rest_of_proof = self.node_to_explanation(existance);
let mut new_rest_of_proof = (*self.node_to_explanation(existance, enode_cache)).clone();
let mut index_of_child = 0;
let mut found = false;
existance_node.node.for_each(|child| {
Expand All @@ -1077,20 +1105,26 @@ impl<L: Language> Explain<L> {
assert!(found);
new_rest_of_proof.child_proofs[index_of_child].push(rest_of_proof);

self.explain_enode_existance(existance, Rc::new(new_rest_of_proof), cache)
self.explain_enode_existance(existance, Rc::new(new_rest_of_proof), cache, enode_cache)
}

fn explain_enodes(
&self,
left: Id,
right: Id,
cache: &mut ExplainCache<L>,
node_explanation_cache: &mut NodeExplanationCache<L>,
) -> TreeExplanation<L> {
let mut proof = vec![Rc::new(self.node_to_explanation(left))];
println!("({}, {})", left, right);

let mut proof = vec![self.node_to_explanation(left, node_explanation_cache)];
println!("node to explanation");
let ancestor = self.common_ancestor(left, right);
let left_nodes = self.get_nodes(left, ancestor);
let right_nodes = self.get_nodes(right, ancestor);

println!("got nodes");

for (i, node) in left_nodes
.iter()
.chain(right_nodes.iter().rev())
Expand All @@ -1104,7 +1138,14 @@ impl<L: Language> Explain<L> {
std::mem::swap(&mut next, &mut current);
}

proof.push(self.explain_adjacent(current, next, direction, &node.justification, cache));
proof.push(self.explain_adjacent(
current,
next,
direction,
&node.justification,
cache,
node_explanation_cache,
));
}
proof
}
Expand All @@ -1116,16 +1157,19 @@ impl<L: Language> Explain<L> {
rule_direction: bool,
justification: &Justification,
cache: &mut ExplainCache<L>,
node_explanation_cache: &mut NodeExplanationCache<L>,
) -> Rc<TreeTerm<L>> {
let fingerprint = (current, next);
println!("finger {} {}", current, next);

if let Some(answer) = cache.get(&fingerprint) {
return answer.clone();
}

let term = match justification {
Justification::Rule(name) => {
let mut rewritten = self.node_to_explanation(next);
let mut rewritten =
(*self.node_to_explanation(next, node_explanation_cache)).clone();
if rule_direction {
rewritten.forward_rule = Some(*name);
} else {
Expand All @@ -1146,7 +1190,12 @@ impl<L: Language> Explain<L> {
.iter()
.zip(next_node.children().iter())
{
subproofs.push(self.explain_enodes(*left_child, *right_child, cache));
subproofs.push(self.explain_enodes(
*left_child,
*right_child,
cache,
node_explanation_cache,
));
}
Rc::new(TreeTerm::new(current_node.clone(), subproofs))
}
Expand Down

0 comments on commit d0acf72

Please sign in to comment.