Skip to content

Commit c0e1702

Browse files
committed
split mode/type for io.open properly, fix open('r') case
1 parent 1b7088c commit c0e1702

File tree

2 files changed

+189
-52
lines changed

2 files changed

+189
-52
lines changed

tests/snippets/builtin_open.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,14 @@
66
assert_raises(FileNotFoundError, lambda: open('DoesNotExist'))
77

88
# Use open as a context manager
9-
with open('README.md') as fp:
10-
fp.read()
9+
with open('README.md', 'rt') as fp:
10+
contents = fp.read()
11+
assert type(contents) == str, "type is " + str(type(contents))
12+
13+
with open('README.md', 'r') as fp:
14+
contents = fp.read()
15+
assert type(contents) == str, "type is " + str(type(contents))
16+
17+
with open('README.md', 'rb') as fp:
18+
contents = fp.read()
19+
assert type(contents) == bytes, "type is " + str(type(contents))

vm/src/stdlib/io.rs

Lines changed: 178 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
/*
22
* I/O core tools.
33
*/
4-
54
use std::cell::RefCell;
6-
use std::collections::HashSet;
75
use std::fs::File;
86
use std::io::prelude::*;
97
use std::io::BufReader;
@@ -294,76 +292,119 @@ fn text_io_base_read(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
294292
}
295293
}
296294

