@@ -17,7 +17,7 @@ use crate::obj::objbytes::PyBytes;
17
17
use crate :: obj:: objint;
18
18
use crate :: obj:: objstr;
19
19
use crate :: obj:: objtype;
20
- use crate :: obj:: objtype:: PyClassRef ;
20
+ use crate :: obj:: objtype:: { PyClass , PyClassRef } ;
21
21
use crate :: pyobject:: TypeProtocol ;
22
22
use crate :: pyobject:: { BufferProtocol , PyObjectRef , PyRef , PyResult , PyValue } ;
23
23
use crate :: vm:: VirtualMachine ;
@@ -442,16 +442,71 @@ fn text_io_wrapper_init(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
442
442
fn text_io_base_read ( vm : & VirtualMachine , args : PyFuncArgs ) -> PyResult {
443
443
arg_check ! ( vm, args, required = [ ( text_io_base, None ) ] ) ;
444
444
445
+ let io_module = vm. import ( "_io" , & vm. ctx . new_tuple ( vec ! [ ] ) , 0 ) ?;
446
+ let buffered_reader_class = vm
447
+ . get_attribute ( io_module. clone ( ) , "BufferedReader" )
448
+ . unwrap ( )
449
+ . downcast :: < PyClass > ( )
450
+ . unwrap ( ) ;
445
451
let raw = vm. get_attribute ( text_io_base. clone ( ) , "buffer" ) . unwrap ( ) ;
446
452
447
- if let Ok ( bytes) = vm. call_method ( & raw , "read" , PyFuncArgs :: default ( ) ) {
448
- let value = objbytes:: get_value ( & bytes) . to_vec ( ) ;
453
+ if objtype:: isinstance ( & raw , & buffered_reader_class) {
454
+ if let Ok ( bytes) = vm. call_method ( & raw , "read" , PyFuncArgs :: default ( ) ) {
455
+ let value = objbytes:: get_value ( & bytes) . to_vec ( ) ;
456
+
457
+ //format bytes into string
458
+ let rust_string = String :: from_utf8 ( value) . map_err ( |e| {
459
+ vm. new_unicode_decode_error ( format ! (
460
+ "cannot decode byte at index: {}" ,
461
+ e. utf8_error( ) . valid_up_to( )
462
+ ) )
463
+ } ) ?;
464
+ Ok ( vm. ctx . new_str ( rust_string) )
465
+ } else {
466
+ Err ( vm. new_value_error ( "Error unpacking Bytes" . to_string ( ) ) )
467
+ }
468
+ } else {
469
+ // TODO: this should be io.UnsupportedOperation error which derives both from ValueError *and* OSError
470
+ Err ( vm. new_value_error ( "not readable" . to_string ( ) ) )
471
+ }
472
+ }
449
473
450
- //format bytes into string
451
- let rust_string = String :: from_utf8 ( value) . unwrap ( ) ;
452
- Ok ( vm. ctx . new_str ( rust_string) )
474
+ fn text_io_base_write ( vm : & VirtualMachine , args : PyFuncArgs ) -> PyResult {
475
+ arg_check ! (
476
+ vm,
477
+ args,
478
+ required = [ ( text_io_base, None ) , ( obj, Some ( vm. ctx. str_type( ) ) ) ]
479
+ ) ;
480
+
481
+ let io_module = vm. import ( "_io" , & vm. ctx . new_tuple ( vec ! [ ] ) , 0 ) ?;
482
+ let buffered_writer_class = vm
483
+ . get_attribute ( io_module. clone ( ) , "BufferedWriter" )
484
+ . unwrap ( )
485
+ . downcast :: < PyClass > ( )
486
+ . unwrap ( ) ;
487
+ let raw = vm. get_attribute ( text_io_base. clone ( ) , "buffer" ) . unwrap ( ) ;
488
+ if objtype:: isinstance ( & raw , & buffered_writer_class) {
489
+ let write = vm
490
+ . get_method ( raw. clone ( ) , "write" )
491
+ . ok_or_else ( || vm. new_attribute_error ( "BufferedWriter has no write method" . to_owned ( ) ) )
492
+ . and_then ( |it| it) ?;
493
+ let bytes = objstr:: get_value ( obj) . into_bytes ( ) ;
494
+
495
+ let len = vm. invoke (
496
+ write,
497
+ PyFuncArgs :: new ( vec ! [ vm. ctx. new_bytes( bytes. clone( ) ) ] , vec ! [ ] ) ,
498
+ ) ?;
499
+ let len = objint:: get_value ( & len) . to_usize ( ) . ok_or_else ( || {
500
+ vm. new_overflow_error ( "int to large to convert to Rust usize" . to_string ( ) )
501
+ } ) ?;
502
+
503
+ // returns the count of unicode code points written
504
+ Ok ( vm
505
+ . ctx
506
+ . new_int ( String :: from_utf8_lossy ( & bytes[ 0 ..len] ) . chars ( ) . count ( ) ) )
453
507
} else {
454
- Err ( vm. new_value_error ( "Error unpacking Bytes" . to_string ( ) ) )
508
+ // TODO: this should be io.UnsupportedOperation error which derives from ValueError and OSError
509
+ Err ( vm. new_value_error ( "not writable" . to_string ( ) ) )
455
510
}
456
511
}
457
512
@@ -594,7 +649,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
594
649
595
650
//TextIO Base has no public constructor
596
651
let text_io_base = py_class ! ( ctx, "TextIOBase" , io_base. clone( ) , {
597
- "read" => ctx. new_rustfunc( text_io_base_read)
652
+ "read" => ctx. new_rustfunc( text_io_base_read) ,
653
+ "write" => ctx. new_rustfunc( text_io_base_write)
598
654
} ) ;
599
655
600
656
// RawBaseIO Subclasses
0 commit comments