Skip to content

Commit

Permalink
Impl multi-threading and request cancellation
Browse files Browse the repository at this point in the history
  • Loading branch information
oxalica committed Sep 27, 2022
1 parent fedf5eb commit ca9c971
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 64 deletions.
188 changes: 125 additions & 63 deletions crates/nil/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ use lsp_server::{ErrorCode, Message, Notification, ReqQueue, Request, RequestId,
use lsp_types::notification::Notification as _;
use lsp_types::{
notification as notif, request as req, ConfigurationItem, ConfigurationParams, MessageType,
PublishDiagnosticsParams, ShowMessageParams, Url,
NumberOrString, PublishDiagnosticsParams, ShowMessageParams, Url,
};
use serde::{Deserialize, Serialize};
use serde::Deserialize;
use std::cell::Cell;
use std::collections::HashSet;
use std::panic::UnwindSafe;
use std::path::PathBuf;
use std::sync::{Arc, Once, RwLock};
use std::{fs, panic};
use std::{fs, panic, thread};

const FILTER_FILE_EXTENTION: &str = "nix";
const CONFIG_KEY: &str = "nil";
Expand All @@ -26,6 +26,12 @@ pub struct Config {

type ReqHandler = fn(&mut Server, Response);

type Task = Box<dyn FnOnce() -> Event + Send>;

enum Event {
Response(Response),
}

pub struct Server {
// States.
host: AnalysisHost,
Expand All @@ -34,19 +40,35 @@ pub struct Server {
config: Arc<Config>,
is_shutdown: bool,

// Request & response.
// Message passing.
req_queue: ReqQueue<(), ReqHandler>,
sender: Sender<Message>,
lsp_tx: Sender<Message>,
task_tx: Sender<Task>,
event_rx: Receiver<Event>,

// Immutable settings.
workspace_root: Option<PathBuf>,
}

impl Server {
pub fn new(responder: Sender<Message>, workspace_root: Option<PathBuf>) -> Self {
pub fn new(lsp_tx: Sender<Message>, workspace_root: Option<PathBuf>) -> Self {
// Vfs root must be absolute.
let workspace_root = workspace_root.and_then(|root| root.canonicalize().ok());
let vfs = Vfs::new(workspace_root.clone().unwrap_or_else(|| PathBuf::from("/")));

let (task_tx, task_rx) = crossbeam_channel::unbounded();
let (event_tx, event_rx) = crossbeam_channel::unbounded();
let worker_cnt = thread::available_parallelism().map_or(1, |n| n.get());
for _ in 0..worker_cnt {
let task_rx = task_rx.clone();
let event_tx = event_tx.clone();
thread::Builder::new()
.name("Worker".into())
.spawn(move || Self::worker(task_rx, event_tx))
.expect("Failed to spawn worker threads");
}
tracing::info!("Started {worker_cnt} workers");

Self {
host: Default::default(),
vfs: Arc::new(RwLock::new(vfs)),
Expand All @@ -55,13 +77,23 @@ impl Server {
is_shutdown: false,

req_queue: ReqQueue::default(),
sender: responder,
lsp_tx,
task_tx,
event_rx,

workspace_root,
}
}

pub fn run(&mut self, lsp_receiver: Receiver<Message>) -> Result<()> {
fn worker(task_rx: Receiver<Task>, event_tx: Sender<Event>) {
while let Ok(task) = task_rx.recv() {
if event_tx.send(task()).is_err() {
break;
}
}
}

pub fn run(&mut self, lsp_rx: Receiver<Message>) -> Result<()> {
if let Some(root) = &self.workspace_root {
let mut vfs = self.vfs.write().unwrap();
for entry in ignore::WalkBuilder::new(root).follow_links(false).build() {
Expand All @@ -86,24 +118,35 @@ impl Server {
self.apply_vfs_change();
}

for msg in &lsp_receiver {
match msg {
Message::Request(req) => self.dispatch_request(req),
Message::Notification(notif) => {
if notif.method == notif::Exit::METHOD {
return Ok(());
loop {
crossbeam_channel::select! {
recv(lsp_rx) -> msg => {
match msg.map_err(|_| "Channel closed")? {
Message::Request(req) => self.dispatch_request(req),
Message::Notification(notif) => {
if notif.method == notif::Exit::METHOD {
return Ok(());
}
self.dispatch_notification(notif)?;
}
Message::Response(resp) => {
if let Some(callback) = self.req_queue.outgoing.complete(resp.id.clone()) {
callback(self, resp);
}
}
}
self.dispatch_notification(notif)?;
}
Message::Response(resp) => {
if let Some(callback) = self.req_queue.outgoing.complete(resp.id.clone()) {
callback(self, resp);
recv(self.event_rx) -> event => {
match event.map_err(|_| "Worker panicked")? {
Event::Response(resp) => {
if let Some(()) = self.req_queue.incoming.complete(resp.id.clone()) {
self.lsp_tx.send(resp.into())?;
}
}
}
}
}
}

Err("Channel closed".into())
}

fn dispatch_request(&mut self, req: Request) {
Expand All @@ -113,7 +156,7 @@ impl Server {
ErrorCode::InvalidRequest as i32,
"Shutdown already requested.".into(),
);
self.sender.send(resp.into()).unwrap();
self.lsp_tx.send(resp.into()).unwrap();
return;
}

Expand All @@ -138,6 +181,16 @@ impl Server {

fn dispatch_notification(&mut self, notif: Notification) -> Result<()> {
NotificationDispatcher(self, Some(notif))
.on_sync_mut::<notif::Cancel>(|st, params| {
let id: RequestId = match params.id {
NumberOrString::Number(id) => id.into(),
NumberOrString::String(id) => id.into(),
};
if let Some(resp) = st.req_queue.incoming.cancel(id) {
st.lsp_tx.send(resp.into()).unwrap();
}
Ok(())
})?
.on_sync_mut::<notif::DidOpenTextDocument>(|st, params| {
let uri = &params.text_document.uri;
st.opened_files.insert(uri.clone());
Expand Down Expand Up @@ -209,11 +262,11 @@ impl Server {
.req_queue
.outgoing
.register(R::METHOD.into(), params, callback);
self.sender.send(req.into()).unwrap();
self.lsp_tx.send(req.into()).unwrap();
}

fn send_notification<N: notif::Notification>(&self, params: N::Params) {
self.sender
self.lsp_tx
.send(Notification::new(N::METHOD.into(), params).into())
.unwrap();
}
Expand Down Expand Up @@ -339,49 +392,44 @@ impl<'s> RequestDispatcher<'s> {
) -> Self {
if matches!(&self.1, Some(notif) if notif.method == R::METHOD) {
let req = self.1.take().unwrap();
let ret = match serde_json::from_value::<R::Params>(req.params) {
Ok(params) => result_to_response(req.id, f(self.0, params)),
Err(err) => Ok(Response::new_err(
req.id,
ErrorCode::InvalidParams as i32,
err.to_string(),
)),
};
if let Ok(resp) = ret {
self.0.sender.send(resp.into()).unwrap();
}
let ret = (|| {
let params = serde_json::from_value::<R::Params>(req.params)?;
let v = f(self.0, params)?;
Ok(serde_json::to_value(v).unwrap())
})();
let resp = result_to_response(req.id, ret);
self.0.lsp_tx.send(resp.into()).unwrap();
}
self
}

fn on<R: req::Request>(mut self, f: fn(StateSnapshot, R::Params) -> Result<R::Result>) -> Self
fn on<R>(mut self, f: fn(StateSnapshot, R::Params) -> Result<R::Result>) -> Self
where
R::Params: UnwindSafe,
R: req::Request,
R::Params: 'static,
R::Result: 'static,
{
if matches!(&self.1, Some(notif) if notif.method == R::METHOD) {
let req = self.1.take().unwrap();
let ret = match serde_json::from_value::<R::Params>(req.params) {
Ok(params) => {
let snap = self.0.snapshot();
result_to_response(req.id, with_catch_unwind(R::METHOD, || f(snap, params)))
}
Err(err) => Ok(Response::new_err(
req.id,
ErrorCode::InvalidParams as i32,
err.to_string(),
)),
let snap = self.0.snapshot();
self.0.req_queue.incoming.register(req.id.clone(), ());
let task = move || {
let ret = with_catch_unwind(R::METHOD, || {
let params = serde_json::from_value::<R::Params>(req.params)?;
let resp = f(snap, params)?;
Ok(serde_json::to_value(resp)?)
});
Event::Response(result_to_response(req.id, ret))
};
if let Ok(resp) = ret {
self.0.sender.send(resp.into()).unwrap();
}
self.0.task_tx.send(Box::new(task)).unwrap();
}
self
}

fn finish(self) {
if let Some(req) = self.1 {
let resp = Response::new_err(req.id, ErrorCode::MethodNotFound as _, String::new());
self.0.sender.send(resp.into()).unwrap();
self.0.lsp_tx.send(resp.into()).unwrap();
}
}
}
Expand All @@ -395,9 +443,14 @@ impl<'s> NotificationDispatcher<'s> {
f: fn(&mut Server, N::Params) -> Result<()>,
) -> Result<Self> {
if matches!(&self.1, Some(notif) if notif.method == N::METHOD) {
let params =
serde_json::from_value::<N::Params>(self.1.take().unwrap().params).unwrap();
f(self.0, params)?;
match serde_json::from_value::<N::Params>(self.1.take().unwrap().params) {
Ok(params) => {
f(self.0, params)?;
}
Err(err) => {
tracing::error!("Failed to parse notification {}: {}", N::METHOD, err)
}
}
}
Ok(self)
}
Expand Down Expand Up @@ -448,18 +501,27 @@ fn with_catch_unwind<T>(ctx: &str, f: impl FnOnce() -> Result<T> + UnwindSafe) -
}
}

fn result_to_response(id: RequestId, ret: Result<impl Serialize>) -> Result<Response, Cancelled> {
match ret {
Ok(ret) => Ok(Response::new_ok(id, ret)),
Err(err) => match err.downcast::<Cancelled>() {
Ok(cancelled) => Err(*cancelled),
Err(err) => Ok(Response::new_err(
fn result_to_response(id: RequestId, ret: Result<serde_json::Value>) -> Response {
let err = match ret {
Ok(v) => {
return Response {
id,
ErrorCode::InternalError as i32,
err.to_string(),
)),
},
result: Some(v),
error: None,
}
}
Err(err) => err,
};

if err.is::<Cancelled>() {
// When client cancelled a request, a response is immediately sent back,
// and this will be ignored.
return Response::new_err(id, ErrorCode::ServerCancelled as i32, "Cancelled".into());
}
if let Some(err) = err.downcast_ref::<serde_json::Error>() {
return Response::new_err(id, ErrorCode::InvalidParams as i32, err.to_string());
}
Response::new_err(id, ErrorCode::InternalError as i32, err.to_string())
}

#[derive(Debug)]
Expand Down
3 changes: 2 additions & 1 deletion docs/features.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ This incomplete list tracks noteble features currently implemented or planned.
```

- [ ] Cross-file analysis.
- [ ] Multi-threaded.
- [x] Multi-threaded.
- [x] Request cancellation. `$/cancelRequest`

[`coc.nvim`]: https://github.com/neoclide/coc.nvim

0 comments on commit ca9c971

Please sign in to comment.