diff --git a/main.lua b/main.lua index cbf27129a..4902072e6 100644 --- a/main.lua +++ b/main.lua @@ -73,6 +73,7 @@ require 'tracy' xpcall(dofile, log.debug, (ROOT / 'debugger.lua'):string()) +require 'main' require 'cli' local _, service = xpcall(require, log.error, 'service') diff --git a/script/class.lua b/script/class.lua new file mode 100644 index 000000000..1ef6ae544 --- /dev/null +++ b/script/class.lua @@ -0,0 +1,59 @@ +---@class Class +local m = {} + +m._classes = {} + +-- 创建一个类 +---@generic T: string +---@param name `T` +---@param super? string +---@return T +function m.declare(name, super) + if m._classes[name] then + return m._classes[name] + end + local class = {} + class.__index = class + class.__name = name + class.__call = function () end + m._classes[name] = class + + local superClass = m._classes[super] + if superClass then + assert(class ~= superClass, ('class %q can not inherit itself'):format(name)) + setmetatable(class, superClass) + end + + return class +end + +-- 获取一个类 +---@generic T: string +---@param name `T` +---@return T +function m.get(name) + return m._classes[name] +end + +---@generic T: string +---@param name `T` +---@return T +function m.new(name) + local class = m._classes[name] + assert(class, ('class %q not found'):format(name)) + + local instance = setmetatable({}, class) + + return instance +end + +---@param obj any +---@return string? +function m.type(obj) + if type(obj) ~= 'table' then + return nil + end + return obj.__name +end + +return m diff --git a/script/core/definition.lua b/script/core/definition.lua index 3619916e6..a0dd5155f 100644 --- a/script/core/definition.lua +++ b/script/core/definition.lua @@ -169,6 +169,9 @@ return function (uri, offset) if src.type == 'self' then goto CONTINUE end + if src.hideView then + goto CONTINUE + end src = src.field or src.method or src if src.type == 'getindex' or src.type == 'setindex' diff --git a/script/main.lua b/script/main.lua new file mode 100644 index 000000000..88f4d96cf --- /dev/null +++ b/script/main.lua @@ -0,0 +1,14 @@ +local class = require 'class' + +Class = class.declare +New = class.new + +---@class LS +LS = {} + +LS.gc = require 'gc' +LS.timer = require 'timer' +LS.inspect = require 'inspect' +LS.util = require 'utility' +LS.fsu = require 'fs-utility' +LS.furi = require 'file-uri' diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua index c022adcbf..0eff6e14e 100644 --- a/script/parser/luadoc.lua +++ b/script/parser/luadoc.lua @@ -156,6 +156,8 @@ Symbol <- ({} { ---@field calls? parser.object[] ---@field generics? parser.object[] ---@field generic? parser.object +---@field hasGeneric? true +---@field hideView? true local function parseTokens(text, offset) Ci = 0 @@ -334,6 +336,17 @@ local function parseTable(parent) return typeUnit end +local function markHasGeneric(obj) + if not obj or obj.type == 'doc' then + return + end + if obj.hasGeneric then + return + end + obj.hasGeneric = true + markHasGeneric(obj.parent) +end + local function parseSigns(parent) if not checkToken('symbol', '<', 1) then return nil @@ -350,6 +363,8 @@ local function parseSigns(parent) } break end + sign.generic = sign + markHasGeneric(sign) signs[#signs+1] = sign if checkToken('symbol', ',', 1) then nextToken() @@ -836,6 +851,9 @@ local docSwitch = util.switch() } return result end + if extend.type == 'doc.type.table' then + extend.hideView = true + end result.extends[#result.extends+1] = extend result.finish = getFinish() if not checkToken('symbol', ',', 1) then @@ -1652,13 +1670,16 @@ local function bindGeneric(binded) if generics[name] then src.type = 'doc.generic.name' src.generic = generics[name] + markHasGeneric(src) end end) guide.eachSourceType(doc, 'doc.type.code', function (src) local name = src[1] if generics[name] then src.type = 'doc.generic.name' + src.generic = generics[name] src.literal = true + markHasGeneric(src) end end) end diff --git a/script/semantic/init.lua b/script/semantic/init.lua new file mode 100644 index 000000000..d1f067704 --- /dev/null +++ b/script/semantic/init.lua @@ -0,0 +1,83 @@ +require 'semantic.type' +require 'semantic.union' + +local Type = Class 'SType' +local files = require 'files' +local ws = require 'workspace' + +---@class Semantic +Semantic = {} + +Semantic.bindMap = {} + +---@alias SNode +---| SType +---| SUnion + +-- 获取类型对象 +---@param name string +---@return SType +function Semantic.getType(name) + return Type.get(name) +end + +-- 绑定语义 +---@param source parser.object +---@param smt SNode +---@return SNode +function Semantic.bind(source, smt) + Semantic.bindMap[source] = smt + return smt +end + +-- 获取绑定的语义 +---@param source parser.object +---@return SNode? +function Semantic.get(source) + return Semantic.bindMap[source] +end + +-- 清除绑定的语义 +---@param source parser.object +function Semantic.remove(source) + Semantic.bindMap[source] = nil +end + +-- 清空所有绑定的语义 +function Semantic.clear() + Semantic.bindMap = {} +end + +-- 创建联合类型 +---@param a SNode +---@param b SNode +---@return SUnion +function Semantic.newUnion(a, b) + return New 'SUnion' (a, b) +end + +files.watch(function (ev, uri) + if ev == 'update' then + Type.dropByUri(uri) + end + if ev == 'remove' then + Type.dropByUri(uri) + end + if ev == 'compile' then + local state = files.getLastState(uri) + if state then + Type.compileAst(state.ast) + end + end + if ev == 'version' then + if ws.isReady(uri) then + Semantic.clear() + end + end +end) + +ws.watch(function (ev, uri) + if ev == 'reload' then + Semantic.clear() + end +end) diff --git a/script/semantic/subManager.lua b/script/semantic/subManager.lua new file mode 100644 index 000000000..04937ebf9 --- /dev/null +++ b/script/semantic/subManager.lua @@ -0,0 +1,119 @@ +local scope = require 'workspace.scope' + +---@class SubMgr.Link +---@overload fun():self +---@field sets parser.object[] +---@field gets parser.object[] +local SubMgrLink = Class 'SubMgr.Link' + +function SubMgrLink:__call() + self.sets = {} + self.gets = {} + return self +end + +---@class SubMgr +---@overload fun():self +---@field private links table +---@field private setsCache? table +local SubMgr = Class 'SubMgr' + +function SubMgr:__call() + self.links = LS.util.multiTable(2, function () + return New 'GlobalLink' () + end) + return self +end + +-- 向订阅管理器中添加一个订阅者,类型为赋值 +---@param uri uri +---@param obj parser.object +function SubMgr:addSet(uri, obj) + local link = self.links[uri] + table.insert(link.sets, obj) + self:clearCache() +end + +-- 向订阅管理器中添加一个订阅者,类型为获取 +---@param uri uri +---@param obj parser.object +function SubMgr:addGet(uri, obj) + local link = self.links[uri] + table.insert(link.gets, obj) +end + +---@private +function SubMgr:clearCache() + self.setsCache = nil +end + +---@private +---@param token string +---@return boolean hasCached +---@return parser.object[] +function SubMgr:getSetsCache(token) + if not self.setsCache then + self.setsCache = {} + end + local cache = self.setsCache[token] + if cache then + return true, cache + end + if not cache then + cache = {} + self.setsCache[token] = cache + end + return false, cache +end + +-- 获取订阅管理器的所有赋值(基于可见性) +---@param suri uri +---@return parser.object[] +function SubMgr:getSets(suri) + local scp = scope.getScope(suri) + local token = scp.uri or '' + local hasCached, cache = self:getSetsCache(token) + if hasCached then + return cache + end + for uri, link in pairs(self.links) do + if link.sets then + if scp:isVisible(uri) then + for _, source in ipairs(link.sets) do + cache[#cache+1] = source + end + end + end + end + return cache +end + +-- 获取订阅管理器的所有赋值 +---@return parser.object[] +function SubMgr:getAllSets() + local hasCached, cache = self:getSetsCache('*') + if hasCached then + return cache + end + for _, link in pairs(self.links) do + if link.sets then + for _, source in ipairs(link.sets) do + cache[#cache+1] = source + end + end + end + return cache +end + +-- 清空每个路径下的所有订阅者 +---@param uri uri +function SubMgr:dropUri(uri) + self.links[uri] = nil + self:clearCache() +end + +-- 检查是否还有任何订阅者 +---@return boolean +function SubMgr:hasAnyLink() + return next(self.links) ~= nil +end diff --git a/script/semantic/type.lua b/script/semantic/type.lua new file mode 100644 index 000000000..4bbcc3747 --- /dev/null +++ b/script/semantic/type.lua @@ -0,0 +1,124 @@ +local guide = require 'parser.guide' + +---@alias Type.Category 'class'|'alias'|'enum'|'unknown' + +---@class SType +---@overload fun(name: string): self +---@field private name string +---@field private subMgr SubMgr +---@field private cate Type.Category|nil +local Type = Class 'SType' + +---@private +---@type table +Type.allTypes = LS.util.multiTable(2, function (name) + return New 'Type' (name) +end) + +---@private +---@type table> +Type.uriSubs = LS.util.multTable(2) + +---@param name string +function Type:__call(name) + self.name = name + self.subMgr = New 'SubMgr' () + return self +end + +-- 获取类型名称 +---@return string +function Type:getName() + return self.name +end + +-- 获取类型的分类(根据可见性) +---@param suri uri +---@return Type.Category +function Type:getCate(suri) + if self.cate then + return self.cate + end + self.cate = 'unknown' + for _, set in ipairs(self.subMgr:getSets(suri)) do + if set.type == 'doc.class' then + self.cate = 'class' + break + end + if set.type == 'doc.alias' then + self.cate = 'alias' + break + end + if set.type == 'doc.enum' then + self.cate = 'enum' + break + end + end + return self.cate +end + +-- 添加一个订阅(赋值) +---@param uri uri +---@param obj parser.object +function Type:addSet(uri, obj) + self.subMgr:addSet(uri, obj) + Type.uriSubs[uri][self.name] = true +end + +-- 添加一个订阅(获取) +function Type:addGet(uri, obj) + self.subMgr:addGet(uri, obj) + Type.uriSubs[uri][self.name] = true +end + +-- 丢弃链接,如果没有任何链接则移除整个类型 +---@package +---@param uri uri +function Type:dropUri(uri) + self:clearCache() + self.subMgr:dropUri(uri) + if not self.subMgr:hasAnyLink() then + Type.allTypes[self.name] = nil + end +end + +---@private +function Type:clearCache() + self.cate = nil +end + +-- 预编译语法树,绑定所有类型 +---@param source parser.object +function Type.compileAst(source) + local uri = guide.getUri(source) + guide.eachSourceTypes(source.docs, { + 'doc.class', + 'doc.alias', + 'doc.enum', + }, function (src) + local name = guide.getKeyName(src) + if not name then + return + end + local type = Type.get(name) + type:addSet(uri, src) + end) +end + +-- 获取一个类型 +---@param name string +---@return SType +function Type.get(name) + return Type.allTypes[name] +end + +-- 丢弃整个文件内的所有订阅 +---@param uri uri +function Type.dropByUri(uri) + local subs = Type.uriSubs[uri] + Type.uriSubs[uri] = nil + for name in pairs(subs) do + local subType = Type.allTypes[name] + subType:dropUri(uri) + end +end diff --git a/script/semantic/union.lua b/script/semantic/union.lua new file mode 100644 index 000000000..6db82f31e --- /dev/null +++ b/script/semantic/union.lua @@ -0,0 +1,13 @@ +---@class SUnion +---@overload fun(a: SNode, b: SNode): self +---@field package sub1 SNode +---@field package sub2 SNode +local Union = Class 'SUnion' + +---@param a SNode +---@param b SNode +function Union:__call(a, b) + self.sub1 = a + self.sub2 = b + return self +end diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index dc22f6b9b..8a764e37a 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -960,10 +960,17 @@ local function compileForVars(source, target) --> local k, v = iterator(status, initValue) if not source._iterator then source._iterator = { - type = 'dummyfunc', + type = 'forstate.iterator', parent = source, } - source._iterArgs = {{},{}} + source._iterArgs = { + { + type = 'forstate.table' + }, + { + type = 'forstate.initValue' + }, + } source._iterVars = {} end -- iterator @@ -1107,16 +1114,13 @@ local function bindReturnOfFunction(source, mfunc, index, args) local returnNode = vm.compileNode(returnObject) for rnode in returnNode:eachObject() do if rnode.type == 'generic' then - returnNode = rnode:resolve(guide.getUri(source), args) + returnNode = rnode:resolve(guide.getUri(source), args, true) break end end if returnNode then for rnode in returnNode:eachObject() do - -- TODO: narrow type - if rnode.type ~= 'doc.generic.name' then - vm.setNode(source, rnode) - end + vm.setNode(source, rnode) end if returnNode:isOptional() then vm.getNode(source):addOptional() @@ -1415,13 +1419,7 @@ local compilerSwitch = util.switch() if rtn.name then source.name = rtn.name[1] end - local hasGeneric - if sign then - guide.eachSourceType(rtn, 'doc.generic.name', function (src) - hasGeneric = true - end) - end - if hasGeneric then + if rtn.hasGeneric then ---@cast sign -? vm.setNode(source, vm.createGeneric(rtn, sign)) else @@ -1678,22 +1676,15 @@ local compilerSwitch = util.switch() return end for _, set in ipairs(global:getSets(uri)) do - if set.type == 'doc.class' then - if set.extends then - for _, ext in ipairs(set.extends) do - if ext.type == 'doc.type.table' then - if vm.getGeneric(ext) then - local resolved = vm.getGeneric(ext):resolve(uri, source.signs) - vm.setNode(source, resolved) - end - end - end - end - end - if set.type == 'doc.alias' then - if vm.getGeneric(set.extends) then - local resolved = vm.getGeneric(set.extends):resolve(uri, source.signs) + if set.type == 'doc.class' + or set.type == 'doc.alias' then + local sign = vm.getSign(set) + if sign then + local generic = vm.getGeneric(set) + or vm.createGeneric(set, sign) + local resolved = generic:resolve(uri, source.signs) vm.setNode(source, resolved) + break end end end @@ -1823,17 +1814,13 @@ local compilerSwitch = util.switch() if set.extends then for _, ext in ipairs(set.extends) do if ext.type == 'doc.type.table' then - if not vm.getGeneric(ext) then - vm.setNode(source, vm.compileNode(ext)) - end + vm.setNode(source, vm.compileNode(ext)) end end end end if set.type == 'doc.alias' then - if not vm.getGeneric(set.extends) then - vm.setNode(source, vm.compileNode(set.extends)) - end + vm.setNode(source, vm.compileNode(set.extends)) end end end @@ -1924,8 +1911,8 @@ local function compileByParentNode(source) end) end ----@param source vm.node.object | vm.variable ----@return vm.node +---@param source parser.object +---@return SNode function vm.compileNode(source) if not source then if TEST then @@ -1935,20 +1922,20 @@ function vm.compileNode(source) end end - local cache = vm.getNode(source) - if cache ~= nil then - return cache + local semantic = Semantic.get(source) + if semantic ~= nil then + return semantic end - ---@cast source parser.object - vm.setNode(source, vm.createNode(), true) + Semantic.bind(source, Semantic.getType 'unknown') + vm.compileByGlobal(source) vm.compileByVariable(source) compileByNode(source) compileByParentNode(source) matchCall(source) - local node = vm.getNode(source) + local node = Semantic.get(source) ---@cast node -? return node end diff --git a/script/vm/generic.lua b/script/vm/generic.lua index ed832d9bd..e62c1983e 100644 --- a/script/vm/generic.lua +++ b/script/vm/generic.lua @@ -12,26 +12,66 @@ local mt = {} mt.__index = mt mt.type = 'generic' +local function markHasGeneric(obj) + if obj.type == 'doc' then + return + end + if obj.hasGeneric then + return + end + obj.hasGeneric = true + markHasGeneric(obj.parent) +end + ---@param source vm.object? ----@param resolved? table +---@param resolved? table +---@param parent? parser.object +---@param removeUnresolved? boolean ---@return vm.object? -local function cloneObject(source, resolved) - if not resolved or not source then +local function cloneObject(source, resolved, parent, removeUnresolved) + if not source then + return source + end + if not source.hasGeneric then return source end if source.type == 'doc.generic.name' then - local key = source[1] + local generic = source.generic local newName = { - type = source.type, - start = source.start, - finish = source.finish, - parent = source.parent, - [1] = source[1], + type = source.type, + start = source.start, + finish = source.finish, + parent = parent or source.parent, + generic = source.generic, + cloned = true, + [1] = source[1], } - if resolved[key] then - vm.setNode(newName, resolved[key], true) - newName._resolved = resolved[key] + local resolvedNode = resolved and resolved[generic] + if not resolvedNode then + if removeUnresolved then + newName.type = 'doc.generic.resolved' + else + markHasGeneric(newName) + end + return newName end + + vm.setNode(newName, resolvedNode, true) + newName._resolved = resolvedNode + local hasGeneric + for n in resolvedNode:eachObject() do + if n.type == 'doc.generic.name' then + hasGeneric = true + break + end + end + + if hasGeneric then + markHasGeneric(newName) + else + newName.type = 'doc.generic.resolved' + end + return newName end if source.type == 'doc.type' then @@ -39,12 +79,13 @@ local function cloneObject(source, resolved) type = source.type, start = source.start, finish = source.finish, - parent = source.parent, + parent = parent or source.parent, optional = source.optional, + cloned = true, types = {}, } for i, typeUnit in ipairs(source.types) do - local newObj = cloneObject(typeUnit, resolved) + local newObj = cloneObject(typeUnit, resolved, newType) newType.types[i] = newObj end return newType @@ -54,10 +95,11 @@ local function cloneObject(source, resolved) type = source.type, start = source.start, finish = source.finish, - parent = source.parent, + parent = parent or source.parent, name = source.name, - extends = cloneObject(source.extends, resolved) + cloned = true, } + newArg.extends = cloneObject(source.extends, resolved, newArg) return newArg end if source.type == 'doc.type.array' then @@ -65,9 +107,10 @@ local function cloneObject(source, resolved) type = source.type, start = source.start, finish = source.finish, - parent = source.parent, - node = cloneObject(source.node, resolved), + parent = parent or source.parent, + cloned = true, } + newArray.node = cloneObject(source.node, resolved, newArray) return newArray end if source.type == 'doc.type.table' then @@ -75,19 +118,25 @@ local function cloneObject(source, resolved) type = source.type, start = source.start, finish = source.finish, - parent = source.parent, + parent = parent or source.parent, + cloned = true, fields = {}, } for i, field in ipairs(source.fields) do - local newField = { - type = field.type, - start = field.start, - finish = field.finish, - parent = newTable, - name = cloneObject(field.name, resolved), - extends = cloneObject(field.extends, resolved), - } - newTable.fields[i] = newField + if field.hasGeneric then + local newField = { + type = field.type, + start = field.start, + finish = field.finish, + parent = newTable, + cloned = true, + } + newField.name = cloneObject(field.name, resolved, newField) + newField.extends = cloneObject(field.extends, resolved, newField) + newTable.fields[i] = newField + else + newTable.fields[i] = field + end end return newTable end @@ -96,20 +145,20 @@ local function cloneObject(source, resolved) type = source.type, start = source.start, finish = source.finish, - parent = source.parent, + parent = parent or source.parent, + cloned = true, args = {}, returns = {}, } for i, arg in ipairs(source.args) do - local newObj = cloneObject(arg, resolved) + local newObj = cloneObject(arg, resolved, newDocFunc) newObj.optional = arg.optional newDocFunc.args[i] = newObj end for i, ret in ipairs(source.returns) do - local newObj = cloneObject(ret, resolved) - newObj.parent = newDocFunc + local newObj = cloneObject(ret, resolved, newDocFunc) newObj.optional = ret.optional - newDocFunc.returns[i] = cloneObject(ret, resolved) + newDocFunc.returns[i] = cloneObject(ret, resolved, newDocFunc) end return newDocFunc end @@ -118,18 +167,29 @@ end ---@param uri uri ---@param args parser.object +---@param removeUnresolved? boolean ---@return vm.node -function mt:resolve(uri, args) +function mt:resolve(uri, args, removeUnresolved) local resolved = self.sign:resolve(uri, args) local protoNode = vm.compileNode(self.proto) local result = vm.createNode() for nd in protoNode:eachObject() do - if nd.type == 'global' or nd.type == 'variable' then - ---@cast nd vm.global | vm.variable + if nd.type == 'global' then + ---@cast nd vm.global + if nd.cate == 'variable' then + result:merge(nd) + end + if nd.cate == 'type' then + if not nd:getSigns(uri) then + result:merge(nd) + end + end + elseif nd.type == 'variable' then + ---@cast nd vm.variable result:merge(nd) else ---@cast nd -vm.global, -vm.variable - local clonedObject = cloneObject(nd, resolved) + local clonedObject = cloneObject(nd, resolved, nil, removeUnresolved) if clonedObject then local clonedNode = vm.compileNode(clonedObject) result:merge(clonedNode) @@ -148,12 +208,6 @@ function vm.getGenericResolved(source) return source._resolved end ----@param source parser.object ----@param generic vm.generic -function vm.setGeneric(source, generic) - source._generic = generic -end - ---@param source parser.object ---@return vm.generic? function vm.getGeneric(source) @@ -168,5 +222,6 @@ function vm.createGeneric(proto, sign) sign = sign, proto = proto, }, mt) + proto._generic = generic return generic end diff --git a/script/vm/global.lua b/script/vm/global.lua index c1b5f3205..f3bbd61d9 100644 --- a/script/vm/global.lua +++ b/script/vm/global.lua @@ -22,6 +22,7 @@ local globalSubs = util.multiTable(2) ---@field links table ---@field setsCache? table ---@field cate vm.global.cate +---@field signsCache? parser.object[]|false local mt = {} mt.__index = mt mt.type = 'global' @@ -33,6 +34,7 @@ function mt:addSet(uri, source) local link = self.links[uri] link.sets[#link.sets+1] = source self.setsCache = nil + self.signsCache = nil end ---@param uri uri @@ -72,6 +74,22 @@ function mt:getSets(suri) return cache end +---@param uri uri +---@return parser.object[]? +function mt:getSigns(uri) + if self.signsCache ~= nil then + return self.signsCache or nil + end + for _, set in ipairs(self:getSets(uri)) do + if set.signs then + self.signsCache = set.signs + return set.signs + end + end + self.signsCache = false + return nil +end + ---@return parser.object[] function mt:getAllSets() if not self.setsCache then @@ -97,6 +115,7 @@ end function mt:dropUri(uri) self.links[uri] = nil self.setsCache = nil + self.signsCache = nil end ---@return string @@ -320,21 +339,6 @@ local compilerGlobalSwitch = util.switch() local class = vm.declareGlobal('type', name, uri) class:addSet(uri, source) source._globalNode = class - - if source.signs then - local sign = vm.createSign() - vm.setSign(source, sign) - for _, obj in ipairs(source.signs) do - sign:addSign(vm.compileNode(obj)) - end - if source.extends then - for _, ext in ipairs(source.extends) do - if ext.type == 'doc.type.table' then - vm.setGeneric(ext, vm.createGeneric(ext, sign)) - end - end - end - end end) : case 'doc.alias' : call(function (source) @@ -346,14 +350,6 @@ local compilerGlobalSwitch = util.switch() local alias = vm.declareGlobal('type', name, uri) alias:addSet(uri, source) source._globalNode = alias - - if source.signs then - source._sign = vm.createSign() - for _, sign in ipairs(source.signs) do - source._sign:addSign(vm.compileNode(sign)) - end - source.extends._generic = vm.createGeneric(source.extends, source._sign) - end end) : case 'doc.enum' : call(function (source) @@ -597,6 +593,7 @@ function vm.getGlobalBase(source) type = 'globalbase', parent = root, global = global, + signs = source.signs, start = 0, finish = 0, } diff --git a/script/vm/infer.lua b/script/vm/infer.lua index 94fdfd887..64a7e40b3 100644 --- a/script/vm/infer.lua +++ b/script/vm/infer.lua @@ -147,10 +147,14 @@ local viewNodeSwitch;viewNodeSwitch = util.switch() end end end + infer._drop[source.node[1]] = true return ('%s<%s>'):format(source.node[1], table.concat(buf, ', ')) end) : case 'doc.type.table' : call(function (source, infer, uri) + if source.hideView then + return + end if #source.fields == 0 then infer._hasTable = true return diff --git a/script/vm/node.lua b/script/vm/node.lua index 0ffd8c70f..8a33b40ad 100644 --- a/script/vm/node.lua +++ b/script/vm/node.lua @@ -397,6 +397,7 @@ function mt:copy() return vm.createNode(self) end +---@deprecated ---@param source vm.node.object | vm.generic ---@param node vm.node | vm.node.object ---@param cover? boolean diff --git a/script/vm/sign.lua b/script/vm/sign.lua index 1f4344758..1890458a3 100644 --- a/script/vm/sign.lua +++ b/script/vm/sign.lua @@ -1,11 +1,12 @@ local guide = require 'parser.guide' ---@class vm local vm = require 'vm.vm' +local util = require 'utility' ---@class vm.sign ----@field parent parser.object ----@field signList vm.node[] ----@field docGenric parser.object[] +---@field parent parser.object +---@field signList vm.node[] +---@field genericObjects parser.object[] local mt = {} mt.__index = mt mt.type = 'sign' @@ -15,20 +16,26 @@ function mt:addSign(node) self.signList[#self.signList+1] = node end ----@param doc parser.object -function mt:addDocGeneric(doc) - self.docGenric[#self.docGenric+1] = doc +---@param object parser.object +function mt:addGenericObject(object) + self.genericObjects[#self.genericObjects+1] = object +end + +---@param generic parser.object +---@return boolean +function mt:isValidGeneric(generic) + return util.arrayHas(self.genericObjects, generic) end ---@param uri uri ---@param args parser.object ----@return table? +---@return table? function mt:resolve(uri, args) if not args then return nil end - ---@type table + ---@type table local resolved = {} ---@param object vm.node|vm.node.object @@ -46,31 +53,30 @@ function mt:resolve(uri, args) return end if object.type == 'doc.generic.name' then - ---@type string - local key = object[1] + local generic = object.generic + if not generic or not self:isValidGeneric(generic) then + return + end if object.literal then -- 'number' -> `T` for n in node:eachObject() do if n.type == 'string' then ---@cast n parser.object local type = vm.declareGlobal('type', n[1], guide.getUri(n)) - resolved[key] = vm.createNode(type, resolved[key]) + resolved[generic] = vm.createNode(type, resolved[generic]) end end else -- number -> T for n in node:eachObject() do - if n.type ~= 'doc.generic.name' - and n.type ~= 'generic' then - if resolved[key] then - resolved[key]:merge(n) - else - resolved[key] = vm.createNode(n) - end + if resolved[generic] then + resolved[generic]:merge(n) + else + resolved[generic] = vm.createNode(n) end end - if resolved[key] and node:isOptional() then - resolved[key]:addOptional() + if resolved[generic] and node:isOptional() then + resolved[generic]:addOptional() end end return @@ -147,7 +153,9 @@ function mt:resolve(uri, args) or n.type == 'doc.type.function' then ---@cast n parser.object local farg = n.args and n.args[i] - if farg then + if farg + and not farg.hasGeneric + and farg.type ~= 'generic' then resolve(arg.extends, vm.compileNode(farg)) end end @@ -159,7 +167,9 @@ function mt:resolve(uri, args) or n.type == 'doc.type.function' then ---@cast n parser.object local fret = vm.getReturnOfFunction(n, i) - if fret then + if fret + and not fret.hasGeneric + and fret.type ~= 'generic' then resolve(ret, vm.compileNode(fret)) end end @@ -171,39 +181,27 @@ function mt:resolve(uri, args) ---@param sign vm.node ---@return table - ---@return table local function getSignInfo(sign) local knownTypes = {} - local genericsNames = {} for obj in sign:eachObject() do - if obj.type == 'doc.generic.name' then - genericsNames[obj[1]] = true + if obj.hasGeneric then goto CONTINUE end - if obj.type == 'doc.type.table' - or obj.type == 'doc.type.function' - or obj.type == 'doc.type.array' then - ---@cast obj parser.object - local hasGeneric - guide.eachSourceType(obj, 'doc.generic.name', function (src) - hasGeneric = true - genericsNames[src[1]] = true - end) - if hasGeneric then - goto CONTINUE - end - end if obj.type == 'variable' or obj.type == 'local' then goto CONTINUE end + if obj.type == 'global' + and obj:getSigns(uri) then + goto CONTINUE + end local view = vm.getInfer(obj):view(uri) if view then knownTypes[view] = true end ::CONTINUE:: end - return knownTypes, genericsNames + return knownTypes end -- remove un-generic type @@ -236,9 +234,8 @@ function mt:resolve(uri, args) return newArgNode end - ---@param genericNames table - local function isAllResolved(genericNames) - for n in pairs(genericNames) do + local function isAllResolved() + for _, n in ipairs(self.genericObjects) do if not resolved[n] then return false end @@ -252,9 +249,9 @@ function mt:resolve(uri, args) break end local argNode = vm.compileNode(arg) - local knownTypes, genericNames = getSignInfo(sign) - if not isAllResolved(genericNames) then - local newArgNode = buildArgNode(argNode,sign, knownTypes) + local knownTypes = getSignInfo(sign) + if not isAllResolved() then + local newArgNode = buildArgNode(argNode, sign, knownTypes) resolve(sign, newArgNode) end end @@ -265,8 +262,8 @@ end ---@return vm.sign function vm.createSign() local genericMgr = setmetatable({ - signList = {}, - docGenric = {}, + signList = {}, + genericObjects = {}, }, mt) return genericMgr end @@ -296,7 +293,9 @@ function vm.getSign(source) if not source._sign then source._sign = vm.createSign() end - source._sign:addDocGeneric(doc) + for _, object in ipairs(doc.generics) do + source._sign:addGenericObject(object) + end end end if not source._sign then @@ -312,30 +311,41 @@ function vm.getSign(source) end end end - if source.type == 'doc.type.function' - or source.type == 'doc.type.table' - or source.type == 'doc.type.array' then - local hasGeneric - guide.eachSourceType(source, 'doc.generic.name', function (_) - hasGeneric = true - end) - if not hasGeneric then + if source.type == 'doc.type.function' then + if not source.hasGeneric then return nil end source._sign = vm.createSign() - if source.type == 'doc.type.function' then + if source.args then for _, arg in ipairs(source.args) do - if arg.extends then - local argNode = vm.compileNode(arg.extends) - if arg.optional then - argNode:addOptional() - end - source._sign:addSign(argNode) - else - source._sign:addSign(vm.createNode()) + local argNode = vm.compileNode(arg) + if arg.optional then + argNode:addOptional() end + source._sign:addSign(argNode) end end + local mark = {} + guide.eachSourceType(source, 'doc.generic.name', function (src) + local genericObject = src.generic + assert(genericObject) + if not mark[genericObject] then + mark[genericObject] = true + source._sign:addGenericObject(genericObject) + end + end) + end + if source.type == 'doc.class' + or source.type == 'doc.alias' then + if not source.signs then + return nil + end + source._sign = vm.createSign() + for _, sign in ipairs(source.signs) do + local signNode = vm.compileNode(sign) + source._sign:addSign(signNode) + source._sign:addGenericObject(sign.generic) + end end return source._sign or nil end diff --git a/script/vm/type.lua b/script/vm/type.lua index 8382eb869..814eeee24 100644 --- a/script/vm/type.lua +++ b/script/vm/type.lua @@ -165,12 +165,13 @@ local function checkChildEnum(childName, parent , uri, mark, errs) return true end +---@param uri uri ---@param parent vm.node.object ---@param child vm.node.object ---@param mark table ---@param errs? typecheck.err[] ---@return boolean -local function checkValue(parent, child, mark, errs) +local function checkValue(uri, parent, child, mark, errs) if parent.type == 'doc.type.integer' then if child.type == 'integer' or child.type == 'doc.type.integer' @@ -226,7 +227,6 @@ local function checkValue(parent, child, mark, errs) end ---@cast parent parser.object ---@cast child parser.object - local uri = guide.getUri(parent) local tnode = vm.compileNode(child) for _, pfield in ipairs(parent.fields) do local knode = vm.compileNode(pfield.name) @@ -404,7 +404,7 @@ function vm.isSubType(uri, child, parent, mark, errs) end if childName == parentName then - if not checkValue(parent, child, mark, errs) then + if not checkValue(uri, parent, child, mark, errs) then return false end return true @@ -517,7 +517,7 @@ function vm.getTableValue(uri, tnode, knode, inversion) for tn in tnode:eachObject() do if tn.type == 'doc.type.table' then for _, field in ipairs(tn.fields) do - if field.extends then + if field.extends and not field.hasGeneric then if inversion then if vm.isSubType(uri, vm.compileNode(field.name), knode) then result:merge(vm.compileNode(field.extends)) @@ -591,7 +591,8 @@ function vm.getTableKey(uri, tnode, vnode, reverse) if tn.type == 'doc.type.table' then for _, field in ipairs(tn.fields) do if field.name.type ~= 'doc.field.name' - and field.extends then + and field.extends + and not field.hasGeneric then if reverse then if vm.isSubType(uri, vm.compileNode(field.extends), vnode) then result:merge(vm.compileNode(field.name)) diff --git a/test.lua b/test.lua index 9e596e880..d9eefb07d 100644 --- a/test.lua +++ b/test.lua @@ -25,6 +25,7 @@ LOCALE = 'zh-cn' --dofile((ROOT / 'build_package.lua'):string()) require 'tracy' +require 'main' local function loadAllLibs() assert(require 'bee.filesystem') @@ -52,8 +53,8 @@ end local function testAll() test 'basic' - test 'definition' test 'type_inference' + test 'definition' test 'references' test 'hover' test 'completion' diff --git a/test/definition/luadoc.lua b/test/definition/luadoc.lua index b877f5cd5..a6e62b60a 100644 --- a/test/definition/luadoc.lua +++ b/test/definition/luadoc.lua @@ -152,7 +152,7 @@ local !> ]] TEST [[ ----@class A: +---@class A: {} ---@type A local !> @@ -647,7 +647,7 @@ TEST [[ ---@type TT local t ----@class A: +---@class A: {} print(t.) ]] @@ -658,7 +658,7 @@ TEST [[ ---@type TT local t ----@class A: +---@class A: {} print(t.) ]] diff --git a/test/diagnostics/type-check.lua b/test/diagnostics/type-check.lua index 18e7190d9..cd2bfb617 100644 --- a/test/diagnostics/type-check.lua +++ b/test/diagnostics/type-check.lua @@ -1255,6 +1255,26 @@ local var func(var) ]] +TEST [[ +---@type string|table +local x + +---@type string|table +local y + +x = y +]] + +TEST [[ +---@type table +local x + +---@type table +local y + +x = y +]] + config.remove(nil, 'Lua.diagnostics.disable', 'unused-local') config.remove(nil, 'Lua.diagnostics.disable', 'unused-function') config.remove(nil, 'Lua.diagnostics.disable', 'undefined-global') diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index 55b4f7afe..0eae5a8d1 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -242,6 +242,27 @@ table = {} () ]] +TEST 'X' [[ +---@class X + +---@type X +local +]] + +TEST 'X' [[ +---@class X + +---@type X +local +]] + +TEST 'X' [[ +---@class X + +---@type X +local +]] + TEST 'string' [[ _VERSION = 'Lua 5.4' @@ -286,6 +307,11 @@ _VERSION = 'Lua 5.4' = _VERSION.xxx ]] +TEST 'table' [[ +---@type table +local +]] + TEST 'table' [[ = setmetatable({}) ]] @@ -632,10 +658,10 @@ local k, v = next() TEST 'string' [[ ---@class string ----@generic K, V ----@param t table ----@return K ----@return V +---@generic KK, VV +---@param t table +---@return KK +---@return VV local function next(t) end ---@type table @@ -3846,8 +3872,8 @@ TEST 'integer[]' [[ ---@return T[] local function x(f) end ----@param x integer -local = x(function (x) end) +---@param y integer +local = x(function (y) end) ]] TEST 'integer[]' [[ @@ -4298,3 +4324,62 @@ local x = 1 repeat until ]] + +TEST 'integer' [[ +---@generic A, B, C +---@type fun(x: A, y: B, z: C):C, B, A +local f + +local , y, z = f(true, '', 1) +]] + +TEST 'string' [[ +---@generic A, B, C +---@type fun(x: A, y: B, z: C):C, B, A +local f + +local x, , z = f(true, '', 1) +]] + +TEST 'boolean' [[ +---@generic A, B, C +---@type fun(x: A, y: B, z: C):C, B, A +local f + +local x, y, = f(true, '', 1) +]] + +TEST '[]' [[ +---@generic A +---@param x A[] +local function f(x) + local v = [1] +end +]] + +TEST '' [[ +---@generic A +---@param x A[] +local function f(x) + local = x[1] +end +]] + +TEST '' [[ +---@generic A +---@param x A[] +local function f(x) + ---@generic B + ---@param y B[] + ---@return B + local function g(y) end + + local = g(x) +end +]] + +TEST 'unknown' [[ +local t = setmetatable({}, {}) + +local = t[x] +]]