Skip to content

Commit 28f8cdb

Browse files
committed
Define __eq__ for basic types.
1 parent c06c3ef commit 28f8cdb

File tree

7 files changed

+86
-3
lines changed

7 files changed

+86
-3
lines changed

tests/snippets/math.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
assert a ** 3 == 64
1212
assert a * 3 == 12
1313
assert a / 2 == 2
14+
assert 2 == a / 2
1415
# assert a % 3 == 1
1516
assert a - 3 == 1
1617
assert -a == -4

vm/src/obj/objfloat.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,25 @@ fn set_value(obj: &PyObjectRef, value: f64) {
4242
obj.borrow_mut().kind = PyObjectKind::Float { value };
4343
}
4444

45+
fn float_eq(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
46+
arg_check!(
47+
vm,
48+
args,
49+
required = [(zelf, Some(vm.ctx.float_type())), (other, None)]
50+
);
51+
let zelf = get_value(zelf);
52+
let result = if objtype::isinstance(other.clone(), vm.ctx.float_type()) {
53+
let other = get_value(other);
54+
zelf == other
55+
} else if objtype::isinstance(other.clone(), vm.ctx.int_type()) {
56+
let other = objint::get_value(other) as f64;
57+
zelf == other
58+
} else {
59+
false
60+
};
61+
Ok(vm.ctx.new_bool(result))
62+
}
63+
4564
fn float_add(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
4665
arg_check!(
4766
vm,
@@ -97,6 +116,7 @@ fn float_pow(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
97116

98117
pub fn init(context: &PyContext) {
99118
let ref float_type = context.float_type;
119+
float_type.set_attr("__eq__", context.new_rustfunc(float_eq));
100120
float_type.set_attr("__add__", context.new_rustfunc(float_add));
101121
float_type.set_attr("__init__", context.new_rustfunc(float_init));
102122
float_type.set_attr("__pow__", context.new_rustfunc(float_pow));

vm/src/obj/objint.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,26 @@ impl FromPyObjectRef for i32 {
4949
}
5050
}
5151

52+
fn int_eq(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
53+
arg_check!(
54+
vm,
55+
args,
56+
required = [(zelf, Some(vm.ctx.int_type())), (other, None)]
57+
);
58+
let result = if objtype::isinstance(other.clone(), vm.ctx.int_type()) {
59+
let zelf = i32::from_pyobj(zelf);
60+
let other = i32::from_pyobj(other);
61+
zelf == other
62+
} else if objtype::isinstance(other.clone(), vm.ctx.float_type()) {
63+
let zelf = i32::from_pyobj(zelf) as f64;
64+
let other = objfloat::get_value(other);
65+
zelf == other
66+
} else {
67+
false
68+
};
69+
Ok(vm.ctx.new_bool(result))
70+
}
71+
5272
fn int_add(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
5373
arg_check!(
5474
vm,
@@ -144,6 +164,7 @@ fn int_pow(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
144164

145165
pub fn init(context: &PyContext) {
146166
let ref int_type = context.int_type;
167+
int_type.set_attr("__eq__", context.new_rustfunc(int_eq));
147168
int_type.set_attr("__add__", context.new_rustfunc(int_add));
148169
int_type.set_attr("__init__", context.new_rustfunc(int_init));
149170
int_type.set_attr("__mod__", context.new_rustfunc(int_mod));

vm/src/obj/objstr.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use super::objtype;
88

99
pub fn init(context: &PyContext) {
1010
let ref str_type = context.str_type;
11+
str_type.set_attr("__eq__", context.new_rustfunc(str_eq));
1112
str_type.set_attr("__add__", context.new_rustfunc(str_add));
1213
str_type.set_attr("__len__", context.new_rustfunc(str_len));
1314
str_type.set_attr("__mul__", context.new_rustfunc(str_mul));
@@ -23,6 +24,21 @@ pub fn get_value(obj: &PyObjectRef) -> String {
2324
}
2425
}
2526

27+
fn str_eq(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
28+
arg_check!(
29+
vm,
30+
args,
31+
required = [(a, Some(vm.ctx.str_type())), (b, None)]
32+
);
33+
34+
let result = if objtype::isinstance(b.clone(), vm.ctx.str_type()) {
35+
get_value(a) == get_value(b)
36+
} else {
37+
false
38+
};
39+
Ok(vm.ctx.new_bool(result))
40+
}
41+
2642
fn str_str(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
2743
arg_check!(vm, args, required = [(s, Some(vm.ctx.str_type()))]);
2844
Ok(s.clone())

vm/src/obj/objtype.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ pub fn init(context: &PyContext) {
2929
}
3030

3131
fn type_mro(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
32-
println!("{:?}", args);
3332
arg_check!(
3433
vm,
3534
args,

vm/src/objbool.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,22 @@ pub fn init(context: &PyContext) {
3535
let ref bool_type = context.bool_type;
3636
bool_type.set_attr("__new__", context.new_rustfunc(bool_new));
3737
bool_type.set_attr("__str__", context.new_rustfunc(bool_str));
38+
bool_type.set_attr("__eq__", context.new_rustfunc(bool_eq));
39+
}
40+
41+
fn bool_eq(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
42+
arg_check!(
43+
vm,
44+
args,
45+
required = [(zelf, Some(vm.ctx.bool_type())), (other, None)]
46+
);
47+
48+
let result = if objtype::isinstance(zelf.clone(), vm.ctx.bool_type()) {
49+
get_value(zelf) == get_value(other)
50+
} else {
51+
false
52+
};
53+
Ok(vm.ctx.new_bool(result))
3854
}
3955

4056
// Retrieve inner int value:

vm/src/objobject.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,16 @@ pub fn create_object(type_type: PyObjectRef, object_type: PyObjectRef, dict_type
2929
(*object_type.borrow_mut()).typ = Some(type_type.clone());
3030
}
3131

32-
fn obj_str(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
32+
fn object_eq(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
33+
arg_check!(
34+
vm,
35+
args,
36+
required = [(zelf, Some(vm.ctx.object())), (other, None)]
37+
);
38+
Ok(vm.ctx.new_bool(zelf.is(other)))
39+
}
40+
41+
fn object_str(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
3342
arg_check!(vm, args, required = [(obj, Some(vm.ctx.object()))]);
3443
let type_name = objtype::get_type_name(&obj.typ());
3544
let address = obj.get_id();
@@ -39,8 +48,9 @@ fn obj_str(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
3948
pub fn init(context: &PyContext) {
4049
let ref object = context.object;
4150
object.set_attr("__new__", context.new_rustfunc(new_instance));
51+
object.set_attr("__eq__", context.new_rustfunc(object_eq));
4252
object.set_attr("__dict__", context.new_member_descriptor(object_dict));
43-
object.set_attr("__str__", context.new_rustfunc(obj_str));
53+
object.set_attr("__str__", context.new_rustfunc(object_str));
4454
}
4555

4656
fn object_dict(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {

0 commit comments

Comments
 (0)