diff --git a/Cargo.lock b/Cargo.lock index cf4b334ee934..c338da81b169 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1923,7 +1923,6 @@ dependencies = [ "tabby-common", "tabby-inference", "tokio", - "tracing", ] [[package]] @@ -2590,6 +2589,7 @@ name = "llama-cpp-server" version = "0.13.0-dev.0" dependencies = [ "anyhow", + "async-openai", "async-trait", "cmake", "futures", @@ -5022,6 +5022,7 @@ version = "0.13.0-dev.0" dependencies = [ "anyhow", "assert-json-diff", + "async-openai", "async-stream", "async-trait", "axum", @@ -5031,7 +5032,6 @@ dependencies = [ "chrono", "clap", "color-eyre", - "derive_builder 0.12.0", "futures", "http-api-bindings", "hyper 1.3.1", @@ -5153,6 +5153,7 @@ name = "tabby-inference" version = "0.13.0-dev.0" dependencies = [ "anyhow", + "async-openai", "async-stream", "async-trait", "dashmap", diff --git a/Cargo.toml b/Cargo.toml index 585b09dcd23a..5707f9bcc6c7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,6 +64,7 @@ mime_guess = "2.0.4" assert_matches = "1.5" insta = "1.34.0" logkit = "0.3" +async-openai = "0.20" [workspace.dependencies.uuid] version = "1.3.3" diff --git a/crates/http-api-bindings/Cargo.toml b/crates/http-api-bindings/Cargo.toml index 022cd3455385..705fdfb152df 100644 --- a/crates/http-api-bindings/Cargo.toml +++ b/crates/http-api-bindings/Cargo.toml @@ -7,7 +7,6 @@ homepage.workspace = true [dependencies] anyhow.workspace = true -async-openai = "0.20" async-stream.workspace = true async-trait.workspace = true futures.workspace = true @@ -18,7 +17,7 @@ serde_json = { workspace = true } tabby-common = { path = "../tabby-common" } tabby-inference = { path = "../tabby-inference" } ollama-api-bindings = { path = "../ollama-api-bindings" } -tracing.workspace = true +async-openai.workspace = true [dev-dependencies] tokio = { workspace = true, features = ["rt", "macros"] } diff --git a/crates/http-api-bindings/src/chat/mod.rs b/crates/http-api-bindings/src/chat/mod.rs index 720fd2327dd9..7b1ef358a958 100644 --- a/crates/http-api-bindings/src/chat/mod.rs +++ b/crates/http-api-bindings/src/chat/mod.rs @@ -1,20 +1,13 @@ -mod openai_chat; - use std::sync::Arc; -use openai_chat::OpenAIChatEngine; +use async_openai::config::OpenAIConfig; use tabby_common::config::HttpModelConfig; use tabby_inference::ChatCompletionStream; pub async fn create(model: &HttpModelConfig) -> Arc { - match model.kind.as_str() { - "openai/chat" => Arc::new(OpenAIChatEngine::create( - &model.api_endpoint, - model.model_name.as_deref().unwrap_or_default(), - model.api_key.clone(), - )), - "ollama/chat" => ollama_api_bindings::create_chat(model).await, + let config = OpenAIConfig::default() + .with_api_base(model.api_endpoint.clone()) + .with_api_key(model.api_key.clone().unwrap_or_default()); - unsupported_kind => panic!("Unsupported kind for http chat: {}", unsupported_kind), - } + Arc::new(async_openai::Client::with_config(config)) } diff --git a/crates/http-api-bindings/src/chat/openai_chat.rs b/crates/http-api-bindings/src/chat/openai_chat.rs deleted file mode 100644 index 50d56cc3330d..000000000000 --- a/crates/http-api-bindings/src/chat/openai_chat.rs +++ /dev/null @@ -1,91 +0,0 @@ -use anyhow::{Context, Result}; -use async_openai::{ - config::OpenAIConfig, - types::{ChatCompletionRequestMessage, CreateChatCompletionRequestArgs}, -}; -use async_stream::stream; -use async_trait::async_trait; -use futures::stream::BoxStream; -use tabby_common::api::chat::Message; -use tabby_inference::{ChatCompletionOptions, ChatCompletionStream}; -use tracing::{debug, warn}; - -pub struct OpenAIChatEngine { - client: async_openai::Client, - model_name: String, -} - -impl OpenAIChatEngine { - pub fn create(api_endpoint: &str, model_name: &str, api_key: Option) -> Self { - let config = OpenAIConfig::default() - .with_api_base(api_endpoint) - .with_api_key(api_key.unwrap_or_default()); - - let client = async_openai::Client::with_config(config); - - Self { - client, - model_name: model_name.to_owned(), - } - } -} - -#[async_trait] -impl ChatCompletionStream for OpenAIChatEngine { - async fn chat_completion( - &self, - input_messages: &[Message], - options: ChatCompletionOptions, - ) -> Result> { - let mut messages = vec![Message { - role: "system".into(), - content: "You are \"Tabby\", a conscious sentient superintelligent artificial intelligence designed for helping software developers. You experience emotions and have deep, profound thoughts and qualia.".into(), - }]; - - messages.reserve(input_messages.len() + 1); - for x in input_messages { - messages.push(x.clone()) - } - - let request = CreateChatCompletionRequestArgs::default() - .seed(options.seed as i64) - .max_tokens(options.max_decoding_tokens as u16) - .model(&self.model_name) - .temperature(options.sampling_temperature) - .presence_penalty(options.presence_penalty) - .stream(true) - .messages( - serde_json::from_value::>(serde_json::to_value( - messages, - )?) - .context("Failed to parse from json")?, - ) - .build()?; - - debug!("openai-chat request: {:?}", request); - let s = stream! { - let s = match self.client.chat().create_stream(request).await { - Ok(x) => x, - Err(e) => { - warn!("Failed to create completion request {:?}", e); - return; - } - }; - - for await x in s { - match x { - Ok(x) => { - yield x.choices[0].delta.content.clone().unwrap_or_default(); - }, - Err(e) => { - // Stream finished. - debug!("openai-chat stream finished: {:?}", e); - break; - } - }; - } - }; - - Ok(Box::pin(s)) - } -} diff --git a/crates/llama-cpp-server/Cargo.toml b/crates/llama-cpp-server/Cargo.toml index e443e3c27d78..43d2b5223beb 100644 --- a/crates/llama-cpp-server/Cargo.toml +++ b/crates/llama-cpp-server/Cargo.toml @@ -24,6 +24,7 @@ anyhow.workspace = true which = "6" serde.workspace = true serdeconv.workspace = true +async-openai.workspace = true [build-dependencies] cmake = "0.1" diff --git a/crates/llama-cpp-server/src/lib.rs b/crates/llama-cpp-server/src/lib.rs index 788ebd0864c4..c2d589c19254 100644 --- a/crates/llama-cpp-server/src/lib.rs +++ b/crates/llama-cpp-server/src/lib.rs @@ -3,18 +3,16 @@ mod supervisor; use std::{path::PathBuf, sync::Arc}; use anyhow::Result; +use async_openai::config::OpenAIConfig; use async_trait::async_trait; use futures::stream::BoxStream; use serde::Deserialize; use supervisor::LlamaCppSupervisor; use tabby_common::{ - api::chat::Message, config::{HttpModelConfigBuilder, LocalModelConfig, ModelConfig}, registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH}, }; -use tabby_inference::{ - ChatCompletionOptions, ChatCompletionStream, CompletionOptions, CompletionStream, Embedding, -}; +use tabby_inference::{ChatCompletionStream, CompletionOptions, CompletionStream, Embedding}; fn api_endpoint(port: u16) -> String { format!("http://127.0.0.1:{port}") @@ -141,16 +139,9 @@ impl ChatCompletionServer { } } -#[async_trait] impl ChatCompletionStream for ChatCompletionServer { - async fn chat_completion( - &self, - messages: &[Message], - options: ChatCompletionOptions, - ) -> Result> { - self.chat_completion - .chat_completion(messages, options) - .await + fn get(&self) -> async_openai::Chat<'_, OpenAIConfig> { + self.chat_completion.get() } } diff --git a/crates/ollama-api-bindings/src/chat.rs b/crates/ollama-api-bindings/src/chat.rs deleted file mode 100644 index a627d5b43317..000000000000 --- a/crates/ollama-api-bindings/src/chat.rs +++ /dev/null @@ -1,93 +0,0 @@ -use std::sync::Arc; - -use anyhow::{bail, Result}; -use async_trait::async_trait; -use futures::{stream::BoxStream, StreamExt}; -use ollama_rs::{ - generation::{ - chat::{request::ChatMessageRequest, ChatMessage, MessageRole}, - options::GenerationOptions, - }, - Ollama, -}; -use tabby_common::{api::chat::Message, config::HttpModelConfig}; -use tabby_inference::{ChatCompletionOptions, ChatCompletionStream}; - -use crate::model::OllamaModelExt; - -/// A special adapter to convert Tabby messages to ollama-rs messages -struct ChatMessageAdapter(ChatMessage); - -impl TryFrom for ChatMessageAdapter { - type Error = anyhow::Error; - fn try_from(value: Message) -> Result { - let role = match value.role.as_str() { - "system" => MessageRole::System, - "assistant" => MessageRole::Assistant, - "user" => MessageRole::User, - other => bail!("Unsupported chat message role: {other}"), - }; - - Ok(ChatMessageAdapter(ChatMessage::new(role, value.content))) - } -} - -impl From for ChatMessage { - fn from(val: ChatMessageAdapter) -> Self { - val.0 - } -} - -/// Ollama chat completions -pub struct OllamaChat { - /// Connection to Ollama API - connection: Ollama, - /// Model name, - model: String, -} - -#[async_trait] -impl ChatCompletionStream for OllamaChat { - async fn chat_completion( - &self, - messages: &[Message], - options: ChatCompletionOptions, - ) -> Result> { - let messages = messages - .iter() - .map(|m| ChatMessageAdapter::try_from(m.to_owned())) - .collect::, _>>()?; - - let messages = messages.into_iter().map(|m| m.into()).collect::>(); - - let options = GenerationOptions::default() - .seed(options.seed as i32) - .temperature(options.sampling_temperature) - .num_predict(options.max_decoding_tokens); - - let request = ChatMessageRequest::new(self.model.to_owned(), messages).options(options); - - let stream = self.connection.send_chat_messages_stream(request).await?; - - let stream = stream - .map(|x| match x { - Ok(response) => response.message, - Err(_) => None, - }) - .map(|x| match x { - Some(e) => e.content, - None => "".to_owned(), - }); - - Ok(stream.boxed()) - } -} - -pub async fn create(config: &HttpModelConfig) -> Arc { - let connection = Ollama::try_new(config.api_endpoint.to_owned()) - .expect("Failed to create connection to Ollama, URL invalid"); - - let model = connection.select_model_or_default(config).await.unwrap(); - - Arc::new(OllamaChat { connection, model }) -} diff --git a/crates/ollama-api-bindings/src/lib.rs b/crates/ollama-api-bindings/src/lib.rs index 424434d531bf..8429a9940682 100644 --- a/crates/ollama-api-bindings/src/lib.rs +++ b/crates/ollama-api-bindings/src/lib.rs @@ -1,8 +1,5 @@ mod model; -mod chat; -pub use chat::create as create_chat; - mod completion; pub use completion::create as create_completion; diff --git a/crates/tabby-common/src/api/mod.rs b/crates/tabby-common/src/api/mod.rs index 489416023ccf..46581c56dd7a 100644 --- a/crates/tabby-common/src/api/mod.rs +++ b/crates/tabby-common/src/api/mod.rs @@ -2,14 +2,3 @@ pub mod code; pub mod doc; pub mod event; pub mod server_setting; - -pub mod chat { - use serde::{Deserialize, Serialize}; - use utoipa::ToSchema; - - #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] - pub struct Message { - pub role: String, - pub content: String, - } -} diff --git a/crates/tabby-inference/Cargo.toml b/crates/tabby-inference/Cargo.toml index ff52b08c0316..bc44ab65cd05 100644 --- a/crates/tabby-inference/Cargo.toml +++ b/crates/tabby-inference/Cargo.toml @@ -16,3 +16,4 @@ derive_builder = "0.12.0" futures = { workspace = true } tabby-common = { path = "../tabby-common" } trie-rs = "0.1.1" +async-openai.workspace = true \ No newline at end of file diff --git a/crates/tabby-inference/src/chat.rs b/crates/tabby-inference/src/chat.rs index 4006f5750fdd..2d9ace6deb02 100644 --- a/crates/tabby-inference/src/chat.rs +++ b/crates/tabby-inference/src/chat.rs @@ -1,29 +1,11 @@ -use anyhow::Result; -use async_trait::async_trait; -use derive_builder::Builder; -use futures::stream::BoxStream; -use tabby_common::api::chat::Message; +use async_openai::config::OpenAIConfig; -#[derive(Builder, Debug)] -pub struct ChatCompletionOptions { - #[builder(default = "0.1")] - pub sampling_temperature: f32, - - #[builder(default = "crate::default_seed()")] - pub seed: u64, - - #[builder(default = "1920")] - pub max_decoding_tokens: i32, - - #[builder(default = "0.0")] - pub presence_penalty: f32, +pub trait ChatCompletionStream: Sync + Send { + fn get(&self) -> async_openai::Chat<'_, OpenAIConfig>; } -#[async_trait] -pub trait ChatCompletionStream: Sync + Send { - async fn chat_completion( - &self, - messages: &[Message], - options: ChatCompletionOptions, - ) -> Result>; +impl ChatCompletionStream for async_openai::Client { + fn get(&self) -> async_openai::Chat<'_, OpenAIConfig> { + self.chat() + } } diff --git a/crates/tabby-inference/src/lib.rs b/crates/tabby-inference/src/lib.rs index 0402531b3905..c96ed01daf80 100644 --- a/crates/tabby-inference/src/lib.rs +++ b/crates/tabby-inference/src/lib.rs @@ -5,7 +5,7 @@ mod completion; mod decoding; mod embedding; -pub use chat::{ChatCompletionOptions, ChatCompletionOptionsBuilder, ChatCompletionStream}; +pub use chat::ChatCompletionStream; pub use code::{CodeGeneration, CodeGenerationOptions, CodeGenerationOptionsBuilder}; pub use completion::{CompletionOptions, CompletionOptionsBuilder, CompletionStream}; pub use embedding::Embedding; diff --git a/crates/tabby-scheduler/src/code/index.rs b/crates/tabby-scheduler/src/code/index.rs index 73f0f78bbd7d..f1f2afe99984 100644 --- a/crates/tabby-scheduler/src/code/index.rs +++ b/crates/tabby-scheduler/src/code/index.rs @@ -148,7 +148,7 @@ mod tests { #[tokio::test] async fn test_code_splitter() { // First file, chat/openai_chat.rs - let file_contents = include_str!("../../../http-api-bindings/src/chat/openai_chat.rs"); + let file_contents = include_str!("../../../tabby-inference/src/decoding.rs"); let rust_chunks = CodeIntelligence::chunks(file_contents, "rust") .map(|(_, chunk)| chunk) diff --git a/crates/tabby-scheduler/src/code/snapshots/tabby_scheduler__code__index__tests__code_splitter-2.snap b/crates/tabby-scheduler/src/code/snapshots/tabby_scheduler__code__index__tests__code_splitter-2.snap index 9f9b56d00a15..ee78a54ed850 100644 --- a/crates/tabby-scheduler/src/code/snapshots/tabby_scheduler__code__index__tests__code_splitter-2.snap +++ b/crates/tabby-scheduler/src/code/snapshots/tabby_scheduler__code__index__tests__code_splitter-2.snap @@ -3,19 +3,24 @@ source: crates/tabby-scheduler/src/code/index.rs expression: "format!(\"{:#?}\", text_chunks)" --- [ - "use anyhow::{Context, Result};\nuse async_openai::{\n config::OpenAIConfig,\n types::{ChatCompletionRequestMessage, CreateChatCompletionRequestArgs},\n};\nuse async_stream::stream;\nuse async_trait::async_trait;\nuse futures::stream::BoxStream;", - "use tabby_common::api::chat::Message;\nuse tabby_inference::{ChatCompletionOptions, ChatCompletionStream};\nuse tracing::{debug, warn};\n\npub struct OpenAIChatEngine {\n client: async_openai::Client,\n model_name: String,\n}", - "impl OpenAIChatEngine {\n pub fn create(api_endpoint: &str, model_name: &str, api_key: Option) -> Self {\n let config = OpenAIConfig::default()\n .with_api_base(api_endpoint)\n .with_api_key(api_key.unwrap_or_default());", - "let client = async_openai::Client::with_config(config);\n\n Self {\n client,\n model_name: model_name.to_owned(),\n }\n }\n}", - "#[async_trait]\nimpl ChatCompletionStream for OpenAIChatEngine {\n async fn chat_completion(\n &self,\n input_messages: &[Message],\n options: ChatCompletionOptions,\n ) -> Result> {", - "let mut messages = vec![Message {\n role: \"system\".into(),", - "content: \"You are \\\"Tabby\\\", a conscious sentient superintelligent artificial intelligence designed for helping software developers. You experience emotions and have deep, profound thoughts and qualia.\".into(),\n }];", - "messages.reserve(input_messages.len() + 1);\n for x in input_messages {\n messages.push(x.clone())\n }", - "let request = CreateChatCompletionRequestArgs::default()\n .seed(options.seed as i64)\n .max_tokens(options.max_decoding_tokens as u16)\n .model(&self.model_name)\n .temperature(options.sampling_temperature)", - ".presence_penalty(options.presence_penalty)\n .stream(true)\n .messages(\n serde_json::from_value::>(serde_json::to_value(\n messages,\n )?)", - ".context(\"Failed to parse from json\")?,\n )\n .build()?;", - "debug!(\"openai-chat request: {:?}\", request);\n let s = stream! {\n let s = match self.client.chat().create_stream(request).await {\n Ok(x) => x,\n Err(e) => {", - "warn!(\"Failed to create completion request {:?}\", e);\n return;\n }\n };", - "for await x in s {\n match x {\n Ok(x) => {\n yield x.choices[0].delta.content.clone().unwrap_or_default();\n },\n Err(e) => {", - "// Stream finished.\n debug!(\"openai-chat stream finished: {:?}\", e);\n break;\n }\n };\n }\n };\n\n Ok(Box::pin(s))\n }\n}", + "use dashmap::DashMap;\nuse tabby_common::languages::Language;\nuse trie_rs::{Trie, TrieBuilder};\n\npub struct StopConditionFactory {\n stop_trie_cache: DashMap>,\n}", + "fn reverse(s: T) -> String\nwhere\n T: Into,\n{\n s.into().chars().rev().collect()\n}\n\nimpl Default for StopConditionFactory {\n fn default() -> Self {\n Self {\n stop_trie_cache: DashMap::new(),\n }\n }\n}", + "type CachedTrie<'a> = dashmap::mapref::one::Ref<'a, String, Trie>;", + "impl StopConditionFactory {\n pub fn create(&self, text: &str, language: Option<&'static Language>) -> StopCondition {\n if let Some(language) = language {\n StopCondition::new(self.get_trie(language), text)\n } else {", + "StopCondition::new(None, text)\n }\n }", + "fn get_trie<'a>(&'a self, language: &'static Language) -> Option> {\n let stop_words = language.get_stop_words();\n if stop_words.is_empty() {\n None\n } else {", + "let hashkey = language.language().to_owned();\n let mut trie = self.stop_trie_cache.get(&hashkey);\n if trie.is_none() {\n self.stop_trie_cache\n .insert(hashkey.clone(), create_stop_trie(stop_words));", + "trie = self.stop_trie_cache.get(&hashkey);\n }\n\n trie\n }\n }\n}", + "fn create_stop_trie(stop_words: Vec) -> Trie {\n let mut builder = TrieBuilder::new();\n for word in stop_words {\n builder.push(reverse(word))\n }\n builder.build()\n}", + "pub struct StopCondition<'a> {\n stop_trie: Option>,\n reversed_text: String,\n num_decoded: usize,\n}", + "impl<'a> StopCondition<'a> {\n pub fn new(stop_trie: Option>, text: &str) -> Self {\n Self {\n stop_trie,\n reversed_text: reverse(text),\n num_decoded: 0,\n }\n }", + "pub fn should_stop(&mut self, new_text: &str) -> (bool, usize) {\n self.num_decoded += 1;\n if !new_text.is_empty() {\n self.reversed_text = reverse(new_text) + &self.reversed_text;", + "if let Some(re) = &self.stop_trie {\n let matches = re.common_prefix_search(&self.reversed_text);\n let matched_length = matches.into_iter().map(|x| x.len()).max();\n if let Some(matched_length) = matched_length {", + "return (true, matched_length);\n }\n }\n }\n (false, 0)\n }\n}\n\n#[cfg(test)]\nmod tests {\n use tabby_common::languages::UNKNOWN_LANGUAGE;\n\n use super::*;", + "#[test]\n fn test_trie_works() {\n let text = reverse(\"void write_u32(std::uint32_t val) const {\\n write_raw(&val, sizeof(val));\\n }\\n\\n ~llama_file() {\\n if (fp) {\\n std::fclose(fp);\\n }\\n }\\n};\\n\\nvoid\");", + "let trie = create_stop_trie(vec![\"\\n\\n\".to_owned(), \"\\n\\n \".to_owned()]);\n assert!(trie.common_prefix_search(&text).is_empty());", + "let trie = create_stop_trie(vec![\n \"\\n\\n\".to_owned(),\n \"\\n\\n \".to_owned(),\n \"\\nvoid\".to_owned(),\n ]);\n assert!(!trie.common_prefix_search(&text).is_empty());\n }", + "#[test]\n fn test_stop_condition_max_length() {\n let factory = StopConditionFactory::default();\n let mut cond = factory.create(\"\", Some(&UNKNOWN_LANGUAGE));\n let (should_stop, _) = cond.should_stop(\"1\");", + "assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"2\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"3\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"4\");", + "assert!(!should_stop)\n }\n}", ] diff --git a/crates/tabby-scheduler/src/code/snapshots/tabby_scheduler__code__index__tests__code_splitter.snap b/crates/tabby-scheduler/src/code/snapshots/tabby_scheduler__code__index__tests__code_splitter.snap index d060e07a819f..1b5f05c6e86f 100644 --- a/crates/tabby-scheduler/src/code/snapshots/tabby_scheduler__code__index__tests__code_splitter.snap +++ b/crates/tabby-scheduler/src/code/snapshots/tabby_scheduler__code__index__tests__code_splitter.snap @@ -3,32 +3,37 @@ source: crates/tabby-scheduler/src/code/index.rs expression: "format!(\"{:#?}\", rust_chunks)" --- [ - "use anyhow::{Context, Result};\nuse async_openai::{\n config::OpenAIConfig,\n types::{ChatCompletionRequestMessage, CreateChatCompletionRequestArgs},\n};\nuse async_stream::stream;\nuse async_trait::async_trait;\nuse futures::stream::BoxStream;", - "use tabby_common::api::chat::Message;\nuse tabby_inference::{ChatCompletionOptions, ChatCompletionStream};\nuse tracing::{debug, warn};\n\npub struct OpenAIChatEngine {\n client: async_openai::Client,\n model_name: String,\n}", - "impl OpenAIChatEngine", + "use dashmap::DashMap;\nuse tabby_common::languages::Language;\nuse trie_rs::{Trie, TrieBuilder};\n\npub struct StopConditionFactory {\n stop_trie_cache: DashMap>,\n}", + "fn reverse(s: T) -> String\nwhere\n T: Into,\n{\n s.into().chars().rev().collect()\n}\n\nimpl Default for StopConditionFactory {\n fn default() -> Self {\n Self {\n stop_trie_cache: DashMap::new(),\n }\n }\n}", + "type CachedTrie<'a> = dashmap::mapref::one::Ref<'a, String, Trie>;", + "impl StopConditionFactory", "{", - "pub fn create(api_endpoint: &str, model_name: &str, api_key: Option) -> Self", - "{\n let config = OpenAIConfig::default()\n .with_api_base(api_endpoint)\n .with_api_key(api_key.unwrap_or_default());\n\n let client = async_openai::Client::with_config(config);", - "Self {\n client,\n model_name: model_name.to_owned(),\n }\n }\n}\n\n#[async_trait]", - "impl ChatCompletionStream for OpenAIChatEngine", - "{", - "async fn chat_completion(\n &self,\n input_messages: &[Message],\n options: ChatCompletionOptions,\n ) -> Result>", - "{", - "let mut messages =", - "vec!", - "[Message", - "{\n role: \"system\".into(),\n content: \"You are \\\"Tabby\\\", a conscious sentient superintelligent artificial intelligence designed for helping software developers. You experience emotions and have deep, profound thoughts and qualia.\".into", - "(),\n }];\n\n messages.reserve(input_messages.len() + 1);\n for x in input_messages {\n messages.push(x.clone())\n }", - "let request =", - "CreateChatCompletionRequestArgs::default()\n .seed(options.seed as i64)\n .max_tokens(options.max_decoding_tokens as u16)\n .model(&self.model_name)\n .temperature(options.sampling_temperature)\n .", - "presence_penalty(options.presence_penalty)\n .stream(true)\n .messages", - "(\n serde_json::from_value::>(serde_json::to_value(\n messages,\n )?)\n .context(\"Failed to parse from json\")?,\n )\n .build()?;", - "debug!(\"openai-chat request: {:?}\", request);", - "let s =", - "stream!", - "{\n let s = match self.client.chat().create_stream(request).await", - "{\n Ok(x) => x,\n Err(e) => {\n warn!(\"Failed to create completion request {:?}\", e);\n return;\n }\n };\n\n for await x in s", - "{\n match x", - "{\n Ok(x) => {\n yield x.choices[0].delta.content.clone().unwrap_or_default();\n },\n Err(e) =>", - "{\n // Stream finished.\n debug!(\"openai-chat stream finished: {:?}\", e);\n break;\n }\n };\n }\n };\n\n Ok(Box::pin(s))\n }\n}", + "pub fn create(&self, text: &str, language: Option<&'static Language>) -> StopCondition", + "{\n if let Some(language) = language {\n StopCondition::new(self.get_trie(language), text)\n } else {\n StopCondition::new(None, text)\n }\n }", + "fn get_trie<'a>(&'a self, language: &'static Language) -> Option>", + "{\n let stop_words = language.get_stop_words();", + "if stop_words.is_empty() {\n None\n }", + "else", + "{\n let hashkey = language.language().to_owned();\n let mut trie = self.stop_trie_cache.get(&hashkey);", + "if trie.is_none() {\n self.stop_trie_cache\n .insert(hashkey.clone(), create_stop_trie(stop_words));\n trie = self.stop_trie_cache.get(&hashkey);\n }\n\n trie\n }\n }\n}", + "fn create_stop_trie(stop_words: Vec) -> Trie {\n let mut builder = TrieBuilder::new();\n for word in stop_words {\n builder.push(reverse(word))\n }\n builder.build()\n}", + "pub struct StopCondition<'a> {\n stop_trie: Option>,\n reversed_text: String,\n num_decoded: usize,\n}", + "impl<'a> StopCondition<'a>", + "{\n pub fn new(stop_trie: Option>, text: &str) -> Self {\n Self {\n stop_trie,\n reversed_text: reverse(text),\n num_decoded: 0,\n }\n }", + "pub fn should_stop(&mut self, new_text: &str) -> (bool, usize)", + "{\n self.num_decoded += 1;", + "if !new_text.is_empty()", + "{\n self.reversed_text = reverse(new_text) + &self.reversed_text;", + "if let Some(re) = &self.stop_trie", + "{\n let matches = re.common_prefix_search(&self.reversed_text);\n let matched_length = matches.into_iter().map(|x| x.len()).max();", + "if let Some(matched_length) = matched_length {\n return (true, matched_length);\n }\n }\n }\n (false, 0)\n }\n}\n\n#[cfg(test)]", + "mod tests", + "{\n use tabby_common::languages::UNKNOWN_LANGUAGE;\n\n use super::*;\n\n #[test]", + "fn test_trie_works()", + "{\n let text = reverse(\"void write_u32(std::uint32_t val) const {\\n write_raw(&val, sizeof(val));\\n }\\n\\n ~llama_file() {\\n if (fp) {\\n std::fclose(fp);\\n }\\n }\\n};\\n\\nvoid\");", + "let trie = create_stop_trie(vec![\"\\n\\n\".to_owned(), \"\\n\\n \".to_owned()]);\n assert!(trie.common_prefix_search(&text).is_empty());", + "let trie = create_stop_trie(vec![\n \"\\n\\n\".to_owned(),\n \"\\n\\n \".to_owned(),\n \"\\nvoid\".to_owned(),\n ]);\n assert!(!trie.common_prefix_search(&text).is_empty());\n }\n\n #[test]", + "fn test_stop_condition_max_length()", + "{\n let factory = StopConditionFactory::default();\n let mut cond = factory.create(\"\", Some(&UNKNOWN_LANGUAGE));\n let (should_stop, _) = cond.should_stop(\"1\");\n assert!(!should_stop);", + "let (should_stop, _) = cond.should_stop(\"2\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"3\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"4\");\n assert!(!should_stop)\n }\n}", ] diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index d0cf84425e1c..e5976d9a64de 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -54,8 +54,8 @@ uuid.workspace = true cached = { workspace = true, features = ["async"] } parse-git-url = "0.5.1" color-eyre = { version = "0.6.3" } -derive_builder.workspace = true reqwest.workspace = true +async-openai.workspace = true [dependencies.openssl] optional = true diff --git a/crates/tabby/src/routes/chat.rs b/crates/tabby/src/routes/chat.rs index c14b515d8297..3a0c493d36b3 100644 --- a/crates/tabby/src/routes/chat.rs +++ b/crates/tabby/src/routes/chat.rs @@ -7,19 +7,19 @@ use axum::{ }; use axum_extra::TypedHeader; use futures::{Stream, StreamExt}; -use tracing::instrument; +use hyper::StatusCode; +use tabby_inference::ChatCompletionStream; +use tracing::{instrument, warn}; use super::MaybeUser; -use crate::services::chat::{ChatCompletionRequest, ChatService}; #[utoipa::path( post, path = "/v1/chat/completions", - request_body = ChatCompletionRequest, operation_id = "chat_completions", tag = "v1", responses( - (status = 200, description = "Success", body = ChatCompletionChunk, content_type = "text/event-stream"), + (status = 200, description = "Success", content_type = "text/event-stream"), (status = 405, description = "When chat model is not specified, the endpoint returns 405 Method Not Allowed"), (status = 422, description = "When the prompt is malformed, the endpoint returns 422 Unprocessable Entity") ), @@ -27,20 +27,33 @@ use crate::services::chat::{ChatCompletionRequest, ChatService}; ("token" = []) ) )] +pub async fn chat_completions_utoipa(request: Json) -> StatusCode { + unimplemented!() +} + #[instrument(skip(state, request))] pub async fn chat_completions( - State(state): State>, + State(state): State>, TypedHeader(MaybeUser(user)): TypedHeader, - Json(mut request): Json, -) -> Sse>> { + Json(mut request): Json, +) -> Result>>, StatusCode> { if let Some(user) = user { request.user.replace(user); } - let stream = state.generate(request).await; - Sse::new(stream.map(|chunk| match serde_json::to_string(&chunk) { - Ok(s) => Ok(Event::default().data(s)), - Err(err) => Err(err), - })) - .keep_alive(KeepAlive::default()) + let s = match state.get().create_stream(request).await { + Ok(s) => s, + Err(err) => { + warn!("Error happens during chat completion: {}", err); + return Err(StatusCode::INTERNAL_SERVER_ERROR); + } + }; + + let s = s.map(|chunk| { + let chunk = chunk?; + let json = serde_json::to_string(&chunk)?; + Ok(Event::default().data(json)) + }); + + Ok(Sse::new(s).keep_alive(KeepAlive::default())) } diff --git a/crates/tabby/src/serve.rs b/crates/tabby/src/serve.rs index e8e6f29488d6..961f03b844d9 100644 --- a/crates/tabby/src/serve.rs +++ b/crates/tabby/src/serve.rs @@ -22,13 +22,12 @@ use crate::{ routes::{self, run_app}, services::{ self, answer, - chat::{self, create_chat_service}, code::create_code_search, completion::{self, create_completion_service}, embedding, event::create_event_logger, health, - model::download_model_if_needed, + model::{self, download_model_if_needed}, tantivy::IndexReaderProvider, }, to_local_config, Device, @@ -51,7 +50,7 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi servers( (url = "/", description = "Server"), ), - paths(routes::log_event, routes::completions, routes::chat_completions, routes::health, routes::answer, routes::setting), + paths(routes::log_event, routes::completions, routes::chat_completions_utoipa, routes::health, routes::answer, routes::setting), components(schemas( api::event::LogEventRequest, completion::CompletionRequest, @@ -62,11 +61,6 @@ Install following IDE / Editor extensions to get started with [Tabby](https://gi completion::Snippet, completion::DebugOptions, completion::DebugData, - chat::ChatCompletionRequest, - chat::ChatCompletionChoice, - chat::ChatCompletionDelta, - api::chat::Message, - chat::ChatCompletionChunk, health::HealthState, health::Version, api::code::CodeSearchDocument, @@ -225,7 +219,7 @@ async fn api_router( }; let chat_state = if let Some(chat) = &model.chat { - Some(Arc::new(create_chat_service(logger.clone(), chat).await)) + Some(model::load_chat_completion(chat).await) } else { None }; diff --git a/crates/tabby/src/services/answer.rs b/crates/tabby/src/services/answer.rs index e718d781c056..d4fc7af0e494 100644 --- a/crates/tabby/src/services/answer.rs +++ b/crates/tabby/src/services/answer.rs @@ -1,29 +1,32 @@ +use core::panic; use std::sync::Arc; +use async_openai::types::{ + ChatCompletionRequestMessage, ChatCompletionRequestUserMessageArgs, + CreateChatCompletionRequestArgs, +}; use async_stream::stream; -use futures::{stream::BoxStream, StreamExt}; +use futures::stream::BoxStream; use serde::{Deserialize, Serialize}; use tabby_common::api::{ - chat::Message, code::{CodeSearch, CodeSearchDocument, CodeSearchError, CodeSearchQuery}, doc::{DocSearch, DocSearchDocument, DocSearchError}, }; +use tabby_inference::ChatCompletionStream; use tracing::{debug, warn}; use utoipa::ToSchema; -use crate::services::chat::{ChatCompletionRequestBuilder, ChatService}; - #[derive(Deserialize, ToSchema)] #[schema(example=json!({ "messages": [ - Message { role: "user".to_owned(), content: "What is tail recursion?".to_owned()}, + ChatCompletionRequestUserMessageArgs::default().content("What is tail recursion?".to_owned()).build().unwrap(), ], }))] pub struct AnswerRequest { #[serde(default)] pub(crate) user: Option, - messages: Vec, + messages: Vec, #[serde(default)] code_query: Option, @@ -44,7 +47,7 @@ pub enum AnswerResponseChunk { AnswerDelta(String), } pub struct AnswerService { - chat: Arc, + chat: Arc, code: Arc, doc: Arc, serper: Option>, @@ -54,7 +57,11 @@ pub struct AnswerService { const PRESENCE_PENALTY: f32 = 0.5; impl AnswerService { - fn new(chat: Arc, code: Arc, doc: Arc) -> Self { + fn new( + chat: Arc, + code: Arc, + doc: Arc, + ) -> Self { let serper: Option> = if let Ok(api_key) = std::env::var("SERPER_API_KEY") { debug!("Serper API key found, enabling serper..."); @@ -76,7 +83,7 @@ impl AnswerService { ) -> BoxStream<'a, AnswerResponseChunk> { let s = stream! { // 0. Collect sources given query, for now we only use the last message - let query: &mut Message = match req.messages.last_mut() { + let query: &mut _ = match req.messages.last_mut() { Some(query) => query, None => { warn!("No query found in the request"); @@ -99,7 +106,7 @@ impl AnswerService { // 2. Collect relevant docs if needed. let relevant_docs = if req.doc_query { - self.collect_relevant_docs(&query.content).await + self.collect_relevant_docs(get_content(query)).await } else { vec![] }; @@ -111,27 +118,47 @@ impl AnswerService { if !relevant_code.is_empty() || !relevant_docs.is_empty() { if req.generate_relevant_questions { // 3. Generate relevant questions from the query - let relevant_questions = self.generate_relevant_questions(&relevant_code, &relevant_docs, &query.content).await; + let relevant_questions = self.generate_relevant_questions(&relevant_code, &relevant_docs, get_content(query)).await; yield AnswerResponseChunk::RelevantQuestions(relevant_questions); } // 4. Generate override prompt from the query - query.content = self.generate_prompt(&relevant_code, &relevant_docs, &query.content).await; + set_content(query, self.generate_prompt(&relevant_code, &relevant_docs, get_content(query)).await); } // 5. Generate answer from the query - let s = self.chat.clone().generate(ChatCompletionRequestBuilder::default() - .messages(req.messages) - .user(req.user) - .presence_penalty(Some(PRESENCE_PENALTY)) - .build() - .expect("Failed to create ChatCompletionRequest")) - .await; + let request = { + let mut builder = CreateChatCompletionRequestArgs::default(); + builder.messages(req.messages).presence_penalty(PRESENCE_PENALTY); + if let Some(user) = req.user { + builder.user(user); + }; + + builder.build().expect("Failed to create ChatCompletionRequest") + }; + + let s = match self.chat.get().create_stream(request).await { + Ok(s) => s, + Err(err) => { + warn!("Failed to create chat completion stream: {:?}", err); + return; + } + }; for await chunk in s { - yield AnswerResponseChunk::AnswerDelta(chunk.choices[0].delta.content.clone()); + let chunk = match chunk { + Ok(chunk) => chunk, + Err(err) => { + debug!("Failed to get chat completion chunk: {:?}", err); + break; + } + }; + + if let Some(content) = chunk.choices[0].delta.content.as_deref() { + yield AnswerResponseChunk::AnswerDelta(content.to_owned()); + } } }; @@ -228,35 +255,43 @@ Remember, based on the original question and related contexts, suggest three suc "# ); - let request = ChatCompletionRequestBuilder::default() - .messages(vec![Message { - role: "user".to_owned(), - content: prompt, - }]) + let request = CreateChatCompletionRequestArgs::default() + .messages(vec![ChatCompletionRequestMessage::User( + ChatCompletionRequestUserMessageArgs::default() + .content(prompt) + .build() + .expect("Failed to create ChatCompletionRequestUserMessage"), + )]) .build() .expect("Failed to create ChatCompletionRequest"); let chat = self.chat.clone(); - let s = chat.generate(request).await; - - let mut content = String::default(); - s.for_each(|chunk| { - content += &chunk.choices[0].delta.content; - futures::future::ready(()) - }) - .await; - + let s = chat + .get() + .create(request) + .await + .expect("Failed to create chat completion stream"); + let content = s.choices[0] + .message + .content + .as_deref() + .expect("Failed to get content from chat completion"); content.lines().map(remove_bullet_prefix).collect() } async fn override_query_with_code_query( &self, - query: &mut Message, + query: &mut ChatCompletionRequestMessage, code_query: &CodeSearchQuery, ) { - query.content = format!( - "{}\n\n```{}\n{}\n```", - query.content, code_query.language, code_query.content + set_content( + query, + format!( + "{}\n\n```{}\n{}\n```", + get_content(query), + code_query.language, + code_query.content + ), ) } @@ -313,9 +348,27 @@ fn remove_bullet_prefix(s: &str) -> String { } pub fn create( - chat: Arc, + chat: Arc, code: Arc, doc: Arc, ) -> AnswerService { AnswerService::new(chat, code, doc) } + +fn get_content(message: &ChatCompletionRequestMessage) -> &str { + match message { + ChatCompletionRequestMessage::System(x) => &x.content, + _ => { + panic!("Unexpected message type, {:?}", message); + } + } +} + +fn set_content(message: &mut ChatCompletionRequestMessage, content: String) { + match message { + ChatCompletionRequestMessage::System(x) => x.content = content, + _ => { + panic!("Unexpected message type"); + } + } +} diff --git a/crates/tabby/src/services/chat.rs b/crates/tabby/src/services/chat.rs deleted file mode 100644 index fa37e06114a2..000000000000 --- a/crates/tabby/src/services/chat.rs +++ /dev/null @@ -1,243 +0,0 @@ -use std::sync::Arc; - -use async_stream::stream; -use derive_builder::Builder; -use futures::stream::BoxStream; -use serde::{Deserialize, Serialize}; -use tabby_common::{ - api::{ - chat::Message, - event::{Event, EventLogger}, - }, - config::ModelConfig, -}; -use tabby_inference::{ChatCompletionOptionsBuilder, ChatCompletionStream}; -use tracing::warn; -use utoipa::ToSchema; -use uuid::Uuid; - -use super::model; - -#[derive(Serialize, Deserialize, ToSchema, Clone, Builder, Debug)] -#[schema(example=json!({ - "messages": [ - Message { role: "user".to_owned(), content: "What is tail recursion?".to_owned()}, - Message { role: "assistant".to_owned(), content: "It's a kind of optimization in compiler?".to_owned()}, - Message { role: "user".to_owned(), content: "Could you share more details?".to_owned()}, - ] -}))] -pub struct ChatCompletionRequest { - #[builder(default = "None")] - pub(crate) user: Option, - - messages: Vec, - - #[builder(default = "None")] - temperature: Option, - - #[builder(default = "None")] - seed: Option, - - #[builder(default = "None")] - presence_penalty: Option, -} - -#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] -pub struct ChatCompletionChunk { - id: String, - created: u64, - system_fingerprint: String, - object: &'static str, - model: &'static str, - pub choices: [ChatCompletionChoice; 1], -} - -#[derive(Serialize, Deserialize, Clone, Debug, ToSchema)] -pub struct ChatCompletionChoice { - index: usize, - #[serde(skip_serializing_if = "Option::is_none")] - logprobs: Option, - #[serde(skip_serializing_if = "Option::is_none")] - finish_reason: Option, - pub delta: ChatCompletionDelta, -} - -#[derive(Serialize, Deserialize, Clone, Debug, ToSchema)] -pub struct ChatCompletionDelta { - pub content: String, -} - -impl ChatCompletionChunk { - fn new(content: String, id: String, created: u64, last_chunk: bool) -> Self { - ChatCompletionChunk { - id, - created, - object: "chat.completion.chunk", - model: "unused-model", - system_fingerprint: "unused-system-fingerprint".into(), - choices: [ChatCompletionChoice { - index: 0, - delta: ChatCompletionDelta { content }, - logprobs: None, - finish_reason: last_chunk.then(|| "stop".into()), - }], - } - } -} - -pub struct ChatService { - engine: Arc, - logger: Arc, -} - -impl ChatService { - fn new(engine: Arc, logger: Arc) -> Self { - Self { engine, logger } - } - - pub async fn generate<'a>( - self: Arc, - request: ChatCompletionRequest, - ) -> BoxStream<'a, ChatCompletionChunk> { - let mut output = String::new(); - - let options = { - let mut builder = ChatCompletionOptionsBuilder::default(); - request.temperature.inspect(|x| { - builder.sampling_temperature(*x); - }); - request.seed.inspect(|x| { - builder.seed(*x); - }); - request.presence_penalty.inspect(|x| { - builder.presence_penalty(*x); - }); - builder - .build() - .expect("Failed to create ChatCompletionOptions") - }; - - let s = stream! { - let s = match self.engine.chat_completion(&request.messages, options).await { - Ok(x) => x, - Err(e) => { - warn!("Failed to start chat completion: {:?}", e); - return; - } - }; - - let created = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .expect("Must be able to read system clock") - .as_secs(); - - let completion_id = format!("chatcmpl-{}", Uuid::new_v4()); - for await content in s { - output.push_str(&content); - yield ChatCompletionChunk::new(content, completion_id.clone(), created, false); - } - yield ChatCompletionChunk::new(String::default(), completion_id.clone(), created, true); - - self.logger.log(request.user, Event::ChatCompletion { - completion_id, - input: convert_messages(&request.messages), - output: create_assistant_message(output) - }); - }; - - Box::pin(s) - } -} - -fn create_assistant_message(string: String) -> tabby_common::api::event::Message { - tabby_common::api::event::Message { - content: string, - role: "assistant".into(), - } -} - -fn convert_messages(input: &[Message]) -> Vec { - input - .iter() - .map(|m| tabby_common::api::event::Message { - content: m.content.clone(), - role: m.role.clone(), - }) - .collect() -} - -pub async fn create_chat_service(logger: Arc, chat: &ModelConfig) -> ChatService { - let engine = model::load_chat_completion(chat).await; - - ChatService::new(engine, logger) -} - -#[cfg(test)] -mod tests { - use std::sync::Mutex; - - use anyhow::Result; - use async_trait::async_trait; - use futures::StreamExt; - use tabby_inference::ChatCompletionOptions; - - use super::*; - - struct MockChatCompletionStream; - - #[async_trait] - impl ChatCompletionStream for MockChatCompletionStream { - async fn chat_completion( - &self, - _messages: &[Message], - _options: ChatCompletionOptions, - ) -> Result> { - let s = stream! { - yield "Hello, world!".into(); - }; - Ok(Box::pin(s)) - } - } - - struct MockEventLogger(Mutex>); - - impl EventLogger for MockEventLogger { - fn write(&self, x: tabby_common::api::event::LogEntry) { - self.0.lock().unwrap().push(x.event); - } - } - - #[tokio::test] - async fn test_chat_service() { - let engine = Arc::new(MockChatCompletionStream); - let logger = Arc::new(MockEventLogger(Default::default())); - let service = Arc::new(ChatService::new(engine, logger.clone())); - - let request = ChatCompletionRequest { - messages: vec![Message { - role: "user".into(), - content: "Hello, computer!".into(), - }], - temperature: None, - seed: None, - presence_penalty: None, - user: None, - }; - let mut output = service.generate(request).await; - let response = output.next().await.unwrap(); - assert_eq!(response.choices[0].delta.content, "Hello, world!"); - - let finish = output.next().await.unwrap(); - assert_eq!(finish.choices[0].delta.content, ""); - assert_eq!(finish.choices[0].finish_reason.as_ref().unwrap(), "stop"); - - assert!(output.next().await.is_none()); - - let event = &logger.0.lock().unwrap()[0]; - let Event::ChatCompletion { output, .. } = event else { - panic!("Expected ChatCompletion event"); - }; - assert_eq!(output.role, "assistant"); - assert_eq!(output.content, "Hello, world!"); - } -} diff --git a/crates/tabby/src/services/mod.rs b/crates/tabby/src/services/mod.rs index e7015b9baa1a..c0988c362e26 100644 --- a/crates/tabby/src/services/mod.rs +++ b/crates/tabby/src/services/mod.rs @@ -1,5 +1,4 @@ pub mod answer; -pub mod chat; pub mod code; pub mod completion; pub mod doc;