Skip to content

Commit

Permalink
Clean up; add a couple tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mwillsey committed Apr 28, 2022
1 parent 3f471ba commit f3cfa4e
Showing 1 changed file with 28 additions and 9 deletions.
37 changes: 28 additions & 9 deletions src/multipattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,18 @@ impl<L: Language, A: Analysis<L>> Applier<L, A> for MultiPattern<L> {
}

fn vars(&self) -> Vec<Var> {
let mut bound_vars = HashSet::default();
let mut vars = vec![];
// TODO are unbound binding vars allowed?
for (_v, pat) in &self.asts {
for (bv, pat) in &self.asts {
for n in pat.as_ref() {
if let ENodeOrVar::Var(v) = n {
vars.push(*v)
// using vars that are already bound doesn't count
if !bound_vars.contains(v) {
vars.push(*v)
}
}
}
bound_vars.insert(bv);
}
vars.sort();
vars.dedup();
Expand All @@ -193,11 +197,22 @@ mod tests {
type EGraph = crate::EGraph<S, ()>;

impl EGraph {
fn add_string(self: &mut Self, s: &str) -> Id {
fn add_string(&mut self, s: &str) -> Id {
self.add_expr(&s.parse().unwrap())
}
}

#[test]
#[should_panic = "unbound var ?z"]
fn bad_unbound_var() {
let _: Rewrite<S, ()> = rewrite!("foo"; "?x = (foo ?y)" |- "?x = ?z");
}

#[test]
fn ok_unbound_var() {
let _: Rewrite<S, ()> = rewrite!("foo"; "?x = (foo ?y)" |- "?z = (baz ?y), ?x = ?z");
}

#[test]
fn multi_patterns() {
crate::init_logger();
Expand Down Expand Up @@ -254,17 +269,21 @@ mod tests {
let z1 = egraph.add_string("(tag z ctx2)");
egraph.union(x1, y1);
egraph.union(y2, z2);
let rules = vec![
rewrite!("context-transfer"; "?x = (tag ?a ?ctx1) = (tag ?b ?ctx1), ?t = (lte ?ctx1 ?ctx2), ?a1 = (tag ?a ?ctx2), ?b1 = (tag ?b ?ctx2)" |- "?a1 = ?b1"),
];
let rules = vec![rewrite!("context-transfer";
"?x = (tag ?a ?ctx1) = (tag ?b ?ctx1),
?t = (lte ?ctx1 ?ctx2),
?a1 = (tag ?a ?ctx2),
?b1 = (tag ?b ?ctx2)"
|-
"?a1 = ?b1")];
let runner = Runner::default().with_egraph(egraph).run(&rules);
assert_eq!(runner.egraph.find(x1), runner.egraph.find(y1));
assert_eq!(runner.egraph.find(y2), runner.egraph.find(z2));

assert_eq!(runner.egraph.find(x2), runner.egraph.find(y2));
assert_eq!(runner.egraph.find(x2), runner.egraph.find(z2));

assert!(runner.egraph.find(y1) != runner.egraph.find(z1));
assert!(runner.egraph.find(x1) != runner.egraph.find(z1));
assert_ne!(runner.egraph.find(y1), runner.egraph.find(z1));
assert_ne!(runner.egraph.find(x1), runner.egraph.find(z1));
}
}

0 comments on commit f3cfa4e

Please sign in to comment.