diff --git a/src/core/interpreter.cc b/src/core/interpreter.cc index 478646833ec1..ada7b629ea57 100644 --- a/src/core/interpreter.cc +++ b/src/core/interpreter.cc @@ -746,6 +746,7 @@ bool Interpreter::AddInternal(const char* f_id, string_view body, string* error) return true; } +// Stack is cleaned for us, we can leave it dirty bool Interpreter::IsTableSafe() const { auto fres = FetchKey(lua_, "err"); if (fres && *fres == LUA_TSTRING) { @@ -757,40 +758,35 @@ bool Interpreter::IsTableSafe() const { return true; } - vector> lens; - unsigned len = lua_rawlen(lua_, -1); - unsigned i = 0; - - // implement dfs traversal - while (true) { - while (i < len) { - DVLOG(1) << "Stack " << lua_gettop(lua_) << "/" << i << "/" << len; - int t = lua_rawgeti(lua_, -1, i + 1); // push table element - if (t == LUA_TTABLE) { - if (lens.size() >= 127) // reached depth 128 - return false; - - CHECK(lua_checkstack(lua_, 1)); - lens.emplace_back(i + 1, len); // save the parent state. - - // reset to iterate on the next table. - i = 0; - len = lua_rawlen(lua_, -1); - } else { - lua_pop(lua_, 1); // pop table element - ++i; - } - } + // Copy root table because we remove it upon finishing traversal + lua_pushnil(lua_); + lua_copy(lua_, -2, -1); - if (lens.empty()) // exit criteria - break; + int depth = 1; + lua_pushnil(lua_); - // unwind to the state before we went down the stack. - tie(i, len) = lens.back(); - lens.pop_back(); + // DFS based on lua stack: [parent-table] [parent-key] [parent-value = table] [key] + while (depth > 0) { + if (lua_checkstack(lua_, 3) == 0 || depth > 128) + return false; - lua_pop(lua_, 1); - }; + bool descending = false; + for (; lua_next(lua_, -2) != 0; lua_pop(lua_, 1)) { + if (lua_type(lua_, -1) != LUA_TTABLE) + continue; + + // If we descend, keep value as new table and push nil for start key + depth++; + lua_pushnil(lua_); + descending = true; + break; + } + + if (!descending) { + lua_pop(lua_, 1); + depth--; + } + } return true; } @@ -827,7 +823,29 @@ void Interpreter::SerializeResult(ObjectExplorer* serializer) { break; } + fres = FetchKey(lua_, "map"); + if (fres && *fres == LUA_TTABLE) { + // Calculate length of map part, there is sadly no other way + unsigned len = 0; + for (lua_pushnil(lua_); lua_next(lua_, -2) != 0; lua_pop(lua_, 1)) + len++; + + serializer->OnMapStart(len); + for (lua_pushnil(lua_); lua_next(lua_, -2) != 0;) { + // Push key to stack top: key value key + lua_pushnil(lua_); + lua_copy(lua_, -3, -1); + SerializeResult(serializer); // pops key + SerializeResult(serializer); // pop value + } + serializer->OnMapEnd(); + + lua_pop(lua_, 2); + break; + } + unsigned len = lua_rawlen(lua_, -1); + serializer->OnArrayStart(len); for (unsigned i = 0; i < len; ++i) { t = lua_rawgeti(lua_, -1, i + 1); // push table element diff --git a/src/core/interpreter.h b/src/core/interpreter.h index e2d333aba88f..c5fb1bee7893 100644 --- a/src/core/interpreter.h +++ b/src/core/interpreter.h @@ -28,6 +28,14 @@ class ObjectExplorer { virtual void OnNil() = 0; virtual void OnStatus(std::string_view str) = 0; virtual void OnError(std::string_view str) = 0; + + virtual void OnMapStart(unsigned len) { + OnArrayStart(len * 2); + } + + virtual void OnMapEnd() { + OnArrayEnd(); + } }; class Interpreter { diff --git a/src/core/interpreter_test.cc b/src/core/interpreter_test.cc index 3f91a2e75bbb..8a1e4a95b6d9 100644 --- a/src/core/interpreter_test.cc +++ b/src/core/interpreter_test.cc @@ -54,6 +54,16 @@ class TestSerializer : public ObjectExplorer { absl::StrAppend(&res, "nil "); } + void OnMapStart(unsigned len) final { + absl::StrAppend(&res, "{"); + } + + void OnMapEnd() final { + if (res.back() == ' ') + res.pop_back(); + absl::StrAppend(&res, "} "); + } + void OnStatus(std::string_view str) { absl::StrAppend(&res, "status(", str, ") "); } @@ -254,6 +264,9 @@ TEST_F(InterpreterTest, Execute) { EXPECT_TRUE(Execute("return {1,2,3,'ciao', {1,2}}")); EXPECT_EQ("[i(1) i(2) i(3) str(ciao) [i(1) i(2)]]", ser_.res); + + EXPECT_TRUE(Execute("return {map={a=1,b=2}}")); + EXPECT_THAT(ser_.res, testing::AnyOf("{str(a) i(1) str(b) i(2)}", "{str(b) i(2) str(a) i(1)}")); } TEST_F(InterpreterTest, Call) { diff --git a/src/server/dragonfly_test.cc b/src/server/dragonfly_test.cc index a2a2d4eaf4cf..2776f989be47 100644 --- a/src/server/dragonfly_test.cc +++ b/src/server/dragonfly_test.cc @@ -35,6 +35,7 @@ using absl::SetFlag; using absl::StrCat; using fb2::Fiber; using ::io::Result; +using testing::AnyOf; using testing::ElementsAre; using testing::HasSubstr; @@ -144,6 +145,11 @@ TEST_F(DflyEngineTest, EvalResp) { resp = Run({"eval", "return {5, 'foo', 17.5}", "0"}); ASSERT_THAT(resp, ArrLen(3)); EXPECT_THAT(resp.GetVec(), ElementsAre(IntArg(5), "foo", "17.5")); + + resp = Run({"eval", "return {map={a=1,b=2}}", "0"}); + ASSERT_THAT(resp, ArrLen(4)); + EXPECT_THAT(resp.GetVec(), AnyOf(ElementsAre("a", IntArg(1), "b", IntArg(2)), + ElementsAre("b", IntArg(2), "a", IntArg(1)))); } TEST_F(DflyEngineTest, EvalPublish) { diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 254e1a37f175..922edb487132 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -26,8 +26,8 @@ extern "C" { #include "base/logging.h" #include "facade/dragonfly_connection.h" #include "facade/error.h" +#include "facade/reply_builder.h" #include "facade/reply_capture.h" -#include "facade/resp_expr.h" #include "server/acl/acl_commands_def.h" #include "server/acl/acl_family.h" #include "server/acl/user_registry.h" @@ -290,6 +290,7 @@ class InterpreterReplier : public RedisReplyBuilder { unsigned num_elems_ = 0; }; +// Serialized result of script invocation to Redis protocol class EvalSerializer : public ObjectExplorer { public: EvalSerializer(RedisReplyBuilder* rb) : rb_(rb) { @@ -327,6 +328,13 @@ class EvalSerializer : public ObjectExplorer { void OnArrayEnd() final { } + void OnMapStart(unsigned len) final { + rb_->StartCollection(len, RedisReplyBuilder::MAP); + } + + void OnMapEnd() final { + } + void OnNil() final { rb_->SendNull(); }