Skip to content

Commit 6542d35

Browse files
committed
Optional args on new for Bytes/StringIO
1 parent 6767b4e commit 6542d35

File tree

3 files changed

+30
-17
lines changed

3 files changed

+30
-17
lines changed

tests/snippets/bytes_io.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,10 @@ def test_01():
1111

1212
def test_02():
1313
bytes_string = b'Test String 2'
14-
15-
f = BytesIO()
16-
f.write(bytes_string)
14+
f = BytesIO(bytes_string)
1715

1816
assert f.read() == bytes_string
1917
assert f.read() == b''
20-
assert f.getvalue() == b''
2118

2219
if __name__ == "__main__":
2320
test_01()

tests/snippets/string_io.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,10 @@ def test_01():
1010

1111
def test_02():
1212
string = 'Test String 2'
13-
f = StringIO()
14-
f.write(string)
13+
f = StringIO(string)
1514

1615
assert f.read() == string
1716
assert f.read() == ''
18-
assert f.getvalue() == ''
1917

2018
if __name__ == "__main__":
2119
test_01()

vm/src/stdlib/io.rs

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@ use num_bigint::ToBigInt;
1111
use num_traits::ToPrimitive;
1212

1313
use super::os;
14-
use crate::function::PyFuncArgs;
14+
use crate::function::{OptionalArg, PyFuncArgs};
1515
use crate::import;
1616
use crate::obj::objbytearray::PyByteArray;
1717
use crate::obj::objbytes;
1818
use crate::obj::objint;
19-
use crate::obj::objstr;
19+
use crate::obj::objstr::{get_value, PyStringRef};
2020
use crate::obj::objtype::PyClassRef;
2121
use crate::pyobject::{BufferProtocol, PyObjectRef, PyRef, PyResult, PyValue};
2222
use crate::vm::VirtualMachine;
@@ -45,7 +45,7 @@ impl PyValue for PyStringIO {
4545
}
4646

4747
impl PyStringIORef {
48-
fn write(self, data: objstr::PyStringRef, _vm: &VirtualMachine) {
48+
fn write(self, data: PyStringRef, _vm: &VirtualMachine) {
4949
let data = data.value.clone();
5050
self.data.borrow_mut().push_str(&data);
5151
}
@@ -61,9 +61,18 @@ impl PyStringIORef {
6161
}
6262
}
6363

64-
fn string_io_new(cls: PyClassRef, vm: &VirtualMachine) -> PyResult<PyStringIORef> {
64+
fn string_io_new(
65+
cls: PyClassRef,
66+
object: OptionalArg<PyObjectRef>,
67+
vm: &VirtualMachine,
68+
) -> PyResult<PyStringIORef> {
69+
let raw_string = match object {
70+
OptionalArg::Present(ref input) => get_value(input),
71+
OptionalArg::Missing => String::new(),
72+
};
73+
6574
PyStringIO {
66-
data: RefCell::new(String::default()),
75+
data: RefCell::new(raw_string),
6776
}
6877
.into_ref_with_type(vm, cls)
6978
}
@@ -98,9 +107,18 @@ impl PyBytesIORef {
98107
}
99108
}
100109

101-
fn bytes_io_new(cls: PyClassRef, vm: &VirtualMachine) -> PyResult<PyBytesIORef> {
110+
fn bytes_io_new(
111+
cls: PyClassRef,
112+
object: OptionalArg<PyObjectRef>,
113+
vm: &VirtualMachine,
114+
) -> PyResult<PyBytesIORef> {
115+
let raw_bytes = match object {
116+
OptionalArg::Present(ref input) => objbytes::get_value(input).to_vec(),
117+
OptionalArg::Missing => vec![],
118+
};
119+
102120
PyBytesIO {
103-
data: RefCell::new(Vec::new()),
121+
data: RefCell::new(raw_bytes),
104122
}
105123
.into_ref_with_type(vm, cls)
106124
}
@@ -172,7 +190,7 @@ fn file_io_init(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
172190
optional = [(mode, Some(vm.ctx.str_type()))]
173191
);
174192

175-
let rust_mode = mode.map_or("r".to_string(), |m| objstr::get_value(m));
193+
let rust_mode = mode.map_or("r".to_string(), |m| get_value(m));
176194

177195
match compute_c_flag(&rust_mode).to_bigint() {
178196
Some(os_mode) => {
@@ -193,7 +211,7 @@ fn file_io_init(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
193211
fn file_io_read(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
194212
arg_check!(vm, args, required = [(file_io, None)]);
195213
let py_name = vm.get_attribute(file_io.clone(), "name")?;
196-
let f = match File::open(objstr::get_value(&py_name)) {
214+
let f = match File::open(get_value(&py_name)) {
197215
Ok(v) => Ok(v),
198216
Err(_) => Err(vm.new_type_error("Error opening file".to_string())),
199217
};
@@ -389,7 +407,7 @@ pub fn io_open(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
389407
optional = [(mode, Some(vm.ctx.str_type()))]
390408
);
391409
// mode is optional: 'rt' is the default mode (open from reading text)
392-
let mode_string = mode.map_or("rt".to_string(), |m| objstr::get_value(m));
410+
let mode_string = mode.map_or("rt".to_string(), |m| get_value(m));
393411
let (mode, typ) = match split_mode_string(mode_string) {
394412
Ok((mode, typ)) => (mode, typ),
395413
Err(error_message) => {

0 commit comments

Comments
 (0)