Skip to content

Commit 12c86fd

Browse files
so far so good
1 parent f64661e commit 12c86fd

File tree

3 files changed

+97
-64
lines changed

3 files changed

+97
-64
lines changed

Diff for: Cargo.lock

+14
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: crates/pg_lsp/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ pg_schema_cache.workspace = true
3333
pg_workspace.workspace = true
3434
pg_diagnostics.workspace = true
3535
tokio = { version = "1.40.0", features = ["macros", "rt-multi-thread", "sync"] }
36+
tokio-util = "0.7.12"
3637

3738
[dev-dependencies]
3839

Diff for: crates/pg_lsp/src/server.rs

+82-64
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ use pg_hover::HoverParams;
2727
use pg_schema_cache::SchemaCache;
2828
use pg_workspace::Workspace;
2929
use serde::{de::DeserializeOwned, Serialize};
30-
use std::{collections::HashSet, sync::Arc, time::Duration};
30+
use std::{collections::HashSet, future::Future, sync::Arc, time::Duration};
3131
use text_size::TextSize;
32-
use threadpool::ThreadPool;
3332

34-
use tokio::sync::{mpsc, oneshot};
33+
use tokio::sync::mpsc;
34+
use tokio_util::sync::CancellationToken;
3535

3636
use crate::{
3737
client::{client_flags::ClientFlags, LspClient},
@@ -72,9 +72,9 @@ impl DbConnection {
7272
/// For now, we move it into a separate task and use tokio's channels to communicate.
7373
fn get_client_receiver(
7474
connection: Connection,
75-
) -> (mpsc::UnboundedReceiver<Message>, oneshot::Receiver<()>) {
75+
cancel_token: Arc<CancellationToken>,
76+
) -> mpsc::UnboundedReceiver<Message> {
7677
let (message_tx, message_rx) = mpsc::unbounded_channel();
77-
let (close_tx, close_rx) = oneshot::channel();
7878

7979
tokio::task::spawn(async move {
8080
// TODO: improve Result handling
@@ -83,7 +83,7 @@ fn get_client_receiver(
8383

8484
match msg {
8585
Message::Request(r) if connection.handle_shutdown(&r).unwrap() => {
86-
close_tx.send(()).unwrap();
86+
cancel_token.cancel();
8787
return;
8888
}
8989

@@ -92,16 +92,15 @@ fn get_client_receiver(
9292
}
9393
});
9494

95-
(message_rx, close_rx)
95+
message_rx
9696
}
9797

9898
pub struct Server {
9999
client_rx: mpsc::UnboundedReceiver<Message>,
100-
close_rx: oneshot::Receiver<()>,
100+
cancel_token: Arc<tokio_util::sync::CancellationToken>,
101101
client: LspClient,
102102
internal_tx: mpsc::UnboundedSender<InternalMessage>,
103103
internal_rx: mpsc::UnboundedReceiver<InternalMessage>,
104-
pool: Arc<ThreadPool>,
105104
client_flags: Arc<ClientFlags>,
106105
ide: Arc<Workspace>,
107106
db_conn: Option<DbConnection>,
@@ -138,10 +137,12 @@ impl Server {
138137
let cloned_pool = pool.clone();
139138
let cloned_client = client.clone();
140139

141-
let (client_rx, close_rx) = get_client_receiver(connection);
140+
let cancel_token = Arc::new(CancellationToken::new());
141+
142+
let client_rx = get_client_receiver(connection, cancel_token.clone());
142143

143144
let server = Self {
144-
close_rx,
145+
cancel_token,
145146
client_rx,
146147
internal_rx,
147148
internal_tx,
@@ -186,7 +187,6 @@ impl Server {
186187
});
187188
},
188189
),
189-
pool,
190190
};
191191

192192
Ok(server)
@@ -200,7 +200,7 @@ impl Server {
200200

201201
self.compute_debouncer.clear();
202202

203-
tokio::spawn(async move {
203+
self.spawn_with_cancel(async move {
204204
client
205205
.send_notification::<ShowMessage>(ShowMessageParams {
206206
typ: lsp_types::MessageType::INFO,
@@ -714,15 +714,17 @@ impl Server {
714714
Q: FnOnce() -> anyhow::Result<R> + Send + 'static,
715715
{
716716
let client = self.client.clone();
717-
self.pool.execute(move || match query() {
718-
Ok(result) => {
719-
let response = lsp_server::Response::new_ok(id, result);
720-
client.send_response(response).unwrap();
721-
}
722-
Err(why) => {
723-
client
724-
.send_error(id, ErrorCode::InternalError, why.to_string())
725-
.unwrap();
717+
self.spawn_with_cancel(async move {
718+
match query() {
719+
Ok(result) => {
720+
let response = lsp_server::Response::new_ok(id, result);
721+
client.send_response(response).unwrap();
722+
}
723+
Err(why) => {
724+
client
725+
.send_error(id, ErrorCode::InternalError, why.to_string())
726+
.unwrap();
727+
}
726728
}
727729
});
728730
}
@@ -748,9 +750,11 @@ impl Server {
748750
let client = self.client.clone();
749751
let ide = Arc::clone(&self.ide);
750752

751-
self.pool.execute(move || {
753+
self.spawn_with_cancel(async move {
752754
let response = lsp_server::Response::new_ok(id, query(&ide));
753-
client.send_response(response).unwrap();
755+
client
756+
.send_response(response)
757+
.expect("Failed to send query to client");
754758
});
755759
}
756760

@@ -791,22 +795,21 @@ impl Server {
791795
async fn process_messages(&mut self) -> anyhow::Result<()> {
792796
loop {
793797
tokio::select! {
794-
_ = &mut self.close_rx => {
798+
_ = self.cancel_token.cancelled() => {
799+
// Close the loop, proceed to shutdown.
795800
return Ok(())
796801
},
797802

798803
msg = self.internal_rx.recv() => {
799804
match msg {
800-
// TODO: handle internal sender close? Is that valid state?
801-
None => return Ok(()),
802-
Some(m) => self.handle_internal_message(m)
805+
None => panic!("The LSP's internal sender closed. This should never happen."),
806+
Some(m) => self.handle_internal_message(m).await
803807
}
804808
},
805809

806810
msg = self.client_rx.recv() => {
807811
match msg {
808-
// the client sender is closed, we can return
809-
None => return Ok(()),
812+
None => panic!("The LSP's client closed, but not via an 'exit' method. This should never happen."),
810813
Some(m) => self.handle_message(m)
811814
}
812815
},
@@ -848,14 +851,14 @@ impl Server {
848851
Ok(())
849852
}
850853

851-
fn handle_internal_message(&mut self, msg: InternalMessage) -> anyhow::Result<()> {
854+
async fn handle_internal_message(&mut self, msg: InternalMessage) -> anyhow::Result<()> {
852855
match msg {
853856
InternalMessage::SetSchemaCache(c) => {
854857
self.ide.set_schema_cache(c);
855858
self.compute_now();
856859
}
857860
InternalMessage::RefreshSchemaCache => {
858-
self.refresh_schema_cache();
861+
self.refresh_schema_cache().await;
859862
}
860863
InternalMessage::PublishDiagnostics(uri) => {
861864
self.publish_diagnostics(uri)?;
@@ -869,10 +872,6 @@ impl Server {
869872
}
870873

871874
fn pull_options(&mut self) {
872-
if !self.client_flags.has_configuration {
873-
return;
874-
}
875-
876875
let params = ConfigurationParams {
877876
items: vec![ConfigurationItem {
878877
section: Some("postgres_lsp".to_string()),
@@ -881,53 +880,72 @@ impl Server {
881880
};
882881

883882
let client = self.client.clone();
884-
let sender = self.internal_tx.clone();
885-
self.pool.execute(move || {
883+
let internal_tx = self.internal_tx.clone();
884+
self.spawn_with_cancel(async move {
886885
match client.send_request::<WorkspaceConfiguration>(params) {
887886
Ok(mut json) => {
888887
let options = client
889888
.parse_options(json.pop().expect("invalid configuration request"))
890889
.unwrap();
891890

892-
sender.send(InternalMessage::SetOptions(options)).unwrap();
891+
if let Err(why) = internal_tx.send(InternalMessage::SetOptions(options)) {
892+
println!("Failed to set internal options: {}", why);
893+
}
893894
}
894-
Err(_why) => {
895-
// log::error!("Retrieving configuration failed: {}", why);
895+
Err(why) => {
896+
println!("Retrieving configuration failed: {}", why);
896897
}
897898
};
898899
});
899900
}
900901

901902
fn register_configuration(&mut self) {
902-
if self.client_flags.will_push_configuration {
903-
let registration = Registration {
904-
id: "pull-config".to_string(),
905-
method: DidChangeConfiguration::METHOD.to_string(),
906-
register_options: None,
907-
};
903+
let registration = Registration {
904+
id: "pull-config".to_string(),
905+
method: DidChangeConfiguration::METHOD.to_string(),
906+
register_options: None,
907+
};
908908

909-
let params = RegistrationParams {
910-
registrations: vec![registration],
911-
};
909+
let params = RegistrationParams {
910+
registrations: vec![registration],
911+
};
912912

913-
let client = self.client.clone();
914-
self.pool.execute(move || {
915-
if let Err(_why) = client.send_request::<RegisterCapability>(params) {
916-
// log::error!(
917-
// "Failed to register \"{}\" notification: {}",
918-
// DidChangeConfiguration::METHOD,
919-
// why
920-
// );
921-
}
922-
});
923-
}
913+
let client = self.client.clone();
914+
self.spawn_with_cancel(async move {
915+
if let Err(why) = client.send_request::<RegisterCapability>(params) {
916+
println!(
917+
"Failed to register \"{}\" notification: {}",
918+
DidChangeConfiguration::METHOD,
919+
why
920+
);
921+
}
922+
});
923+
}
924+
925+
fn spawn_with_cancel<F>(&self, f: F) -> tokio::task::JoinHandle<()>
926+
where
927+
F: Future + Send + 'static,
928+
{
929+
let cancel_token = self.cancel_token.clone();
930+
tokio::spawn(async move {
931+
tokio::select! {
932+
_ = cancel_token.cancelled() => {},
933+
_ = f => {}
934+
};
935+
})
924936
}
925937

926938
pub async fn run(mut self) -> anyhow::Result<()> {
927-
self.register_configuration();
928-
self.pull_options();
939+
if self.client_flags.will_push_configuration {
940+
self.register_configuration();
941+
}
942+
943+
if self.client_flags.has_configuration {
944+
self.pull_options();
945+
}
946+
929947
self.process_messages().await?;
930-
self.pool.join();
948+
931949
Ok(())
932950
}
933951
}

0 commit comments

Comments
 (0)