@@ -2,7 +2,7 @@ use std::cell::RefCell;
2
2
use std:: io;
3
3
use std:: io:: Read ;
4
4
use std:: io:: Write ;
5
- use std:: net:: { SocketAddr , TcpListener , TcpStream , UdpSocket } ;
5
+ use std:: net:: { SocketAddr , TcpListener , TcpStream } ;
6
6
7
7
use crate :: obj:: objbytes;
8
8
use crate :: obj:: objint;
@@ -52,7 +52,7 @@ impl SocketKind {
52
52
enum Connection {
53
53
TcpListener ( TcpListener ) ,
54
54
TcpStream ( TcpStream ) ,
55
- UdpSocket ( UdpSocket ) ,
55
+ // UdpSocket(UdpSocket),
56
56
}
57
57
58
58
impl Connection {
@@ -62,6 +62,13 @@ impl Connection {
62
62
_ => Err ( io:: Error :: new ( io:: ErrorKind :: Other , "oh no!" ) ) ,
63
63
}
64
64
}
65
+
66
+ fn local_addr ( & self ) -> io:: Result < SocketAddr > {
67
+ match self {
68
+ Connection :: TcpListener ( con) => con. local_addr ( ) ,
69
+ _ => Err ( io:: Error :: new ( io:: ErrorKind :: Other , "oh no!" ) ) ,
70
+ }
71
+ }
65
72
}
66
73
67
74
impl Read for Connection {
@@ -87,15 +94,15 @@ impl Write for Connection {
87
94
88
95
pub struct Socket {
89
96
address_family : AddressFamily ,
90
- sk : SocketKind ,
97
+ socket_kind : SocketKind ,
91
98
con : Option < Connection > ,
92
99
}
93
100
94
101
impl Socket {
95
- fn new ( address_family : AddressFamily , sk : SocketKind ) -> Socket {
102
+ fn new ( address_family : AddressFamily , socket_kind : SocketKind ) -> Socket {
96
103
Socket {
97
104
address_family,
98
- sk : sk ,
105
+ socket_kind : socket_kind ,
99
106
con : None ,
100
107
}
101
108
}
@@ -130,11 +137,7 @@ fn socket_connect(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
130
137
required = [ ( zelf, None ) , ( address, Some ( vm. ctx. tuple_type( ) ) ) ]
131
138
) ;
132
139
133
- let elements = get_elements ( address) ;
134
- let host = objstr:: get_value ( & elements[ 0 ] ) ;
135
- let port = objint:: get_value ( & elements[ 1 ] ) ;
136
-
137
- let address_string = format ! ( "{}:{}" , host, port. to_string( ) ) ;
140
+ let address_string = get_address_string ( vm, address) ?;
138
141
139
142
match zelf. payload {
140
143
PyObjectPayload :: Socket { ref socket } => {
@@ -157,11 +160,7 @@ fn socket_bind(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
157
160
required = [ ( zelf, None ) , ( address, Some ( vm. ctx. tuple_type( ) ) ) ]
158
161
) ;
159
162
160
- let elements = get_elements ( address) ;
161
- let host = objstr:: get_value ( & elements[ 0 ] ) ;
162
- let port = objint:: get_value ( & elements[ 1 ] ) ;
163
-
164
- let address_string = format ! ( "{}:{}" , host, port. to_string( ) ) ;
163
+ let address_string = get_address_string ( vm, address) ?;
165
164
166
165
match zelf. payload {
167
166
PyObjectPayload :: Socket { ref socket } => {
@@ -177,7 +176,36 @@ fn socket_bind(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
177
176
}
178
177
}
179
178
179
+ fn get_address_string (
180
+ vm : & mut VirtualMachine ,
181
+ address : & PyObjectRef ,
182
+ ) -> Result < String , PyObjectRef > {
183
+ let args = PyFuncArgs {
184
+ args : get_elements ( address) . to_vec ( ) ,
185
+ kwargs : vec ! [ ] ,
186
+ } ;
187
+ arg_check ! (
188
+ vm,
189
+ args,
190
+ required = [
191
+ ( host, Some ( vm. ctx. str_type( ) ) ) ,
192
+ ( port, Some ( vm. ctx. int_type( ) ) )
193
+ ]
194
+ ) ;
195
+
196
+ Ok ( format ! (
197
+ "{}:{}" ,
198
+ objstr:: get_value( host) ,
199
+ objint:: get_value( port) . to_string( )
200
+ ) )
201
+ }
202
+
180
203
fn socket_listen ( vm : & mut VirtualMachine , args : PyFuncArgs ) -> PyResult {
204
+ arg_check ! (
205
+ vm,
206
+ args,
207
+ required = [ ( _zelf, None ) , ( _num, Some ( vm. ctx. int_type( ) ) ) ]
208
+ ) ;
181
209
Ok ( vm. get_none ( ) )
182
210
}
183
211
@@ -197,8 +225,8 @@ fn socket_accept(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
197
225
} ;
198
226
199
227
let socket = RefCell :: new ( Socket {
200
- address_family : socket. borrow ( ) . address_family . clone ( ) ,
201
- sk : socket. borrow ( ) . sk . clone ( ) ,
228
+ address_family : socket. borrow ( ) . address_family ,
229
+ socket_kind : socket. borrow ( ) . socket_kind ,
202
230
con : Some ( Connection :: TcpStream ( tcp_stream) ) ,
203
231
} ) ;
204
232
@@ -223,10 +251,10 @@ fn socket_recv(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
223
251
) ;
224
252
match zelf. payload {
225
253
PyObjectPayload :: Socket { ref socket } => {
226
- let mut buffer = Vec :: new ( ) ;
227
- let _temp = match socket. borrow_mut ( ) . con {
228
- Some ( ref mut v) => v. read_to_end ( & mut buffer) . unwrap ( ) ,
229
- None => 0 ,
254
+ let mut buffer = vec ! [ 0u8 ; objint :: get_value ( bufsize ) . to_usize ( ) . unwrap ( ) ] ;
255
+ match socket. borrow_mut ( ) . con {
256
+ Some ( ref mut v) => v. read_exact ( & mut buffer) . unwrap ( ) ,
257
+ None => return Err ( vm . new_type_error ( "" . to_string ( ) ) ) ,
230
258
} ;
231
259
Ok ( vm. ctx . new_bytes ( buffer) )
232
260
}
@@ -258,7 +286,7 @@ fn socket_close(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
258
286
PyObjectPayload :: Socket { ref socket } => {
259
287
let mut socket = socket. borrow_mut ( ) ;
260
288
match socket. address_family {
261
- AddressFamily :: AfInet => match socket. sk {
289
+ AddressFamily :: AfInet => match socket. socket_kind {
262
290
SocketKind :: SockStream => {
263
291
socket. con = None ;
264
292
Ok ( vm. get_none ( ) )
@@ -272,6 +300,33 @@ fn socket_close(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
272
300
}
273
301
}
274
302
303
+ fn socket_getsockname ( vm : & mut VirtualMachine , args : PyFuncArgs ) -> PyResult {
304
+ arg_check ! ( vm, args, required = [ ( zelf, None ) ] ) ;
305
+ match zelf. payload {
306
+ PyObjectPayload :: Socket { ref socket } => {
307
+ let addr = match socket. borrow_mut ( ) . con {
308
+ Some ( ref mut v) => v. local_addr ( ) ,
309
+ None => return Err ( vm. new_type_error ( "" . to_string ( ) ) ) ,
310
+ } ;
311
+
312
+ match addr {
313
+ Ok ( addr) => {
314
+ let port = vm. ctx . new_int ( addr. port ( ) ) ;
315
+ let ip = vm. ctx . new_str ( addr. ip ( ) . to_string ( ) ) ;
316
+ let elements = RefCell :: new ( vec ! [ ip, port] ) ;
317
+
318
+ Ok ( PyObject :: new (
319
+ PyObjectPayload :: Sequence { elements } ,
320
+ vm. ctx . tuple_type ( ) ,
321
+ ) )
322
+ }
323
+ _ => return Err ( vm. new_type_error ( "" . to_string ( ) ) ) ,
324
+ }
325
+ }
326
+ _ => Err ( vm. new_type_error ( "" . to_string ( ) ) ) ,
327
+ }
328
+ }
329
+
275
330
pub fn mk_module ( ctx : & PyContext ) -> PyObjectRef {
276
331
let py_mod = ctx. new_module ( & "socket" . to_string ( ) , ctx. new_scope ( None ) ) ;
277
332
@@ -297,6 +352,7 @@ pub fn mk_module(ctx: &PyContext) -> PyObjectRef {
297
352
ctx. set_attr ( & socket, "accept" , ctx. new_rustfunc ( socket_accept) ) ;
298
353
ctx. set_attr ( & socket, "listen" , ctx. new_rustfunc ( socket_listen) ) ;
299
354
ctx. set_attr ( & socket, "close" , ctx. new_rustfunc ( socket_close) ) ;
355
+ ctx. set_attr ( & socket, "getsockname" , ctx. new_rustfunc ( socket_getsockname) ) ;
300
356
socket
301
357
} ;
302
358
ctx. set_attr ( & py_mod, "socket" , socket. clone ( ) ) ;
0 commit comments