Skip to content

Commit a970f7e

Browse files
feat(completions): improve completions in WHERE clauses (#403)
1 parent b8c0643 commit a970f7e

File tree

6 files changed

+230
-17
lines changed

6 files changed

+230
-17
lines changed

crates/pgt_completions/src/context/mod.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ impl<'a> CompletionContext<'a> {
237237
executor.add_query_results::<queries::TableAliasMatch>();
238238
executor.add_query_results::<queries::SelectColumnMatch>();
239239
executor.add_query_results::<queries::InsertColumnMatch>();
240+
executor.add_query_results::<queries::WhereColumnMatch>();
240241

241242
for relation_match in executor.get_iter(stmt_range) {
242243
match relation_match {
@@ -251,6 +252,7 @@ impl<'a> CompletionContext<'a> {
251252
})
252253
.or_insert(HashSet::from([table_name]));
253254
}
255+
254256
QueryResult::TableAliases(table_alias_match) => {
255257
self.mentioned_table_aliases.insert(
256258
table_alias_match.get_alias(sql),
@@ -272,6 +274,20 @@ impl<'a> CompletionContext<'a> {
272274
.or_insert(HashSet::from([mentioned]));
273275
}
274276

277+
QueryResult::WhereClauseColumns(c) => {
278+
let mentioned = MentionedColumn {
279+
column: c.get_column(sql),
280+
alias: c.get_alias(sql),
281+
};
282+
283+
self.mentioned_columns
284+
.entry(Some(WrappingClause::Where))
285+
.and_modify(|s| {
286+
s.insert(mentioned.clone());
287+
})
288+
.or_insert(HashSet::from([mentioned]));
289+
}
290+
275291
QueryResult::InsertClauseColumns(c) => {
276292
let mentioned = MentionedColumn {
277293
column: c.get_column(sql),
@@ -359,8 +375,9 @@ impl<'a> CompletionContext<'a> {
359375
let parent_node_kind = parent_node.kind();
360376
let current_node_kind = current_node.kind();
361377

362-
// prevent infinite recursion – this can happen if we only have a PROGRAM node
363-
if current_node_kind == parent_node_kind {
378+
// prevent infinite recursion – this can happen with ERROR nodes
379+
if current_node_kind == parent_node_kind && ["ERROR", "program"].contains(&parent_node_kind)
380+
{
364381
self.node_under_cursor = Some(NodeUnderCursor::from(current_node));
365382
return;
366383
}

crates/pgt_completions/src/providers/columns.rs

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@ pub fn complete_columns<'a>(ctx: &CompletionContext<'a>, builder: &mut Completio
2323
};
2424

2525
// autocomplete with the alias in a join clause if we find one
26-
if matches!(ctx.wrapping_clause_type, Some(WrappingClause::Join { .. })) {
26+
if matches!(
27+
ctx.wrapping_clause_type,
28+
Some(WrappingClause::Join { .. })
29+
| Some(WrappingClause::Where)
30+
| Some(WrappingClause::Select)
31+
) {
2732
item.completion_text = find_matching_alias_for_table(ctx, col.table_name.as_str())
2833
.and_then(|alias| {
2934
get_completion_text_with_schema_or_alias(ctx, col.name.as_str(), alias.as_str())
@@ -36,6 +41,8 @@ pub fn complete_columns<'a>(ctx: &CompletionContext<'a>, builder: &mut Completio
3641

3742
#[cfg(test)]
3843
mod tests {
44+
use std::vec;
45+
3946
use crate::{
4047
CompletionItem, CompletionItemKind, complete,
4148
test_helper::{
@@ -643,4 +650,81 @@ mod tests {
643650
)
644651
.await;
645652
}
653+
654+
#[tokio::test]
655+
async fn suggests_columns_in_where_clause() {
656+
let setup = r#"
657+
create table instruments (
658+
id bigint primary key generated always as identity,
659+
name text not null,
660+
z text,
661+
created_at timestamp with time zone default now()
662+
);
663+
664+
create table others (
665+
a text,
666+
b text,
667+
c text
668+
);
669+
"#;
670+
671+
assert_complete_results(
672+
format!("select name from instruments where {} ", CURSOR_POS).as_str(),
673+
vec![
674+
CompletionAssertion::Label("created_at".into()),
675+
CompletionAssertion::Label("id".into()),
676+
CompletionAssertion::Label("name".into()),
677+
CompletionAssertion::Label("z".into()),
678+
],
679+
setup,
680+
)
681+
.await;
682+
683+
assert_complete_results(
684+
format!(
685+
"select name from instruments where z = 'something' and created_at > {}",
686+
CURSOR_POS
687+
)
688+
.as_str(),
689+
// simply do not complete columns + schemas; functions etc. are ok
690+
vec![
691+
CompletionAssertion::KindNotExists(CompletionItemKind::Column),
692+
CompletionAssertion::KindNotExists(CompletionItemKind::Schema),
693+
],
694+
setup,
695+
)
696+
.await;
697+
698+
// prefers not mentioned columns
699+
assert_complete_results(
700+
format!(
701+
"select name from instruments where id = 'something' and {}",
702+
CURSOR_POS
703+
)
704+
.as_str(),
705+
vec![
706+
CompletionAssertion::Label("created_at".into()),
707+
CompletionAssertion::Label("name".into()),
708+
CompletionAssertion::Label("z".into()),
709+
],
710+
setup,
711+
)
712+
.await;
713+
714+
// // uses aliases
715+
assert_complete_results(
716+
format!(
717+
"select name from instruments i join others o on i.z = o.a where i.{}",
718+
CURSOR_POS
719+
)
720+
.as_str(),
721+
vec![
722+
CompletionAssertion::Label("created_at".into()),
723+
CompletionAssertion::Label("id".into()),
724+
CompletionAssertion::Label("name".into()),
725+
],
726+
setup,
727+
)
728+
.await;
729+
}
646730
}

crates/pgt_completions/src/relevance/filtering.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,13 @@ impl CompletionFilter<'_> {
119119
.as_ref()
120120
.is_some_and(|n| n == &WrappingNode::List),
121121

122+
// only autocomplete left side of binary expression
123+
WrappingClause::Where => {
124+
ctx.before_cursor_matches_kind(&["keyword_and", "keyword_where"])
125+
|| (ctx.before_cursor_matches_kind(&["."])
126+
&& ctx.parent_matches_one_of_kind(&["field"]))
127+
}
128+
122129
_ => true,
123130
}
124131
}
@@ -133,12 +140,15 @@ impl CompletionFilter<'_> {
133140

134141
CompletionRelevanceData::Schema(_) => match clause {
135142
WrappingClause::Select
136-
| WrappingClause::Where
137143
| WrappingClause::From
138144
| WrappingClause::Join { .. }
139145
| WrappingClause::Update
140146
| WrappingClause::Delete => true,
141147

148+
WrappingClause::Where => {
149+
ctx.before_cursor_matches_kind(&["keyword_and", "keyword_where"])
150+
}
151+
142152
WrappingClause::DropTable | WrappingClause::AlterTable => ctx
143153
.before_cursor_matches_kind(&[
144154
"keyword_exists",

crates/pgt_completions/src/sanitization.rs

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -193,11 +193,6 @@ fn cursor_before_semicolon(tree: &tree_sitter::Tree, position: TextSize) -> bool
193193
return false;
194194
}
195195

196-
// not okay to be on the semi.
197-
if byte == leaf_node.start_byte() {
198-
return false;
199-
}
200-
201196
leaf_node
202197
.prev_named_sibling()
203198
.map(|n| n.end_byte() < byte)
@@ -355,19 +350,17 @@ mod tests {
355350
// select * from| ; <-- still touches the from
356351
assert!(!cursor_before_semicolon(&tree, TextSize::new(13)));
357352

358-
// not okay to be ON the semi.
359-
// select * from |;
360-
assert!(!cursor_before_semicolon(&tree, TextSize::new(18)));
361-
362353
// anything is fine here
363-
// select * from | ;
364-
// select * from | ;
365-
// select * from | ;
366-
// select * from |;
354+
// select * from | ;
355+
// select * from | ;
356+
// select * from | ;
357+
// select * from | ;
358+
// select * from |;
367359
assert!(cursor_before_semicolon(&tree, TextSize::new(14)));
368360
assert!(cursor_before_semicolon(&tree, TextSize::new(15)));
369361
assert!(cursor_before_semicolon(&tree, TextSize::new(16)));
370362
assert!(cursor_before_semicolon(&tree, TextSize::new(17)));
363+
assert!(cursor_before_semicolon(&tree, TextSize::new(18)));
371364
}
372365

373366
#[test]

crates/pgt_treesitter_queries/src/queries/mod.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ mod parameters;
33
mod relations;
44
mod select_columns;
55
mod table_aliases;
6+
mod where_columns;
67

78
pub use insert_columns::*;
89
pub use parameters::*;
910
pub use relations::*;
1011
pub use select_columns::*;
1112
pub use table_aliases::*;
13+
pub use where_columns::*;
1214

1315
#[derive(Debug)]
1416
pub enum QueryResult<'a> {
@@ -17,6 +19,7 @@ pub enum QueryResult<'a> {
1719
TableAliases(TableAliasMatch<'a>),
1820
SelectClauseColumns(SelectColumnMatch<'a>),
1921
InsertClauseColumns(InsertColumnMatch<'a>),
22+
WhereClauseColumns(WhereColumnMatch<'a>),
2023
}
2124

2225
impl QueryResult<'_> {
@@ -53,6 +56,16 @@ impl QueryResult<'_> {
5356

5457
start >= range.start_point && end <= range.end_point
5558
}
59+
Self::WhereClauseColumns(cm) => {
60+
let start = match cm.alias {
61+
Some(n) => n.start_position(),
62+
None => cm.column.start_position(),
63+
};
64+
65+
let end = cm.column.end_position();
66+
67+
start >= range.start_point && end <= range.end_point
68+
}
5669
Self::InsertClauseColumns(cm) => {
5770
let start = cm.column.start_position();
5871
let end = cm.column.end_position();
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
use std::sync::LazyLock;
2+
3+
use crate::{Query, QueryResult};
4+
5+
use super::QueryTryFrom;
6+
7+
static TS_QUERY: LazyLock<tree_sitter::Query> = LazyLock::new(|| {
8+
static QUERY_STR: &str = r#"
9+
(where
10+
(binary_expression
11+
(binary_expression
12+
(field
13+
(object_reference)? @alias
14+
"."?
15+
(identifier) @column
16+
)
17+
)
18+
)
19+
)
20+
"#;
21+
tree_sitter::Query::new(tree_sitter_sql::language(), QUERY_STR).expect("Invalid TS Query")
22+
});
23+
24+
#[derive(Debug)]
25+
pub struct WhereColumnMatch<'a> {
26+
pub(crate) alias: Option<tree_sitter::Node<'a>>,
27+
pub(crate) column: tree_sitter::Node<'a>,
28+
}
29+
30+
impl WhereColumnMatch<'_> {
31+
pub fn get_alias(&self, sql: &str) -> Option<String> {
32+
let str = self
33+
.alias
34+
.as_ref()?
35+
.utf8_text(sql.as_bytes())
36+
.expect("Failed to get alias from ColumnMatch");
37+
38+
Some(str.to_string())
39+
}
40+
41+
pub fn get_column(&self, sql: &str) -> String {
42+
self.column
43+
.utf8_text(sql.as_bytes())
44+
.expect("Failed to get column from ColumnMatch")
45+
.to_string()
46+
}
47+
}
48+
49+
impl<'a> TryFrom<&'a QueryResult<'a>> for &'a WhereColumnMatch<'a> {
50+
type Error = String;
51+
52+
fn try_from(q: &'a QueryResult<'a>) -> Result<Self, Self::Error> {
53+
match q {
54+
QueryResult::WhereClauseColumns(c) => Ok(c),
55+
56+
#[allow(unreachable_patterns)]
57+
_ => Err("Invalid QueryResult type".into()),
58+
}
59+
}
60+
}
61+
62+
impl<'a> QueryTryFrom<'a> for WhereColumnMatch<'a> {
63+
type Ref = &'a WhereColumnMatch<'a>;
64+
}
65+
66+
impl<'a> Query<'a> for WhereColumnMatch<'a> {
67+
fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec<crate::QueryResult<'a>> {
68+
let mut cursor = tree_sitter::QueryCursor::new();
69+
70+
let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes());
71+
72+
let mut to_return = vec![];
73+
74+
for m in matches {
75+
if m.captures.len() == 1 {
76+
let capture = m.captures[0].node;
77+
to_return.push(QueryResult::WhereClauseColumns(WhereColumnMatch {
78+
alias: None,
79+
column: capture,
80+
}));
81+
}
82+
83+
if m.captures.len() == 2 {
84+
let alias = m.captures[0].node;
85+
let column = m.captures[1].node;
86+
87+
to_return.push(QueryResult::WhereClauseColumns(WhereColumnMatch {
88+
alias: Some(alias),
89+
column,
90+
}));
91+
}
92+
}
93+
94+
to_return
95+
}
96+
}

0 commit comments

Comments
 (0)