Skip to content

Commit 03b7299

Browse files
committed
make TextIOBase writable, handle malformed utf-8 in read()
1 parent fabc260 commit 03b7299

File tree

4 files changed

+86
-8
lines changed

4 files changed

+86
-8
lines changed

vm/src/builtins.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,10 @@ pub fn make_module(vm: &VirtualMachine, module: PyObjectRef) {
889889
"FileNotFoundError" => ctx.exceptions.file_not_found_error.clone(),
890890
"FileExistsError" => ctx.exceptions.file_exists_error.clone(),
891891
"StopIteration" => ctx.exceptions.stop_iteration.clone(),
892+
"UnicodeError" => ctx.exceptions.unicode_error.clone(),
893+
"UnicodeDecodeError" => ctx.exceptions.unicode_decode_error.clone(),
894+
"UnicodeEncodeError" => ctx.exceptions.unicode_encode_error.clone(),
895+
"UnicodeTranslateError" => ctx.exceptions.unicode_translate_error.clone(),
892896
"ZeroDivisionError" => ctx.exceptions.zero_division_error.clone(),
893897
"KeyError" => ctx.exceptions.key_error.clone(),
894898
"OSError" => ctx.exceptions.os_error.clone(),

vm/src/exceptions.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,10 @@ pub struct ExceptionZoo {
215215
pub syntax_error: PyClassRef,
216216
pub type_error: PyClassRef,
217217
pub value_error: PyClassRef,
218+
pub unicode_error: PyClassRef,
219+
pub unicode_decode_error: PyClassRef,
220+
pub unicode_encode_error: PyClassRef,
221+
pub unicode_translate_error: PyClassRef,
218222
pub zero_division_error: PyClassRef,
219223
pub eof_error: PyClassRef,
220224

@@ -258,6 +262,11 @@ impl ExceptionZoo {
258262
let permission_error = create_type("PermissionError", &type_type, &os_error);
259263
let file_exists_error = create_type("FileExistsError", &type_type, &os_error);
260264
let eof_error = create_type("EOFError", &type_type, &exception_type);
265+
let unicode_error = create_type("UnicodeError", &type_type, &value_error);
266+
let unicode_decode_error = create_type("UnicodeDecodeError", &type_type, &unicode_error);
267+
let unicode_encode_error = create_type("UnicodeEncodeError", &type_type, &unicode_error);
268+
let unicode_translate_error =
269+
create_type("UnicodeTranslateError", &type_type, &unicode_error);
261270

262271
let warning = create_type("Warning", &type_type, &exception_type);
263272
let bytes_warning = create_type("BytesWarning", &type_type, &warning);
@@ -294,6 +303,10 @@ impl ExceptionZoo {
294303
syntax_error,
295304
type_error,
296305
value_error,
306+
unicode_error,
307+
unicode_decode_error,
308+
unicode_encode_error,
309+
unicode_translate_error,
297310
zero_division_error,
298311
eof_error,
299312
warning,

vm/src/stdlib/io.rs

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use crate::obj::objbytes::PyBytes;
1717
use crate::obj::objint;
1818
use crate::obj::objstr;
1919
use crate::obj::objtype;
20-
use crate::obj::objtype::PyClassRef;
20+
use crate::obj::objtype::{PyClass, PyClassRef};
2121
use crate::pyobject::TypeProtocol;
2222
use crate::pyobject::{BufferProtocol, PyObjectRef, PyRef, PyResult, PyValue};
2323
use crate::vm::VirtualMachine;
@@ -442,16 +442,71 @@ fn text_io_wrapper_init(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
442442
fn text_io_base_read(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
443443
arg_check!(vm, args, required = [(text_io_base, None)]);
444444

445+
let io_module = vm.import("_io", &vm.ctx.new_tuple(vec![]), 0)?;
446+
let buffered_reader_class = vm
447+
.get_attribute(io_module.clone(), "BufferedReader")
448+
.unwrap()
449+
.downcast::<PyClass>()
450+
.unwrap();
445451
let raw = vm.get_attribute(text_io_base.clone(), "buffer").unwrap();
446452

447-
if let Ok(bytes) = vm.call_method(&raw, "read", PyFuncArgs::default()) {
448-
let value = objbytes::get_value(&bytes).to_vec();
453+
if objtype::isinstance(&raw, &buffered_reader_class) {
454+
if let Ok(bytes) = vm.call_method(&raw, "read", PyFuncArgs::default()) {
455+
let value = objbytes::get_value(&bytes).to_vec();
456+
457+
//format bytes into string
458+
let rust_string = String::from_utf8(value).map_err(|e| {
459+
vm.new_unicode_decode_error(format!(
460+
"cannot decode byte at index: {}",
461+
e.utf8_error().valid_up_to()
462+
))
463+
})?;
464+
Ok(vm.ctx.new_str(rust_string))
465+
} else {
466+
Err(vm.new_value_error("Error unpacking Bytes".to_string()))
467+
}
468+
} else {
469+
// TODO: this should be io.UnsupportedOperation error which derives both from ValueError *and* OSError
470+
Err(vm.new_value_error("not readable".to_string()))
471+
}
472+
}
449473

450-
//format bytes into string
451-
let rust_string = String::from_utf8(value).unwrap();
452-
Ok(vm.ctx.new_str(rust_string))
474+
fn text_io_base_write(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
475+
arg_check!(
476+
vm,
477+
args,
478+
required = [(text_io_base, None), (obj, Some(vm.ctx.str_type()))]
479+
);
480+
481+
let io_module = vm.import("_io", &vm.ctx.new_tuple(vec![]), 0)?;
482+
let buffered_writer_class = vm
483+
.get_attribute(io_module.clone(), "BufferedWriter")
484+
.unwrap()
485+
.downcast::<PyClass>()
486+
.unwrap();
487+
let raw = vm.get_attribute(text_io_base.clone(), "buffer").unwrap();
488+
if objtype::isinstance(&raw, &buffered_writer_class) {
489+
let write = vm
490+
.get_method(raw.clone(), "write")
491+
.ok_or_else(|| vm.new_attribute_error("BufferedWriter has no write method".to_owned()))
492+
.and_then(|it| it)?;
493+
let bytes = objstr::get_value(obj).into_bytes();
494+
495+
let len = vm.invoke(
496+
write,
497+
PyFuncArgs::new(vec![vm.ctx.new_bytes(bytes.clone())], vec![]),
498+
)?;
499+
let len = objint::get_value(&len).to_usize().ok_or_else(|| {
500+
vm.new_overflow_error("int to large to convert to Rust usize".to_string())
501+
})?;
502+
503+
// returns the count of unicode code points written
504+
Ok(vm
505+
.ctx
506+
.new_int(String::from_utf8_lossy(&bytes[0..len]).chars().count()))
453507
} else {
454-
Err(vm.new_value_error("Error unpacking Bytes".to_string()))
508+
// TODO: this should be io.UnsupportedOperation error which derives from ValueError and OSError
509+
Err(vm.new_value_error("not writable".to_string()))
455510
}
456511
}
457512

@@ -594,7 +649,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
594649

595650
//TextIO Base has no public constructor
596651
let text_io_base = py_class!(ctx, "TextIOBase", io_base.clone(), {
597-
"read" => ctx.new_rustfunc(text_io_base_read)
652+
"read" => ctx.new_rustfunc(text_io_base_read),
653+
"write" => ctx.new_rustfunc(text_io_base_write)
598654
});
599655

600656
// RawBaseIO Subclasses

vm/src/vm.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,11 @@ impl VirtualMachine {
218218
self.new_exception(os_error, msg)
219219
}
220220

221+
pub fn new_unicode_decode_error(&self, msg: String) -> PyObjectRef {
222+
let unicode_decode_error = self.ctx.exceptions.unicode_decode_error.clone();
223+
self.new_exception(unicode_decode_error, msg)
224+
}
225+
221226
/// Create a new python ValueError object. Useful for raising errors from
222227
/// python functions implemented in rust.
223228
pub fn new_value_error(&self, msg: String) -> PyObjectRef {

0 commit comments

Comments
 (0)