Skip to content

Commit b98926a

Browse files
Merge pull request RustPython#572 from RustPython/joey/fun-with-functions
Derive types, arity, conversions and more from rust fns
2 parents 30ddb48 + 3478251 commit b98926a

File tree

9 files changed

+438
-250
lines changed

9 files changed

+438
-250
lines changed

vm/src/function.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
use std::marker::PhantomData;
2+
use std::ops::Deref;
3+
4+
use crate::obj::objtype;
5+
use crate::pyobject::{PyObjectPayload2, PyObjectRef, PyResult, TryFromObject};
6+
use crate::vm::VirtualMachine;
7+
8+
// TODO: Move PyFuncArgs, FromArgs, etc. here
9+
10+
pub struct PyRef<T> {
11+
// invariant: this obj must always have payload of type T
12+
obj: PyObjectRef,
13+
_payload: PhantomData<T>,
14+
}
15+
16+
impl<T> Deref for PyRef<T>
17+
where
18+
T: PyObjectPayload2,
19+
{
20+
type Target = T;
21+
22+
fn deref(&self) -> &T {
23+
self.obj.payload().expect("unexpected payload for type")
24+
}
25+
}
26+
27+
impl<T> TryFromObject for PyRef<T>
28+
where
29+
T: PyObjectPayload2,
30+
{
31+
fn try_from_object(vm: &mut VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
32+
if objtype::isinstance(&obj, &T::required_type(&vm.ctx)) {
33+
Ok(PyRef {
34+
obj,
35+
_payload: PhantomData,
36+
})
37+
} else {
38+
Err(vm.new_type_error("wrong type".to_string())) // TODO: better message
39+
}
40+
}
41+
}

vm/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ pub mod eval;
4141
mod exceptions;
4242
pub mod format;
4343
pub mod frame;
44+
pub mod function;
4445
pub mod import;
4546
pub mod obj;
4647
pub mod pyobject;

