Skip to content

Commit

Permalink
Clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
mwillsey committed Nov 4, 2021
1 parent 4ec6002 commit c30c99a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 31 deletions.
31 changes: 15 additions & 16 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@ use pyo3::{
basic::CompareOp,
prelude::*,
types::{PyList, PyString, PyTuple, PyType},
PyObjectProtocol,
};

macro_rules! impl_py_object {
($t:ty) => {
#[pyproto]
impl PyObjectProtocol for $t {
macro_rules! py_object {
(impl $t:ty { $($rest:tt)* }) => {
#[pymethods]
impl $t {
$($rest)*

fn __str__(&self) -> String {
self.0.to_string()
}
Expand All @@ -42,13 +43,18 @@ macro_rules! impl_py_object {
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct Id(egg::Id);

impl_py_object!(Id);
py_object!(impl Id {});

#[pyclass]
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct Var(egg::Var);

impl_py_object!(Var);
py_object!(impl Var {
#[new]
fn new(str: &PyString) -> Self {
Self::from_str(str.to_string_lossy().as_ref())
}
});

impl Var {
fn from_str(str: &str) -> Self {
Expand All @@ -57,14 +63,6 @@ impl Var {
}
}

#[pymethods]
impl Var {
#[new]
fn new(str: &PyString) -> Self {
Self::from_str(str.to_string_lossy().as_ref())
}
}

#[derive(Debug, Clone)]
struct PyLang {
obj: PyObject,
Expand Down Expand Up @@ -234,7 +232,7 @@ impl Rewrite {

#[pyclass]
#[derive(Default)]
pub struct EGraph {
struct EGraph {
egraph: egg::EGraph<PyLang, ()>,
}

Expand Down Expand Up @@ -344,6 +342,7 @@ where
#[pymodule]
fn snake_egg(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<EGraph>()?;
m.add_class::<ENode>()?;
m.add_class::<Id>()?;
m.add_class::<Var>()?;
m.add_class::<Pattern>()?;
Expand Down
16 changes: 1 addition & 15 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,6 @@
Add = namedtuple('Add', 'x y')
Mul = namedtuple('Mul', 'x y')


class ENode:
def __rshift__(self, other):
return Rewrite(self, other)


class Foo(tuple, ENode):
def __new__(cls, *args):
return super().__new__(cls, tuple(args))


print(inspect.getmro(Foo))


x, y, z = vars('x y z')

print(str(Add(x, y)))
Expand All @@ -33,7 +19,7 @@ def __new__(cls, *args):
Rewrite(Mul(x, Mul(y, z)), Mul(Mul(x, y), z)),
Rewrite(Add(x, 0), x),
Rewrite(Mul(x, 0), 0),
Foo(x, 1) >> x,
Rewrite(Mul(x, 1), x),
Rewrite(Add(x, x), Mul(x, 2)),
]

Expand Down

0 comments on commit c30c99a

Please sign in to comment.