Skip to content

[IDE] Avoid uses of isBeforeInBuffer in TypeCheckASTNodeAtLocRequest #81028

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 14, 2025
9 changes: 3 additions & 6 deletions include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2523,9 +2523,7 @@ class PatternBindingDecl final : public Decl,
Pattern *Pat, Expr *E,
DeclContext *Parent);

SourceLoc getStartLoc() const {
return StaticLoc.isValid() ? StaticLoc : VarLoc;
}
SourceLoc getStartLoc() const;
SourceRange getSourceRange() const;

unsigned getNumPatternEntries() const {
Expand Down Expand Up @@ -8413,9 +8411,8 @@ class FuncDecl : public AbstractFunctionDecl {
SourceLoc getStaticLoc() const { return StaticLoc; }
SourceLoc getFuncLoc() const { return FuncLoc; }

SourceLoc getStartLoc() const {
return StaticLoc.isValid() ? StaticLoc : FuncLoc;
}
SourceLoc getStartLoc() const;
SourceLoc getEndLoc() const;
SourceRange getSourceRange() const;

TypeRepr *getResultTypeRepr() const { return FnRetType.getTypeRepr(); }
Expand Down
106 changes: 69 additions & 37 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2528,6 +2528,19 @@ StringRef PatternBindingEntry::getInitStringRepresentation(
return extractInlinableText(ctx, init, scratch);
}

SourceLoc PatternBindingDecl::getStartLoc() const {
if (StaticLoc.isValid())
return StaticLoc;

if (VarLoc.isValid())
return VarLoc;

if (getPatternList().empty())
return SourceLoc();

return getPatternList().front().getStartLoc();
}

SourceRange PatternBindingDecl::getSourceRange() const {
SourceLoc startLoc = getStartLoc();
SourceLoc endLoc = getPatternList().empty()
Expand Down Expand Up @@ -10308,30 +10321,20 @@ SourceRange AbstractFunctionDecl::getSignatureSourceRange() const {
if (isImplicit())
return SourceRange();

SourceLoc endLoc;

// name(parameter list...) async throws(E)
SourceRange thrownTypeRange;
if (auto *typeRepr = getThrownTypeRepr())
endLoc = typeRepr->getSourceRange().End;
if (endLoc.isInvalid())
endLoc = getThrowsLoc();
if (endLoc.isInvalid())
endLoc = getAsyncLoc();
thrownTypeRange = typeRepr->getSourceRange();

if (endLoc.isInvalid())
return getParameterListSourceRange();
return SourceRange(getNameLoc(), endLoc);
// name(parameter list...) async throws(E)
return SourceRange::combine(getParameterListSourceRange(), getAsyncLoc(),
getThrowsLoc(), thrownTypeRange);
}

SourceRange AbstractFunctionDecl::getParameterListSourceRange() const {
if (isImplicit())
return SourceRange();

auto endLoc = getParameters()->getSourceRange().End;
if (endLoc.isInvalid())
return getNameLoc();

return SourceRange(getNameLoc(), endLoc);
return SourceRange::combine(getNameLoc(), getParameters()->getSourceRange());
}

std::optional<Fingerprint> AbstractFunctionDecl::getBodyFingerprint() const {
Expand Down Expand Up @@ -11454,33 +11457,62 @@ DestructorDecl *DestructorDecl::getSuperDeinit() const {
return nullptr;
}

SourceRange FuncDecl::getSourceRange() const {
SourceLoc startLoc = getStartLoc();
SourceLoc FuncDecl::getStartLoc() const {
if (StaticLoc)
return StaticLoc;

if (startLoc.isInvalid())
return SourceRange();
if (FuncLoc)
return FuncLoc;

if (getBodyKind() == BodyKind::Unparsed)
return { startLoc, BodyRange.End };
auto nameLoc = getNameLoc();
if (nameLoc)
return nameLoc;

SourceLoc endLoc = getOriginalBodySourceRange().End;
if (endLoc.isInvalid()) {
if (isa<AccessorDecl>(this))
return startLoc;
auto sigStart = getSignatureSourceRange().Start;
if (sigStart)
return sigStart;

if (getBodyKind() == BodyKind::Synthesize)
return SourceRange();
auto resultTyStart = getResultTypeSourceRange().Start;
if (resultTyStart)
return resultTyStart;

endLoc = getGenericTrailingWhereClauseSourceRange().End;
}
if (endLoc.isInvalid())
endLoc = getResultTypeSourceRange().End;
if (endLoc.isInvalid())
endLoc = getSignatureSourceRange().End;
if (endLoc.isInvalid())
endLoc = startLoc;
auto genericWhereStart = getGenericTrailingWhereClauseSourceRange().Start;
if (genericWhereStart)
return genericWhereStart;

return { startLoc, endLoc };
auto bodyStart = getOriginalBodySourceRange().Start;
if (bodyStart)
return bodyStart;

return SourceLoc();
}

SourceLoc FuncDecl::getEndLoc() const {
auto bodyEnd = getOriginalBodySourceRange().End;
if (bodyEnd)
return bodyEnd;

auto genericWhereEnd = getGenericTrailingWhereClauseSourceRange().End;
if (genericWhereEnd)
return genericWhereEnd;

auto resultTyEnd = getResultTypeSourceRange().End;
if (resultTyEnd)
return resultTyEnd;

auto sigEnd = getSignatureSourceRange().End;
if (sigEnd)
return sigEnd;

return getStartLoc();
}

SourceRange FuncDecl::getSourceRange() const {
SourceLoc startLoc = getStartLoc();
if (startLoc.isInvalid())
return SourceRange();

return { startLoc, getEndLoc() };
}

EnumElementDecl::EnumElementDecl(SourceLoc IdentifierLoc, DeclName Name,
Expand Down
2 changes: 1 addition & 1 deletion lib/Basic/SourceLoc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ namespace {

std::optional<unsigned>
SourceManager::findBufferContainingLocInternal(SourceLoc Loc) const {
assert(Loc.isValid());
ASSERT(Loc.isValid());

// If the cache is out-of-date, update it now.
unsigned numBuffers = LLVMSourceMgr.getNumBuffers();
Expand Down
3 changes: 0 additions & 3 deletions lib/IDE/PostfixCompletion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,6 @@ getClosureActorIsolation(const Solution &S, AbstractClosureExpr *ACE) {
if (auto Ty = target->getClosureContextualType())
return Ty;
}
if (!S.hasType(E)) {
return Type();
}
return getTypeForCompletion(S, E);
};
auto getClosureActorIsolationThunk = [&S](AbstractClosureExpr *ACE) {
Expand Down
2 changes: 1 addition & 1 deletion lib/IDE/TypeCheckCompletionCallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Type swift::ide::getTypeForCompletion(const constraints::Solution &S,
}

if (!S.hasType(Node)) {
assert(false && "Expression wasn't type checked?");
CONDITIONAL_ASSERT(false && "Expression wasn't type checked?");
return nullptr;
}

Expand Down
2 changes: 1 addition & 1 deletion lib/Parse/ParseDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9009,7 +9009,7 @@ Parser::parseAbstractFunctionBodyDelayed(AbstractFunctionDecl *AFD) {
auto bodyRange = AFD->getBodySourceRange();
auto BeginParserPosition = getParserPosition(bodyRange.Start,
/*previousLoc*/ SourceLoc());
auto EndLexerState = L->getStateForEndOfTokenLoc(AFD->getEndLoc());
auto EndLexerState = L->getStateForEndOfTokenLoc(bodyRange.End);

// ParserPositionRAII needs a primed parser to restore to.
if (Tok.is(tok::NUM_TOKENS))
Expand Down
2 changes: 1 addition & 1 deletion lib/Parse/ParseExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3639,7 +3639,7 @@ Parser::parseExprCollectionElement(std::optional<bool> &isDictionary) {
} else {
diagnose(Tok, diag::expected_colon_in_dictionary_literal);
Value = makeParserResult(makeParserError(),
new (Context) ErrorExpr(SourceRange()));
new (Context) ErrorExpr(PreviousLoc));
}

// Make a tuple of Key Value pair.
Expand Down
72 changes: 38 additions & 34 deletions lib/Sema/TypeCheckStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2462,12 +2462,29 @@ bool TypeCheckASTNodeAtLocRequest::evaluate(
return MacroWalking::ArgumentsAndExpansion;
}

/// Checks whether the given range, when treated as a character range,
/// contains the searched location.
bool charRangeContainsLoc(SourceRange range) {
if (!range)
return false;

if (SM.isBefore(Loc, range.Start))
return false;

// NOTE: We need to check the character loc here because the target
// loc can be inside the last token of the node. i.e. interpolated
// string.
return SM.isBefore(Loc, Lexer::getLocForEndOfToken(SM, range.End));
}

PreWalkResult<Stmt *> walkToStmtPre(Stmt *S) override {
if (auto *brace = dyn_cast<BraceStmt>(S)) {
auto braceCharRange = Lexer::getCharSourceRangeFromSourceRange(
SM, brace->getSourceRange());
auto braceRange = brace->getSourceRange();
auto braceCharRange = SourceRange(
braceRange.Start, Lexer::getLocForEndOfToken(SM, braceRange.End));

// Unless this brace contains the loc, there's nothing to do.
if (!braceCharRange.contains(Loc))
if (!SM.containsLoc(braceCharRange, Loc))
return Action::SkipNode(S);

// Reset the node found in a parent context if it's not part of this
Expand All @@ -2477,22 +2494,22 @@ bool TypeCheckASTNodeAtLocRequest::evaluate(
// syntactically part of the brace stmt's range but won't be walked as
// a child of the brace stmt.
if (!brace->isImplicit() && FoundNode) {
auto foundNodeCharRange = Lexer::getCharSourceRangeFromSourceRange(
SM, FoundNode->getSourceRange());
if (!braceCharRange.contains(foundNodeCharRange)) {
auto foundRange = FoundNode->getSourceRange();
auto foundCharRange = SourceRange(
foundRange.Start, Lexer::getLocForEndOfToken(SM, foundRange.End));
if (!SM.encloses(braceCharRange, foundCharRange))
FoundNode = nullptr;
}
}

for (ASTNode &node : brace->getElements()) {
if (SM.isBeforeInBuffer(Loc, node.getStartLoc()))
auto range = node.getSourceRange();
if (SM.isBefore(Loc, range.Start))
break;

// NOTE: We need to check the character loc here because the target
// loc can be inside the last token of the node. i.e. interpolated
// string.
SourceLoc endLoc = Lexer::getLocForEndOfToken(SM, node.getEndLoc());
if (SM.isBeforeInBuffer(endLoc, Loc) || endLoc == Loc)
if (!SM.isBefore(Loc, Lexer::getLocForEndOfToken(SM, range.End)))
continue;

// 'node' may be the target node, except 'CaseStmt' which cannot be
Expand All @@ -2509,13 +2526,11 @@ bool TypeCheckASTNodeAtLocRequest::evaluate(
return Action::Stop();
} else if (auto Conditional = dyn_cast<LabeledConditionalStmt>(S)) {
for (StmtConditionElement &Cond : Conditional->getCond()) {
if (SM.isBeforeInBuffer(Loc, Cond.getStartLoc())) {
auto range = Cond.getSourceRange();
if (SM.isBefore(Loc, range.Start))
break;
}
SourceLoc endLoc = Lexer::getLocForEndOfToken(SM, Cond.getEndLoc());
if (SM.isBeforeInBuffer(endLoc, Loc) || endLoc == Loc) {
if (!SM.isBefore(Loc, Lexer::getLocForEndOfToken(SM, range.End)))
continue;
}

FoundNodeStorage = ASTNode(&Cond);
FoundNode = &FoundNodeStorage;
Expand All @@ -2527,11 +2542,7 @@ bool TypeCheckASTNodeAtLocRequest::evaluate(
}

PreWalkResult<Expr *> walkToExprPre(Expr *E) override {
if (SM.isBeforeInBuffer(Loc, E->getStartLoc()))
return Action::SkipNode(E);

SourceLoc endLoc = Lexer::getLocForEndOfToken(SM, E->getEndLoc());
if (SM.isBeforeInBuffer(endLoc, Loc))
if (!charRangeContainsLoc(E->getSourceRange()))
return Action::SkipNode(E);

// Don't walk into 'TapExpr'. They should be type checked with parent
Expand All @@ -2546,9 +2557,7 @@ bool TypeCheckASTNodeAtLocRequest::evaluate(
if (auto *SVE = dyn_cast<SingleValueStmtExpr>(E)) {
SmallVector<Expr *> scratch;
for (auto *result : SVE->getResultExprs(scratch)) {
auto resultCharRange = Lexer::getCharSourceRangeFromSourceRange(
SM, result->getSourceRange());
if (resultCharRange.contains(Loc)) {
if (charRangeContainsLoc(result->getSourceRange())) {
if (!result->walk(*this))
return Action::Stop();

Expand All @@ -2570,20 +2579,15 @@ bool TypeCheckASTNodeAtLocRequest::evaluate(
}

PreWalkAction walkToDeclPre(Decl *D) override {
if (!charRangeContainsLoc(D->getSourceRange()))
return Action::SkipNode();

if (auto *newDC = dyn_cast<DeclContext>(D))
DC = newDC;

if (!SM.isBeforeInBuffer(Loc, D->getStartLoc())) {
// NOTE: We need to check the character loc here because the target
// loc can be inside the last token of the node. i.e. interpolated
// string.
SourceLoc endLoc = Lexer::getLocForEndOfToken(SM, D->getEndLoc());
if (!(SM.isBeforeInBuffer(endLoc, Loc) || endLoc == Loc)) {
if (!isa<TopLevelCodeDecl>(D)) {
FoundNodeStorage = ASTNode(D);
FoundNode = &FoundNodeStorage;
}
}
if (!isa<TopLevelCodeDecl>(D)) {
FoundNodeStorage = ASTNode(D);
FoundNode = &FoundNodeStorage;
}
return Action::Continue();
}
Expand Down
4 changes: 2 additions & 2 deletions test/Concurrency/async_main_resolution.swift
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ extension MainProtocol {
#endif

// CHECK-IS-SYNC-LABEL: "MyMain" interface_type="MyMain.Type"
// CHECK-IS-SYNC: (func_decl {{.*}}implicit "$main()" interface_type="(MyMain.Type) -> () -> ()"
// CHECK-IS-SYNC: (func_decl {{.*}}implicit range={{.*}} "$main()" interface_type="(MyMain.Type) -> () -> ()"
// CHECK-IS-SYNC: (declref_expr implicit type="(MyMain.Type) -> () -> ()"

// CHECK-IS-ASYNC-LABEL: "MyMain" interface_type="MyMain.Type"
// CHECK-IS-ASYNC: (func_decl {{.*}}implicit "$main()" interface_type="(MyMain.Type) -> () async -> ()"
// CHECK-IS-ASYNC: (func_decl {{.*}}implicit range={{.*}} "$main()" interface_type="(MyMain.Type) -> () async -> ()"
// CHECK-IS-ASYNC: (declref_expr implicit type="(MyMain.Type) -> () async -> ()"

// CHECK-IS-ERROR1: error: 'MyMain' is annotated with '@main' and must provide a main static function of type {{\(\) -> Void or \(\) throws -> Void|\(\) -> Void, \(\) throws -> Void, \(\) async -> Void, or \(\) async throws -> Void}}
Expand Down
6 changes: 3 additions & 3 deletions test/Concurrency/where_clause_main_resolution.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,19 @@ protocol App {
// CHECK-SAME: interface_type="<Self where Self : App> (Self.Type) -> () async -> ()"

extension App where Configuration == Config1 {
// CHECK-CONFIG1: (func_decl {{.*}}implicit "$main()" interface_type="(MainType.Type) -> () -> ()"
// CHECK-CONFIG1: (func_decl {{.*}}implicit range=[{{.*}}:[[@LINE+20]]:1 - line:[[@LINE+20]]:1] "$main()" interface_type="(MainType.Type) -> () -> ()"
// CHECK-CONFIG1: [[SOURCE_FILE]]:[[# @LINE+1 ]]
static func main() { }
}

extension App where Configuration == Config2 {
// CHECK-CONFIG2: (func_decl {{.*}}implicit "$main()" interface_type="(MainType.Type) -> () async -> ()"
// CHECK-CONFIG2: (func_decl {{.*}}implicit range=[{{.*}}:[[@LINE+14]]:1 - line:[[@LINE+14]]:1] "$main()" interface_type="(MainType.Type) -> () async -> ()"
// CHECK-CONFIG2: [[SOURCE_FILE]]:[[# @LINE+1 ]]
static func main() async { }
}

extension App where Configuration == Config3 {
// CHECK-CONFIG3-ASYNC: (func_decl {{.*}}implicit "$main()" interface_type="(MainType.Type) -> () async -> ()"
// CHECK-CONFIG3-ASYNC: (func_decl {{.*}}implicit range=[{{.*}}:[[@LINE+8]]:1 - line:[[@LINE+8]]:1] "$main()" interface_type="(MainType.Type) -> () async -> ()"
// CHECK-CONFIG3-ASYNC: [[SOURCE_FILE]]:[[DEFAULT_ASYNCHRONOUS_MAIN_LINE]]
}

Expand Down
2 changes: 1 addition & 1 deletion test/attr/ApplicationMain/attr_main_throws.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ struct MyBase {
}
}

// CHECK-AST: (func_decl {{.*}} implicit "$main()" interface_type="(MyBase.Type) -> () throws -> ()" access=internal static
// CHECK-AST: (func_decl {{.*}} implicit range=[{{.*}}:[[@LINE-6]]:1 - line:[[@LINE-6]]:1] "$main()" interface_type="(MyBase.Type) -> () throws -> ()" access=internal static
// CHECK-AST-NEXT: (parameter "self" {{.*}})
// CHECK-AST-NEXT: (parameter_list)
// CHECK-AST-NEXT: (brace_stmt implicit
Expand Down
2 changes: 1 addition & 1 deletion test/expr/capture/top-level-guard.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ let closureCapture: () -> Void = { [x] in
}

// CHECK-LABEL: (defer_stmt
// CHECK-NEXT: (func_decl{{.*}}implicit "$defer()" interface_type="() -> ()" access=fileprivate captures=(x<direct><noescape>)
// CHECK-NEXT: (func_decl{{.*}}implicit range={{.*}} "$defer()" interface_type="() -> ()" access=fileprivate captures=(x<direct><noescape>)
defer {
_ = x
}
Expand Down
Loading