Skip to content

Commit da39dde

Browse files
authored
Merge pull request RustPython#1062 from palaviv/file-io-fd
Improve FileIO
2 parents 0642412 + 8cdf19c commit da39dde

File tree

2 files changed

+68
-62
lines changed

2 files changed

+68
-62
lines changed

tests/snippets/stdlib_io.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from io import BufferedReader, FileIO
2+
import os
23

34
fi = FileIO('README.md')
45
bb = BufferedReader(fi)
@@ -14,3 +15,9 @@
1415
assert len(result) <= 8*1024
1516
assert len(result) >= 0
1617
assert isinstance(result, bytes)
18+
19+
fd = os.open('README.md', os.O_RDONLY)
20+
21+
with FileIO(fd) as fio:
22+
res2 = fio.read()
23+
assert res == res2

vm/src/stdlib/io.rs

Lines changed: 61 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22
* I/O core tools.
33
*/
44
use std::cell::RefCell;
5-
use std::fs::File;
65
use std::io::prelude::*;
7-
use std::io::BufReader;
86
use std::io::Cursor;
97
use std::io::SeekFrom;
108

@@ -18,9 +16,12 @@ use crate::function::{OptionalArg, PyFuncArgs};
1816
use crate::import;
1917
use crate::obj::objbytearray::PyByteArray;
2018
use crate::obj::objbytes;
19+
use crate::obj::objbytes::PyBytes;
2120
use crate::obj::objint;
2221
use crate::obj::objstr;
22+
use crate::obj::objtype;
2323
use crate::obj::objtype::PyClassRef;
24+
use crate::pyobject::TypeProtocol;
2425
use crate::pyobject::{BufferProtocol, PyObjectRef, PyRef, PyResult, PyValue};
2526
use crate::vm::VirtualMachine;
2627

@@ -284,16 +285,19 @@ fn buffered_reader_read(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
284285
}
285286

