Skip to content

Commit 0e33aa3

Browse files
committed
Create objects with correct types in json.rs
Fixes RustPython#120.
1 parent a4a91f1 commit 0e33aa3

File tree

2 files changed

+137
-141
lines changed

2 files changed

+137
-141
lines changed

tests/snippets/json_snippet.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,36 +16,36 @@ def round_trip_test(obj):
1616
assert '[]' == json.dumps([])
1717
assert '[1]' == json.dumps([1])
1818
assert '[[1]]' == json.dumps([[1]])
19-
# round_trip_test([1, "string", 1.0, True])
19+
round_trip_test([1, "string", 1.0, True])
2020

2121
assert '[]' == json.dumps(())
2222
assert '[1]' == json.dumps((1,))
2323
assert '[[1]]' == json.dumps(((1,),))
2424
# tuples don't round-trip through json
25-
# assert [1, "string", 1.0, True] == json.loads(json.dumps((1, "string", 1.0, True)))
26-
27-
# assert '{}' == json.dumps({})
28-
# # TODO: uncomment once dict comparison is implemented
29-
# # round_trip_test({'a': 'b'})
30-
31-
# assert 1 == json.loads("1")
32-
# assert -1 == json.loads("-1")
33-
# assert 1.0 == json.loads("1.0")
34-
# # TODO: uncomment once negative floats are implemented
35-
# # assert -1.0 == json.loads("-1.0")
36-
# assert "str" == json.loads('"str"')
37-
# # TODO: Use "is" once implemented
38-
# assert True == json.loads('true')
39-
# assert False == json.loads('false')
40-
# # TODO: uncomment once None comparison is implemented
41-
# assert None == json.loads('null')
42-
# assert [] == json.loads('[]')
43-
# assert ['a'] == json.loads('["a"]')
44-
# assert [['a'], 'b'] == json.loads('[["a"], "b"]')
45-
46-
# class String(str): pass
47-
48-
# assert '"string"' == json.dumps(String("string"))
25+
assert [1, "string", 1.0, True] == json.loads(json.dumps((1, "string", 1.0, True)))
26+
27+
assert '{}' == json.dumps({})
28+
# TODO: uncomment once dict comparison is implemented
29+
# round_trip_test({'a': 'b'})
30+
31+
assert 1 == json.loads("1")
32+
assert -1 == json.loads("-1")
33+
assert 1.0 == json.loads("1.0")
34+
# TODO: uncomment once negative floats are implemented
35+
# assert -1.0 == json.loads("-1.0")
36+
assert "str" == json.loads('"str"')
37+
# TODO: Use "is" once implemented
38+
assert True == json.loads('true')
39+
assert False == json.loads('false')
40+
# TODO: uncomment once None comparison is implemented
41+
assert None == json.loads('null')
42+
assert [] == json.loads('[]')
43+
assert ['a'] == json.loads('["a"]')
44+
assert [['a'], 'b'] == json.loads('[["a"], "b"]')
45+
46+
class String(str): pass
47+
48+
assert '"string"' == json.dumps(String("string"))
4949

5050
# TODO: Uncomment and test once int/float construction is supported
5151
# class Int(int): pass

vm/src/stdlib/json.rs

Lines changed: 112 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
1-
use std::collections::HashMap;
21
use std::fmt;
32

43
use serde;
5-
use serde::de::Visitor;
4+
use serde::de::{DeserializeSeed, Visitor};
65
use serde::ser::{SerializeMap, SerializeSeq};
76
use serde_json;
87

98
use super::super::obj::{objdict, objfloat, objint, objlist, objstr, objtuple, objtype};
109
use super::super::objbool;
1110
use super::super::pyobject::{
12-
DictProtocol, PyContext, PyFuncArgs, PyObject, PyObjectKind, PyObjectRef, PyResult,
13-
TypeProtocol,
11+
DictProtocol, PyContext, PyFuncArgs, PyObjectKind, PyObjectRef, PyResult, TypeProtocol,
1412
};
1513
use super::super::VirtualMachine;
1614

@@ -75,128 +73,123 @@ impl<'s> serde::Serialize for PyObjectSerializer<'s> {
7573
}
7674
}
7775

