Skip to content

Commit 7bbc729

Browse files
committed
BufferedIO.read avoid clone
1 parent 3d5ea1b commit 7bbc729

File tree

1 file changed

+35
-31
lines changed

1 file changed

+35
-31
lines changed

vm/src/stdlib/io.rs

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use crate::pyobject::{
2626
use crate::vm::VirtualMachine;
2727

2828
fn byte_count(bytes: OptionalOption<i64>) -> i64 {
29-
bytes.flatten().unwrap_or(-1 as i64)
29+
bytes.flatten().unwrap_or(-1)
3030
}
3131
fn os_err(vm: &VirtualMachine, err: io::Error) -> PyBaseExceptionRef {
3232
#[cfg(any(not(target_arch = "wasm32"), target_os = "wasi"))]
@@ -83,28 +83,32 @@ impl BufferedIO {
8383
}
8484

8585
//Read k bytes from the object and return.
86-
fn read(&mut self, bytes: i64) -> Option<Vec<u8>> {
87-
let mut buffer = Vec::new();
88-
86+
fn read(&mut self, bytes: Option<i64>) -> Option<Vec<u8>> {
8987
//for a defined number of bytes, i.e. bytes != -1
90-
if bytes >= 0 {
91-
let mut handle = self.cursor.clone().take(bytes as u64);
92-
//read handle into buffer
93-
94-
if handle.read_to_end(&mut buffer).is_err() {
95-
return None;
88+
match bytes.and_then(|v| v.to_usize()) {
89+
Some(bytes) => {
90+
let mut buffer = unsafe {
91+
// Do not move or edit any part of this block without a safety validation.
92+
// `set_len` is guaranteed to be safe only when the new length is less than or equal to the capacity
93+
let mut buffer = Vec::with_capacity(bytes);
94+
buffer.set_len(bytes);
95+
buffer
96+
};
97+
//read handle into buffer
98+
self.cursor
99+
.read_exact(&mut buffer)
100+
.map_or(None, |_| Some(buffer))
96101
}
97-
//the take above consumes the struct value
98-
//we add this back in with the takes into_inner method
99-
self.cursor = handle.into_inner();
100-
} else {
101-
//read handle into buffer
102-
if self.cursor.read_to_end(&mut buffer).is_err() {
103-
return None;
102+
None => {
103+
let mut buffer = Vec::new();
104+
//read handle into buffer
105+
if self.cursor.read_to_end(&mut buffer).is_err() {
106+
None
107+
} else {
108+
Some(buffer)
109+
}
104110
}
105-
};
106-
107-
Some(buffer)
111+
}
108112
}
109113

110114
fn tell(&self) -> u64 {
@@ -209,7 +213,7 @@ impl PyStringIORef {
209213
//If k is undefined || k == -1, then we read all bytes until the end of the file.
210214
//This also increments the stream position by the value of k
211215
fn read(self, bytes: OptionalOption<i64>, vm: &VirtualMachine) -> PyResult {
212-
let data = match self.buffer(vm)?.read(byte_count(bytes)) {
216+
let data = match self.buffer(vm)?.read(bytes.flatten()) {
213217
Some(value) => value,
214218
None => Vec::new(),
215219
};
@@ -263,11 +267,12 @@ fn string_io_new(
263267
_args: StringIOArgs,
264268
vm: &VirtualMachine,
265269
) -> PyResult<PyStringIORef> {
266-
let flatten = object.flatten();
267-
let input = flatten.map_or_else(Vec::new, |v| objstr::borrow_value(&v).as_bytes().to_vec());
270+
let raw_bytes = object
271+
.flatten()
272+
.map_or_else(Vec::new, |v| objstr::borrow_value(&v).as_bytes().to_vec());
268273

269274
PyStringIO {
270-
buffer: PyRwLock::new(BufferedIO::new(Cursor::new(input))),
275+
buffer: PyRwLock::new(BufferedIO::new(Cursor::new(raw_bytes))),
271276
closed: AtomicCell::new(false),
272277
}
273278
.into_ref_with_type(vm, cls)
@@ -312,7 +317,7 @@ impl PyBytesIORef {
312317
//If k is undefined || k == -1, then we read all bytes until the end of the file.
313318
//This also increments the stream position by the value of k
314319
fn read(self, bytes: OptionalOption<i64>, vm: &VirtualMachine) -> PyResult {
315-
match self.buffer(vm)?.read(byte_count(bytes)) {
320+
match self.buffer(vm)?.read(bytes.flatten()) {
316321
Some(value) => Ok(vm.ctx.new_bytes(value)),
317322
None => Err(vm.new_value_error("Error Retrieving Value".to_owned())),
318323
}
@@ -363,10 +368,9 @@ fn bytes_io_new(
363368
object: OptionalArg<Option<PyBytesRef>>,
364369
vm: &VirtualMachine,
365370
) -> PyResult<PyBytesIORef> {
366-
let raw_bytes = match object {
367-
OptionalArg::Present(Some(ref input)) => input.get_value().to_vec(),
368-
_ => vec![],
369-
};
371+
let raw_bytes = object
372+
.flatten()
373+
.map_or_else(Vec::new, |input| input.get_value().to_vec());
370374

371375
PyBytesIO {
372376
buffer: PyRwLock::new(BufferedIO::new(Cursor::new(raw_bytes))),
@@ -1446,7 +1450,7 @@ mod tests {
14461450
cursor: Cursor::new(data.clone()),
14471451
};
14481452

1449-
assert_eq!(buffered.read(bytes).unwrap(), data);
1453+
assert_eq!(buffered.read(Some(bytes)).unwrap(), data);
14501454
}
14511455

14521456
#[test]
@@ -1458,7 +1462,7 @@ mod tests {
14581462
};
14591463

14601464
assert_eq!(buffered.seek(SeekFrom::Start(count)).unwrap(), count);
1461-
assert_eq!(buffered.read(count.clone() as i64).unwrap(), vec![3, 4]);
1465+
assert_eq!(buffered.read(Some(count as i64)).unwrap(), vec![3, 4]);
14621466
}
14631467

14641468
#[test]

0 commit comments

Comments
 (0)