From bd034e7151d3f043ca51c9642289af2488b9e8ad Mon Sep 17 00:00:00 2001 From: Julian Date: Sat, 17 May 2025 14:38:30 +0200 Subject: [PATCH 01/13] revert logic to whitelist --- crates/pgt_completions/src/context.rs | 12 ++- .../src/relevance/filtering.rs | 101 +++++++++++------- crates/pgt_completions/src/test_helper.rs | 2 + 3 files changed, 77 insertions(+), 38 deletions(-) diff --git a/crates/pgt_completions/src/context.rs b/crates/pgt_completions/src/context.rs index a17cafa2..2fcb49b1 100644 --- a/crates/pgt_completions/src/context.rs +++ b/crates/pgt_completions/src/context.rs @@ -18,6 +18,7 @@ pub enum WrappingClause<'a> { }, Update, Delete, + ColumnDefinitions, } #[derive(PartialEq, Eq, Debug)] @@ -76,6 +77,7 @@ impl TryFrom for WrappingNode { } } +#[derive(Debug)] pub(crate) struct CompletionContext<'a> { pub node_under_cursor: Option>, @@ -140,6 +142,13 @@ impl<'a> CompletionContext<'a> { ctx.gather_tree_context(); ctx.gather_info_from_ts_queries(); + if cfg!(test) { + println!("{:#?}", ctx.wrapping_clause_type); + println!("{:#?}", ctx.wrapping_node_kind); + println!("{:#?}", ctx.is_in_error_node); + println!("{:#?}", ctx.text); + } + ctx } @@ -303,7 +312,7 @@ impl<'a> CompletionContext<'a> { } } - "where" | "update" | "select" | "delete" | "from" | "join" => { + "where" | "update" | "select" | "delete" | "from" | "join" | "column_definitions" => { self.wrapping_clause_type = self.get_wrapping_clause_from_current_node(current_node, &mut cursor); } @@ -367,6 +376,7 @@ impl<'a> CompletionContext<'a> { "select" => Some(WrappingClause::Select), "delete" => Some(WrappingClause::Delete), "from" => Some(WrappingClause::From), + "column_definitions" => Some(WrappingClause::ColumnDefinitions), "join" => { // sadly, we need to manually iterate over the children – // `node.child_by_field_id(..)` does not work as expected diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index ec12201c..ad520574 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -24,6 +24,13 @@ impl CompletionFilter<'_> { } fn completable_context(&self, ctx: &CompletionContext) -> Option<()> { + if ctx.wrapping_node_kind.is_none() + && ctx.wrapping_clause_type.is_none() + && ctx.is_in_error_node + { + return None; + } + let current_node_kind = ctx.node_under_cursor.map(|n| n.kind()).unwrap_or(""); if current_node_kind.starts_with("keyword_") @@ -58,44 +65,53 @@ impl CompletionFilter<'_> { } fn check_clause(&self, ctx: &CompletionContext) -> Option<()> { - let clause = ctx.wrapping_clause_type.as_ref(); - - match self.data { - CompletionRelevanceData::Table(_) => { - let in_select_clause = clause.is_some_and(|c| c == &WrappingClause::Select); - let in_where_clause = clause.is_some_and(|c| c == &WrappingClause::Where); - - if in_select_clause || in_where_clause { - return None; - }; - } - CompletionRelevanceData::Column(_) => { - let in_from_clause = clause.is_some_and(|c| c == &WrappingClause::From); - if in_from_clause { - return None; - } - - // We can complete columns in JOIN cluases, but only if we are after the - // ON node in the "ON u.id = posts.user_id" part. - let in_join_clause_before_on_node = clause.is_some_and(|c| match c { - // we are in a JOIN, but definitely not after an ON - WrappingClause::Join { on_node: None } => true, - - WrappingClause::Join { on_node: Some(on) } => ctx - .node_under_cursor - .is_some_and(|n| n.end_byte() < on.start_byte()), - - _ => false, - }); - - if in_join_clause_before_on_node { - return None; + ctx.wrapping_clause_type + .as_ref() + .map(|clause| { + match self.data { + CompletionRelevanceData::Table(_) => match clause { + WrappingClause::Select + | WrappingClause::Where + | WrappingClause::ColumnDefinitions => false, + _ => true, + }, + CompletionRelevanceData::Column(_) => { + match clause { + WrappingClause::From | WrappingClause::ColumnDefinitions => false, + + // We can complete columns in JOIN cluases, but only if we are after the + // ON node in the "ON u.id = posts.user_id" part. + WrappingClause::Join { on_node: Some(on) } => ctx + .node_under_cursor + .is_some_and(|cn| cn.start_byte() >= on.end_byte()), + + // we are in a JOIN, but definitely not after an ON + WrappingClause::Join { on_node: None } => false, + + _ => true, + } + } + CompletionRelevanceData::Function(_) => match clause { + WrappingClause::From + | WrappingClause::Select + | WrappingClause::Where + | WrappingClause::Join { .. } => true, + + _ => false, + }, + CompletionRelevanceData::Schema(_) => match clause { + WrappingClause::Select + | WrappingClause::Where + | WrappingClause::From + | WrappingClause::Join { .. } + | WrappingClause::Update + | WrappingClause::Delete => true, + + WrappingClause::ColumnDefinitions => false, + }, } - } - _ => {} - } - - Some(()) + }) + .and_then(|is_ok| if is_ok { Some(()) } else { None }) } fn check_invocation(&self, ctx: &CompletionContext) -> Option<()> { @@ -170,4 +186,15 @@ mod tests { ) .await; } + + #[tokio::test] + async fn completion_after_create_table() { + assert_no_complete_results(format!("create table {}", CURSOR_POS).as_str(), "").await; + } + + #[tokio::test] + async fn completion_in_column_definitions() { + let query = format!(r#"create table instruments ( {} )"#, CURSOR_POS); + assert_no_complete_results(query.as_str(), "").await; + } } diff --git a/crates/pgt_completions/src/test_helper.rs b/crates/pgt_completions/src/test_helper.rs index 937c11af..f3d5c2bf 100644 --- a/crates/pgt_completions/src/test_helper.rs +++ b/crates/pgt_completions/src/test_helper.rs @@ -244,6 +244,8 @@ pub(crate) async fn assert_complete_results( pub(crate) async fn assert_no_complete_results(query: &str, setup: &str) { let (tree, cache) = get_test_deps(setup, query.into()).await; let params = get_test_params(&tree, &cache, query.into()); + println!("{:#?}", params.position); + println!("{:#?}", params.text); let items = complete(params); assert_eq!(items.len(), 0) From d9f1c8d3d662f21e502d50ac04cd189d55c84ceb Mon Sep 17 00:00:00 2001 From: Julian Date: Mon, 19 May 2025 09:58:23 +0200 Subject: [PATCH 02/13] not sure about this --- crates/pgt_completions/src/context.rs | 150 ++++++++++++------ .../pgt_completions/src/providers/columns.rs | 20 +++ .../src/relevance/filtering.rs | 28 +++- crates/pgt_completions/src/sanitization.rs | 25 ++- 4 files changed, 173 insertions(+), 50 deletions(-) diff --git a/crates/pgt_completions/src/context.rs b/crates/pgt_completions/src/context.rs index 2fcb49b1..274b6e79 100644 --- a/crates/pgt_completions/src/context.rs +++ b/crates/pgt_completions/src/context.rs @@ -1,4 +1,7 @@ -use std::collections::{HashMap, HashSet}; +use std::{ + cmp, + collections::{HashMap, HashSet}, +}; use pgt_schema_cache::SchemaCache; use pgt_treesitter_queries::{ @@ -8,7 +11,7 @@ use pgt_treesitter_queries::{ use crate::sanitization::SanitizedCompletionParams; -#[derive(Debug, PartialEq, Eq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash, Clone)] pub enum WrappingClause<'a> { Select, Where, @@ -19,6 +22,7 @@ pub enum WrappingClause<'a> { Update, Delete, ColumnDefinitions, + Insert, } #[derive(PartialEq, Eq, Debug)] @@ -46,6 +50,7 @@ pub enum WrappingNode { Relation, BinaryExpression, Assignment, + List, } impl TryFrom<&str> for WrappingNode { @@ -56,6 +61,7 @@ impl TryFrom<&str> for WrappingNode { "relation" => Ok(Self::Relation), "assignment" => Ok(Self::Assignment), "binary_expression" => Ok(Self::BinaryExpression), + "list" => Ok(Self::List), _ => { let message = format!("Unimplemented Relation: {}", value); @@ -142,13 +148,6 @@ impl<'a> CompletionContext<'a> { ctx.gather_tree_context(); ctx.gather_info_from_ts_queries(); - if cfg!(test) { - println!("{:#?}", ctx.wrapping_clause_type); - println!("{:#?}", ctx.wrapping_node_kind); - println!("{:#?}", ctx.is_in_error_node); - println!("{:#?}", ctx.text); - } - ctx } @@ -240,10 +239,20 @@ impl<'a> CompletionContext<'a> { * `select * from use {}` becomes `select * from use{}`. */ let current_node = cursor.node(); - while cursor.goto_first_child_for_byte(self.position).is_none() && self.position > 0 { - self.position -= 1; + + let mut chars = self.text.chars(); + + if chars + .nth(self.position) + .is_some_and(|c| !c.is_ascii_whitespace()) + { + self.position = cmp::min(self.position + 1, self.text.len()); + } else { + self.position = cmp::min(self.position, self.text.len()); } + cursor.goto_first_child_for_byte(self.position); + self.gather_context_from_node(cursor, current_node); } @@ -276,23 +285,11 @@ impl<'a> CompletionContext<'a> { // try to gather context from the siblings if we're within an error node. if self.is_in_error_node { - let mut next_sibling = current_node.next_named_sibling(); - while let Some(n) = next_sibling { - if let Some(clause_type) = self.get_wrapping_clause_from_keyword_node(n) { - self.wrapping_clause_type = Some(clause_type); - break; - } else { - next_sibling = n.next_named_sibling(); - } + if let Some(clause_type) = self.get_wrapping_clause_from_siblings(current_node) { + self.wrapping_clause_type = Some(clause_type); } - let mut prev_sibling = current_node.prev_named_sibling(); - while let Some(n) = prev_sibling { - if let Some(clause_type) = self.get_wrapping_clause_from_keyword_node(n) { - self.wrapping_clause_type = Some(clause_type); - break; - } else { - prev_sibling = n.prev_named_sibling(); - } + if let Some(wrapping_node) = self.get_wrapping_node_from_siblings(current_node) { + self.wrapping_node_kind = Some(wrapping_node) } } @@ -317,7 +314,7 @@ impl<'a> CompletionContext<'a> { self.get_wrapping_clause_from_current_node(current_node, &mut cursor); } - "relation" | "binary_expression" | "assignment" => { + "relation" | "binary_expression" | "assignment" | "list" => { self.wrapping_node_kind = current_node_kind.try_into().ok(); } @@ -338,31 +335,89 @@ impl<'a> CompletionContext<'a> { self.gather_context_from_node(cursor, current_node); } - fn get_wrapping_clause_from_keyword_node( + fn get_first_sibling(&self, node: tree_sitter::Node<'a>) -> tree_sitter::Node<'a> { + let mut first_sibling = node; + while let Some(n) = first_sibling.prev_sibling() { + first_sibling = n; + } + first_sibling + } + + fn get_wrapping_node_from_siblings(&self, node: tree_sitter::Node<'a>) -> Option { + self.wrapping_clause_type + .as_ref() + .and_then(|clause| match clause { + WrappingClause::Insert => { + if node.prev_sibling().is_some_and(|n| n.kind() == "(") + || node.next_sibling().is_some_and(|n| n.kind() == ")") + { + Some(WrappingNode::List) + } else { + None + } + } + _ => None, + }) + } + + fn get_wrapping_clause_from_siblings( &self, node: tree_sitter::Node<'a>, ) -> Option> { - if node.kind().starts_with("keyword_") { - if let Some(txt) = self.get_ts_node_content(node).and_then(|txt| match txt { - NodeText::Original(txt) => Some(txt), - NodeText::Replaced => None, - }) { - match txt { - "where" => return Some(WrappingClause::Where), - "update" => return Some(WrappingClause::Update), - "select" => return Some(WrappingClause::Select), - "delete" => return Some(WrappingClause::Delete), - "from" => return Some(WrappingClause::From), - "join" => { - // TODO: not sure if we can infer it here. - return Some(WrappingClause::Join { on_node: None }); + let clause_combinations: Vec<(WrappingClause, &[&'static str])> = vec![ + (WrappingClause::Where, &["where"]), + (WrappingClause::Update, &["update"]), + (WrappingClause::Select, &["select"]), + (WrappingClause::Delete, &["delete"]), + (WrappingClause::Insert, &["insert", "into"]), + (WrappingClause::From, &["from"]), + (WrappingClause::Join { on_node: None }, &["join"]), + ]; + + let first_sibling = self.get_first_sibling(node); + + /* + * For each clause, we'll iterate from first_sibling to the next ones, + * either until the end or until we land on the node under the cursor. + * We'll score the `WrappingClause` by how many tokens it matches in order. + */ + let mut clauses_with_score: Vec<(WrappingClause, usize)> = clause_combinations + .into_iter() + .map(|(clause, tokens)| { + let mut idx = 0; + + let mut sibling = Some(first_sibling); + while let Some(sib) = sibling { + if sib.end_byte() >= node.end_byte() || idx >= tokens.len() { + break; } - _ => {} + + if let Some(sibling_content) = + self.get_ts_node_content(sib).and_then(|txt| match txt { + NodeText::Original(txt) => Some(txt), + NodeText::Replaced => None, + }) + { + if sibling_content == tokens[idx] { + idx += 1; + } + } else { + break; + } + + sibling = sib.next_sibling(); } - }; - } - None + (clause, idx) + }) + .collect(); + + clauses_with_score.sort_by(|(_, score_a), (_, score_b)| score_b.cmp(score_a)); + clauses_with_score + .iter() + .filter(|(_, score)| *score > 0) + .next() + .map(|c| c.0.clone()) } fn get_wrapping_clause_from_current_node( @@ -377,6 +432,7 @@ impl<'a> CompletionContext<'a> { "delete" => Some(WrappingClause::Delete), "from" => Some(WrappingClause::From), "column_definitions" => Some(WrappingClause::ColumnDefinitions), + "insert" => Some(WrappingClause::Insert), "join" => { // sadly, we need to manually iterate over the children – // `node.child_by_field_id(..)` does not work as expected diff --git a/crates/pgt_completions/src/providers/columns.rs b/crates/pgt_completions/src/providers/columns.rs index 8109ba83..9dc7bfa9 100644 --- a/crates/pgt_completions/src/providers/columns.rs +++ b/crates/pgt_completions/src/providers/columns.rs @@ -573,4 +573,24 @@ mod tests { ) .await; } + + #[tokio::test] + async fn suggests_columns_in_insert_clause() { + let setup = r#" + create table instruments ( + id bigint primary key generated always as identity, + name text not null + ); + "#; + + assert_complete_results( + format!("insert into instruments ({})", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::Label("id".to_string()), + CompletionAssertion::Label("name".to_string()), + ], + setup, + ) + .await; + } } diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index ad520574..263880b3 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -1,4 +1,4 @@ -use crate::context::{CompletionContext, WrappingClause}; +use crate::context::{CompletionContext, WrappingClause, WrappingNode}; use super::CompletionRelevanceData; @@ -73,6 +73,17 @@ impl CompletionFilter<'_> { WrappingClause::Select | WrappingClause::Where | WrappingClause::ColumnDefinitions => false, + + WrappingClause::Insert => { + ctx.wrapping_node_kind + .as_ref() + .is_some_and(|n| n != &WrappingNode::List) + && ctx.node_under_cursor.is_some_and(|n| { + n.prev_sibling() + .is_some_and(|sib| sib.kind() == "keyword_into") + }) + } + _ => true, }, CompletionRelevanceData::Column(_) => { @@ -88,6 +99,11 @@ impl CompletionFilter<'_> { // we are in a JOIN, but definitely not after an ON WrappingClause::Join { on_node: None } => false, + WrappingClause::Insert => ctx + .wrapping_node_kind + .as_ref() + .is_some_and(|n| n == &WrappingNode::List), + _ => true, } } @@ -107,6 +123,16 @@ impl CompletionFilter<'_> { | WrappingClause::Update | WrappingClause::Delete => true, + WrappingClause::Insert => { + ctx.wrapping_node_kind + .as_ref() + .is_some_and(|n| n != &WrappingNode::List) + && ctx.node_under_cursor.is_some_and(|n| { + n.prev_sibling() + .is_some_and(|sib| sib.kind() == "keyword_into") + }) + } + WrappingClause::ColumnDefinitions => false, }, } diff --git a/crates/pgt_completions/src/sanitization.rs b/crates/pgt_completions/src/sanitization.rs index 248a0ffa..0f8a968b 100644 --- a/crates/pgt_completions/src/sanitization.rs +++ b/crates/pgt_completions/src/sanitization.rs @@ -25,6 +25,7 @@ where || cursor_prepared_to_write_token_after_last_node(params.tree, params.position) || cursor_before_semicolon(params.tree, params.position) || cursor_on_a_dot(¶ms.text, params.position) + || cursor_between_parentheses(¶ms.text, params.position) { SanitizedCompletionParams::with_adjusted_sql(params) } else { @@ -200,13 +201,19 @@ fn cursor_before_semicolon(tree: &tree_sitter::Tree, position: TextSize) -> bool .unwrap_or(false) } +fn cursor_between_parentheses(sql: &str, position: TextSize) -> bool { + let position: usize = position.into(); + let mut chars = sql.chars(); + chars.nth(position - 1).is_some_and(|c| c == '(') && chars.next().is_some_and(|c| c == ')') +} + #[cfg(test)] mod tests { use pgt_text_size::TextSize; use crate::sanitization::{ - cursor_before_semicolon, cursor_inbetween_nodes, cursor_on_a_dot, - cursor_prepared_to_write_token_after_last_node, + cursor_before_semicolon, cursor_between_parentheses, cursor_inbetween_nodes, + cursor_on_a_dot, cursor_prepared_to_write_token_after_last_node, }; #[test] @@ -317,4 +324,18 @@ mod tests { assert!(cursor_before_semicolon(&tree, TextSize::new(16))); assert!(cursor_before_semicolon(&tree, TextSize::new(17))); } + + #[test] + fn between_parentheses() { + let input = "insert into instruments ()"; + + // insert into (|) <- right in the parentheses + assert!(cursor_between_parentheses(input, TextSize::new(25))); + + // insert into ()| <- too late + assert!(!cursor_between_parentheses(input, TextSize::new(26))); + + // insert into |() <- too early + assert!(!cursor_between_parentheses(input, TextSize::new(24))); + } } From 66df5c69fcda8561305daa439a5e37e9ada509fb Mon Sep 17 00:00:00 2001 From: Julian Date: Tue, 20 May 2025 07:54:30 +0200 Subject: [PATCH 03/13] ok --- crates/pgt_completions/src/context.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/pgt_completions/src/context.rs b/crates/pgt_completions/src/context.rs index 274b6e79..a236a0cd 100644 --- a/crates/pgt_completions/src/context.rs +++ b/crates/pgt_completions/src/context.rs @@ -244,7 +244,7 @@ impl<'a> CompletionContext<'a> { if chars .nth(self.position) - .is_some_and(|c| !c.is_ascii_whitespace()) + .is_some_and(|c| !c.is_ascii_whitespace() && c != ';') { self.position = cmp::min(self.position + 1, self.text.len()); } else { From d022e2ef958ca118a0d9ca2c2206a028a95ef8d0 Mon Sep 17 00:00:00 2001 From: Julian Date: Tue, 20 May 2025 09:01:09 +0200 Subject: [PATCH 04/13] alter table statements --- crates/pgt_completions/src/context.rs | 47 ++++++++++--- .../pgt_completions/src/providers/tables.rs | 69 +++++++++++++++++++ .../src/relevance/filtering.rs | 34 +++++---- 3 files changed, 127 insertions(+), 23 deletions(-) diff --git a/crates/pgt_completions/src/context.rs b/crates/pgt_completions/src/context.rs index a236a0cd..55751794 100644 --- a/crates/pgt_completions/src/context.rs +++ b/crates/pgt_completions/src/context.rs @@ -23,6 +23,8 @@ pub enum WrappingClause<'a> { Delete, ColumnDefinitions, Insert, + AlterTable, + DropTable, } #[derive(PartialEq, Eq, Debug)] @@ -118,9 +120,6 @@ pub(crate) struct CompletionContext<'a> { pub is_invocation: bool, pub wrapping_statement_range: Option, - /// Some incomplete statements can't be correctly parsed by TreeSitter. - pub is_in_error_node: bool, - pub mentioned_relations: HashMap, HashSet>, pub mentioned_table_aliases: HashMap, pub mentioned_columns: HashMap>, HashSet>, @@ -142,12 +141,19 @@ impl<'a> CompletionContext<'a> { mentioned_relations: HashMap::new(), mentioned_table_aliases: HashMap::new(), mentioned_columns: HashMap::new(), - is_in_error_node: false, }; ctx.gather_tree_context(); ctx.gather_info_from_ts_queries(); + // if cfg!(test) { + // println!("{:?}", ctx.position); + // println!("{:?}", ctx.text); + // println!("{:?}", ctx.wrapping_clause_type); + // println!("{:?}", ctx.wrapping_node_kind); + // println!("{:?}", ctx.before_cursor_matches_kind(&["keyword_table"])); + // } + ctx } @@ -284,7 +290,7 @@ impl<'a> CompletionContext<'a> { } // try to gather context from the siblings if we're within an error node. - if self.is_in_error_node { + if parent_node_kind == "ERROR" { if let Some(clause_type) = self.get_wrapping_clause_from_siblings(current_node) { self.wrapping_clause_type = Some(clause_type); } @@ -309,7 +315,8 @@ impl<'a> CompletionContext<'a> { } } - "where" | "update" | "select" | "delete" | "from" | "join" | "column_definitions" => { + "where" | "update" | "select" | "delete" | "from" | "join" | "column_definitions" + | "drop_table" | "alter_table" => { self.wrapping_clause_type = self.get_wrapping_clause_from_current_node(current_node, &mut cursor); } @@ -318,10 +325,6 @@ impl<'a> CompletionContext<'a> { self.wrapping_node_kind = current_node_kind.try_into().ok(); } - "ERROR" => { - self.is_in_error_node = true; - } - _ => {} } @@ -372,6 +375,16 @@ impl<'a> CompletionContext<'a> { (WrappingClause::Insert, &["insert", "into"]), (WrappingClause::From, &["from"]), (WrappingClause::Join { on_node: None }, &["join"]), + (WrappingClause::AlterTable, &["alter", "table"]), + ( + WrappingClause::AlterTable, + &["alter", "table", "if", "exists"], + ), + (WrappingClause::DropTable, &["drop", "table"]), + ( + WrappingClause::DropTable, + &["drop", "table", "if", "exists"], + ), ]; let first_sibling = self.get_first_sibling(node); @@ -431,6 +444,8 @@ impl<'a> CompletionContext<'a> { "select" => Some(WrappingClause::Select), "delete" => Some(WrappingClause::Delete), "from" => Some(WrappingClause::From), + "drop_table" => Some(WrappingClause::DropTable), + "alter_table" => Some(WrappingClause::AlterTable), "column_definitions" => Some(WrappingClause::ColumnDefinitions), "insert" => Some(WrappingClause::Insert), "join" => { @@ -449,6 +464,18 @@ impl<'a> CompletionContext<'a> { _ => None, } } + + pub(crate) fn before_cursor_matches_kind(&self, kinds: &[&'static str]) -> bool { + self.node_under_cursor.is_some_and(|mut node| { + // move up to the parent until we're at top OR we have a prev sibling + while node.prev_sibling().is_none() && node.parent().is_some() { + node = node.parent().unwrap(); + } + + node.prev_sibling() + .is_some_and(|sib| kinds.contains(&sib.kind())) + }) + } } #[cfg(test)] diff --git a/crates/pgt_completions/src/providers/tables.rs b/crates/pgt_completions/src/providers/tables.rs index 57195da7..217db91f 100644 --- a/crates/pgt_completions/src/providers/tables.rs +++ b/crates/pgt_completions/src/providers/tables.rs @@ -310,4 +310,73 @@ mod tests { ) .await; } + + #[tokio::test] + async fn suggests_tables_in_alter_and_drop_statements() { + let setup = r#" + create schema auth; + + create table auth.users ( + uid serial primary key, + name text not null, + email text unique not null + ); + + create table auth.posts ( + pid serial primary key, + user_id int not null references auth.users(uid), + title text not null, + content text, + created_at timestamp default now() + ); + "#; + + assert_complete_results( + format!("alter table {}", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("posts".into(), CompletionItemKind::Table), + CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), + ], + setup, + ) + .await; + + assert_complete_results( + format!("alter table if exists {}", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("posts".into(), CompletionItemKind::Table), + CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), + ], + setup, + ) + .await; + + assert_complete_results( + format!("drop table {}", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("posts".into(), CompletionItemKind::Table), + CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), + ], + setup, + ) + .await; + + assert_complete_results( + format!("drop table if exists {}", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("posts".into(), CompletionItemKind::Table), // self-join + CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), + ], + setup, + ) + .await; + } } diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index 263880b3..c237a4dc 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -24,10 +24,7 @@ impl CompletionFilter<'_> { } fn completable_context(&self, ctx: &CompletionContext) -> Option<()> { - if ctx.wrapping_node_kind.is_none() - && ctx.wrapping_clause_type.is_none() - && ctx.is_in_error_node - { + if ctx.wrapping_node_kind.is_none() && ctx.wrapping_clause_type.is_none() { return None; } @@ -78,17 +75,24 @@ impl CompletionFilter<'_> { ctx.wrapping_node_kind .as_ref() .is_some_and(|n| n != &WrappingNode::List) - && ctx.node_under_cursor.is_some_and(|n| { - n.prev_sibling() - .is_some_and(|sib| sib.kind() == "keyword_into") - }) + && ctx.before_cursor_matches_kind(&["keyword_into"]) } + WrappingClause::DropTable | WrappingClause::AlterTable => ctx + .before_cursor_matches_kind(&[ + "keyword_exists", + "keyword_only", + "keyword_table", + ]), + _ => true, }, CompletionRelevanceData::Column(_) => { match clause { - WrappingClause::From | WrappingClause::ColumnDefinitions => false, + WrappingClause::From + | WrappingClause::ColumnDefinitions + | WrappingClause::AlterTable + | WrappingClause::DropTable => false, // We can complete columns in JOIN cluases, but only if we are after the // ON node in the "ON u.id = posts.user_id" part. @@ -123,14 +127,18 @@ impl CompletionFilter<'_> { | WrappingClause::Update | WrappingClause::Delete => true, + WrappingClause::DropTable | WrappingClause::AlterTable => ctx + .before_cursor_matches_kind(&[ + "keyword_exists", + "keyword_only", + "keyword_table", + ]), + WrappingClause::Insert => { ctx.wrapping_node_kind .as_ref() .is_some_and(|n| n != &WrappingNode::List) - && ctx.node_under_cursor.is_some_and(|n| { - n.prev_sibling() - .is_some_and(|sib| sib.kind() == "keyword_into") - }) + && ctx.before_cursor_matches_kind(&["keyword_into"]) } WrappingClause::ColumnDefinitions => false, From 8689a36f45fe636c65f403e2476dfb2e62a95f8e Mon Sep 17 00:00:00 2001 From: Julian Date: Thu, 22 May 2025 09:39:25 +0200 Subject: [PATCH 05/13] fix --- crates/pgt_completions/src/context/mod.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/pgt_completions/src/context/mod.rs b/crates/pgt_completions/src/context/mod.rs index d67e71e0..7ae5ab27 100644 --- a/crates/pgt_completions/src/context/mod.rs +++ b/crates/pgt_completions/src/context/mod.rs @@ -560,7 +560,8 @@ impl<'a> CompletionContext<'a> { current = current.parent().unwrap(); } - node.prev_sibling() + current + .prev_sibling() .is_some_and(|sib| kinds.contains(&sib.kind())) } From 05c95b6b48994c00ed5ec6cb6f06989c24f84748 Mon Sep 17 00:00:00 2001 From: Julian Date: Sat, 24 May 2025 13:04:41 +0200 Subject: [PATCH 06/13] wow it works! --- crates/pgt_completions/src/context/mod.rs | 91 +++++++++++++++--- .../pgt_completions/src/providers/columns.rs | 30 +++++- crates/pgt_completions/src/sanitization.rs | 95 ++++++++++++++++++- crates/pgt_lsp/src/capabilities.rs | 2 +- 4 files changed, 202 insertions(+), 16 deletions(-) diff --git a/crates/pgt_completions/src/context/mod.rs b/crates/pgt_completions/src/context/mod.rs index 7ae5ab27..ee40dc41 100644 --- a/crates/pgt_completions/src/context/mod.rs +++ b/crates/pgt_completions/src/context/mod.rs @@ -35,7 +35,7 @@ pub enum WrappingClause<'a> { ToRoleAssignment, } -#[derive(PartialEq, Eq, Hash, Debug)] +#[derive(PartialEq, Eq, Hash, Debug, Clone)] pub(crate) struct MentionedColumn { pub(crate) column: String, pub(crate) alias: Option, @@ -194,14 +194,6 @@ impl<'a> CompletionContext<'a> { ctx.gather_info_from_ts_queries(); } - // if cfg!(test) { - // println!("{:?}", ctx.position); - // println!("{:?}", ctx.text); - // println!("{:?}", ctx.wrapping_clause_type); - // println!("{:?}", ctx.wrapping_node_kind); - // println!("{:?}", ctx.before_cursor_matches_kind(&["keyword_table"])); - // } - ctx } @@ -334,7 +326,7 @@ impl<'a> CompletionContext<'a> { if chars .nth(self.position) - .is_some_and(|c| !c.is_ascii_whitespace() && c != ';') + .is_some_and(|c| !c.is_ascii_whitespace() && !&[';', ')'].contains(&c)) { self.position = cmp::min(self.position + 1, self.text.len()); } else { @@ -381,6 +373,8 @@ impl<'a> CompletionContext<'a> { if let Some(wrapping_node) = self.get_wrapping_node_from_siblings(current_node) { self.wrapping_node_kind = Some(wrapping_node) } + + self.get_info_from_error_node_child(current_node); } match current_node_kind { @@ -435,9 +429,28 @@ impl<'a> CompletionContext<'a> { .as_ref() .and_then(|clause| match clause { WrappingClause::Insert => { - if node.prev_sibling().is_some_and(|n| n.kind() == "(") - || node.next_sibling().is_some_and(|n| n.kind() == ")") - { + let mut first_sib = self.get_first_sibling(node); + + let mut after_opening_bracket = false; + let mut before_closing_bracket = false; + + while let Some(next_sib) = first_sib.next_sibling() { + if next_sib.kind() == "(" + && next_sib.end_position() <= node.start_position() + { + after_opening_bracket = true; + } + + if next_sib.kind() == ")" + && next_sib.start_position() >= node.end_position() + { + before_closing_bracket = true; + } + + first_sib = next_sib; + } + + if after_opening_bracket && before_closing_bracket { Some(WrappingNode::List) } else { None @@ -517,6 +530,58 @@ impl<'a> CompletionContext<'a> { .map(|c| c.0.clone()) } + fn get_info_from_error_node_child(&mut self, node: tree_sitter::Node<'a>) { + let mut first_sibling = self.get_first_sibling(node); + + if let Some(clause) = self.wrapping_clause_type.as_ref() { + match clause { + WrappingClause::Insert => { + while let Some(sib) = first_sibling.next_sibling() { + match sib.kind() { + "object_reference" => { + if let Some(NodeText::Original(txt)) = + self.get_ts_node_content(&sib) + { + let mut iter = txt.split('.').rev(); + let table = iter.next().unwrap().to_string(); + let schema = iter.next().map(|s| s.to_string()); + self.mentioned_relations + .entry(schema) + .and_modify(|s| { + s.insert(table.clone()); + }) + .or_insert(HashSet::from([table])); + } + } + "column" => { + if let Some(NodeText::Original(txt)) = + self.get_ts_node_content(&sib) + { + let entry = MentionedColumn { + column: txt, + alias: None, + }; + + self.mentioned_columns + .entry(Some(WrappingClause::Insert)) + .and_modify(|s| { + s.insert(entry.clone()); + }) + .or_insert(HashSet::from([entry])); + } + } + + _ => {} + } + + first_sibling = sib; + } + } + _ => {} + } + } + } + fn get_wrapping_clause_from_current_node( &self, node: tree_sitter::Node<'a>, diff --git a/crates/pgt_completions/src/providers/columns.rs b/crates/pgt_completions/src/providers/columns.rs index 9dc7bfa9..5926bbe2 100644 --- a/crates/pgt_completions/src/providers/columns.rs +++ b/crates/pgt_completions/src/providers/columns.rs @@ -579,18 +579,46 @@ mod tests { let setup = r#" create table instruments ( id bigint primary key generated always as identity, - name text not null + name text not null, + z text + ); + + create table others ( + id serial primary key, + a text, + b text ); "#; + // We should prefer the instrument columns, even though they + // are lower in the alphabet + assert_complete_results( format!("insert into instruments ({})", CURSOR_POS).as_str(), vec![ CompletionAssertion::Label("id".to_string()), CompletionAssertion::Label("name".to_string()), + CompletionAssertion::Label("z".to_string()), ], setup, ) .await; + + assert_complete_results( + format!("insert into instruments (id, {})", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::Label("name".to_string()), + CompletionAssertion::Label("z".to_string()), + ], + setup, + ) + .await; + + assert_complete_results( + format!("insert into instruments (id, {}, name)", CURSOR_POS).as_str(), + vec![CompletionAssertion::Label("z".to_string())], + setup, + ) + .await; } } diff --git a/crates/pgt_completions/src/sanitization.rs b/crates/pgt_completions/src/sanitization.rs index 75876887..8b1cae6e 100644 --- a/crates/pgt_completions/src/sanitization.rs +++ b/crates/pgt_completions/src/sanitization.rs @@ -206,8 +206,64 @@ fn cursor_before_semicolon(tree: &tree_sitter::Tree, position: TextSize) -> bool fn cursor_between_parentheses(sql: &str, position: TextSize) -> bool { let position: usize = position.into(); + + let mut level = 0; + let mut tracking_open_idx = None; + + let mut matching_open_idx = None; + let mut matching_close_idx = None; + + for (idx, char) in sql.chars().enumerate() { + if char == '(' { + tracking_open_idx = Some(idx); + level += 1; + } + + if char == ')' { + level -= 1; + + if tracking_open_idx.is_some_and(|it| it < position) && idx >= position { + matching_open_idx = tracking_open_idx; + matching_close_idx = Some(idx) + } + } + } + + // invalid statement + if level != 0 { + return false; + } + + // early check: '(|)' + // however, we want to check this after the level nesting. let mut chars = sql.chars(); - chars.nth(position - 1).is_some_and(|c| c == '(') && chars.next().is_some_and(|c| c == ')') + if chars.nth(position - 1).is_some_and(|c| c == '(') && chars.next().is_some_and(|c| c == ')') { + return true; + } + + // not *within* parentheses + if matching_open_idx.is_none() || matching_close_idx.is_none() { + return false; + } + + // use string indexing, because we can't `.rev()` after `.take()` + let before = sql[..position] + .to_string() + .chars() + .rev() + .find(|c| !c.is_whitespace()) + .unwrap_or_default(); + + let after = sql + .chars() + .skip(position) + .find(|c| !c.is_whitespace()) + .unwrap_or_default(); + + let before_matches = before == ',' || before == '('; + let after_matches = after == ',' || after == ')'; + + before_matches && after_matches } #[cfg(test)] @@ -326,5 +382,42 @@ mod tests { // insert into |() <- too early assert!(!cursor_between_parentheses(input, TextSize::new(24))); + + let input = "insert into instruments (name, id, )"; + + // insert into instruments (name, id, |) <-- we should sanitize the next column + assert!(cursor_between_parentheses(input, TextSize::new(35))); + + // insert into instruments (name, id|, ) <-- we are still on the previous token. + assert!(!cursor_between_parentheses(input, TextSize::new(33))); + + let input = "insert into instruments (name, , id)"; + + // insert into instruments (name, |, id) <-- we can sanitize! + assert!(cursor_between_parentheses(input, TextSize::new(31))); + + // insert into instruments (name, ,| id) <-- we are already on the next token + assert!(!cursor_between_parentheses(input, TextSize::new(32))); + + let input = "insert into instruments (, name, id)"; + + // insert into instruments (|, name, id) <-- we can sanitize! + assert!(cursor_between_parentheses(input, TextSize::new(25))); + + // insert into instruments (,| name, id) <-- already on next token + assert!(!cursor_between_parentheses(input, TextSize::new(26))); + + // bails on invalidly nested statements + assert!(!cursor_between_parentheses( + "insert into (instruments ()", + TextSize::new(26) + )); + + // can find its position in nested statements + // "insert into instruments (name) values (a_function(name, |))", + assert!(cursor_between_parentheses( + "insert into instruments (name) values (a_function(name, ))", + TextSize::new(56) + )); } } diff --git a/crates/pgt_lsp/src/capabilities.rs b/crates/pgt_lsp/src/capabilities.rs index b3e35b69..acfc60ed 100644 --- a/crates/pgt_lsp/src/capabilities.rs +++ b/crates/pgt_lsp/src/capabilities.rs @@ -37,7 +37,7 @@ pub(crate) fn server_capabilities(capabilities: &ClientCapabilities) -> ServerCa // The request is used to get more information about a simple CompletionItem. resolve_provider: None, - trigger_characters: Some(vec![".".to_owned(), " ".to_owned()]), + trigger_characters: Some(vec![".".to_owned(), " ".to_owned(), "(".to_owned()]), // No character will lead to automatically inserting the selected completion-item all_commit_characters: None, From a98f99616b360fabf75fcfd8b21af8cbd81d4307 Mon Sep 17 00:00:00 2001 From: Julian Date: Sat, 24 May 2025 13:14:08 +0200 Subject: [PATCH 07/13] ok --- crates/pgt_completions/src/context/mod.rs | 24 +++++++++++++++---- .../pgt_completions/src/providers/columns.rs | 11 +++++++-- crates/pgt_completions/src/test_helper.rs | 2 -- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/crates/pgt_completions/src/context/mod.rs b/crates/pgt_completions/src/context/mod.rs index ee40dc41..d91a046d 100644 --- a/crates/pgt_completions/src/context/mod.rs +++ b/crates/pgt_completions/src/context/mod.rs @@ -367,10 +367,12 @@ impl<'a> CompletionContext<'a> { // try to gather context from the siblings if we're within an error node. if parent_node_kind == "ERROR" { - if let Some(clause_type) = self.get_wrapping_clause_from_siblings(current_node) { + if let Some(clause_type) = self.get_wrapping_clause_from_error_node_child(current_node) + { self.wrapping_clause_type = Some(clause_type); } - if let Some(wrapping_node) = self.get_wrapping_node_from_siblings(current_node) { + if let Some(wrapping_node) = self.get_wrapping_node_from_error_node_child(current_node) + { self.wrapping_node_kind = Some(wrapping_node) } @@ -399,10 +401,19 @@ impl<'a> CompletionContext<'a> { self.get_wrapping_clause_from_current_node(current_node, &mut cursor); } - "relation" | "binary_expression" | "assignment" | "list" => { + "relation" | "binary_expression" | "assignment" => { self.wrapping_node_kind = current_node_kind.try_into().ok(); } + "list" => { + if current_node + .prev_sibling() + .is_none_or(|n| n.kind() != "keyword_values") + { + self.wrapping_node_kind = current_node_kind.try_into().ok(); + } + } + _ => {} } @@ -424,7 +435,10 @@ impl<'a> CompletionContext<'a> { first_sibling } - fn get_wrapping_node_from_siblings(&self, node: tree_sitter::Node<'a>) -> Option { + fn get_wrapping_node_from_error_node_child( + &self, + node: tree_sitter::Node<'a>, + ) -> Option { self.wrapping_clause_type .as_ref() .and_then(|clause| match clause { @@ -460,7 +474,7 @@ impl<'a> CompletionContext<'a> { }) } - fn get_wrapping_clause_from_siblings( + fn get_wrapping_clause_from_error_node_child( &self, node: tree_sitter::Node<'a>, ) -> Option> { diff --git a/crates/pgt_completions/src/providers/columns.rs b/crates/pgt_completions/src/providers/columns.rs index 5926bbe2..b23b0ccf 100644 --- a/crates/pgt_completions/src/providers/columns.rs +++ b/crates/pgt_completions/src/providers/columns.rs @@ -39,8 +39,8 @@ mod tests { use crate::{ CompletionItem, CompletionItemKind, complete, test_helper::{ - CURSOR_POS, CompletionAssertion, InputQuery, assert_complete_results, get_test_deps, - get_test_params, + CURSOR_POS, CompletionAssertion, InputQuery, assert_complete_results, + assert_no_complete_results, get_test_deps, get_test_params, }, }; @@ -620,5 +620,12 @@ mod tests { setup, ) .await; + + // no completions in the values list! + assert_no_complete_results( + format!("insert into instruments (id, name) values ({})", CURSOR_POS).as_str(), + setup, + ) + .await; } } diff --git a/crates/pgt_completions/src/test_helper.rs b/crates/pgt_completions/src/test_helper.rs index f3d5c2bf..937c11af 100644 --- a/crates/pgt_completions/src/test_helper.rs +++ b/crates/pgt_completions/src/test_helper.rs @@ -244,8 +244,6 @@ pub(crate) async fn assert_complete_results( pub(crate) async fn assert_no_complete_results(query: &str, setup: &str) { let (tree, cache) = get_test_deps(setup, query.into()).await; let params = get_test_params(&tree, &cache, query.into()); - println!("{:#?}", params.position); - println!("{:#?}", params.text); let items = complete(params); assert_eq!(items.len(), 0) From ec3d575e96b81d7f0ff89eafc95cd266f6451908 Mon Sep 17 00:00:00 2001 From: Julian Date: Sat, 24 May 2025 13:22:25 +0200 Subject: [PATCH 08/13] wow! --- .../pgt_completions/src/providers/tables.rs | 24 +++++++++++++++++++ .../src/relevance/filtering.rs | 4 ++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/crates/pgt_completions/src/providers/tables.rs b/crates/pgt_completions/src/providers/tables.rs index 217db91f..e372125f 100644 --- a/crates/pgt_completions/src/providers/tables.rs +++ b/crates/pgt_completions/src/providers/tables.rs @@ -379,4 +379,28 @@ mod tests { ) .await; } + + #[tokio::test] + async fn suggests_tables_in_insert_into() { + let setup = r#" + create schema auth; + + create table auth.users ( + uid serial primary key, + name text not null, + email text unique not null + ); + "#; + + assert_complete_results( + format!("insert into {}", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), + ], + setup, + ) + .await; + } } diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index 725c175d..8693f5c0 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -81,7 +81,7 @@ impl CompletionFilter<'_> { WrappingClause::Insert => { ctx.wrapping_node_kind .as_ref() - .is_some_and(|n| n != &WrappingNode::List) + .is_none_or(|n| n != &WrappingNode::List) && ctx.before_cursor_matches_kind(&["keyword_into"]) } @@ -148,7 +148,7 @@ impl CompletionFilter<'_> { WrappingClause::Insert => { ctx.wrapping_node_kind .as_ref() - .is_some_and(|n| n != &WrappingNode::List) + .is_none_or(|n| n != &WrappingNode::List) && ctx.before_cursor_matches_kind(&["keyword_into"]) } From 09084c13fa3f80ed172299a5fe28dc9bd7b59582 Mon Sep 17 00:00:00 2001 From: Julian Date: Sat, 24 May 2025 14:02:16 +0200 Subject: [PATCH 09/13] cool --- crates/pgt_completions/src/context/mod.rs | 57 ++++--- .../pgt_completions/src/providers/columns.rs | 15 ++ .../pgt_completions/src/providers/tables.rs | 26 +++ .../src/relevance/filtering.rs | 4 +- .../src/queries/insert_columns.rs | 150 ++++++++++++++++++ .../pgt_treesitter_queries/src/queries/mod.rs | 8 + .../src/queries/relations.rs | 106 +++++++++++++ 7 files changed, 347 insertions(+), 19 deletions(-) create mode 100644 crates/pgt_treesitter_queries/src/queries/insert_columns.rs diff --git a/crates/pgt_completions/src/context/mod.rs b/crates/pgt_completions/src/context/mod.rs index d91a046d..b5686788 100644 --- a/crates/pgt_completions/src/context/mod.rs +++ b/crates/pgt_completions/src/context/mod.rs @@ -236,6 +236,7 @@ impl<'a> CompletionContext<'a> { executor.add_query_results::(); executor.add_query_results::(); executor.add_query_results::(); + executor.add_query_results::(); for relation_match in executor.get_iter(stmt_range) { match relation_match { @@ -243,13 +244,12 @@ impl<'a> CompletionContext<'a> { let schema_name = r.get_schema(sql); let table_name = r.get_table(sql); - if let Some(c) = self.mentioned_relations.get_mut(&schema_name) { - c.insert(table_name); - } else { - let mut new = HashSet::new(); - new.insert(table_name); - self.mentioned_relations.insert(schema_name, new); - } + self.mentioned_relations + .entry(schema_name) + .and_modify(|s| { + s.insert(table_name.clone()); + }) + .or_insert(HashSet::from([table_name])); } QueryResult::TableAliases(table_alias_match) => { self.mentioned_table_aliases.insert( @@ -257,23 +257,33 @@ impl<'a> CompletionContext<'a> { table_alias_match.get_table(sql), ); } + QueryResult::SelectClauseColumns(c) => { let mentioned = MentionedColumn { column: c.get_column(sql), alias: c.get_alias(sql), }; - if let Some(cols) = self - .mentioned_columns - .get_mut(&Some(WrappingClause::Select)) - { - cols.insert(mentioned); - } else { - let mut new = HashSet::new(); - new.insert(mentioned); - self.mentioned_columns - .insert(Some(WrappingClause::Select), new); - } + self.mentioned_columns + .entry(Some(WrappingClause::Select)) + .and_modify(|s| { + s.insert(mentioned.clone()); + }) + .or_insert(HashSet::from([mentioned])); + } + + QueryResult::InsertClauseColumns(c) => { + let mentioned = MentionedColumn { + column: c.get_column(sql), + alias: None, + }; + + self.mentioned_columns + .entry(Some(WrappingClause::Insert)) + .and_modify(|s| { + s.insert(mentioned.clone()); + }) + .or_insert(HashSet::from([mentioned])); } }; } @@ -628,6 +638,17 @@ impl<'a> CompletionContext<'a> { } } + pub(crate) fn parent_matches_one_of_kind(&self, kinds: &[&'static str]) -> bool { + self.node_under_cursor + .as_ref() + .is_some_and(|under_cursor| match under_cursor { + NodeUnderCursor::TsNode(node) => node + .parent() + .is_some_and(|parent| kinds.contains(&parent.kind())), + + NodeUnderCursor::CustomNode { .. } => false, + }) + } pub(crate) fn before_cursor_matches_kind(&self, kinds: &[&'static str]) -> bool { self.node_under_cursor.as_ref().is_some_and(|under_cursor| { match under_cursor { diff --git a/crates/pgt_completions/src/providers/columns.rs b/crates/pgt_completions/src/providers/columns.rs index b23b0ccf..148504b9 100644 --- a/crates/pgt_completions/src/providers/columns.rs +++ b/crates/pgt_completions/src/providers/columns.rs @@ -621,6 +621,21 @@ mod tests { ) .await; + // works with completed statement + assert_complete_results( + format!( + "insert into instruments (name, {}) values ('my_bass');", + CURSOR_POS + ) + .as_str(), + vec![ + CompletionAssertion::Label("id".to_string()), + CompletionAssertion::Label("z".to_string()), + ], + setup, + ) + .await; + // no completions in the values list! assert_no_complete_results( format!("insert into instruments (id, name) values ({})", CURSOR_POS).as_str(), diff --git a/crates/pgt_completions/src/providers/tables.rs b/crates/pgt_completions/src/providers/tables.rs index e372125f..96d327de 100644 --- a/crates/pgt_completions/src/providers/tables.rs +++ b/crates/pgt_completions/src/providers/tables.rs @@ -402,5 +402,31 @@ mod tests { setup, ) .await; + + assert_complete_results( + format!("insert into auth.{}", CURSOR_POS).as_str(), + vec![CompletionAssertion::LabelAndKind( + "users".into(), + CompletionItemKind::Table, + )], + setup, + ) + .await; + + // works with complete statement. + assert_complete_results( + format!( + "insert into {} (name, email) values ('jules', 'a@b.com');", + CURSOR_POS + ) + .as_str(), + vec![ + CompletionAssertion::LabelAndKind("public".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("auth".into(), CompletionItemKind::Schema), + CompletionAssertion::LabelAndKind("users".into(), CompletionItemKind::Table), + ], + setup, + ) + .await; } } diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index 8693f5c0..e0399591 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -82,7 +82,9 @@ impl CompletionFilter<'_> { ctx.wrapping_node_kind .as_ref() .is_none_or(|n| n != &WrappingNode::List) - && ctx.before_cursor_matches_kind(&["keyword_into"]) + && (ctx.before_cursor_matches_kind(&["keyword_into"]) + || (ctx.before_cursor_matches_kind(&["."]) + && ctx.parent_matches_one_of_kind(&["object_reference"]))) } WrappingClause::DropTable | WrappingClause::AlterTable => ctx diff --git a/crates/pgt_treesitter_queries/src/queries/insert_columns.rs b/crates/pgt_treesitter_queries/src/queries/insert_columns.rs new file mode 100644 index 00000000..3e88d998 --- /dev/null +++ b/crates/pgt_treesitter_queries/src/queries/insert_columns.rs @@ -0,0 +1,150 @@ +use std::sync::LazyLock; + +use crate::{Query, QueryResult}; + +use super::QueryTryFrom; + +static TS_QUERY: LazyLock = LazyLock::new(|| { + static QUERY_STR: &str = r#" + (insert + (object_reference) + (list + "("? + (column) @column + ","? + ")"? + ) + ) +"#; + tree_sitter::Query::new(tree_sitter_sql::language(), QUERY_STR).expect("Invalid TS Query") +}); + +#[derive(Debug)] +pub struct InsertColumnMatch<'a> { + pub(crate) column: tree_sitter::Node<'a>, +} + +impl InsertColumnMatch<'_> { + pub fn get_column(&self, sql: &str) -> String { + self.column + .utf8_text(sql.as_bytes()) + .expect("Failed to get column from ColumnMatch") + .to_string() + } +} + +impl<'a> TryFrom<&'a QueryResult<'a>> for &'a InsertColumnMatch<'a> { + type Error = String; + + fn try_from(q: &'a QueryResult<'a>) -> Result { + match q { + QueryResult::InsertClauseColumns(c) => Ok(c), + + #[allow(unreachable_patterns)] + _ => Err("Invalid QueryResult type".into()), + } + } +} + +impl<'a> QueryTryFrom<'a> for InsertColumnMatch<'a> { + type Ref = &'a InsertColumnMatch<'a>; +} + +impl<'a> Query<'a> for InsertColumnMatch<'a> { + fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec> { + let mut cursor = tree_sitter::QueryCursor::new(); + + let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes()); + + let mut to_return = vec![]; + + for m in matches { + if m.captures.len() == 1 { + let capture = m.captures[0].node; + to_return.push(QueryResult::InsertClauseColumns(InsertColumnMatch { + column: capture, + })); + } + } + + to_return + } +} +#[cfg(test)] +mod tests { + use super::InsertColumnMatch; + use crate::TreeSitterQueriesExecutor; + + #[test] + fn finds_all_insert_columns() { + let sql = r#"insert into users (id, email, name) values (1, 'a@b.com', 'Alice');"#; + + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + + let tree = parser.parse(sql, None).unwrap(); + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql); + + executor.add_query_results::(); + + let results: Vec<&InsertColumnMatch> = executor + .get_iter(None) + .filter_map(|q| q.try_into().ok()) + .collect(); + + let columns: Vec = results.iter().map(|c| c.get_column(sql)).collect(); + + assert_eq!(columns, vec!["id", "email", "name"]); + } + + #[test] + fn finds_insert_columns_with_whitespace_and_commas() { + let sql = r#" + insert into users ( + id, + email, + name + ) values (1, 'a@b.com', 'Alice'); + "#; + + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + + let tree = parser.parse(sql, None).unwrap(); + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql); + + executor.add_query_results::(); + + let results: Vec<&InsertColumnMatch> = executor + .get_iter(None) + .filter_map(|q| q.try_into().ok()) + .collect(); + + let columns: Vec = results.iter().map(|c| c.get_column(sql)).collect(); + + assert_eq!(columns, vec!["id", "email", "name"]); + } + + #[test] + fn returns_empty_for_insert_without_columns() { + let sql = r#"insert into users values (1, 'a@b.com', 'Alice');"#; + + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + + let tree = parser.parse(sql, None).unwrap(); + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql); + + executor.add_query_results::(); + + let results: Vec<&InsertColumnMatch> = executor + .get_iter(None) + .filter_map(|q| q.try_into().ok()) + .collect(); + + assert!(results.is_empty()); + } +} diff --git a/crates/pgt_treesitter_queries/src/queries/mod.rs b/crates/pgt_treesitter_queries/src/queries/mod.rs index e02d675b..aae7e1a1 100644 --- a/crates/pgt_treesitter_queries/src/queries/mod.rs +++ b/crates/pgt_treesitter_queries/src/queries/mod.rs @@ -1,7 +1,9 @@ +mod insert_columns; mod relations; mod select_columns; mod table_aliases; +pub use insert_columns::*; pub use relations::*; pub use select_columns::*; pub use table_aliases::*; @@ -11,6 +13,7 @@ pub enum QueryResult<'a> { Relation(RelationMatch<'a>), TableAliases(TableAliasMatch<'a>), SelectClauseColumns(SelectColumnMatch<'a>), + InsertClauseColumns(InsertColumnMatch<'a>), } impl QueryResult<'_> { @@ -41,6 +44,11 @@ impl QueryResult<'_> { start >= range.start_point && end <= range.end_point } + Self::InsertClauseColumns(cm) => { + let start = cm.column.start_position(); + let end = cm.column.end_position(); + start >= range.start_point && end <= range.end_point + } } } } diff --git a/crates/pgt_treesitter_queries/src/queries/relations.rs b/crates/pgt_treesitter_queries/src/queries/relations.rs index f9061ce8..38fd0513 100644 --- a/crates/pgt_treesitter_queries/src/queries/relations.rs +++ b/crates/pgt_treesitter_queries/src/queries/relations.rs @@ -14,6 +14,14 @@ static TS_QUERY: LazyLock = LazyLock::new(|| { (identifier)? @table )+ ) + (insert + (object_reference + . + (identifier) @schema_or_table + "."? + (identifier)? @table + )+ + ) "#; tree_sitter::Query::new(tree_sitter_sql::language(), QUERY_STR).expect("Invalid TS Query") }); @@ -91,3 +99,101 @@ impl<'a> Query<'a> for RelationMatch<'a> { to_return } } + +#[cfg(test)] +mod tests { + use super::RelationMatch; + use crate::TreeSitterQueriesExecutor; + + #[test] + fn finds_table_without_schema() { + let sql = r#"select * from users;"#; + + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + + let tree = parser.parse(sql, None).unwrap(); + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql); + + executor.add_query_results::(); + + let results: Vec<&RelationMatch> = executor + .get_iter(None) + .filter_map(|q| q.try_into().ok()) + .collect(); + + assert_eq!(results.len(), 1); + assert_eq!(results[0].get_schema(sql), None); + assert_eq!(results[0].get_table(sql), "users"); + } + + #[test] + fn finds_table_with_schema() { + let sql = r#"select * from public.users;"#; + + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + + let tree = parser.parse(sql, None).unwrap(); + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql); + + executor.add_query_results::(); + + let results: Vec<&RelationMatch> = executor + .get_iter(None) + .filter_map(|q| q.try_into().ok()) + .collect(); + + assert_eq!(results.len(), 1); + assert_eq!(results[0].get_schema(sql), Some("public".to_string())); + assert_eq!(results[0].get_table(sql), "users"); + } + + #[test] + fn finds_insert_into_with_schema_and_table() { + let sql = r#"insert into auth.accounts (id, email) values (1, 'a@b.com');"#; + + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + + let tree = parser.parse(sql, None).unwrap(); + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql); + + executor.add_query_results::(); + + let results: Vec<&RelationMatch> = executor + .get_iter(None) + .filter_map(|q| q.try_into().ok()) + .collect(); + + assert_eq!(results.len(), 1); + assert_eq!(results[0].get_schema(sql), Some("auth".to_string())); + assert_eq!(results[0].get_table(sql), "accounts"); + } + + #[test] + fn finds_insert_into_without_schema() { + let sql = r#"insert into users (id, email) values (1, 'a@b.com');"#; + + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + + let tree = parser.parse(sql, None).unwrap(); + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql); + + executor.add_query_results::(); + + let results: Vec<&RelationMatch> = executor + .get_iter(None) + .filter_map(|q| q.try_into().ok()) + .collect(); + + assert_eq!(results.len(), 1); + assert_eq!(results[0].get_schema(sql), None); + assert_eq!(results[0].get_table(sql), "users"); + } +} From 4225d56be9635d3760540a7f17098a7f362d5450 Mon Sep 17 00:00:00 2001 From: Julian Date: Sat, 24 May 2025 14:54:07 +0200 Subject: [PATCH 10/13] fine! --- crates/pgt_completions/src/context/mod.rs | 79 +++++++++---------- .../src/relevance/filtering.rs | 20 +++-- 2 files changed, 46 insertions(+), 53 deletions(-) diff --git a/crates/pgt_completions/src/context/mod.rs b/crates/pgt_completions/src/context/mod.rs index 79f85a84..91bc07b0 100644 --- a/crates/pgt_completions/src/context/mod.rs +++ b/crates/pgt_completions/src/context/mod.rs @@ -549,9 +549,7 @@ impl<'a> CompletionContext<'a> { clauses_with_score.sort_by(|(_, score_a), (_, score_b)| score_b.cmp(score_a)); clauses_with_score - .iter() - .filter(|(_, score)| *score > 0) - .next() + .iter().find(|(_, score)| *score > 0) .map(|c| c.0.clone()) } @@ -559,50 +557,47 @@ impl<'a> CompletionContext<'a> { let mut first_sibling = self.get_first_sibling(node); if let Some(clause) = self.wrapping_clause_type.as_ref() { - match clause { - WrappingClause::Insert => { - while let Some(sib) = first_sibling.next_sibling() { - match sib.kind() { - "object_reference" => { - if let Some(NodeText::Original(txt)) = - self.get_ts_node_content(&sib) - { - let mut iter = txt.split('.').rev(); - let table = iter.next().unwrap().to_string(); - let schema = iter.next().map(|s| s.to_string()); - self.mentioned_relations - .entry(schema) - .and_modify(|s| { - s.insert(table.clone()); - }) - .or_insert(HashSet::from([table])); - } + if clause == &WrappingClause::Insert { + while let Some(sib) = first_sibling.next_sibling() { + match sib.kind() { + "object_reference" => { + if let Some(NodeText::Original(txt)) = + self.get_ts_node_content(&sib) + { + let mut iter = txt.split('.').rev(); + let table = iter.next().unwrap().to_string(); + let schema = iter.next().map(|s| s.to_string()); + self.mentioned_relations + .entry(schema) + .and_modify(|s| { + s.insert(table.clone()); + }) + .or_insert(HashSet::from([table])); } - "column" => { - if let Some(NodeText::Original(txt)) = - self.get_ts_node_content(&sib) - { - let entry = MentionedColumn { - column: txt, - alias: None, - }; - - self.mentioned_columns - .entry(Some(WrappingClause::Insert)) - .and_modify(|s| { - s.insert(entry.clone()); - }) - .or_insert(HashSet::from([entry])); - } + } + "column" => { + if let Some(NodeText::Original(txt)) = + self.get_ts_node_content(&sib) + { + let entry = MentionedColumn { + column: txt, + alias: None, + }; + + self.mentioned_columns + .entry(Some(WrappingClause::Insert)) + .and_modify(|s| { + s.insert(entry.clone()); + }) + .or_insert(HashSet::from([entry])); } - - _ => {} } - first_sibling = sib; + _ => {} } + + first_sibling = sib; } - _ => {} } } } @@ -654,7 +649,7 @@ impl<'a> CompletionContext<'a> { self.node_under_cursor.as_ref().is_some_and(|under_cursor| { match under_cursor { NodeUnderCursor::TsNode(node) => { - let mut current = node.clone(); + let mut current = *node; // move up to the parent until we're at top OR we have a prev sibling while current.prev_sibling().is_none() && current.parent().is_some() { diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index e0399591..cb6d2cf6 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -123,14 +123,13 @@ impl CompletionFilter<'_> { } } - CompletionRelevanceData::Function(_) => match clause { + CompletionRelevanceData::Function(_) => matches!( + clause, WrappingClause::From - | WrappingClause::Select - | WrappingClause::Where - | WrappingClause::Join { .. } => true, - - _ => false, - }, + | WrappingClause::Select + | WrappingClause::Where + | WrappingClause::Join { .. } + ), CompletionRelevanceData::Schema(_) => match clause { WrappingClause::Select @@ -157,10 +156,9 @@ impl CompletionFilter<'_> { _ => false, }, - CompletionRelevanceData::Policy(_) => match clause { - WrappingClause::PolicyName => true, - _ => false, - }, + CompletionRelevanceData::Policy(_) => { + matches!(clause, WrappingClause::PolicyName) + } } }) .and_then(|is_ok| if is_ok { Some(()) } else { None }) From c084c5883de983732b841c1da313f239fc998802 Mon Sep 17 00:00:00 2001 From: Julian Date: Sat, 24 May 2025 14:55:54 +0200 Subject: [PATCH 11/13] lindt --- crates/pgt_workspace/src/workspace/server/sql_function.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/pgt_workspace/src/workspace/server/sql_function.rs b/crates/pgt_workspace/src/workspace/server/sql_function.rs index 48f91ef4..bc2c6c3b 100644 --- a/crates/pgt_workspace/src/workspace/server/sql_function.rs +++ b/crates/pgt_workspace/src/workspace/server/sql_function.rs @@ -15,6 +15,7 @@ pub struct SQLFunctionArg { #[derive(Debug, Clone)] pub struct SQLFunctionSignature { + #[allow(dead_code)] pub schema: Option, pub name: String, pub args: Vec, From b8c064311285582e36d8cd40b5081831ec4aa684 Mon Sep 17 00:00:00 2001 From: Julian Date: Sat, 24 May 2025 15:38:30 +0200 Subject: [PATCH 12/13] for maart --- crates/pgt_completions/src/context/mod.rs | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/crates/pgt_completions/src/context/mod.rs b/crates/pgt_completions/src/context/mod.rs index 91bc07b0..d034de09 100644 --- a/crates/pgt_completions/src/context/mod.rs +++ b/crates/pgt_completions/src/context/mod.rs @@ -549,7 +549,8 @@ impl<'a> CompletionContext<'a> { clauses_with_score.sort_by(|(_, score_a), (_, score_b)| score_b.cmp(score_a)); clauses_with_score - .iter().find(|(_, score)| *score > 0) + .iter() + .find(|(_, score)| *score > 0) .map(|c| c.0.clone()) } @@ -561,9 +562,7 @@ impl<'a> CompletionContext<'a> { while let Some(sib) = first_sibling.next_sibling() { match sib.kind() { "object_reference" => { - if let Some(NodeText::Original(txt)) = - self.get_ts_node_content(&sib) - { + if let Some(NodeText::Original(txt)) = self.get_ts_node_content(&sib) { let mut iter = txt.split('.').rev(); let table = iter.next().unwrap().to_string(); let schema = iter.next().map(|s| s.to_string()); @@ -576,9 +575,7 @@ impl<'a> CompletionContext<'a> { } } "column" => { - if let Some(NodeText::Original(txt)) = - self.get_ts_node_content(&sib) - { + if let Some(NodeText::Original(txt)) = self.get_ts_node_content(&sib) { let entry = MentionedColumn { column: txt, alias: None, From a970f7eebf34e3062fd45744fc718af38f632f02 Mon Sep 17 00:00:00 2001 From: Julian Domke <68325451+juleswritescode@users.noreply.github.com> Date: Sat, 24 May 2025 19:20:17 +0200 Subject: [PATCH 13/13] feat(completions): improve completions in WHERE clauses (#403) --- crates/pgt_completions/src/context/mod.rs | 21 +++- .../pgt_completions/src/providers/columns.rs | 86 ++++++++++++++++- .../src/relevance/filtering.rs | 12 ++- crates/pgt_completions/src/sanitization.rs | 19 ++-- .../pgt_treesitter_queries/src/queries/mod.rs | 13 +++ .../src/queries/where_columns.rs | 96 +++++++++++++++++++ 6 files changed, 230 insertions(+), 17 deletions(-) create mode 100644 crates/pgt_treesitter_queries/src/queries/where_columns.rs diff --git a/crates/pgt_completions/src/context/mod.rs b/crates/pgt_completions/src/context/mod.rs index d034de09..0bb190a9 100644 --- a/crates/pgt_completions/src/context/mod.rs +++ b/crates/pgt_completions/src/context/mod.rs @@ -237,6 +237,7 @@ impl<'a> CompletionContext<'a> { executor.add_query_results::(); executor.add_query_results::(); executor.add_query_results::(); + executor.add_query_results::(); for relation_match in executor.get_iter(stmt_range) { match relation_match { @@ -251,6 +252,7 @@ impl<'a> CompletionContext<'a> { }) .or_insert(HashSet::from([table_name])); } + QueryResult::TableAliases(table_alias_match) => { self.mentioned_table_aliases.insert( table_alias_match.get_alias(sql), @@ -272,6 +274,20 @@ impl<'a> CompletionContext<'a> { .or_insert(HashSet::from([mentioned])); } + QueryResult::WhereClauseColumns(c) => { + let mentioned = MentionedColumn { + column: c.get_column(sql), + alias: c.get_alias(sql), + }; + + self.mentioned_columns + .entry(Some(WrappingClause::Where)) + .and_modify(|s| { + s.insert(mentioned.clone()); + }) + .or_insert(HashSet::from([mentioned])); + } + QueryResult::InsertClauseColumns(c) => { let mentioned = MentionedColumn { column: c.get_column(sql), @@ -359,8 +375,9 @@ impl<'a> CompletionContext<'a> { let parent_node_kind = parent_node.kind(); let current_node_kind = current_node.kind(); - // prevent infinite recursion – this can happen if we only have a PROGRAM node - if current_node_kind == parent_node_kind { + // prevent infinite recursion – this can happen with ERROR nodes + if current_node_kind == parent_node_kind && ["ERROR", "program"].contains(&parent_node_kind) + { self.node_under_cursor = Some(NodeUnderCursor::from(current_node)); return; } diff --git a/crates/pgt_completions/src/providers/columns.rs b/crates/pgt_completions/src/providers/columns.rs index 148504b9..a040bab1 100644 --- a/crates/pgt_completions/src/providers/columns.rs +++ b/crates/pgt_completions/src/providers/columns.rs @@ -23,7 +23,12 @@ pub fn complete_columns<'a>(ctx: &CompletionContext<'a>, builder: &mut Completio }; // autocomplete with the alias in a join clause if we find one - if matches!(ctx.wrapping_clause_type, Some(WrappingClause::Join { .. })) { + if matches!( + ctx.wrapping_clause_type, + Some(WrappingClause::Join { .. }) + | Some(WrappingClause::Where) + | Some(WrappingClause::Select) + ) { item.completion_text = find_matching_alias_for_table(ctx, col.table_name.as_str()) .and_then(|alias| { get_completion_text_with_schema_or_alias(ctx, col.name.as_str(), alias.as_str()) @@ -36,6 +41,8 @@ pub fn complete_columns<'a>(ctx: &CompletionContext<'a>, builder: &mut Completio #[cfg(test)] mod tests { + use std::vec; + use crate::{ CompletionItem, CompletionItemKind, complete, test_helper::{ @@ -643,4 +650,81 @@ mod tests { ) .await; } + + #[tokio::test] + async fn suggests_columns_in_where_clause() { + let setup = r#" + create table instruments ( + id bigint primary key generated always as identity, + name text not null, + z text, + created_at timestamp with time zone default now() + ); + + create table others ( + a text, + b text, + c text + ); + "#; + + assert_complete_results( + format!("select name from instruments where {} ", CURSOR_POS).as_str(), + vec![ + CompletionAssertion::Label("created_at".into()), + CompletionAssertion::Label("id".into()), + CompletionAssertion::Label("name".into()), + CompletionAssertion::Label("z".into()), + ], + setup, + ) + .await; + + assert_complete_results( + format!( + "select name from instruments where z = 'something' and created_at > {}", + CURSOR_POS + ) + .as_str(), + // simply do not complete columns + schemas; functions etc. are ok + vec![ + CompletionAssertion::KindNotExists(CompletionItemKind::Column), + CompletionAssertion::KindNotExists(CompletionItemKind::Schema), + ], + setup, + ) + .await; + + // prefers not mentioned columns + assert_complete_results( + format!( + "select name from instruments where id = 'something' and {}", + CURSOR_POS + ) + .as_str(), + vec![ + CompletionAssertion::Label("created_at".into()), + CompletionAssertion::Label("name".into()), + CompletionAssertion::Label("z".into()), + ], + setup, + ) + .await; + + // // uses aliases + assert_complete_results( + format!( + "select name from instruments i join others o on i.z = o.a where i.{}", + CURSOR_POS + ) + .as_str(), + vec![ + CompletionAssertion::Label("created_at".into()), + CompletionAssertion::Label("id".into()), + CompletionAssertion::Label("name".into()), + ], + setup, + ) + .await; + } } diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index cb6d2cf6..5323e2bc 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -119,6 +119,13 @@ impl CompletionFilter<'_> { .as_ref() .is_some_and(|n| n == &WrappingNode::List), + // only autocomplete left side of binary expression + WrappingClause::Where => { + ctx.before_cursor_matches_kind(&["keyword_and", "keyword_where"]) + || (ctx.before_cursor_matches_kind(&["."]) + && ctx.parent_matches_one_of_kind(&["field"])) + } + _ => true, } } @@ -133,12 +140,15 @@ impl CompletionFilter<'_> { CompletionRelevanceData::Schema(_) => match clause { WrappingClause::Select - | WrappingClause::Where | WrappingClause::From | WrappingClause::Join { .. } | WrappingClause::Update | WrappingClause::Delete => true, + WrappingClause::Where => { + ctx.before_cursor_matches_kind(&["keyword_and", "keyword_where"]) + } + WrappingClause::DropTable | WrappingClause::AlterTable => ctx .before_cursor_matches_kind(&[ "keyword_exists", diff --git a/crates/pgt_completions/src/sanitization.rs b/crates/pgt_completions/src/sanitization.rs index 8b1cae6e..40dea7e6 100644 --- a/crates/pgt_completions/src/sanitization.rs +++ b/crates/pgt_completions/src/sanitization.rs @@ -193,11 +193,6 @@ fn cursor_before_semicolon(tree: &tree_sitter::Tree, position: TextSize) -> bool return false; } - // not okay to be on the semi. - if byte == leaf_node.start_byte() { - return false; - } - leaf_node .prev_named_sibling() .map(|n| n.end_byte() < byte) @@ -355,19 +350,17 @@ mod tests { // select * from| ; <-- still touches the from assert!(!cursor_before_semicolon(&tree, TextSize::new(13))); - // not okay to be ON the semi. - // select * from |; - assert!(!cursor_before_semicolon(&tree, TextSize::new(18))); - // anything is fine here - // select * from | ; - // select * from | ; - // select * from | ; - // select * from |; + // select * from | ; + // select * from | ; + // select * from | ; + // select * from | ; + // select * from |; assert!(cursor_before_semicolon(&tree, TextSize::new(14))); assert!(cursor_before_semicolon(&tree, TextSize::new(15))); assert!(cursor_before_semicolon(&tree, TextSize::new(16))); assert!(cursor_before_semicolon(&tree, TextSize::new(17))); + assert!(cursor_before_semicolon(&tree, TextSize::new(18))); } #[test] diff --git a/crates/pgt_treesitter_queries/src/queries/mod.rs b/crates/pgt_treesitter_queries/src/queries/mod.rs index 2d957872..b9f39aed 100644 --- a/crates/pgt_treesitter_queries/src/queries/mod.rs +++ b/crates/pgt_treesitter_queries/src/queries/mod.rs @@ -3,12 +3,14 @@ mod parameters; mod relations; mod select_columns; mod table_aliases; +mod where_columns; pub use insert_columns::*; pub use parameters::*; pub use relations::*; pub use select_columns::*; pub use table_aliases::*; +pub use where_columns::*; #[derive(Debug)] pub enum QueryResult<'a> { @@ -17,6 +19,7 @@ pub enum QueryResult<'a> { TableAliases(TableAliasMatch<'a>), SelectClauseColumns(SelectColumnMatch<'a>), InsertClauseColumns(InsertColumnMatch<'a>), + WhereClauseColumns(WhereColumnMatch<'a>), } impl QueryResult<'_> { @@ -53,6 +56,16 @@ impl QueryResult<'_> { start >= range.start_point && end <= range.end_point } + Self::WhereClauseColumns(cm) => { + let start = match cm.alias { + Some(n) => n.start_position(), + None => cm.column.start_position(), + }; + + let end = cm.column.end_position(); + + start >= range.start_point && end <= range.end_point + } Self::InsertClauseColumns(cm) => { let start = cm.column.start_position(); let end = cm.column.end_position(); diff --git a/crates/pgt_treesitter_queries/src/queries/where_columns.rs b/crates/pgt_treesitter_queries/src/queries/where_columns.rs new file mode 100644 index 00000000..8e19590d --- /dev/null +++ b/crates/pgt_treesitter_queries/src/queries/where_columns.rs @@ -0,0 +1,96 @@ +use std::sync::LazyLock; + +use crate::{Query, QueryResult}; + +use super::QueryTryFrom; + +static TS_QUERY: LazyLock = LazyLock::new(|| { + static QUERY_STR: &str = r#" + (where + (binary_expression + (binary_expression + (field + (object_reference)? @alias + "."? + (identifier) @column + ) + ) + ) + ) +"#; + tree_sitter::Query::new(tree_sitter_sql::language(), QUERY_STR).expect("Invalid TS Query") +}); + +#[derive(Debug)] +pub struct WhereColumnMatch<'a> { + pub(crate) alias: Option>, + pub(crate) column: tree_sitter::Node<'a>, +} + +impl WhereColumnMatch<'_> { + pub fn get_alias(&self, sql: &str) -> Option { + let str = self + .alias + .as_ref()? + .utf8_text(sql.as_bytes()) + .expect("Failed to get alias from ColumnMatch"); + + Some(str.to_string()) + } + + pub fn get_column(&self, sql: &str) -> String { + self.column + .utf8_text(sql.as_bytes()) + .expect("Failed to get column from ColumnMatch") + .to_string() + } +} + +impl<'a> TryFrom<&'a QueryResult<'a>> for &'a WhereColumnMatch<'a> { + type Error = String; + + fn try_from(q: &'a QueryResult<'a>) -> Result { + match q { + QueryResult::WhereClauseColumns(c) => Ok(c), + + #[allow(unreachable_patterns)] + _ => Err("Invalid QueryResult type".into()), + } + } +} + +impl<'a> QueryTryFrom<'a> for WhereColumnMatch<'a> { + type Ref = &'a WhereColumnMatch<'a>; +} + +impl<'a> Query<'a> for WhereColumnMatch<'a> { + fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec> { + let mut cursor = tree_sitter::QueryCursor::new(); + + let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes()); + + let mut to_return = vec![]; + + for m in matches { + if m.captures.len() == 1 { + let capture = m.captures[0].node; + to_return.push(QueryResult::WhereClauseColumns(WhereColumnMatch { + alias: None, + column: capture, + })); + } + + if m.captures.len() == 2 { + let alias = m.captures[0].node; + let column = m.captures[1].node; + + to_return.push(QueryResult::WhereClauseColumns(WhereColumnMatch { + alias: Some(alias), + column, + })); + } + } + + to_return + } +}