Skip to content

Commit 5dddd02

Browse files
committed
improve (de)serialization: no crash on BigInts, non-string map keys
1 parent a332b74 commit 5dddd02

File tree

3 files changed

+64
-50
lines changed

3 files changed

+64
-50
lines changed

tests/snippets/json_snippet.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from testutils import assert_raises
12
import json
23

34
def round_trip_test(obj):
@@ -27,6 +28,12 @@ def round_trip_test(obj):
2728
assert '{}' == json.dumps({})
2829
round_trip_test({'a': 'b'})
2930

31+
# should reject non-str keys in jsons
32+
assert_raises(json.JSONDecodeError, lambda: json.loads('{3: "abc"}'))
33+
34+
# should serialize non-str keys as strings
35+
assert json.dumps({'3': 'abc'}) == json.dumps({3: 'abc'})
36+
3037
assert 1 == json.loads("1")
3138
assert -1 == json.loads("-1")
3239
assert 1.0 == json.loads("1.0")
@@ -44,12 +51,23 @@ class String(str): pass
4451
assert "string" == json.loads(String('"string"'))
4552
assert '"string"' == json.dumps(String("string"))
4653

47-
# TODO: Uncomment and test once int/float construction is supported
48-
# class Int(int): pass
49-
# class Float(float): pass
54+
class Int(int): pass
55+
class Float(float): pass
56+
57+
assert '1' == json.dumps(Int(1))
58+
assert '0.5' == json.dumps(Float(0.5))
59+
60+
class List(list): pass
61+
class Tuple(tuple): pass
62+
class Dict(dict): pass
63+
64+
assert '[1]' == json.dumps(List([1]))
65+
assert json.dumps((1, "string", 1.0, True)) == json.dumps(Tuple((1, "string", 1.0, True)))
66+
assert json.dumps({'a': 'b'}) == json.dumps(Dict({'a': 'b'}))
5067

51-
# TODO: Uncomment and test once sequence/dict subclasses are supported by
52-
# json.dumps
53-
# class List(list): pass
54-
# class Tuple(tuple): pass
55-
# class Dict(dict): pass
68+
# big ints should not crash VM
69+
# TODO: test for correct output when actual serialization implemented and doesn’t throw
70+
try:
71+
json.dumps(7*500)
72+
except:
73+
pass

tests/snippets/math_basics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@
3333
assert_raises(
3434
OverflowError,
3535
lambda: round(float('inf')),
36-
'OverflowError: cannot convert float NaN to integer')
36+
'OverflowError: cannot convert float infinity to integer')
3737
assert_raises(
3838
OverflowError,
3939
lambda: round(-float('inf')),
40-
'OverflowError: cannot convert float NaN to integer')
40+
'OverflowError: cannot convert float infinity to integer')
4141

4242
assert pow(0, 0) == 1
4343
assert pow(2, 2) == 4

vm/src/py_serde.rs

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,11 @@ use serde;
44
use serde::de::{DeserializeSeed, Visitor};
55
use serde::ser::{Serialize, SerializeMap, SerializeSeq};
66

7-
use crate::obj::{
8-
objbool,
9-
objdict::PyDictRef,
10-
objfloat, objint, objsequence,
11-
objstr::{self, PyString},
12-
objtype,
13-
};
7+
use crate::obj::{objbool, objdict::PyDictRef, objfloat, objint, objsequence, objstr, objtype};
148
use crate::pyobject::{IdProtocol, ItemProtocol, PyObjectRef, TypeProtocol};
159
use crate::VirtualMachine;
1610
use num_traits::cast::ToPrimitive;
11+
use num_traits::sign::Signed;
1712

