@@ -5,6 +5,9 @@ use std::cell::RefCell;
5
5
use std:: fs:: File ;
6
6
use std:: io:: prelude:: * ;
7
7
use std:: io:: BufReader ;
8
+ use std:: io:: Cursor ;
9
+ use std:: io:: SeekFrom ;
10
+
8
11
use std:: path:: PathBuf ;
9
12
10
13
use num_bigint:: ToBigInt ;
@@ -33,7 +36,7 @@ fn compute_c_flag(mode: &str) -> u16 {
33
36
34
37
#[ derive( Debug ) ]
35
38
struct PyStringIO {
36
- data : RefCell < String > ,
39
+ data : RefCell < Cursor < Vec < u8 > > > ,
37
40
}
38
41
39
42
type PyStringIORef = PyRef < PyStringIO > ;
@@ -45,19 +48,68 @@ impl PyValue for PyStringIO {
45
48
}
46
49
47
50
impl PyStringIORef {
48
- fn write ( self , data : objstr:: PyStringRef , _vm : & VirtualMachine ) {
49
- let data = data. value . clone ( ) ;
50
- self . data . borrow_mut ( ) . push_str ( & data) ;
51
+ //write string to underlying vector
52
+ fn write ( self , data : objstr:: PyStringRef , vm : & VirtualMachine ) -> PyResult {
53
+ let bytes = & data. value . clone ( ) . into_bytes ( ) ;
54
+ let length = bytes. len ( ) ;
55
+
56
+ let mut cursor = self . data . borrow_mut ( ) ;
57
+ match cursor. write_all ( bytes) {
58
+ Ok ( _) => Ok ( vm. ctx . new_int ( length) ) ,
59
+ Err ( _) => Err ( vm. new_type_error ( "Error Writing String" . to_string ( ) ) ) ,
60
+ }
51
61
}
52
62
53
- fn getvalue ( self , _vm : & VirtualMachine ) -> String {
54
- self . data . borrow ( ) . clone ( )
63
+ //return the entire contents of the underlying
64
+ fn getvalue ( self , vm : & VirtualMachine ) -> PyResult {
65
+ match String :: from_utf8 ( self . data . borrow ( ) . clone ( ) . into_inner ( ) ) {
66
+ Ok ( result) => Ok ( vm. ctx . new_str ( result) ) ,
67
+ Err ( _) => Err ( vm. new_value_error ( "Error Retrieving Value" . to_string ( ) ) ) ,
68
+ }
55
69
}
56
70
57
- fn read ( self , _vm : & VirtualMachine ) -> String {
58
- let data = self . data . borrow ( ) . clone ( ) ;
59
- self . data . borrow_mut ( ) . clear ( ) ;
60
- data
71
+ //skip to the jth position
72
+ fn seek ( self , offset : PyObjectRef , vm : & VirtualMachine ) -> PyResult {
73
+ let position = objint:: get_value ( & offset) . to_u64 ( ) . unwrap ( ) ;
74
+ if let Err ( _) = self
75
+ . data
76
+ . borrow_mut ( )
77
+ . seek ( SeekFrom :: Start ( position. clone ( ) ) )
78
+ {
79
+ return Err ( vm. new_value_error ( "Error Retrieving Value" . to_string ( ) ) ) ;
80
+ }
81
+
82
+ Ok ( vm. ctx . new_int ( position) )
83
+ }
84
+
85
+ //Read k bytes from the object and return.
86
+ //If k is undefined || k == -1, then we read all bytes until the end of the file.
87
+ //This also increments the stream position by the value of k
88
+ fn read ( self , bytes : OptionalArg < Option < PyObjectRef > > , vm : & VirtualMachine ) -> PyResult {
89
+ let mut buffer = String :: new ( ) ;
90
+
91
+ match bytes {
92
+ OptionalArg :: Present ( Some ( ref integer) ) => {
93
+ let k = objint:: get_value ( integer) . to_u64 ( ) . unwrap ( ) ;
94
+ let mut handle = self . data . borrow ( ) . clone ( ) . take ( k) ;
95
+
96
+ //read bytes into string
97
+ if let Err ( _) = handle. read_to_string ( & mut buffer) {
98
+ return Err ( vm. new_value_error ( "Error Retrieving Value" . to_string ( ) ) ) ;
99
+ }
100
+
101
+ //the take above consumes the struct value
102
+ //we add this back in with the takes into_inner method
103
+ self . data . replace ( handle. into_inner ( ) ) ;
104
+ }
105
+ _ => {
106
+ if let Err ( _) = self . data . borrow_mut ( ) . read_to_string ( & mut buffer) {
107
+ return Err ( vm. new_value_error ( "Error Retrieving Value" . to_string ( ) ) ) ;
108
+ }
109
+ }
110
+ } ;
111
+
112
+ Ok ( vm. ctx . new_str ( buffer) )
61
113
}
62
114
}
63
115
@@ -72,14 +124,14 @@ fn string_io_new(
72
124
} ;
73
125
74
126
PyStringIO {
75
- data : RefCell :: new ( raw_string) ,
127
+ data : RefCell :: new ( Cursor :: new ( raw_string. into_bytes ( ) ) ) ,
76
128
}
77
129
. into_ref_with_type ( vm, cls)
78
130
}
79
131
80
- #[ derive( Debug , Default , Clone ) ]
132
+ #[ derive( Debug ) ]
81
133
struct PyBytesIO {
82
- data : RefCell < Vec < u8 > > ,
134
+ data : RefCell < Cursor < Vec < u8 > > > ,
83
135
}
84
136
85
137
type PyBytesIORef = PyRef < PyBytesIO > ;
@@ -91,19 +143,65 @@ impl PyValue for PyBytesIO {
91
143
}
92
144
93
145
impl PyBytesIORef {
94
- fn write ( self , data : objbytes:: PyBytesRef , _vm : & VirtualMachine ) {
95
- let data = data. get_value ( ) ;
96
- self . data . borrow_mut ( ) . extend ( data) ;
146
+ //write string to underlying vector
147
+ fn write ( self , data : objbytes:: PyBytesRef , vm : & VirtualMachine ) -> PyResult {
148
+ let bytes = data. get_value ( ) ;
149
+ let length = bytes. len ( ) ;
150
+
151
+ let mut cursor = self . data . borrow_mut ( ) ;
152
+ match cursor. write_all ( bytes) {
153
+ Ok ( _) => Ok ( vm. ctx . new_int ( length) ) ,
154
+ Err ( _) => Err ( vm. new_type_error ( "Error Writing String" . to_string ( ) ) ) ,
155
+ }
97
156
}
98
157
158
+ //return the entire contents of the underlying
99
159
fn getvalue ( self , vm : & VirtualMachine ) -> PyResult {
100
- Ok ( vm. ctx . new_bytes ( self . data . borrow ( ) . clone ( ) ) )
160
+ Ok ( vm. ctx . new_bytes ( self . data . borrow ( ) . clone ( ) . into_inner ( ) ) )
161
+ }
162
+
163
+ //skip to the jth position
164
+ fn seek ( self , offset : PyObjectRef , vm : & VirtualMachine ) -> PyResult {
165
+ let position = objint:: get_value ( & offset) . to_u64 ( ) . unwrap ( ) ;
166
+ if let Err ( _) = self
167
+ . data
168
+ . borrow_mut ( )
169
+ . seek ( SeekFrom :: Start ( position. clone ( ) ) )
170
+ {
171
+ return Err ( vm. new_value_error ( "Error Retrieving Value" . to_string ( ) ) ) ;
172
+ }
173
+
174
+ Ok ( vm. ctx . new_int ( position) )
101
175
}
102
176
103
- fn read ( self , vm : & VirtualMachine ) -> PyResult {
104
- let data = self . data . borrow ( ) . clone ( ) ;
105
- self . data . borrow_mut ( ) . clear ( ) ;
106
- Ok ( vm. ctx . new_bytes ( data) )
177
+ //Read k bytes from the object and return.
178
+ //If k is undefined || k == -1, then we read all bytes until the end of the file.
179
+ //This also increments the stream position by the value of k
180
+ fn read ( self , bytes : OptionalArg < Option < PyObjectRef > > , vm : & VirtualMachine ) -> PyResult {
181
+ let mut buffer = Vec :: new ( ) ;
182
+
183
+ match bytes {
184
+ OptionalArg :: Present ( Some ( ref integer) ) => {
185
+ let k = objint:: get_value ( integer) . to_u64 ( ) . unwrap ( ) ;
186
+ let mut handle = self . data . borrow ( ) . clone ( ) . take ( k) ;
187
+
188
+ //read bytes into string
189
+ if let Err ( _) = handle. read_to_end ( & mut buffer) {
190
+ return Err ( vm. new_value_error ( "Error Retrieving Value" . to_string ( ) ) ) ;
191
+ }
192
+
193
+ //the take above consumes the struct value
194
+ //we add this back in with the takes into_inner method
195
+ self . data . replace ( handle. into_inner ( ) ) ;
196
+ }
197
+ _ => {
198
+ if let Err ( _) = self . data . borrow_mut ( ) . read_to_end ( & mut buffer) {
199
+ return Err ( vm. new_value_error ( "Error Retrieving Value" . to_string ( ) ) ) ;
200
+ }
201
+ }
202
+ } ;
203
+
204
+ Ok ( vm. ctx . new_bytes ( buffer) )
107
205
}
108
206
}
109
207
@@ -118,7 +216,7 @@ fn bytes_io_new(
118
216
} ;
119
217
120
218
PyBytesIO {
121
- data : RefCell :: new ( raw_bytes) ,
219
+ data : RefCell :: new ( Cursor :: new ( raw_bytes) ) ,
122
220
}
123
221
. into_ref_with_type ( vm, cls)
124
222
}
@@ -514,6 +612,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
514
612
//StringIO: in-memory text
515
613
let string_io = py_class ! ( ctx, "StringIO" , text_io_base. clone( ) , {
516
614
"__new__" => ctx. new_rustfunc( string_io_new) ,
615
+ "seek" => ctx. new_rustfunc( PyStringIORef :: seek) ,
517
616
"read" => ctx. new_rustfunc( PyStringIORef :: read) ,
518
617
"write" => ctx. new_rustfunc( PyStringIORef :: write) ,
519
618
"getvalue" => ctx. new_rustfunc( PyStringIORef :: getvalue)
@@ -523,6 +622,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
523
622
let bytes_io = py_class ! ( ctx, "BytesIO" , buffered_io_base. clone( ) , {
524
623
"__new__" => ctx. new_rustfunc( bytes_io_new) ,
525
624
"read" => ctx. new_rustfunc( PyBytesIORef :: read) ,
625
+ "read1" => ctx. new_rustfunc( PyBytesIORef :: read) ,
626
+ "seek" => ctx. new_rustfunc( PyBytesIORef :: seek) ,
526
627
"write" => ctx. new_rustfunc( PyBytesIORef :: write) ,
527
628
"getvalue" => ctx. new_rustfunc( PyBytesIORef :: getvalue)
528
629
} ) ;
@@ -627,4 +728,5 @@ mod tests {
627
728
Err ( "invalid mode: 'a++'" . to_string( ) )
628
729
) ;
629
730
}
731
+
630
732
}
0 commit comments