Skip to content

Commit 206ba4d

Browse files
committed
Change socket to new args style
1 parent 3afb5d3 commit 206ba4d

File tree

1 file changed

+79
-113
lines changed

1 file changed

+79
-113
lines changed

vm/src/stdlib/socket.rs

Lines changed: 79 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@ use std::ops::Deref;
77

88
use crate::function::PyFuncArgs;
99
use crate::obj::objbytes;
10+
use crate::obj::objbytes::PyBytesRef;
1011
use crate::obj::objint;
11-
use crate::obj::objsequence::get_elements;
1212
use crate::obj::objstr;
13+
use crate::obj::objtuple::PyTupleRef;
1314
use crate::pyobject::{PyObjectRef, PyRef, PyResult, PyValue, TryFromObject};
1415
use crate::vm::VirtualMachine;
1516

@@ -167,88 +168,94 @@ fn get_socket<'a>(obj: &'a PyObjectRef) -> impl Deref<Target = Socket> + 'a {
167168

168169
type SocketRef = PyRef<Socket>;
169170

170-
fn socket_new(
171-
cls: PyClassRef,
172-
family: AddressFamily,
173-
kind: SocketKind,
174-
vm: &VirtualMachine,
175-
) -> PyResult<SocketRef> {
176-
Socket::new(family, kind).into_ref_with_type(vm, cls)
177-
}
178-
179-
fn socket_connect(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
180-
arg_check!(
181-
vm,
182-
args,
183-
required = [(zelf, None), (address, Some(vm.ctx.tuple_type()))]
184-
);
185-
186-
let address_string = get_address_string(vm, address)?;
171+
impl SocketRef {
172+
fn new(
173+
cls: PyClassRef,
174+
family: AddressFamily,
175+
kind: SocketKind,
176+
vm: &VirtualMachine,
177+
) -> PyResult<SocketRef> {
178+
Socket::new(family, kind).into_ref_with_type(vm, cls)
179+
}
187180

188-
let socket = get_socket(zelf);
181+
fn connect(self, address: PyTupleRef, vm: &VirtualMachine) -> PyResult {
182+
let address_string = get_address_string(vm, address)?;
189183

190-
match socket.socket_kind {
191-
SocketKind::Stream => match TcpStream::connect(address_string) {
192-
Ok(stream) => {
193-
socket
194-
.con
195-
.borrow_mut()
196-
.replace(Connection::TcpStream(stream));
197-
Ok(vm.get_none())
198-
}
199-
Err(s) => Err(vm.new_os_error(s.to_string())),
200-
},
201-
SocketKind::Dgram => {
202-
if let Some(Connection::UdpSocket(con)) = socket.con.borrow().as_ref() {
203-
match con.connect(address_string) {
204-
Ok(_) => Ok(vm.get_none()),
205-
Err(s) => Err(vm.new_os_error(s.to_string())),
184+
match self.socket_kind {
185+
SocketKind::Stream => match TcpStream::connect(address_string) {
186+
Ok(stream) => {
187+
self.con.borrow_mut().replace(Connection::TcpStream(stream));
188+
Ok(vm.get_none())
189+
}
190+
Err(s) => Err(vm.new_os_error(s.to_string())),
191+
},
192+
SocketKind::Dgram => {
193+
if let Some(Connection::UdpSocket(con)) = self.con.borrow().as_ref() {
194+
match con.connect(address_string) {
195+
Ok(_) => Ok(vm.get_none()),
196+
Err(s) => Err(vm.new_os_error(s.to_string())),
197+
}
198+
} else {
199+
Err(vm.new_type_error("".to_string()))
206200
}
207-
} else {
208-
Err(vm.new_type_error("".to_string()))
209201
}
210202
}
211203
}
212-
}
213204

214-
fn socket_bind(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
215-
arg_check!(
216-
vm,
217-
args,
218-
required = [(zelf, None), (address, Some(vm.ctx.tuple_type()))]
219-
);
205+
fn bind(self, address: PyTupleRef, vm: &VirtualMachine) -> PyResult {
206+
let address_string = get_address_string(vm, address)?;
220207

221-
let address_string = get_address_string(vm, address)?;
208+
match self.socket_kind {
209+
SocketKind::Stream => match TcpListener::bind(address_string) {
210+
Ok(stream) => {
211+
self.con
212+
.borrow_mut()
213+
.replace(Connection::TcpListener(stream));
214+
Ok(vm.get_none())
215+
}
216+
Err(s) => Err(vm.new_os_error(s.to_string())),
217+
},
218+
SocketKind::Dgram => match UdpSocket::bind(address_string) {
219+
Ok(dgram) => {
220+
self.con.borrow_mut().replace(Connection::UdpSocket(dgram));
221+
Ok(vm.get_none())
222+
}
223+
Err(s) => Err(vm.new_os_error(s.to_string())),
224+
},
225+
}
226+
}
222227

223-
let socket = get_socket(zelf);
228+
fn sendto(self, bytes: PyBytesRef, address: PyTupleRef, vm: &VirtualMachine) -> PyResult {
229+
let address_string = get_address_string(vm, address)?;
224230

225-
match socket.socket_kind {
226-
SocketKind::Stream => match TcpListener::bind(address_string) {
227-
Ok(stream) => {
228-
socket
229-
.con
230-
.borrow_mut()
231-
.replace(Connection::TcpListener(stream));
232-
Ok(vm.get_none())
233-
}
234-
Err(s) => Err(vm.new_os_error(s.to_string())),
235-
},
236-
SocketKind::Dgram => match UdpSocket::bind(address_string) {
237-
Ok(dgram) => {
238-
socket
239-
.con
240-
.borrow_mut()
241-
.replace(Connection::UdpSocket(dgram));
242-
Ok(vm.get_none())
231+
match self.socket_kind {
232+
SocketKind::Dgram => {
233+
if let Some(v) = self.con.borrow().as_ref() {
234+
return match v.send_to(&bytes, address_string) {
235+
Ok(_) => Ok(vm.get_none()),
236+
Err(s) => Err(vm.new_os_error(s.to_string())),
237+
};
238+
}
239+
// Doing implicit bind
240+
match UdpSocket::bind("0.0.0.0:0") {
241+
Ok(dgram) => match dgram.send_to(&bytes, address_string) {
242+
Ok(_) => {
243+
self.con.borrow_mut().replace(Connection::UdpSocket(dgram));
244+
Ok(vm.get_none())
245+
}
246+
Err(s) => Err(vm.new_os_error(s.to_string())),
247+
},
248+
Err(s) => Err(vm.new_os_error(s.to_string())),
249+
}
243250
}
244-
Err(s) => Err(vm.new_os_error(s.to_string())),
245-
},
251+
_ => Err(vm.new_not_implemented_error("".to_string())),
252+
}
246253
}
247254
}
248255

249-
fn get_address_string(vm: &VirtualMachine, address: &PyObjectRef) -> Result<String, PyObjectRef> {
256+
fn get_address_string(vm: &VirtualMachine, address: PyTupleRef) -> Result<String, PyObjectRef> {
250257
let args = PyFuncArgs {
251-
args: get_elements(address).to_vec(),
258+
args: address.elements.borrow().to_vec(),
252259
kwargs: vec![],
253260
};
254261
arg_check!(
@@ -365,47 +372,6 @@ fn socket_send(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
365372
Ok(vm.get_none())
366373
}
367374

368-
fn socket_sendto(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
369-
arg_check!(
370-
vm,
371-
args,
372-
required = [
373-
(zelf, None),
374-
(bytes, Some(vm.ctx.bytes_type())),
375-
(address, Some(vm.ctx.tuple_type()))
376-
]
377-
);
378-
let address_string = get_address_string(vm, address)?;
379-
380-
let socket = get_socket(zelf);
381-
382-
match socket.socket_kind {
383-
SocketKind::Dgram => {
384-
if let Some(v) = socket.con.borrow().as_ref() {
385-
return match v.send_to(&objbytes::get_value(&bytes), address_string) {
386-
Ok(_) => Ok(vm.get_none()),
387-
Err(s) => Err(vm.new_os_error(s.to_string())),
388-
};
389-
}
390-
// Doing implicit bind
391-
match UdpSocket::bind("0.0.0.0:0") {
392-
Ok(dgram) => match dgram.send_to(&objbytes::get_value(&bytes), address_string) {
393-
Ok(_) => {
394-
socket
395-
.con
396-
.borrow_mut()
397-
.replace(Connection::UdpSocket(dgram));
398-
Ok(vm.get_none())
399-
}
400-
Err(s) => Err(vm.new_os_error(s.to_string())),
401-
},
402-
Err(s) => Err(vm.new_os_error(s.to_string())),
403-
}
404-
}
405-
_ => Err(vm.new_not_implemented_error("".to_string())),
406-
}
407-
}
408-
409375
fn socket_close(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
410376
arg_check!(vm, args, required = [(zelf, None)]);
411377

@@ -452,16 +418,16 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
452418
let ctx = &vm.ctx;
453419

454420
let socket = py_class!(ctx, "socket", ctx.object(), {
455-
"__new__" => ctx.new_rustfunc(socket_new),
456-
"connect" => ctx.new_rustfunc(socket_connect),
421+
"__new__" => ctx.new_rustfunc(SocketRef::new),
422+
"connect" => ctx.new_rustfunc(SocketRef::connect),
457423
"recv" => ctx.new_rustfunc(socket_recv),
458424
"send" => ctx.new_rustfunc(socket_send),
459-
"bind" => ctx.new_rustfunc(socket_bind),
425+
"bind" => ctx.new_rustfunc(SocketRef::bind),
460426
"accept" => ctx.new_rustfunc(socket_accept),
461427
"listen" => ctx.new_rustfunc(socket_listen),
462428
"close" => ctx.new_rustfunc(socket_close),
463429
"getsockname" => ctx.new_rustfunc(socket_getsockname),
464-
"sendto" => ctx.new_rustfunc(socket_sendto),
430+
"sendto" => ctx.new_rustfunc(SocketRef::sendto),
465431
"recvfrom" => ctx.new_rustfunc(socket_recvfrom),
466432
"fileno" => ctx.new_rustfunc(socket_fileno),
467433
});

0 commit comments

Comments
 (0)