@@ -26,7 +26,7 @@ use crate::pyobject::{
26
26
use crate :: vm:: VirtualMachine ;
27
27
28
28
fn byte_count ( bytes : OptionalOption < i64 > ) -> i64 {
29
- bytes. flatten ( ) . unwrap_or ( -1 as i64 )
29
+ bytes. flatten ( ) . unwrap_or ( -1 )
30
30
}
31
31
fn os_err ( vm : & VirtualMachine , err : io:: Error ) -> PyBaseExceptionRef {
32
32
#[ cfg( any( not( target_arch = "wasm32" ) , target_os = "wasi" ) ) ]
@@ -83,28 +83,32 @@ impl BufferedIO {
83
83
}
84
84
85
85
//Read k bytes from the object and return.
86
- fn read ( & mut self , bytes : i64 ) -> Option < Vec < u8 > > {
87
- let mut buffer = Vec :: new ( ) ;
88
-
86
+ fn read ( & mut self , bytes : Option < i64 > ) -> Option < Vec < u8 > > {
89
87
//for a defined number of bytes, i.e. bytes != -1
90
- if bytes >= 0 {
91
- let mut handle = self . cursor . clone ( ) . take ( bytes as u64 ) ;
92
- //read handle into buffer
93
-
94
- if handle. read_to_end ( & mut buffer) . is_err ( ) {
95
- return None ;
88
+ match bytes. and_then ( |v| v. to_usize ( ) ) {
89
+ Some ( bytes) => {
90
+ let mut buffer = unsafe {
91
+ // Do not move or edit any part of this block without a safety validation.
92
+ // `set_len` is guaranteed to be safe only when the new length is less than or equal to the capacity
93
+ let mut buffer = Vec :: with_capacity ( bytes) ;
94
+ buffer. set_len ( bytes) ;
95
+ buffer
96
+ } ;
97
+ //read handle into buffer
98
+ self . cursor
99
+ . read_exact ( & mut buffer)
100
+ . map_or ( None , |_| Some ( buffer) )
96
101
}
97
- //the take above consumes the struct value
98
- //we add this back in with the takes into_inner method
99
- self . cursor = handle. into_inner ( ) ;
100
- } else {
101
- //read handle into buffer
102
- if self . cursor . read_to_end ( & mut buffer) . is_err ( ) {
103
- return None ;
102
+ None => {
103
+ let mut buffer = Vec :: new ( ) ;
104
+ //read handle into buffer
105
+ if self . cursor . read_to_end ( & mut buffer) . is_err ( ) {
106
+ None
107
+ } else {
108
+ Some ( buffer)
109
+ }
104
110
}
105
- } ;
106
-
107
- Some ( buffer)
111
+ }
108
112
}
109
113
110
114
fn tell ( & self ) -> u64 {
@@ -209,7 +213,7 @@ impl PyStringIORef {
209
213
//If k is undefined || k == -1, then we read all bytes until the end of the file.
210
214
//This also increments the stream position by the value of k
211
215
fn read ( self , bytes : OptionalOption < i64 > , vm : & VirtualMachine ) -> PyResult {
212
- let data = match self . buffer ( vm) ?. read ( byte_count ( bytes) ) {
216
+ let data = match self . buffer ( vm) ?. read ( bytes. flatten ( ) ) {
213
217
Some ( value) => value,
214
218
None => Vec :: new ( ) ,
215
219
} ;
@@ -263,11 +267,12 @@ fn string_io_new(
263
267
_args : StringIOArgs ,
264
268
vm : & VirtualMachine ,
265
269
) -> PyResult < PyStringIORef > {
266
- let flatten = object. flatten ( ) ;
267
- let input = flatten. map_or_else ( Vec :: new, |v| objstr:: borrow_value ( & v) . as_bytes ( ) . to_vec ( ) ) ;
270
+ let raw_bytes = object
271
+ . flatten ( )
272
+ . map_or_else ( Vec :: new, |v| objstr:: borrow_value ( & v) . as_bytes ( ) . to_vec ( ) ) ;
268
273
269
274
PyStringIO {
270
- buffer : PyRwLock :: new ( BufferedIO :: new ( Cursor :: new ( input ) ) ) ,
275
+ buffer : PyRwLock :: new ( BufferedIO :: new ( Cursor :: new ( raw_bytes ) ) ) ,
271
276
closed : AtomicCell :: new ( false ) ,
272
277
}
273
278
. into_ref_with_type ( vm, cls)
@@ -312,7 +317,7 @@ impl PyBytesIORef {
312
317
//If k is undefined || k == -1, then we read all bytes until the end of the file.
313
318
//This also increments the stream position by the value of k
314
319
fn read ( self , bytes : OptionalOption < i64 > , vm : & VirtualMachine ) -> PyResult {
315
- match self . buffer ( vm) ?. read ( byte_count ( bytes) ) {
320
+ match self . buffer ( vm) ?. read ( bytes. flatten ( ) ) {
316
321
Some ( value) => Ok ( vm. ctx . new_bytes ( value) ) ,
317
322
None => Err ( vm. new_value_error ( "Error Retrieving Value" . to_owned ( ) ) ) ,
318
323
}
@@ -363,10 +368,9 @@ fn bytes_io_new(
363
368
object : OptionalArg < Option < PyBytesRef > > ,
364
369
vm : & VirtualMachine ,
365
370
) -> PyResult < PyBytesIORef > {
366
- let raw_bytes = match object {
367
- OptionalArg :: Present ( Some ( ref input) ) => input. get_value ( ) . to_vec ( ) ,
368
- _ => vec ! [ ] ,
369
- } ;
371
+ let raw_bytes = object
372
+ . flatten ( )
373
+ . map_or_else ( Vec :: new, |input| input. get_value ( ) . to_vec ( ) ) ;
370
374
371
375
PyBytesIO {
372
376
buffer : PyRwLock :: new ( BufferedIO :: new ( Cursor :: new ( raw_bytes) ) ) ,
@@ -1446,7 +1450,7 @@ mod tests {
1446
1450
cursor : Cursor :: new ( data. clone ( ) ) ,
1447
1451
} ;
1448
1452
1449
- assert_eq ! ( buffered. read( bytes) . unwrap( ) , data) ;
1453
+ assert_eq ! ( buffered. read( Some ( bytes) ) . unwrap( ) , data) ;
1450
1454
}
1451
1455
1452
1456
#[ test]
@@ -1458,7 +1462,7 @@ mod tests {
1458
1462
} ;
1459
1463
1460
1464
assert_eq ! ( buffered. seek( SeekFrom :: Start ( count) ) . unwrap( ) , count) ;
1461
- assert_eq ! ( buffered. read( count . clone ( ) as i64 ) . unwrap( ) , vec![ 3 , 4 ] ) ;
1465
+ assert_eq ! ( buffered. read( Some ( count as i64 ) ) . unwrap( ) , vec![ 3 , 4 ] ) ;
1462
1466
}
1463
1467
1464
1468
#[ test]
0 commit comments