Skip to content

Commit

Permalink
working on better tests
Browse files Browse the repository at this point in the history
  • Loading branch information
oflatt committed Aug 20, 2024
1 parent 539a02a commit fdcd657
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 36 deletions.
22 changes: 0 additions & 22 deletions src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,6 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
///
/// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return a copy of `expr` when explanations are enabled
pub fn add_expr_uncanonical(&mut self, expr: &RecExpr<L>) -> Id {
eprintln!("Adding {:?} directly", expr);
self.add_expr_uncanonical_with_reason(expr, ExistsOrReason::Reason(ExistenceReason::Direct))
}

Expand Down Expand Up @@ -1228,27 +1227,6 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
(self.find(id1), did_union)
}

/// Like `union_instantiations`, but assumes that the `from_pat` and substitution
/// is guaranteed to match the egraph already.
/// Using this method makes existence explanations more precise.
pub fn union_instantiations_guaranteed_match(
&mut self,
from_pat: &PatternAst<L>,
to_pat: &PatternAst<L>,
subst: &Subst,
rule_name: impl Into<Symbol>,
) -> (Id, bool) {
// add the lhs without an existence reason,
// assuming it matches
let id1 = self.add_instantiation_noncanonical(from_pat, subst, None);
// add the rhs, making it equal to the lhs
let id2 =
self.add_instantiation_noncanonical(to_pat, subst, Some(ExistenceReason::EqualTo(id1)));

let did_union = self.perform_union(id1, id2, Some(Justification::Rule(rule_name.into())));
(self.find(id1), did_union)
}

