Skip to content

Commit e5d0a02

Browse files
Merge pull request RustPython#1119 from silmeth/text_file_io
make TextIOBase writable, handle malformed utf-8 in read()
2 parents 53ce362 + afec714 commit e5d0a02

File tree

4 files changed

+68
-2
lines changed

4 files changed

+68
-2
lines changed

vm/src/builtins.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,10 @@ pub fn make_module(vm: &VirtualMachine, module: PyObjectRef) {
889889
"FileNotFoundError" => ctx.exceptions.file_not_found_error.clone(),
890890
"FileExistsError" => ctx.exceptions.file_exists_error.clone(),
891891
"StopIteration" => ctx.exceptions.stop_iteration.clone(),
892+
"UnicodeError" => ctx.exceptions.unicode_error.clone(),
893+
"UnicodeDecodeError" => ctx.exceptions.unicode_decode_error.clone(),
894+
"UnicodeEncodeError" => ctx.exceptions.unicode_encode_error.clone(),
895+
"UnicodeTranslateError" => ctx.exceptions.unicode_translate_error.clone(),
892896
"ZeroDivisionError" => ctx.exceptions.zero_division_error.clone(),
893897
"KeyError" => ctx.exceptions.key_error.clone(),
894898
"OSError" => ctx.exceptions.os_error.clone(),

vm/src/exceptions.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,10 @@ pub struct ExceptionZoo {
215215
pub syntax_error: PyClassRef,
216216
pub type_error: PyClassRef,
217217
pub value_error: PyClassRef,
218+
pub unicode_error: PyClassRef,
219+
pub unicode_decode_error: PyClassRef,
220+
pub unicode_encode_error: PyClassRef,
221+
pub unicode_translate_error: PyClassRef,
218222
pub zero_division_error: PyClassRef,
219223
pub eof_error: PyClassRef,
220224

@@ -258,6 +262,11 @@ impl ExceptionZoo {
258262
let permission_error = create_type("PermissionError", &type_type, &os_error);
259263
let file_exists_error = create_type("FileExistsError", &type_type, &os_error);
260264
let eof_error = create_type("EOFError", &type_type, &exception_type);
265+
let unicode_error = create_type("UnicodeError", &type_type, &value_error);
266+
let unicode_decode_error = create_type("UnicodeDecodeError", &type_type, &unicode_error);
267+
let unicode_encode_error = create_type("UnicodeEncodeError", &type_type, &unicode_error);
268+
let unicode_translate_error =
269+
create_type("UnicodeTranslateError", &type_type, &unicode_error);
261270

262271
let warning = create_type("Warning", &type_type, &exception_type);
263272
let bytes_warning = create_type("BytesWarning", &type_type, &warning);
@@ -294,6 +303,10 @@ impl ExceptionZoo {
294303
syntax_error,
295304
type_error,
296305
value_error,
306+
unicode_error,
307+
unicode_decode_error,
308+
unicode_encode_error,
309+
unicode_translate_error,
297310
zero_division_error,
298311
eof_error,
299312
warning,

vm/src/stdlib/io.rs

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,19 +442,62 @@ fn text_io_wrapper_init(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
442442
fn text_io_base_read(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
443443
arg_check!(vm, args, required = [(text_io_base, None)]);
444444

445+
let buffered_reader_class = vm.try_class("_io", "BufferedReader")?;
445446
let raw = vm.get_attribute(text_io_base.clone(), "buffer").unwrap();
446447

448+
if !objtype::isinstance(&raw, &buffered_reader_class) {
449+
// TODO: this should be io.UnsupportedOperation error which derives both from ValueError *and* OSError
450+
return Err(vm.new_value_error("not readable".to_string()));
451+
}
452+
447453
if let Ok(bytes) = vm.call_method(&raw, "read", PyFuncArgs::default()) {
448454
let value = objbytes::get_value(&bytes).to_vec();
449455

450456
//format bytes into string
451-
let rust_string = String::from_utf8(value).unwrap();
457+
let rust_string = String::from_utf8(value).map_err(|e| {
458+
vm.new_unicode_decode_error(format!(
459+
"cannot decode byte at index: {}",
460+
e.utf8_error().valid_up_to()
461+
))
462+
})?;
452463
Ok(vm.ctx.new_str(rust_string))
453464
} else {
454465
Err(vm.new_value_error("Error unpacking Bytes".to_string()))
455466
}
456467
}
457468

469+
fn text_io_base_write(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
470+
use std::str::from_utf8;
471+
472+
arg_check!(
473+
vm,
474+
args,
475+
required = [(text_io_base, None), (obj, Some(vm.ctx.str_type()))]
476+
);
477+
478+
let buffered_writer_class = vm.try_class("_io", "BufferedWriter")?;
479+
let raw = vm.get_attribute(text_io_base.clone(), "buffer").unwrap();
480+
481+
if !objtype::isinstance(&raw, &buffered_writer_class) {
482+
// TODO: this should be io.UnsupportedOperation error which derives from ValueError and OSError
483+
return Err(vm.new_value_error("not writable".to_string()));
484+
}
485+
486+
let bytes = objstr::get_value(obj).into_bytes();
487+
488+
let len = vm.call_method(&raw, "write", vec![vm.ctx.new_bytes(bytes.clone())])?;
489+
let len = objint::get_value(&len).to_usize().ok_or_else(|| {
490+
vm.new_overflow_error("int to large to convert to Rust usize".to_string())
491+
})?;
492+
493+
// returns the count of unicode code points written
494+
let len = from_utf8(&bytes[..len])
495+
.unwrap_or_else(|e| from_utf8(&bytes[..e.valid_up_to()]).unwrap())
496+
.chars()
497+
.count();
498+
Ok(vm.ctx.new_int(len))
499+
}
500+
458501
fn split_mode_string(mode_string: String) -> Result<(String, String), String> {
459502
let mut mode: char = '\0';
460503
let mut typ: char = '\0';
@@ -594,7 +637,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
594637

595638
//TextIO Base has no public constructor
596639
let text_io_base = py_class!(ctx, "TextIOBase", io_base.clone(), {
597-
"read" => ctx.new_rustfunc(text_io_base_read)
640+
"read" => ctx.new_rustfunc(text_io_base_read),
641+
"write" => ctx.new_rustfunc(text_io_base_write)
598642
});
599643

600644
// RawBaseIO Subclasses

vm/src/vm.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,11 @@ impl VirtualMachine {
219219
self.new_exception(os_error, msg)
220220
}
221221

222+
pub fn new_unicode_decode_error(&self, msg: String) -> PyObjectRef {
223+
let unicode_decode_error = self.ctx.exceptions.unicode_decode_error.clone();
224+
self.new_exception(unicode_decode_error, msg)
225+
}
226+
222227
/// Create a new python ValueError object. Useful for raising errors from
223228
/// python functions implemented in rust.
224229
pub fn new_value_error(&self, msg: String) -> PyObjectRef {

0 commit comments

Comments
 (0)