286287
fn compute_c_flag(mode: &str) -> u32 {
287-
let flags = match mode {
288-
"w" => os::FileCreationFlags::O_WRONLY | os::FileCreationFlags::O_CREAT,
289-
"x" => {
290-
os::FileCreationFlags::O_WRONLY
291-
| os::FileCreationFlags::O_CREAT
292-
| os::FileCreationFlags::O_EXCL
293-
}
294-
"a" => os::FileCreationFlags::O_APPEND,
295-
"+" => os::FileCreationFlags::O_RDWR,
296-
_ => os::FileCreationFlags::O_RDONLY,
288+
let flags = match mode.chars().next() {
289+
Some(mode) => match mode {
290+
'w' => os::FileCreationFlags::O_WRONLY | os::FileCreationFlags::O_CREAT,
291+
'x' => {
292+
os::FileCreationFlags::O_WRONLY
293+
| os::FileCreationFlags::O_CREAT
294+
| os::FileCreationFlags::O_EXCL
295+
}
296+
'a' => os::FileCreationFlags::O_APPEND,
297+
'+' => os::FileCreationFlags::O_RDWR,
298+
_ => os::FileCreationFlags::O_RDONLY,
299+
},
300+
None => os::FileCreationFlags::O_RDONLY,
297301
};
298302
flags.bits()
299303
}
@@ -302,47 +306,43 @@ fn file_io_init(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
302306
arg_check!(
303307
vm,
304308
args,
305-
required = [(file_io, None), (name, Some(vm.ctx.str_type()))],
309+
required = [(file_io, None), (name, None)],
306310
optional = [(mode, Some(vm.ctx.str_type()))]
307311
);
308312

309-
let rust_mode = mode.map_or("r".to_string(), objstr::get_value);
310-
311-
match compute_c_flag(&rust_mode).to_bigint() {
312-
Some(os_mode) => {
313-
let args = vec![name.clone(), vm.ctx.new_int(os_mode)];
314-
let file_no = os::os_open(vm, PyFuncArgs::new(args, vec![]))?;
315-
316-
vm.set_attr(file_io, "name", name.clone())?;
317-
vm.set_attr(file_io, "fileno", file_no)?;
318-
vm.set_attr(file_io, "closefd", vm.new_bool(false))?;
319-
vm.set_attr(file_io, "closed", vm.new_bool(false))?;
313+
let file_no = if objtype::isinstance(&name, &vm.ctx.str_type()) {
314+
let rust_mode = mode.map_or("r".to_string(), objstr::get_value);
315+
let args = vec![
316+
name.clone(),
317+
vm.ctx
318+
.new_int(compute_c_flag(&rust_mode).to_bigint().unwrap()),
319+
];
320+
os::os_open(vm, PyFuncArgs::new(args, vec![]))?
321+
} else if objtype::isinstance(&name, &vm.ctx.int_type()) {
322+
name.clone()
323+
} else {
324+
return Err(vm.new_type_error("name parameter must be string or int".to_string()));
325+
};
320326

321-
Ok(vm.get_none())
322-
}
323-
None => Err(vm.new_type_error(format!("invalid mode {}", rust_mode))),
324-
}
327+
vm.set_attr(file_io, "name", name.clone())?;
328+
vm.set_attr(file_io, "fileno", file_no)?;
329+
vm.set_attr(file_io, "closefd", vm.new_bool(false))?;
330+
vm.set_attr(file_io, "closed", vm.new_bool(false))?;
331+
Ok(vm.get_none())
325332
}
326333

327334
fn file_io_read(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
328335
arg_check!(vm, args, required = [(file_io, None)]);
329-
let py_name = vm.get_attribute(file_io.clone(), "name")?;
330-
let f = match File::open(objstr::get_value(&py_name)) {
331-
Ok(v) => Ok(v),
332-
Err(_) => Err(vm.new_type_error("Error opening file".to_string())),
333-
};
334336

335-
let buffer = match f {
336-
Ok(v) => Ok(BufReader::new(v)),
337-
Err(_) => Err(vm.new_type_error("Error reading from file".to_string())),
338-
};
337+
let file_no = vm.get_attribute(file_io.clone(), "fileno")?;
338+
let raw_fd = objint::get_value(&file_no).to_i64().unwrap();
339+
340+
let mut handle = os::rust_file(raw_fd);
339341

340342
let mut bytes = vec![];
341-
if let Ok(mut buff) = buffer {
342-
match buff.read_to_end(&mut bytes) {
343-
Ok(_) => {}
344-
Err(_) => return Err(vm.new_value_error("Error reading from Buffer".to_string())),
345-
}
343+
match handle.read_to_end(&mut bytes) {
344+
Ok(_) => {}
345+
Err(_) => return Err(vm.new_value_error("Error reading from Buffer".to_string())),
346346
}
347347

348348
Ok(vm.ctx.new_bytes(bytes))
@@ -385,11 +385,7 @@ fn file_io_readinto(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
385385
}
386386

387387
fn file_io_write(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
388-
arg_check!(
389-
vm,
390-
args,
391-
required = [(file_io, None), (obj, Some(vm.ctx.bytes_type()))]
392-
);
388+
arg_check!(vm, args, required = [(file_io, None), (obj, None)]);
393389

394390
let file_no = vm.get_attribute(file_io.clone(), "fileno")?;
395391
let raw_fd = objint::get_value(&file_no).to_i64().unwrap();
@@ -399,22 +395,25 @@ fn file_io_write(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
399395
//to support windows - i.e. raw file_handles
400396
let mut handle = os::rust_file(raw_fd);
401397

402-
match obj.payload::<PyByteArray>() {
403-
Some(bytes) => {
404-
let value_mut = &mut bytes.inner.borrow_mut().elements;
405-
match handle.write(&value_mut[..]) {
406-
Ok(len) => {
407-
//reset raw fd on the FileIO object
408-
let updated = os::raw_file_number(handle);
409-
vm.set_attr(file_io, "fileno", vm.ctx.new_int(updated))?;
410-
411-
//return number of bytes written
412-
Ok(vm.ctx.new_int(len))
413-
}
414-
Err(_) => Err(vm.new_value_error("Error Writing Bytes to Handle".to_string())),
415-
}
398+
let bytes = match_class!(obj.clone(),
399+
i @ PyBytes => Ok(i.get_value().to_vec()),
400+
j @ PyByteArray => Ok(j.inner.borrow().elements.to_vec()),
401+
obj => Err(vm.new_type_error(format!(
402+
"a bytes-like object is required, not {}",
403+
obj.class()
404+
)))
405+
);
406+
407+
match handle.write(&bytes?) {
408+
Ok(len) => {
409+
//reset raw fd on the FileIO object
410+
let updated = os::raw_file_number(handle);
411+
vm.set_attr(file_io, "fileno", vm.ctx.new_int(updated))?;
412+
413+
//return number of bytes written
414+
Ok(vm.ctx.new_int(len))
416415
}
417-
None => Err(vm.new_value_error("Expected Bytes Object".to_string())),
416+
Err(_) => Err(vm.new_value_error("Error Writing Bytes to Handle".to_string())),
418417
}
419418
}
420419

0 commit comments

Comments
 (0)