Skip to content

Commit b797b51

Browse files
committed
Fix class scopes by modifying symboltable
1 parent ae45a4a commit b797b51

File tree

8 files changed

+41
-79
lines changed

8 files changed

+41
-79
lines changed

bytecode/src/bytecode.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ bitflags! {
5353
const HAS_DEFAULTS = 0x01;
5454
const HAS_KW_ONLY_DEFAULTS = 0x02;
5555
const HAS_ANNOTATIONS = 0x04;
56-
const IS_CLASS = 0x08;
5756
}
5857
}
5958

compiler/src/compile.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -975,7 +975,7 @@ impl<O: OutputStream> Compiler<O> {
975975

976976
// Turn code object into function object:
977977
self.emit(Instruction::MakeFunction {
978-
flags: bytecode::FunctionOpArg::IS_CLASS,
978+
flags: bytecode::FunctionOpArg::empty(),
979979
});
980980

981981
self.emit(Instruction::LoadConst {

compiler/src/symboltable.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ impl SymbolTable {
6161
}
6262
}
6363

64-
#[derive(Clone)]
64+
#[derive(Clone, PartialEq)]
6565
pub enum SymbolTableType {
6666
Module,
6767
Class,
@@ -241,12 +241,10 @@ impl SymbolTableAnalyzer {
241241
} else {
242242
// Interesting stuff about the __class__ variable:
243243
// https://docs.python.org/3/reference/datamodel.html?highlight=__class__#creating-the-class-object
244-
let found_in_outer_scope = (symbol.name == "__class__")
245-
|| self
246-
.tables
247-
.iter()
248-
.skip(1)
249-
.any(|t| t.symbols.contains_key(&symbol.name));
244+
let found_in_outer_scope = symbol.name == "__class__"
245+
|| self.tables.iter().skip(1).any(|t| {
246+
t.typ != SymbolTableType::Class && t.symbols.contains_key(&symbol.name)
247+
});
250248

251249
if found_in_outer_scope {
252250
// Symbol is in some outer scope.
@@ -387,7 +385,6 @@ impl SymbolTableBuilder {
387385
keywords,
388386
decorator_list,
389387
} => {
390-
self.register_name(name, SymbolUsage::Assigned)?;
391388
self.enter_scope(name, SymbolTableType::Class, statement.location.row());
392389
self.scan_statements(body)?;
393390
self.leave_scope();
@@ -396,6 +393,7 @@ impl SymbolTableBuilder {
396393
self.scan_expression(&keyword.value, &ExpressionContext::Load)?;
397394
}
398395
self.scan_expressions(decorator_list, &ExpressionContext::Load)?;
396+
self.register_name(name, SymbolUsage::Assigned)?;
399397
}
400398
Expression { expression } => {
401399
self.scan_expression(expression, &ExpressionContext::Load)?

tests/snippets/class.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,15 @@ def b():
174174
assert a == 2
175175
A.b()
176176

177-
def func():
178-
class A:
179-
a = 2
180-
def b():
181-
assert a == 1
182-
b()
183-
assert a == 2
184-
A.b()
185-
func()
177+
# TODO: uncomment once free vars/cells are working
178+
# a = 1
179+
# def nested_scope():
180+
# a = 2
181+
# class A:
182+
# a = 3
183+
# def b():
184+
# assert a == 2
185+
# b()
186+
# assert a == 3
187+
# A.b()
188+
# nested_scope()

vm/src/builtins.rs

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,10 @@ use std::str;
1010
use num_bigint::Sign;
1111
use num_traits::{Signed, ToPrimitive, Zero};
1212

13-
use crate::frame::Frame;
1413
use crate::obj::objbool;
1514
use crate::obj::objbytes::PyBytesRef;
1615
use crate::obj::objcode::PyCodeRef;
1716
use crate::obj::objdict::PyDictRef;
18-
use crate::obj::objfunction::{PyFunction, PyFunctionRef};
1917
use crate::obj::objint::{self, PyIntRef};
2018
use crate::obj::objiter;
2119
use crate::obj::objstr::{PyString, PyStringRef};
@@ -898,7 +896,7 @@ pub fn make_module(vm: &VirtualMachine, module: PyObjectRef) {
898896
}
899897

900898
pub fn builtin_build_class_(
901-
function: PyFunctionRef,
899+
function: PyObjectRef,
902900
qualified_name: PyStringRef,
903901
bases: Args<PyClassRef>,
904902
mut kwargs: KwArgs,
@@ -935,20 +933,7 @@ pub fn builtin_build_class_(
935933

936934
let cells = vm.ctx.new_dict();
937935

938-
let PyFunction { code, scope, .. } = &*function;
939-
940-
let is_class = scope.is_class();
941-
942-
let mut scope = scope
943-
.new_child_scope_with_locals(cells.clone())
944-
.new_child_scope_with_locals(namespace.clone());
945-
946-
if is_class {
947-
scope = scope.as_class();
948-
}
949-
950-
let frame = Frame::new(code.clone(), scope).into_ref(vm);
951-
vm.run_frame_full(frame)?;
936+
vm.invoke_with_locals(&function, cells.clone(), namespace.clone())?;
952937

953938
namespace.set_item("__name__", name_obj.clone(), vm)?;
954939
namespace.set_item("__qualname__", qualified_name.into_object(), vm)?;

vm/src/frame.rs

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,16 +1064,6 @@ impl Frame {
10641064
// pop argc arguments
10651065
// argument: name, args, globals
10661066
let scope = self.scope.clone();
1067-
let scope = if flags.contains(bytecode::FunctionOpArg::IS_CLASS) {
1068-
// if the function we're making is a class initializer
1069-
scope.new_child_scope(&vm.ctx).as_class()
1070-
} else if scope.is_class() {
1071-
// if the surrounding scope is a class, i.e. the function we're making is a method,
1072-
// then get the parent scope. See builtin_build_class for why.
1073-
scope.parent_scope()
1074-
} else {
1075-
scope
1076-
};
10771067
let func_obj = vm
10781068
.ctx
10791069
.new_function(code_obj, scope, defaults, kw_only_defaults);

vm/src/scope.rs

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ use crate::vm::VirtualMachine;
1212
pub struct Scope {
1313
locals: Vec<PyDictRef>,
1414
pub globals: PyDictRef,
15-
is_class: bool,
1615
}
1716

1817
impl fmt::Debug for Scope {
@@ -28,11 +27,7 @@ impl Scope {
2827
Some(dict) => vec![dict],
2928
None => vec![],
3029
};
31-
let scope = Scope {
32-
locals,
33-
globals,
34-
is_class: false,
35-
};
30+
let scope = Scope { locals, globals };
3631
scope.store_name(vm, "__annotations__", vm.ctx.new_dict().into_object());
3732
scope
3833
}
@@ -69,30 +64,12 @@ impl Scope {
6964
Scope {
7065
locals: new_locals,
7166
globals: self.globals.clone(),
72-
is_class: false,
7367
}
7468
}
7569

7670
pub fn new_child_scope(&self, ctx: &PyContext) -> Scope {
7771
self.new_child_scope_with_locals(ctx.new_dict())
7872
}
79-
80-
pub fn parent_scope(&self) -> Scope {
81-
Scope {
82-
locals: self.locals[1..].to_vec(),
83-
globals: self.globals.clone(),
84-
is_class: false,
85-
}
86-
}
87-
88-
pub fn is_class(&self) -> bool {
89-
self.is_class
90-
}
91-
92-
pub fn as_class(mut self) -> Self {
93-
self.is_class = true;
94-
self
95-
}
9673
}
9774

9875
pub trait NameProtocol {
@@ -154,15 +131,6 @@ impl NameProtocol for Scope {
154131
#[cfg_attr(feature = "flame-it", flame("Scope"))]
155132
/// Load a global name.
156133
fn load_global(&self, vm: &VirtualMachine, name: &str) -> Option<PyObjectRef> {
157-
// First, take a look in the outmost local scope (the scope at top level)
158-
let last_local_dict = self.locals.iter().last();
159-
if let Some(local_dict) = last_local_dict {
160-
if let Some(value) = local_dict.get_item_option(name, vm).unwrap() {
161-
return Some(value);
162-
}
163-
}
164-
165-
// Now, take a look at the globals or builtins.
166134
if let Some(value) = self.globals.get_item_option(name, vm).unwrap() {
167135
Some(value)
168136
} else {

vm/src/vm.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,25 @@ impl VirtualMachine {
662662
}
663663
}
664664

665+
pub fn invoke_with_locals(
666+
&self,
667+
function: &PyObjectRef,
668+
cells: PyDictRef,
669+
locals: PyDictRef,
670+
) -> PyResult {
671+
if let Some(PyFunction { code, scope, .. }) = &function.payload() {
672+
let scope = scope
673+
.new_child_scope_with_locals(cells)
674+
.new_child_scope_with_locals(locals);
675+
let frame = Frame::new(code.clone(), scope).into_ref(self);
676+
return self.run_frame_full(frame);
677+
}
678+
panic!(
679+
"invoke_with_locals: expected python function, got: {:?}",
680+
*function
681+
);
682+
}
683+
665684
fn fill_locals_from_args(
666685
&self,
667686
code_object: &bytecode::CodeObject,

0 commit comments

Comments
 (0)