Skip to content

Commit 8faab1f

Browse files
committed
Implement String, BytesIO with Cursor<Vec<u8>>
1 parent 33885a8 commit 8faab1f

File tree

3 files changed

+189
-22
lines changed

3 files changed

+189
-22
lines changed

tests/snippets/bytes_io.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,35 @@ def test_02():
1616
assert f.read() == bytes_string
1717
assert f.read() == b''
1818

19+
def test_03():
20+
"""
21+
Tests that the read method (integer arg)
22+
returns the expected value
23+
"""
24+
string = b'Test String 3'
25+
f = BytesIO(string)
26+
27+
assert f.read(1) == b'T'
28+
assert f.read(1) == b'e'
29+
assert f.read(1) == b's'
30+
assert f.read(1) == b't'
31+
32+
def test_04():
33+
"""
34+
Tests that the read method increments the
35+
cursor position and the seek method moves
36+
the cursor to the appropriate position
37+
"""
38+
string = b'Test String 4'
39+
f = BytesIO(string)
40+
41+
assert f.read(4) == b'Test'
42+
assert f.seek(0) == 0
43+
assert f.read(4) == b'Test'
44+
1945
if __name__ == "__main__":
2046
test_01()
2147
test_02()
48+
test_03()
49+
test_04()
50+

tests/snippets/string_io.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,55 @@
22
from io import StringIO
33

44
def test_01():
5+
"""
6+
Test that the constructor and getvalue
7+
method return expected values
8+
"""
59
string = 'Test String 1'
610
f = StringIO()
711
f.write(string)
812

913
assert f.getvalue() == string
1014

1115
def test_02():
16+
"""
17+
Test that the read method (no arg)
18+
results the expected value
19+
"""
1220
string = 'Test String 2'
1321
f = StringIO(string)
1422

1523
assert f.read() == string
1624
assert f.read() == ''
1725

26+
def test_03():
27+
"""
28+
Tests that the read method (integer arg)
29+
returns the expected value
30+
"""
31+
string = 'Test String 3'
32+
f = StringIO(string)
33+
34+
assert f.read(1) == 'T'
35+
assert f.read(1) == 'e'
36+
assert f.read(1) == 's'
37+
assert f.read(1) == 't'
38+
39+
def test_04():
40+
"""
41+
Tests that the read method increments the
42+
cursor position and the seek method moves
43+
the cursor to the appropriate position
44+
"""
45+
string = 'Test String 4'
46+
f = StringIO(string)
47+
48+
assert f.read(4) == 'Test'
49+
assert f.seek(0) == 0
50+
assert f.read(4) == 'Test'
51+
1852
if __name__ == "__main__":
1953
test_01()
2054
test_02()
55+
test_03()
56+
test_04()

vm/src/stdlib/io.rs

Lines changed: 124 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ use std::cell::RefCell;
55
use std::fs::File;
66
use std::io::prelude::*;
77
use std::io::BufReader;
8+
use std::io::Cursor;
9+
use std::io::SeekFrom;
10+
811
use std::path::PathBuf;
912

