@@ -3,6 +3,7 @@ use std::io;
3
3
use std:: io:: Read ;
4
4
use std:: io:: Write ;
5
5
use std:: net:: { SocketAddr , TcpListener , TcpStream } ;
6
+ use std:: ops:: DerefMut ;
6
7
7
8
use crate :: obj:: objbytes;
8
9
use crate :: obj:: objint;
@@ -108,6 +109,15 @@ impl Socket {
108
109
}
109
110
}
110
111
112
+ fn get_socket < ' a > ( obj : & ' a PyObjectRef ) -> impl DerefMut < Target = Socket > + ' a {
113
+ if let PyObjectPayload :: AnyRustValue { ref value } = obj. payload {
114
+ if let Some ( socket) = value. downcast_ref :: < RefCell < Socket > > ( ) {
115
+ return socket. borrow_mut ( ) ;
116
+ }
117
+ }
118
+ panic ! ( "Inner error getting socket {:?}" , obj) ;
119
+ }
120
+
111
121
fn socket_new ( vm : & mut VirtualMachine , args : PyFuncArgs ) -> PyResult {
112
122
arg_check ! (
113
123
vm,
@@ -125,7 +135,9 @@ fn socket_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
125
135
let socket = RefCell :: new ( Socket :: new ( address_family, kind) ) ;
126
136
127
137
Ok ( PyObject :: new (
128
- PyObjectPayload :: Socket { socket } ,
138
+ PyObjectPayload :: AnyRustValue {
139
+ value : Box :: new ( socket) ,
140
+ } ,
129
141
cls. clone ( ) ,
130
142
) )
131
143
}
@@ -139,17 +151,14 @@ fn socket_connect(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
139
151
140
152
let address_string = get_address_string ( vm, address) ?;
141
153
142
- match zelf. payload {
143
- PyObjectPayload :: Socket { ref socket } => {
144
- if let Ok ( stream) = TcpStream :: connect ( address_string) {
145
- socket. borrow_mut ( ) . con = Some ( Connection :: TcpStream ( stream) ) ;
146
- Ok ( vm. get_none ( ) )
147
- } else {
148
- // TODO: Socket error
149
- Err ( vm. new_type_error ( "socket failed" . to_string ( ) ) )
150
- }
151
- }
152
- _ => Err ( vm. new_type_error ( "" . to_string ( ) ) ) ,
154
+ let mut socket = get_socket ( zelf) ;
155
+
156
+ if let Ok ( stream) = TcpStream :: connect ( address_string) {
157
+ socket. con = Some ( Connection :: TcpStream ( stream) ) ;
158
+ Ok ( vm. get_none ( ) )
159
+ } else {
160
+ // TODO: Socket error
161
+ Err ( vm. new_type_error ( "socket failed" . to_string ( ) ) )
153
162
}
154
163
}
155
164
@@ -162,17 +171,14 @@ fn socket_bind(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
162
171
163
172
let address_string = get_address_string ( vm, address) ?;
164
173
165
- match zelf. payload {
166
- PyObjectPayload :: Socket { ref socket } => {
167
- if let Ok ( stream) = TcpListener :: bind ( address_string) {
168
- socket. borrow_mut ( ) . con = Some ( Connection :: TcpListener ( stream) ) ;
169
- Ok ( vm. get_none ( ) )
170
- } else {
171
- // TODO: Socket error
172
- Err ( vm. new_type_error ( "socket failed" . to_string ( ) ) )
173
- }
174
- }
175
- _ => Err ( vm. new_type_error ( "" . to_string ( ) ) ) ,
174
+ let mut socket = get_socket ( zelf) ;
175
+
176
+ if let Ok ( stream) = TcpListener :: bind ( address_string) {
177
+ socket. con = Some ( Connection :: TcpListener ( stream) ) ;
178
+ Ok ( vm. get_none ( ) )
179
+ } else {
180
+ // TODO: Socket error
181
+ Err ( vm. new_type_error ( "socket failed" . to_string ( ) ) )
176
182
}
177
183
}
178
184
@@ -212,35 +218,37 @@ fn socket_listen(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
212
218
fn socket_accept ( vm : & mut VirtualMachine , args : PyFuncArgs ) -> PyResult {
213
219
arg_check ! ( vm, args, required = [ ( zelf, None ) ] ) ;
214
220
215
- match zelf. payload {
216
- PyObjectPayload :: Socket { ref socket } => {
217
- let ret = match socket. borrow_mut ( ) . con {
218
- Some ( ref mut v) => v. accept ( ) ,
219
- None => return Err ( vm. new_type_error ( "" . to_string ( ) ) ) ,
220
- } ;
221
+ let mut socket = get_socket ( zelf) ;
221
222
222
- let tcp_stream = match ret {
223
- Ok ( ( socket , _addr ) ) => socket ,
224
- _ => return Err ( vm. new_type_error ( "" . to_string ( ) ) ) ,
225
- } ;
223
+ let ret = match socket . con {
224
+ Some ( ref mut v ) => v . accept ( ) ,
225
+ None => return Err ( vm. new_type_error ( "" . to_string ( ) ) ) ,
226
+ } ;
226
227
227
- let socket = RefCell :: new ( Socket {
228
- address_family : socket. borrow ( ) . address_family ,
229
- socket_kind : socket. borrow ( ) . socket_kind ,
230
- con : Some ( Connection :: TcpStream ( tcp_stream) ) ,
231
- } ) ;
228
+ let tcp_stream = match ret {
229
+ Ok ( ( socket, _addr) ) => socket,
230
+ _ => return Err ( vm. new_type_error ( "" . to_string ( ) ) ) ,
231
+ } ;
232
232
233
- let sock_obj = PyObject :: new ( PyObjectPayload :: Socket { socket } , zelf. typ ( ) ) ;
233
+ let socket = RefCell :: new ( Socket {
234
+ address_family : socket. address_family ,
235
+ socket_kind : socket. socket_kind ,
236
+ con : Some ( Connection :: TcpStream ( tcp_stream) ) ,
237
+ } ) ;
238
+
239
+ let sock_obj = PyObject :: new (
240
+ PyObjectPayload :: AnyRustValue {
241
+ value : Box :: new ( socket) ,
242
+ } ,
243
+ zelf. typ ( ) ,
244
+ ) ;
234
245
235
- let elements = RefCell :: new ( vec ! [ sock_obj, vm. get_none( ) ] ) ;
246
+ let elements = RefCell :: new ( vec ! [ sock_obj, vm. get_none( ) ] ) ;
236
247
237
- Ok ( PyObject :: new (
238
- PyObjectPayload :: Sequence { elements } ,
239
- vm. ctx . tuple_type ( ) ,
240
- ) )
241
- }
242
- _ => Err ( vm. new_type_error ( "" . to_string ( ) ) ) ,
243
- }
248
+ Ok ( PyObject :: new (
249
+ PyObjectPayload :: Sequence { elements } ,
250
+ vm. ctx . tuple_type ( ) ,
251
+ ) )
244
252
}
245
253
246
254
fn socket_recv ( vm : & mut VirtualMachine , args : PyFuncArgs ) -> PyResult {
@@ -249,17 +257,14 @@ fn socket_recv(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
249
257
args,
250
258
required = [ ( zelf, None ) , ( bufsize, Some ( vm. ctx. int_type( ) ) ) ]
251
259
) ;
252
- match zelf. payload {
253
- PyObjectPayload :: Socket { ref socket } => {
254
- let mut buffer = vec ! [ 0u8 ; objint:: get_value( bufsize) . to_usize( ) . unwrap( ) ] ;
255
- match socket. borrow_mut ( ) . con {
256
- Some ( ref mut v) => v. read_exact ( & mut buffer) . unwrap ( ) ,
257
- None => return Err ( vm. new_type_error ( "" . to_string ( ) ) ) ,
258
- } ;
259
- Ok ( vm. ctx . new_bytes ( buffer) )
260
- }
261
- _ => Err ( vm. new_type_error ( "" . to_string ( ) ) ) ,
262
- }
260
+ let mut socket = get_socket ( zelf) ;
261
+
262
+ let mut buffer = vec ! [ 0u8 ; objint:: get_value( bufsize) . to_usize( ) . unwrap( ) ] ;
263
+ match socket. con {
264
+ Some ( ref mut v) => v. read_exact ( & mut buffer) . unwrap ( ) ,
265
+ None => return Err ( vm. new_type_error ( "" . to_string ( ) ) ) ,
266
+ } ;
267
+ Ok ( vm. ctx . new_bytes ( buffer) )
263
268
}
264
269
265
270
fn socket_send ( vm : & mut VirtualMachine , args : PyFuncArgs ) -> PyResult {
@@ -268,60 +273,42 @@ fn socket_send(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
268
273
args,
269
274
required = [ ( zelf, None ) , ( bytes, Some ( vm. ctx. bytes_type( ) ) ) ]
270
275
) ;
271
- match zelf. payload {
272
- PyObjectPayload :: Socket { ref socket } => {
273
- match socket. borrow_mut ( ) . con {
274
- Some ( ref mut v) => v. write ( & objbytes:: get_value ( & bytes) ) . unwrap ( ) ,
275
- None => return Err ( vm. new_type_error ( "" . to_string ( ) ) ) ,
276
- } ;
277
- Ok ( vm. get_none ( ) )
278
- }
279
- _ => Err ( vm. new_type_error ( "" . to_string ( ) ) ) ,
280
- }
276
+ let mut socket = get_socket ( zelf) ;
277
+
278
+ match socket. con {
279
+ Some ( ref mut v) => v. write ( & objbytes:: get_value ( & bytes) ) . unwrap ( ) ,
280
+ None => return Err ( vm. new_type_error ( "" . to_string ( ) ) ) ,
281
+ } ;
282
+ Ok ( vm. get_none ( ) )
281
283
}
282
284
283
285
fn socket_close ( vm : & mut VirtualMachine , args : PyFuncArgs ) -> PyResult {
284
286
arg_check ! ( vm, args, required = [ ( zelf, None ) ] ) ;
285
- match zelf. payload {
286
- PyObjectPayload :: Socket { ref socket } => {
287
- let mut socket = socket. borrow_mut ( ) ;
288
- match socket. address_family {
289
- AddressFamily :: Inet => match socket. socket_kind {
290
- SocketKind :: Stream => {
291
- socket. con = None ;
292
- Ok ( vm. get_none ( ) )
293
- }
294
- _ => Err ( vm. new_type_error ( "" . to_string ( ) ) ) ,
295
- } ,
296
- _ => Err ( vm. new_type_error ( "" . to_string ( ) ) ) ,
297
- }
298
- }
299
- _ => Err ( vm. new_type_error ( "" . to_string ( ) ) ) ,
300
- }
287
+
288
+ let mut socket = get_socket ( zelf) ;
289
+ socket. con = None ;
290
+ Ok ( vm. get_none ( ) )
301
291
}
302
292
303
293
fn socket_getsockname ( vm : & mut VirtualMachine , args : PyFuncArgs ) -> PyResult {
304
294
arg_check ! ( vm, args, required = [ ( zelf, None ) ] ) ;
305
- match zelf. payload {
306
- PyObjectPayload :: Socket { ref socket } => {
307
- let addr = match socket. borrow_mut ( ) . con {
308
- Some ( ref mut v) => v. local_addr ( ) ,
309
- None => return Err ( vm. new_type_error ( "" . to_string ( ) ) ) ,
310
- } ;
311
-
312
- match addr {
313
- Ok ( addr) => {
314
- let port = vm. ctx . new_int ( addr. port ( ) ) ;
315
- let ip = vm. ctx . new_str ( addr. ip ( ) . to_string ( ) ) ;
316
- let elements = RefCell :: new ( vec ! [ ip, port] ) ;
317
-
318
- Ok ( PyObject :: new (
319
- PyObjectPayload :: Sequence { elements } ,
320
- vm. ctx . tuple_type ( ) ,
321
- ) )
322
- }
323
- _ => Err ( vm. new_type_error ( "" . to_string ( ) ) ) ,
324
- }
295
+ let mut socket = get_socket ( zelf) ;
296
+
297
+ let addr = match socket. con {
298
+ Some ( ref mut v) => v. local_addr ( ) ,
299
+ None => return Err ( vm. new_type_error ( "" . to_string ( ) ) ) ,
300
+ } ;
301
+
302
+ match addr {
303
+ Ok ( addr) => {
304
+ let port = vm. ctx . new_int ( addr. port ( ) ) ;
305
+ let ip = vm. ctx . new_str ( addr. ip ( ) . to_string ( ) ) ;
306
+ let elements = RefCell :: new ( vec ! [ ip, port] ) ;
307
+
308
+ Ok ( PyObject :: new (
309
+ PyObjectPayload :: Sequence { elements } ,
310
+ vm. ctx . tuple_type ( ) ,
311
+ ) )
325
312
}
326
313
_ => Err ( vm. new_type_error ( "" . to_string ( ) ) ) ,
327
314
}
0 commit comments