295+
fn split_mode_string(mode_string: String) -> Result<(String, String), String> {
296+
let mut mode: char = '\0';
297+
let mut typ: char = '\0';
298+
let mut plus_is_set = false;
299+
300+
for ch in mode_string.chars() {
301+
match ch {
302+
'+' => {
303+
if plus_is_set {
304+
return Err(format!("invalid mode: '{}'", mode_string));
305+
}
306+
plus_is_set = true;
307+
}
308+
't' | 'b' => {
309+
if typ != '\0' {
310+
if typ == ch {
311+
// no duplicates allowed
312+
return Err(format!("invalid mode: '{}'", mode_string));
313+
} else {
314+
return Err("can't have text and binary mode at once".to_string());
315+
}
316+
}
317+
typ = ch;
318+
}
319+
'a' | 'r' | 'w' => {
320+
if mode != '\0' {
321+
if mode == ch {
322+
// no duplicates allowed
323+
return Err(format!("invalid mode: '{}'", mode_string));
324+
} else {
325+
return Err(
326+
"must have exactly one of create/read/write/append mode".to_string()
327+
);
328+
}
329+
}
330+
mode = ch;
331+
}
332+
_ => return Err(format!("invalid mode: '{}'", mode_string)),
333+
}
334+
}
335+
336+
if mode == '\0' {
337+
return Err(
338+
"Must have exactly one of create/read/write/append mode and at most one plus"
339+
.to_string(),
340+
);
341+
}
342+
let mut mode = mode.to_string();
343+
if plus_is_set {
344+
mode.push('+');
345+
}
346+
if typ == '\0' {
347+
typ = 't';
348+
}
349+
Ok((mode, typ.to_string()))
350+
}
351+
297352
pub fn io_open(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
298353
arg_check!(
299354
vm,
300355
args,
301356
required = [(file, Some(vm.ctx.str_type()))],
302357
optional = [(mode, Some(vm.ctx.str_type()))]
303358
);
359+
// mode is optional: 'rt' is the default mode (open from reading text)
360+
let mode_string = mode.map_or("rt".to_string(), |m| objstr::get_value(m));
361+
let (mode, typ) = match split_mode_string(mode_string) {
362+
Ok((mode, typ)) => (mode, typ),
363+
Err(error_message) => {
364+
return Err(vm.new_value_error(error_message));
365+
}
366+
};
304367

305368
let module = import::import_module(vm, PathBuf::default(), "io").unwrap();
306369

307-
//mode is optional: 'rt' is the default mode (open from reading text)
308-
let rust_mode = mode.map_or("rt".to_string(), |m| objstr::get_value(m));
309-
310-
let mut raw_modes = HashSet::new();
311-
312-
//add raw modes
313-
raw_modes.insert("a".to_string());
314-
raw_modes.insert("r".to_string());
315-
raw_modes.insert("x".to_string());
316-
raw_modes.insert("w".to_string());
317-
318-
//This is not a terribly elegant way to separate the file mode from
319-
//the "type" flag - this should be improved. The intention here is to
320-
//match a valid flag for the file_io_init call:
321-
//https://docs.python.org/3/library/io.html#io.FileIO
322-
let modes: Vec<char> = rust_mode
323-
.chars()
324-
.filter(|a| raw_modes.contains(&a.to_string()))
325-
.collect();
326-
327-
if modes.is_empty() || modes.len() > 1 {
328-
return Err(vm.new_value_error("Invalid Mode".to_string()));
329-
}
330-
331-
//Class objects (potentially) consumed by io.open
332-
//RawIO: FileIO
333-
//Buffered: BufferedWriter, BufferedReader
334-
//Text: TextIOWrapper
370+
// Class objects (potentially) consumed by io.open
371+
// RawIO: FileIO
372+
// Buffered: BufferedWriter, BufferedReader
373+
// Text: TextIOWrapper
335374
let file_io_class = vm.get_attribute(module.clone(), "FileIO").unwrap();
336375
let buffered_writer_class = vm.get_attribute(module.clone(), "BufferedWriter").unwrap();
337376
let buffered_reader_class = vm.get_attribute(module.clone(), "BufferedReader").unwrap();
338377
let text_io_wrapper_class = vm.get_attribute(module, "TextIOWrapper").unwrap();
339378

340-
//Construct a FileIO (subclass of RawIOBase)
341-
//This is subsequently consumed by a Buffered Class.
342-
let file_args = vec![file.clone(), vm.ctx.new_str(modes[0].to_string())];
343-
let file_io = vm.invoke(file_io_class, file_args)?;
379+
// Construct a FileIO (subclass of RawIOBase)
380+
// This is subsequently consumed by a Buffered Class.
381+
let file_args = vec![file.clone(), vm.ctx.new_str(mode.clone())];
382+
let file_io_obj = vm.invoke(file_io_class, file_args)?;
344383

345-
//Create Buffered class to consume FileIO. The type of buffered class depends on
346-
//the operation in the mode.
347-
//There are 3 possible classes here, each inheriting from the RawBaseIO
384+
// Create Buffered class to consume FileIO. The type of buffered class depends on
385+
// the operation in the mode.
386+
// There are 3 possible classes here, each inheriting from the RawBaseIO
348387
// creating || writing || appending => BufferedWriter
349-
let buffered = if rust_mode.contains('w') {
350-
vm.invoke(buffered_writer_class, vec![file_io.clone()])
351-
// reading => BufferedReader
352-
} else {
353-
vm.invoke(buffered_reader_class, vec![file_io.clone()])
388+
let buffered = match mode.chars().next().unwrap() {
389+
'w' => vm.invoke(buffered_writer_class, vec![file_io_obj.clone()]),
390+
'r' => vm.invoke(buffered_reader_class, vec![file_io_obj.clone()]),
354391
//TODO: updating => PyBufferedRandom
392+
_ => unimplemented!("'a' mode is not yet implemented"),
355393
};
356394

357-
if rust_mode.contains('t') {
358-
//If the mode is text this buffer type is consumed on construction of
359-
//a TextIOWrapper which is subsequently returned.
360-
vm.invoke(text_io_wrapper_class, vec![buffered.unwrap()])
361-
} else {
395+
let io_obj = match typ.chars().next().unwrap() {
396+
// If the mode is text this buffer type is consumed on construction of
397+
// a TextIOWrapper which is subsequently returned.
398+
't' => vm.invoke(text_io_wrapper_class, vec![buffered.unwrap()]),
399+
362400
// If the mode is binary this Buffered class is returned directly at
363401
// this point.
364-
//For Buffered class construct "raw" IO class e.g. FileIO and pass this into corresponding field
365-
buffered
366-
}
402+
// For Buffered class construct "raw" IO class e.g. FileIO and pass this into corresponding field
403+
'b' => buffered,
404+
405+
_ => unreachable!(),
406+
};
407+
io_obj
367408
}
368409

369410
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
@@ -439,3 +480,90 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
439480
"BytesIO" => bytes_io,
440481
})
441482
}
483+
484+
#[cfg(test)]
485+
mod tests {
486+
use super::*;
487+
488+
fn assert_mode_split_into(mode_string: &str, expected_mode: &str, expected_typ: &str) {
489+
let (mode, typ) = split_mode_string(mode_string.to_string()).unwrap();
490+
assert_eq!(mode, expected_mode);
491+
assert_eq!(typ, expected_typ);
492+
}
493+
494+
#[test]
495+
fn test_split_mode_valid_cases() {
496+
assert_mode_split_into("r", "r", "t");
497+
assert_mode_split_into("rb", "r", "b");
498+
assert_mode_split_into("rt", "r", "t");
499+
assert_mode_split_into("r+t", "r+", "t");
500+
assert_mode_split_into("w+t", "w+", "t");
501+
assert_mode_split_into("r+b", "r+", "b");
502+
assert_mode_split_into("w+b", "w+", "b");
503+
}
504+
505+
#[test]
506+
fn test_invalid_mode() {
507+
assert_eq!(
508+
split_mode_string("rbsss".to_string()),
509+
Err("invalid mode: 'rbsss'".to_string())
510+
);
511+
assert_eq!(
512+
split_mode_string("rrb".to_string()),
513+
Err("invalid mode: 'rrb'".to_string())
514+
);
515+
assert_eq!(
516+
split_mode_string("rbb".to_string()),
517+
Err("invalid mode: 'rbb'".to_string())
518+
);
519+
}
520+
521+
#[test]
522+
fn test_mode_not_specified() {
523+
assert_eq!(
524+
split_mode_string("".to_string()),
525+
Err(
526+
"Must have exactly one of create/read/write/append mode and at most one plus"
527+
.to_string()
528+
)
529+
);
530+
assert_eq!(
531+
split_mode_string("b".to_string()),
532+
Err(
533+
"Must have exactly one of create/read/write/append mode and at most one plus"
534+
.to_string()
535+
)
536+
);
537+
assert_eq!(
538+
split_mode_string("t".to_string()),
539+
Err(
540+
"Must have exactly one of create/read/write/append mode and at most one plus"
541+
.to_string()
542+
)
543+
);
544+
}
545+
546+
#[test]
547+
fn test_text_and_binary_at_once() {
548+
assert_eq!(
549+
split_mode_string("rbt".to_string()),
550+
Err("can't have text and binary mode at once".to_string())
551+
);
552+
}
553+
554+
#[test]
555+
fn test_exactly_one_mode() {
556+
assert_eq!(
557+
split_mode_string("rwb".to_string()),
558+
Err("must have exactly one of create/read/write/append mode".to_string())
559+
);
560+
}
561+
562+
#[test]
563+
fn test_at_most_one_plus() {
564+
assert_eq!(
565+
split_mode_string("a++".to_string()),
566+
Err("invalid mode: 'a++'".to_string())
567+
);
568+
}
569+
}

0 commit comments

Comments
 (0)