vm/src/obj/objbool.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use super::objstr::PyString;
12
use super::objtype;
23
use crate::pyobject::{
34
IntoPyObject, PyContext, PyFuncArgs, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol,
@@ -12,12 +13,14 @@ impl IntoPyObject for bool {
1213
}
1314

1415
pub fn boolval(vm: &mut VirtualMachine, obj: PyObjectRef) -> Result<bool, PyObjectRef> {
16+
if let Some(s) = obj.payload::<PyString>() {
17+
return Ok(!s.value.is_empty());
18+
}
1519
let result = match obj.payload {
1620
PyObjectPayload::Integer { ref value } => !value.is_zero(),
1721
PyObjectPayload::Float { value } => value != 0.0,
1822
PyObjectPayload::Sequence { ref elements } => !elements.borrow().is_empty(),
1923
PyObjectPayload::Dict { ref elements } => !elements.borrow().is_empty(),
20-
PyObjectPayload::String { ref value } => !value.is_empty(),
2124
PyObjectPayload::None { .. } => false,
2225
_ => {
2326
if let Ok(f) = vm.get_method(obj.clone(), "__bool__") {

vm/src/obj/objint.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use super::objtype;
44
use crate::format::FormatSpec;
55
use crate::pyobject::{
66
FromPyObjectRef, IntoPyObject, PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef,
7-
PyResult, TypeProtocol,
7+
PyResult, TryFromObject, TypeProtocol,
88
};
99
use crate::vm::VirtualMachine;
1010
use num_bigint::{BigInt, ToBigInt};
@@ -31,6 +31,26 @@ impl IntoPyObject for usize {
3131
}
3232
}
3333

34+
impl TryFromObject for usize {
35+
fn try_from_object(vm: &mut VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
36+
// FIXME: don't use get_value
37+
match get_value(&obj).to_usize() {
38+
Some(value) => Ok(value),
39+
None => Err(vm.new_overflow_error("Int value cannot fit into Rust usize".to_string())),
40+
}
41+
}
42+
}
43+
44+
impl TryFromObject for isize {
45+
fn try_from_object(vm: &mut VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
46+
// FIXME: don't use get_value
47+
match get_value(&obj).to_isize() {
48+
Some(value) => Ok(value),
49+
None => Err(vm.new_overflow_error("Int value cannot fit into Rust isize".to_string())),
50+
}
51+
}
52+
}
53+
3454
fn int_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
3555
arg_check!(vm, args, required = [(int, Some(vm.ctx.int_type()))]);
3656
let v = get_value(int);

vm/src/obj/objrange.rs

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@ use std::ops::Mul;
44
use super::objint;
55
use super::objtype;
66
use crate::pyobject::{
7-
FromPyObject, PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult,
8-
TypeProtocol,
7+
PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol,
98
};
109
use crate::vm::VirtualMachine;
1110
use num_bigint::{BigInt, Sign};
@@ -21,18 +20,6 @@ pub struct RangeType {
2120
pub step: BigInt,
2221
}
2322

24-
type PyRange = RangeType;
25-
26-
impl FromPyObject for PyRange {
27-
fn typ(ctx: &PyContext) -> Option<PyObjectRef> {
28-
Some(ctx.range_type())
29-
}
30-
31-
fn from_pyobject(obj: PyObjectRef) -> PyResult<Self> {
32-
Ok(get_value(&obj))
33-
}
34-
}
35-
3623
impl RangeType {
3724
#[inline]
3825
pub fn try_len(&self) -> Option<usize> {
@@ -360,12 +347,22 @@ fn range_bool(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
360347
Ok(vm.ctx.new_bool(len > 0))
361348
}
362349

363-
fn range_contains(vm: &mut VirtualMachine, zelf: PyRange, needle: PyObjectRef) -> bool {
364-
if objtype::isinstance(&needle, &vm.ctx.int_type()) {
365-
zelf.contains(&objint::get_value(&needle))
350+
fn range_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
351+
arg_check!(
352+
vm,
353+
args,
354+
required = [(zelf, Some(vm.ctx.range_type())), (needle, None)]
355+
);
356+
357+
let range = get_value(zelf);
358+
359+
let result = if objtype::isinstance(needle, &vm.ctx.int_type()) {
360+
range.contains(&objint::get_value(needle))
366361
} else {
367362
false
368-
}
363+
};
364+
365+
Ok(vm.ctx.new_bool(result))
369366
}
370367

371368
fn range_index(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {

vm/src/obj/objstr.rs

Lines changed: 77 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@ use super::objint;
22
use super::objsequence::PySliceableSequence;
33
use super::objtype;
44
use crate::format::{FormatParseError, FormatPart, FormatString};
5+
use crate::function::PyRef;
56
use crate::pyobject::{
6-
FromPyObject, PyContext, PyFuncArgs, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol,
7+
OptArg, PyContext, PyFuncArgs, PyIterable, PyObjectPayload, PyObjectPayload2, PyObjectRef,
8+
PyResult, TypeProtocol,
79
};
810
use crate::vm::VirtualMachine;
911
use num_traits::ToPrimitive;
@@ -16,13 +18,71 @@ extern crate unicode_segmentation;
1618

1719
use self::unicode_segmentation::UnicodeSegmentation;
1820

19-
impl FromPyObject for String {
20-
fn typ(ctx: &PyContext) -> Option<PyObjectRef> {
21-
Some(ctx.str_type())
21+
#[derive(Clone, Debug)]
22+
pub struct PyString {
23+
// TODO: shouldn't be public
24+
pub value: String,
25+
}
26+
27+
impl PyString {
28+
pub fn endswith(
29+
zelf: PyRef<Self>,
30+
suffix: PyRef<Self>,
31+
start: OptArg<usize>,
32+
end: OptArg<usize>,
33+
_vm: &mut VirtualMachine,
34+
) -> bool {
35+
let start = start.unwrap_or(0);
36+
let end = end.unwrap_or(zelf.value.len());
37+
zelf.value[start..end].ends_with(&suffix.value)
2238
}
2339

24-
fn from_pyobject(obj: PyObjectRef) -> PyResult<Self> {
25-
Ok(get_value(&obj))
40+
pub fn startswith(
41+
zelf: PyRef<Self>,
42+
prefix: PyRef<Self>,
43+
start: OptArg<usize>,
44+
end: OptArg<usize>,
45+
_vm: &mut VirtualMachine,
46+
) -> bool {
47+
let start = start.unwrap_or(0);
48+
let end = end.unwrap_or(zelf.value.len());
49+
zelf.value[start..end].starts_with(&prefix.value)
50+
}
51+
52+
fn upper(zelf: PyRef<Self>, _vm: &mut VirtualMachine) -> PyString {
53+
PyString {
54+
value: zelf.value.to_uppercase(),
55+
}
56+
}
57+
58+
fn lower(zelf: PyRef<Self>, _vm: &mut VirtualMachine) -> PyString {
59+
PyString {
60+
value: zelf.value.to_lowercase(),
61+
}
62+
}
63+
64+
fn join(
65+
zelf: PyRef<Self>,
66+
iterable: PyIterable<PyRef<Self>>,
67+
vm: &mut VirtualMachine,
68+
) -> PyResult<PyString> {
69+
let mut joined = String::new();
70+
71+
for (idx, elem) in iterable.iter(vm)?.enumerate() {
72+
let elem = elem?;
73+
if idx != 0 {
74+
joined.push_str(&zelf.value);
75+
}
76+
joined.push_str(&elem.value)
77+
}
78+
79+
Ok(PyString { value: joined })
80+
}
81+
}
82+
83+
impl PyObjectPayload2 for PyString {
84+
fn required_type(ctx: &PyContext) -> PyObjectRef {
85+
ctx.str_type()
2686
}
2787
}
2888

@@ -47,9 +107,9 @@ pub fn init(context: &PyContext) {
47107
context.set_attr(&str_type, "__str__", context.new_rustfunc(str_str));
48108
context.set_attr(&str_type, "__repr__", context.new_rustfunc(str_repr));
49109
context.set_attr(&str_type, "format", context.new_rustfunc(str_format));
50-
context.set_attr(&str_type, "lower", context.new_rustfunc(str_lower));
110+
context.set_attr(&str_type, "lower", context.new_rustfunc(PyString::lower));
51111
context.set_attr(&str_type, "casefold", context.new_rustfunc(str_casefold));
52-
context.set_attr(&str_type, "upper", context.new_rustfunc(str_upper));
112+
context.set_attr(&str_type, "upper", context.new_rustfunc(PyString::upper));
53113
context.set_attr(
54114
&str_type,
55115
"capitalize",
@@ -60,11 +120,15 @@ pub fn init(context: &PyContext) {
60120
context.set_attr(&str_type, "strip", context.new_rustfunc(str_strip));
61121
context.set_attr(&str_type, "lstrip", context.new_rustfunc(str_lstrip));
62122
context.set_attr(&str_type, "rstrip", context.new_rustfunc(str_rstrip));
63-
context.set_attr(&str_type, "endswith", context.new_rustfunc(str_endswith));
123+
context.set_attr(
124+
&str_type,
125+
"endswith",
126+
context.new_rustfunc(PyString::endswith),
127+
);
64128
context.set_attr(
65129
&str_type,
66130
"startswith",
67-
context.new_rustfunc(str_startswith),
131+
context.new_rustfunc(PyString::startswith),
68132
);
69133
context.set_attr(&str_type, "isalnum", context.new_rustfunc(str_isalnum));
70134
context.set_attr(&str_type, "isnumeric", context.new_rustfunc(str_isnumeric));
@@ -84,7 +148,7 @@ pub fn init(context: &PyContext) {
84148
"splitlines",
85149
context.new_rustfunc(str_splitlines),
86150
);
87-
context.set_attr(&str_type, "join", context.new_rustfunc(str_join));
151+
context.set_attr(&str_type, "join", context.new_rustfunc(PyString::join));
88152
context.set_attr(&str_type, "find", context.new_rustfunc(str_find));
89153
context.set_attr(&str_type, "rfind", context.new_rustfunc(str_rfind));
90154
context.set_attr(&str_type, "index", context.new_rustfunc(str_index));
@@ -113,19 +177,11 @@ pub fn init(context: &PyContext) {
113177
}
114178

115179
pub fn get_value(obj: &PyObjectRef) -> String {
116-
if let PyObjectPayload::String { value } = &obj.payload {
117-
value.to_string()
118-
} else {
119-
panic!("Inner error getting str");
120-
}
180+
obj.payload::<PyString>().unwrap().value.clone()
121181
}
122182

123183
pub fn borrow_value(obj: &PyObjectRef) -> &str {
124-
if let PyObjectPayload::String { value } = &obj.payload {
125-
value.as_str()
126-
} else {
127-
panic!("Inner error getting str");
128-
}
184+
&obj.payload::<PyString>().unwrap().value
129185
}
130186

131187
fn str_eq(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
@@ -387,18 +443,6 @@ fn str_mul(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
387443
}
388444
}
389445

390-
fn str_upper(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
391-
arg_check!(vm, args, required = [(s, Some(vm.ctx.str_type()))]);
392-
let value = get_value(&s).to_uppercase();
393-
Ok(vm.ctx.new_str(value))
394-
}
395-
396-
fn str_lower(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
397-
arg_check!(vm, args, required = [(s, Some(vm.ctx.str_type()))]);
398-
let value = get_value(&s).to_lowercase();
399-
Ok(vm.ctx.new_str(value))
400-
}
401-
402446
fn str_capitalize(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
403447
arg_check!(vm, args, required = [(s, Some(vm.ctx.str_type()))]);
404448
let value = get_value(&s);
@@ -477,10 +521,6 @@ fn str_rstrip(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
477521
Ok(vm.ctx.new_str(value))
478522
}
479523

480-
fn str_endswith(_vm: &mut VirtualMachine, zelf: String, suffix: String) -> bool {
481-
zelf.ends_with(&suffix)
482-
}
483-
484524
fn str_isidentifier(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
485525
arg_check!(vm, args, required = [(s, Some(vm.ctx.str_type()))]);
486526
let value = get_value(&s);
@@ -560,22 +600,6 @@ fn str_zfill(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
560600
Ok(vm.ctx.new_str(new_str))
561601
}
562602

563-
fn str_join(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
564-
arg_check!(
565-
vm,
566-
args,
567-
required = [(s, Some(vm.ctx.str_type())), (iterable, None)]
568-
);
569-
let value = get_value(&s);
570-
let elements: Vec<String> = vm
571-
.extract_elements(iterable)?
572-
.iter()
573-
.map(|w| get_value(&w))
574-
.collect();
575-
let joined = elements.join(&value);
576-
Ok(vm.ctx.new_str(joined))
577-
}
578-
579603
fn str_count(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
580604
arg_check!(
581605
vm,
@@ -865,17 +889,6 @@ fn str_center(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
865889
Ok(vm.ctx.new_str(new_str))
866890
}
867891

868-
fn str_startswith(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
869-
arg_check!(
870-
vm,
871-
args,
872-
required = [(s, Some(vm.ctx.str_type())), (pat, Some(vm.ctx.str_type()))]
873-
);
874-
let value = get_value(&s);
875-
let pat = get_value(&pat);
876-
Ok(vm.ctx.new_bool(value.starts_with(pat.as_str())))
877-
}
878-
879892
fn str_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
880893
arg_check!(
881894
vm,

0 commit comments

Comments
 (0)