Skip to content

Commit 3f8f0e2

Browse files
Merge pull request RustPython#524 from palaviv/Improve-set-5
Add set.{__iter__,__ior__,__iand__,__isub__,__ixor__}
2 parents 2d19486 + fc10560 commit 3f8f0e2

File tree

2 files changed

+95
-7
lines changed

2 files changed

+95
-7
lines changed

tests/snippets/set.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,17 +99,64 @@ def __hash__(self):
9999
assert a == set([1,2,3,4,5])
100100
assert_raises(TypeError, lambda: a.update(1))
101101

102+
a = set([1,2,3])
103+
b = set()
104+
for e in a:
105+
assert e == 1 or e == 2 or e == 3
106+
b.add(e)
107+
assert a == b
108+
109+
a = set([1,2,3])
110+
a |= set([3,4,5])
111+
assert a == set([1,2,3,4,5])
112+
try:
113+
a |= 1
114+
except TypeError:
115+
pass
116+
else:
117+
assert False, "TypeError not raised"
118+
102119
a = set([1,2,3])
103120
a.intersection_update([2,3,4,5])
104121
assert a == set([2,3])
105122
assert_raises(TypeError, lambda: a.intersection_update(1))
106123

124+
a = set([1,2,3])
125+
a &= set([2,3,4,5])
126+
assert a == set([2,3])
127+
try:
128+
a &= 1
129+
except TypeError:
130+
pass
131+
else:
132+
assert False, "TypeError not raised"
133+
107134
a = set([1,2,3])
108135
a.difference_update([3,4,5])
109136
assert a == set([1,2])
110137
assert_raises(TypeError, lambda: a.difference_update(1))
111138

139+
a = set([1,2,3])
140+
a -= set([3,4,5])
141+
assert a == set([1,2])
142+
try:
143+
a -= 1
144+
except TypeError:
145+
pass
146+
else:
147+
assert False, "TypeError not raised"
148+
112149
a = set([1,2,3])
113150
a.symmetric_difference_update([3,4,5])
114151
assert a == set([1,2,4,5])
115152
assert_raises(TypeError, lambda: a.difference_update(1))
153+
154+
a = set([1,2,3])
155+
a ^= set([3,4,5])
156+
assert a == set([1,2,4,5])
157+
try:
158+
a ^= 1
159+
except TypeError:
160+
pass
161+
else:
162+
assert False, "TypeError not raised"

vm/src/obj/objset.rs

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,11 @@ fn set_pop(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
433433
}
434434

435435
fn set_update(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
436+
set_ior(vm, args)?;
437+
Ok(vm.get_none())
438+
}
439+
440+
fn set_ior(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
436441
arg_check!(
437442
vm,
438443
args,
@@ -447,17 +452,27 @@ fn set_update(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
447452
while let Ok(v) = vm.call_method(&iterator, "__next__", vec![]) {
448453
insert_into_set(vm, elements, &v)?;
449454
}
450-
Ok(vm.get_none())
451455
}
452-
_ => Err(vm.new_type_error("set.update is called with no other".to_string())),
456+
_ => return Err(vm.new_type_error("set.update is called with no other".to_string())),
453457
}
458+
Ok(zelf.clone())
454459
}
455460

456461
fn set_intersection_update(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
462+
set_combine_update_inner(vm, args, SetCombineOperation::Intersection)?;
463+
Ok(vm.get_none())
464+
}
465+
466+
fn set_iand(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
457467
set_combine_update_inner(vm, args, SetCombineOperation::Intersection)
458468
}
459469

460470
fn set_difference_update(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
471+
set_combine_update_inner(vm, args, SetCombineOperation::Difference)?;
472+
Ok(vm.get_none())
473+
}
474+
475+
fn set_isub(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
461476
set_combine_update_inner(vm, args, SetCombineOperation::Difference)
462477
}
463478

@@ -486,13 +501,18 @@ fn set_combine_update_inner(
486501
elements.remove(&element.0.clone());
487502
}
488503
}
489-
Ok(vm.get_none())
490504
}
491-
_ => Err(vm.new_type_error("".to_string())),
505+
_ => return Err(vm.new_type_error("".to_string())),
492506
}
507+
Ok(zelf.clone())
493508
}
494509

495510
fn set_symmetric_difference_update(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
511+
set_ixor(vm, args)?;
512+
Ok(vm.get_none())
513+
}
514+
515+
fn set_ixor(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
496516
arg_check!(
497517
vm,
498518
args,
@@ -514,11 +534,27 @@ fn set_symmetric_difference_update(vm: &mut VirtualMachine, args: PyFuncArgs) ->
514534
elements.remove(&element.0.clone());
515535
}
516536
}
517-
518-
Ok(vm.get_none())
519537
}
520-
_ => Err(vm.new_type_error("".to_string())),
538+
_ => return Err(vm.new_type_error("".to_string())),
521539
}
540+
541+
Ok(zelf.clone())
542+
}
543+
544+
fn set_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
545+
arg_check!(vm, args, required = [(zelf, Some(vm.ctx.set_type()))]);
546+
547+
let items = get_elements(zelf).values().map(|x| x.clone()).collect();
548+
let set_list = vm.ctx.new_list(items);
549+
let iter_obj = PyObject::new(
550+
PyObjectPayload::Iterator {
551+
position: 0,
552+
iterated_obj: set_list,
553+
},
554+
vm.ctx.iter_type(),
555+
);
556+
557+
Ok(iter_obj)
522558
}
523559

524560
fn frozenset_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
@@ -593,21 +629,26 @@ pub fn init(context: &PyContext) {
593629
context.set_attr(&set_type, "copy", context.new_rustfunc(set_copy));
594630
context.set_attr(&set_type, "pop", context.new_rustfunc(set_pop));
595631
context.set_attr(&set_type, "update", context.new_rustfunc(set_update));
632+
context.set_attr(&set_type, "__ior__", context.new_rustfunc(set_ior));
596633
context.set_attr(
597634
&set_type,
598635
"intersection_update",
599636
context.new_rustfunc(set_intersection_update),
600637
);
638+
context.set_attr(&set_type, "__iand__", context.new_rustfunc(set_iand));
601639
context.set_attr(
602640
&set_type,
603641
"difference_update",
604642
context.new_rustfunc(set_difference_update),
605643
);
644+
context.set_attr(&set_type, "__isub__", context.new_rustfunc(set_isub));
606645
context.set_attr(
607646
&set_type,
608647
"symmetric_difference_update",
609648
context.new_rustfunc(set_symmetric_difference_update),
610649
);
650+
context.set_attr(&set_type, "__ixor__", context.new_rustfunc(set_ixor));
651+
context.set_attr(&set_type, "__iter__", context.new_rustfunc(set_iter));
611652

612653
let frozenset_type = &context.frozenset_type;
613654

0 commit comments

Comments
 (0)