Skip to content

Commit 7dd1eec

Browse files
Merge pull request RustPython#517 from palaviv/set-more-funcs
Add set.{pop,update,intersection_update,difference_update,symmetric_difference_update}
2 parents 565023f + 21b6616 commit 7dd1eec

File tree

2 files changed

+147
-0
lines changed

2 files changed

+147
-0
lines changed

tests/snippets/set.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,30 @@ def __hash__(self):
8686
b.clear()
8787
assert len(a) == 3
8888
assert len(b) == 0
89+
90+
a = set([1,2])
91+
b = a.pop()
92+
assert b in [1,2]
93+
c = a.pop()
94+
assert (c in [1,2] and c != b)
95+
assert_raises(KeyError, lambda: a.pop())
96+
97+
a = set([1,2,3])
98+
a.update([3,4,5])
99+
assert a == set([1,2,3,4,5])
100+
assert_raises(TypeError, lambda: a.update(1))
101+
102+
a = set([1,2,3])
103+
a.intersection_update([2,3,4,5])
104+
assert a == set([2,3])
105+
assert_raises(TypeError, lambda: a.intersection_update(1))
106+
107+
a = set([1,2,3])
108+
a.difference_update([3,4,5])
109+
assert a == set([1,2])
110+
assert_raises(TypeError, lambda: a.difference_update(1))
111+
112+
a = set([1,2,3])
113+
a.symmetric_difference_update([3,4,5])
114+
assert a == set([1,2,4,5])
115+
assert_raises(TypeError, lambda: a.difference_update(1))

vm/src/obj/objset.rs

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,109 @@ fn set_combine_inner(
418418
))
419419
}
420420

421+
fn set_pop(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
422+
arg_check!(vm, args, required = [(s, Some(vm.ctx.set_type()))]);
423+
424+
let mut mut_obj = s.borrow_mut();
425+
426+
match mut_obj.payload {
427+
PyObjectPayload::Set { ref mut elements } => match elements.clone().keys().next() {
428+
Some(key) => Ok(elements.remove(key).unwrap()),
429+
None => Err(vm.new_key_error("pop from an empty set".to_string())),
430+
},
431+
_ => Err(vm.new_type_error("".to_string())),
432+
}
433+
}
434+
435+
fn set_update(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
436+
arg_check!(
437+
vm,
438+
args,
439+
required = [(zelf, Some(vm.ctx.set_type())), (iterable, None)]
440+
);
441+
442+
let mut mut_obj = zelf.borrow_mut();
443+
444+
match mut_obj.payload {
445+
PyObjectPayload::Set { ref mut elements } => {
446+
let iterator = objiter::get_iter(vm, iterable)?;
447+
while let Ok(v) = vm.call_method(&iterator, "__next__", vec![]) {
448+
insert_into_set(vm, elements, &v)?;
449+
}
450+
Ok(vm.get_none())
451+
}
452+
_ => Err(vm.new_type_error("set.update is called with no other".to_string())),
453+
}
454+
}
455+
456+
fn set_intersection_update(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
457+
set_combine_update_inner(vm, args, SetCombineOperation::Intersection)
458+
}
459+
460+
fn set_difference_update(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
461+
set_combine_update_inner(vm, args, SetCombineOperation::Difference)
462+
}
463+
464+
fn set_combine_update_inner(
465+
vm: &mut VirtualMachine,
466+
args: PyFuncArgs,
467+
op: SetCombineOperation,
468+
) -> PyResult {
469+
arg_check!(
470+
vm,
471+
args,
472+
required = [(zelf, Some(vm.ctx.set_type())), (iterable, None)]
473+
);
474+
475+
let mut mut_obj = zelf.borrow_mut();
476+
477+
match mut_obj.payload {
478+
PyObjectPayload::Set { ref mut elements } => {
479+
for element in elements.clone().iter() {
480+
let value = vm.call_method(iterable, "__contains__", vec![element.1.clone()])?;
481+
let should_remove = match op {
482+
SetCombineOperation::Intersection => !objbool::get_value(&value),
483+
SetCombineOperation::Difference => objbool::get_value(&value),
484+
};
485+
if should_remove {
486+
elements.remove(&element.0.clone());
487+
}
488+
}
489+
Ok(vm.get_none())
490+
}
491+
_ => Err(vm.new_type_error("".to_string())),
492+
}
493+
}
494+
495+
fn set_symmetric_difference_update(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
496+
arg_check!(
497+
vm,
498+
args,
499+
required = [(zelf, Some(vm.ctx.set_type())), (iterable, None)]
500+
);
501+
502+
let mut mut_obj = zelf.borrow_mut();
503+
504+
match mut_obj.payload {
505+
PyObjectPayload::Set { ref mut elements } => {
506+
let elements_original = elements.clone();
507+
let iterator = objiter::get_iter(vm, iterable)?;
508+
while let Ok(v) = vm.call_method(&iterator, "__next__", vec![]) {
509+
insert_into_set(vm, elements, &v)?;
510+
}
511+
for element in elements_original.iter() {
512+
let value = vm.call_method(iterable, "__contains__", vec![element.1.clone()])?;
513+
if objbool::get_value(&value) {
514+
elements.remove(&element.0.clone());
515+
}
516+
}
517+
518+
Ok(vm.get_none())
519+
}
520+
_ => Err(vm.new_type_error("".to_string())),
521+
}
522+
}
523+
421524
fn frozenset_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
422525
arg_check!(vm, args, required = [(o, Some(vm.ctx.frozenset_type()))]);
423526

@@ -488,6 +591,23 @@ pub fn init(context: &PyContext) {
488591
context.set_attr(&set_type, "discard", context.new_rustfunc(set_discard));
489592
context.set_attr(&set_type, "clear", context.new_rustfunc(set_clear));
490593
context.set_attr(&set_type, "copy", context.new_rustfunc(set_copy));
594+
context.set_attr(&set_type, "pop", context.new_rustfunc(set_pop));
595+
context.set_attr(&set_type, "update", context.new_rustfunc(set_update));
596+
context.set_attr(
597+
&set_type,
598+
"intersection_update",
599+
context.new_rustfunc(set_intersection_update),
600+
);
601+
context.set_attr(
602+
&set_type,
603+
"difference_update",
604+
context.new_rustfunc(set_difference_update),
605+
);
606+
context.set_attr(
607+
&set_type,
608+
"symmetric_difference_update",
609+
context.new_rustfunc(set_symmetric_difference_update),
610+
);
491611

492612
let frozenset_type = &context.frozenset_type;
493613

0 commit comments

Comments
 (0)