/// Like [`EGraph::union_instantiations`] but assumes that the `from_term` is a
/// term that the `rule_name` rule matched.
///
Expand Down
8 changes: 3 additions & 5 deletions src/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ impl<L: Language> Explanation<L> {
assert!(has_forward ^ has_backward);

if has_forward {
eprintln!("Checking rewrite forward from {:?} to {:?}", current, next);
assert!(self.check_rewrite_at(current, next, &rule_table, true));
} else {
assert!(self.check_rewrite_at(current, next, &rule_table, false));
Expand Down Expand Up @@ -1326,22 +1327,19 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
}
ExistenceReason::EqualTo(adjacent_id) => {
let adjacent_node = &self.explainfind[usize::from(adjacent_id)];
// The node should be directly adjacent to another node
let connection = if node.parent_connection.next == adjacent_id {
let mut connection = node.parent_connection.clone();
connection.is_rewrite_forward = !connection.is_rewrite_forward;
std::mem::swap(&mut connection.next, &mut connection.current);
connection
} else {
assert!(
adjacent_node.parent_connection.next == term,
"existence reason between two nodes failed: not directly adjacent."
);
assert_eq!(node.parent_connection.next, adjacent_id);
adjacent_node.parent_connection.clone()
};

let adj = self.explain_adjacent(connection, cache, enode_cache, false);
let mut exp = self.explain_term_existence(adjacent_id, adj, cache, enode_cache);

exp.push(rest_of_proof);
exp
}
Expand Down
2 changes: 1 addition & 1 deletion src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ macro_rules! test_fn {
&[$( $goal.parse().unwrap() ),+],
None $(.or(Some($check_fn)))?,
check,
true $(&& $check_existence_explanations)?,
false $(|| $check_existence_explanations)?,
)
}};
}
23 changes: 15 additions & 8 deletions tests/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ impl Analysis<Math> for ConstantFold {
let data = egraph[id].data.clone();
if let Some((c, pat)) = data {
if egraph.are_explanations_enabled() {
egraph.union_instantiations_guaranteed_match(
egraph.union_instantiations(
&pat,
&format!("{}", c).parse().unwrap(),
&Default::default(),
Expand Down Expand Up @@ -227,15 +227,17 @@ egg::test_fn! {
egg::test_fn! {
#[should_panic(expected = "Could not prove goal 0")]
math_fail, rules(),
"(+ x y)" => "(/ x y)"
"(+ x y)" => "(/ x y)",
@existence false
}

egg::test_fn! {math_simplify_add, rules(), "(+ x (+ x (+ x x)))" => "(* 4 x)" }
egg::test_fn! {math_powers, rules(), "(* (pow 2 x) (pow 2 y))" => "(pow 2 (+ x y))"}

egg::test_fn! {
math_simplify_const, rules(),
"(+ 1 (- a (* (- 2 1) a)))" => "1"
"(+ 1 (- a (* (- 2 1) a)))" => "1",
@existence false
}

egg::test_fn! {
Expand All @@ -249,24 +251,27 @@ egg::test_fn! {
2)))"#
=>
"(/ 1 (sqrt five))"
@existence false
}

egg::test_fn! {
math_simplify_factor, rules(),
"(* (+ x 3) (+ x 1))"
=>
"(+ (+ (* x x) (* 4 x)) 3)"
@existence false
}

egg::test_fn! {math_diff_same, rules(), "(d x x)" => "1"}
// Existence proofs don't support analysis, so we turn tests for them off
egg::test_fn! {math_diff_same, rules(), "(d x x)" => "1"}
egg::test_fn! {math_diff_different, rules(), "(d x y)" => "0"}
egg::test_fn! {math_diff_simple1, rules(), "(d x (+ 1 (* 2 x)))" => "2"}
egg::test_fn! {math_diff_simple2, rules(), "(d x (+ 1 (* y x)))" => "y"}
egg::test_fn! {math_diff_ln, rules(), "(d x (ln x))" => "(/ 1 x)"}

egg::test_fn! {
diff_power_simple, rules(),
"(d x (pow x 3))" => "(* 3 (pow x 2))"
"(d x (pow x 3))" => "(* 3 (pow x 2))",
@existence false
}

egg::test_fn! {
Expand All @@ -280,11 +285,13 @@ egg::test_fn! {
.with_expr(&"(* x (- (* 3 x) 14))".parse().unwrap()),
"(d x (- (pow x 3) (* 7 (pow x 2))))"
=>
"(* x (- (* 3 x) 14))"
"(* x (- (* 3 x) 14))",
@existence false
}

egg::test_fn! {
integ_one, rules(), "(i 1 x)" => "x"
integ_one, rules(), "(i 1 x)" => "x",
@existence false
}

egg::test_fn! {
Expand Down
189 changes: 189 additions & 0 deletions tests/math_no_analysis.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
//! Since existence proofs don't support analysis,
//! we test egg without analysis here.

use egg::{rewrite as rw, *};
use ordered_float::NotNan;

pub type EGraph = egg::EGraph<Math, ()>;
pub type Rewrite = egg::Rewrite<Math, ()>;

pub type Constant = NotNan<f64>;

define_language! {
pub enum Math {
"d" = Diff([Id; 2]),
"i" = Integral([Id; 2]),

"+" = Add([Id; 2]),
"-" = Sub([Id; 2]),
"*" = Mul([Id; 2]),
"/" = Div([Id; 2]),
"pow" = Pow([Id; 2]),
"ln" = Ln(Id),
"sqrt" = Sqrt(Id),

"sin" = Sin(Id),
"cos" = Cos(Id),

Constant(Constant),
Symbol(Symbol),
}
}

// You could use egg::AstSize, but this is useful for debugging, since
// it will really try to get rid of the Diff operator
pub struct MathCostFn;
impl egg::CostFunction<Math> for MathCostFn {
type Cost = usize;
fn cost<C>(&mut self, enode: &Math, mut costs: C) -> Self::Cost
where
C: FnMut(Id) -> Self::Cost,
{
let op_cost = match enode {
Math::Diff(..) => 100,
Math::Integral(..) => 100,
_ => 1,
};
enode.fold(op_cost, |sum, i| sum + costs(i))
}
}

#[rustfmt::skip]
pub fn rules() -> Vec<Rewrite> { vec![
rw!("add-1-1"; "(+ 1 1)" => "2"),
rw!("add-0-r"; "(+ ?a 0)" => "?a"),
rw!("add-0-l"; "(+ 0 ?a)" => "?a"),
rw!("add-2-2"; "(+ 2 2)" => "4"),
rw!("add-3-1"; "(+ 3 1)" => "4"),
rw!("sub-0-r"; "(- ?a 0)" => "?a"),
rw!("sub-0-1"; "(- 0 1)" => "-1"),
rw!("sub-1-0"; "(- 1 0)" => "1"),
rw!("sub-1-1"; "(- 1 1)" => "0"),
rw!("add-2-neg1"; "(+ 2 -1)" => "1"),
rw!("comm-add"; "(+ ?a ?b)" => "(+ ?b ?a)"),
rw!("comm-mul"; "(* ?a ?b)" => "(* ?b ?a)"),
rw!("assoc-add"; "(+ ?a (+ ?b ?c))" => "(+ (+ ?a ?b) ?c)"),
rw!("assoc-mul"; "(* ?a (* ?b ?c))" => "(* (* ?a ?b) ?c)"),

rw!("sub-canon"; "(- ?a ?b)" => "(+ ?a (* -1 ?b))"),
rw!("div-canon"; "(/ ?a ?b)" => "(* ?a (pow ?b -1))"),
// rw!("canon-sub"; "(+ ?a (* -1 ?b))" => "(- ?a ?b)"),
// rw!("canon-div"; "(* ?a (pow ?b -1))" => "(/ ?a ?b)" if is_not_zero("?b")),

rw!("zero-add"; "(+ ?a 0)" => "?a"),
rw!("zero-mul"; "(* ?a 0)" => "0"),
rw!("one-mul"; "(* ?a 1)" => "?a"),

rw!("add-zero"; "?a" => "(+ ?a 0)"),
rw!("mul-one"; "?a" => "(* ?a 1)"),

rw!("cancel-sub"; "(- ?a ?a)" => "0"),
rw!("cancel-div"; "(/ ?a ?a)" => "1"),

rw!("distribute"; "(* ?a (+ ?b ?c))" => "(+ (* ?a ?b) (* ?a ?c))"),
rw!("factor" ; "(+ (* ?a ?b) (* ?a ?c))" => "(* ?a (+ ?b ?c))"),

rw!("pow-mul"; "(* (pow ?a ?b) (pow ?a ?c))" => "(pow ?a (+ ?b ?c))"),
rw!("pow0"; "(pow ?x 0)" => "1"),
rw!("pow1"; "(pow ?x 1)" => "?x"),
rw!("pow2"; "(pow ?x 2)" => "(* ?x ?x)"),
rw!("pow-recip"; "(pow ?x -1)" => "(/ 1 ?x)"),
rw!("recip-mul-div"; "(* ?x (/ 1 ?x))" => "1"),

rw!("d-variable"; "(d ?x ?x)" => "1"),
rw!("d-constant"; "(d ?x ?c)" => "0"),

rw!("d-add"; "(d ?x (+ ?a ?b))" => "(+ (d ?x ?a) (d ?x ?b))"),
rw!("d-mul"; "(d ?x (* ?a ?b))" => "(+ (* ?a (d ?x ?b)) (* ?b (d ?x ?a)))"),

rw!("d-sin"; "(d ?x (sin ?x))" => "(cos ?x)"),
rw!("d-cos"; "(d ?x (cos ?x))" => "(* -1 (sin ?x))"),

rw!("d-ln"; "(d ?x (ln ?x))" => "(/ 1 ?x)"),

rw!("d-power";
"(d ?x (pow ?f ?g))" =>
"(* (pow ?f ?g)
(+ (* (d ?x ?f)
(/ ?g ?f))
(* (d ?x ?g)
(ln ?f))))"
),

rw!("i-one"; "(i 1 ?x)" => "?x"),
rw!("i-power-const"; "(i (pow ?x ?c) ?x)" =>
"(/ (pow ?x (+ ?c 1)) (+ ?c 1))"),
rw!("i-cos"; "(i (cos ?x) ?x)" => "(sin ?x)"),
rw!("i-sin"; "(i (sin ?x) ?x)" => "(* -1 (cos ?x))"),
rw!("i-sum"; "(i (+ ?f ?g) ?x)" => "(+ (i ?f ?x) (i ?g ?x))"),
rw!("i-dif"; "(i (- ?f ?g) ?x)" => "(- (i ?f ?x) (i ?g ?x))"),
rw!("i-parts"; "(i (* ?a ?b) ?x)" =>
"(- (* ?a (i ?b ?x)) (i (* (d ?x ?a) (i ?b ?x)) ?x))"),
]}

egg::test_fn! {
existence_associate_adds, [
rw!("comm-add"; "(+ ?a ?b)" => "(+ ?b ?a)"),
rw!("assoc-add"; "(+ ?a (+ ?b ?c))" => "(+ (+ ?a ?b) ?c)"),
],
runner = Runner::default()
.with_iter_limit(7)
.with_scheduler(SimpleScheduler),
"(+ 1 (+ 2 (+ 3 (+ 4 (+ 5 (+ 6 7))))))"
=>
"(+ 7 (+ 6 (+ 5 (+ 4 (+ 3 (+ 2 1))))))"
@check |r: Runner<Math, ()>| assert_eq!(r.egraph.number_of_classes(), 127),
@existence true
}

egg::test_fn! {
#[should_panic(expected = "Could not prove goal 0")]
existence_fail, rules(),
"(+ x y)" => "(/ x y)",
@existence true
}

egg::test_fn! {existence_simplify_add, rules(), "(+ x (+ x (+ x x)))" => "(* 4 x)", @existence true }
egg::test_fn! {existence_powers, rules(), "(* (pow 2 x) (pow 2 y))" => "(pow 2 (+ x y))", @existence true}

egg::test_fn! {
existence_simplify_const, rules(),
"(+ 1 (- a (* (- 2 1) a)))" => "1",
@existence true
}

egg::test_fn! {
existence_simplify_factor, rules(),
"(* (+ x 3) (+ x 1))"
=>
"(+ (+ (* x x) (* 4 x)) 3)"
@existence true
}

egg::test_fn! {existence_diff_same, rules(), "(d x x)" => "1", @existence true}
egg::test_fn! {existence_diff_different, rules(), "(d x y)" => "0", @existence true}
egg::test_fn! {existence_diff_simple2, rules(), "(d x (+ 1 (* y x)))" => "y", @existence true}
egg::test_fn! {existence_diff_ln, rules(), "(d x (ln x))" => "(/ 1 x)", @existence true}

egg::test_fn! {
existence_diff_power_simple, rules(),
"(d x (pow x 3))" => "(* 3 (pow x 2))",
@existence true
}

egg::test_fn! {
existence_integ_one, rules(), "(i 1 x)" => "x",
@existence true
}

egg::test_fn! {
existence_integ_sin, rules(), "(i (cos x) x)" => "(sin x)", @existence true
}

egg::test_fn! {
existence_integ_x, rules(), "(i (pow x 1) x)" => "(/ (pow x 2) 2)", @existence true
}

egg::test_fn! {
existence_integ_part1, rules(), "(i (* x (cos x)) x)" => "(+ (* x (sin x)) (cos x))", @existence true
}

0 comments on commit fdcd657

Please sign in to comment.