Skip to content

Commit 16832c8

Browse files
Merge pull request RustPython#479 from calixteman/fix_equality
Fix issue with equality
2 parents 70d5cdb + 43d9fc5 commit 16832c8

File tree

4 files changed

+83
-10
lines changed

4 files changed

+83
-10
lines changed

tests/snippets/list.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,46 @@
7474
foo = bar = [1]
7575
foo += [2]
7676
assert (foo, bar) == ([1, 2], [1, 2])
77+
78+
79+
x = [1]
80+
x.append(x)
81+
assert x in x
82+
assert x.index(x) == 1
83+
assert x.count(x) == 1
84+
x.remove(x)
85+
assert x not in x
86+
87+
class Foo(object):
88+
def __eq__(self, x):
89+
return False
90+
91+
foo = Foo()
92+
foo1 = Foo()
93+
x = [1, foo, 2, foo, []]
94+
assert x == x
95+
assert foo in x
96+
assert 2 in x
97+
assert x.index(foo) == 1
98+
assert x.count(foo) == 2
99+
assert x.index(2) == 2
100+
assert [] in x
101+
assert x.index([]) == 4
102+
assert foo1 not in x
103+
x.remove(foo)
104+
assert x.index(foo) == 2
105+
assert x.count(foo) == 1
106+
107+
x = []
108+
x.append(x)
109+
assert x == x
110+
111+
a = [1, 2, 3]
112+
b = [1, 2, 3]
113+
c = [a, b]
114+
a.append(c)
115+
b.append(c)
116+
117+
assert a == b
118+
119+
assert [foo] == [foo]

tests/snippets/tuple.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,10 @@
2727

2828
assert (None, "", 1).index(1) == 2
2929
assert 1 in (None, "", 1)
30+
31+
class Foo(object):
32+
def __eq__(self, x):
33+
return False
34+
35+
foo = Foo()
36+
assert (foo,) == (foo,)

vm/src/obj/objlist.rs

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use super::super::pyobject::{
2-
PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol,
2+
IdProtocol, PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult,
3+
TypeProtocol,
34
};
45
use super::super::vm::{ReprGuard, VirtualMachine};
56
use super::objbool;
@@ -66,6 +67,10 @@ fn list_eq(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
6667
required = [(zelf, Some(vm.ctx.list_type())), (other, None)]
6768
);
6869

70+
if zelf.is(&other) {
71+
return Ok(vm.ctx.new_bool(true));
72+
}
73+
6974
let result = if objtype::isinstance(other, &vm.ctx.list_type()) {
7075
let zelf = get_elements(zelf);
7176
let other = get_elements(other);
@@ -249,9 +254,13 @@ fn list_count(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
249254
let elements = get_elements(zelf);
250255
let mut count: usize = 0;
251256
for element in elements.iter() {
252-
let is_eq = vm._eq(element.clone(), value.clone())?;
253-
if objbool::boolval(vm, is_eq)? {
257+
if value.is(&element) {
254258
count += 1;
259+
} else {
260+
let is_eq = vm._eq(element.clone(), value.clone())?;
261+
if objbool::boolval(vm, is_eq)? {
262+
count += 1;
263+
}
255264
}
256265
}
257266
Ok(vm.context().new_int(count))
@@ -277,6 +286,9 @@ fn list_index(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
277286
required = [(list, Some(vm.ctx.list_type())), (needle, None)]
278287
);
279288
for (index, element) in get_elements(list).iter().enumerate() {
289+
if needle.is(&element) {
290+
return Ok(vm.context().new_int(index));
291+
}
280292
let py_equal = vm._eq(needle.clone(), element.clone())?;
281293
if objbool::get_value(&py_equal) {
282294
return Ok(vm.context().new_int(index));
@@ -350,6 +362,9 @@ fn list_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
350362
required = [(list, Some(vm.ctx.list_type())), (needle, None)]
351363
);
352364
for element in get_elements(list).iter() {
365+
if needle.is(&element) {
366+
return Ok(vm.new_bool(true));
367+
}
353368
match vm._eq(needle.clone(), element.clone()) {
354369
Ok(value) => {
355370
if objbool::get_value(&value) {
@@ -431,9 +446,12 @@ fn list_remove(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
431446
required = [(list, Some(vm.ctx.list_type())), (needle, None)]
432447
);
433448

434-
let mut elements = get_mut_elements(list);
435449
let mut ri: Option<usize> = None;
436-
for (index, element) in elements.iter().enumerate() {
450+
for (index, element) in get_elements(list).iter().enumerate() {
451+
if needle.is(&element) {
452+
ri = Some(index);
453+
break;
454+
}
437455
let py_equal = vm._eq(needle.clone(), element.clone())?;
438456
if objbool::get_value(&py_equal) {
439457
ri = Some(index);
@@ -442,6 +460,7 @@ fn list_remove(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
442460
}
443461

444462
if let Some(index) = ri {
463+
let mut elements = get_mut_elements(list);
445464
elements.remove(index);
446465
Ok(vm.get_none())
447466
} else {

vm/src/obj/objsequence.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
use super::super::pyobject::{PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol};
1+
use super::super::pyobject::{
2+
IdProtocol, PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol,
3+
};
24
use super::super::vm::VirtualMachine;
35
use super::objbool;
46
use super::objint;
@@ -178,10 +180,12 @@ pub fn seq_equal(
178180
) -> Result<bool, PyObjectRef> {
179181
if zelf.len() == other.len() {
180182
for (a, b) in Iterator::zip(zelf.iter(), other.iter()) {
181-
let eq = vm._eq(a.clone(), b.clone())?;
182-
let value = objbool::boolval(vm, eq)?;
183-
if !value {
184-
return Ok(false);
183+
if !a.is(&b) {
184+
let eq = vm._eq(a.clone(), b.clone())?;
185+
let value = objbool::boolval(vm, eq)?;
186+
if !value {
187+
return Ok(false);
188+
}
185189
}
186190
}
187191
Ok(true)

0 commit comments

Comments
 (0)