Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve auto completions #310

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/pgt_completions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pgt_treesitter_queries.workspace = true
schemars = { workspace = true, optional = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
tracing = { workspace = true }
tree-sitter.workspace = true
tree_sitter_sql.workspace = true

Expand Down
129 changes: 127 additions & 2 deletions crates/pgt_completions/src/complete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,49 @@ pub struct CompletionParams<'a> {
pub tree: Option<&'a tree_sitter::Tree>,
}

pub fn complete(params: CompletionParams) -> Vec<CompletionItem> {
let ctx = CompletionContext::new(&params);
#[tracing::instrument(level = "debug", skip_all, fields(
text = params.text,
position = params.position.to_string()
))]
pub fn complete(mut params: CompletionParams) -> Vec<CompletionItem> {
let should_adjust_params = params.tree.is_some()
&& (cursor_inbetween_nodes(params.tree.unwrap(), params.position)
|| cursor_prepared_to_write_token_after_last_node(
params.tree.unwrap(),
params.position,
));

let usable_sql = if should_adjust_params {
let pos: usize = params.position.into();

let mut mutated_sql = String::new();

for (idx, c) in params.text.chars().enumerate() {
if idx == pos {
mutated_sql.push_str("REPLACED_TOKEN ");
}
mutated_sql.push(c);
}

mutated_sql
} else {
params.text
};

let usable_tree = if should_adjust_params {
let mut parser = tree_sitter::Parser::new();
parser
.set_language(tree_sitter_sql::language())
.expect("Error loading sql language");
parser.parse(usable_sql.clone(), None)
} else {
tracing::info!("We're reusing the previous tree.");
None
};

params.text = usable_sql;

let ctx = CompletionContext::new(&params, usable_tree.as_ref().or(params.tree));

let mut builder = CompletionBuilder::new();

Expand All @@ -28,3 +69,87 @@ pub fn complete(params: CompletionParams) -> Vec<CompletionItem> {

builder.finish()
}

fn cursor_inbetween_nodes(tree: &tree_sitter::Tree, position: TextSize) -> bool {
let mut cursor = tree.walk();
let mut node = tree.root_node();

loop {
let child_dx = cursor.goto_first_child_for_byte(position.into());
if child_dx.is_none() {
break;
}
node = cursor.node();
}

let byte = position.into();

// Return true if the cursor is NOT within the node's bounds, INCLUSIVE
!(node.start_byte() <= byte && node.end_byte() >= byte)
}

fn cursor_prepared_to_write_token_after_last_node(
tree: &tree_sitter::Tree,
position: TextSize,
) -> bool {
let cursor_pos: usize = position.into();
cursor_pos == tree.root_node().end_byte() + 1
}

#[cfg(test)]
mod tests {
use pgt_text_size::TextSize;

use crate::complete::{cursor_inbetween_nodes, cursor_prepared_to_write_token_after_last_node};

#[test]
fn test_cursor_inbetween_nodes() {
let input = "select from users;";

let mut parser = tree_sitter::Parser::new();
parser
.set_language(tree_sitter_sql::language())
.expect("Error loading sql language");

let mut tree = parser.parse(input.to_string(), None).unwrap();

// select | from users;
assert!(cursor_inbetween_nodes(&mut tree, TextSize::new(7)));

// select |from users;
assert!(!cursor_inbetween_nodes(&mut tree, TextSize::new(8)));

// select| from users;
assert!(!cursor_inbetween_nodes(&mut tree, TextSize::new(6)));
}

#[test]
fn test_cursor_after_nodes() {
let input = "select * from ";

let mut parser = tree_sitter::Parser::new();
parser
.set_language(tree_sitter_sql::language())
.expect("Error loading sql language");

let mut tree = parser.parse(input.to_string(), None).unwrap();

// select * from|; <-- still on previous token
assert!(!cursor_prepared_to_write_token_after_last_node(
&mut tree,
TextSize::new(14)
));

// select * from |; <-- too far off
assert!(!cursor_prepared_to_write_token_after_last_node(
&mut tree,
TextSize::new(16)
));

// select * from |; <-- just right
assert!(cursor_prepared_to_write_token_after_last_node(
&mut tree,
TextSize::new(15)
));
}
}
50 changes: 29 additions & 21 deletions crates/pgt_completions/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ impl TryFrom<String> for ClauseType {
}

pub(crate) struct CompletionContext<'a> {
pub ts_node: Option<tree_sitter::Node<'a>>,
pub node_under_cursor: Option<tree_sitter::Node<'a>>,

pub tree: Option<&'a tree_sitter::Tree>,
pub text: &'a str,
pub schema_cache: &'a SchemaCache,
Expand All @@ -64,21 +65,24 @@ pub(crate) struct CompletionContext<'a> {
}