1013
use num_bigint::ToBigInt;
@@ -33,7 +36,7 @@ fn compute_c_flag(mode: &str) -> u16 {
3336

3437
#[derive(Debug)]
3538
struct PyStringIO {
36-
data: RefCell<String>,
39+
data: RefCell<Cursor<Vec<u8>>>,
3740
}
3841

3942
type PyStringIORef = PyRef<PyStringIO>;
@@ -45,19 +48,68 @@ impl PyValue for PyStringIO {
4548
}
4649

4750
impl PyStringIORef {
48-
fn write(self, data: objstr::PyStringRef, _vm: &VirtualMachine) {
49-
let data = data.value.clone();
50-
self.data.borrow_mut().push_str(&data);
51+
//write string to underlying vector
52+
fn write(self, data: objstr::PyStringRef, vm: &VirtualMachine) -> PyResult {
53+
let bytes = &data.value.clone().into_bytes();
54+
let length = bytes.len();
55+
56+
let mut cursor = self.data.borrow_mut();
57+
match cursor.write_all(bytes) {
58+
Ok(_) => Ok(vm.ctx.new_int(length)),
59+
Err(_) => Err(vm.new_type_error("Error Writing String".to_string())),
60+
}
5161
}
5262

53-
fn getvalue(self, _vm: &VirtualMachine) -> String {
54-
self.data.borrow().clone()
63+
//return the entire contents of the underlying
64+
fn getvalue(self, vm: &VirtualMachine) -> PyResult {
65+
match String::from_utf8(self.data.borrow().clone().into_inner()) {
66+
Ok(result) => Ok(vm.ctx.new_str(result)),
67+
Err(_) => Err(vm.new_value_error("Error Retrieving Value".to_string())),
68+
}
5569
}
5670

57-
fn read(self, _vm: &VirtualMachine) -> String {
58-
let data = self.data.borrow().clone();
59-
self.data.borrow_mut().clear();
60-
data
71+
//skip to the jth position
72+
fn seek(self, offset: PyObjectRef, vm: &VirtualMachine) -> PyResult {
73+
let position = objint::get_value(&offset).to_u64().unwrap();
74+
if let Err(_) = self
75+
.data
76+
.borrow_mut()
77+
.seek(SeekFrom::Start(position.clone()))
78+
{
79+
return Err(vm.new_value_error("Error Retrieving Value".to_string()));
80+
}
81+
82+
Ok(vm.ctx.new_int(position))
83+
}
84+
85+
//Read k bytes from the object and return.
86+
//If k is undefined || k == -1, then we read all bytes until the end of the file.
87+
//This also increments the stream position by the value of k
88+
fn read(self, bytes: OptionalArg<Option<PyObjectRef>>, vm: &VirtualMachine) -> PyResult {
89+
let mut buffer = String::new();
90+
91+
match bytes {
92+
OptionalArg::Present(Some(ref integer)) => {
93+
let k = objint::get_value(integer).to_u64().unwrap();
94+
let mut handle = self.data.borrow().clone().take(k);
95+
96+
//read bytes into string
97+
if let Err(_) = handle.read_to_string(&mut buffer) {
98+
return Err(vm.new_value_error("Error Retrieving Value".to_string()));
99+
}
100+
101+
//the take above consumes the struct value
102+
//we add this back in with the takes into_inner method
103+
self.data.replace(handle.into_inner());
104+
}
105+
_ => {
106+
if let Err(_) = self.data.borrow_mut().read_to_string(&mut buffer) {
107+
return Err(vm.new_value_error("Error Retrieving Value".to_string()));
108+
}
109+
}
110+
};
111+
112+
Ok(vm.ctx.new_str(buffer))
61113
}
62114
}
63115

@@ -72,14 +124,14 @@ fn string_io_new(
72124
};
73125

74126
PyStringIO {
75-
data: RefCell::new(raw_string),
127+
data: RefCell::new(Cursor::new(raw_string.into_bytes())),
76128
}
77129
.into_ref_with_type(vm, cls)
78130
}
79131

80-
#[derive(Debug, Default, Clone)]
132+
#[derive(Debug)]
81133
struct PyBytesIO {
82-
data: RefCell<Vec<u8>>,
134+
data: RefCell<Cursor<Vec<u8>>>,
83135
}
84136

85137
type PyBytesIORef = PyRef<PyBytesIO>;
@@ -91,19 +143,65 @@ impl PyValue for PyBytesIO {
91143
}
92144

93145
impl PyBytesIORef {
94-
fn write(self, data: objbytes::PyBytesRef, _vm: &VirtualMachine) {
95-
let data = data.get_value();
96-
self.data.borrow_mut().extend(data);
146+
//write string to underlying vector
147+
fn write(self, data: objbytes::PyBytesRef, vm: &VirtualMachine) -> PyResult {
148+
let bytes = data.get_value();
149+
let length = bytes.len();
150+
151+
let mut cursor = self.data.borrow_mut();
152+
match cursor.write_all(bytes) {
153+
Ok(_) => Ok(vm.ctx.new_int(length)),
154+
Err(_) => Err(vm.new_type_error("Error Writing String".to_string())),
155+
}
97156
}
98157

158+
//return the entire contents of the underlying
99159
fn getvalue(self, vm: &VirtualMachine) -> PyResult {
100-
Ok(vm.ctx.new_bytes(self.data.borrow().clone()))
160+
Ok(vm.ctx.new_bytes(self.data.borrow().clone().into_inner()))
161+
}
162+
163+
//skip to the jth position
164+
fn seek(self, offset: PyObjectRef, vm: &VirtualMachine) -> PyResult {
165+
let position = objint::get_value(&offset).to_u64().unwrap();
166+
if let Err(_) = self
167+
.data
168+
.borrow_mut()
169+
.seek(SeekFrom::Start(position.clone()))
170+
{
171+
return Err(vm.new_value_error("Error Retrieving Value".to_string()));
172+
}
173+
174+
Ok(vm.ctx.new_int(position))
101175
}
102176

103-
fn read(self, vm: &VirtualMachine) -> PyResult {
104-
let data = self.data.borrow().clone();
105-
self.data.borrow_mut().clear();
106-
Ok(vm.ctx.new_bytes(data))
177+
//Read k bytes from the object and return.
178+
//If k is undefined || k == -1, then we read all bytes until the end of the file.
179+
//This also increments the stream position by the value of k
180+
fn read(self, bytes: OptionalArg<Option<PyObjectRef>>, vm: &VirtualMachine) -> PyResult {
181+
let mut buffer = Vec::new();
182+
183+
match bytes {
184+
OptionalArg::Present(Some(ref integer)) => {
185+
let k = objint::get_value(integer).to_u64().unwrap();
186+
let mut handle = self.data.borrow().clone().take(k);
187+
188+
//read bytes into string
189+
if let Err(_) = handle.read_to_end(&mut buffer) {
190+
return Err(vm.new_value_error("Error Retrieving Value".to_string()));
191+
}
192+
193+
//the take above consumes the struct value
194+
//we add this back in with the takes into_inner method
195+
self.data.replace(handle.into_inner());
196+
}
197+
_ => {
198+
if let Err(_) = self.data.borrow_mut().read_to_end(&mut buffer) {
199+
return Err(vm.new_value_error("Error Retrieving Value".to_string()));
200+
}
201+
}
202+
};
203+
204+
Ok(vm.ctx.new_bytes(buffer))
107205
}
108206
}
109207

@@ -118,7 +216,7 @@ fn bytes_io_new(
118216
};
119217

120218
PyBytesIO {
121-
data: RefCell::new(raw_bytes),
219+
data: RefCell::new(Cursor::new(raw_bytes)),
122220
}
123221
.into_ref_with_type(vm, cls)
124222
}
@@ -514,6 +612,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
514612
//StringIO: in-memory text
515613
let string_io = py_class!(ctx, "StringIO", text_io_base.clone(), {
516614
"__new__" => ctx.new_rustfunc(string_io_new),
615+
"seek" => ctx.new_rustfunc(PyStringIORef::seek),
517616
"read" => ctx.new_rustfunc(PyStringIORef::read),
518617
"write" => ctx.new_rustfunc(PyStringIORef::write),
519618
"getvalue" => ctx.new_rustfunc(PyStringIORef::getvalue)
@@ -523,6 +622,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
523622
let bytes_io = py_class!(ctx, "BytesIO", buffered_io_base.clone(), {
524623
"__new__" => ctx.new_rustfunc(bytes_io_new),
525624
"read" => ctx.new_rustfunc(PyBytesIORef::read),
625+
"read1" => ctx.new_rustfunc(PyBytesIORef::read),
626+
"seek" => ctx.new_rustfunc(PyBytesIORef::seek),
526627
"write" => ctx.new_rustfunc(PyBytesIORef::write),
527628
"getvalue" => ctx.new_rustfunc(PyBytesIORef::getvalue)
528629
});
@@ -627,4 +728,5 @@ mod tests {
627728
Err("invalid mode: 'a++'".to_string())
628729
);
629730
}
731+
630732
}

0 commit comments

Comments
 (0)