78-
struct PyObjectKindVisitor;
79-
80-
impl<'de> Visitor<'de> for PyObjectKindVisitor {
81-
type Value = PyObjectKind;
76+
// This object is used as the seed for deserialization so we have access to the PyContext for type
77+
// creation
78+
#[derive(Clone)]
79+
struct PyObjectDeserializer<'c> {
80+
ctx: &'c PyContext,
81+
}
8282

83-
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
84-
formatter.write_str("a type that can deserialise in Python")
85-
}
83+
impl<'de> serde::de::DeserializeSeed<'de> for PyObjectDeserializer<'de> {
84+
type Value = PyObjectRef;
8685

87-
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
86+
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
8887
where
89-
E: serde::de::Error,
88+
D: serde::Deserializer<'de>,
9089
{
91-
Ok(PyObjectKind::String {
92-
value: value.to_string(),
93-
})
94-
}
90+
impl<'de> Visitor<'de> for PyObjectDeserializer<'de> {
91+
type Value = PyObjectRef;
9592

96-
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
97-
where
98-
E: serde::de::Error,
99-
{
100-
Ok(PyObjectKind::String { value })
101-
}
93+
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
94+
formatter.write_str("a type that can deserialise in Python")
95+
}
10296

103-
fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
104-
where
105-
E: serde::de::Error,
106-
{
107-
// The JSON deserialiser always uses the i64/u64 deserialisers, so we only need to
108-
// implement those for now
109-
use std::i32;
110-
if value >= i32::MIN as i64 && value <= i32::MAX as i64 {
111-
Ok(PyObjectKind::Integer {
112-
value: value as i32,
113-
})
114-
} else {
115-
Err(E::custom(format!("i64 out of range: {}", value)))
116-
}
117-
}
97+
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
98+
where
99+
E: serde::de::Error,
100+
{
101+
Ok(self.ctx.new_str(value.to_string()))
102+
}
118103

119-
fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
120-
where
121-
E: serde::de::Error,
122-
{
123-
// The JSON deserialiser always uses the i64/u64 deserialisers, so we only need to
124-
// implement those for now
125-
use std::i32;
126-
if value <= i32::MAX as u64 {
127-
Ok(PyObjectKind::Integer {
128-
value: value as i32,
129-
})
130-
} else {
131-
Err(E::custom(format!("u64 out of range: {}", value)))
132-
}
133-
}
104+
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
105+
where
106+
E: serde::de::Error,
107+
{
108+
Ok(self.ctx.new_str(value))
109+
}
134110

135-
fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E>
136-
where
137-
E: serde::de::Error,
138-
{
139-
Ok(PyObjectKind::Float { value })
140-
}
111+
fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
112+
where
113+
E: serde::de::Error,
114+
{
115+
// The JSON deserialiser always uses the i64/u64 deserialisers, so we only need to
116+
// implement those for now
117+
use std::i32;
118+
if value >= i32::MIN as i64 && value <= i32::MAX as i64 {
119+
Ok(self.ctx.new_int(value as i32))
120+
} else {
121+
Err(E::custom(format!("i64 out of range: {}", value)))
122+
}
123+
}
141124

142-
fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
143-
where
144-
E: serde::de::Error,
145-
{
146-
Ok(PyObjectKind::Integer {
147-
value: if value { 1 } else { 0 },
148-
})
149-
}
125+
fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
126+
where
127+
E: serde::de::Error,
128+
{
129+
// The JSON deserialiser always uses the i64/u64 deserialisers, so we only need to
130+
// implement those for now
131+
use std::i32;
132+
if value <= i32::MAX as u64 {
133+
Ok(self.ctx.new_int(value as i32))
134+
} else {
135+
Err(E::custom(format!("u64 out of range: {}", value)))
136+
}
137+
}
150138