impl<'a> CompletionContext<'a> {
pub fn new(params: &'a CompletionParams) -> Self {
pub fn new(params: &'a CompletionParams, usable_tree: Option<&'a tree_sitter::Tree>) -> Self {
let mut ctx = Self {
tree: params.tree,
tree: usable_tree,
text: &params.text,
schema_cache: params.schema,
position: usize::from(params.position),
ts_node: None,
node_under_cursor: None,
schema_name: None,
wrapping_clause_type: None,
wrapping_statement_range: None,
is_invocation: false,
mentioned_relations: HashMap::new(),
};

tracing::warn!("gathering tree context");
ctx.gather_tree_context();

tracing::warn!("gathering info from ts query");
ctx.gather_info_from_ts_queries();

ctx
Expand Down Expand Up @@ -155,20 +159,20 @@ impl<'a> CompletionContext<'a> {
fn gather_context_from_node(
&mut self,
mut cursor: tree_sitter::TreeCursor<'a>,
previous_node: tree_sitter::Node<'a>,
parent_node: tree_sitter::Node<'a>,
) {
let current_node = cursor.node();

// prevent infinite recursion – this can happen if we only have a PROGRAM node
if current_node.kind() == previous_node.kind() {
self.ts_node = Some(current_node);
if current_node.kind() == parent_node.kind() {
self.node_under_cursor = Some(current_node);
return;
}

match previous_node.kind() {
match parent_node.kind() {
"statement" | "subquery" => {
self.wrapping_clause_type = current_node.kind().try_into().ok();
self.wrapping_statement_range = Some(previous_node.range());
self.wrapping_statement_range = Some(parent_node.range());
}
"invocation" => self.is_invocation = true,

Expand Down Expand Up @@ -200,7 +204,11 @@ impl<'a> CompletionContext<'a> {

// We have arrived at the leaf node
if current_node.child_count() == 0 {
self.ts_node = Some(current_node);
if self.get_ts_node_content(current_node).unwrap() == "REPLACED_TOKEN" {
self.node_under_cursor = None;
} else {
self.node_under_cursor = Some(current_node);
}
return;
}

Expand Down Expand Up @@ -266,7 +274,7 @@ mod tests {
schema: &pgt_schema_cache::SchemaCache::default(),
};

let ctx = CompletionContext::new(&params);
let ctx = CompletionContext::new(&params, Some(&tree));

assert_eq!(ctx.wrapping_clause_type, expected_clause.try_into().ok());
}
Expand Down Expand Up @@ -298,7 +306,7 @@ mod tests {
schema: &pgt_schema_cache::SchemaCache::default(),
};

let ctx = CompletionContext::new(&params);
let ctx = CompletionContext::new(&params, Some(&tree));

assert_eq!(ctx.schema_name, expected_schema.map(|f| f.to_string()));
}
Expand Down Expand Up @@ -332,7 +340,7 @@ mod tests {
schema: &pgt_schema_cache::SchemaCache::default(),
};

let ctx = CompletionContext::new(&params);
let ctx = CompletionContext::new(&params, Some(&tree));

assert_eq!(ctx.is_invocation, is_invocation);
}
Expand All @@ -357,9 +365,9 @@ mod tests {
schema: &pgt_schema_cache::SchemaCache::default(),
};

let ctx = CompletionContext::new(&params);
let ctx = CompletionContext::new(&params, Some(&tree));

let node = ctx.ts_node.unwrap();
let node = ctx.node_under_cursor.unwrap();

assert_eq!(ctx.get_ts_node_content(node), Some("select"));

Expand All @@ -385,9 +393,9 @@ mod tests {
schema: &pgt_schema_cache::SchemaCache::default(),
};

let ctx = CompletionContext::new(&params);
let ctx = CompletionContext::new(&params, Some(&tree));

let node = ctx.ts_node.unwrap();
let node = ctx.node_under_cursor.unwrap();

assert_eq!(ctx.get_ts_node_content(node), Some("from"));
assert_eq!(
Expand All @@ -411,9 +419,9 @@ mod tests {
schema: &pgt_schema_cache::SchemaCache::default(),
};

let ctx = CompletionContext::new(&params);
let ctx = CompletionContext::new(&params, Some(&tree));

let node = ctx.ts_node.unwrap();
let node = ctx.node_under_cursor.unwrap();

assert_eq!(ctx.get_ts_node_content(node), Some(""));
assert_eq!(ctx.wrapping_clause_type, None);
Expand All @@ -436,9 +444,9 @@ mod tests {
schema: &pgt_schema_cache::SchemaCache::default(),
};

let ctx = CompletionContext::new(&params);
let ctx = CompletionContext::new(&params, Some(&tree));

let node = ctx.ts_node.unwrap();
let node = ctx.node_under_cursor.unwrap();

assert_eq!(ctx.get_ts_node_content(node), Some("fro"));
assert_eq!(ctx.wrapping_clause_type, Some(ClauseType::Select));
Expand Down
Loading
Loading