forked from SilasMarvin/lsp-ai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.rs
171 lines (154 loc) · 6.6 KB
/
main.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
use anyhow::Result;
use lsp_server::{Connection, ExtractError, Message, Notification, Request, RequestId};
use lsp_types::{
request::Completion, CompletionOptions, DidChangeTextDocumentParams, DidOpenTextDocumentParams,
RenameFilesParams, ServerCapabilities, TextDocumentSyncKind,
};
use std::{
collections::HashMap,
sync::{mpsc, Arc},
thread,
};
use tracing::error;
use tracing_subscriber::{EnvFilter, FmtSubscriber};
mod config;
mod custom_requests;
mod memory_backends;
mod memory_worker;
#[cfg(feature = "llama_cpp")]
mod template;
mod transformer_backends;
mod transformer_worker;
mod utils;
use config::Config;
use custom_requests::generation::Generation;
use memory_backends::MemoryBackend;
use transformer_backends::TransformerBackend;
use transformer_worker::{CompletionRequest, GenerationRequest, WorkerRequest};
use crate::{
custom_requests::generation_stream::GenerationStream,
transformer_worker::GenerationStreamRequest,
};
fn notification_is<N: lsp_types::notification::Notification>(notification: &Notification) -> bool {
notification.method == N::METHOD
}
fn request_is<R: lsp_types::request::Request>(request: &Request) -> bool {
request.method == R::METHOD
}
fn cast<R>(req: Request) -> Result<(RequestId, R::Params), ExtractError<Request>>
where
R: lsp_types::request::Request,
R::Params: serde::de::DeserializeOwned,
{
req.extract(R::METHOD)
}
fn main() -> Result<()> {
// Builds a tracing subscriber from the `LSP_AI_LOG` environment variable
// If the variables value is malformed or missing, sets the default log level to ERROR
FmtSubscriber::builder()
.with_writer(std::io::stderr)
.with_ansi(false)
.without_time()
.with_env_filter(EnvFilter::from_env("LSP_AI_LOG"))
.init();
let (connection, io_threads) = Connection::stdio();
let server_capabilities = serde_json::to_value(ServerCapabilities {
completion_provider: Some(CompletionOptions::default()),
text_document_sync: Some(lsp_types::TextDocumentSyncCapability::Kind(
TextDocumentSyncKind::INCREMENTAL,
)),
..Default::default()
})?;
let initialization_args = connection.initialize(server_capabilities)?;
main_loop(connection, initialization_args)?;
io_threads.join()?;
Ok(())
}
fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
// Build our configuration
let config = Config::new(args)?;
// Wrap the connection for sharing between threads
let connection = Arc::new(connection);
// Our channel we use to communicate with our transformer worker
// let last_worker_request = Arc::new(Mutex::new(None));
let (transformer_tx, transformer_rx) = mpsc::channel();
// The channel we use to communicate with our memory worker
let (memory_tx, memory_rx) = mpsc::channel();
// Setup the transformer worker
let memory_backend: Box<dyn MemoryBackend + Send + Sync> = config.clone().try_into()?;
thread::spawn(move || memory_worker::run(memory_backend, memory_rx));
// Setup our transformer worker
// let transformer_backend: Box<dyn TransformerBackend + Send + Sync> =
// config.clone().try_into()?;
let transformer_backends: HashMap<String, Box<dyn TransformerBackend + Send + Sync>> = config
.config
.models
.clone()
.into_iter()
.map(|(key, value)| Ok((key, value.try_into()?)))
.collect::<anyhow::Result<HashMap<String, Box<dyn TransformerBackend + Send + Sync>>>>()?;
let thread_connection = connection.clone();
let thread_memory_tx = memory_tx.clone();
let thread_config = config.clone();
thread::spawn(move || {
transformer_worker::run(
transformer_backends,
thread_memory_tx,
transformer_rx,
thread_connection,
thread_config,
)
});
for msg in &connection.receiver {
match msg {
Message::Request(req) => {
if connection.handle_shutdown(&req)? {
return Ok(());
}
if request_is::<Completion>(&req) {
match cast::<Completion>(req) {
Ok((id, params)) => {
let completion_request = CompletionRequest::new(id, params);
transformer_tx.send(WorkerRequest::Completion(completion_request))?;
}
Err(err) => error!("{err:?}"),
}
} else if request_is::<Generation>(&req) {
match cast::<Generation>(req) {
Ok((id, params)) => {
let generation_request = GenerationRequest::new(id, params);
transformer_tx.send(WorkerRequest::Generation(generation_request))?;
}
Err(err) => error!("{err:?}"),
}
} else if request_is::<GenerationStream>(&req) {
match cast::<GenerationStream>(req) {
Ok((id, params)) => {
let generation_stream_request =
GenerationStreamRequest::new(id, params);
transformer_tx
.send(WorkerRequest::GenerationStream(generation_stream_request))?;
}
Err(err) => error!("{err:?}"),
}
} else {
error!("lsp-ai currently only supports textDocument/completion, textDocument/generation and textDocument/generationStream")
}
}
Message::Notification(not) => {
if notification_is::<lsp_types::notification::DidOpenTextDocument>(¬) {
let params: DidOpenTextDocumentParams = serde_json::from_value(not.params)?;
memory_tx.send(memory_worker::WorkerRequest::DidOpenTextDocument(params))?;
} else if notification_is::<lsp_types::notification::DidChangeTextDocument>(¬) {
let params: DidChangeTextDocumentParams = serde_json::from_value(not.params)?;
memory_tx.send(memory_worker::WorkerRequest::DidChangeTextDocument(params))?;
} else if notification_is::<lsp_types::notification::DidRenameFiles>(¬) {
let params: RenameFilesParams = serde_json::from_value(not.params)?;
memory_tx.send(memory_worker::WorkerRequest::DidRenameFiles(params))?;
}
}
_ => (),
}
}
Ok(())
}