Skip to content

Commit bf1fe9e

Browse files
Convert set payload
1 parent a264017 commit bf1fe9e

File tree

2 files changed

+65
-53
lines changed

2 files changed

+65
-53
lines changed

vm/src/obj/objset.rs

Lines changed: 62 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,26 @@ use super::objiter;
1313
use super::objstr;
1414
use super::objtype;
1515
use crate::pyobject::{
16-
PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol,
16+
PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectPayload2, PyObjectRef, PyResult,
17+
TypeProtocol,
1718
};
1819
use crate::vm::{ReprGuard, VirtualMachine};
1920

20-
pub fn get_elements(obj: &PyObjectRef) -> HashMap<u64, PyObjectRef> {
21-
if let PyObjectPayload::Set { elements } = &obj.payload {
22-
elements.borrow().clone()
23-
} else {
24-
panic!("Cannot extract set elements from non-set");
21+
#[derive(Debug, Default)]
22+
pub struct PySet {
23+
elements: RefCell<HashMap<u64, PyObjectRef>>,
24+
}
25+
26+
impl PyObjectPayload2 for PySet {
27+
fn required_type(ctx: &PyContext) -> PyObjectRef {
28+
ctx.set_type()
2529
}
2630
}
2731

32+
pub fn get_elements(obj: &PyObjectRef) -> HashMap<u64, PyObjectRef> {
33+
obj.payload::<PySet>().unwrap().elements.borrow().clone()
34+
}
35+
2836
fn perform_action_with_hash(
2937
vm: &mut VirtualMachine,
3038
elements: &mut HashMap<u64, PyObjectRef>,
@@ -62,12 +70,10 @@ fn set_add(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
6270
arg_check!(
6371
vm,
6472
args,
65-
required = [(s, Some(vm.ctx.set_type())), (item, None)]
73+
required = [(zelf, Some(vm.ctx.set_type())), (item, None)]
6674
);
67-
match s.payload {
68-
PyObjectPayload::Set { ref elements } => {
69-
insert_into_set(vm, &mut elements.borrow_mut(), item)
70-
}
75+
match zelf.payload::<PySet>() {
76+
Some(set) => insert_into_set(vm, &mut set.elements.borrow_mut(), item),
7177
_ => Err(vm.new_type_error("set.add is called with no item".to_string())),
7278
}
7379
}
@@ -79,8 +85,8 @@ fn set_remove(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
7985
args,
8086
required = [(s, Some(vm.ctx.set_type())), (item, None)]
8187
);
82-
match s.payload {
83-
PyObjectPayload::Set { ref elements } => {
88+
match s.payload::<PySet>() {
89+
Some(set) => {
8490
fn remove(
8591
vm: &mut VirtualMachine,
8692
elements: &mut HashMap<u64, PyObjectRef>,
@@ -95,7 +101,7 @@ fn set_remove(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
95101
Some(_) => Ok(vm.get_none()),
96102
}
97103
}
98-
perform_action_with_hash(vm, &mut elements.borrow_mut(), item, &remove)
104+
perform_action_with_hash(vm, &mut set.elements.borrow_mut(), item, &remove)
99105
}
100106
_ => Err(vm.new_type_error("set.remove is called with no item".to_string())),
101107
}
@@ -108,8 +114,8 @@ fn set_discard(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
108114
args,
109115
required = [(s, Some(vm.ctx.set_type())), (item, None)]
110116
);
111-
match s.payload {
112-
PyObjectPayload::Set { ref elements } => {
117+
match s.payload::<PySet>() {
118+
Some(set) => {
113119
fn discard(
114120
vm: &mut VirtualMachine,
115121
elements: &mut HashMap<u64, PyObjectRef>,
@@ -119,21 +125,21 @@ fn set_discard(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
119125
elements.remove(&key);
120126
Ok(vm.get_none())
121127
}
122-
perform_action_with_hash(vm, &mut elements.borrow_mut(), item, &discard)
128+
perform_action_with_hash(vm, &mut set.elements.borrow_mut(), item, &discard)
123129
}
124-
_ => Err(vm.new_type_error("set.discard is called with no item".to_string())),
130+
None => Err(vm.new_type_error("set.discard is called with no item".to_string())),
125131
}
126132
}
127133

128134
fn set_clear(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
129135
trace!("set.clear called");
130136
arg_check!(vm, args, required = [(s, Some(vm.ctx.set_type()))]);
131-
match s.payload {
132-
PyObjectPayload::Set { ref elements } => {
133-
elements.borrow_mut().clear();
137+
match s.payload::<PySet>() {
138+
Some(set) => {
139+
set.elements.borrow_mut().clear();
134140
Ok(vm.get_none())
135141
}
136-
_ => Err(vm.new_type_error("".to_string())),
142+
None => Err(vm.new_type_error("".to_string())),
137143
}
138144
}
139145

@@ -163,8 +169,10 @@ fn set_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
163169
};
164170

165171
Ok(PyObject::new(
166-
PyObjectPayload::Set {
167-
elements: RefCell::new(elements),
172+
PyObjectPayload::AnyRustValue {
173+
value: Box::new(PySet {
174+
elements: RefCell::new(elements),
175+
}),
168176
},
169177
cls.clone(),
170178
))
@@ -182,8 +190,10 @@ fn set_copy(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
182190
arg_check!(vm, args, required = [(s, Some(vm.ctx.set_type()))]);
183191
let elements = get_elements(s);
184192
Ok(PyObject::new(
185-
PyObjectPayload::Set {
186-
elements: RefCell::new(elements),
193+
PyObjectPayload::AnyRustValue {
194+
value: Box::new(PySet {
195+
elements: RefCell::new(elements),
196+
}),
187197
},
188198
vm.ctx.set_type(),
189199
))
@@ -336,8 +346,10 @@ fn set_union(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
336346
elements.extend(get_elements(other).clone());
337347

338348
Ok(PyObject::new(
339-
PyObjectPayload::Set {
340-
elements: RefCell::new(elements),
349+
PyObjectPayload::AnyRustValue {
350+
value: Box::new(PySet {
351+
elements: RefCell::new(elements),
352+
}),
341353
},
342354
vm.ctx.set_type(),
343355
))
@@ -378,8 +390,10 @@ fn set_symmetric_difference(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResu
378390
}
379391

380392
Ok(PyObject::new(
381-
PyObjectPayload::Set {
382-
elements: RefCell::new(elements),
393+
PyObjectPayload::AnyRustValue {
394+
value: Box::new(PySet {
395+
elements: RefCell::new(elements),
396+
}),
383397
},
384398
vm.ctx.set_type(),
385399
))
@@ -418,8 +432,10 @@ fn set_combine_inner(
418432
}
419433

420434
Ok(PyObject::new(
421-
PyObjectPayload::Set {
422-
elements: RefCell::new(elements),
435+
PyObjectPayload::AnyRustValue {
436+
value: Box::new(PySet {
437+
elements: RefCell::new(elements),
438+
}),
423439
},
424440
vm.ctx.set_type(),
425441
))
@@ -428,9 +444,9 @@ fn set_combine_inner(
428444
fn set_pop(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
429445
arg_check!(vm, args, required = [(s, Some(vm.ctx.set_type()))]);
430446

431-
match s.payload {
432-
PyObjectPayload::Set { ref elements } => {
433-
let mut elements = elements.borrow_mut();
447+
match s.payload::<PySet>() {
448+
Some(set) => {
449+
let mut elements = set.elements.borrow_mut();
434450
match elements.clone().keys().next() {
435451
Some(key) => Ok(elements.remove(key).unwrap()),
436452
None => Err(vm.new_key_error("pop from an empty set".to_string())),
@@ -452,11 +468,11 @@ fn set_ior(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
452468
required = [(zelf, Some(vm.ctx.set_type())), (iterable, None)]
453469
);
454470

455-
match zelf.payload {
456-
PyObjectPayload::Set { ref elements } => {
471+
match zelf.payload::<PySet>() {
472+
Some(set) => {
457473
let iterator = objiter::get_iter(vm, iterable)?;
458474
while let Ok(v) = vm.call_method(&iterator, "__next__", vec![]) {
459-
insert_into_set(vm, &mut elements.borrow_mut(), &v)?;
475+
insert_into_set(vm, &mut set.elements.borrow_mut(), &v)?;
460476
}
461477
}
462478
_ => return Err(vm.new_type_error("set.update is called with no other".to_string())),
@@ -493,9 +509,9 @@ fn set_combine_update_inner(
493509
required = [(zelf, Some(vm.ctx.set_type())), (iterable, None)]
494510
);
495511

496-
match zelf.payload {
497-
PyObjectPayload::Set { ref elements } => {
498-
let mut elements = elements.borrow_mut();
512+
match zelf.payload::<PySet>() {
513+
Some(set) => {
514+
let mut elements = set.elements.borrow_mut();
499515
for element in elements.clone().iter() {
500516
let value = vm.call_method(iterable, "__contains__", vec![element.1.clone()])?;
501517
let should_remove = match op {
@@ -524,17 +540,17 @@ fn set_ixor(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
524540
required = [(zelf, Some(vm.ctx.set_type())), (iterable, None)]
525541
);
526542

527-
match zelf.payload {
528-
PyObjectPayload::Set { ref elements } => {
529-
let elements_original = elements.borrow().clone();
543+
match zelf.payload::<PySet>() {
544+
Some(set) => {
545+
let elements_original = set.elements.borrow().clone();
530546
let iterator = objiter::get_iter(vm, iterable)?;
531547
while let Ok(v) = vm.call_method(&iterator, "__next__", vec![]) {
532-
insert_into_set(vm, &mut elements.borrow_mut(), &v)?;
548+
insert_into_set(vm, &mut set.elements.borrow_mut(), &v)?;
533549
}
534550
for element in elements_original.iter() {
535551
let value = vm.call_method(iterable, "__contains__", vec![element.1.clone()])?;
536552
if objbool::get_value(&value) {
537-
elements.borrow_mut().remove(&element.0.clone());
553+
set.elements.borrow_mut().remove(&element.0.clone());
538554
}
539555
}
540556
}

vm/src/pyobject.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use crate::obj::objnone;
3434
use crate::obj::objobject;
3535
use crate::obj::objproperty;
3636
use crate::obj::objrange;
37-
use crate::obj::objset;
37+
use crate::obj::objset::{self, PySet};
3838
use crate::obj::objslice;
3939
use crate::obj::objstr;
4040
use crate::obj::objsuper;
@@ -536,8 +536,8 @@ impl PyContext {
536536
// Initialized empty, as calling __hash__ is required for adding each object to the set
537537
// which requires a VM context - this is done in the objset code itself.
538538
PyObject::new(
539-
PyObjectPayload::Set {
540-
elements: RefCell::new(HashMap::new()),
539+
PyObjectPayload::AnyRustValue {
540+
value: Box::new(PySet::default()),
541541
},
542542
self.set_type(),
543543
)
@@ -1476,9 +1476,6 @@ pub enum PyObjectPayload {
14761476
dict: RefCell<PyAttributes>,
14771477
mro: Vec<PyObjectRef>,
14781478
},
1479-
Set {
1480-
elements: RefCell<HashMap<u64, PyObjectRef>>,
1481-
},
14821479
WeakRef {
14831480
referent: PyObjectWeakRef,
14841481
},
@@ -1500,7 +1497,6 @@ impl fmt::Debug for PyObjectPayload {
15001497
PyObjectPayload::MemoryView { ref obj } => write!(f, "bytes/bytearray {:?}", obj),
15011498
PyObjectPayload::Sequence { .. } => write!(f, "list or tuple"),
15021499
PyObjectPayload::Dict { .. } => write!(f, "dict"),
1503-
PyObjectPayload::Set { .. } => write!(f, "set"),
15041500
PyObjectPayload::WeakRef { .. } => write!(f, "weakref"),
15051501
PyObjectPayload::Iterator { .. } => write!(f, "iterator"),
15061502
PyObjectPayload::EnumerateIterator { .. } => write!(f, "enumerate"),

0 commit comments

Comments
 (0)