Skip to content

Commit a8d8019

Browse files
committed
Move PyObject serialization to its own file
1 parent 0f73c74 commit a8d8019

File tree

5 files changed

+218
-214
lines changed

5 files changed

+218
-214
lines changed

vm/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ pub mod import;
5555
pub mod obj;
5656
mod pyhash;
5757
pub mod pyobject;
58+
pub mod ser_de;
5859
pub mod stdlib;
5960
mod sysmodule;
6061
mod traceback;

vm/src/ser_de.rs

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
use std::fmt;
2+
3+
use serde;
4+
use serde::de::Visitor;
5+
use serde::ser::{SerializeMap, SerializeSeq};
6+
7+
use crate::obj::{
8+
objbool,
9+
objdict::PyDictRef,
10+
objfloat, objint, objsequence,
11+
objstr::{self, PyString},
12+
objtype,
13+
};
14+
use crate::pyobject::{IdProtocol, ItemProtocol, PyObjectRef, TypeProtocol};
15+
use crate::VirtualMachine;
16+
use num_traits::cast::ToPrimitive;
17+
18+
// We need to have a VM available to serialise a PyObject based on its subclass, so we implement
19+
// PyObject serialisation via a proxy object which holds a reference to a VM
20+
pub struct PyObjectSerializer<'s> {
21+
pyobject: &'s PyObjectRef,
22+
vm: &'s VirtualMachine,
23+
}
24+
25+
impl<'s> PyObjectSerializer<'s> {
26+
pub fn new(vm: &'s VirtualMachine, pyobject: &'s PyObjectRef) -> Self {
27+
PyObjectSerializer { pyobject, vm }
28+
}
29+
30+
fn clone_with_object(&self, pyobject: &'s PyObjectRef) -> PyObjectSerializer {
31+
PyObjectSerializer {
32+
pyobject,
33+
vm: self.vm,
34+
}
35+
}
36+
}
37+
38+
impl<'s> serde::Serialize for PyObjectSerializer<'s> {
39+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
40+
where
41+
S: serde::Serializer,
42+
{
43+
let serialize_seq_elements =
44+
|serializer: S, elements: &Vec<PyObjectRef>| -> Result<S::Ok, S::Error> {
45+
let mut seq = serializer.serialize_seq(Some(elements.len()))?;
46+
for e in elements.iter() {
47+
seq.serialize_element(&self.clone_with_object(e))?;
48+
}
49+
seq.end()
50+
};
51+
if objtype::isinstance(self.pyobject, &self.vm.ctx.str_type()) {
52+
serializer.serialize_str(&objstr::get_value(&self.pyobject))
53+
} else if objtype::isinstance(self.pyobject, &self.vm.ctx.float_type()) {
54+
serializer.serialize_f64(objfloat::get_value(self.pyobject))
55+
} else if objtype::isinstance(self.pyobject, &self.vm.ctx.bool_type()) {
56+
serializer.serialize_bool(objbool::get_value(self.pyobject))
57+
} else if objtype::isinstance(self.pyobject, &self.vm.ctx.int_type()) {
58+
let v = objint::get_value(self.pyobject);
59+
serializer.serialize_i64(v.to_i64().unwrap())
60+
// Although this may seem nice, it does not give the right result:
61+
// v.serialize(serializer)
62+
} else if objtype::isinstance(self.pyobject, &self.vm.ctx.list_type()) {
63+
let elements = objsequence::get_elements_list(self.pyobject);
64+
serialize_seq_elements(serializer, &elements)
65+
} else if objtype::isinstance(self.pyobject, &self.vm.ctx.tuple_type()) {
66+
let elements = objsequence::get_elements_tuple(self.pyobject);
67+
serialize_seq_elements(serializer, &elements)
68+
} else if objtype::isinstance(self.pyobject, &self.vm.ctx.dict_type()) {
69+
let dict: PyDictRef = self.pyobject.clone().downcast().unwrap();
70+
let pairs: Vec<_> = dict.into_iter().collect();
71+
let mut map = serializer.serialize_map(Some(pairs.len()))?;
72+
for (key, e) in pairs.iter() {
73+
map.serialize_entry(&self.clone_with_object(key), &self.clone_with_object(&e))?;
74+
}
75+
map.end()
76+
} else if self.pyobject.is(&self.vm.get_none()) {
77+
serializer.serialize_none()
78+
} else {
79+
Err(serde::ser::Error::custom(format!(
80+
"Object of type '{:?}' is not serializable",
81+
self.pyobject.class()
82+
)))
83+
}
84+
}
85+
}
86+
87+
// This object is used as the seed for deserialization so we have access to the PyContext for type
88+
// creation
89+
#[derive(Clone)]
90+
pub struct PyObjectDeserializer<'c> {
91+
vm: &'c VirtualMachine,
92+
}
93+
94+
impl<'c> PyObjectDeserializer<'c> {
95+
pub fn new(vm: &'c VirtualMachine) -> Self {
96+
PyObjectDeserializer { vm }
97+
}
98+
}
99+
100+
impl<'de> serde::de::DeserializeSeed<'de> for PyObjectDeserializer<'de> {
101+
type Value = PyObjectRef;
102+
103+
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
104+
where
105+
D: serde::Deserializer<'de>,
106+
{
107+
deserializer.deserialize_any(self.clone())
108+
}
109+
}
110+
111+
impl<'de> Visitor<'de> for PyObjectDeserializer<'de> {
112+
type Value = PyObjectRef;
113+
114+
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
115+
formatter.write_str("a type that can deserialise in Python")
116+
}
117+
118+
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
119+
where
120+
E: serde::de::Error,
121+
{
122+
Ok(self.vm.ctx.new_str(value.to_string()))
123+
}
124+
125+
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
126+
where
127+
E: serde::de::Error,
128+
{
129+
Ok(self.vm.ctx.new_str(value))
130+
}
131+
132+
fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
133+
where
134+
E: serde::de::Error,
135+
{
136+
// The JSON deserializer always uses the i64/u64 deserializers, so we only need to
137+
// implement those for now
138+
Ok(self.vm.ctx.new_int(value))
139+
}
140+
141+
fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
142+
where
143+
E: serde::de::Error,
144+
{
145+
// The JSON deserializer always uses the i64/u64 deserializers, so we only need to
146+
// implement those for now
147+
Ok(self.vm.ctx.new_int(value))
148+
}
149+
150+
fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E>
151+
where
152+
E: serde::de::Error,
153+
{
154+
Ok(self.vm.ctx.new_float(value))
155+
}
156+
157+
fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
158+
where
159+
E: serde::de::Error,
160+
{
161+
Ok(self.vm.ctx.new_bool(value))
162+
}
163+
164+
fn visit_seq<A>(self, mut access: A) -> Result<Self::Value, A::Error>
165+
where
166+
A: serde::de::SeqAccess<'de>,
167+
{
168+
let mut seq = Vec::with_capacity(access.size_hint().unwrap_or(0));
169+
while let Some(value) = access.next_element_seed(self.clone())? {
170+
seq.push(value);
171+
}
172+
Ok(self.vm.ctx.new_list(seq))
173+
}
174+
175+
fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
176+
where
177+
M: serde::de::MapAccess<'de>,
178+
{
179+
let dict = self.vm.ctx.new_dict();
180+
// TODO: Given keys must be strings, we can probably do something more efficient
181+
// than wrapping the given object up and then unwrapping it to determine whether or
182+
// not it is a string
183+
while let Some((key_obj, value)) = access.next_entry_seed(self.clone(), self.clone())? {
184+
let key: String = match key_obj.payload::<PyString>() {
185+
Some(PyString { ref value }) => value.clone(),
186+
_ => unimplemented!("map keys must be strings"),
187+
};
188+
dict.set_item(&key, value, self.vm).unwrap();
189+
}
190+
Ok(dict.into_object())
191+
}
192+
193+
fn visit_unit<E>(self) -> Result<Self::Value, E>
194+
where
195+
E: serde::de::Error,
196+
{
197+
Ok(self.vm.get_none())
198+
}
199+
}

0 commit comments

Comments
 (0)