Skip to content

Commit ae45a4a

Browse files
committed
Fix class vs method scopes
1 parent c973ed8 commit ae45a4a

File tree

7 files changed

+81
-23
lines changed

7 files changed

+81
-23
lines changed

bytecode/src/bytecode.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ bitflags! {
5353
const HAS_DEFAULTS = 0x01;
5454
const HAS_KW_ONLY_DEFAULTS = 0x02;
5555
const HAS_ANNOTATIONS = 0x04;
56+
const IS_CLASS = 0x08;
5657
}
5758
}
5859

compiler/src/compile.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -975,7 +975,7 @@ impl<O: OutputStream> Compiler<O> {
975975

976976
// Turn code object into function object:
977977
self.emit(Instruction::MakeFunction {
978-
flags: bytecode::FunctionOpArg::empty(),
978+
flags: bytecode::FunctionOpArg::IS_CLASS,
979979
});
980980

981981
self.emit(Instruction::LoadConst {

tests/snippets/class.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,22 @@ class T5(int):
164164
assert str(super(int, T5(5))) == "<super: <class 'int'>, <T5 object>>"
165165

166166
#assert str(super(type, None)) == "<super: <class 'type'>, NULL>"
167+
168+
a = 1
169+
class A:
170+
a = 2
171+
def b():
172+
assert a == 1
173+
b()
174+
assert a == 2
175+
A.b()
176+
177+
def func():
178+
class A:
179+
a = 2
180+
def b():
181+
assert a == 1
182+
b()
183+
assert a == 2
184+
A.b()
185+
func()

vm/src/builtins.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ use std::str;
1010
use num_bigint::Sign;
1111
use num_traits::{Signed, ToPrimitive, Zero};
1212

13+
use crate::frame::Frame;
1314
use crate::obj::objbool;
1415
use crate::obj::objbytes::PyBytesRef;
1516
use crate::obj::objcode::PyCodeRef;
1617
use crate::obj::objdict::PyDictRef;
18+
use crate::obj::objfunction::{PyFunction, PyFunctionRef};
1719
use crate::obj::objint::{self, PyIntRef};
1820
use crate::obj::objiter;
1921
use crate::obj::objstr::{PyString, PyStringRef};
@@ -896,7 +898,7 @@ pub fn make_module(vm: &VirtualMachine, module: PyObjectRef) {
896898
}
897899

898900
pub fn builtin_build_class_(
899-
function: PyObjectRef,
901+
function: PyFunctionRef,
900902
qualified_name: PyStringRef,
901903
bases: Args<PyClassRef>,
902904
mut kwargs: KwArgs,
@@ -933,7 +935,20 @@ pub fn builtin_build_class_(
933935

934936
let cells = vm.ctx.new_dict();
935937

936-
vm.invoke_with_locals(&function, cells.clone(), namespace.clone())?;
938+
let PyFunction { code, scope, .. } = &*function;
939+
940+
let is_class = scope.is_class();
941+
942+
let mut scope = scope
943+
.new_child_scope_with_locals(cells.clone())
944+
.new_child_scope_with_locals(namespace.clone());
945+
946+
if is_class {
947+
scope = scope.as_class();
948+
}
949+
950+
let frame = Frame::new(code.clone(), scope).into_ref(vm);
951+
vm.run_frame_full(frame)?;
937952

938953
namespace.set_item("__name__", name_obj.clone(), vm)?;
939954
namespace.set_item("__qualname__", qualified_name.into_object(), vm)?;

vm/src/frame.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,6 +1064,16 @@ impl Frame {
10641064
// pop argc arguments
10651065
// argument: name, args, globals
10661066
let scope = self.scope.clone();
1067+
let scope = if flags.contains(bytecode::FunctionOpArg::IS_CLASS) {
1068+
// if the function we're making is a class initializer
1069+
scope.new_child_scope(&vm.ctx).as_class()
1070+
} else if scope.is_class() {
1071+
// if the surrounding scope is a class, i.e. the function we're making is a method,
1072+
// then get the parent scope. See builtin_build_class for why.
1073+
scope.parent_scope()
1074+
} else {
1075+
scope
1076+
};
10671077
let func_obj = vm
10681078
.ctx
10691079
.new_function(code_obj, scope, defaults, kw_only_defaults);

vm/src/scope.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use crate::vm::VirtualMachine;
1212
pub struct Scope {
1313
locals: Vec<PyDictRef>,
1414
pub globals: PyDictRef,
15+
is_class: bool,
1516
}
1617

1718
impl fmt::Debug for Scope {
@@ -27,7 +28,11 @@ impl Scope {
2728
Some(dict) => vec![dict],
2829
None => vec![],
2930
};
30-
let scope = Scope { locals, globals };
31+
let scope = Scope {
32+
locals,
33+
globals,
34+
is_class: false,
35+
};
3136
scope.store_name(vm, "__annotations__", vm.ctx.new_dict().into_object());
3237
scope
3338
}
@@ -64,12 +69,30 @@ impl Scope {
6469
Scope {
6570
locals: new_locals,
6671
globals: self.globals.clone(),
72+
is_class: false,
6773
}
6874
}
6975

7076
pub fn new_child_scope(&self, ctx: &PyContext) -> Scope {
7177
self.new_child_scope_with_locals(ctx.new_dict())
7278
}
79+
80+
pub fn parent_scope(&self) -> Scope {
81+
Scope {
82+
locals: self.locals[1..].to_vec(),
83+
globals: self.globals.clone(),
84+
is_class: false,
85+
}
86+
}
87+
88+
pub fn is_class(&self) -> bool {
89+
self.is_class
90+
}
91+
92+
pub fn as_class(mut self) -> Self {
93+
self.is_class = true;
94+
self
95+
}
7396
}
7497

7598
pub trait NameProtocol {
@@ -131,6 +154,15 @@ impl NameProtocol for Scope {
131154
#[cfg_attr(feature = "flame-it", flame("Scope"))]
132155
/// Load a global name.
133156
fn load_global(&self, vm: &VirtualMachine, name: &str) -> Option<PyObjectRef> {
157+
// First, take a look in the outmost local scope (the scope at top level)
158+
let last_local_dict = self.locals.iter().last();
159+
if let Some(local_dict) = last_local_dict {
160+
if let Some(value) = local_dict.get_item_option(name, vm).unwrap() {
161+
return Some(value);
162+
}
163+
}
164+
165+
// Now, take a look at the globals or builtins.
134166
if let Some(value) = self.globals.get_item_option(name, vm).unwrap() {
135167
Some(value)
136168
} else {

vm/src/vm.rs

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -662,25 +662,6 @@ impl VirtualMachine {
662662
}
663663
}
664664

665-
pub fn invoke_with_locals(
666-
&self,
667-
function: &PyObjectRef,
668-
cells: PyDictRef,
669-
locals: PyDictRef,
670-
) -> PyResult {
671-
if let Some(PyFunction { code, scope, .. }) = &function.payload() {
672-
let scope = scope
673-
.new_child_scope_with_locals(cells)
674-
.new_child_scope_with_locals(locals);
675-
let frame = Frame::new(code.clone(), scope).into_ref(self);
676-
return self.run_frame_full(frame);
677-
}
678-
panic!(
679-
"invoke_with_locals: expected python function, got: {:?}",
680-
*function
681-
);
682-
}
683-
684665
fn fill_locals_from_args(
685666
&self,
686667
code_object: &bytecode::CodeObject,

0 commit comments

Comments
 (0)