151-
fn visit_seq<A>(self, mut access: A) -> Result<Self::Value, A::Error>
152-
where
153-
A: serde::de::SeqAccess<'de>,
154-
{
155-
let mut seq = Vec::with_capacity(access.size_hint().unwrap_or(0));
156-
while let Some(value) = access.next_element()? {
157-
seq.push(
158-
PyObject {
159-
kind: value,
160-
typ: None, // TODO: Determine the effect this None will have
161-
}.into_ref(),
162-
);
163-
}
164-
Ok(PyObjectKind::List { elements: seq })
165-
}
139+
fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E>
140+
where
141+
E: serde::de::Error,
142+
{
143+
Ok(self.ctx.new_float(value))
144+
}
166145

167-
fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
168-
where
169-
M: serde::de::MapAccess<'de>,
170-
{
171-
let mut map = HashMap::with_capacity(access.size_hint().unwrap_or(0));
172-
173-
while let Some((key, value)) = access.next_entry()? {
174-
map.insert(
175-
key,
176-
PyObject {
177-
kind: value,
178-
typ: None, // TODO: Determine the effect this None will have
179-
}.into_ref(),
180-
);
181-
}
146+
fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
147+
where
148+
E: serde::de::Error,
149+
{
150+
Ok(self.ctx.new_bool(value))
151+
}
182152

183-
Ok(PyObjectKind::Dict { elements: map })
184-
}
153+
fn visit_seq<A>(self, mut access: A) -> Result<Self::Value, A::Error>
154+
where
155+
A: serde::de::SeqAccess<'de>,
156+
{
157+
let mut seq = Vec::with_capacity(access.size_hint().unwrap_or(0));
158+
while let Some(value) = access.next_element_seed(self.clone())? {
159+
seq.push(value);
160+
}
161+
Ok(self.ctx.new_list(seq))
162+
}
185163

186-
fn visit_unit<E>(self) -> Result<Self::Value, E>
187-
where
188-
E: serde::de::Error,
189-
{
190-
Ok(PyObjectKind::None)
191-
}
192-
}
164+
fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
165+
where
166+
M: serde::de::MapAccess<'de>,
167+
{
168+
let dict = self.ctx.new_dict();
169+
// TODO: Given keys must be strings, we can probably do something more efficient
170+
// than wrapping the given object up and then unwrapping it to determine whether or
171+
// not it is a string
172+
while let Some((key_obj, value)) =
173+
access.next_entry_seed(self.clone(), self.clone())?
174+
{
175+
let key = match key_obj.borrow().kind {
176+
PyObjectKind::String { ref value } => value.clone(),
177+
_ => unimplemented!("map keys must be strings"),
178+
};
179+
dict.set_item(&key, value);
180+
}
181+
Ok(dict)
182+
}
193183

194-
impl<'de> serde::Deserialize<'de> for PyObjectKind {
195-
fn deserialize<D>(deserializer: D) -> Result<PyObjectKind, D::Error>
196-
where
197-
D: serde::Deserializer<'de>,
198-
{
199-
deserializer.deserialize_any(PyObjectKindVisitor)
184+
fn visit_unit<E>(self) -> Result<Self::Value, E>
185+
where
186+
E: serde::de::Error,
187+
{
188+
Ok(self.ctx.none.clone())
189+
}
190+
}
191+
192+
deserializer.deserialize_any(self.clone())
200193
}
201194
}
202195

@@ -216,11 +209,14 @@ fn loads(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
216209
// TODO: Implement non-trivial deserialisation case
217210
arg_check!(vm, args, required = [(string, Some(vm.ctx.str_type()))]);
218211
// TODO: Raise an exception for deserialisation errors
219-
let kind: PyObjectKind = match string.borrow().kind {
220-
PyObjectKind::String { ref value } => serde_json::from_str(&value).unwrap(),
212+
let de = PyObjectDeserializer { ctx: &vm.ctx };
213+
// TODO: Support deserializing string sub-classes
214+
Ok(match string.borrow().kind {
215+
PyObjectKind::String { ref value } => de
216+
.deserialize(&mut serde_json::Deserializer::from_str(&value))
217+
.unwrap(),
221218
_ => unimplemented!("json.loads only handles strings"),
222-
};
223-
Ok(PyObject::new(kind, vm.get_type()))
219+
})
224220
}
225221

226222
pub fn mk_module(ctx: &PyContext) -> PyObjectRef {

0 commit comments

Comments
 (0)