Skip to content

Commit 4517527

Browse files
committed
Support StringIdx in the vm
1 parent fdd5d08 commit 4517527

File tree

12 files changed

+209
-135
lines changed

12 files changed

+209
-135
lines changed

src/shell/helper.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use rustpython_vm::obj::objstr::PyStringRef;
2-
use rustpython_vm::pyobject::{PyIterable, PyResult, TryFromObject};
2+
use rustpython_vm::pyobject::{PyIterable, PyResult, StringRef, TryFromObject};
33
use rustpython_vm::scope::{NameProtocol, Scope};
44
use rustpython_vm::VirtualMachine;
55

@@ -77,7 +77,9 @@ impl<'vm> ShellHelper<'vm> {
7777
// last: the last word, could be empty if it ends with a dot
7878
// parents: the words before the dot
7979

80-
let mut current = self.scope.load_global(self.vm, first)?;
80+
let mut current = self
81+
.scope
82+
.load_global(self.vm, &StringRef::new(first.into()))?;
8183

8284
for attr in parents {
8385
current = self.vm.get_attribute(current.clone(), attr.as_str()).ok()?;

vm/src/dictdatatype.rs

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
use crate::obj::objstr::{PyString, PyStringRef};
22
use crate::pyhash;
3-
use crate::pyobject::{IdProtocol, IntoPyObject, PyObjectRef, PyResult};
3+
use crate::pyobject::{IdProtocol, IntoPyObject, PyObjectRef, PyResult, StringRef};
44
use crate::vm::VirtualMachine;
5-
use num_bigint::ToBigInt;
65
/// Ordered dictionary implementation.
76
/// Inspired by: https://morepypy.blogspot.com/2015/01/faster-more-memory-efficient-and-more.html
87
/// And: https://www.youtube.com/watch?v=p33CVV29OG8
98
/// And: http://code.activestate.com/recipes/578375/
10-
use std::collections::{hash_map::DefaultHasher, HashMap};
11-
use std::hash::{Hash, Hasher};
9+
use std::collections::HashMap;
1210
use std::mem::size_of;
1311
use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
1412

@@ -423,10 +421,7 @@ pub trait DictKey {
423421
/// to index dictionaries.
424422
impl DictKey for &PyObjectRef {
425423
fn do_hash(self, vm: &VirtualMachine) -> PyResult<HashValue> {
426-
let raw_hash = vm._hash(self)?;
427-
let mut hasher = DefaultHasher::new();
428-
raw_hash.hash(&mut hasher);
429-
Ok(hasher.finish() as HashValue)
424+
vm._hash(self)
430425
}
431426

432427
fn do_is(self, other: &PyObjectRef) -> bool {
@@ -438,9 +433,14 @@ impl DictKey for &PyObjectRef {
438433
}
439434
}
440435

436+
// for the following DictKey impls, we use pyhash::hash_as_bigint because vm._hash() takes the
437+
// BigInt that obj.__hash__() returns and does `raw_hash % MODULUS` to get an i64.
438+
// pyhash::hash_as_bigint does the same here, to ensure that the string types' do_hash returns the
439+
// same as what PyObjectRef::do_hash() does for an equivalent python str
440+
441441
impl DictKey for &PyStringRef {
442442
fn do_hash(self, _vm: &VirtualMachine) -> PyResult<HashValue> {
443-
Ok(self.hash())
443+
Ok(pyhash::hash_as_bigint(self.hash()))
444444
}
445445

446446
fn do_is(self, other: &PyObjectRef) -> bool {
@@ -458,16 +458,30 @@ impl DictKey for &PyStringRef {
458458
}
459459
}
460460

461+
impl DictKey for &StringRef {
462+
fn do_hash(self, _vm: &VirtualMachine) -> PyResult<HashValue> {
463+
Ok(pyhash::hash_as_bigint(self.hash_value()))
464+
}
465+
466+
fn do_is(self, _other: &PyObjectRef) -> bool {
467+
false
468+
}
469+
470+
fn do_eq(self, vm: &VirtualMachine, other_key: &PyObjectRef) -> PyResult<bool> {
471+
if let Some(py_str_value) = other_key.payload::<PyString>() {
472+
Ok(py_str_value.as_str() == self.as_str())
473+
} else {
474+
vm.bool_eq(vm.new_str(self), other_key.clone())
475+
}
476+
}
477+
}
478+
461479
/// Implement trait for the str type, so that we can use strings
462480
/// to index dictionaries.
463481
impl DictKey for &str {
464482
fn do_hash(self, _vm: &VirtualMachine) -> PyResult<HashValue> {
465483
// follow a similar route as the hashing of PyStringRef
466-
let raw_hash = pyhash::hash_value(&self.to_owned()).to_bigint().unwrap();
467-
let raw_hash = pyhash::hash_bigint(&raw_hash);
468-
let mut hasher = DefaultHasher::new();
469-
raw_hash.hash(&mut hasher);
470-
Ok(hasher.finish() as HashValue)
484+
Ok(pyhash::hash_as_bigint(pyhash::hash_value(self) as HashValue))
471485
}
472486

473487
fn do_is(self, _other: &PyObjectRef) -> bool {

vm/src/frame.rs

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use indexmap::IndexMap;
66
use itertools::Itertools;
77

88
use crate::builtins::builtin_isinstance;
9-
use crate::bytecode;
9+
use crate::bytecode::{self, StringIdx};
1010
use crate::exceptions::{self, ExceptionCtor, PyBaseExceptionRef};
1111
use crate::function::PyFuncArgs;
1212
use crate::obj::objasyncgenerator::PyAsyncGenWrappedValue;
@@ -348,12 +348,12 @@ impl ExecutingFrame<'_> {
348348
bytecode::Instruction::LoadName {
349349
ref name,
350350
ref scope,
351-
} => self.load_name(vm, name, scope),
351+
} => self.load_name(vm, *name, scope),
352352
bytecode::Instruction::StoreName {
353353
ref name,
354354
ref scope,
355-
} => self.store_name(vm, name, scope),
356-
bytecode::Instruction::DeleteName { ref name } => self.delete_name(vm, name),
355+
} => self.store_name(vm, *name, scope),
356+
bytecode::Instruction::DeleteName { ref name } => self.delete_name(vm, *name),
357357
bytecode::Instruction::Subscript => self.execute_subscript(vm),
358358
bytecode::Instruction::StoreSubscript => self.execute_store_subscript(vm),
359359
bytecode::Instruction::DeleteSubscript => self.execute_delete_subscript(vm),
@@ -437,9 +437,9 @@ impl ExecutingFrame<'_> {
437437
bytecode::Instruction::BinaryOperation { ref op, inplace } => {
438438
self.execute_binop(vm, op, *inplace)
439439
}
440-
bytecode::Instruction::LoadAttr { ref name } => self.load_attr(vm, name),
441-
bytecode::Instruction::StoreAttr { ref name } => self.store_attr(vm, name),
442-
bytecode::Instruction::DeleteAttr { ref name } => self.delete_attr(vm, name),
440+
bytecode::Instruction::LoadAttr { ref name } => self.load_attr(vm, *name),
441+
bytecode::Instruction::StoreAttr { ref name } => self.store_attr(vm, *name),
442+
bytecode::Instruction::DeleteAttr { ref name } => self.delete_attr(vm, *name),
443443
bytecode::Instruction::UnaryOperation { ref op } => self.execute_unop(vm, op),
444444
bytecode::Instruction::CompareOperation { ref op } => self.execute_compare(vm, op),
445445
bytecode::Instruction::ReturnValue => {
@@ -786,9 +786,8 @@ impl ExecutingFrame<'_> {
786786
if let Some(dict) = module.dict() {
787787
for (k, v) in &dict {
788788
let k = vm.to_str(&k)?;
789-
let k = k.as_str();
790-
if !k.starts_with('_') {
791-
self.scope.store_name(&vm, k, v);
789+
if !k.as_str().starts_with('_') {
790+
self.scope.store_name(&vm, &k.data(), v);
792791
}
793792
}
794793
}
@@ -859,28 +858,22 @@ impl ExecutingFrame<'_> {
859858
fn store_name(
860859
&mut self,
861860
vm: &VirtualMachine,
862-
name: &str,
861+
name: StringIdx,
863862
name_scope: &bytecode::NameScope,
864863
) -> FrameResult {
865864
let obj = self.pop_value();
865+
let name = self.code.get_string(name);
866866
match name_scope {
867-
bytecode::NameScope::Global => {
868-
self.scope.store_global(vm, name, obj);
869-
}
870-
bytecode::NameScope::NonLocal => {
871-
self.scope.store_cell(vm, name, obj);
872-
}
873-
bytecode::NameScope::Local => {
874-
self.scope.store_name(vm, name, obj);
875-
}
876-
bytecode::NameScope::Free => {
877-
self.scope.store_name(vm, name, obj);
878-
}
867+
bytecode::NameScope::Global => self.scope.store_global(vm, name, obj),
868+
bytecode::NameScope::NonLocal => self.scope.store_cell(vm, name, obj),
869+
bytecode::NameScope::Local => self.scope.store_name(vm, name, obj),
870+
bytecode::NameScope::Free => self.scope.store_name(vm, name, obj),
879871
}
880872
Ok(None)
881873
}
882874

883-
fn delete_name(&self, vm: &VirtualMachine, name: &str) -> FrameResult {
875+
fn delete_name(&self, vm: &VirtualMachine, name: StringIdx) -> FrameResult {
876+
let name = self.code.get_string(name);
884877
match self.scope.delete_name(vm, name) {
885878
Ok(_) => Ok(None),
886879
Err(_) => Err(vm.new_name_error(format!("name '{}' is not defined", name))),
@@ -891,9 +884,10 @@ impl ExecutingFrame<'_> {
891884
fn load_name(
892885
&mut self,
893886
vm: &VirtualMachine,
894-
name: &str,
887+
name: StringIdx,
895888
name_scope: &bytecode::NameScope,
896889
) -> FrameResult {
890+
let name = self.code.get_string(name);
897891
let optional_value = match name_scope {
898892
bytecode::NameScope::Global => self.scope.load_global(vm, name),
899893
bytecode::NameScope::NonLocal => self.scope.load_cell(vm, name),
@@ -1412,23 +1406,25 @@ impl ExecutingFrame<'_> {
14121406
Ok(None)
14131407
}
14141408

1415-
fn load_attr(&mut self, vm: &VirtualMachine, attr_name: &str) -> FrameResult {
1409+
fn load_attr(&mut self, vm: &VirtualMachine, attr_name: StringIdx) -> FrameResult {
1410+
let attr_name = vm.new_str(self.code.get_string(attr_name));
14161411
let parent = self.pop_value();
14171412
let obj = vm.get_attribute(parent, attr_name)?;
14181413
self.push_value(obj);
14191414
Ok(None)
14201415
}
14211416

1422-
fn store_attr(&mut self, vm: &VirtualMachine, attr_name: &str) -> FrameResult {
1417+
fn store_attr(&mut self, vm: &VirtualMachine, attr_name: StringIdx) -> FrameResult {
1418+
let attr_name = vm.new_str(self.code.get_string(attr_name));
14231419
let parent = self.pop_value();
14241420
let value = self.pop_value();
1425-
vm.set_attr(&parent, vm.new_str(attr_name.to_owned()), value)?;
1421+
vm.set_attr(&parent, attr_name, value)?;
14261422
Ok(None)
14271423
}
14281424

1429-
fn delete_attr(&mut self, vm: &VirtualMachine, attr_name: &str) -> FrameResult {
1425+
fn delete_attr(&mut self, vm: &VirtualMachine, attr_name: StringIdx) -> FrameResult {
14301426
let parent = self.pop_value();
1431-
let name = vm.ctx.new_str(attr_name.to_owned());
1427+
let name = vm.new_str(self.code.get_string(attr_name));
14321428
vm.del_attr(&parent, name)?;
14331429
Ok(None)
14341430
}

vm/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,5 +88,6 @@ pub use rustpython_bytecode::*;
8888
#[doc(hidden)]
8989
pub mod __exports {
9090
pub use maplit::hashmap;
91+
pub use once_cell::sync::Lazy;
9192
pub use smallbox::smallbox;
9293
}

vm/src/macros.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,3 +342,14 @@ macro_rules! named_function {
342342
}
343343
}};
344344
}
345+
346+
#[macro_export]
347+
macro_rules! strdata {
348+
($s:literal) => {{
349+
use $crate::__exports::Lazy;
350+
static DATA: Lazy<$crate::pyobject::StringRef> =
351+
Lazy::new(|| $crate::pyobject::StringRef::new($s.into()));
352+
let data: &$crate::pyobject::StringRef = &DATA;
353+
data
354+
}};
355+
}

vm/src/obj/objfunction.rs

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -81,21 +81,18 @@ impl PyFunction {
8181
// See also: PyEval_EvalCodeWithName in cpython:
8282
// https://github.com/python/cpython/blob/master/Python/ceval.c#L3681
8383

84-
let n = if nargs > nexpected_args {
85-
nexpected_args
86-
} else {
87-
nargs
88-
};
84+
let n = std::cmp::min(nargs, nexpected_args);
8985

9086
// Copy positional arguments into local variables
9187
for i in 0..n {
92-
let arg_name = &code_object.arg_names[i];
88+
let arg_name = code_object.get_string(code_object.arg_names[i]);
9389
let arg = &func_args.args[i];
9490
locals.set_item(arg_name, arg.clone(), vm)?;
9591
}
9692

9793
// Pack other positional arguments in to *args:
98-
if let Some(ref vararg_name) = code_object.varargs_name {
94+
if let Some(vararg_name) = code_object.varargs_name {
95+
let vararg_name = code_object.get_string(vararg_name);
9996
let mut last_args = vec![];
10097
for i in n..nargs {
10198
let arg = &func_args.args[i];
@@ -120,7 +117,8 @@ impl PyFunction {
120117
.contains(bytecode::CodeFlags::HAS_VARKEYWORDS)
121118
{
122119
let d = vm.ctx.new_dict();
123-
if let Some(ref kwargs_name) = code_object.varkeywords_name {
120+
if let Some(kwargs_name) = code_object.varkeywords_name {
121+
let kwargs_name = code_object.get_string(kwargs_name);
124122
locals.set_item(kwargs_name, d.as_object().clone(), vm)?;
125123
}
126124
Some(d)
@@ -132,9 +130,14 @@ impl PyFunction {
132130
// Handle keyword arguments
133131
for (name, value) in func_args.kwargs {
134132
// Check if we have a parameter with this name:
135-
if code_object.arg_names.contains(&name) || code_object.kwonlyarg_names.contains(&name)
133+
let contains_str = |v: &[crate::bytecode::StringIdx], s| {
134+
v.iter()
135+
.any(|&idx| code_object.get_string(idx).as_str() == s)
136+
};
137+
if contains_str(&code_object.arg_names, &name)
138+
|| contains_str(&code_object.kwonlyarg_names, &name)
136139
{
137-
if posonly_args.contains(&name) {
140+
if contains_str(&posonly_args, &name) {
138141
posonly_passed_as_kwarg.push(name);
139142
continue;
140143
} else if locals.contains_key(&name, vm) {
@@ -170,24 +173,24 @@ impl PyFunction {
170173
let required_args = nexpected_args - num_defaults_available;
171174
let mut missing = vec![];
172175
for i in 0..required_args {
173-
let variable_name = &code_object.arg_names[i];
176+
let variable_name = code_object.get_string(code_object.arg_names[i]);
174177
if !locals.contains_key(variable_name, vm) {
175178
missing.push(variable_name)
176179
}
177180
}
178181
if !missing.is_empty() {
179182
return Err(vm.new_type_error(format!(
180-
"Missing {} required positional arguments: {:?}",
183+
"Missing {} required positional arguments: {}",
181184
missing.len(),
182-
missing
185+
missing.iter().format(", "),
183186
)));
184187
}
185188
if let Some(defaults) = &self.defaults {
186189
let defaults = defaults.as_slice();
187190
// We have sufficient defaults, so iterate over the corresponding names and use
188191
// the default if we don't already have a value
189192
for (default_index, i) in (required_args..nexpected_args).enumerate() {
190-
let arg_name = &code_object.arg_names[i];
193+
let arg_name = code_object.get_string(code_object.arg_names[i]);
191194
if !locals.contains_key(arg_name, vm) {
192195
locals.set_item(arg_name, defaults[default_index].clone(), vm)?;
193196
}
@@ -197,6 +200,7 @@ impl PyFunction {
197200

198201
// Check if kw only arguments are all present:
199202
for arg_name in &code_object.kwonlyarg_names {
203+
let arg_name = code_object.get_string(*arg_name);
200204
if !locals.contains_key(arg_name, vm) {
201205
if let Some(kw_only_defaults) = &self.kw_only_defaults {
202206
if let Some(default) = kw_only_defaults.get_item_option(arg_name, vm)? {

0 commit comments

Comments
 (0)