|
| 1 | +use std::fmt; |
| 2 | + |
| 3 | +use serde; |
| 4 | +use serde::de::Visitor; |
| 5 | +use serde::ser::{SerializeMap, SerializeSeq}; |
| 6 | + |
| 7 | +use crate::obj::{ |
| 8 | + objbool, |
| 9 | + objdict::PyDictRef, |
| 10 | + objfloat, objint, objsequence, |
| 11 | + objstr::{self, PyString}, |
| 12 | + objtype, |
| 13 | +}; |
| 14 | +use crate::pyobject::{IdProtocol, ItemProtocol, PyObjectRef, TypeProtocol}; |
| 15 | +use crate::VirtualMachine; |
| 16 | +use num_traits::cast::ToPrimitive; |
| 17 | + |
| 18 | +// We need to have a VM available to serialise a PyObject based on its subclass, so we implement |
| 19 | +// PyObject serialisation via a proxy object which holds a reference to a VM |
| 20 | +pub struct PyObjectSerializer<'s> { |
| 21 | + pyobject: &'s PyObjectRef, |
| 22 | + vm: &'s VirtualMachine, |
| 23 | +} |
| 24 | + |
| 25 | +impl<'s> PyObjectSerializer<'s> { |
| 26 | + pub fn new(vm: &'s VirtualMachine, pyobject: &'s PyObjectRef) -> Self { |
| 27 | + PyObjectSerializer { pyobject, vm } |
| 28 | + } |
| 29 | + |
| 30 | + fn clone_with_object(&self, pyobject: &'s PyObjectRef) -> PyObjectSerializer { |
| 31 | + PyObjectSerializer { |
| 32 | + pyobject, |
| 33 | + vm: self.vm, |
| 34 | + } |
| 35 | + } |
| 36 | +} |
| 37 | + |
| 38 | +impl<'s> serde::Serialize for PyObjectSerializer<'s> { |
| 39 | + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> |
| 40 | + where |
| 41 | + S: serde::Serializer, |
| 42 | + { |
| 43 | + let serialize_seq_elements = |
| 44 | + |serializer: S, elements: &Vec<PyObjectRef>| -> Result<S::Ok, S::Error> { |
| 45 | + let mut seq = serializer.serialize_seq(Some(elements.len()))?; |
| 46 | + for e in elements.iter() { |
| 47 | + seq.serialize_element(&self.clone_with_object(e))?; |
| 48 | + } |
| 49 | + seq.end() |
| 50 | + }; |
| 51 | + if objtype::isinstance(self.pyobject, &self.vm.ctx.str_type()) { |
| 52 | + serializer.serialize_str(&objstr::get_value(&self.pyobject)) |
| 53 | + } else if objtype::isinstance(self.pyobject, &self.vm.ctx.float_type()) { |
| 54 | + serializer.serialize_f64(objfloat::get_value(self.pyobject)) |
| 55 | + } else if objtype::isinstance(self.pyobject, &self.vm.ctx.bool_type()) { |
| 56 | + serializer.serialize_bool(objbool::get_value(self.pyobject)) |
| 57 | + } else if objtype::isinstance(self.pyobject, &self.vm.ctx.int_type()) { |
| 58 | + let v = objint::get_value(self.pyobject); |
| 59 | + serializer.serialize_i64(v.to_i64().unwrap()) |
| 60 | + // Although this may seem nice, it does not give the right result: |
| 61 | + // v.serialize(serializer) |
| 62 | + } else if objtype::isinstance(self.pyobject, &self.vm.ctx.list_type()) { |
| 63 | + let elements = objsequence::get_elements_list(self.pyobject); |
| 64 | + serialize_seq_elements(serializer, &elements) |
| 65 | + } else if objtype::isinstance(self.pyobject, &self.vm.ctx.tuple_type()) { |
| 66 | + let elements = objsequence::get_elements_tuple(self.pyobject); |
| 67 | + serialize_seq_elements(serializer, &elements) |
| 68 | + } else if objtype::isinstance(self.pyobject, &self.vm.ctx.dict_type()) { |
| 69 | + let dict: PyDictRef = self.pyobject.clone().downcast().unwrap(); |
| 70 | + let pairs: Vec<_> = dict.into_iter().collect(); |
| 71 | + let mut map = serializer.serialize_map(Some(pairs.len()))?; |
| 72 | + for (key, e) in pairs.iter() { |
| 73 | + map.serialize_entry(&self.clone_with_object(key), &self.clone_with_object(&e))?; |
| 74 | + } |
| 75 | + map.end() |
| 76 | + } else if self.pyobject.is(&self.vm.get_none()) { |
| 77 | + serializer.serialize_none() |
| 78 | + } else { |
| 79 | + Err(serde::ser::Error::custom(format!( |
| 80 | + "Object of type '{:?}' is not serializable", |
| 81 | + self.pyobject.class() |
| 82 | + ))) |
| 83 | + } |
| 84 | + } |
| 85 | +} |
| 86 | + |
| 87 | +// This object is used as the seed for deserialization so we have access to the PyContext for type |
| 88 | +// creation |
| 89 | +#[derive(Clone)] |
| 90 | +pub struct PyObjectDeserializer<'c> { |
| 91 | + vm: &'c VirtualMachine, |
| 92 | +} |
| 93 | + |
| 94 | +impl<'c> PyObjectDeserializer<'c> { |
| 95 | + pub fn new(vm: &'c VirtualMachine) -> Self { |
| 96 | + PyObjectDeserializer { vm } |
| 97 | + } |
| 98 | +} |
| 99 | + |
| 100 | +impl<'de> serde::de::DeserializeSeed<'de> for PyObjectDeserializer<'de> { |
| 101 | + type Value = PyObjectRef; |
| 102 | + |
| 103 | + fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error> |
| 104 | + where |
| 105 | + D: serde::Deserializer<'de>, |
| 106 | + { |
| 107 | + deserializer.deserialize_any(self.clone()) |
| 108 | + } |
| 109 | +} |
| 110 | + |
| 111 | +impl<'de> Visitor<'de> for PyObjectDeserializer<'de> { |
| 112 | + type Value = PyObjectRef; |
| 113 | + |
| 114 | + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { |
| 115 | + formatter.write_str("a type that can deserialise in Python") |
| 116 | + } |
| 117 | + |
| 118 | + fn visit_str<E>(self, value: &str) -> Result<Self::Value, E> |
| 119 | + where |
| 120 | + E: serde::de::Error, |
| 121 | + { |
| 122 | + Ok(self.vm.ctx.new_str(value.to_string())) |
| 123 | + } |
| 124 | + |
| 125 | + fn visit_string<E>(self, value: String) -> Result<Self::Value, E> |
| 126 | + where |
| 127 | + E: serde::de::Error, |
| 128 | + { |
| 129 | + Ok(self.vm.ctx.new_str(value)) |
| 130 | + } |
| 131 | + |
| 132 | + fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E> |
| 133 | + where |
| 134 | + E: serde::de::Error, |
| 135 | + { |
| 136 | + // The JSON deserializer always uses the i64/u64 deserializers, so we only need to |
| 137 | + // implement those for now |
| 138 | + Ok(self.vm.ctx.new_int(value)) |
| 139 | + } |
| 140 | + |
| 141 | + fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E> |
| 142 | + where |
| 143 | + E: serde::de::Error, |
| 144 | + { |
| 145 | + // The JSON deserializer always uses the i64/u64 deserializers, so we only need to |
| 146 | + // implement those for now |
| 147 | + Ok(self.vm.ctx.new_int(value)) |
| 148 | + } |
| 149 | + |
| 150 | + fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E> |
| 151 | + where |
| 152 | + E: serde::de::Error, |
| 153 | + { |
| 154 | + Ok(self.vm.ctx.new_float(value)) |
| 155 | + } |
| 156 | + |
| 157 | + fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E> |
| 158 | + where |
| 159 | + E: serde::de::Error, |
| 160 | + { |
| 161 | + Ok(self.vm.ctx.new_bool(value)) |
| 162 | + } |
| 163 | + |
| 164 | + fn visit_seq<A>(self, mut access: A) -> Result<Self::Value, A::Error> |
| 165 | + where |
| 166 | + A: serde::de::SeqAccess<'de>, |
| 167 | + { |
| 168 | + let mut seq = Vec::with_capacity(access.size_hint().unwrap_or(0)); |
| 169 | + while let Some(value) = access.next_element_seed(self.clone())? { |
| 170 | + seq.push(value); |
| 171 | + } |
| 172 | + Ok(self.vm.ctx.new_list(seq)) |
| 173 | + } |
| 174 | + |
| 175 | + fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error> |
| 176 | + where |
| 177 | + M: serde::de::MapAccess<'de>, |
| 178 | + { |
| 179 | + let dict = self.vm.ctx.new_dict(); |
| 180 | + // TODO: Given keys must be strings, we can probably do something more efficient |
| 181 | + // than wrapping the given object up and then unwrapping it to determine whether or |
| 182 | + // not it is a string |
| 183 | + while let Some((key_obj, value)) = access.next_entry_seed(self.clone(), self.clone())? { |
| 184 | + let key: String = match key_obj.payload::<PyString>() { |
| 185 | + Some(PyString { ref value }) => value.clone(), |
| 186 | + _ => unimplemented!("map keys must be strings"), |
| 187 | + }; |
| 188 | + dict.set_item(&key, value, self.vm).unwrap(); |
| 189 | + } |
| 190 | + Ok(dict.into_object()) |
| 191 | + } |
| 192 | + |
| 193 | + fn visit_unit<E>(self) -> Result<Self::Value, E> |
| 194 | + where |
| 195 | + E: serde::de::Error, |
| 196 | + { |
| 197 | + Ok(self.vm.get_none()) |
| 198 | + } |
| 199 | +} |
0 commit comments