1813
#[inline]
1914
pub fn serialize<S>(
@@ -79,9 +74,16 @@ impl<'s> serde::Serialize for PyObjectSerializer<'s> {
7974
serializer.serialize_bool(objbool::get_value(self.pyobject))
8075
} else if objtype::isinstance(self.pyobject, &self.vm.ctx.int_type()) {
8176
let v = objint::get_value(self.pyobject);
82-
serializer.serialize_i64(v.to_i64().unwrap())
83-
// Although this may seem nice, it does not give the right result:
84-
// v.serialize(serializer)
77+
let int_too_large = || serde::ser::Error::custom("int too large to serialize");
78+
// TODO: serialize BigInt when it does not fit into i64
79+
// BigInt implements serialization to a tuple of sign and a list of u32s,
80+
// eg. -1 is [-1, [1]], 0 is [0, []], 12345678900000654321 is [1, [2710766577,2874452364]]
81+
// CPython serializes big ints as long decimal integer literals
82+
if v.is_positive() {
83+
serializer.serialize_u64(v.to_u64().ok_or_else(int_too_large)?)
84+
} else {
85+
serializer.serialize_i64(v.to_i64().ok_or_else(int_too_large)?)
86+
}
8587
} else if objtype::isinstance(self.pyobject, &self.vm.ctx.list_type()) {
8688
let elements = objsequence::get_elements_list(self.pyobject);
8789
serialize_seq_elements(serializer, &elements)
@@ -138,35 +140,26 @@ impl<'de> Visitor<'de> for PyObjectDeserializer<'de> {
138140
formatter.write_str("a type that can deserialise in Python")
139141
}
140142

141-
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
142-
where
143-
E: serde::de::Error,
144-
{
145-
Ok(self.vm.ctx.new_str(value.to_string()))
146-
}
147-
148-
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
143+
fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
149144
where
150145
E: serde::de::Error,
151146
{
152-
Ok(self.vm.ctx.new_str(value))
147+
Ok(self.vm.ctx.new_bool(value))
153148
}
154149

150+
// Other signed integers delegate to this method by default, it’s the only one needed
155151
fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
156152
where
157153
E: serde::de::Error,
158154
{
159-
// The JSON deserializer always uses the i64/u64 deserializers, so we only need to
160-
// implement those for now
161155
Ok(self.vm.ctx.new_int(value))
162156
}
163157

158+
// Other unsigned integers delegate to this method by default, it’s the only one needed
164159
fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
165160
where
166161
E: serde::de::Error,
167162
{
168-
// The JSON deserializer always uses the i64/u64 deserializers, so we only need to
169-
// implement those for now
170163
Ok(self.vm.ctx.new_int(value))
171164
}
172165

@@ -177,11 +170,26 @@ impl<'de> Visitor<'de> for PyObjectDeserializer<'de> {
177170
Ok(self.vm.ctx.new_float(value))
178171
}
179172

180-
fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
173+
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
181174
where
182175
E: serde::de::Error,
183176
{
184-
Ok(self.vm.ctx.new_bool(value))
177+
// Owned value needed anyway, delegate to visit_string
178+
self.visit_string(value.to_string())
179+
}
180+
181+
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
182+
where
183+
E: serde::de::Error,
184+
{
185+
Ok(self.vm.ctx.new_str(value))
186+
}
187+
188+
fn visit_unit<E>(self) -> Result<Self::Value, E>
189+
where
190+
E: serde::de::Error,
191+
{
192+
Ok(self.vm.get_none())
185193
}
186194

187195
fn visit_seq<A>(self, mut access: A) -> Result<Self::Value, A::Error>
@@ -200,23 +208,11 @@ impl<'de> Visitor<'de> for PyObjectDeserializer<'de> {
200208
M: serde::de::MapAccess<'de>,
201209
{
202210
let dict = self.vm.ctx.new_dict();
203-
// TODO: Given keys must be strings, we can probably do something more efficient
204-
// than wrapping the given object up and then unwrapping it to determine whether or
205-
// not it is a string
211+
// Although JSON keys must be strings, implementation accepts any keys
212+
// and can be reused by other deserializers without such limit
206213
while let Some((key_obj, value)) = access.next_entry_seed(self.clone(), self.clone())? {
207-
let key: String = match key_obj.payload::<PyString>() {
208-
Some(PyString { ref value }) => value.clone(),
209-
_ => unimplemented!("map keys must be strings"),
210-
};
211-
dict.set_item(&key, value, self.vm).unwrap();
214+
dict.set_item(key_obj, value, self.vm).unwrap();
212215
}
213216
Ok(dict.into_object())
214217
}
215-
216-
fn visit_unit<E>(self) -> Result<Self::Value, E>
217-
where
218-
E: serde::de::Error,
219-
{
220-
Ok(self.vm.get_none())
221-
}
222218
}

0 commit comments

Comments
 (0)