@@ -11,7 +11,7 @@ use rustpython_vm::{
11
11
import,
12
12
obj:: objstr:: PyStringRef ,
13
13
print_exception,
14
- pyobject:: { ItemProtocol , PyResult } ,
14
+ pyobject:: { ItemProtocol , PyObjectRef , PyResult } ,
15
15
scope:: Scope ,
16
16
util, PySettings , VirtualMachine ,
17
17
} ;
@@ -434,40 +434,32 @@ fn test_run_script() {
434
434
assert ! ( r. is_ok( ) ) ;
435
435
}
436
436
437
- fn shell_exec ( vm : & VirtualMachine , source : & str , scope : Scope ) -> Result < ( ) , CompileError > {
437
+ enum ShellExecResult {
438
+ Ok ,
439
+ PyErr ( PyObjectRef ) ,
440
+ Continue ,
441
+ }
442
+
443
+ fn shell_exec ( vm : & VirtualMachine , source : & str , scope : Scope ) -> ShellExecResult {
438
444
match vm. compile ( source, compile:: Mode :: Single , "<stdin>" . to_string ( ) ) {
439
445
Ok ( code) => {
440
446
match vm. run_code_obj ( code, scope. clone ( ) ) {
441
447
Ok ( value) => {
442
448
// Save non-None values as "_"
443
-
444
- use rustpython_vm:: pyobject:: IdProtocol ;
445
-
446
- if !value. is ( & vm. get_none ( ) ) {
449
+ if !vm. is_none ( & value) {
447
450
let key = "_" ;
448
451
scope. globals . set_item ( key, value, vm) . unwrap ( ) ;
449
452
}
453
+ ShellExecResult :: Ok
450
454
}
451
-
452
- Err ( err) => {
453
- print_exception ( vm, & err) ;
454
- }
455
+ Err ( err) => ShellExecResult :: PyErr ( err) ,
455
456
}
456
-
457
- Ok ( ( ) )
458
- }
459
- // Don't inject syntax errors for line continuation
460
- Err (
461
- err @ CompileError {
462
- error : CompileErrorType :: Parse ( ParseErrorType :: EOF ) ,
463
- ..
464
- } ,
465
- ) => Err ( err) ,
466
- Err ( err) => {
467
- let exc = vm. new_syntax_error ( & err) ;
468
- print_exception ( vm, & exc) ;
469
- Err ( err)
470
457
}
458
+ Err ( CompileError {
459
+ error : CompileErrorType :: Parse ( ParseErrorType :: EOF ) ,
460
+ ..
461
+ } ) => ShellExecResult :: Continue ,
462
+ Err ( err) => ShellExecResult :: PyErr ( vm. new_syntax_error ( & err) ) ,
471
463
}
472
464
}
473
465
@@ -519,48 +511,66 @@ fn run_shell(vm: &VirtualMachine, scope: Scope) -> PyResult<()> {
519
511
} else {
520
512
get_prompt ( vm, "ps1" )
521
513
} ;
522
- let prompt = prompt. as_ref ( ) . map ( |s| s. as_str ( ) ) . unwrap_or ( "" ) ;
523
- match repl. readline ( prompt) {
514
+ let prompt = match prompt {
515
+ Some ( ref s) => s. as_str ( ) ,
516
+ None => "" ,
517
+ } ;
518
+ let result = match repl. readline ( prompt) {
524
519
Ok ( line) => {
525
520
debug ! ( "You entered {:?}" , line) ;
526
- input. push_str ( & line) ;
527
- input. push ( '\n' ) ;
521
+
528
522
repl. add_history_entry ( line. trim_end ( ) ) ;
529
523
524
+ let stop_continuing = line. is_empty ( ) ;
525
+
526
+ if input. is_empty ( ) {
527
+ input = line;
528
+ } else {
529
+ input. push_str ( & line) ;
530
+ }
531
+ input. push_str ( "\n " ) ;
532
+
530
533
if continuing {
531
- if line . is_empty ( ) {
534
+ if stop_continuing {
532
535
continuing = false ;
533
536
} else {
534
537
continue ;
535
538
}
536
539
}
537
540
538
541
match shell_exec ( vm, & input, scope. clone ( ) ) {
539
- Err ( CompileError {
540
- error : CompileErrorType :: Parse ( ParseErrorType :: EOF ) ,
541
- ..
542
- } ) => {
542
+ ShellExecResult :: Ok => {
543
+ input = String :: new ( ) ;
544
+ Ok ( ( ) )
545
+ }
546
+ ShellExecResult :: Continue => {
543
547
continuing = true ;
548
+ Ok ( ( ) )
544
549
}
545
- _ => {
550
+ ShellExecResult :: PyErr ( err ) => {
546
551
input = String :: new ( ) ;
552
+ Err ( err)
547
553
}
548
554
}
549
555
}
550
556
Err ( ReadlineError :: Interrupted ) => {
551
- let exc = vm
557
+ continuing = false ;
558
+ let keyboard_interrupt = vm
552
559
. new_empty_exception ( vm. ctx . exceptions . keyboard_interrupt . clone ( ) )
553
560
. unwrap ( ) ;
554
- print_exception ( vm, & exc) ;
555
- continuing = false ;
561
+ Err ( keyboard_interrupt)
556
562
}
557
563
Err ( ReadlineError :: Eof ) => {
558
564
break ;
559
565
}
560
566
Err ( err) => {
561
- println ! ( "Error : {:?}" , err) ;
567
+ eprintln ! ( "Readline error : {:?}" , err) ;
562
568
break ;
563
569
}
570
+ } ;
571
+
572
+ if let Err ( exc) = result {
573
+ print_exception ( vm, & exc) ;
564
574
}
565
575
}
566
576
repl. save_history ( repl_history_path_str) . unwrap ( ) ;
@@ -586,7 +596,11 @@ fn run_shell(vm: &VirtualMachine, scope: Scope) -> PyResult<()> {
586
596
for line in stdin. lock ( ) . lines ( ) {
587
597
let mut line = line. expect ( "line failed" ) ;
588
598
line. push ( '\n' ) ;
589
- let _ = shell_exec ( vm, & line, scope. clone ( ) ) ;
599
+ match shell_exec ( vm, & line, scope. clone ( ) ) {
600
+ ShellExecResult :: Ok => { }
601
+ ShellExecResult :: Continue => println ! ( "Unexpected EOF" ) ,
602
+ ShellExecResult :: PyErr ( exc) => print_exception ( vm, & exc) ,
603
+ }
590
604
print ! ( "{}" , get_prompt( vm, "ps1" ) ) ;
591
605
stdout. flush ( ) . expect ( "flush failed" ) ;
592
606
}
0 commit comments