diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index cd12c620..f19601c9 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -129,8 +129,8 @@ jobs: - name: Run compile tests (macos lua54) if: ${{ matrix.os == 'macos-latest' && matrix.lua == 'lua54' }} run: | - TRYBUILD=overwrite cargo test --features "${{ matrix.lua }},vendored" -- --ignored - TRYBUILD=overwrite cargo test --features "${{ matrix.lua }},vendored,async,send,serialize,macros" -- --ignored + TRYBUILD=overwrite cargo test --features "${{ matrix.lua }},vendored" --tests -- --ignored + TRYBUILD=overwrite cargo test --features "${{ matrix.lua }},vendored,async,send,serialize,macros" --tests -- --ignored shell: bash test_with_sanitizer: diff --git a/CHANGELOG.md b/CHANGELOG.md index a8230bda..6d89e661 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,29 @@ -## v0.10.2 (Jan 27th, 2025) +## v0.10.5 (May 24th, 2025) + +- mlua-sys is back to 0.6.x (Luau 0.663) +- Reverted: Trigger abort when Luau userdata destructors are panic (requires new mlua-sys) +- Reverted: Added large (52bit) integers support for Luau (breaking change) + +## v0.10.4 (May 5th, 2025) + +_yanked_ because of semver-breaking changes + +- Luau updated to 0.672 +- New serde option `encode_empty_tables_as_array` to serialize empty tables as arrays +- Added `WeakLua` and `Lua::weak()` to create weak references to Lua state +- Trigger abort when Luau userdata destructors are panic (Luau GC does not support it) +- Added `AnyUserData::type_id()` method to get the type id of the userdata +- Added `Chunk::name()`, `Chunk::environment()` and `Chunk::mode()` functions +- Support borrowing underlying wrapped types for `UserDataRef` and `UserDataRefMut` (under `userdata-wrappers` feature) +- Added large (52bit) integers support for Luau +- Enable `serde` for `bstr` if `serialize` feature flag is enabled +- Recursive warnings (Lua 5.4) are no longer allowed +- Implemented `IntoLua`/`FromLua` for `BorrowedString` and `BorrowedBytes` +- Implemented `IntoLua`/`FromLua` for `char` +- Enable `Thread::reset()` for all Lua versions (limited support for 5.1-5.3) +- Bugfixes and improvements + +## v0.10.3 (Jan 27th, 2025) - Set `Default` for `Value` to be `Nil` - Allow exhaustive match on `Value` (#502) @@ -304,7 +329,7 @@ Other: ## v0.8.0 Changes since 0.7.4 -- Roblox Luau support +- Luau support - Removed C glue - Added async support to `__index` and `__newindex` metamethods - Added `Function::info()` to get information about functions (#149). @@ -354,7 +379,7 @@ Breaking changes: ## v0.8.0-beta.1 -- Roblox Luau support +- Luau support - Refactored ffi module. C glue is no longer required - Added async support to `__index` and `__newindex` metamethods @@ -467,7 +492,7 @@ Breaking changes: - [**Breaking**] Removed `AnyUserData::has_metamethod()` - Added `Thread::reset()` for luajit/lua54 to recycle threads. It's possible to attach a new function to a thread (coroutine). -- Added `chunk!` macro support to load chunks of Lua code using the Rust tokenizer and optinally capturing Rust variables. +- Added `chunk!` macro support to load chunks of Lua code using the Rust tokenizer and optionally capturing Rust variables. - Improved error reporting (`Error`'s `__tostring` method formats full stacktraces). This is useful in the module mode. ## v0.6.0-beta.1 @@ -523,7 +548,7 @@ Breaking changes: - Lua 5.4 support with `MetaMethod::Close`. - `lua53` feature is disabled by default. Now preferred Lua version have to be chosen explicitly. -- Provide safety guaraness for Lua state, which means that potenially unsafe operations, like loading C modules (using `require` or `package.loadlib`) are disabled. Equalient for the previous `Lua::new()` function is `Lua::unsafe_new()`. +- Provide safety guarantees for Lua state, which means that potentially unsafe operations, like loading C modules (using `require` or `package.loadlib`) are disabled. Equivalent to the previous `Lua::new()` function is `Lua::unsafe_new()`. - New `send` feature to require `Send`. - New `module` feature, that disables linking to Lua Core Libraries. Required for modules. - Don't allow `'callback` outlive `'lua` in `Lua::create_function()` to fix [the unsoundness](tests/compile/static_callback_args.rs). diff --git a/Cargo.toml b/Cargo.toml index efbb8d5a..8389b685 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,17 +1,17 @@ [package] name = "mlua" -version = "0.10.3" # remember to update mlua_derive +version = "0.10.5" # remember to update mlua_derive authors = ["Aleksandr Orlenko ", "kyren "] rust-version = "1.79.0" edition = "2021" -repository = "https://github.com/khvzak/mlua" +repository = "https://github.com/mlua-rs/mlua" documentation = "https://docs.rs/mlua" readme = "README.md" keywords = ["lua", "luajit", "luau", "async", "scripting"] categories = ["api-bindings", "asynchronous"] license = "MIT" description = """ -High level bindings to Lua 5.4/5.3/5.2/5.1 (including LuaJIT) and Roblox Luau +High level bindings to Lua 5.4/5.3/5.2/5.1 (including LuaJIT) and Luau with async/await features and support of writing native Lua modules in Rust. """ @@ -36,11 +36,11 @@ luau = ["ffi/luau", "dep:libloading"] luau-jit = ["luau", "ffi/luau-codegen"] luau-vector4 = ["luau", "ffi/luau-vector4"] vendored = ["ffi/vendored"] -module = ["dep:mlua_derive", "ffi/module"] +module = ["mlua_derive", "ffi/module"] async = ["dep:futures-util"] send = ["parking_lot/send_guard", "error-send"] error-send = [] -serialize = ["dep:serde", "dep:erased-serde", "dep:serde-value"] +serialize = ["dep:serde", "dep:erased-serde", "dep:serde-value", "bstr/serde"] macros = ["mlua_derive/macros"] anyhow = ["dep:anyhow", "error-send"] userdata-wrappers = [] @@ -57,8 +57,9 @@ erased-serde = { version = "0.4", optional = true } serde-value = { version = "0.7", optional = true } parking_lot = { version = "0.12", features = ["arc_lock"] } anyhow = { version = "1.0", optional = true } +rustversion = "1.0" -ffi = { package = "mlua-sys", version = "0.6.6", path = "mlua-sys" } +ffi = { package = "mlua-sys", version = "0.6.8", path = "mlua-sys" } [target.'cfg(unix)'.dependencies] libloading = { version = "0.8", optional = true } @@ -78,7 +79,7 @@ static_assertions = "1.0" [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] criterion = { version = "0.5", features = ["async_tokio"] } -rustyline = "14.0" +rustyline = "15.0" tokio = { version = "1.0", features = ["full"] } [lints.rust] diff --git a/README.md b/README.md index 5f14a829..3bf2b771 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # mlua [![Build Status]][github-actions] [![Latest Version]][crates.io] [![API Documentation]][docs.rs] [![Coverage Status]][codecov.io] ![MSRV] -[Build Status]: https://github.com/khvzak/mlua/workflows/CI/badge.svg -[github-actions]: https://github.com/khvzak/mlua/actions +[Build Status]: https://github.com/mlua-rs/mlua/workflows/CI/badge.svg +[github-actions]: https://github.com/mlua-rs/mlua/actions [Latest Version]: https://img.shields.io/crates/v/mlua.svg [crates.io]: https://crates.io/crates/mlua [API Documentation]: https://docs.rs/mlua/badge.svg @@ -19,19 +19,19 @@ > **Note** > -> See v0.10 [release notes](https://github.com/khvzak/mlua/blob/main/docs/release_notes/v0.10.md). +> See v0.10 [release notes](https://github.com/mlua-rs/mlua/blob/main/docs/release_notes/v0.10.md). `mlua` is bindings to [Lua](https://www.lua.org) programming language for Rust with a goal to provide _safe_ (as far as it's possible), high level, easy to use, practical and flexible API. -Started as `rlua` fork, `mlua` supports Lua 5.4, 5.3, 5.2, 5.1 (including LuaJIT) and [Roblox Luau] and allows to write native Lua modules in Rust as well as use Lua in a standalone mode. +Started as `rlua` fork, `mlua` supports Lua 5.4, 5.3, 5.2, 5.1 (including LuaJIT) and [Luau] and allows to write native Lua modules in Rust as well as use Lua in a standalone mode. `mlua` tested on Windows/macOS/Linux including module mode in [GitHub Actions] on `x86_64` platform and cross-compilation to `aarch64` (other targets are also supported). -WebAssembly (WASM) is supported through `wasm32-unknown-emscripten` target for all Lua versions excluding JIT. +WebAssembly (WASM) is supported through `wasm32-unknown-emscripten` target for all Lua/Luau versions excluding JIT. -[GitHub Actions]: https://github.com/khvzak/mlua/actions -[Roblox Luau]: https://luau.org +[GitHub Actions]: https://github.com/mlua-rs/mlua/actions +[Luau]: https://luau.org ## Usage @@ -64,9 +64,9 @@ Below is a list of the available feature flags. By default `mlua` does not enabl [5.2]: https://www.lua.org/manual/5.2/manual.html [5.1]: https://www.lua.org/manual/5.1/manual.html [LuaJIT]: https://luajit.org/ -[Luau]: https://github.com/Roblox/luau -[lua-src]: https://github.com/khvzak/lua-src-rs -[luajit-src]: https://github.com/khvzak/luajit-src-rs +[Luau]: https://github.com/luau-lang/luau +[lua-src]: https://github.com/mlua-rs/lua-src-rs +[luajit-src]: https://github.com/mlua-rs/luajit-src-rs [tokio]: https://github.com/tokio-rs/tokio [async-std]: https://github.com/async-rs/async-std [`Send`]: https://doc.rust-lang.org/std/marker/trait.Send.html diff --git a/docs/release_notes/v0.10.md b/docs/release_notes/v0.10.md index c2e09008..db01e8b2 100644 --- a/docs/release_notes/v0.10.md +++ b/docs/release_notes/v0.10.md @@ -3,7 +3,7 @@ The v0.10 version of mlua has goal to improve the user experience while keeping the same performance and safety guarantees. This document highlights the most notable features. For a full list of changes, see the [CHANGELOG]. -[CHANGELOG]: https://github.com/khvzak/mlua/blob/main/CHANGELOG.md +[CHANGELOG]: https://github.com/mlua-rs/mlua/blob/main/CHANGELOG.md ### New features diff --git a/docs/release_notes/v0.9.md b/docs/release_notes/v0.9.md index 33c6a2a5..cbc2d29b 100644 --- a/docs/release_notes/v0.9.md +++ b/docs/release_notes/v0.9.md @@ -3,7 +3,7 @@ The v0.9 version of mlua is a major release that includes a number of API changes and improvements. This release is a stepping stone towards the v1.0. This document highlights the most important changes. For a full list of changes, see the [CHANGELOG]. -[CHANGELOG]: https://github.com/khvzak/mlua/blob/main/CHANGELOG.md +[CHANGELOG]: https://github.com/mlua-rs/mlua/blob/main/CHANGELOG.md ### New features @@ -304,7 +304,7 @@ assert_eq!(f.call::<_, mlua::String>(())?, "hello"); The new mlua version has a number of performance improvements. Please check the [benchmarks results] to see how mlua compares to rlua and rhai. -[benchmarks results]: https://github.com/khvzak/script-bench-rs +[benchmarks results]: https://github.com/mlua-rs/script-bench-rs ### Changes in `module` mode diff --git a/mlua-sys/Cargo.toml b/mlua-sys/Cargo.toml index f4171a15..9eae26db 100644 --- a/mlua-sys/Cargo.toml +++ b/mlua-sys/Cargo.toml @@ -1,10 +1,10 @@ [package] name = "mlua-sys" -version = "0.6.7" +version = "0.6.8" authors = ["Aleksandr Orlenko "] rust-version = "1.71" edition = "2021" -repository = "https://github.com/khvzak/mlua" +repository = "https://github.com/mlua-rs/mlua" documentation = "https://docs.rs/mlua-sys" readme = "README.md" categories = ["external-ffi-bindings"] @@ -12,7 +12,7 @@ license = "MIT" links = "lua" build = "build/main.rs" description = """ -Low level (FFI) bindings to Lua 5.4/5.3/5.2/5.1 (including LuaJIT) and Roblox Luau +Low level (FFI) bindings to Lua 5.4/5.3/5.2/5.1 (including LuaJIT) and Luau """ [package.metadata.docs.rs] diff --git a/mlua-sys/README.md b/mlua-sys/README.md index d0de6252..927ebbd6 100644 --- a/mlua-sys/README.md +++ b/mlua-sys/README.md @@ -1,8 +1,8 @@ # mlua-sys -Low level (FFI) bindings to Lua 5.4/5.3/5.2/5.1 (including LuaJIT) and Roblox [Luau]. +Low level (FFI) bindings to Lua 5.4/5.3/5.2/5.1 (including LuaJIT) and [Luau]. Intended to be consumed by the [mlua] crate. -[Luau]: https://github.com/Roblox/luau +[Luau]: https://github.com/luau-lang/luau [mlua]: https://crates.io/crates/mlua diff --git a/mlua-sys/src/lib.rs b/mlua-sys/src/lib.rs index 629dfd88..6bbb595c 100644 --- a/mlua-sys/src/lib.rs +++ b/mlua-sys/src/lib.rs @@ -1,4 +1,4 @@ -//! Low level bindings to Lua 5.4/5.3/5.2/5.1 (including LuaJIT) and Roblox Luau. +//! Low level bindings to Lua 5.4/5.3/5.2/5.1 (including LuaJIT) and Luau. #![allow(non_camel_case_types, non_snake_case, dead_code)] #![allow(clippy::missing_safety_doc)] diff --git a/mlua-sys/src/lua51/compat.rs b/mlua-sys/src/lua51/compat.rs index 2294cd4f..027077f0 100644 --- a/mlua-sys/src/lua51/compat.rs +++ b/mlua-sys/src/lua51/compat.rs @@ -548,7 +548,7 @@ pub unsafe fn luaL_getsubtable(L: *mut lua_State, idx: c_int, fname: *const c_ch pub unsafe fn luaL_requiref(L: *mut lua_State, modname: *const c_char, openf: lua_CFunction, glb: c_int) { luaL_checkstack(L, 3, cstr!("not enough stack slots available")); - luaL_getsubtable(L, LUA_REGISTRYINDEX, cstr!("_LOADED")); + luaL_getsubtable(L, LUA_REGISTRYINDEX, LUA_LOADED_TABLE); if lua_getfield(L, -1, modname) == LUA_TNIL { lua_pop(L, 1); lua_pushcfunction(L, openf); diff --git a/mlua-sys/src/lua51/lauxlib.rs b/mlua-sys/src/lua51/lauxlib.rs index 54238858..78cef292 100644 --- a/mlua-sys/src/lua51/lauxlib.rs +++ b/mlua-sys/src/lua51/lauxlib.rs @@ -8,6 +8,9 @@ use super::lua::{self, lua_CFunction, lua_Integer, lua_Number, lua_State}; // Extra error code for 'luaL_load' pub const LUA_ERRFILE: c_int = lua::LUA_ERRERR + 1; +// Key, in the registry, for table of loaded modules +pub const LUA_LOADED_TABLE: *const c_char = cstr!("_LOADED"); + #[repr(C)] pub struct luaL_Reg { pub name: *const c_char, diff --git a/mlua-sys/src/lua52/compat.rs b/mlua-sys/src/lua52/compat.rs index d6a79ac6..29cd888e 100644 --- a/mlua-sys/src/lua52/compat.rs +++ b/mlua-sys/src/lua52/compat.rs @@ -232,7 +232,7 @@ pub unsafe fn luaL_tolstring(L: *mut lua_State, mut idx: c_int, len: *mut usize) pub unsafe fn luaL_requiref(L: *mut lua_State, modname: *const c_char, openf: lua_CFunction, glb: c_int) { luaL_checkstack(L, 3, cstr!("not enough stack slots available")); - luaL_getsubtable(L, LUA_REGISTRYINDEX, cstr!("_LOADED")); + luaL_getsubtable(L, LUA_REGISTRYINDEX, LUA_LOADED_TABLE); if lua_getfield(L, -1, modname) == LUA_TNIL { lua_pop(L, 1); lua_pushcfunction(L, openf); diff --git a/mlua-sys/src/lua52/lauxlib.rs b/mlua-sys/src/lua52/lauxlib.rs index d5cdf664..fad19ebe 100644 --- a/mlua-sys/src/lua52/lauxlib.rs +++ b/mlua-sys/src/lua52/lauxlib.rs @@ -8,6 +8,12 @@ use super::lua::{self, lua_CFunction, lua_Integer, lua_Number, lua_State, lua_Un // Extra error code for 'luaL_load' pub const LUA_ERRFILE: c_int = lua::LUA_ERRERR + 1; +// Key, in the registry, for table of loaded modules +pub const LUA_LOADED_TABLE: *const c_char = cstr!("_LOADED"); + +// Key, in the registry, for table of preloaded loaders +pub const LUA_PRELOAD_TABLE: *const c_char = cstr!("_PRELOAD"); + #[repr(C)] pub struct luaL_Reg { pub name: *const c_char, diff --git a/mlua-sys/src/luau/compat.rs b/mlua-sys/src/luau/compat.rs index 3d189cec..29c7b1e1 100644 --- a/mlua-sys/src/luau/compat.rs +++ b/mlua-sys/src/luau/compat.rs @@ -1,4 +1,4 @@ -//! MLua compatibility layer for Roblox Luau. +//! MLua compatibility layer for Luau. //! //! Based on github.com/keplerproject/lua-compat-5.3 diff --git a/mlua_derive/Cargo.toml b/mlua_derive/Cargo.toml index 1e121341..59b96e84 100644 --- a/mlua_derive/Cargo.toml +++ b/mlua_derive/Cargo.toml @@ -4,7 +4,7 @@ version = "0.10.1" authors = ["Aleksandr Orlenko "] edition = "2021" description = "Procedural macros for the mlua crate." -repository = "https://github.com/khvzak/mlua" +repository = "https://github.com/mlua-rs/mlua" keywords = ["lua", "mlua"] license = "MIT" @@ -19,6 +19,6 @@ quote = "1.0" proc-macro2 = { version = "1.0", features = ["span-locations"] } proc-macro-error2 = { version = "2.0.1", optional = true } syn = { version = "2.0", features = ["full"] } -itertools = { version = "0.13", optional = true } +itertools = { version = "0.14", optional = true } regex = { version = "1.4", optional = true } once_cell = { version = "1.0", optional = true } diff --git a/src/chunk.rs b/src/chunk.rs index 089f4b82..a973d27c 100644 --- a/src/chunk.rs +++ b/src/chunk.rs @@ -307,6 +307,11 @@ impl Compiler { } impl Chunk<'_> { + /// Returns the name of this chunk. + pub fn name(&self) -> &str { + &self.name + } + /// Sets the name of this chunk, which results in more informative error traces. /// /// Possible name prefixes: @@ -318,6 +323,11 @@ impl Chunk<'_> { self } + /// Returns the environment of this chunk. + pub fn environment(&self) -> Option<&Table> { + self.env.as_ref().ok()?.as_ref() + } + /// Sets the environment of the loaded chunk to the given value. /// /// In Lua >=5.2 main chunks always have exactly one upvalue, and this upvalue is used as the @@ -334,6 +344,11 @@ impl Chunk<'_> { self } + /// Returns the mode (auto-detected by default) of this chunk. + pub fn mode(&self) -> ChunkMode { + self.detect_mode() + } + /// Sets whether the chunk is text or binary (autodetected by default). /// /// Be aware, Lua does not check the consistency of the code inside binary chunks. diff --git a/src/conversion.rs b/src/conversion.rs index ad0a7003..2f647199 100644 --- a/src/conversion.rs +++ b/src/conversion.rs @@ -5,7 +5,7 @@ use std::hash::{BuildHasher, Hash}; use std::os::raw::c_int; use std::path::{Path, PathBuf}; use std::string::String as StdString; -use std::{slice, str}; +use std::{mem, slice, str}; use bstr::{BStr, BString, ByteSlice, ByteVec}; use num_traits::cast; @@ -13,7 +13,7 @@ use num_traits::cast; use crate::error::{Error, Result}; use crate::function::Function; use crate::state::{Lua, RawLua}; -use crate::string::String; +use crate::string::{BorrowedBytes, BorrowedStr, String}; use crate::table::Table; use crate::thread::Thread; use crate::traits::{FromLua, IntoLua, ShortTypeName as _}; @@ -91,6 +91,94 @@ impl FromLua for String { } } +impl IntoLua for BorrowedStr<'_> { + #[inline] + fn into_lua(self, _: &Lua) -> Result { + Ok(Value::String(self.borrow.into_owned())) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &RawLua) -> Result<()> { + lua.push_ref(&self.borrow.0); + Ok(()) + } +} + +impl IntoLua for &BorrowedStr<'_> { + #[inline] + fn into_lua(self, _: &Lua) -> Result { + Ok(Value::String(self.borrow.clone().into_owned())) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &RawLua) -> Result<()> { + lua.push_ref(&self.borrow.0); + Ok(()) + } +} + +impl FromLua for BorrowedStr<'_> { + fn from_lua(value: Value, lua: &Lua) -> Result { + let s = String::from_lua(value, lua)?; + let BorrowedStr { buf, _lua, .. } = BorrowedStr::try_from(&s)?; + let buf = unsafe { mem::transmute::<&str, &'static str>(buf) }; + let borrow = Cow::Owned(s); + Ok(Self { buf, borrow, _lua }) + } + + unsafe fn from_stack(idx: c_int, lua: &RawLua) -> Result { + let s = String::from_stack(idx, lua)?; + let BorrowedStr { buf, _lua, .. } = BorrowedStr::try_from(&s)?; + let buf = unsafe { mem::transmute::<&str, &'static str>(buf) }; + let borrow = Cow::Owned(s); + Ok(Self { buf, borrow, _lua }) + } +} + +impl IntoLua for BorrowedBytes<'_> { + #[inline] + fn into_lua(self, _: &Lua) -> Result { + Ok(Value::String(self.borrow.into_owned())) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &RawLua) -> Result<()> { + lua.push_ref(&self.borrow.0); + Ok(()) + } +} + +impl IntoLua for &BorrowedBytes<'_> { + #[inline] + fn into_lua(self, _: &Lua) -> Result { + Ok(Value::String(self.borrow.clone().into_owned())) + } + + #[inline] + unsafe fn push_into_stack(self, lua: &RawLua) -> Result<()> { + lua.push_ref(&self.borrow.0); + Ok(()) + } +} + +impl FromLua for BorrowedBytes<'_> { + fn from_lua(value: Value, lua: &Lua) -> Result { + let s = String::from_lua(value, lua)?; + let BorrowedBytes { buf, _lua, .. } = BorrowedBytes::from(&s); + let buf = unsafe { mem::transmute::<&[u8], &'static [u8]>(buf) }; + let borrow = Cow::Owned(s); + Ok(Self { buf, borrow, _lua }) + } + + unsafe fn from_stack(idx: c_int, lua: &RawLua) -> Result { + let s = String::from_stack(idx, lua)?; + let BorrowedBytes { buf, _lua, .. } = BorrowedBytes::from(&s); + let buf = unsafe { mem::transmute::<&[u8], &'static [u8]>(buf) }; + let borrow = Cow::Owned(s); + Ok(Self { buf, borrow, _lua }) + } +} + impl IntoLua for Table { #[inline] fn into_lua(self, _: &Lua) -> Result { @@ -655,6 +743,51 @@ impl IntoLua for &Path { } } +impl IntoLua for char { + #[inline] + fn into_lua(self, lua: &Lua) -> Result { + let mut char_bytes = [0; 4]; + self.encode_utf8(&mut char_bytes); + Ok(Value::String(lua.create_string(&char_bytes[..self.len_utf8()])?)) + } +} + +impl FromLua for char { + fn from_lua(value: Value, _lua: &Lua) -> Result { + let ty = value.type_name(); + match value { + Value::Integer(i) => { + cast(i) + .and_then(char::from_u32) + .ok_or_else(|| Error::FromLuaConversionError { + from: ty, + to: "char".to_string(), + message: Some("integer out of range when converting to char".to_string()), + }) + } + Value::String(s) => { + let str = s.to_str()?; + let mut str_iter = str.chars(); + match (str_iter.next(), str_iter.next()) { + (Some(char), None) => Ok(char), + _ => Err(Error::FromLuaConversionError { + from: ty, + to: "char".to_string(), + message: Some( + "expected string to have exactly one char when converting to char".to_string(), + ), + }), + } + } + _ => Err(Error::FromLuaConversionError { + from: ty, + to: Self::type_name(), + message: Some("expected string or integer".to_string()), + }), + } + } +} + #[inline] unsafe fn push_bytes_into_stack(this: T, lua: &RawLua) -> Result<()> where diff --git a/src/function.rs b/src/function.rs index 37b33fa9..4e6e4069 100644 --- a/src/function.rs +++ b/src/function.rs @@ -146,7 +146,7 @@ impl Function { /// Ok(()) /// })?; /// - /// sleep.call_async(10).await?; + /// sleep.call_async::<()>(10).await?; /// /// # Ok(()) /// # } diff --git a/src/lib.rs b/src/lib.rs index a404594c..bbf947c6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -66,6 +66,7 @@ // warnings at all. #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(not(send), allow(clippy::arc_with_non_send_sync))] +#![allow(clippy::ptr_eq)] #[macro_use] mod macros; @@ -104,7 +105,7 @@ pub use crate::function::{Function, FunctionInfo}; pub use crate::hook::{Debug, DebugEvent, DebugNames, DebugSource, DebugStack}; pub use crate::multi::{MultiValue, Variadic}; pub use crate::scope::Scope; -pub use crate::state::{GCMode, Lua, LuaOptions}; +pub use crate::state::{GCMode, Lua, LuaOptions, WeakLua}; pub use crate::stdlib::StdLib; pub use crate::string::{BorrowedBytes, BorrowedStr, String}; pub use crate::table::{Table, TablePairs, TableSequence}; @@ -209,7 +210,7 @@ pub use mlua_derive::FromLua; /// /// You can register multiple entrypoints as required. /// -/// ``` +/// ```ignore /// use mlua::{Lua, Result, Table}; /// /// #[mlua::lua_module] @@ -246,7 +247,7 @@ pub use mlua_derive::FromLua; /// ... /// } /// ``` -#[cfg(any(feature = "module", docsrs))] +#[cfg(all(feature = "mlua_derive", any(feature = "module", doc)))] #[cfg_attr(docsrs, doc(cfg(feature = "module")))] pub use mlua_derive::lua_module; diff --git a/src/luau/package.rs b/src/luau/package.rs index fc1aa1ac..f30d54d6 100644 --- a/src/luau/package.rs +++ b/src/luau/package.rs @@ -20,7 +20,7 @@ use {libloading::Library, rustc_hash::FxHashMap}; // #[cfg(unix)] -const TARGET_MLUA_LUAU_ABI_VERSION: u32 = 2; +const TARGET_MLUA_LUAU_ABI_VERSION: u32 = 3; #[cfg(all(unix, feature = "module"))] #[no_mangle] @@ -130,10 +130,10 @@ unsafe extern "C-unwind" fn lua_require(state: *mut ffi::lua_State) -> c_int { for i in 1.. { if ffi::lua_rawgeti(state, -1, i) == ffi::LUA_TNIL { // no more loaders? - if (*err_buf).is_empty() { + if (&*err_buf).is_empty() { ffi::luaL_error(state, cstr!("module '%s' not found"), name); } else { - let bytes = (*err_buf).as_bytes(); + let bytes = (&*err_buf).as_bytes(); let extra = ffi::lua_pushlstring(state, bytes.as_ptr() as *const _, bytes.len()); ffi::luaL_error(state, cstr!("module '%s' not found:%s"), name, extra); } diff --git a/src/memory.rs b/src/memory.rs index 672a8647..f469fe12 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -18,15 +18,32 @@ pub(crate) struct MemoryState { } impl MemoryState { + #[cfg(feature = "luau")] #[inline] pub(crate) unsafe fn get(state: *mut ffi::lua_State) -> *mut Self { let mut mem_state = ptr::null_mut(); - #[cfg(feature = "luau")] - { - ffi::lua_getallocf(state, &mut mem_state); - mlua_assert!(!mem_state.is_null(), "Luau state has no allocator userdata"); + ffi::lua_getallocf(state, &mut mem_state); + mlua_assert!(!mem_state.is_null(), "Luau state has no allocator userdata"); + mem_state as *mut MemoryState + } + + #[cfg(not(feature = "luau"))] + #[rustversion::since(1.85)] + #[inline] + #[allow(clippy::incompatible_msrv)] + pub(crate) unsafe fn get(state: *mut ffi::lua_State) -> *mut Self { + let mut mem_state = ptr::null_mut(); + if !ptr::fn_addr_eq(ffi::lua_getallocf(state, &mut mem_state), ALLOCATOR) { + mem_state = ptr::null_mut(); } - #[cfg(not(feature = "luau"))] + mem_state as *mut MemoryState + } + + #[cfg(not(feature = "luau"))] + #[rustversion::before(1.85)] + #[inline] + pub(crate) unsafe fn get(state: *mut ffi::lua_State) -> *mut Self { + let mut mem_state = ptr::null_mut(); if ffi::lua_getallocf(state, &mut mem_state) != ALLOCATOR { mem_state = ptr::null_mut(); } diff --git a/src/multi.rs b/src/multi.rs index b4fb0e2e..c171962c 100644 --- a/src/multi.rs +++ b/src/multi.rs @@ -126,8 +126,7 @@ impl MultiValue { /// Creates a `MultiValue` container from vector of values. /// - /// This methods needs *O*(*n*) data movement if the circular buffer doesn't happen to be at the - /// beginning of the allocation. + /// This method works in *O*(1) time and does not allocate any additional memory. #[inline] pub fn from_vec(vec: Vec) -> MultiValue { vec.into() @@ -135,7 +134,8 @@ impl MultiValue { /// Consumes the `MultiValue` and returns a vector of values. /// - /// This methods works in *O*(1) time and does not allocate any additional memory. + /// This method needs *O*(*n*) data movement if the circular buffer doesn't happen to be at the + /// beginning of the allocation. #[inline] pub fn into_vec(self) -> Vec { self.into() diff --git a/src/prelude.rs b/src/prelude.rs index 68ba8f2f..1b070d97 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -13,7 +13,7 @@ pub use crate::{ UserData as LuaUserData, UserDataFields as LuaUserDataFields, UserDataMetatable as LuaUserDataMetatable, UserDataMethods as LuaUserDataMethods, UserDataRef as LuaUserDataRef, UserDataRefMut as LuaUserDataRefMut, UserDataRegistry as LuaUserDataRegistry, Value as LuaValue, - VmState as LuaVmState, + Variadic as LuaVariadic, VmState as LuaVmState, WeakLua, }; #[cfg(not(feature = "luau"))] diff --git a/src/scope.rs b/src/scope.rs index bbe6dd79..7c155710 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -168,7 +168,7 @@ impl<'scope, 'env: 'scope> Scope<'scope, 'env> { #[cfg(feature = "luau")] let ud_ptr = { let data = UserDataStorage::new_scoped(data); - util::push_userdata::>(state, data, protect)? + util::push_userdata(state, data, protect)? }; #[cfg(not(feature = "luau"))] let ud_ptr = util::push_uninit_userdata::>(state, protect)?; @@ -216,7 +216,7 @@ impl<'scope, 'env: 'scope> Scope<'scope, 'env> { #[cfg(feature = "luau")] let ud_ptr = { let data = UserDataStorage::new_scoped(data); - util::push_userdata::>(state, data, protect)? + util::push_userdata(state, data, protect)? }; #[cfg(not(feature = "luau"))] let ud_ptr = util::push_uninit_userdata::>(state, protect)?; diff --git a/src/serde/de.rs b/src/serde/de.rs index 0a941170..69d9056c 100644 --- a/src/serde/de.rs +++ b/src/serde/de.rs @@ -49,6 +49,11 @@ pub struct Options { /// /// Default: **false** pub sort_keys: bool, + + /// If true, empty Lua tables will be encoded as array, instead of map. + /// + /// Default: **false** + pub encode_empty_tables_as_array: bool, } impl Default for Options { @@ -64,6 +69,7 @@ impl Options { deny_unsupported_types: true, deny_recursive_tables: true, sort_keys: false, + encode_empty_tables_as_array: false, } } @@ -93,6 +99,15 @@ impl Options { self.sort_keys = enabled; self } + + /// Sets [`encode_empty_tables_as_array`] option. + /// + /// [`encode_empty_tables_as_array`]: #structfield.encode_empty_tables_as_array + #[must_use] + pub const fn encode_empty_tables_as_array(mut self, enabled: bool) -> Self { + self.encode_empty_tables_as_array = enabled; + self + } } impl Deserializer { @@ -141,6 +156,9 @@ impl<'de> serde::Deserializer<'de> for Deserializer { Err(_) => visitor.visit_bytes(&s.as_bytes()), }, Value::Table(ref t) if t.raw_len() > 0 || t.is_array() => self.deserialize_seq(visitor), + Value::Table(ref t) if self.options.encode_empty_tables_as_array && t.is_empty() => { + self.deserialize_seq(visitor) + } Value::Table(_) => self.deserialize_map(visitor), Value::LightUserData(ud) if ud.0.is_null() => visitor.visit_none(), Value::UserData(ud) if ud.is_serializable() => { diff --git a/src/state.rs b/src/state.rs index 35d9e4ec..fa360c73 100644 --- a/src/state.rs +++ b/src/state.rs @@ -46,18 +46,20 @@ use serde::Serialize; pub(crate) use extra::ExtraData; pub use raw::RawLua; -use util::{callback_error_ext, StateGuard}; +use util::callback_error_ext; /// Top level Lua struct which represents an instance of Lua VM. -#[derive(Clone)] pub struct Lua { pub(self) raw: XRc>, // Controls whether garbage collection should be run on drop pub(self) collect_garbage: bool, } +/// Weak reference to Lua instance. +/// +/// This can used to prevent circular references between Lua and Rust objects. #[derive(Clone)] -pub(crate) struct WeakLua(XWeak>); +pub struct WeakLua(XWeak>); pub(crate) struct LuaGuard(ArcReentrantMutexGuard); @@ -98,9 +100,6 @@ pub struct LuaOptions { /// Max size of thread (coroutine) object pool used to execute asynchronous functions. /// - /// It works on Lua 5.4 and Luau, where [`lua_resetthread`] function - /// is available and allows to reuse old coroutines after resetting their state. - /// /// Default: **0** (disabled) /// /// [`lua_resetthread`]: https://www.lua.org/manual/5.4/manual.html#lua_resetthread @@ -154,6 +153,16 @@ impl Drop for Lua { } } +impl Clone for Lua { + #[inline] + fn clone(&self) -> Self { + Lua { + raw: XRc::clone(&self.raw), + collect_garbage: false, + } + } +} + impl fmt::Debug for Lua { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "Lua({:p})", self.lock().state()) @@ -421,7 +430,6 @@ impl Lua { callback_error_ext(state, ptr::null_mut(), move |extra, nargs| { let rawlua = (*extra).raw_lua(); - let _guard = StateGuard::new(rawlua, state); let args = A::from_stack_args(nargs, 1, None, rawlua)?; func(rawlua.lua(), args)?.push_into_stack(rawlua)?; Ok(1) @@ -612,7 +620,7 @@ impl Lua { /// .into_function()?, /// )?; /// while co.status() == ThreadStatus::Resumable { - /// co.resume(())?; + /// co.resume::<()>(())?; /// } /// # Ok(()) /// # } @@ -639,7 +647,6 @@ impl Lua { if Rc::strong_count(&interrupt_cb) > 2 { return Ok(VmState::Continue); // Don't allow recursion } - let _guard = StateGuard::new((*extra).raw_lua(), state); interrupt_cb((*extra).lua()) }); match result { @@ -687,18 +694,19 @@ impl Lua { unsafe extern "C-unwind" fn warn_proc(ud: *mut c_void, msg: *const c_char, tocont: c_int) { let extra = ud as *mut ExtraData; callback_error_ext((*extra).raw_lua().state(), extra, |extra, _| { - let cb = mlua_expect!( - (*extra).warn_callback.as_ref(), - "no warning callback set in warn_proc" - ); + let warn_callback = (*extra).warn_callback.clone(); + let warn_callback = mlua_expect!(warn_callback, "no warning callback set in warn_proc"); + if XRc::strong_count(&warn_callback) > 2 { + return Ok(()); + } let msg = StdString::from_utf8_lossy(CStr::from_ptr(msg).to_bytes()); - cb((*extra).lua(), &msg, tocont != 0) + warn_callback((*extra).lua(), &msg, tocont != 0) }); } let lua = self.lock(); unsafe { - (*lua.extra.get()).warn_callback = Some(Box::new(callback)); + (*lua.extra.get()).warn_callback = Some(XRc::new(callback)); ffi::lua_setwarnf(lua.state(), Some(warn_proc), lua.extra.get() as *mut c_void); } } @@ -1313,7 +1321,7 @@ impl Lua { /// This methods provides a way to add fields or methods to userdata objects of a type `T`. pub fn register_userdata_type(&self, f: impl FnOnce(&mut UserDataRegistry)) -> Result<()> { let type_id = TypeId::of::(); - let mut registry = UserDataRegistry::new(self, type_id); + let mut registry = UserDataRegistry::new(self); f(&mut registry); let lua = self.lock(); @@ -1792,8 +1800,8 @@ impl Lua { let state = lua.state(); unsafe { let mut unref_list = (*lua.extra.get()).registry_unref_list.lock(); - let unref_list = mem::replace(&mut *unref_list, Some(Vec::new())); - for id in mlua_expect!(unref_list, "unref list not set") { + let unref_list = unref_list.replace(Vec::new()); + for id in mlua_expect!(unref_list, "unref list is not set") { ffi::luaL_unref(state, ffi::LUA_REGISTRYINDEX, id); } } @@ -1823,7 +1831,7 @@ impl Lua { /// fn main() -> Result<()> { /// let lua = Lua::new(); /// lua.set_app_data("hello"); - /// lua.create_function(hello)?.call(())?; + /// lua.create_function(hello)?.call::<()>(())?; /// let s = lua.app_data_ref::<&str>().unwrap(); /// assert_eq!(*s, "world"); /// Ok(()) @@ -1907,6 +1915,8 @@ impl Lua { } /// Returns an internal `Poll::Pending` constant used for executing async callbacks. + /// + /// Every time when [`Future`] is Pending, Lua corotine is suspended with this constant. #[cfg(feature = "async")] #[doc(hidden)] #[inline(always)] @@ -1915,6 +1925,15 @@ impl Lua { LightUserData(&ASYNC_POLL_PENDING as *const u8 as *mut std::os::raw::c_void) } + /// Returns a weak reference to the Lua instance. + /// + /// This is useful for creating a reference to the Lua instance that does not prevent it from + /// being deallocated. + #[inline(always)] + pub fn weak(&self) -> WeakLua { + WeakLua(XRc::downgrade(&self.raw)) + } + // Luau version located in `luau/mod.rs` #[cfg(not(feature = "luau"))] fn disable_c_modules(&self) -> Result<()> { @@ -1953,11 +1972,6 @@ impl Lua { LuaGuard(self.raw.lock_arc()) } - #[inline(always)] - pub(crate) fn weak(&self) -> WeakLua { - WeakLua(XRc::downgrade(&self.raw)) - } - /// Returns a handle to the unprotected Lua state without any synchronization. /// /// This is useful where we know that the lock is already held by the caller. @@ -1980,14 +1994,30 @@ impl WeakLua { Some(LuaGuard::new(self.0.upgrade()?)) } + /// Upgrades the weak Lua reference to a strong reference. + /// + /// # Panics + /// + /// Panics if the Lua instance is destroyed. #[track_caller] #[inline(always)] - pub(crate) fn upgrade(&self) -> Lua { + pub fn upgrade(&self) -> Lua { Lua { raw: self.0.upgrade().expect("Lua instance is destroyed"), collect_garbage: false, } } + + /// Tries to upgrade the weak Lua reference to a strong reference. + /// + /// Returns `None` if the Lua instance is destroyed. + #[inline(always)] + pub fn try_upgrade(&self) -> Option { + Some(Lua { + raw: self.0.upgrade()?, + collect_garbage: false, + }) + } } impl PartialEq for WeakLua { diff --git a/src/state/extra.rs b/src/state/extra.rs index d1823b5c..07bfd391 100644 --- a/src/state/extra.rs +++ b/src/state/extra.rs @@ -28,7 +28,7 @@ use super::{Lua, WeakLua}; static EXTRA_REGISTRY_KEY: u8 = 0; const WRAPPED_FAILURE_POOL_DEFAULT_CAPACITY: usize = 64; -const REF_STACK_RESERVE: c_int = 1; +const REF_STACK_RESERVE: c_int = 2; /// Data associated with the Lua state. pub(crate) struct ExtraData { diff --git a/src/state/raw.rs b/src/state/raw.rs index 0731f846..d4ab85d6 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -12,7 +12,7 @@ use crate::chunk::ChunkMode; use crate::error::{Error, Result}; use crate::function::Function; use crate::memory::{MemoryState, ALLOCATOR}; -use crate::state::util::{callback_error_ext, ref_stack_pop, StateGuard}; +use crate::state::util::{callback_error_ext, ref_stack_pop}; use crate::stdlib::StdLib; use crate::string::String; use crate::table::Table; @@ -23,13 +23,14 @@ use crate::types::{ MaybeSend, ReentrantMutex, RegistryKey, ValueRef, XRc, }; use crate::userdata::{ - AnyUserData, MetaMethod, RawUserDataRegistry, UserData, UserDataRegistry, UserDataStorage, + init_userdata_metatable, AnyUserData, MetaMethod, RawUserDataRegistry, UserData, UserDataRegistry, + UserDataStorage, }; use crate::util::{ assert_stack, check_stack, get_destructed_userdata_metatable, get_internal_userdata, get_main_state, - get_metatable_ptr, get_userdata, init_error_registry, init_internal_metatable, init_userdata_metatable, - pop_error, push_internal_userdata, push_string, push_table, rawset_field, safe_pcall, safe_xpcall, - short_type_name, StackGuard, WrappedFailure, + get_metatable_ptr, get_userdata, init_error_registry, init_internal_metatable, pop_error, + push_internal_userdata, push_string, push_table, rawset_field, safe_pcall, safe_xpcall, short_type_name, + StackGuard, WrappedFailure, }; use crate::value::{Nil, Value}; @@ -400,7 +401,6 @@ impl RawLua { return Ok(VmState::Continue); // Don't allow recursion } let rawlua = (*extra).raw_lua(); - let _guard = StateGuard::new(rawlua, state); let debug = Debug::new(rawlua, ar); hook_cb((*extra).lua(), debug) }); @@ -505,7 +505,6 @@ impl RawLua { /// Wraps a Lua function into a new or recycled thread (coroutine). #[cfg(feature = "async")] pub(crate) unsafe fn create_recycled_thread(&self, func: &Function) -> Result { - #[cfg(any(feature = "lua54", feature = "luau"))] if let Some(index) = (*self.extra.get()).thread_pool.pop() { let thread_state = ffi::lua_tothread(self.ref_thread(), index); ffi::lua_xpush(self.ref_thread(), thread_state, func.0.index); @@ -525,27 +524,47 @@ impl RawLua { /// Resets thread (coroutine) and returns it to the pool for later use. #[cfg(feature = "async")] - #[cfg(any(feature = "lua54", feature = "luau"))] - pub(crate) unsafe fn recycle_thread(&self, thread: &mut Thread) -> bool { + pub(crate) unsafe fn recycle_thread(&self, thread: &mut Thread) { + let thread_state = thread.1; let extra = &mut *self.extra.get(); - if extra.thread_pool.len() < extra.thread_pool.capacity() { - let thread_state = ffi::lua_tothread(extra.ref_thread, thread.0.index); - #[cfg(all(feature = "lua54", not(feature = "vendored")))] - let status = ffi::lua_resetthread(thread_state); - #[cfg(all(feature = "lua54", feature = "vendored"))] - let status = ffi::lua_closethread(thread_state, self.state()); + if extra.thread_pool.len() == extra.thread_pool.capacity() { #[cfg(feature = "lua54")] - if status != ffi::LUA_OK { - // Error object is on top, drop it + if ffi::lua_status(thread_state) != ffi::LUA_OK { + // Close all to-be-closed variables without returning thread to the pool + #[cfg(not(feature = "vendored"))] + ffi::lua_resetthread(thread_state); + #[cfg(feature = "vendored")] + ffi::lua_closethread(thread_state, self.state()); + } + return; + } + + let mut reset_ok = false; + if ffi::lua_status(thread_state) == ffi::LUA_OK { + if ffi::lua_gettop(thread_state) > 0 { ffi::lua_settop(thread_state, 0); } - #[cfg(feature = "luau")] + reset_ok = true; + } + + #[cfg(feature = "lua54")] + if !reset_ok { + #[cfg(not(feature = "vendored"))] + let status = ffi::lua_resetthread(thread_state); + #[cfg(feature = "vendored")] + let status = ffi::lua_closethread(thread_state, self.state()); + reset_ok = status == ffi::LUA_OK; + } + #[cfg(feature = "luau")] + if !reset_ok { ffi::lua_resetthread(thread_state); + reset_ok = true; + } + + if reset_ok { extra.thread_pool.push(thread.0.index); thread.0.drop = false; // Prevent thread from being garbage collected - return true; } - false } /// Pushes a value that implements `IntoLua` onto the Lua stack. @@ -779,7 +798,7 @@ impl RawLua { } // Create a new metatable from `UserData` definition - let mut registry = UserDataRegistry::new(self.lua(), type_id); + let mut registry = UserDataRegistry::new(self.lua()); T::register(&mut registry); self.create_userdata_metatable(registry.into_raw()) @@ -800,7 +819,7 @@ impl RawLua { // Check if metatable creation is pending or create an empty metatable otherwise let registry = match (*self.extra.get()).pending_userdata_reg.remove(&type_id) { Some(registry) => registry, - None => UserDataRegistry::::new(self.lua(), type_id).into_raw(), + None => UserDataRegistry::::new(self.lua()).into_raw(), }; self.create_userdata_metatable(registry) }) @@ -815,12 +834,11 @@ impl RawLua { let _sg = StackGuard::new(state); check_stack(state, 3)?; - // We push metatable first to ensure having correct metatable with `__gc` method - ffi::lua_pushnil(state); - ffi::lua_rawgeti(state, ffi::LUA_REGISTRYINDEX, get_metatable_id()?); + // We generate metatable first to make sure it *always* available when userdata pushed + let mt_id = get_metatable_id()?; let protect = !self.unlikely_memory_error(); crate::util::push_userdata(state, data, protect)?; - ffi::lua_replace(state, -3); + ffi::lua_rawgeti(state, ffi::LUA_REGISTRYINDEX, mt_id); ffi::lua_setmetatable(state, -2); // Set empty environment for Lua 5.1 @@ -1018,17 +1036,22 @@ impl RawLua { // Returns `TypeId` for the userdata ref, checking that it's registered and not destructed. // // Returns `None` if the userdata is registered but non-static. - pub(crate) unsafe fn get_userdata_ref_type_id(&self, vref: &ValueRef) -> Result> { - self.get_userdata_type_id_inner(self.ref_thread(), vref.index) + #[inline(always)] + pub(crate) fn get_userdata_ref_type_id(&self, vref: &ValueRef) -> Result> { + unsafe { self.get_userdata_type_id_inner(self.ref_thread(), vref.index) } } // Same as `get_userdata_ref_type_id` but assumes the userdata is already on the stack. - pub(crate) unsafe fn get_userdata_type_id(&self, idx: c_int) -> Result> { - match self.get_userdata_type_id_inner(self.state(), idx) { + pub(crate) unsafe fn get_userdata_type_id( + &self, + state: *mut ffi::lua_State, + idx: c_int, + ) -> Result> { + match self.get_userdata_type_id_inner(state, idx) { Ok(type_id) => Ok(type_id), - Err(Error::UserDataTypeMismatch) if ffi::lua_type(self.state(), idx) != ffi::LUA_TUSERDATA => { + Err(Error::UserDataTypeMismatch) if ffi::lua_type(state, idx) != ffi::LUA_TUSERDATA => { // Report `FromLuaConversionError` instead - let idx_type_name = CStr::from_ptr(ffi::luaL_typename(self.state(), idx)); + let idx_type_name = CStr::from_ptr(ffi::luaL_typename(state, idx)); let idx_type_name = idx_type_name.to_str().unwrap(); let message = format!("expected userdata of type '{}'", short_type_name::()); Err(Error::from_lua_conversion(idx_type_name, "userdata", message)) @@ -1081,7 +1104,6 @@ impl RawLua { // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) // The lock must be already held as the callback is executed let rawlua = (*extra).raw_lua(); - let _guard = StateGuard::new(rawlua, state); match (*upvalue).data { Some(ref func) => func(rawlua, nargs), None => Err(Error::CallbackDestructed), @@ -1125,12 +1147,10 @@ impl RawLua { // Async functions cannot be scoped and therefore destroyed, // so the first upvalue is always valid let upvalue = get_userdata::(state, ffi::lua_upvalueindex(1)); - let extra = (*upvalue).extra.get(); - callback_error_ext(state, extra, |extra, nargs| { + callback_error_ext(state, (*upvalue).extra.get(), |extra, nargs| { // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) // The lock must be already held as the callback is executed let rawlua = (*extra).raw_lua(); - let _guard = StateGuard::new(rawlua, state); let func = &*(*upvalue).data; let fut = func(rawlua, nargs); @@ -1155,7 +1175,6 @@ impl RawLua { // Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments) // The lock must be already held as the future is polled let rawlua = (*extra).raw_lua(); - let _guard = StateGuard::new(rawlua, state); let fut = &mut (*upvalue).data; let mut ctx = Context::from_waker(rawlua.waker()); diff --git a/src/state/util.rs b/src/state/util.rs index ec701eaf..2d54e20d 100644 --- a/src/state/util.rs +++ b/src/state/util.rs @@ -7,10 +7,10 @@ use crate::error::{Error, Result}; use crate::state::{ExtraData, RawLua}; use crate::util::{self, get_internal_metatable, WrappedFailure}; -pub(super) struct StateGuard<'a>(&'a RawLua, *mut ffi::lua_State); +struct StateGuard<'a>(&'a RawLua, *mut ffi::lua_State); impl<'a> StateGuard<'a> { - pub(super) fn new(inner: &'a RawLua, mut state: *mut ffi::lua_State) -> Self { + fn new(inner: &'a RawLua, mut state: *mut ffi::lua_State) -> Self { state = inner.state.replace(state); Self(inner, state) } @@ -23,7 +23,7 @@ impl Drop for StateGuard<'_> { } // An optimized version of `callback_error` that does not allocate `WrappedFailure` userdata -// and instead reuses unsed values from previous calls (or allocates new). +// and instead reuses unused values from previous calls (or allocates new). pub(super) unsafe fn callback_error_ext( state: *mut ffi::lua_State, mut extra: *mut ExtraData, @@ -51,7 +51,7 @@ where } // We need to check stack for Luau in case when callback is called from interrupt - // See https://github.com/Roblox/luau/issues/446 and mlua #142 and #153 + // See https://github.com/luau-lang/luau/issues/446 and mlua #142 and #153 #[cfg(feature = "luau")] ffi::lua_rawcheckstack(state, 2); // Place it to the beginning of the stack @@ -101,7 +101,11 @@ where // to store a wrapped failure (error or panic) *before* we proceed. let prealloc_failure = PreallocatedFailure::reserve(state, extra); - match catch_unwind(AssertUnwindSafe(|| f(extra, nargs))) { + match catch_unwind(AssertUnwindSafe(|| { + let rawlua = (*extra).raw_lua(); + let _guard = StateGuard::new(rawlua, state); + f(extra, nargs) + })) { Ok(Ok(r)) => { // Return unused `WrappedFailure` to the pool prealloc_failure.release(state, extra); diff --git a/src/string.rs b/src/string.rs index 38fb8850..9c86102b 100644 --- a/src/string.rs +++ b/src/string.rs @@ -1,4 +1,4 @@ -use std::borrow::Borrow; +use std::borrow::{Borrow, Cow}; use std::hash::{Hash, Hasher}; use std::ops::Deref; use std::os::raw::{c_int, c_void}; @@ -44,13 +44,7 @@ impl String { /// ``` #[inline] pub fn to_str(&self) -> Result { - let BorrowedBytes(bytes, guard) = self.as_bytes(); - let s = str::from_utf8(bytes).map_err(|e| Error::FromLuaConversionError { - from: "string", - to: "&str".to_string(), - message: Some(e.to_string()), - })?; - Ok(BorrowedStr(s, guard)) + BorrowedStr::try_from(self) } /// Converts this string to a [`StdString`]. @@ -109,19 +103,21 @@ impl String { /// ``` #[inline] pub fn as_bytes(&self) -> BorrowedBytes { - let (bytes, guard) = unsafe { self.to_slice() }; - BorrowedBytes(&bytes[..bytes.len() - 1], guard) + BorrowedBytes::from(self) } /// Get the bytes that make up this string, including the trailing nul byte. pub fn as_bytes_with_nul(&self) -> BorrowedBytes { - let (bytes, guard) = unsafe { self.to_slice() }; - BorrowedBytes(bytes, guard) + let BorrowedBytes { buf, borrow, _lua } = BorrowedBytes::from(self); + // Include the trailing nul byte (it's always present but excluded by default) + let buf = unsafe { slice::from_raw_parts((*buf).as_ptr(), (*buf).len() + 1) }; + BorrowedBytes { buf, borrow, _lua } } + // Does not return the terminating nul byte unsafe fn to_slice(&self) -> (&[u8], Lua) { let lua = self.0.lua.upgrade(); - let slice = unsafe { + let slice = { let rawlua = lua.lock(); let ref_thread = rawlua.ref_thread(); @@ -134,7 +130,7 @@ impl String { // string type let mut size = 0; let data = ffi::lua_tolstring(ref_thread, self.0.index, &mut size); - slice::from_raw_parts(data as *const u8, size + 1) + slice::from_raw_parts(data as *const u8, size) }; (slice, lua) } @@ -238,40 +234,45 @@ impl fmt::Display for Display<'_> { } /// A borrowed string (`&str`) that holds a strong reference to the Lua state. -pub struct BorrowedStr<'a>(&'a str, #[allow(unused)] Lua); +pub struct BorrowedStr<'a> { + // `buf` points to a readonly memory managed by Lua + pub(crate) buf: &'a str, + pub(crate) borrow: Cow<'a, String>, + pub(crate) _lua: Lua, +} impl Deref for BorrowedStr<'_> { type Target = str; #[inline(always)] fn deref(&self) -> &str { - self.0 + self.buf } } impl Borrow for BorrowedStr<'_> { #[inline(always)] fn borrow(&self) -> &str { - self.0 + self.buf } } impl AsRef for BorrowedStr<'_> { #[inline(always)] fn as_ref(&self) -> &str { - self.0 + self.buf } } impl fmt::Display for BorrowedStr<'_> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - self.0.fmt(f) + self.buf.fmt(f) } } impl fmt::Debug for BorrowedStr<'_> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - self.0.fmt(f) + self.buf.fmt(f) } } @@ -280,7 +281,7 @@ where T: AsRef, { fn eq(&self, other: &T) -> bool { - self.0 == other.as_ref() + self.buf == other.as_ref() } } @@ -291,45 +292,65 @@ where T: AsRef, { fn partial_cmp(&self, other: &T) -> Option { - self.0.partial_cmp(other.as_ref()) + self.buf.partial_cmp(other.as_ref()) } } impl Ord for BorrowedStr<'_> { fn cmp(&self, other: &Self) -> cmp::Ordering { - self.0.cmp(other.0) + self.buf.cmp(other.buf) + } +} + +impl<'a> TryFrom<&'a String> for BorrowedStr<'a> { + type Error = Error; + + #[inline] + fn try_from(value: &'a String) -> Result { + let BorrowedBytes { buf, borrow, _lua } = BorrowedBytes::from(value); + let buf = str::from_utf8(buf).map_err(|e| Error::FromLuaConversionError { + from: "string", + to: "&str".to_string(), + message: Some(e.to_string()), + })?; + Ok(Self { buf, borrow, _lua }) } } /// A borrowed byte slice (`&[u8]`) that holds a strong reference to the Lua state. -pub struct BorrowedBytes<'a>(&'a [u8], #[allow(unused)] Lua); +pub struct BorrowedBytes<'a> { + // `buf` points to a readonly memory managed by Lua + pub(crate) buf: &'a [u8], + pub(crate) borrow: Cow<'a, String>, + pub(crate) _lua: Lua, +} impl Deref for BorrowedBytes<'_> { type Target = [u8]; #[inline(always)] fn deref(&self) -> &[u8] { - self.0 + self.buf } } impl Borrow<[u8]> for BorrowedBytes<'_> { #[inline(always)] fn borrow(&self) -> &[u8] { - self.0 + self.buf } } impl AsRef<[u8]> for BorrowedBytes<'_> { #[inline(always)] fn as_ref(&self) -> &[u8] { - self.0 + self.buf } } impl fmt::Debug for BorrowedBytes<'_> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - self.0.fmt(f) + self.buf.fmt(f) } } @@ -338,7 +359,7 @@ where T: AsRef<[u8]>, { fn eq(&self, other: &T) -> bool { - self.0 == other.as_ref() + self.buf == other.as_ref() } } @@ -349,22 +370,31 @@ where T: AsRef<[u8]>, { fn partial_cmp(&self, other: &T) -> Option { - self.0.partial_cmp(other.as_ref()) + self.buf.partial_cmp(other.as_ref()) } } impl Ord for BorrowedBytes<'_> { fn cmp(&self, other: &Self) -> cmp::Ordering { - self.0.cmp(other.0) + self.buf.cmp(other.buf) } } -impl<'a> IntoIterator for BorrowedBytes<'a> { +impl<'a> IntoIterator for &'a BorrowedBytes<'_> { type Item = &'a u8; type IntoIter = slice::Iter<'a, u8>; fn into_iter(self) -> Self::IntoIter { - self.0.iter() + self.iter() + } +} + +impl<'a> From<&'a String> for BorrowedBytes<'a> { + #[inline] + fn from(value: &'a String) -> Self { + let (buf, _lua) = unsafe { value.to_slice() }; + let borrow = Cow::Borrowed(value); + Self { buf, borrow, _lua } } } diff --git a/src/table.rs b/src/table.rs index 3f100cad..8d43ff55 100644 --- a/src/table.rs +++ b/src/table.rs @@ -468,26 +468,16 @@ impl Table { /// /// It checks both the array part and the hash part. pub fn is_empty(&self) -> bool { - // Check array part - if self.raw_len() != 0 { - return false; - } - - // Check hash part let lua = self.0.lua.lock(); - let state = lua.state(); + let ref_thread = lua.ref_thread(); unsafe { - let _sg = StackGuard::new(state); - assert_stack(state, 4); - - lua.push_ref(&self.0); - ffi::lua_pushnil(state); - if ffi::lua_next(state, -2) != 0 { - return false; + ffi::lua_pushnil(ref_thread); + if ffi::lua_next(ref_thread, self.0.index) == 0 { + return true; } + ffi::lua_pop(ref_thread, 2); } - - true + false } /// Returns a reference to the metatable of this table, or `None` if no metatable is set. @@ -1030,7 +1020,10 @@ impl Serialize for SerializableTable<'_> { // Array let len = self.table.raw_len(); - if len > 0 || self.table.is_array() { + if len > 0 + || self.table.is_array() + || (self.options.encode_empty_tables_as_array && self.table.is_empty()) + { let mut seq = serializer.serialize_seq(Some(len))?; let mut serialize_err = None; let res = self.table.for_each_value::(|value| { diff --git a/src/thread.rs b/src/thread.rs index aaec30a7..35e1ee3d 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -2,8 +2,7 @@ use std::fmt; use std::os::raw::{c_int, c_void}; use crate::error::{Error, Result}; -#[allow(unused)] -use crate::state::Lua; +use crate::function::Function; use crate::state::RawLua; use crate::traits::{FromLuaMulti, IntoLuaMulti}; use crate::types::{LuaType, ValueRef}; @@ -42,6 +41,26 @@ pub enum ThreadStatus { Error, } +/// Internal representation of a Lua thread status. +/// +/// The number in `New` and `Yielded` variants is the number of arguments pushed +/// to the thread stack. +#[derive(Clone, Copy)] +enum ThreadStatusInner { + New, + Running, + Yielded, + Finished, + Error, +} + +impl ThreadStatusInner { + #[inline(always)] + fn is_resumable(self) -> bool { + matches!(self, ThreadStatusInner::New | ThreadStatusInner::Yielded) + } +} + /// Handle to an internal Lua thread (coroutine). #[derive(Clone)] pub struct Thread(pub(crate) ValueRef, pub(crate) *mut ffi::lua_State); @@ -122,7 +141,7 @@ impl Thread { R: FromLuaMulti, { let lua = self.0.lua.lock(); - if self.status_inner(&lua) != ThreadStatus::Resumable { + if !self.status_inner(&lua).is_resumable() { return Err(Error::CoroutineUnresumable); } @@ -170,23 +189,27 @@ impl Thread { /// Gets the status of the thread. pub fn status(&self) -> ThreadStatus { - self.status_inner(&self.0.lua.lock()) + match self.status_inner(&self.0.lua.lock()) { + ThreadStatusInner::New | ThreadStatusInner::Yielded => ThreadStatus::Resumable, + ThreadStatusInner::Running => ThreadStatus::Running, + ThreadStatusInner::Finished => ThreadStatus::Finished, + ThreadStatusInner::Error => ThreadStatus::Error, + } } /// Gets the status of the thread (internal implementation). - pub(crate) fn status_inner(&self, lua: &RawLua) -> ThreadStatus { + fn status_inner(&self, lua: &RawLua) -> ThreadStatusInner { let thread_state = self.state(); if thread_state == lua.state() { // The thread is currently running - return ThreadStatus::Running; + return ThreadStatusInner::Running; } let status = unsafe { ffi::lua_status(thread_state) }; - if status != ffi::LUA_OK && status != ffi::LUA_YIELD { - ThreadStatus::Error - } else if status == ffi::LUA_YIELD || unsafe { ffi::lua_gettop(thread_state) > 0 } { - ThreadStatus::Resumable - } else { - ThreadStatus::Finished + match status { + ffi::LUA_YIELD => ThreadStatusInner::Yielded, + ffi::LUA_OK if unsafe { ffi::lua_gettop(thread_state) } > 0 => ThreadStatusInner::New, + ffi::LUA_OK => ThreadStatusInner::Finished, + _ => ThreadStatusInner::Error, } } @@ -198,7 +221,7 @@ impl Thread { #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] pub fn set_hook(&self, triggers: HookTriggers, callback: F) where - F: Fn(&Lua, Debug) -> Result + MaybeSend + 'static, + F: Fn(&crate::Lua, Debug) -> Result + MaybeSend + 'static, { let lua = self.0.lua.lock(); unsafe { @@ -215,32 +238,37 @@ impl Thread { /// In Luau: resets to the initial state of a newly created Lua thread. /// Lua threads in arbitrary states (like yielded or errored) can be reset properly. /// - /// Sets a Lua function for the thread afterwards. + /// Other Lua versions can reset only new or finished threads. /// - /// Requires `feature = "lua54"` OR `feature = "luau"`. + /// Sets a Lua function for the thread afterwards. /// /// [Lua 5.4]: https://www.lua.org/manual/5.4/manual.html#lua_closethread - #[cfg(any(feature = "lua54", feature = "luau"))] - #[cfg_attr(docsrs, doc(cfg(any(feature = "lua54", feature = "luau"))))] - pub fn reset(&self, func: crate::function::Function) -> Result<()> { + pub fn reset(&self, func: Function) -> Result<()> { let lua = self.0.lua.lock(); - if self.status_inner(&lua) == ThreadStatus::Running { - return Err(Error::runtime("cannot reset a running thread")); + let thread_state = self.state(); + match self.status_inner(&lua) { + ThreadStatusInner::Running => return Err(Error::runtime("cannot reset a running thread")), + // Any Lua can reuse new or finished thread + ThreadStatusInner::New => unsafe { ffi::lua_settop(thread_state, 0) }, + ThreadStatusInner::Finished => {} + #[cfg(not(any(feature = "lua54", feature = "luau")))] + _ => return Err(Error::runtime("cannot reset non-finished thread")), + #[cfg(any(feature = "lua54", feature = "luau"))] + _ => unsafe { + #[cfg(all(feature = "lua54", not(feature = "vendored")))] + let status = ffi::lua_resetthread(thread_state); + #[cfg(all(feature = "lua54", feature = "vendored"))] + let status = ffi::lua_closethread(thread_state, lua.state()); + #[cfg(feature = "lua54")] + if status != ffi::LUA_OK { + return Err(pop_error(thread_state, status)); + } + #[cfg(feature = "luau")] + ffi::lua_resetthread(thread_state); + }, } - let thread_state = self.state(); unsafe { - #[cfg(all(feature = "lua54", not(feature = "vendored")))] - let status = ffi::lua_resetthread(thread_state); - #[cfg(all(feature = "lua54", feature = "vendored"))] - let status = ffi::lua_closethread(thread_state, lua.state()); - #[cfg(feature = "lua54")] - if status != ffi::LUA_OK { - return Err(pop_error(thread_state, status)); - } - #[cfg(feature = "luau")] - ffi::lua_resetthread(thread_state); - // Push function to the top of the thread stack ffi::lua_xpush(lua.ref_thread(), thread_state, func.0.index); @@ -338,7 +366,7 @@ impl Thread { /// Ok(()) /// })?)?; /// thread.sandbox()?; - /// thread.resume(())?; + /// thread.resume::<()>(())?; /// /// // The global environment should be unchanged /// assert_eq!(lua.globals().get::>("var")?, None); @@ -393,30 +421,19 @@ impl LuaType for Thread { #[cfg(feature = "async")] impl AsyncThread { - #[inline] + #[inline(always)] pub(crate) fn set_recyclable(&mut self, recyclable: bool) { self.recycle = recyclable; } } #[cfg(feature = "async")] -#[cfg(any(feature = "lua54", feature = "luau"))] impl Drop for AsyncThread { fn drop(&mut self) { if self.recycle { if let Some(lua) = self.thread.0.lua.try_lock() { - unsafe { - // For Lua 5.4 this also closes all pending to-be-closed variables - if !lua.recycle_thread(&mut self.thread) { - #[cfg(feature = "lua54")] - if self.thread.status_inner(&lua) == ThreadStatus::Error { - #[cfg(not(feature = "vendored"))] - ffi::lua_resetthread(self.thread.state()); - #[cfg(feature = "vendored")] - ffi::lua_closethread(self.thread.state(), lua.state()); - } - } - } + // For Lua 5.4 this also closes all pending to-be-closed variables + unsafe { lua.recycle_thread(&mut self.thread) }; } } } @@ -428,7 +445,7 @@ impl Stream for AsyncThread { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let lua = self.thread.0.lua.lock(); - if self.thread.status_inner(&lua) != ThreadStatus::Resumable { + if !self.thread.status_inner(&lua).is_resumable() { return Poll::Ready(None); } @@ -466,7 +483,7 @@ impl Future for AsyncThread { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let lua = self.thread.0.lua.lock(); - if self.thread.status_inner(&lua) != ThreadStatus::Resumable { + if !self.thread.status_inner(&lua).is_resumable() { return Poll::Ready(Err(Error::CoroutineUnresumable)); } @@ -506,7 +523,7 @@ impl Future for AsyncThread { #[cfg(feature = "async")] #[inline(always)] unsafe fn is_poll_pending(state: *mut ffi::lua_State) -> bool { - ffi::lua_tolightuserdata(state, -1) == Lua::poll_pending().0 + ffi::lua_tolightuserdata(state, -1) == crate::Lua::poll_pending().0 } #[cfg(feature = "async")] diff --git a/src/types.rs b/src/types.rs index afeb239d..6a2b4d25 100644 --- a/src/types.rs +++ b/src/types.rs @@ -86,10 +86,10 @@ pub(crate) type InterruptCallback = Rc Result + Send>; pub(crate) type InterruptCallback = Rc Result>; #[cfg(all(feature = "send", feature = "lua54"))] -pub(crate) type WarnCallback = Box Result<()> + Send>; +pub(crate) type WarnCallback = XRc Result<()> + Send>; #[cfg(all(not(feature = "send"), feature = "lua54"))] -pub(crate) type WarnCallback = Box Result<()>>; +pub(crate) type WarnCallback = XRc Result<()>>; /// A trait that adds `Send` requirement if `send` feature is enabled. #[cfg(feature = "send")] diff --git a/src/userdata.rs b/src/userdata.rs index 83c5efcd..a4631905 100644 --- a/src/userdata.rs +++ b/src/userdata.rs @@ -26,9 +26,13 @@ use { // Re-export for convenience pub(crate) use cell::UserDataStorage; -pub use cell::{UserDataRef, UserDataRefMut}; +pub use r#ref::{UserDataRef, UserDataRefMut}; pub use registry::UserDataRegistry; pub(crate) use registry::{RawUserDataRegistry, UserDataProxy}; +pub(crate) use util::{ + borrow_userdata_scoped, borrow_userdata_scoped_mut, collect_userdata, init_userdata_metatable, + TypeIdHints, +}; /// Kinds of metamethods that can be overridden. /// @@ -622,7 +626,9 @@ impl AnyUserData { /// Checks whether the type of this userdata is `T`. #[inline] pub fn is(&self) -> bool { - self.inspect::(|_| Ok(())).is_ok() + let type_id = self.type_id(); + // We do not use wrapped types here, rather prefer to check the "real" type of the userdata + matches!(type_id, Some(type_id) if type_id == TypeId::of::()) } /// Borrow this userdata immutably if it is of type `T`. @@ -637,7 +643,8 @@ impl AnyUserData { /// [`DataTypeMismatch`]: crate::Error::UserDataTypeMismatch #[inline] pub fn borrow(&self) -> Result> { - self.inspect(|ud| ud.try_borrow_owned()) + let lua = self.0.lua.lock(); + unsafe { UserDataRef::borrow_from_stack(&lua, lua.ref_thread(), self.0.index) } } /// Borrow this userdata immutably if it is of type `T`, passing the borrowed value @@ -645,7 +652,10 @@ impl AnyUserData { /// /// This method is the only way to borrow scoped userdata (created inside [`Lua::scope`]). pub fn borrow_scoped(&self, f: impl FnOnce(&T) -> R) -> Result { - self.inspect(|ud| ud.try_borrow_scoped(|ud| f(ud))) + let lua = self.0.lua.lock(); + let type_id = lua.get_userdata_ref_type_id(&self.0)?; + let type_hints = TypeIdHints::new::(); + unsafe { borrow_userdata_scoped(lua.ref_thread(), self.0.index, type_id, type_hints, f) } } /// Borrow this userdata mutably if it is of type `T`. @@ -660,7 +670,8 @@ impl AnyUserData { /// [`UserDataTypeMismatch`]: crate::Error::UserDataTypeMismatch #[inline] pub fn borrow_mut(&self) -> Result> { - self.inspect(|ud| ud.try_borrow_owned_mut()) + let lua = self.0.lua.lock(); + unsafe { UserDataRefMut::borrow_from_stack(&lua, lua.ref_thread(), self.0.index) } } /// Borrow this userdata mutably if it is of type `T`, passing the borrowed value @@ -668,7 +679,10 @@ impl AnyUserData { /// /// This method is the only way to borrow scoped userdata (created inside [`Lua::scope`]). pub fn borrow_mut_scoped(&self, f: impl FnOnce(&mut T) -> R) -> Result { - self.inspect(|ud| ud.try_borrow_scoped_mut(|ud| f(ud))) + let lua = self.0.lua.lock(); + let type_id = lua.get_userdata_ref_type_id(&self.0)?; + let type_hints = TypeIdHints::new::(); + unsafe { borrow_userdata_scoped_mut(lua.ref_thread(), self.0.index, type_id, type_hints, f) } } /// Takes the value out of this userdata. @@ -687,9 +701,11 @@ impl AnyUserData { let type_id = lua.push_userdata_ref(&self.0)?; match type_id { Some(type_id) if type_id == TypeId::of::() => { - // Try to borrow userdata exclusively - let _ = (*get_userdata::>(state, -1)).try_borrow_mut()?; - take_userdata::>(state).into_inner() + if (*get_userdata::>(state, -1)).has_exclusive_access() { + take_userdata::>(state).into_inner() + } else { + Err(Error::UserDataBorrowMutError) + } } _ => Err(Error::UserDataTypeMismatch), } @@ -910,6 +926,15 @@ impl AnyUserData { self.0.to_pointer() } + /// Returns [`TypeId`] of this userdata if it is registered and `'static`. + /// + /// This method is not available for scoped userdata. + #[inline] + pub fn type_id(&self) -> Option { + let lua = self.0.lua.lock(); + lua.get_userdata_ref_type_id(&self.0).ok().flatten() + } + /// Returns a type name of this `UserData` (from a metatable field). pub(crate) fn type_name(&self) -> Result> { let lua = self.0.lua.lock(); @@ -965,24 +990,6 @@ impl AnyUserData { }; is_serializable().unwrap_or(false) } - - pub(crate) fn inspect(&self, func: F) -> Result - where - T: 'static, - F: FnOnce(&UserDataStorage) -> Result, - { - let lua = self.0.lua.lock(); - unsafe { - let type_id = lua.get_userdata_ref_type_id(&self.0)?; - match type_id { - Some(type_id) if type_id == TypeId::of::() => { - let ud = get_userdata::>(lua.ref_thread(), self.0.index); - func(&*ud) - } - _ => Err(Error::UserDataTypeMismatch), - } - } - } } /// Handle to a [`AnyUserData`] metatable. @@ -1106,6 +1113,7 @@ where mod cell; mod lock; mod object; +mod r#ref; mod registry; mod util; diff --git a/src/userdata/cell.rs b/src/userdata/cell.rs index d5d8e005..538e33d7 100644 --- a/src/userdata/cell.rs +++ b/src/userdata/cell.rs @@ -1,22 +1,13 @@ -use std::any::{type_name, TypeId}; -use std::cell::{Cell, RefCell, UnsafeCell}; -use std::fmt; -use std::ops::{Deref, DerefMut}; -use std::os::raw::c_int; +use std::cell::{RefCell, UnsafeCell}; #[cfg(feature = "serialize")] use serde::ser::{Serialize, Serializer}; use crate::error::{Error, Result}; -use crate::state::{Lua, RawLua}; -use crate::traits::FromLua; use crate::types::XRc; -use crate::userdata::AnyUserData; -use crate::util::get_userdata; -use crate::value::Value; use super::lock::{RawLock, UserDataLock}; -use super::util::is_sync; +use super::r#ref::{UserDataRef, UserDataRefMut}; #[cfg(all(feature = "serialize", not(feature = "send")))] type DynSerialize = dyn erased_serde::Serialize; @@ -34,7 +25,7 @@ pub(crate) enum UserDataStorage { pub(crate) enum UserDataVariant { Default(XRc>), #[cfg(feature = "serialize")] - Serializable(XRc>>), + Serializable(XRc>>, bool), // bool is `is_sync` } impl Clone for UserDataVariant { @@ -43,28 +34,34 @@ impl Clone for UserDataVariant { match self { Self::Default(inner) => Self::Default(XRc::clone(inner)), #[cfg(feature = "serialize")] - Self::Serializable(inner) => Self::Serializable(XRc::clone(inner)), + Self::Serializable(inner, is_sync) => Self::Serializable(XRc::clone(inner), *is_sync), } } } impl UserDataVariant { - // Immutably borrows the wrapped value in-place. #[inline(always)] - fn try_borrow(&self) -> Result> { - UserDataBorrowRef::try_from(self) + pub(super) fn try_borrow_scoped(&self, f: impl FnOnce(&T) -> R) -> Result { + // We don't need to check for `T: Sync` because when this method is used (internally), + // Lua mutex is already locked. + // If non-`Sync` userdata is already borrowed by another thread (via `UserDataRef`), it will be + // exclusively locked. + let _guard = (self.raw_lock().try_lock_shared_guarded()).map_err(|_| Error::UserDataBorrowError)?; + Ok(f(unsafe { &*self.as_ptr() })) } - // Immutably borrows the wrapped value and returns an owned reference. + // Mutably borrows the wrapped value in-place. #[inline(always)] - fn try_borrow_owned(&self) -> Result> { - UserDataRef::try_from(self.clone()) + fn try_borrow_scoped_mut(&self, f: impl FnOnce(&mut T) -> R) -> Result { + let _guard = + (self.raw_lock().try_lock_exclusive_guarded()).map_err(|_| Error::UserDataBorrowMutError)?; + Ok(f(unsafe { &mut *self.as_ptr() })) } - // Mutably borrows the wrapped value in-place. + // Immutably borrows the wrapped value and returns an owned reference. #[inline(always)] - fn try_borrow_mut(&self) -> Result> { - UserDataBorrowMut::try_from(self) + fn try_borrow_owned(&self) -> Result> { + UserDataRef::try_from(self.clone()) } // Mutably borrows the wrapped value and returns an owned reference. @@ -83,7 +80,7 @@ impl UserDataVariant { Ok(match self { Self::Default(inner) => XRc::into_inner(inner).unwrap().value.into_inner(), #[cfg(feature = "serialize")] - Self::Serializable(inner) => unsafe { + Self::Serializable(inner, _) => unsafe { let raw = Box::into_raw(XRc::into_inner(inner).unwrap().value.into_inner()); *Box::from_raw(raw as *mut T) }, @@ -91,29 +88,29 @@ impl UserDataVariant { } #[inline(always)] - fn raw_lock(&self) -> &RawLock { + fn strong_count(&self) -> usize { match self { - Self::Default(inner) => &inner.raw_lock, + Self::Default(inner) => XRc::strong_count(inner), #[cfg(feature = "serialize")] - Self::Serializable(inner) => &inner.raw_lock, + Self::Serializable(inner, _) => XRc::strong_count(inner), } } #[inline(always)] - fn borrow_count(&self) -> &Cell { + pub(super) fn raw_lock(&self) -> &RawLock { match self { - Self::Default(inner) => &inner.borrow_count, + Self::Default(inner) => &inner.raw_lock, #[cfg(feature = "serialize")] - Self::Serializable(inner) => &inner.borrow_count, + Self::Serializable(inner, _) => &inner.raw_lock, } } #[inline(always)] - fn as_ptr(&self) -> *mut T { + pub(super) fn as_ptr(&self) -> *mut T { match self { Self::Default(inner) => inner.value.get(), #[cfg(feature = "serialize")] - Self::Serializable(inner) => unsafe { &mut **(inner.value.get() as *mut Box) }, + Self::Serializable(inner, _) => unsafe { &mut **(inner.value.get() as *mut Box) }, } } } @@ -122,14 +119,24 @@ impl UserDataVariant { impl Serialize for UserDataStorage<()> { fn serialize(&self, serializer: S) -> std::result::Result { match self { - Self::Owned(UserDataVariant::Serializable(inner)) => unsafe { - // We need to borrow the inner value exclusively to serialize it. + Self::Owned(variant @ UserDataVariant::Serializable(inner, is_sync)) => unsafe { #[cfg(feature = "send")] - let _guard = self.try_borrow_mut().map_err(serde::ser::Error::custom)?; - // No need to do this if the `send` feature is disabled. + if *is_sync { + let _guard = (variant.raw_lock().try_lock_shared_guarded()) + .map_err(|_| serde::ser::Error::custom(Error::UserDataBorrowError))?; + (*inner.value.get()).serialize(serializer) + } else { + let _guard = (variant.raw_lock().try_lock_exclusive_guarded()) + .map_err(|_| serde::ser::Error::custom(Error::UserDataBorrowError))?; + (*inner.value.get()).serialize(serializer) + } #[cfg(not(feature = "send"))] - let _guard = self.try_borrow().map_err(serde::ser::Error::custom)?; - (*inner.value.get()).serialize(serializer) + { + let _ = is_sync; + let _guard = (variant.raw_lock().try_lock_shared_guarded()) + .map_err(|_| serde::ser::Error::custom(Error::UserDataBorrowError))?; + (*inner.value.get()).serialize(serializer) + } }, _ => Err(serde::ser::Error::custom("cannot serialize ")), } @@ -139,7 +146,6 @@ impl Serialize for UserDataStorage<()> { /// A type that provides interior mutability for a userdata value (thread-safe). pub(crate) struct UserDataCell { raw_lock: RawLock, - borrow_count: Cell, value: UnsafeCell, } @@ -153,242 +159,11 @@ impl UserDataCell { fn new(value: T) -> Self { UserDataCell { raw_lock: RawLock::INIT, - borrow_count: Cell::new(0), value: UnsafeCell::new(value), } } } -/// A wrapper type for a userdata value that provides read access. -/// -/// It implements [`FromLua`] and can be used to receive a typed userdata from Lua. -pub struct UserDataRef(UserDataVariant); - -impl Deref for UserDataRef { - type Target = T; - - #[inline] - fn deref(&self) -> &T { - unsafe { &*self.0.as_ptr() } - } -} - -impl Drop for UserDataRef { - #[inline] - fn drop(&mut self) { - if !cfg!(feature = "send") || is_sync::() { - unsafe { self.0.raw_lock().unlock_shared() }; - } else { - unsafe { self.0.raw_lock().unlock_exclusive() }; - } - } -} - -impl fmt::Debug for UserDataRef { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - (**self).fmt(f) - } -} - -impl fmt::Display for UserDataRef { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - (**self).fmt(f) - } -} - -impl TryFrom> for UserDataRef { - type Error = Error; - - #[inline] - fn try_from(variant: UserDataVariant) -> Result { - if !cfg!(feature = "send") || is_sync::() { - if !variant.raw_lock().try_lock_shared() { - return Err(Error::UserDataBorrowError); - } - } else if !variant.raw_lock().try_lock_exclusive() { - return Err(Error::UserDataBorrowError); - } - Ok(UserDataRef(variant)) - } -} - -impl FromLua for UserDataRef { - fn from_lua(value: Value, _: &Lua) -> Result { - try_value_to_userdata::(value)?.borrow() - } - - unsafe fn from_stack(idx: c_int, lua: &RawLua) -> Result { - let type_id = lua.get_userdata_type_id::(idx)?; - match type_id { - Some(type_id) if type_id == TypeId::of::() => { - (*get_userdata::>(lua.state(), idx)).try_borrow_owned() - } - _ => Err(Error::UserDataTypeMismatch), - } - } -} - -/// A wrapper type for a userdata value that provides read and write access. -/// -/// It implements [`FromLua`] and can be used to receive a typed userdata from Lua. -pub struct UserDataRefMut(UserDataVariant); - -impl Deref for UserDataRefMut { - type Target = T; - - #[inline] - fn deref(&self) -> &Self::Target { - unsafe { &*self.0.as_ptr() } - } -} - -impl DerefMut for UserDataRefMut { - #[inline] - fn deref_mut(&mut self) -> &mut Self::Target { - unsafe { &mut *self.0.as_ptr() } - } -} - -impl Drop for UserDataRefMut { - #[inline] - fn drop(&mut self) { - unsafe { self.0.raw_lock().unlock_exclusive() }; - } -} - -impl fmt::Debug for UserDataRefMut { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - (**self).fmt(f) - } -} - -impl fmt::Display for UserDataRefMut { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - (**self).fmt(f) - } -} - -impl TryFrom> for UserDataRefMut { - type Error = Error; - - #[inline] - fn try_from(variant: UserDataVariant) -> Result { - if !variant.raw_lock().try_lock_exclusive() { - return Err(Error::UserDataBorrowMutError); - } - Ok(UserDataRefMut(variant)) - } -} - -impl FromLua for UserDataRefMut { - fn from_lua(value: Value, _: &Lua) -> Result { - try_value_to_userdata::(value)?.borrow_mut() - } - - unsafe fn from_stack(idx: c_int, lua: &RawLua) -> Result { - let type_id = lua.get_userdata_type_id::(idx)?; - match type_id { - Some(type_id) if type_id == TypeId::of::() => { - (*get_userdata::>(lua.state(), idx)).try_borrow_owned_mut() - } - _ => Err(Error::UserDataTypeMismatch), - } - } -} - -/// A type that provides read access to a userdata value (borrowing the value). -pub(crate) struct UserDataBorrowRef<'a, T>(&'a UserDataVariant); - -impl Drop for UserDataBorrowRef<'_, T> { - #[inline] - fn drop(&mut self) { - unsafe { - self.0.borrow_count().set(self.0.borrow_count().get() - 1); - self.0.raw_lock().unlock_shared(); - } - } -} - -impl Deref for UserDataBorrowRef<'_, T> { - type Target = T; - - #[inline] - fn deref(&self) -> &T { - // SAFETY: `UserDataBorrowRef` is only created with shared access to the value. - unsafe { &*self.0.as_ptr() } - } -} - -impl<'a, T> TryFrom<&'a UserDataVariant> for UserDataBorrowRef<'a, T> { - type Error = Error; - - #[inline(always)] - fn try_from(variant: &'a UserDataVariant) -> Result { - // We don't need to check for `T: Sync` because when this method is used (internally), - // Lua mutex is already locked. - // If non-`Sync` userdata is already borrowed by another thread (via `UserDataRef`), it will be - // exclusively locked. - if !variant.raw_lock().try_lock_shared() { - return Err(Error::UserDataBorrowError); - } - variant.borrow_count().set(variant.borrow_count().get() + 1); - Ok(UserDataBorrowRef(variant)) - } -} - -pub(crate) struct UserDataBorrowMut<'a, T>(&'a UserDataVariant); - -impl Drop for UserDataBorrowMut<'_, T> { - #[inline] - fn drop(&mut self) { - unsafe { - self.0.borrow_count().set(self.0.borrow_count().get() - 1); - self.0.raw_lock().unlock_exclusive(); - } - } -} - -impl Deref for UserDataBorrowMut<'_, T> { - type Target = T; - - #[inline] - fn deref(&self) -> &T { - unsafe { &*self.0.as_ptr() } - } -} - -impl DerefMut for UserDataBorrowMut<'_, T> { - #[inline] - fn deref_mut(&mut self) -> &mut T { - unsafe { &mut *self.0.as_ptr() } - } -} - -impl<'a, T> TryFrom<&'a UserDataVariant> for UserDataBorrowMut<'a, T> { - type Error = Error; - - #[inline(always)] - fn try_from(variant: &'a UserDataVariant) -> Result { - if !variant.raw_lock().try_lock_exclusive() { - return Err(Error::UserDataBorrowMutError); - } - variant.borrow_count().set(variant.borrow_count().get() + 1); - Ok(UserDataBorrowMut(variant)) - } -} - -#[inline] -fn try_value_to_userdata(value: Value) -> Result { - match value { - Value::UserData(ud) => Ok(ud), - _ => Err(Error::FromLuaConversionError { - from: value.type_name(), - to: "userdata".to_string(), - message: Some(format!("expected userdata of type {}", type_name::())), - }), - } -} - pub(crate) enum ScopedUserDataVariant { Ref(*const T), RefMut(RefCell<*mut T>), @@ -429,13 +204,15 @@ impl UserDataStorage { T: Serialize + crate::types::MaybeSend, { let data = Box::new(data) as Box; - Self::Owned(UserDataVariant::Serializable(XRc::new(UserDataCell::new(data)))) + let is_sync = super::util::is_sync::(); + let variant = UserDataVariant::Serializable(XRc::new(UserDataCell::new(data)), is_sync); + Self::Owned(variant) } #[cfg(feature = "serialize")] #[inline(always)] pub(crate) fn is_serializable(&self) -> bool { - matches!(self, Self::Owned(UserDataVariant::Serializable(_))) + matches!(self, Self::Owned(UserDataVariant::Serializable(..))) } // Immutably borrows the wrapped value and returns an owned reference. @@ -447,23 +224,6 @@ impl UserDataStorage { } } - #[allow(unused)] - #[inline(always)] - pub(crate) fn try_borrow(&self) -> Result> { - match self { - Self::Owned(data) => data.try_borrow(), - Self::Scoped(_) => Err(Error::UserDataTypeMismatch), - } - } - - #[inline(always)] - pub(crate) fn try_borrow_mut(&self) -> Result> { - match self { - Self::Owned(data) => data.try_borrow_mut(), - Self::Scoped(_) => Err(Error::UserDataTypeMismatch), - } - } - // Mutably borrows the wrapped value and returns an owned reference. #[inline(always)] pub(crate) fn try_borrow_owned_mut(&self) -> Result> { @@ -489,18 +249,31 @@ impl UserDataStorage { Self::Scoped(ScopedUserDataVariant::Boxed(RefCell::new(data))) } + /// Returns `true` if it's safe to destroy the container. + /// + /// It's safe to destroy the container if the reference count is greater than 1 or the lock is + /// not acquired. #[inline(always)] - pub(crate) fn is_borrowed(&self) -> bool { + pub(crate) fn is_safe_to_destroy(&self) -> bool { match self { - Self::Owned(variant) => variant.borrow_count().get() > 0, - Self::Scoped(_) => true, + Self::Owned(variant) => variant.strong_count() > 1 || !variant.raw_lock().is_locked(), + Self::Scoped(_) => false, + } + } + + /// Returns `true` if the container has exclusive access to the value. + #[inline(always)] + pub(crate) fn has_exclusive_access(&self) -> bool { + match self { + Self::Owned(variant) => !variant.raw_lock().is_locked(), + Self::Scoped(_) => false, } } #[inline] pub(crate) fn try_borrow_scoped(&self, f: impl FnOnce(&T) -> R) -> Result { match self { - Self::Owned(data) => Ok(f(&*data.try_borrow()?)), + Self::Owned(data) => data.try_borrow_scoped(f), Self::Scoped(ScopedUserDataVariant::Ref(value)) => Ok(f(unsafe { &**value })), Self::Scoped(ScopedUserDataVariant::RefMut(value) | ScopedUserDataVariant::Boxed(value)) => { let t = value.try_borrow().map_err(|_| Error::UserDataBorrowError)?; @@ -512,7 +285,7 @@ impl UserDataStorage { #[inline] pub(crate) fn try_borrow_scoped_mut(&self, f: impl FnOnce(&mut T) -> R) -> Result { match self { - Self::Owned(data) => Ok(f(&mut *data.try_borrow_mut()?)), + Self::Owned(data) => data.try_borrow_scoped_mut(f), Self::Scoped(ScopedUserDataVariant::Ref(_)) => Err(Error::UserDataBorrowMutError), Self::Scoped(ScopedUserDataVariant::RefMut(value) | ScopedUserDataVariant::Boxed(value)) => { let mut t = value @@ -523,30 +296,3 @@ impl UserDataStorage { } } } - -#[cfg(test)] -mod assertions { - use super::*; - - #[cfg(feature = "send")] - static_assertions::assert_impl_all!(UserDataRef<()>: Send, Sync); - #[cfg(feature = "send")] - static_assertions::assert_not_impl_all!(UserDataRef>: Send, Sync); - #[cfg(feature = "send")] - static_assertions::assert_impl_all!(UserDataRefMut<()>: Sync, Send); - #[cfg(feature = "send")] - static_assertions::assert_not_impl_all!(UserDataRefMut>: Send, Sync); - #[cfg(feature = "send")] - static_assertions::assert_impl_all!(UserDataBorrowRef<'_, ()>: Send, Sync); - #[cfg(feature = "send")] - static_assertions::assert_impl_all!(UserDataBorrowMut<'_, ()>: Send, Sync); - - #[cfg(not(feature = "send"))] - static_assertions::assert_not_impl_all!(UserDataRef<()>: Send, Sync); - #[cfg(not(feature = "send"))] - static_assertions::assert_not_impl_all!(UserDataRefMut<()>: Send, Sync); - #[cfg(not(feature = "send"))] - static_assertions::assert_not_impl_all!(UserDataBorrowRef<'_, ()>: Send, Sync); - #[cfg(not(feature = "send"))] - static_assertions::assert_not_impl_all!(UserDataBorrowMut<'_, ()>: Send, Sync); -} diff --git a/src/userdata/lock.rs b/src/userdata/lock.rs index c5690444..e0e5d1af 100644 --- a/src/userdata/lock.rs +++ b/src/userdata/lock.rs @@ -1,11 +1,51 @@ pub(crate) trait UserDataLock { const INIT: Self; + fn is_locked(&self) -> bool; fn try_lock_shared(&self) -> bool; fn try_lock_exclusive(&self) -> bool; unsafe fn unlock_shared(&self); unsafe fn unlock_exclusive(&self); + + fn try_lock_shared_guarded(&self) -> Result, ()> { + if self.try_lock_shared() { + Ok(LockGuard { + lock: self, + exclusive: false, + }) + } else { + Err(()) + } + } + + fn try_lock_exclusive_guarded(&self) -> Result, ()> { + if self.try_lock_exclusive() { + Ok(LockGuard { + lock: self, + exclusive: true, + }) + } else { + Err(()) + } + } +} + +pub(crate) struct LockGuard<'a, L: UserDataLock + ?Sized> { + lock: &'a L, + exclusive: bool, +} + +impl Drop for LockGuard<'_, L> { + fn drop(&mut self) { + unsafe { + if self.exclusive { + self.lock.unlock_exclusive(); + } else { + self.lock.unlock_shared(); + } + } + } } pub(crate) use lock_impl::RawLock; @@ -25,6 +65,11 @@ mod lock_impl { #[allow(clippy::declare_interior_mutable_const)] const INIT: Self = Cell::new(UNUSED); + #[inline(always)] + fn is_locked(&self) -> bool { + self.get() != UNUSED + } + #[inline(always)] fn try_lock_shared(&self) -> bool { let flag = self.get().wrapping_add(1); @@ -71,6 +116,11 @@ mod lock_impl { #[allow(clippy::declare_interior_mutable_const)] const INIT: Self = ::INIT; + #[inline(always)] + fn is_locked(&self) -> bool { + RawRwLock::is_locked(self) + } + #[inline(always)] fn try_lock_shared(&self) -> bool { RawRwLock::try_lock_shared(self) diff --git a/src/userdata/ref.rs b/src/userdata/ref.rs new file mode 100644 index 00000000..4302d0e3 --- /dev/null +++ b/src/userdata/ref.rs @@ -0,0 +1,474 @@ +use std::any::TypeId; +use std::ops::{Deref, DerefMut}; +use std::os::raw::c_int; +use std::{fmt, mem}; + +use crate::error::{Error, Result}; +use crate::state::{Lua, RawLua}; +use crate::traits::FromLua; +use crate::userdata::AnyUserData; +use crate::util::{get_userdata, short_type_name}; +use crate::value::Value; + +use super::cell::{UserDataStorage, UserDataVariant}; +use super::lock::{LockGuard, RawLock, UserDataLock}; +use super::util::is_sync; + +#[cfg(feature = "userdata-wrappers")] +use { + parking_lot::{ + Mutex as MutexPL, MutexGuard as MutexGuardPL, RwLock as RwLockPL, + RwLockReadGuard as RwLockReadGuardPL, RwLockWriteGuard as RwLockWriteGuardPL, + }, + std::sync::Arc, +}; +#[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] +use { + std::cell::{Ref, RefCell, RefMut}, + std::rc::Rc, +}; + +/// A wrapper type for a userdata value that provides read access. +/// +/// It implements [`FromLua`] and can be used to receive a typed userdata from Lua. +pub struct UserDataRef { + // It's important to drop the guard first, as it refers to the `inner` data. + _guard: LockGuard<'static, RawLock>, + inner: UserDataRefInner, +} + +impl Deref for UserDataRef { + type Target = T; + + #[inline] + fn deref(&self) -> &T { + &self.inner + } +} + +impl fmt::Debug for UserDataRef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +impl fmt::Display for UserDataRef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +impl TryFrom> for UserDataRef { + type Error = Error; + + #[inline] + fn try_from(variant: UserDataVariant) -> Result { + let guard = if !cfg!(feature = "send") || is_sync::() { + variant.raw_lock().try_lock_shared_guarded() + } else { + variant.raw_lock().try_lock_exclusive_guarded() + }; + let guard = guard.map_err(|_| Error::UserDataBorrowError)?; + let guard = unsafe { mem::transmute::, LockGuard<'static, _>>(guard) }; + Ok(UserDataRef::from_parts(UserDataRefInner::Default(variant), guard)) + } +} + +impl FromLua for UserDataRef { + fn from_lua(value: Value, _: &Lua) -> Result { + try_value_to_userdata::(value)?.borrow() + } + + #[inline] + unsafe fn from_stack(idx: c_int, lua: &RawLua) -> Result { + Self::borrow_from_stack(lua, lua.state(), idx) + } +} + +impl UserDataRef { + #[inline(always)] + fn from_parts(inner: UserDataRefInner, guard: LockGuard<'static, RawLock>) -> Self { + Self { _guard: guard, inner } + } + + #[cfg(feature = "userdata-wrappers")] + fn remap( + self, + f: impl FnOnce(UserDataVariant) -> Result>, + ) -> Result> { + match &self.inner { + UserDataRefInner::Default(variant) => { + let inner = f(variant.clone())?; + Ok(UserDataRef::from_parts(inner, self._guard)) + } + _ => Err(Error::UserDataTypeMismatch), + } + } + + pub(crate) unsafe fn borrow_from_stack( + lua: &RawLua, + state: *mut ffi::lua_State, + idx: c_int, + ) -> Result { + let type_id = lua.get_userdata_type_id::(state, idx)?; + match type_id { + Some(type_id) if type_id == TypeId::of::() => { + let ud = get_userdata::>(state, idx); + (*ud).try_borrow_owned() + } + + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + Some(type_id) if type_id == TypeId::of::>() => { + let ud = get_userdata::>>(state, idx); + ((*ud).try_borrow_owned()).and_then(|ud| ud.transform_rc()) + } + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + Some(type_id) if type_id == TypeId::of::>>() => { + let ud = get_userdata::>>>(state, idx); + ((*ud).try_borrow_owned()).and_then(|ud| ud.transform_rc_refcell()) + } + + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == TypeId::of::>() => { + let ud = get_userdata::>>(state, idx); + ((*ud).try_borrow_owned()).and_then(|ud| ud.transform_arc()) + } + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == TypeId::of::>>() => { + let ud = get_userdata::>>>(state, idx); + ((*ud).try_borrow_owned()).and_then(|ud| ud.transform_arc_mutex_pl()) + } + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == TypeId::of::>>() => { + let ud = get_userdata::>>>(state, idx); + ((*ud).try_borrow_owned()).and_then(|ud| ud.transform_arc_rwlock_pl()) + } + _ => Err(Error::UserDataTypeMismatch), + } + } +} + +#[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] +impl UserDataRef> { + fn transform_rc(self) -> Result> { + self.remap(|variant| Ok(UserDataRefInner::Rc(variant))) + } +} + +#[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] +impl UserDataRef>> { + fn transform_rc_refcell(self) -> Result> { + self.remap(|variant| unsafe { + let obj = &*variant.as_ptr(); + let r#ref = obj.try_borrow().map_err(|_| Error::UserDataBorrowError)?; + let borrow = std::mem::transmute::, Ref<'static, T>>(r#ref); + Ok(UserDataRefInner::RcRefCell(borrow, variant)) + }) + } +} + +#[cfg(feature = "userdata-wrappers")] +impl UserDataRef> { + fn transform_arc(self) -> Result> { + self.remap(|variant| Ok(UserDataRefInner::Arc(variant))) + } +} + +#[cfg(feature = "userdata-wrappers")] +impl UserDataRef>> { + fn transform_arc_mutex_pl(self) -> Result> { + self.remap(|variant| unsafe { + let obj = &*variant.as_ptr(); + let guard = obj.try_lock().ok_or(Error::UserDataBorrowError)?; + let borrow = std::mem::transmute::, MutexGuardPL<'static, T>>(guard); + Ok(UserDataRefInner::ArcMutexPL(borrow, variant)) + }) + } +} + +#[cfg(feature = "userdata-wrappers")] +impl UserDataRef>> { + fn transform_arc_rwlock_pl(self) -> Result> { + self.remap(|variant| unsafe { + let obj = &*variant.as_ptr(); + let guard = obj.try_read().ok_or(Error::UserDataBorrowError)?; + let borrow = std::mem::transmute::, RwLockReadGuardPL<'static, T>>(guard); + Ok(UserDataRefInner::ArcRwLockPL(borrow, variant)) + }) + } +} + +#[allow(unused)] +enum UserDataRefInner { + Default(UserDataVariant), + + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + Rc(UserDataVariant>), + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + RcRefCell(Ref<'static, T>, UserDataVariant>>), + + #[cfg(feature = "userdata-wrappers")] + Arc(UserDataVariant>), + #[cfg(feature = "userdata-wrappers")] + ArcMutexPL(MutexGuardPL<'static, T>, UserDataVariant>>), + #[cfg(feature = "userdata-wrappers")] + ArcRwLockPL(RwLockReadGuardPL<'static, T>, UserDataVariant>>), +} + +impl Deref for UserDataRefInner { + type Target = T; + + #[inline] + fn deref(&self) -> &T { + match self { + Self::Default(inner) => unsafe { &*inner.as_ptr() }, + + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + Self::Rc(inner) => unsafe { &*Rc::as_ptr(&*inner.as_ptr()) }, + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + Self::RcRefCell(x, ..) => x, + + #[cfg(feature = "userdata-wrappers")] + Self::Arc(inner) => unsafe { &*Arc::as_ptr(&*inner.as_ptr()) }, + #[cfg(feature = "userdata-wrappers")] + Self::ArcMutexPL(x, ..) => x, + #[cfg(feature = "userdata-wrappers")] + Self::ArcRwLockPL(x, ..) => x, + } + } +} + +/// A wrapper type for a userdata value that provides read and write access. +/// +/// It implements [`FromLua`] and can be used to receive a typed userdata from Lua. +pub struct UserDataRefMut { + // It's important to drop the guard first, as it refers to the `inner` data. + _guard: LockGuard<'static, RawLock>, + inner: UserDataRefMutInner, +} + +impl Deref for UserDataRefMut { + type Target = T; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for UserDataRefMut { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +impl fmt::Debug for UserDataRefMut { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +impl fmt::Display for UserDataRefMut { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +impl TryFrom> for UserDataRefMut { + type Error = Error; + + #[inline] + fn try_from(variant: UserDataVariant) -> Result { + let guard = variant.raw_lock().try_lock_exclusive_guarded(); + let guard = guard.map_err(|_| Error::UserDataBorrowMutError)?; + let guard = unsafe { mem::transmute::, LockGuard<'static, _>>(guard) }; + Ok(UserDataRefMut::from_parts( + UserDataRefMutInner::Default(variant), + guard, + )) + } +} + +impl FromLua for UserDataRefMut { + fn from_lua(value: Value, _: &Lua) -> Result { + try_value_to_userdata::(value)?.borrow_mut() + } + + unsafe fn from_stack(idx: c_int, lua: &RawLua) -> Result { + Self::borrow_from_stack(lua, lua.state(), idx) + } +} + +impl UserDataRefMut { + #[inline(always)] + fn from_parts(inner: UserDataRefMutInner, guard: LockGuard<'static, RawLock>) -> Self { + Self { _guard: guard, inner } + } + + #[cfg(feature = "userdata-wrappers")] + fn remap( + self, + f: impl FnOnce(UserDataVariant) -> Result>, + ) -> Result> { + match &self.inner { + UserDataRefMutInner::Default(variant) => { + let inner = f(variant.clone())?; + Ok(UserDataRefMut::from_parts(inner, self._guard)) + } + _ => Err(Error::UserDataTypeMismatch), + } + } + + pub(crate) unsafe fn borrow_from_stack( + lua: &RawLua, + state: *mut ffi::lua_State, + idx: c_int, + ) -> Result { + let type_id = lua.get_userdata_type_id::(state, idx)?; + match type_id { + Some(type_id) if type_id == TypeId::of::() => { + let ud = get_userdata::>(state, idx); + (*ud).try_borrow_owned_mut() + } + + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + Some(type_id) if type_id == TypeId::of::>() => Err(Error::UserDataBorrowMutError), + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + Some(type_id) if type_id == TypeId::of::>>() => { + let ud = get_userdata::>>>(state, idx); + ((*ud).try_borrow_owned_mut()).and_then(|ud| ud.transform_rc_refcell()) + } + + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == TypeId::of::>() => Err(Error::UserDataBorrowMutError), + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == TypeId::of::>>() => { + let ud = get_userdata::>>>(state, idx); + ((*ud).try_borrow_owned_mut()).and_then(|ud| ud.transform_arc_mutex_pl()) + } + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == TypeId::of::>>() => { + let ud = get_userdata::>>>(state, idx); + ((*ud).try_borrow_owned_mut()).and_then(|ud| ud.transform_arc_rwlock_pl()) + } + _ => Err(Error::UserDataTypeMismatch), + } + } +} + +#[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] +impl UserDataRefMut>> { + fn transform_rc_refcell(self) -> Result> { + self.remap(|variant| unsafe { + let obj = &*variant.as_ptr(); + let refmut = obj.try_borrow_mut().map_err(|_| Error::UserDataBorrowMutError)?; + let borrow = std::mem::transmute::, RefMut<'static, T>>(refmut); + Ok(UserDataRefMutInner::RcRefCell(borrow, variant)) + }) + } +} + +#[cfg(feature = "userdata-wrappers")] +impl UserDataRefMut>> { + fn transform_arc_mutex_pl(self) -> Result> { + self.remap(|variant| unsafe { + let obj = &*variant.as_ptr(); + let guard = obj.try_lock().ok_or(Error::UserDataBorrowMutError)?; + let borrow = std::mem::transmute::, MutexGuardPL<'static, T>>(guard); + Ok(UserDataRefMutInner::ArcMutexPL(borrow, variant)) + }) + } +} + +#[cfg(feature = "userdata-wrappers")] +impl UserDataRefMut>> { + fn transform_arc_rwlock_pl(self) -> Result> { + self.remap(|variant| unsafe { + let obj = &*variant.as_ptr(); + let guard = obj.try_write().ok_or(Error::UserDataBorrowMutError)?; + let borrow = std::mem::transmute::, RwLockWriteGuardPL<'static, T>>(guard); + Ok(UserDataRefMutInner::ArcRwLockPL(borrow, variant)) + }) + } +} + +#[allow(unused)] +enum UserDataRefMutInner { + Default(UserDataVariant), + + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + RcRefCell(RefMut<'static, T>, UserDataVariant>>), + + #[cfg(feature = "userdata-wrappers")] + ArcMutexPL(MutexGuardPL<'static, T>, UserDataVariant>>), + #[cfg(feature = "userdata-wrappers")] + ArcRwLockPL(RwLockWriteGuardPL<'static, T>, UserDataVariant>>), +} + +impl Deref for UserDataRefMutInner { + type Target = T; + + #[inline] + fn deref(&self) -> &T { + match self { + Self::Default(inner) => unsafe { &*inner.as_ptr() }, + + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + Self::RcRefCell(x, ..) => x, + + #[cfg(feature = "userdata-wrappers")] + Self::ArcMutexPL(x, ..) => x, + #[cfg(feature = "userdata-wrappers")] + Self::ArcRwLockPL(x, ..) => x, + } + } +} + +impl DerefMut for UserDataRefMutInner { + #[inline] + fn deref_mut(&mut self) -> &mut T { + match self { + Self::Default(inner) => unsafe { &mut *inner.as_ptr() }, + + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + Self::RcRefCell(x, ..) => x, + + #[cfg(feature = "userdata-wrappers")] + Self::ArcMutexPL(x, ..) => x, + #[cfg(feature = "userdata-wrappers")] + Self::ArcRwLockPL(x, ..) => x, + } + } +} + +#[inline] +fn try_value_to_userdata(value: Value) -> Result { + match value { + Value::UserData(ud) => Ok(ud), + _ => Err(Error::FromLuaConversionError { + from: value.type_name(), + to: "userdata".to_string(), + message: Some(format!("expected userdata of type {}", short_type_name::())), + }), + } +} + +#[cfg(test)] +mod assertions { + use super::*; + + #[cfg(feature = "send")] + static_assertions::assert_impl_all!(UserDataRef<()>: Send, Sync); + #[cfg(feature = "send")] + static_assertions::assert_not_impl_all!(UserDataRef>: Send, Sync); + #[cfg(feature = "send")] + static_assertions::assert_impl_all!(UserDataRefMut<()>: Sync, Send); + #[cfg(feature = "send")] + static_assertions::assert_not_impl_all!(UserDataRefMut>: Send, Sync); + + #[cfg(not(feature = "send"))] + static_assertions::assert_not_impl_all!(UserDataRef<()>: Send, Sync); + #[cfg(not(feature = "send"))] + static_assertions::assert_not_impl_all!(UserDataRefMut<()>: Send, Sync); +} diff --git a/src/userdata/registry.rs b/src/userdata/registry.rs index ec5a989b..74e36e7d 100644 --- a/src/userdata/registry.rs +++ b/src/userdata/registry.rs @@ -10,8 +10,11 @@ use crate::error::{Error, Result}; use crate::state::{Lua, LuaGuard}; use crate::traits::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti}; use crate::types::{Callback, MaybeSend}; -use crate::userdata::{AnyUserData, MetaMethod, UserData, UserDataFields, UserDataMethods, UserDataStorage}; -use crate::util::{get_userdata, short_type_name}; +use crate::userdata::{ + borrow_userdata_scoped, borrow_userdata_scoped_mut, AnyUserData, MetaMethod, TypeIdHints, UserData, + UserDataFields, UserDataMethods, UserDataStorage, +}; +use crate::util::short_type_name; use crate::value::Value; #[cfg(feature = "async")] @@ -21,38 +24,18 @@ use { std::future::{self, Future}, }; -#[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] -use std::rc::Rc; -#[cfg(feature = "userdata-wrappers")] -use std::sync::{Arc, Mutex, RwLock}; - #[derive(Clone, Copy)] -enum UserDataTypeId { - Shared(TypeId), +enum UserDataType { + Shared(TypeIdHints), Unique(*mut c_void), - - #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] - Rc(TypeId), - #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] - RcRefCell(TypeId), - #[cfg(feature = "userdata-wrappers")] - Arc(TypeId), - #[cfg(feature = "userdata-wrappers")] - ArcMutex(TypeId), - #[cfg(feature = "userdata-wrappers")] - ArcRwLock(TypeId), - #[cfg(feature = "userdata-wrappers")] - ArcParkingLotMutex(TypeId), - #[cfg(feature = "userdata-wrappers")] - ArcParkingLotRwLock(TypeId), } /// Handle to registry for userdata methods and metamethods. pub struct UserDataRegistry { lua: LuaGuard, raw: RawUserDataRegistry, - ud_type_id: UserDataTypeId, - _type: PhantomData, + r#type: UserDataType, + _phantom: PhantomData, } pub(crate) struct RawUserDataRegistry { @@ -75,46 +58,34 @@ pub(crate) struct RawUserDataRegistry { pub(crate) type_name: StdString, } -impl UserDataTypeId { +impl UserDataType { #[inline] - pub(crate) fn type_id(self) -> Option { + pub(crate) fn type_id(&self) -> Option { match self { - UserDataTypeId::Shared(type_id) => Some(type_id), - UserDataTypeId::Unique(_) => None, - #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] - UserDataTypeId::Rc(type_id) => Some(type_id), - #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] - UserDataTypeId::RcRefCell(type_id) => Some(type_id), - #[cfg(feature = "userdata-wrappers")] - UserDataTypeId::Arc(type_id) => Some(type_id), - #[cfg(feature = "userdata-wrappers")] - UserDataTypeId::ArcMutex(type_id) => Some(type_id), - #[cfg(feature = "userdata-wrappers")] - UserDataTypeId::ArcRwLock(type_id) => Some(type_id), - #[cfg(feature = "userdata-wrappers")] - UserDataTypeId::ArcParkingLotMutex(type_id) => Some(type_id), - #[cfg(feature = "userdata-wrappers")] - UserDataTypeId::ArcParkingLotRwLock(type_id) => Some(type_id), + UserDataType::Shared(hints) => Some(hints.type_id()), + UserDataType::Unique(_) => None, } } } #[cfg(feature = "send")] -unsafe impl Send for UserDataTypeId {} +unsafe impl Send for UserDataType {} -impl UserDataRegistry { +impl UserDataRegistry { #[inline(always)] - pub(crate) fn new(lua: &Lua, type_id: TypeId) -> Self { - Self::with_type_id(lua, UserDataTypeId::Shared(type_id)) + pub(crate) fn new(lua: &Lua) -> Self { + Self::with_type(lua, UserDataType::Shared(TypeIdHints::new::())) } +} +impl UserDataRegistry { #[inline(always)] pub(crate) fn new_unique(lua: &Lua, ud_ptr: *mut c_void) -> Self { - Self::with_type_id(lua, UserDataTypeId::Unique(ud_ptr)) + Self::with_type(lua, UserDataType::Unique(ud_ptr)) } #[inline(always)] - fn with_type_id(lua: &Lua, ud_type_id: UserDataTypeId) -> Self { + fn with_type(lua: &Lua, r#type: UserDataType) -> Self { let raw = RawUserDataRegistry { fields: Vec::new(), field_getters: Vec::new(), @@ -126,16 +97,16 @@ impl UserDataRegistry { meta_methods: Vec::new(), #[cfg(feature = "async")] async_meta_methods: Vec::new(), - destructor: super::util::userdata_destructor::, - type_id: ud_type_id.type_id(), + destructor: super::util::destroy_userdata_storage::, + type_id: r#type.type_id(), type_name: short_type_name::(), }; UserDataRegistry { lua: lua.lock_arc(), raw, - ud_type_id, - _type: PhantomData, + r#type, + _phantom: PhantomData, } } @@ -152,7 +123,7 @@ impl UserDataRegistry { }; } - let target_type_id = self.ud_type_id; + let target_type = self.r#type; Box::new(move |rawlua, nargs| unsafe { if nargs == 0 { let err = Error::from_lua_conversion("missing argument", "userdata", None); @@ -164,103 +135,24 @@ impl UserDataRegistry { // Self was at position 1, so we pass 2 here let args = A::from_stack_args(nargs - 1, 2, Some(&name), rawlua); - match target_type_id { + match target_type { #[rustfmt::skip] - UserDataTypeId::Shared(target_type_id) - if try_self_arg!(rawlua.get_userdata_type_id::(self_index)) == Some(target_type_id) => - { - let ud = get_userdata::>(state, self_index); - try_self_arg!((*ud).try_borrow_scoped(|ud| { + UserDataType::Shared(type_hints) => { + let type_id = try_self_arg!(rawlua.get_userdata_type_id::(state, self_index)); + try_self_arg!(borrow_userdata_scoped(state, self_index, type_id, type_hints, |ud| { method(rawlua.lua(), ud, args?)?.push_into_stack_multi(rawlua) })) } - #[rustfmt::skip] - UserDataTypeId::Unique(target_ptr) - if get_userdata::>(state, self_index) as *mut c_void == target_ptr => - { + UserDataType::Unique(target_ptr) if ffi::lua_touserdata(state, self_index) == target_ptr => { let ud = target_ptr as *mut UserDataStorage; try_self_arg!((*ud).try_borrow_scoped(|ud| { method(rawlua.lua(), ud, args?)?.push_into_stack_multi(rawlua) })) } - #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] - #[rustfmt::skip] - UserDataTypeId::Rc(target_type_id) - if try_self_arg!(rawlua.get_userdata_type_id::>(self_index)) == Some(target_type_id) => - { - let ud = get_userdata::>>(state, self_index); - try_self_arg!((*ud).try_borrow_scoped(|ud| { - method(rawlua.lua(), ud, args?)?.push_into_stack_multi(rawlua) - })) - } - #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] - #[rustfmt::skip] - UserDataTypeId::RcRefCell(target_type_id) - if try_self_arg!(rawlua.get_userdata_type_id::>>(self_index)) == Some(target_type_id) => - { - let ud = get_userdata::>>>(state, self_index); - try_self_arg!((*ud).try_borrow_scoped(|ud| { - let ud = ud.try_borrow().map_err(|_| Error::UserDataBorrowError)?; - method(rawlua.lua(), &ud, args?)?.push_into_stack_multi(rawlua) - })) - } - #[cfg(feature = "userdata-wrappers")] - #[rustfmt::skip] - UserDataTypeId::Arc(target_type_id) - if try_self_arg!(rawlua.get_userdata_type_id::>(self_index)) == Some(target_type_id) => - { - let ud = get_userdata::>>(state, self_index); - try_self_arg!((*ud).try_borrow_scoped(|ud| { - method(rawlua.lua(), ud, args?)?.push_into_stack_multi(rawlua) - })) - } - #[cfg(feature = "userdata-wrappers")] - #[rustfmt::skip] - UserDataTypeId::ArcMutex(target_type_id) - if try_self_arg!(rawlua.get_userdata_type_id::>>(self_index)) == Some(target_type_id) => - { - let ud = get_userdata::>>>(state, self_index); - try_self_arg!((*ud).try_borrow_scoped(|ud| { - let ud = ud.try_lock().map_err(|_| Error::UserDataBorrowError)?; - method(rawlua.lua(), &ud, args?)?.push_into_stack_multi(rawlua) - })) - } - #[cfg(feature = "userdata-wrappers")] - #[rustfmt::skip] - UserDataTypeId::ArcRwLock(target_type_id) - if try_self_arg!(rawlua.get_userdata_type_id::>>(self_index)) == Some(target_type_id) => - { - let ud = get_userdata::>>>(state, self_index); - try_self_arg!((*ud).try_borrow_scoped(|ud| { - let ud = ud.try_read().map_err(|_| Error::UserDataBorrowError)?; - method(rawlua.lua(), &ud, args?)?.push_into_stack_multi(rawlua) - })) - } - #[cfg(feature = "userdata-wrappers")] - #[rustfmt::skip] - UserDataTypeId::ArcParkingLotMutex(target_type_id) - if try_self_arg!(rawlua.get_userdata_type_id::>>(self_index)) - == Some(target_type_id) => - { - let ud = get_userdata::>>>(state, self_index); - try_self_arg!((*ud).try_borrow_scoped(|ud| { - let ud = ud.try_lock().ok_or(Error::UserDataBorrowError)?; - method(rawlua.lua(), &ud, args?)?.push_into_stack_multi(rawlua) - })) - } - #[cfg(feature = "userdata-wrappers")] - #[rustfmt::skip] - UserDataTypeId::ArcParkingLotRwLock(target_type_id) - if try_self_arg!(rawlua.get_userdata_type_id::>>(self_index)) - == Some(target_type_id) => - { - let ud = get_userdata::>>>(state, self_index); - try_self_arg!((*ud).try_borrow_scoped(|ud| { - let ud = ud.try_read().ok_or(Error::UserDataBorrowError)?; - method(rawlua.lua(), &ud, args?)?.push_into_stack_multi(rawlua) - })) + UserDataType::Unique(_) => { + try_self_arg!(rawlua.get_userdata_type_id::(state, self_index)); + Err(Error::bad_self_argument(&name, Error::UserDataTypeMismatch)) } - _ => Err(Error::bad_self_argument(&name, Error::UserDataTypeMismatch)), } }) } @@ -279,7 +171,7 @@ impl UserDataRegistry { } let method = RefCell::new(method); - let target_type_id = self.ud_type_id; + let target_type = self.r#type; Box::new(move |rawlua, nargs| unsafe { let mut method = method.try_borrow_mut().map_err(|_| Error::RecursiveMutCallback)?; if nargs == 0 { @@ -292,97 +184,24 @@ impl UserDataRegistry { // Self was at position 1, so we pass 2 here let args = A::from_stack_args(nargs - 1, 2, Some(&name), rawlua); - match target_type_id { + match target_type { #[rustfmt::skip] - UserDataTypeId::Shared(target_type_id) - if try_self_arg!(rawlua.get_userdata_type_id::(self_index)) == Some(target_type_id) => - { - let ud = get_userdata::>(state, self_index); - try_self_arg!((*ud).try_borrow_scoped_mut(|ud| { + UserDataType::Shared(type_hints) => { + let type_id = try_self_arg!(rawlua.get_userdata_type_id::(state, self_index)); + try_self_arg!(borrow_userdata_scoped_mut(state, self_index, type_id, type_hints, |ud| { method(rawlua.lua(), ud, args?)?.push_into_stack_multi(rawlua) })) } - #[rustfmt::skip] - UserDataTypeId::Unique(target_ptr) - if get_userdata::>(state, self_index) as *mut c_void == target_ptr => - { + UserDataType::Unique(target_ptr) if ffi::lua_touserdata(state, self_index) == target_ptr => { let ud = target_ptr as *mut UserDataStorage; try_self_arg!((*ud).try_borrow_scoped_mut(|ud| { method(rawlua.lua(), ud, args?)?.push_into_stack_multi(rawlua) })) } - #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] - #[rustfmt::skip] - UserDataTypeId::Rc(target_type_id) - if try_self_arg!(rawlua.get_userdata_type_id::>(self_index)) == Some(target_type_id) => - { - Err(Error::UserDataBorrowMutError) - }, - #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] - #[rustfmt::skip] - UserDataTypeId::RcRefCell(target_type_id) - if try_self_arg!(rawlua.get_userdata_type_id::>>(self_index)) == Some(target_type_id) => - { - let ud = get_userdata::>>>(state, self_index); - try_self_arg!((*ud).try_borrow_scoped(|ud| { - let mut ud = ud.try_borrow_mut().map_err(|_| Error::UserDataBorrowMutError)?; - method(rawlua.lua(), &mut ud, args?)?.push_into_stack_multi(rawlua) - })) - } - #[cfg(feature = "userdata-wrappers")] - #[rustfmt::skip] - UserDataTypeId::Arc(target_type_id) - if try_self_arg!(rawlua.get_userdata_type_id::>(self_index)) == Some(target_type_id) => - { - Err(Error::UserDataBorrowMutError) - }, - #[cfg(feature = "userdata-wrappers")] - #[rustfmt::skip] - UserDataTypeId::ArcMutex(target_type_id) - if try_self_arg!(rawlua.get_userdata_type_id::>>(self_index)) == Some(target_type_id) => - { - let ud = get_userdata::>>>(state, self_index); - try_self_arg!((*ud).try_borrow_scoped(|ud| { - let mut ud = ud.try_lock().map_err(|_| Error::UserDataBorrowMutError)?; - method(rawlua.lua(), &mut ud, args?)?.push_into_stack_multi(rawlua) - })) - } - #[cfg(feature = "userdata-wrappers")] - #[rustfmt::skip] - UserDataTypeId::ArcRwLock(target_type_id) - if try_self_arg!(rawlua.get_userdata_type_id::>>(self_index)) == Some(target_type_id) => - { - let ud = get_userdata::>>>(state, self_index); - try_self_arg!((*ud).try_borrow_scoped(|ud| { - let mut ud = ud.try_write().map_err(|_| Error::UserDataBorrowMutError)?; - method(rawlua.lua(), &mut ud, args?)?.push_into_stack_multi(rawlua) - })) - } - #[cfg(feature = "userdata-wrappers")] - #[rustfmt::skip] - UserDataTypeId::ArcParkingLotMutex(target_type_id) - if try_self_arg!(rawlua.get_userdata_type_id::>>(self_index)) - == Some(target_type_id) => - { - let ud = get_userdata::>>>(state, self_index); - try_self_arg!((*ud).try_borrow_scoped(|ud| { - let mut ud = ud.try_lock().ok_or(Error::UserDataBorrowMutError)?; - method(rawlua.lua(), &mut ud, args?)?.push_into_stack_multi(rawlua) - })) - } - #[cfg(feature = "userdata-wrappers")] - #[rustfmt::skip] - UserDataTypeId::ArcParkingLotRwLock(target_type_id) - if try_self_arg!(rawlua.get_userdata_type_id::>>(self_index)) - == Some(target_type_id) => - { - let ud = get_userdata::>>>(state, self_index); - try_self_arg!((*ud).try_borrow_scoped(|ud| { - let mut ud = ud.try_write().ok_or(Error::UserDataBorrowMutError)?; - method(rawlua.lua(), &mut ud, args?)?.push_into_stack_multi(rawlua) - })) + UserDataType::Unique(_) => { + try_self_arg!(rawlua.get_userdata_type_id::(state, self_index)); + Err(Error::bad_self_argument(&name, Error::UserDataTypeMismatch)) } - _ => Err(Error::bad_self_argument(&name, Error::UserDataTypeMismatch)), } }) } @@ -789,14 +608,10 @@ impl UserDataMethods for UserDataRegistry { } macro_rules! lua_userdata_impl { - ($type:ty => $type_variant:tt) => { - lua_userdata_impl!($type, UserDataTypeId::$type_variant(TypeId::of::<$type>())); - }; - - ($type:ty, $type_id:expr) => { + ($type:ty) => { impl UserData for $type { fn register(registry: &mut UserDataRegistry) { - let mut orig_registry = UserDataRegistry::with_type_id(registry.lua.lua(), $type_id); + let mut orig_registry = UserDataRegistry::new(registry.lua.lua()); T::register(&mut orig_registry); // Copy all fields, methods, etc. from the original registry @@ -818,22 +633,22 @@ macro_rules! lua_userdata_impl { // A special proxy object for UserData pub(crate) struct UserDataProxy(pub(crate) PhantomData); -lua_userdata_impl!(UserDataProxy, UserDataTypeId::Shared(TypeId::of::())); +lua_userdata_impl!(UserDataProxy); #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] -lua_userdata_impl!(Rc => Rc); +lua_userdata_impl!(std::rc::Rc); #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] -lua_userdata_impl!(Rc> => RcRefCell); +lua_userdata_impl!(std::rc::Rc>); #[cfg(feature = "userdata-wrappers")] -lua_userdata_impl!(Arc => Arc); +lua_userdata_impl!(std::sync::Arc); #[cfg(feature = "userdata-wrappers")] -lua_userdata_impl!(Arc> => ArcMutex); +lua_userdata_impl!(std::sync::Arc>); #[cfg(feature = "userdata-wrappers")] -lua_userdata_impl!(Arc> => ArcRwLock); +lua_userdata_impl!(std::sync::Arc>); #[cfg(feature = "userdata-wrappers")] -lua_userdata_impl!(Arc> => ArcParkingLotMutex); +lua_userdata_impl!(std::sync::Arc>); #[cfg(feature = "userdata-wrappers")] -lua_userdata_impl!(Arc> => ArcParkingLotRwLock); +lua_userdata_impl!(std::sync::Arc>); #[cfg(test)] mod assertions { diff --git a/src/userdata/util.rs b/src/userdata/util.rs index 02c5ea4e..27cc63d3 100644 --- a/src/userdata/util.rs +++ b/src/userdata/util.rs @@ -1,9 +1,12 @@ +use std::any::TypeId; use std::cell::Cell; use std::marker::PhantomData; use std::os::raw::c_int; +use std::ptr; use super::UserDataStorage; -use crate::util::{get_userdata, take_userdata}; +use crate::error::{Error, Result}; +use crate::util::{get_userdata, rawget_field, rawset_field, take_userdata}; // This is a trick to check if a type is `Sync` or not. // It uses leaked specialization feature from stdlib. @@ -34,9 +37,416 @@ pub(crate) fn is_sync() -> bool { is_sync.get() } -pub(super) unsafe extern "C-unwind" fn userdata_destructor(state: *mut ffi::lua_State) -> c_int { +// Userdata type hints, used to match types of wrapped userdata +#[derive(Clone, Copy)] +pub(crate) struct TypeIdHints { + t: TypeId, + + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + rc: TypeId, + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + rc_refcell: TypeId, + + #[cfg(feature = "userdata-wrappers")] + arc: TypeId, + #[cfg(feature = "userdata-wrappers")] + arc_mutex: TypeId, + #[cfg(feature = "userdata-wrappers")] + arc_rwlock: TypeId, + #[cfg(feature = "userdata-wrappers")] + arc_pl_mutex: TypeId, + #[cfg(feature = "userdata-wrappers")] + arc_pl_rwlock: TypeId, +} + +impl TypeIdHints { + pub(crate) fn new() -> Self { + Self { + t: TypeId::of::(), + + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + rc: TypeId::of::>(), + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + rc_refcell: TypeId::of::>>(), + + #[cfg(feature = "userdata-wrappers")] + arc: TypeId::of::>(), + #[cfg(feature = "userdata-wrappers")] + arc_mutex: TypeId::of::>>(), + #[cfg(feature = "userdata-wrappers")] + arc_rwlock: TypeId::of::>>(), + #[cfg(feature = "userdata-wrappers")] + arc_pl_mutex: TypeId::of::>>(), + #[cfg(feature = "userdata-wrappers")] + arc_pl_rwlock: TypeId::of::>>(), + } + } + + #[inline(always)] + pub(crate) fn type_id(&self) -> TypeId { + self.t + } +} + +pub(crate) unsafe fn borrow_userdata_scoped( + state: *mut ffi::lua_State, + idx: c_int, + type_id: Option, + type_hints: TypeIdHints, + f: impl FnOnce(&T) -> R, +) -> Result { + match type_id { + Some(type_id) if type_id == type_hints.t => { + let ud = get_userdata::>(state, idx); + (*ud).try_borrow_scoped(|ud| f(ud)) + } + + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + Some(type_id) if type_id == type_hints.rc => { + let ud = get_userdata::>>(state, idx); + (*ud).try_borrow_scoped(|ud| f(ud)) + } + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + Some(type_id) if type_id == type_hints.rc_refcell => { + let ud = get_userdata::>>>(state, idx); + (*ud).try_borrow_scoped(|ud| { + let ud = ud.try_borrow().map_err(|_| Error::UserDataBorrowError)?; + Ok(f(&ud)) + })? + } + + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == type_hints.arc => { + let ud = get_userdata::>>(state, idx); + (*ud).try_borrow_scoped(|ud| f(ud)) + } + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == type_hints.arc_mutex => { + let ud = get_userdata::>>>(state, idx); + (*ud).try_borrow_scoped(|ud| { + let ud = ud.try_lock().map_err(|_| Error::UserDataBorrowError)?; + Ok(f(&ud)) + })? + } + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == type_hints.arc_rwlock => { + let ud = get_userdata::>>>(state, idx); + (*ud).try_borrow_scoped(|ud| { + let ud = ud.try_read().map_err(|_| Error::UserDataBorrowError)?; + Ok(f(&ud)) + })? + } + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == type_hints.arc_pl_mutex => { + let ud = get_userdata::>>>(state, idx); + (*ud).try_borrow_scoped(|ud| { + let ud = ud.try_lock().ok_or(Error::UserDataBorrowError)?; + Ok(f(&ud)) + })? + } + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == type_hints.arc_pl_rwlock => { + let ud = get_userdata::>>>(state, idx); + (*ud).try_borrow_scoped(|ud| { + let ud = ud.try_read().ok_or(Error::UserDataBorrowError)?; + Ok(f(&ud)) + })? + } + _ => Err(Error::UserDataTypeMismatch), + } +} + +pub(crate) unsafe fn borrow_userdata_scoped_mut( + state: *mut ffi::lua_State, + idx: c_int, + type_id: Option, + type_hints: TypeIdHints, + f: impl FnOnce(&mut T) -> R, +) -> Result { + match type_id { + Some(type_id) if type_id == type_hints.t => { + let ud = get_userdata::>(state, idx); + (*ud).try_borrow_scoped_mut(|ud| f(ud)) + } + + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + Some(type_id) if type_id == type_hints.rc => { + let ud = get_userdata::>>(state, idx); + (*ud).try_borrow_scoped_mut(|ud| match std::rc::Rc::get_mut(ud) { + Some(ud) => Ok(f(ud)), + None => Err(Error::UserDataBorrowMutError), + })? + } + #[cfg(all(feature = "userdata-wrappers", not(feature = "send")))] + Some(type_id) if type_id == type_hints.rc_refcell => { + let ud = get_userdata::>>>(state, idx); + (*ud).try_borrow_scoped(|ud| { + let mut ud = ud.try_borrow_mut().map_err(|_| Error::UserDataBorrowMutError)?; + Ok(f(&mut ud)) + })? + } + + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == type_hints.arc => { + let ud = get_userdata::>>(state, idx); + (*ud).try_borrow_scoped_mut(|ud| match std::sync::Arc::get_mut(ud) { + Some(ud) => Ok(f(ud)), + None => Err(Error::UserDataBorrowMutError), + })? + } + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == type_hints.arc_mutex => { + let ud = get_userdata::>>>(state, idx); + (*ud).try_borrow_scoped_mut(|ud| { + let mut ud = ud.try_lock().map_err(|_| Error::UserDataBorrowMutError)?; + Ok(f(&mut ud)) + })? + } + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == type_hints.arc_rwlock => { + let ud = get_userdata::>>>(state, idx); + (*ud).try_borrow_scoped_mut(|ud| { + let mut ud = ud.try_write().map_err(|_| Error::UserDataBorrowMutError)?; + Ok(f(&mut ud)) + })? + } + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == type_hints.arc_pl_mutex => { + let ud = get_userdata::>>>(state, idx); + (*ud).try_borrow_scoped_mut(|ud| { + let mut ud = ud.try_lock().ok_or(Error::UserDataBorrowMutError)?; + Ok(f(&mut ud)) + })? + } + #[cfg(feature = "userdata-wrappers")] + Some(type_id) if type_id == type_hints.arc_pl_rwlock => { + let ud = get_userdata::>>>(state, idx); + (*ud).try_borrow_scoped_mut(|ud| { + let mut ud = ud.try_write().ok_or(Error::UserDataBorrowMutError)?; + Ok(f(&mut ud)) + })? + } + _ => Err(Error::UserDataTypeMismatch), + } +} + +// Populates the given table with the appropriate members to be a userdata metatable for the given +// type. This function takes the given table at the `metatable` index, and adds an appropriate +// `__gc` member to it for the given type and a `__metatable` entry to protect the table from script +// access. The function also, if given a `field_getters` or `methods` tables, will create an +// `__index` metamethod (capturing previous one) to lookup in `field_getters` first, then `methods` +// and falling back to the captured `__index` if no matches found. +// The same is also applicable for `__newindex` metamethod and `field_setters` table. +// Internally uses 9 stack spaces and does not call checkstack. +pub(crate) unsafe fn init_userdata_metatable( + state: *mut ffi::lua_State, + metatable: c_int, + field_getters: Option, + field_setters: Option, + methods: Option, +) -> Result<()> { + if field_getters.is_some() || methods.is_some() { + // Push `__index` generator function + init_userdata_metatable_index(state)?; + + let index_type = rawget_field(state, metatable, "__index")?; + match index_type { + ffi::LUA_TNIL | ffi::LUA_TTABLE | ffi::LUA_TFUNCTION => { + for &idx in &[field_getters, methods] { + if let Some(idx) = idx { + ffi::lua_pushvalue(state, idx); + } else { + ffi::lua_pushnil(state); + } + } + + // Generate `__index` + protect_lua!(state, 4, 1, fn(state) ffi::lua_call(state, 3, 1))?; + } + _ => mlua_panic!("improper `__index` type: {}", index_type), + } + + rawset_field(state, metatable, "__index")?; + } + + if let Some(field_setters) = field_setters { + // Push `__newindex` generator function + init_userdata_metatable_newindex(state)?; + + let newindex_type = rawget_field(state, metatable, "__newindex")?; + match newindex_type { + ffi::LUA_TNIL | ffi::LUA_TTABLE | ffi::LUA_TFUNCTION => { + ffi::lua_pushvalue(state, field_setters); + // Generate `__newindex` + protect_lua!(state, 3, 1, fn(state) ffi::lua_call(state, 2, 1))?; + } + _ => mlua_panic!("improper `__newindex` type: {}", newindex_type), + } + + rawset_field(state, metatable, "__newindex")?; + } + + ffi::lua_pushboolean(state, 0); + rawset_field(state, metatable, "__metatable")?; + + Ok(()) +} + +unsafe extern "C-unwind" fn lua_error_impl(state: *mut ffi::lua_State) -> c_int { + ffi::lua_error(state); +} + +unsafe extern "C-unwind" fn lua_isfunction_impl(state: *mut ffi::lua_State) -> c_int { + ffi::lua_pushboolean(state, ffi::lua_isfunction(state, -1)); + 1 +} + +unsafe extern "C-unwind" fn lua_istable_impl(state: *mut ffi::lua_State) -> c_int { + ffi::lua_pushboolean(state, ffi::lua_istable(state, -1)); + 1 +} + +unsafe fn init_userdata_metatable_index(state: *mut ffi::lua_State) -> Result<()> { + let index_key = &USERDATA_METATABLE_INDEX as *const u8 as *const _; + if ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, index_key) == ffi::LUA_TFUNCTION { + return Ok(()); + } + ffi::lua_pop(state, 1); + + // Create and cache `__index` generator + let code = cr#" + local error, isfunction, istable = ... + return function (__index, field_getters, methods) + -- Common case: has field getters and index is a table + if field_getters ~= nil and methods == nil and istable(__index) then + return function (self, key) + local field_getter = field_getters[key] + if field_getter ~= nil then + return field_getter(self) + end + return __index[key] + end + end + + return function (self, key) + if field_getters ~= nil then + local field_getter = field_getters[key] + if field_getter ~= nil then + return field_getter(self) + end + end + + if methods ~= nil then + local method = methods[key] + if method ~= nil then + return method + end + end + + if isfunction(__index) then + return __index(self, key) + elseif __index == nil then + error("attempt to get an unknown field '"..key.."'") + else + return __index[key] + end + end + end + "#; + protect_lua!(state, 0, 1, |state| { + let ret = ffi::luaL_loadbuffer(state, code.as_ptr(), code.count_bytes(), cstr!("=__mlua_index")); + if ret != ffi::LUA_OK { + ffi::lua_error(state); + } + ffi::lua_pushcfunction(state, lua_error_impl); + ffi::lua_pushcfunction(state, lua_isfunction_impl); + ffi::lua_pushcfunction(state, lua_istable_impl); + ffi::lua_call(state, 3, 1); + + #[cfg(feature = "luau-jit")] + if ffi::luau_codegen_supported() != 0 { + ffi::luau_codegen_compile(state, -1); + } + + // Store in the registry + ffi::lua_pushvalue(state, -1); + ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, index_key); + }) +} + +unsafe fn init_userdata_metatable_newindex(state: *mut ffi::lua_State) -> Result<()> { + let newindex_key = &USERDATA_METATABLE_NEWINDEX as *const u8 as *const _; + if ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, newindex_key) == ffi::LUA_TFUNCTION { + return Ok(()); + } + ffi::lua_pop(state, 1); + + // Create and cache `__newindex` generator + let code = cr#" + local error, isfunction = ... + return function (__newindex, field_setters) + return function (self, key, value) + if field_setters ~= nil then + local field_setter = field_setters[key] + if field_setter ~= nil then + field_setter(self, value) + return + end + end + + if isfunction(__newindex) then + __newindex(self, key, value) + elseif __newindex == nil then + error("attempt to set an unknown field '"..key.."'") + else + __newindex[key] = value + end + end + end + "#; + protect_lua!(state, 0, 1, |state| { + let code_len = code.count_bytes(); + let ret = ffi::luaL_loadbuffer(state, code.as_ptr(), code_len, cstr!("=__mlua_newindex")); + if ret != ffi::LUA_OK { + ffi::lua_error(state); + } + ffi::lua_pushcfunction(state, lua_error_impl); + ffi::lua_pushcfunction(state, lua_isfunction_impl); + ffi::lua_call(state, 2, 1); + + #[cfg(feature = "luau-jit")] + if ffi::luau_codegen_supported() != 0 { + ffi::luau_codegen_compile(state, -1); + } + + // Store in the registry + ffi::lua_pushvalue(state, -1); + ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, newindex_key); + }) +} + +// This method is called by Lua GC when it's time to collect the userdata. +// +// This method is usually used to collect internal userdata. +#[cfg(not(feature = "luau"))] +pub(crate) unsafe extern "C-unwind" fn collect_userdata(state: *mut ffi::lua_State) -> c_int { + let ud = get_userdata::(state, -1); + ptr::drop_in_place(ud); + 0 +} + +// This method is called by Luau GC when it's time to collect the userdata. +#[cfg(feature = "luau")] +pub(crate) unsafe extern "C-unwind" fn collect_userdata(ud: *mut std::os::raw::c_void) { + ptr::drop_in_place(ud as *mut T); +} + +// This method can be called by user or Lua GC to destroy the userdata. +// It checks if the userdata is safe to destroy and sets the "destroyed" metatable +// to prevent further GC collection. +pub(super) unsafe extern "C-unwind" fn destroy_userdata_storage(state: *mut ffi::lua_State) -> c_int { let ud = get_userdata::>(state, -1); - if !(*ud).is_borrowed() { + if (*ud).is_safe_to_destroy() { take_userdata::>(state); ffi::lua_pushboolean(state, 1); } else { @@ -44,3 +454,6 @@ pub(super) unsafe extern "C-unwind" fn userdata_destructor(state: *mut ffi::l } 1 } + +static USERDATA_METATABLE_INDEX: u8 = 0; +static USERDATA_METATABLE_NEWINDEX: u8 = 0; diff --git a/src/util/error.rs b/src/util/error.rs index eb9ee1f5..ef9184e2 100644 --- a/src/util/error.rs +++ b/src/util/error.rs @@ -9,9 +9,8 @@ use std::sync::Arc; use crate::error::{Error, Result}; use crate::memory::MemoryState; use crate::util::{ - check_stack, get_internal_metatable, get_internal_userdata, init_internal_metatable, - push_internal_userdata, push_string, push_table, rawset_field, to_string, TypeKey, - DESTRUCTED_USERDATA_METATABLE, + check_stack, get_internal_userdata, init_internal_metatable, push_internal_userdata, push_string, + push_table, rawset_field, to_string, TypeKey, DESTRUCTED_USERDATA_METATABLE, }; static WRAPPED_FAILURE_TYPE_KEY: u8 = 0; @@ -31,12 +30,8 @@ impl TypeKey for WrappedFailure { impl WrappedFailure { pub(crate) unsafe fn new_userdata(state: *mut ffi::lua_State) -> *mut Self { - #[cfg(feature = "luau")] - let ud = ffi::lua_newuserdata_t::(state); - #[cfg(not(feature = "luau"))] - let ud = ffi::lua_newuserdata(state, std::mem::size_of::()) as *mut Self; - ptr::write(ud, WrappedFailure::None); - ud + // Unprotected calls always return `Ok` + push_internal_userdata(state, WrappedFailure::None, false).unwrap() } } @@ -90,16 +85,11 @@ where let cause = Arc::new(err); let wrapped_error = WrappedFailure::Error(Error::CallbackError { traceback, cause }); ptr::write(ud, wrapped_error); - get_internal_metatable::(state); - ffi::lua_setmetatable(state, -2); - ffi::lua_error(state) } Err(p) => { ffi::lua_settop(state, 1); ptr::write(ud, WrappedFailure::Panic(Some(p))); - get_internal_metatable::(state); - ffi::lua_setmetatable(state, -2); ffi::lua_error(state) } } @@ -262,7 +252,7 @@ where pub(crate) unsafe extern "C-unwind" fn error_traceback(state: *mut ffi::lua_State) -> c_int { // Luau calls error handler for memory allocation errors, skip it - // See https://github.com/Roblox/luau/issues/880 + // See https://github.com/luau-lang/luau/issues/880 #[cfg(feature = "luau")] if MemoryState::limit_reached(state) { return 0; diff --git a/src/util/mod.rs b/src/util/mod.rs index 48e7d8fa..f5fbae52 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -13,13 +13,12 @@ pub(crate) use short_names::short_type_name; pub(crate) use types::TypeKey; pub(crate) use userdata::{ get_destructed_userdata_metatable, get_internal_metatable, get_internal_userdata, get_userdata, - init_internal_metatable, init_userdata_metatable, push_internal_userdata, take_userdata, + init_internal_metatable, push_internal_userdata, push_userdata, take_userdata, DESTRUCTED_USERDATA_METATABLE, }; #[cfg(not(feature = "luau"))] pub(crate) use userdata::push_uninit_userdata; -pub(crate) use userdata::push_userdata; // Checks that Lua has enough free stack space for future stack operations. On failure, this will // panic with an internal error message. diff --git a/src/util/userdata.rs b/src/util/userdata.rs index 119b8c8d..96f9c7b4 100644 --- a/src/util/userdata.rs +++ b/src/util/userdata.rs @@ -1,8 +1,9 @@ use std::os::raw::{c_int, c_void}; -use std::{ptr, str}; +use std::{mem, ptr}; use crate::error::Result; -use crate::util::{check_stack, get_metatable_ptr, push_table, rawget_field, rawset_field, TypeKey}; +use crate::userdata::collect_userdata; +use crate::util::{check_stack, get_metatable_ptr, push_table, rawset_field, TypeKey}; // Pushes the userdata and attaches a metatable with __gc method. // Internally uses 3 stack spaces, does not call checkstack. @@ -10,11 +11,27 @@ pub(crate) unsafe fn push_internal_userdata( state: *mut ffi::lua_State, t: T, protect: bool, -) -> Result<()> { - push_userdata(state, t, protect)?; +) -> Result<*mut T> { + #[cfg(not(feature = "luau"))] + let ud_ptr = if protect { + protect_lua!(state, 0, 1, move |state| { + ffi::lua_newuserdata(state, const { mem::size_of::() }) as *mut T + })? + } else { + ffi::lua_newuserdata(state, const { mem::size_of::() }) as *mut T + }; + + #[cfg(feature = "luau")] + let ud_ptr = if protect { + protect_lua!(state, 0, 1, move |state| ffi::lua_newuserdata_t::(state))? + } else { + ffi::lua_newuserdata_t::(state) + }; + + ptr::write(ud_ptr, t); get_internal_metatable::(state); ffi::lua_setmetatable(state, -2); - Ok(()) + Ok(ud_ptr) } #[track_caller] @@ -35,7 +52,7 @@ pub(crate) unsafe fn init_internal_metatable( #[cfg(not(feature = "luau"))] { - ffi::lua_pushcfunction(state, userdata_destructor::); + ffi::lua_pushcfunction(state, collect_userdata::); rawset_field(state, -2, "__gc")?; } @@ -81,24 +98,34 @@ pub(crate) unsafe fn get_internal_userdata( pub(crate) unsafe fn push_uninit_userdata(state: *mut ffi::lua_State, protect: bool) -> Result<*mut T> { if protect { protect_lua!(state, 0, 1, |state| { - ffi::lua_newuserdata(state, std::mem::size_of::()) as *mut T + ffi::lua_newuserdata(state, const { mem::size_of::() }) as *mut T }) } else { - Ok(ffi::lua_newuserdata(state, std::mem::size_of::()) as *mut T) + Ok(ffi::lua_newuserdata(state, const { mem::size_of::() }) as *mut T) } } // Internally uses 3 stack spaces, does not call checkstack. #[inline] pub(crate) unsafe fn push_userdata(state: *mut ffi::lua_State, t: T, protect: bool) -> Result<*mut T> { + let size = const { mem::size_of::() }; + #[cfg(not(feature = "luau"))] - let ud_ptr = push_uninit_userdata(state, protect)?; + let ud_ptr = if protect { + protect_lua!(state, 0, 1, move |state| ffi::lua_newuserdata(state, size))? + } else { + ffi::lua_newuserdata(state, size) + } as *mut T; + #[cfg(feature = "luau")] let ud_ptr = if protect { - protect_lua!(state, 0, 1, |state| { ffi::lua_newuserdata_t::(state) })? + protect_lua!(state, 0, 1, |state| { + ffi::lua_newuserdatadtor(state, size, collect_userdata::) + })? } else { - ffi::lua_newuserdata_t::(state) - }; + ffi::lua_newuserdatadtor(state, size, collect_userdata::) + } as *mut T; + ptr::write(ud_ptr, t); Ok(ud_ptr) } @@ -137,208 +164,4 @@ pub(crate) unsafe fn get_destructed_userdata_metatable(state: *mut ffi::lua_Stat ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, key); } -// Populates the given table with the appropriate members to be a userdata metatable for the given -// type. This function takes the given table at the `metatable` index, and adds an appropriate -// `__gc` member to it for the given type and a `__metatable` entry to protect the table from script -// access. The function also, if given a `field_getters` or `methods` tables, will create an -// `__index` metamethod (capturing previous one) to lookup in `field_getters` first, then `methods` -// and falling back to the captured `__index` if no matches found. -// The same is also applicable for `__newindex` metamethod and `field_setters` table. -// Internally uses 9 stack spaces and does not call checkstack. -pub(crate) unsafe fn init_userdata_metatable( - state: *mut ffi::lua_State, - metatable: c_int, - field_getters: Option, - field_setters: Option, - methods: Option, -) -> Result<()> { - if field_getters.is_some() || methods.is_some() { - // Push `__index` generator function - init_userdata_metatable_index(state)?; - - let index_type = rawget_field(state, metatable, "__index")?; - match index_type { - ffi::LUA_TNIL | ffi::LUA_TTABLE | ffi::LUA_TFUNCTION => { - for &idx in &[field_getters, methods] { - if let Some(idx) = idx { - ffi::lua_pushvalue(state, idx); - } else { - ffi::lua_pushnil(state); - } - } - - // Generate `__index` - protect_lua!(state, 4, 1, fn(state) ffi::lua_call(state, 3, 1))?; - } - _ => mlua_panic!("improper `__index` type: {}", index_type), - } - - rawset_field(state, metatable, "__index")?; - } - - if let Some(field_setters) = field_setters { - // Push `__newindex` generator function - init_userdata_metatable_newindex(state)?; - - let newindex_type = rawget_field(state, metatable, "__newindex")?; - match newindex_type { - ffi::LUA_TNIL | ffi::LUA_TTABLE | ffi::LUA_TFUNCTION => { - ffi::lua_pushvalue(state, field_setters); - // Generate `__newindex` - protect_lua!(state, 3, 1, fn(state) ffi::lua_call(state, 2, 1))?; - } - _ => mlua_panic!("improper `__newindex` type: {}", newindex_type), - } - - rawset_field(state, metatable, "__newindex")?; - } - - ffi::lua_pushboolean(state, 0); - rawset_field(state, metatable, "__metatable")?; - - Ok(()) -} - -unsafe extern "C-unwind" fn lua_error_impl(state: *mut ffi::lua_State) -> c_int { - ffi::lua_error(state); -} - -unsafe extern "C-unwind" fn lua_isfunction_impl(state: *mut ffi::lua_State) -> c_int { - ffi::lua_pushboolean(state, ffi::lua_isfunction(state, -1)); - 1 -} - -unsafe extern "C-unwind" fn lua_istable_impl(state: *mut ffi::lua_State) -> c_int { - ffi::lua_pushboolean(state, ffi::lua_istable(state, -1)); - 1 -} - -unsafe fn init_userdata_metatable_index(state: *mut ffi::lua_State) -> Result<()> { - let index_key = &USERDATA_METATABLE_INDEX as *const u8 as *const _; - if ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, index_key) == ffi::LUA_TFUNCTION { - return Ok(()); - } - ffi::lua_pop(state, 1); - - // Create and cache `__index` generator - let code = cr#" - local error, isfunction, istable = ... - return function (__index, field_getters, methods) - -- Common case: has field getters and index is a table - if field_getters ~= nil and methods == nil and istable(__index) then - return function (self, key) - local field_getter = field_getters[key] - if field_getter ~= nil then - return field_getter(self) - end - return __index[key] - end - end - - return function (self, key) - if field_getters ~= nil then - local field_getter = field_getters[key] - if field_getter ~= nil then - return field_getter(self) - end - end - - if methods ~= nil then - local method = methods[key] - if method ~= nil then - return method - end - end - - if isfunction(__index) then - return __index(self, key) - elseif __index == nil then - error("attempt to get an unknown field '"..key.."'") - else - return __index[key] - end - end - end - "#; - protect_lua!(state, 0, 1, |state| { - let ret = ffi::luaL_loadbuffer(state, code.as_ptr(), code.count_bytes(), cstr!("__mlua_index")); - if ret != ffi::LUA_OK { - ffi::lua_error(state); - } - ffi::lua_pushcfunction(state, lua_error_impl); - ffi::lua_pushcfunction(state, lua_isfunction_impl); - ffi::lua_pushcfunction(state, lua_istable_impl); - ffi::lua_call(state, 3, 1); - - #[cfg(feature = "luau-jit")] - if ffi::luau_codegen_supported() != 0 { - ffi::luau_codegen_compile(state, -1); - } - - // Store in the registry - ffi::lua_pushvalue(state, -1); - ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, index_key); - }) -} - -unsafe fn init_userdata_metatable_newindex(state: *mut ffi::lua_State) -> Result<()> { - let newindex_key = &USERDATA_METATABLE_NEWINDEX as *const u8 as *const _; - if ffi::lua_rawgetp(state, ffi::LUA_REGISTRYINDEX, newindex_key) == ffi::LUA_TFUNCTION { - return Ok(()); - } - ffi::lua_pop(state, 1); - - // Create and cache `__newindex` generator - let code = cr#" - local error, isfunction = ... - return function (__newindex, field_setters) - return function (self, key, value) - if field_setters ~= nil then - local field_setter = field_setters[key] - if field_setter ~= nil then - field_setter(self, value) - return - end - end - - if isfunction(__newindex) then - __newindex(self, key, value) - elseif __newindex == nil then - error("attempt to set an unknown field '"..key.."'") - else - __newindex[key] = value - end - end - end - "#; - protect_lua!(state, 0, 1, |state| { - let ret = ffi::luaL_loadbuffer(state, code.as_ptr(), code.count_bytes(), cstr!("__mlua_newindex")); - if ret != ffi::LUA_OK { - ffi::lua_error(state); - } - ffi::lua_pushcfunction(state, lua_error_impl); - ffi::lua_pushcfunction(state, lua_isfunction_impl); - ffi::lua_call(state, 2, 1); - - #[cfg(feature = "luau-jit")] - if ffi::luau_codegen_supported() != 0 { - ffi::luau_codegen_compile(state, -1); - } - - // Store in the registry - ffi::lua_pushvalue(state, -1); - ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, newindex_key); - }) -} - -#[cfg(not(feature = "luau"))] -unsafe extern "C-unwind" fn userdata_destructor(state: *mut ffi::lua_State) -> c_int { - // It's probably NOT a good idea to catch Rust panics in finalizer - // Lua 5.4 ignores it, other versions generates `LUA_ERRGCMM` without calling message handler - take_userdata::(state); - 0 -} - pub(crate) static DESTRUCTED_USERDATA_METATABLE: u8 = 0; -static USERDATA_METATABLE_INDEX: u8 = 0; -static USERDATA_METATABLE_NEWINDEX: u8 = 0; diff --git a/src/value.rs b/src/value.rs index bd088887..d77fddf7 100644 --- a/src/value.rs +++ b/src/value.rs @@ -697,6 +697,15 @@ impl<'a> SerializableValue<'a> { self.options.sort_keys = enabled; self } + + /// If true, empty Lua tables will be encoded as array, instead of map. + /// + /// Default: **false** + #[must_use] + pub const fn encode_empty_tables_as_array(mut self, enabled: bool) -> Self { + self.options.encode_empty_tables_as_array = enabled; + self + } } #[cfg(feature = "serialize")] diff --git a/tests/chunk.rs b/tests/chunk.rs index 16df553b..67d98c25 100644 --- a/tests/chunk.rs +++ b/tests/chunk.rs @@ -1,6 +1,24 @@ use std::{fs, io}; -use mlua::{Chunk, Lua, Result}; +use mlua::{Chunk, ChunkMode, Lua, Result}; + +#[test] +fn test_chunk_methods() -> Result<()> { + let lua = Lua::new(); + + #[cfg(unix)] + assert!(lua.load("return 123").name().contains("tests/chunk.rs")); + let chunk2 = lua.load("return 123").set_name("@new_name"); + assert_eq!(chunk2.name(), "@new_name"); + + let env = lua.create_table_from([("a", 987)])?; + let chunk3 = lua.load("return a").set_environment(env.clone()); + assert_eq!(chunk3.environment().unwrap(), &env); + assert_eq!(chunk3.mode(), ChunkMode::Text); + assert_eq!(chunk3.call::(())?, 987); + + Ok(()) +} #[test] fn test_chunk_path() -> Result<()> { diff --git a/tests/compile/lua_norefunwindsafe.stderr b/tests/compile/lua_norefunwindsafe.stderr index ea2442bd..a482a8d7 100644 --- a/tests/compile/lua_norefunwindsafe.stderr +++ b/tests/compile/lua_norefunwindsafe.stderr @@ -1,32 +1,28 @@ -error[E0277]: the type `UnsafeCell` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary +error[E0277]: the type `UnsafeCell<*mut lua_State>` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary --> tests/compile/lua_norefunwindsafe.rs:7:18 | 7 | catch_unwind(|| lua.create_table().unwrap()); - | ------------ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ `UnsafeCell` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary + | ------------ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ `UnsafeCell<*mut lua_State>` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary | | | required by a bound introduced by this call | - = help: within `Lua`, the trait `RefUnwindSafe` is not implemented for `UnsafeCell`, which is required by `{closure@$DIR/tests/compile/lua_norefunwindsafe.rs:7:18: 7:20}: UnwindSafe` -note: required because it appears within the type `lock_api::remutex::ReentrantMutex` - --> $CARGO/lock_api-0.4.12/src/remutex.rs - | - | pub struct ReentrantMutex { - | ^^^^^^^^^^^^^^ -note: required because it appears within the type `alloc::sync::ArcInner>` - --> $RUST/alloc/src/sync.rs + = help: within `mlua::types::sync::inner::ReentrantMutex`, the trait `RefUnwindSafe` is not implemented for `UnsafeCell<*mut lua_State>` +note: required because it appears within the type `Cell<*mut lua_State>` + --> $RUST/core/src/cell.rs | - | struct ArcInner { - | ^^^^^^^^ -note: required because it appears within the type `PhantomData>>` - --> $RUST/core/src/marker.rs + | pub struct Cell { + | ^^^^ +note: required because it appears within the type `mlua::state::raw::RawLua` + --> src/state/raw.rs | - | pub struct PhantomData; - | ^^^^^^^^^^^ -note: required because it appears within the type `Arc>` - --> $RUST/alloc/src/sync.rs + | pub struct RawLua { + | ^^^^^^ +note: required because it appears within the type `mlua::types::sync::inner::ReentrantMutex` + --> src/types/sync.rs | - | pub struct Arc< - | ^^^ + | pub(crate) struct ReentrantMutex(T); + | ^^^^^^^^^^^^^^ + = note: required for `Rc>` to implement `RefUnwindSafe` note: required because it appears within the type `Lua` --> src/state.rs | @@ -44,45 +40,27 @@ note: required by a bound in `std::panic::catch_unwind` | pub fn catch_unwind R + UnwindSafe, R>(f: F) -> Result { | ^^^^^^^^^^ required by this bound in `catch_unwind` -error[E0277]: the type `UnsafeCell` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary +error[E0277]: the type `UnsafeCell` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary --> tests/compile/lua_norefunwindsafe.rs:7:18 | 7 | catch_unwind(|| lua.create_table().unwrap()); - | ------------ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ `UnsafeCell` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary + | ------------ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ `UnsafeCell` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary | | | required by a bound introduced by this call | - = help: within `Lua`, the trait `RefUnwindSafe` is not implemented for `UnsafeCell`, which is required by `{closure@$DIR/tests/compile/lua_norefunwindsafe.rs:7:18: 7:20}: UnwindSafe` -note: required because it appears within the type `Cell` - --> $RUST/core/src/cell.rs + = help: the trait `RefUnwindSafe` is not implemented for `UnsafeCell` + = note: required for `Rc>` to implement `RefUnwindSafe` +note: required because it appears within the type `mlua::state::raw::RawLua` + --> src/state/raw.rs | - | pub struct Cell { - | ^^^^ -note: required because it appears within the type `lock_api::remutex::RawReentrantMutex` - --> $CARGO/lock_api-0.4.12/src/remutex.rs - | - | pub struct RawReentrantMutex { - | ^^^^^^^^^^^^^^^^^ -note: required because it appears within the type `lock_api::remutex::ReentrantMutex` - --> $CARGO/lock_api-0.4.12/src/remutex.rs - | - | pub struct ReentrantMutex { - | ^^^^^^^^^^^^^^ -note: required because it appears within the type `alloc::sync::ArcInner>` - --> $RUST/alloc/src/sync.rs + | pub struct RawLua { + | ^^^^^^ +note: required because it appears within the type `mlua::types::sync::inner::ReentrantMutex` + --> src/types/sync.rs | - | struct ArcInner { - | ^^^^^^^^ -note: required because it appears within the type `PhantomData>>` - --> $RUST/core/src/marker.rs - | - | pub struct PhantomData; - | ^^^^^^^^^^^ -note: required because it appears within the type `Arc>` - --> $RUST/alloc/src/sync.rs - | - | pub struct Arc< - | ^^^ + | pub(crate) struct ReentrantMutex(T); + | ^^^^^^^^^^^^^^ + = note: required for `Rc>` to implement `RefUnwindSafe` note: required because it appears within the type `Lua` --> src/state.rs | diff --git a/tests/compile/non_send.stderr b/tests/compile/non_send.stderr index c7e28da1..c94b720f 100644 --- a/tests/compile/non_send.stderr +++ b/tests/compile/non_send.stderr @@ -8,7 +8,7 @@ error[E0277]: `Rc>` cannot be sent between threads safely | | within this `{closure@$DIR/tests/compile/non_send.rs:11:25: 11:37}` | required by a bound introduced by this call | - = help: within `{closure@$DIR/tests/compile/non_send.rs:11:25: 11:37}`, the trait `Send` is not implemented for `Rc>`, which is required by `{closure@$DIR/tests/compile/non_send.rs:11:25: 11:37}: MaybeSend` + = help: within `{closure@$DIR/tests/compile/non_send.rs:11:25: 11:37}`, the trait `Send` is not implemented for `Rc>` note: required because it's used within this closure --> tests/compile/non_send.rs:11:25 | diff --git a/tests/compile/ref_nounwindsafe.stderr b/tests/compile/ref_nounwindsafe.stderr index 39e70812..048f9d32 100644 --- a/tests/compile/ref_nounwindsafe.stderr +++ b/tests/compile/ref_nounwindsafe.stderr @@ -1,38 +1,38 @@ -error[E0277]: the type `UnsafeCell` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary +error[E0277]: the type `UnsafeCell` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary --> tests/compile/ref_nounwindsafe.rs:8:18 | 8 | catch_unwind(move || table.set("a", "b").unwrap()); - | ------------ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ `UnsafeCell` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary + | ------------ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ `UnsafeCell` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary | | | required by a bound introduced by this call | - = help: within `alloc::sync::ArcInner>`, the trait `RefUnwindSafe` is not implemented for `UnsafeCell`, which is required by `{closure@$DIR/tests/compile/ref_nounwindsafe.rs:8:18: 8:25}: UnwindSafe` -note: required because it appears within the type `lock_api::remutex::ReentrantMutex` - --> $CARGO/lock_api-0.4.12/src/remutex.rs + = help: within `rc::RcInner>`, the trait `RefUnwindSafe` is not implemented for `UnsafeCell` +note: required because it appears within the type `Cell` + --> $RUST/core/src/cell.rs | - | pub struct ReentrantMutex { - | ^^^^^^^^^^^^^^ -note: required because it appears within the type `alloc::sync::ArcInner>` - --> $RUST/alloc/src/sync.rs + | pub struct Cell { + | ^^^^ +note: required because it appears within the type `rc::RcInner>` + --> $RUST/alloc/src/rc.rs | - | struct ArcInner { - | ^^^^^^^^ - = note: required for `NonNull>>` to implement `UnwindSafe` -note: required because it appears within the type `std::sync::Weak>` - --> $RUST/alloc/src/sync.rs + | struct RcInner { + | ^^^^^^^ + = note: required for `NonNull>>` to implement `UnwindSafe` +note: required because it appears within the type `std::rc::Weak>` + --> $RUST/alloc/src/rc.rs | | pub struct Weak< | ^^^^ -note: required because it appears within the type `mlua::state::WeakLua` +note: required because it appears within the type `WeakLua` --> src/state.rs | - | pub(crate) struct WeakLua(XWeak>); - | ^^^^^^^ -note: required because it appears within the type `mlua::types::ValueRef` - --> src/types.rs + | pub struct WeakLua(XWeak>); + | ^^^^^^^ +note: required because it appears within the type `mlua::types::value_ref::ValueRef` + --> src/types/value_ref.rs | - | pub(crate) struct ValueRef { - | ^^^^^^^^ + | pub struct ValueRef { + | ^^^^^^^^ note: required because it appears within the type `LuaTable` --> src/table.rs | @@ -49,51 +49,108 @@ note: required by a bound in `std::panic::catch_unwind` | pub fn catch_unwind R + UnwindSafe, R>(f: F) -> Result { | ^^^^^^^^^^ required by this bound in `catch_unwind` -error[E0277]: the type `UnsafeCell` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary +error[E0277]: the type `UnsafeCell<*mut lua_State>` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary --> tests/compile/ref_nounwindsafe.rs:8:18 | 8 | catch_unwind(move || table.set("a", "b").unwrap()); - | ------------ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ `UnsafeCell` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary + | ------------ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ `UnsafeCell<*mut lua_State>` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary | | | required by a bound introduced by this call | - = help: within `alloc::sync::ArcInner>`, the trait `RefUnwindSafe` is not implemented for `UnsafeCell`, which is required by `{closure@$DIR/tests/compile/ref_nounwindsafe.rs:8:18: 8:25}: UnwindSafe` -note: required because it appears within the type `Cell` + = help: within `rc::RcInner>`, the trait `RefUnwindSafe` is not implemented for `UnsafeCell<*mut lua_State>` +note: required because it appears within the type `Cell<*mut lua_State>` --> $RUST/core/src/cell.rs | | pub struct Cell { | ^^^^ -note: required because it appears within the type `lock_api::remutex::RawReentrantMutex` - --> $CARGO/lock_api-0.4.12/src/remutex.rs +note: required because it appears within the type `mlua::state::raw::RawLua` + --> src/state/raw.rs + | + | pub struct RawLua { + | ^^^^^^ +note: required because it appears within the type `mlua::types::sync::inner::ReentrantMutex` + --> src/types/sync.rs + | + | pub(crate) struct ReentrantMutex(T); + | ^^^^^^^^^^^^^^ +note: required because it appears within the type `rc::RcInner>` + --> $RUST/alloc/src/rc.rs + | + | struct RcInner { + | ^^^^^^^ + = note: required for `NonNull>>` to implement `UnwindSafe` +note: required because it appears within the type `std::rc::Weak>` + --> $RUST/alloc/src/rc.rs + | + | pub struct Weak< + | ^^^^ +note: required because it appears within the type `WeakLua` + --> src/state.rs + | + | pub struct WeakLua(XWeak>); + | ^^^^^^^ +note: required because it appears within the type `mlua::types::value_ref::ValueRef` + --> src/types/value_ref.rs + | + | pub struct ValueRef { + | ^^^^^^^^ +note: required because it appears within the type `LuaTable` + --> src/table.rs + | + | pub struct Table(pub(crate) ValueRef); + | ^^^^^ +note: required because it's used within this closure + --> tests/compile/ref_nounwindsafe.rs:8:18 + | +8 | catch_unwind(move || table.set("a", "b").unwrap()); + | ^^^^^^^ +note: required by a bound in `std::panic::catch_unwind` + --> $RUST/std/src/panic.rs + | + | pub fn catch_unwind R + UnwindSafe, R>(f: F) -> Result { + | ^^^^^^^^^^ required by this bound in `catch_unwind` + +error[E0277]: the type `UnsafeCell` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary + --> tests/compile/ref_nounwindsafe.rs:8:18 + | +8 | catch_unwind(move || table.set("a", "b").unwrap()); + | ------------ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ `UnsafeCell` may contain interior mutability and a reference may not be safely transferrable across a catch_unwind boundary + | | + | required by a bound introduced by this call + | + = help: the trait `RefUnwindSafe` is not implemented for `UnsafeCell` + = note: required for `Rc>` to implement `RefUnwindSafe` +note: required because it appears within the type `mlua::state::raw::RawLua` + --> src/state/raw.rs | - | pub struct RawReentrantMutex { - | ^^^^^^^^^^^^^^^^^ -note: required because it appears within the type `lock_api::remutex::ReentrantMutex` - --> $CARGO/lock_api-0.4.12/src/remutex.rs + | pub struct RawLua { + | ^^^^^^ +note: required because it appears within the type `mlua::types::sync::inner::ReentrantMutex` + --> src/types/sync.rs | - | pub struct ReentrantMutex { - | ^^^^^^^^^^^^^^ -note: required because it appears within the type `alloc::sync::ArcInner>` - --> $RUST/alloc/src/sync.rs + | pub(crate) struct ReentrantMutex(T); + | ^^^^^^^^^^^^^^ +note: required because it appears within the type `rc::RcInner>` + --> $RUST/alloc/src/rc.rs | - | struct ArcInner { - | ^^^^^^^^ - = note: required for `NonNull>>` to implement `UnwindSafe` -note: required because it appears within the type `std::sync::Weak>` - --> $RUST/alloc/src/sync.rs + | struct RcInner { + | ^^^^^^^ + = note: required for `NonNull>>` to implement `UnwindSafe` +note: required because it appears within the type `std::rc::Weak>` + --> $RUST/alloc/src/rc.rs | | pub struct Weak< | ^^^^ -note: required because it appears within the type `mlua::state::WeakLua` +note: required because it appears within the type `WeakLua` --> src/state.rs | - | pub(crate) struct WeakLua(XWeak>); - | ^^^^^^^ -note: required because it appears within the type `mlua::types::ValueRef` - --> src/types.rs + | pub struct WeakLua(XWeak>); + | ^^^^^^^ +note: required because it appears within the type `mlua::types::value_ref::ValueRef` + --> src/types/value_ref.rs | - | pub(crate) struct ValueRef { - | ^^^^^^^^ + | pub struct ValueRef { + | ^^^^^^^^ note: required because it appears within the type `LuaTable` --> src/table.rs | diff --git a/tests/compile/scope_callback_capture.stderr b/tests/compile/scope_callback_capture.stderr index a6993916..94844dc9 100644 --- a/tests/compile/scope_callback_capture.stderr +++ b/tests/compile/scope_callback_capture.stderr @@ -2,7 +2,7 @@ error[E0373]: closure may outlive the current function, but it borrows `inner`, --> tests/compile/scope_callback_capture.rs:7:43 | 5 | lua.scope(|scope| { - | ----- has type `&'1 mut mlua::scope::Scope<'1, '_>` + | ----- has type `&'1 mlua::Scope<'1, '_>` 6 | let mut inner: Option = None; 7 | let f = scope.create_function_mut(|_, t: Table| { | ^^^^^^^^^^^^^ may outlive borrowed value `inner` diff --git a/tests/compile/scope_invariance.stderr b/tests/compile/scope_invariance.stderr index 91158aa6..a3f218df 100644 --- a/tests/compile/scope_invariance.stderr +++ b/tests/compile/scope_invariance.stderr @@ -2,7 +2,7 @@ error[E0373]: closure may outlive the current function, but it borrows `test.fie --> tests/compile/scope_invariance.rs:13:39 | 9 | lua.scope(|scope| { - | ----- has type `&'1 mut mlua::scope::Scope<'1, '_>` + | ----- has type `&'1 mlua::Scope<'1, '_>` ... 13 | scope.create_function_mut(|_, ()| { | ^^^^^^^ may outlive borrowed value `test.field` diff --git a/tests/compile/scope_mutable_aliasing.stderr b/tests/compile/scope_mutable_aliasing.stderr index 362cf91d..e6e57f13 100644 --- a/tests/compile/scope_mutable_aliasing.stderr +++ b/tests/compile/scope_mutable_aliasing.stderr @@ -2,7 +2,7 @@ error[E0499]: cannot borrow `i` as mutable more than once at a time --> tests/compile/scope_mutable_aliasing.rs:12:51 | 10 | lua.scope(|scope| { - | ----- has type `&mut mlua::scope::Scope<'_, '1>` + | ----- has type `&mlua::Scope<'_, '1>` 11 | let _a = scope.create_userdata(MyUserData(&mut i)).unwrap(); | ----------------------------------------- | | | diff --git a/tests/compile/scope_userdata_borrow.stderr b/tests/compile/scope_userdata_borrow.stderr index 043a99b1..43025ddf 100644 --- a/tests/compile/scope_userdata_borrow.stderr +++ b/tests/compile/scope_userdata_borrow.stderr @@ -2,7 +2,7 @@ error[E0597]: `ibad` does not live long enough --> tests/compile/scope_userdata_borrow.rs:15:46 | 11 | lua.scope(|scope| { - | ----- has type `&mut mlua::scope::Scope<'_, '1>` + | ----- has type `&mlua::Scope<'_, '1>` ... 14 | let ibad = 42; | ---- binding `ibad` declared here diff --git a/tests/conversion.rs b/tests/conversion.rs index 2cb9b880..e75a3a06 100644 --- a/tests/conversion.rs +++ b/tests/conversion.rs @@ -6,8 +6,8 @@ use std::path::PathBuf; use bstr::BString; use maplit::{btreemap, btreeset, hashmap, hashset}; use mlua::{ - AnyUserData, Either, Error, Function, IntoLua, Lua, RegistryKey, Result, Table, Thread, UserDataRef, - Value, + AnyUserData, BorrowedBytes, BorrowedStr, Either, Error, Function, IntoLua, Lua, RegistryKey, Result, + Table, Thread, UserDataRef, Value, }; #[test] @@ -60,6 +60,66 @@ fn test_string_from_lua() -> Result<()> { Ok(()) } +#[test] +fn test_borrowedstr_into_lua() -> Result<()> { + let lua = Lua::new(); + + // Direct conversion + let s = lua.create_string("hello, world!")?; + let bs = s.to_str()?; + let bs2 = (&bs).into_lua(&lua)?; + assert_eq!(bs2.as_string().unwrap(), "hello, world!"); + + // Push into stack + let table = lua.create_table()?; + table.set("bs", &bs)?; + assert_eq!(bs, table.get::("bs")?); + + Ok(()) +} + +#[test] +fn test_borrowedstr_from_lua() -> Result<()> { + let lua = Lua::new(); + + // From stack + let f = lua.create_function(|_, s: BorrowedStr| Ok(s))?; + let s = f.call::("hello, world!")?; + assert_eq!(s, "hello, world!"); + + Ok(()) +} + +#[test] +fn test_borrowedbytes_into_lua() -> Result<()> { + let lua = Lua::new(); + + // Direct conversion + let s = lua.create_string("hello, world!")?; + let bb = s.as_bytes(); + let bb2 = (&bb).into_lua(&lua)?; + assert_eq!(bb2.as_string().unwrap(), "hello, world!"); + + // Push into stack + let table = lua.create_table()?; + table.set("bb", &bb)?; + assert_eq!(bb, table.get::("bb")?.as_bytes()); + + Ok(()) +} + +#[test] +fn test_borrowedbytes_from_lua() -> Result<()> { + let lua = Lua::new(); + + // From stack + let f = lua.create_function(|_, s: BorrowedBytes| Ok(s))?; + let s = f.call::("hello, world!")?; + assert_eq!(s, "hello, world!"); + + Ok(()) +} + #[test] fn test_table_into_lua() -> Result<()> { let lua = Lua::new(); @@ -657,3 +717,34 @@ fn test_either_from_lua() -> Result<()> { Ok(()) } + +#[test] +fn test_char_into_lua() -> Result<()> { + let lua = Lua::new(); + + let v = '🦀'; + let v2 = v.into_lua(&lua)?; + assert_eq!(Some(v.to_string()), v2.as_string_lossy()); + + Ok(()) +} + +#[test] +fn test_char_from_lua() -> Result<()> { + let lua = Lua::new(); + + assert_eq!(lua.convert::("A")?, 'A'); + assert_eq!(lua.convert::(65)?, 'A'); + assert_eq!(lua.convert::(128175)?, '💯'); + assert!(lua + .convert::(5456324) + .is_err_and(|e| e.to_string().contains("integer out of range"))); + assert!(lua + .convert::("hello") + .is_err_and(|e| e.to_string().contains("expected string to have exactly one char"))); + assert!(lua + .convert::(HashMap::::new()) + .is_err_and(|e| e.to_string().contains("expected string or integer"))); + + Ok(()) +} diff --git a/tests/function.rs b/tests/function.rs index b8c10703..683eca9c 100644 --- a/tests/function.rs +++ b/tests/function.rs @@ -214,7 +214,7 @@ fn test_function_dump() -> Result<()> { #[cfg(feature = "luau")] #[test] -fn test_finction_coverage() -> Result<()> { +fn test_function_coverage() -> Result<()> { let lua = Lua::new(); lua.set_compiler(mlua::Compiler::default().set_coverage_level(1)); diff --git a/tests/memory.rs b/tests/memory.rs index ba30761b..ba1d7e48 100644 --- a/tests/memory.rs +++ b/tests/memory.rs @@ -57,8 +57,8 @@ fn test_memory_limit_thread() -> Result<()> { return Ok(()); } - lua.set_memory_limit(lua.used_memory() + 10000)?; let thread = lua.create_thread(f)?; + lua.set_memory_limit(lua.used_memory() + 10000)?; match thread.resume::<()>(()) { Err(Error::MemoryError(_)) => {} something_else => panic!("did not trigger memory error: {:?}", something_else), diff --git a/tests/scope.rs b/tests/scope.rs index 4113f385..9aa3ee19 100644 --- a/tests/scope.rs +++ b/tests/scope.rs @@ -235,12 +235,14 @@ fn test_scope_userdata_values() -> Result<()> { #[test] fn test_scope_userdata_mismatch() -> Result<()> { - struct MyUserData<'a>(&'a Cell); + struct MyUserData<'a>(&'a mut i64); impl<'a> UserData for MyUserData<'a> { fn register(reg: &mut UserDataRegistry) { - reg.add_method("inc", |_, data, ()| { - data.0.set(data.0.get() + 1); + reg.add_method("get", |_, data, ()| Ok(*data.0)); + + reg.add_method_mut("inc", |_, data, ()| { + *data.0 = data.0.wrapping_add(1); Ok(()) }); } @@ -251,30 +253,53 @@ fn test_scope_userdata_mismatch() -> Result<()> { lua.load( r#" function inc(a, b) a.inc(b) end + function get(a, b) a.get(b) end "#, ) .exec()?; - let a = Cell::new(1); - let b = Cell::new(1); + let mut a = 1; + let mut b = 1; - let inc: Function = lua.globals().get("inc")?; lua.scope(|scope| { - let au = scope.create_userdata(MyUserData(&a))?; - let bu = scope.create_userdata(MyUserData(&b))?; - assert!(inc.call::<()>((&au, &au)).is_ok()); - match inc.call::<()>((&au, &bu)) { - Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() { - Error::BadArgument { to, pos, name, cause } => { - assert_eq!(to.as_deref(), Some("MyUserData.inc")); - assert_eq!(*pos, 1); - assert_eq!(name.as_deref(), Some("self")); - assert!(matches!(*cause.as_ref(), Error::UserDataTypeMismatch)); - } + let au = scope.create_userdata(MyUserData(&mut a))?; + let bu = scope.create_userdata(MyUserData(&mut b))?; + for method_name in ["get", "inc"] { + let f: Function = lua.globals().get(method_name)?; + let full_name = format!("MyUserData.{method_name}"); + let full_name = full_name.as_str(); + + assert!(f.call::<()>((&au, &au)).is_ok()); + match f.call::<()>((&au, &bu)) { + Err(Error::CallbackError { ref cause, .. }) => match cause.as_ref() { + Error::BadArgument { to, pos, name, cause } => { + assert_eq!(to.as_deref(), Some(full_name)); + assert_eq!(*pos, 1); + assert_eq!(name.as_deref(), Some("self")); + assert!(matches!(*cause.as_ref(), Error::UserDataTypeMismatch)); + } + other => panic!("wrong error type {other:?}"), + }, + Err(other) => panic!("wrong error type {other:?}"), + Ok(_) => panic!("incorrectly returned Ok"), + } + + // Pass non-userdata type + let err = f.call::<()>((&au, 321)).err().unwrap(); + match err { + Error::CallbackError { ref cause, .. } => match cause.as_ref() { + Error::BadArgument { to, pos, name, cause } => { + assert_eq!(to.as_deref(), Some(full_name)); + assert_eq!(*pos, 1); + assert_eq!(name.as_deref(), Some("self")); + assert!(matches!(*cause.as_ref(), Error::FromLuaConversionError { .. })); + } + other => panic!("wrong error type {other:?}"), + }, other => panic!("wrong error type {other:?}"), - }, - Err(other) => panic!("wrong error type {other:?}"), - Ok(_) => panic!("incorrectly returned Ok"), + } + let err_msg = format!("bad argument `self` to `{full_name}`: error converting Lua number to userdata (expected userdata of type 'MyUserData')"); + assert!(err.to_string().contains(&err_msg)); } Ok(()) })?; @@ -318,7 +343,7 @@ fn test_scope_userdata_drop() -> Result<()> { let ud = lua.globals().get::("ud")?; match ud.borrow_scoped::(|_| Ok::<_, Error>(())) { - Ok(_) => panic!("succesfull borrow for destructed userdata"), + Ok(_) => panic!("successful borrow for destructed userdata"), Err(Error::UserDataDestructed) => {} Err(err) => panic!("improper borrow error for destructed userdata: {err:?}"), } diff --git a/tests/serde.rs b/tests/serde.rs index 4efb6537..9e3d5984 100644 --- a/tests/serde.rs +++ b/tests/serde.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use std::error::Error as StdError; +use bstr::BString; use mlua::{ AnyUserData, DeserializeOptions, Error, ExternalResult, IntoLua, Lua, LuaSerdeExt, Result as LuaResult, SerializeOptions, UserData, Value, @@ -248,6 +249,26 @@ fn test_serialize_same_table_twice() -> LuaResult<()> { Ok(()) } +#[test] +fn test_serialize_empty_table() -> LuaResult<()> { + let lua = Lua::new(); + + let table = Value::Table(lua.create_table()?); + let json = serde_json::to_string(&table.to_serializable()).unwrap(); + assert_eq!(json, "{}"); + + // Set the option to encode empty tables as array + let json = serde_json::to_string(&table.to_serializable().encode_empty_tables_as_array(true)).unwrap(); + assert_eq!(json, "[]"); + + // Check hashmap table with this option + table.as_table().unwrap().set("hello", "world")?; + let json = serde_json::to_string(&table.to_serializable().encode_empty_tables_as_array(true)).unwrap(); + assert_eq!(json, r#"{"hello":"world"}"#); + + Ok(()) +} + #[test] fn test_to_value_struct() -> LuaResult<()> { let lua = Lua::new(); @@ -420,6 +441,7 @@ fn test_from_value_struct() -> Result<(), Box> { map: HashMap, empty: Vec<()>, tuple: (u8, u8, u8), + bytes: BString, } let value = lua @@ -431,6 +453,7 @@ fn test_from_value_struct() -> Result<(), Box> { map = {2, [4] = 1}, empty = {}, tuple = {10, 20, 30}, + bytes = "\240\040\140\040", } "#, ) @@ -443,6 +466,7 @@ fn test_from_value_struct() -> Result<(), Box> { map: vec![(1, 2), (4, 1)].into_iter().collect(), empty: vec![], tuple: (10, 20, 30), + bytes: BString::from([240, 40, 140, 40]), }, got ); @@ -663,6 +687,37 @@ fn test_from_value_userdata() -> Result<(), Box> { Ok(()) } +#[test] +fn test_from_value_empty_table() -> Result<(), Box> { + let lua = Lua::new(); + + // By default we encode empty tables as objects + let t = lua.create_table()?; + let got = lua.from_value::(Value::Table(t.clone()))?; + assert_eq!(got, serde_json::json!({})); + + // Set the option to encode empty tables as array + let got = lua + .from_value_with::( + Value::Table(t.clone()), + DeserializeOptions::new().encode_empty_tables_as_array(true), + ) + .unwrap(); + assert_eq!(got, serde_json::json!([])); + + // Check hashmap table with this option + t.raw_set("hello", "world")?; + let got = lua + .from_value_with::( + Value::Table(t), + DeserializeOptions::new().encode_empty_tables_as_array(true), + ) + .unwrap(); + assert_eq!(got, serde_json::json!({"hello": "world"})); + + Ok(()) +} + #[test] fn test_from_value_sorted() -> Result<(), Box> { let lua = Lua::new(); diff --git a/tests/string.rs b/tests/string.rs index 1f849df9..f6bdd995 100644 --- a/tests/string.rs +++ b/tests/string.rs @@ -143,3 +143,17 @@ fn test_string_wrap() -> Result<()> { Ok(()) } + +#[test] +fn test_bytes_into_iter() -> Result<()> { + let lua = Lua::new(); + + let s = lua.create_string("hello")?; + let bytes = s.as_bytes(); + + for (i, &b) in bytes.into_iter().enumerate() { + assert_eq!(b, s.as_bytes()[i]); + } + + Ok(()) +} diff --git a/tests/tests.rs b/tests/tests.rs index 89bece8a..4a00724d 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -12,6 +12,24 @@ use mlua::{ Value, Variadic, }; +#[test] +fn test_weak_lua() { + let lua = Lua::new(); + let weak_lua = lua.weak(); + assert!(weak_lua.try_upgrade().is_some()); + drop(lua); + assert!(weak_lua.try_upgrade().is_none()); +} + +#[test] +#[should_panic(expected = "Lua instance is destroyed")] +fn test_weak_lua_panic() { + let lua = Lua::new(); + let weak_lua = lua.weak(); + drop(lua); + let _ = weak_lua.upgrade(); +} + #[cfg(not(feature = "luau"))] #[test] fn test_safety() -> Result<()> { @@ -322,7 +340,7 @@ fn test_error() -> Result<()> { let return_string_error = globals.get::("return_string_error")?; assert!(return_string_error.call::(()).is_ok()); - match lua.load("if youre happy and you know it syntax error").exec() { + match lua.load("if you are happy and you know it syntax error").exec() { Err(Error::SyntaxError { incomplete_input: false, .. @@ -1289,6 +1307,13 @@ fn test_warnings() -> Result<()> { if matches!(*cause, Error::RuntimeError(ref err) if err == "warning error") )); + // Recursive warning + lua.set_warning_function(|lua, _, _| { + lua.warning("inner", false); + Ok(()) + }); + lua.warning("hello", false); + Ok(()) } diff --git a/tests/thread.rs b/tests/thread.rs index 7ece2b56..74f75614 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -107,7 +107,6 @@ fn test_thread() -> Result<()> { } #[test] -#[cfg(any(feature = "lua54", feature = "luau"))] fn test_thread_reset() -> Result<()> { use mlua::{AnyUserData, UserData}; use std::sync::Arc; @@ -120,7 +119,8 @@ fn test_thread_reset() -> Result<()> { let arc = Arc::new(()); let func: Function = lua.load(r#"function(ud) coroutine.yield(ud) end"#).eval()?; - let thread = lua.create_thread(func.clone())?; + let thread = lua.create_thread(lua.load("return 0").into_function()?)?; // Dummy function first + assert!(thread.reset(func.clone()).is_ok()); for _ in 0..2 { assert_eq!(thread.status(), ThreadStatus::Resumable); @@ -145,11 +145,7 @@ fn test_thread_reset() -> Result<()> { assert!(thread.reset(func.clone()).is_err()); // Reset behavior has changed in Lua v5.4.4 // It's became possible to force reset thread by popping error object - assert!(matches!( - thread.status(), - ThreadStatus::Finished | ThreadStatus::Error - )); - // Would pass in 5.4.4 + assert!(matches!(thread.status(), ThreadStatus::Finished)); assert!(thread.reset(func.clone()).is_ok()); assert_eq!(thread.status(), ThreadStatus::Resumable); } diff --git a/tests/userdata.rs b/tests/userdata.rs index 59248d0f..63f9e0f5 100644 --- a/tests/userdata.rs +++ b/tests/userdata.rs @@ -1,3 +1,4 @@ +use std::any::TypeId; use std::collections::HashMap; use std::string::String as StdString; use std::sync::Arc; @@ -23,9 +24,11 @@ fn test_userdata() -> Result<()> { let userdata2 = lua.create_userdata(UserData2(Box::new(2)))?; assert!(userdata1.is::()); + assert!(userdata1.type_id() == Some(TypeId::of::())); assert!(!userdata1.is::()); assert!(userdata2.is::()); assert!(!userdata2.is::()); + assert!(userdata2.type_id() == Some(TypeId::of::())); assert_eq!(userdata1.borrow::()?.0, 1); assert_eq!(*userdata2.borrow::()?.0, 2); @@ -261,7 +264,7 @@ fn test_gc_userdata() -> Result<()> { impl UserData for MyUserdata { fn add_methods>(methods: &mut M) { methods.add_method("access", |_, this, ()| { - assert!(this.id == 123); + assert_eq!(this.id, 123); Ok(()) }); } @@ -410,7 +413,7 @@ fn test_userdata_destroy() -> Result<()> { let ud_ref = ud.borrow::()?; // With active `UserDataRef` this methods only marks userdata as destructed // without running destructor - ud.destroy()?; + ud.destroy().unwrap(); assert_eq!(Arc::strong_count(&rc), 2); drop(ud_ref); assert_eq!(Arc::strong_count(&rc), 1); @@ -419,7 +422,7 @@ fn test_userdata_destroy() -> Result<()> { let ud = lua.create_userdata(MyUserdata(rc.clone()))?; lua.globals().set("ud", &ud)?; lua.load("ud:try_destroy()").exec().unwrap(); - ud.destroy()?; + ud.destroy().unwrap(); assert_eq!(Arc::strong_count(&rc), 1); Ok(()) @@ -913,6 +916,7 @@ fn test_nested_userdata_gc() -> Result<()> { #[cfg(feature = "userdata-wrappers")] #[test] fn test_userdata_wrappers() -> Result<()> { + #[derive(Debug)] struct MyUserData(i64); impl UserData for MyUserData { @@ -924,6 +928,10 @@ fn test_userdata_wrappers() -> Result<()> { Ok(()) }) } + + fn add_methods>(methods: &mut M) { + methods.add_method("dbg", |_, this, ()| Ok(format!("{this:?}"))); + } } let lua = Lua::new(); @@ -932,136 +940,359 @@ fn test_userdata_wrappers() -> Result<()> { // Rc #[cfg(not(feature = "send"))] { - let ud = std::rc::Rc::new(MyUserData(1)); - globals.set("rc_ud", ud.clone())?; + use std::rc::Rc; + + let ud = Rc::new(MyUserData(1)); + globals.set("ud", ud.clone())?; lua.load( r#" - assert(rc_ud.static == "constant") - local ok, err = pcall(function() rc_ud.data = 2 end) + assert(ud.static == "constant") + local ok, err = pcall(function() ud.data = 2 end) assert( - tostring(err):sub(1, 32) == "error mutably borrowing userdata", - "expected error mutably borrowing userdata, got " .. tostring(err) + tostring(err):find("error mutably borrowing userdata") ~= nil, + "expected 'error mutably borrowing userdata', got '" .. tostring(err) .. "'" ) - assert(rc_ud.data == 1) + assert(ud.data == 1) + assert(ud:dbg(), "MyUserData(1)") "#, ) .exec() .unwrap(); - globals.set("rc_ud", Nil)?; + + // Test borrowing original userdata + { + let ud = globals.get::("ud")?; + assert!(ud.is::>()); + assert!(!ud.is::()); + + assert_eq!(ud.borrow::()?.0, 1); + assert!(matches!( + ud.borrow_mut::(), + Err(Error::UserDataBorrowMutError) + )); + assert!(ud.borrow_mut::>().is_ok()); + + assert_eq!(ud.borrow_scoped::(|x| x.0)?, 1); + assert!(matches!( + ud.borrow_mut_scoped::(|_| ()), + Err(Error::UserDataBorrowMutError) + )); + } + + // Collect userdata + globals.set("ud", Nil)?; lua.gc_collect()?; - assert_eq!(std::rc::Rc::strong_count(&ud), 1); + assert_eq!(Rc::strong_count(&ud), 1); + + // We must be able to mutate userdata when having one reference only + globals.set("ud", ud)?; + lua.load( + r#" + ud.data = 2 + assert(ud.data == 2) + "#, + ) + .exec() + .unwrap(); } // Rc> #[cfg(not(feature = "send"))] { - let ud = std::rc::Rc::new(std::cell::RefCell::new(MyUserData(2))); - globals.set("rc_refcell_ud", ud.clone())?; + use std::cell::RefCell; + use std::rc::Rc; + + let ud = Rc::new(RefCell::new(MyUserData(2))); + globals.set("ud", ud.clone())?; lua.load( r#" - assert(rc_refcell_ud.static == "constant") - rc_refcell_ud.data = rc_refcell_ud.data + 1 - assert(rc_refcell_ud.data == 3) - "#, + assert(ud.static == "constant") + assert(ud.data == 2) + ud.data = 10 + assert(ud.data == 10) + assert(ud:dbg() == "MyUserData(10)") + "#, ) - .exec()?; - assert_eq!(ud.borrow().0, 3); - globals.set("rc_refcell_ud", Nil)?; + .exec() + .unwrap(); + + // Test borrowing original userdata + { + let ud = globals.get::("ud")?; + assert!(ud.is::>>()); + assert!(!ud.is::()); + + assert_eq!(ud.borrow::()?.0, 10); + assert_eq!(ud.borrow_mut::()?.0, 10); + ud.borrow_mut::()?.0 = 20; + assert_eq!(ud.borrow::()?.0, 20); + + assert_eq!(ud.borrow_scoped::(|x| x.0)?, 20); + ud.borrow_mut_scoped::(|x| x.0 = 30)?; + assert_eq!(ud.borrow::()?.0, 30); + + // Double (read) borrow is okay + let _borrow = ud.borrow::()?; + assert_eq!(ud.borrow::()?.0, 30); + assert!(matches!( + ud.borrow_mut::(), + Err(Error::UserDataBorrowMutError) + )); + } + + // Collect userdata + globals.set("ud", Nil)?; + lua.gc_collect()?; + assert_eq!(Rc::strong_count(&ud), 1); + + // Check destroying wrapped UserDataRef without references in Lua + let ud = lua.convert::>(ud)?; lua.gc_collect()?; - assert_eq!(std::rc::Rc::strong_count(&ud), 1); + assert_eq!(ud.0, 30); + drop(ud); } // Arc { let ud = Arc::new(MyUserData(3)); - globals.set("arc_ud", ud.clone())?; + globals.set("ud", ud.clone())?; lua.load( r#" - assert(arc_ud.static == "constant") - local ok, err = pcall(function() arc_ud.data = 10 end) + assert(ud.static == "constant") + local ok, err = pcall(function() ud.data = 4 end) assert( - tostring(err):sub(1, 32) == "error mutably borrowing userdata", - "expected error mutably borrowing userdata, got " .. tostring(err) + tostring(err):find("error mutably borrowing userdata") ~= nil, + "expected 'error mutably borrowing userdata', got '" .. tostring(err) .. "'" ) - assert(arc_ud.data == 3) - "#, + assert(ud.data == 3) + assert(ud:dbg() == "MyUserData(3)") + "#, ) - .exec()?; - globals.set("arc_ud", Nil)?; + .exec() + .unwrap(); + + // Test borrowing original userdata + { + let ud = globals.get::("ud")?; + assert!(ud.is::>()); + assert!(!ud.is::()); + + assert_eq!(ud.borrow::()?.0, 3); + assert!(matches!( + ud.borrow_mut::(), + Err(Error::UserDataBorrowMutError) + )); + assert!(ud.borrow_mut::>().is_ok()); + + assert_eq!(ud.borrow_scoped::(|x| x.0)?, 3); + assert!(matches!( + ud.borrow_mut_scoped::(|_| ()), + Err(Error::UserDataBorrowMutError) + )); + } + + // Collect userdata + globals.set("ud", Nil)?; lua.gc_collect()?; assert_eq!(Arc::strong_count(&ud), 1); + + // We must be able to mutate userdata when having one reference only + globals.set("ud", ud)?; + lua.load( + r#" + ud.data = 4 + assert(ud.data == 4) + "#, + ) + .exec() + .unwrap(); } // Arc> { - let ud = Arc::new(std::sync::Mutex::new(MyUserData(4))); - globals.set("arc_mutex_ud", ud.clone())?; + use std::sync::Mutex; + + let ud = Arc::new(Mutex::new(MyUserData(5))); + globals.set("ud", ud.clone())?; lua.load( r#" - assert(arc_mutex_ud.static == "constant") - arc_mutex_ud.data = arc_mutex_ud.data + 1 - assert(arc_mutex_ud.data == 5) - "#, + assert(ud.static == "constant") + assert(ud.data == 5) + ud.data = 6 + assert(ud.data == 6) + assert(ud:dbg() == "MyUserData(6)") + "#, ) - .exec()?; - assert_eq!(ud.lock().unwrap().0, 5); - globals.set("arc_mutex_ud", Nil)?; + .exec() + .unwrap(); + + // Test borrowing original userdata + { + let ud = globals.get::("ud")?; + assert!(ud.is::>>()); + assert!(!ud.is::()); + + #[rustfmt::skip] + assert!(matches!(ud.borrow::(), Err(Error::UserDataTypeMismatch))); + #[rustfmt::skip] + assert!(matches!(ud.borrow_mut::(), Err(Error::UserDataTypeMismatch))); + + assert_eq!(ud.borrow_scoped::(|x| x.0)?, 6); + ud.borrow_mut_scoped::(|x| x.0 = 8)?; + assert_eq!(ud.borrow_scoped::(|x| x.0)?, 8); + } + + // Collect userdata + globals.set("ud", Nil)?; lua.gc_collect()?; assert_eq!(Arc::strong_count(&ud), 1); } // Arc> { - let ud = Arc::new(std::sync::RwLock::new(MyUserData(6))); - globals.set("arc_rwlock_ud", ud.clone())?; + use std::sync::RwLock; + + let ud = Arc::new(RwLock::new(MyUserData(9))); + globals.set("ud", ud.clone())?; lua.load( r#" - assert(arc_rwlock_ud.static == "constant") - arc_rwlock_ud.data = arc_rwlock_ud.data + 1 - assert(arc_rwlock_ud.data == 7) - "#, + assert(ud.static == "constant") + assert(ud.data == 9) + ud.data = 10 + assert(ud.data == 10) + assert(ud:dbg() == "MyUserData(10)") + "#, ) - .exec()?; - assert_eq!(ud.read().unwrap().0, 7); - globals.set("arc_rwlock_ud", Nil)?; + .exec() + .unwrap(); + + // Test borrowing original userdata + { + let ud = globals.get::("ud")?; + assert!(ud.is::>>()); + assert!(!ud.is::()); + + #[rustfmt::skip] + assert!(matches!(ud.borrow::(), Err(Error::UserDataTypeMismatch))); + #[rustfmt::skip] + assert!(matches!(ud.borrow_mut::(), Err(Error::UserDataTypeMismatch))); + + assert_eq!(ud.borrow_scoped::(|x| x.0)?, 10); + ud.borrow_mut_scoped::(|x| x.0 = 12)?; + assert_eq!(ud.borrow_scoped::(|x| x.0)?, 12); + } + + // Collect userdata + globals.set("ud", Nil)?; lua.gc_collect()?; assert_eq!(Arc::strong_count(&ud), 1); } // Arc> { - let ud = Arc::new(parking_lot::Mutex::new(MyUserData(8))); - globals.set("arc_parking_lot_mutex_ud", ud.clone())?; + use parking_lot::Mutex; + + let ud = Arc::new(Mutex::new(MyUserData(13))); + globals.set("ud", ud.clone())?; lua.load( r#" - assert(arc_parking_lot_mutex_ud.static == "constant") - arc_parking_lot_mutex_ud.data = arc_parking_lot_mutex_ud.data + 1 - assert(arc_parking_lot_mutex_ud.data == 9) - "#, + assert(ud.static == "constant") + assert(ud.data == 13) + ud.data = 14 + assert(ud.data == 14) + assert(ud:dbg() == "MyUserData(14)") + "#, ) - .exec()?; - assert_eq!(ud.lock().0, 9); - globals.set("arc_parking_lot_mutex_ud", Nil)?; + .exec() + .unwrap(); + + // Test borrowing original userdata + { + let ud = globals.get::("ud")?; + assert!(ud.is::>>()); + assert!(!ud.is::()); + + assert_eq!(ud.borrow::()?.0, 14); + assert_eq!(ud.borrow_mut::()?.0, 14); + ud.borrow_mut::()?.0 = 15; + assert_eq!(ud.borrow::()?.0, 15); + + assert_eq!(ud.borrow_scoped::(|x| x.0)?, 15); + ud.borrow_mut_scoped::(|x| x.0 = 16)?; + assert_eq!(ud.borrow::()?.0, 16); + + // Double borrow is not allowed + let _borrow = ud.borrow::()?; + assert!(matches!( + ud.borrow::(), + Err(Error::UserDataBorrowError) + )); + } + + // Collect userdata + globals.set("ud", Nil)?; lua.gc_collect()?; assert_eq!(Arc::strong_count(&ud), 1); + + // Check destroying wrapped UserDataRef without references in Lua + let ud = lua.convert::>(ud)?; + lua.gc_collect()?; + assert_eq!(ud.0, 16); + drop(ud); } // Arc> { - let ud = Arc::new(parking_lot::RwLock::new(MyUserData(10))); - globals.set("arc_parking_lot_rwlock_ud", ud.clone())?; + use parking_lot::RwLock; + + let ud = Arc::new(RwLock::new(MyUserData(17))); + globals.set("ud", ud.clone())?; lua.load( r#" - assert(arc_parking_lot_rwlock_ud.static == "constant") - arc_parking_lot_rwlock_ud.data = arc_parking_lot_rwlock_ud.data + 1 - assert(arc_parking_lot_rwlock_ud.data == 11) - "#, + assert(ud.static == "constant") + assert(ud.data == 17) + ud.data = 18 + assert(ud.data == 18) + assert(ud:dbg() == "MyUserData(18)") + "#, ) - .exec()?; - assert_eq!(ud.read().0, 11); - globals.set("arc_parking_lot_rwlock_ud", Nil)?; + .exec() + .unwrap(); + + // Test borrowing original userdata + { + let ud = globals.get::("ud")?; + assert!(ud.is::>>()); + assert!(!ud.is::()); + + assert_eq!(ud.borrow::()?.0, 18); + assert_eq!(ud.borrow_mut::()?.0, 18); + ud.borrow_mut::()?.0 = 19; + assert_eq!(ud.borrow::()?.0, 19); + + assert_eq!(ud.borrow_scoped::(|x| x.0)?, 19); + ud.borrow_mut_scoped::(|x| x.0 = 20)?; + assert_eq!(ud.borrow::()?.0, 20); + + // Multiple read borrows are allowed with parking_lot::RwLock + let _borrow1 = ud.borrow::()?; + let _borrow2 = ud.borrow::()?; + assert!(matches!( + ud.borrow_mut::(), + Err(Error::UserDataBorrowMutError) + )); + } + + // Collect userdata + globals.set("ud", Nil)?; lua.gc_collect()?; assert_eq!(Arc::strong_count(&ud), 1); + + // Check destroying wrapped UserDataRef without references in Lua + let ud = lua.convert::>(ud)?; + lua.gc_collect()?; + assert_eq!(ud.0, 20); + drop(ud); } Ok(())