Skip to content

Commit d7e0cee

Browse files
committed
New (PyRef<T>) style types.
1 parent aa38a1a commit d7e0cee

File tree

1 file changed

+45
-59
lines changed

1 file changed

+45
-59
lines changed

vm/src/stdlib/ast.rs

Lines changed: 45 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,10 @@ use num_complex::Complex64;
99

1010
use rustpython_parser::{ast, parser};
1111

12-
use crate::function::PyFuncArgs;
1312
use crate::obj::objlist::PyListRef;
14-
use crate::obj::objstr;
13+
use crate::obj::objstr::PyStringRef;
1514
use crate::obj::objtype::PyClassRef;
16-
use crate::pyobject::{PyObject, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol};
15+
use crate::pyobject::{PyObjectRef, PyRef, PyResult, PyValue};
1716
use crate::vm::VirtualMachine;
1817

1918
#[derive(Debug)]
@@ -26,28 +25,12 @@ impl PyValue for AstNode {
2625
}
2726
}
2827

29-
/*
30-
* Idea: maybe we can create a sort of struct with some helper functions?
31-
struct AstToPyAst {
32-
ctx: &PyContext,
33-
}
34-
35-
impl AstToPyAst {
36-
fn new(ctx: &PyContext) -> Self {
37-
AstToPyAst {
38-
ctx: ctx,
39-
}
40-
}
41-
42-
}
43-
*/
44-
4528
macro_rules! node {
4629
( $vm: expr, $node_name:ident, { $($attr_name:ident => $attr_value:expr),* $(,)* }) => {
4730
{
4831
let node = create_node($vm, stringify!($node_name));
4932
$(
50-
$vm.set_attr(&node, stringify!($attr_name), $attr_value).unwrap();
33+
$vm.set_attr(node.as_object(), stringify!($attr_name), $attr_value).unwrap();
5134
)*
5235
node
5336
}
@@ -57,25 +40,27 @@ macro_rules! node {
5740
}
5841
}
5942

60-
fn program_to_ast(vm: &VirtualMachine, program: &ast::Program) -> PyObjectRef {
43+
fn program_to_ast(vm: &VirtualMachine, program: &ast::Program) -> AstNodeRef {
6144
let py_body = statements_to_ast(vm, &program.statements);
6245
node!(vm, Module, { body => py_body })
6346
}
6447

6548
// Create a node class instance
66-
fn create_node(vm: &VirtualMachine, name: &str) -> PyObjectRef {
67-
PyObject::new(AstNode, vm.class("ast", name), Some(vm.ctx.new_dict()))
49+
fn create_node(vm: &VirtualMachine, name: &str) -> AstNodeRef {
50+
AstNode
51+
.into_ref_with_type(vm, vm.class("ast", name))
52+
.unwrap()
6853
}
6954

7055
fn statements_to_ast(vm: &VirtualMachine, statements: &[ast::LocatedStatement]) -> PyListRef {
7156
let body = statements
7257
.iter()
73-
.map(|statement| statement_to_ast(&vm, statement))
58+
.map(|statement| statement_to_ast(&vm, statement).into_object())
7459
.collect();
7560
vm.ctx.new_list(body).downcast().unwrap()
7661
}
7762

78-
fn statement_to_ast(vm: &VirtualMachine, statement: &ast::LocatedStatement) -> PyObjectRef {
63+
fn statement_to_ast(vm: &VirtualMachine, statement: &ast::LocatedStatement) -> AstNodeRef {
7964
let node = match &statement.node {
8065
ast::Statement::ClassDef {
8166
name,
@@ -95,7 +80,7 @@ fn statement_to_ast(vm: &VirtualMachine, statement: &ast::LocatedStatement) -> P
9580
returns,
9681
} => {
9782
let py_returns = if let Some(hint) = returns {
98-
expression_to_ast(vm, hint)
83+
expression_to_ast(vm, hint).into_object()
9984
} else {
10085
vm.ctx.none()
10186
};
@@ -112,7 +97,7 @@ fn statement_to_ast(vm: &VirtualMachine, statement: &ast::LocatedStatement) -> P
11297
ast::Statement::Pass => node!(vm, Pass),
11398
ast::Statement::Assert { test, msg } => {
11499
let py_msg = match msg {
115-
Some(msg) => expression_to_ast(vm, msg),
100+
Some(msg) => expression_to_ast(vm, msg).into_object(),
116101
None => vm.ctx.none(),
117102
};
118103
node!(vm, Assert, {
@@ -121,15 +106,18 @@ fn statement_to_ast(vm: &VirtualMachine, statement: &ast::LocatedStatement) -> P
121106
})
122107
}
123108
ast::Statement::Delete { targets } => {
124-
let py_targets = vm
125-
.ctx
126-
.new_tuple(targets.iter().map(|v| expression_to_ast(vm, v)).collect());
109+
let py_targets = vm.ctx.new_tuple(
110+
targets
111+
.iter()
112+
.map(|v| expression_to_ast(vm, v).into_object())
113+
.collect(),
114+
);
127115

128116
node!(vm, Delete, { targets => py_targets })
129117
}
130118
ast::Statement::Return { value } => {
131119
let py_value = if let Some(value) = value {
132-
expression_to_ast(vm, value)
120+
expression_to_ast(vm, value).into_object()
133121
} else {
134122
vm.ctx.none()
135123
};
@@ -181,20 +169,20 @@ fn statement_to_ast(vm: &VirtualMachine, statement: &ast::LocatedStatement) -> P
181169

182170
// set lineno on node:
183171
let lineno = vm.ctx.new_int(statement.location.get_row());
184-
vm.set_attr(&node, "lineno", lineno).unwrap();
172+
vm.set_attr(node.as_object(), "lineno", lineno).unwrap();
185173

186174
node
187175
}
188176

189177
fn expressions_to_ast(vm: &VirtualMachine, expressions: &[ast::Expression]) -> PyListRef {
190178
let py_expression_nodes = expressions
191179
.iter()
192-
.map(|expression| expression_to_ast(vm, expression))
180+
.map(|expression| expression_to_ast(vm, expression).into_object())
193181
.collect();
194182
vm.ctx.new_list(py_expression_nodes).downcast().unwrap()
195183
}
196184

197-
fn expression_to_ast(vm: &VirtualMachine, expression: &ast::Expression) -> PyObjectRef {
185+
fn expression_to_ast(vm: &VirtualMachine, expression: &ast::Expression) -> AstNodeRef {
198186
let node = match &expression {
199187
ast::Expression::Call { function, args, .. } => node!(vm, Call, {
200188
func => expression_to_ast(vm, function),
@@ -237,8 +225,8 @@ fn expression_to_ast(vm: &VirtualMachine, expression: &ast::Expression) -> PyObj
237225
}
238226
ast::Expression::BoolOp { a, op, b } => {
239227
// Attach values:
240-
let py_a = expression_to_ast(vm, a);
241-
let py_b = expression_to_ast(vm, b);
228+
let py_a = expression_to_ast(vm, a).into_object();
229+
let py_b = expression_to_ast(vm, b).into_object();
242230
let py_values = vm.ctx.new_tuple(vec![py_a, py_b]);
243231

244232
let str_op = match op {
@@ -277,7 +265,7 @@ fn expression_to_ast(vm: &VirtualMachine, expression: &ast::Expression) -> PyObj
277265
let py_b = vm.ctx.new_list(
278266
vals.iter()
279267
.skip(1)
280-
.map(|x| expression_to_ast(vm, x))
268+
.map(|x| expression_to_ast(vm, x).into_object())
281269
.collect(),
282270
);
283271
node!(vm, Compare, {
@@ -333,8 +321,8 @@ fn expression_to_ast(vm: &VirtualMachine, expression: &ast::Expression) -> PyObj
333321
let mut keys = Vec::new();
334322
let mut values = Vec::new();
335323
for (k, v) in elements {
336-
keys.push(expression_to_ast(vm, k));
337-
values.push(expression_to_ast(vm, v));
324+
keys.push(expression_to_ast(vm, k).into_object());
325+
values.push(expression_to_ast(vm, v).into_object());
338326
}
339327

340328
node!(vm, Dict, {
@@ -345,7 +333,7 @@ fn expression_to_ast(vm: &VirtualMachine, expression: &ast::Expression) -> PyObj
345333
ast::Expression::Comprehension { kind, generators } => {
346334
let g = generators
347335
.iter()
348-
.map(|g| comprehension_to_ast(vm, g))
336+
.map(|g| comprehension_to_ast(vm, g).into_object())
349337
.collect();
350338
let py_generators = vm.ctx.new_list(g);
351339

@@ -366,7 +354,7 @@ fn expression_to_ast(vm: &VirtualMachine, expression: &ast::Expression) -> PyObj
366354
}
367355
ast::Expression::Yield { value } => {
368356
let py_value = match value {
369-
Some(value) => expression_to_ast(vm, value),
357+
Some(value) => expression_to_ast(vm, value).into_object(),
370358
None => vm.ctx.none(),
371359
};
372360
node!(vm, Yield, {
@@ -401,22 +389,24 @@ fn expression_to_ast(vm: &VirtualMachine, expression: &ast::Expression) -> PyObj
401389

402390
// TODO: retrieve correct lineno:
403391
let lineno = vm.ctx.new_int(1);
404-
vm.set_attr(&node, "lineno", lineno).unwrap();
405-
392+
vm.set_attr(node.as_object(), "lineno", lineno).unwrap();
406393
node
407394
}
408395

409-
fn parameters_to_ast(vm: &VirtualMachine, args: &ast::Parameters) -> PyObjectRef {
410-
let args = vm
411-
.ctx
412-
.new_list(args.args.iter().map(|a| parameter_to_ast(vm, a)).collect());
396+
fn parameters_to_ast(vm: &VirtualMachine, args: &ast::Parameters) -> AstNodeRef {
397+
let args = vm.ctx.new_list(
398+
args.args
399+
.iter()
400+
.map(|a| parameter_to_ast(vm, a).into_object())
401+
.collect(),
402+
);
413403

414404
node!(vm, arguments, { args => args })
415405
}
416406

417-
fn parameter_to_ast(vm: &VirtualMachine, parameter: &ast::Parameter) -> PyObjectRef {
407+
fn parameter_to_ast(vm: &VirtualMachine, parameter: &ast::Parameter) -> AstNodeRef {
418408
let py_annotation = if let Some(annotation) = &parameter.annotation {
419-
expression_to_ast(vm, annotation)
409+
expression_to_ast(vm, annotation).into_object()
420410
} else {
421411
vm.ctx.none()
422412
};
@@ -427,15 +417,15 @@ fn parameter_to_ast(vm: &VirtualMachine, parameter: &ast::Parameter) -> PyObject
427417
})
428418
}
429419

430-
fn comprehension_to_ast(vm: &VirtualMachine, comprehension: &ast::Comprehension) -> PyObjectRef {
420+
fn comprehension_to_ast(vm: &VirtualMachine, comprehension: &ast::Comprehension) -> AstNodeRef {
431421
node!(vm, comprehension, {
432422
target => expression_to_ast(vm, &comprehension.target),
433423
iter => expression_to_ast(vm, &comprehension.iter),
434424
ifs => expressions_to_ast(vm, &comprehension.ifs),
435425
})
436426
}
437427

438-
fn string_to_ast(vm: &VirtualMachine, string: &ast::StringGroup) -> PyObjectRef {
428+
fn string_to_ast(vm: &VirtualMachine, string: &ast::StringGroup) -> AstNodeRef {
439429
match string {
440430
ast::StringGroup::Constant { value } => {
441431
node!(vm, Str, { s => vm.ctx.new_str(value.clone()) })
@@ -447,23 +437,19 @@ fn string_to_ast(vm: &VirtualMachine, string: &ast::StringGroup) -> PyObjectRef
447437
let py_values = vm.ctx.new_list(
448438
values
449439
.iter()
450-
.map(|value| string_to_ast(vm, value))
440+
.map(|value| string_to_ast(vm, value).into_object())
451441
.collect(),
452442
);
453443
node!(vm, JoinedStr, { values => py_values })
454444
}
455445
}
456446
}
457447

458-
fn ast_parse(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
459-
arg_check!(vm, args, required = [(source, Some(vm.ctx.str_type()))]);
460-
461-
let source_string = objstr::get_value(source);
462-
let internal_ast = parser::parse_program(&source_string)
448+
fn ast_parse(source: PyStringRef, vm: &VirtualMachine) -> PyResult<AstNodeRef> {
449+
let internal_ast = parser::parse_program(&source.value)
463450
.map_err(|err| vm.new_value_error(format!("{}", err)))?;
464451
// source.clone();
465-
let ast_node = program_to_ast(&vm, &internal_ast);
466-
Ok(ast_node)
452+
Ok(program_to_ast(&vm, &internal_ast))
467453
}
468454

469455
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {

0 commit comments

Comments
 (0)