Skip to content

Commit

Permalink
feat: add mistral/chat support to talk to mistral api platform throug…
Browse files Browse the repository at this point in the history
…h chat api (TabbyML#2568)

* feat: support ExtendedOpenAIConfig

* update

* support mistral/chat use case
  • Loading branch information
wsxiaoys authored Jul 9, 2024
1 parent 70a508b commit ef8542c
Show file tree
Hide file tree
Showing 10 changed files with 156 additions and 55 deletions.
45 changes: 8 additions & 37 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ anyhow = "1.0.71"
tantivy = { git = "https://github.com/quickwit-oss/tantivy", rev = "4143d31" }
async-trait = "0.1.72"
reqwest = { version = "0.12" }
derive_builder = "0.12.0"
derive_builder = "0.20"
futures = "0.3.30"
async-stream = "0.3.5"
regex = "1.10.0"
Expand Down
17 changes: 16 additions & 1 deletion crates/http-api-bindings/src/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,27 @@ use std::sync::Arc;

use async_openai::config::OpenAIConfig;
use tabby_common::config::HttpModelConfig;
use tabby_inference::ChatCompletionStream;
use tabby_inference::{ChatCompletionStream, ExtendedOpenAIConfig};

pub async fn create(model: &HttpModelConfig) -> Arc<dyn ChatCompletionStream> {
let config = OpenAIConfig::default()
.with_api_base(model.api_endpoint.clone())
.with_api_key(model.api_key.clone().unwrap_or_default());

let mut builder = ExtendedOpenAIConfig::builder();
builder
.base(config)
.model_name(model.model_name.as_deref().expect("Model name is required"));

if model.kind == "openai/chat" {
// Do nothing
} else if model.kind == "mistral/chat" {
builder.fields_to_remove(ExtendedOpenAIConfig::mistral_fields_to_remove());
} else {
panic!("Unsupported model kind: {}", model.kind);
}

let config = builder.build().expect("Failed to build config");

Arc::new(async_openai::Client::with_config(config))
}
17 changes: 14 additions & 3 deletions crates/llama-cpp-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod supervisor;
use std::{path::PathBuf, sync::Arc};

use anyhow::Result;
use async_openai::config::OpenAIConfig;
use async_openai::error::OpenAIError;
use async_trait::async_trait;
use futures::stream::BoxStream;
use serde::Deserialize;
Expand Down Expand Up @@ -139,9 +139,20 @@ impl ChatCompletionServer {
}
}

#[async_trait]
impl ChatCompletionStream for ChatCompletionServer {
fn get(&self) -> async_openai::Chat<'_, OpenAIConfig> {
self.chat_completion.get()
async fn chat(
&self,
request: async_openai::types::CreateChatCompletionRequest,
) -> Result<async_openai::types::CreateChatCompletionResponse, OpenAIError> {
self.chat_completion.chat(request).await
}

async fn chat_stream(
&self,
request: async_openai::types::CreateChatCompletionRequest,
) -> Result<async_openai::types::ChatCompletionResponseStream, OpenAIError> {
self.chat_completion.chat_stream(request).await
}
}

Expand Down
6 changes: 4 additions & 2 deletions crates/tabby-inference/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ anyhow.workspace = true
async-stream = { workspace = true }
async-trait = { workspace = true }
dashmap = "5.5.3"
derive_builder = "0.12.0"
derive_builder.workspace = true
futures = { workspace = true }
tabby-common = { path = "../tabby-common" }
trie-rs = "0.1.1"
async-openai.workspace = true
async-openai.workspace = true
secrecy = "0.8"
reqwest.workspace = true
113 changes: 108 additions & 5 deletions crates/tabby-inference/src/chat.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,114 @@
use async_openai::config::OpenAIConfig;
use async_openai::{
config::OpenAIConfig,
error::OpenAIError,
types::{
ChatCompletionResponseStream, CreateChatCompletionRequest, CreateChatCompletionResponse,
},
};
use async_trait::async_trait;
use derive_builder::Builder;

#[async_trait]
pub trait ChatCompletionStream: Sync + Send {
fn get(&self) -> async_openai::Chat<'_, OpenAIConfig>;
async fn chat(
&self,
request: CreateChatCompletionRequest,
) -> Result<CreateChatCompletionResponse, OpenAIError>;

async fn chat_stream(
&self,
request: CreateChatCompletionRequest,
) -> Result<ChatCompletionResponseStream, OpenAIError>;
}

#[derive(Clone)]
pub enum OpenAIRequestFieldEnum {
PresencePenalty,
User,
}

#[derive(Builder, Clone)]
pub struct ExtendedOpenAIConfig {
base: OpenAIConfig,

#[builder(setter(into))]
model_name: String,

#[builder(default)]
fields_to_remove: Vec<OpenAIRequestFieldEnum>,
}

impl ChatCompletionStream for async_openai::Client<OpenAIConfig> {
fn get(&self) -> async_openai::Chat<'_, OpenAIConfig> {
self.chat()
impl ExtendedOpenAIConfig {
pub fn builder() -> ExtendedOpenAIConfigBuilder {
ExtendedOpenAIConfigBuilder::default()
}

pub fn mistral_fields_to_remove() -> Vec<OpenAIRequestFieldEnum> {
vec![
OpenAIRequestFieldEnum::PresencePenalty,
OpenAIRequestFieldEnum::User,
]
}

fn process_request(
&self,
mut request: CreateChatCompletionRequest,
) -> CreateChatCompletionRequest {
request.model = self.model_name.clone();

for field in &self.fields_to_remove {
match field {
OpenAIRequestFieldEnum::PresencePenalty => {
request.presence_penalty = None;
}
OpenAIRequestFieldEnum::User => {
request.user = None;
}
}
}

request
}
}

impl async_openai::config::Config for ExtendedOpenAIConfig {
fn headers(&self) -> reqwest::header::HeaderMap {
self.base.headers()
}

fn url(&self, path: &str) -> String {
self.base.url(path)
}

fn query(&self) -> Vec<(&str, &str)> {
self.base.query()
}

fn api_base(&self) -> &str {
self.base.api_base()
}

fn api_key(&self) -> &secrecy::Secret<String> {
self.base.api_key()
}
}

#[async_trait]
impl ChatCompletionStream for async_openai::Client<ExtendedOpenAIConfig> {
async fn chat(
&self,
request: CreateChatCompletionRequest,
) -> Result<CreateChatCompletionResponse, OpenAIError> {
let request = self.config().process_request(request);
self.chat().create(request).await
}

async fn chat_stream(
&self,
request: CreateChatCompletionRequest,
) -> Result<ChatCompletionResponseStream, OpenAIError> {
let request = self.config().process_request(request);
eprintln!("Creating chat stream: {:?}", request);
self.chat().create_stream(request).await
}
}
2 changes: 1 addition & 1 deletion crates/tabby-inference/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ mod completion;
mod decoding;
mod embedding;

pub use chat::ChatCompletionStream;
pub use chat::{ChatCompletionStream, ExtendedOpenAIConfig};
pub use code::{CodeGeneration, CodeGenerationOptions, CodeGenerationOptionsBuilder};
pub use completion::{CompletionOptions, CompletionOptionsBuilder, CompletionStream};
pub use embedding::Embedding;
Expand Down
2 changes: 1 addition & 1 deletion crates/tabby-scheduler/src/doc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{collections::HashSet, sync::Arc};

use async_stream::stream;
use async_trait::async_trait;
use futures::{stream::BoxStream, StreamExt};
use futures::stream::BoxStream;
use public::WebDocument;
use serde_json::json;
use tabby_common::index::{self, corpus, doc};
Expand Down
2 changes: 1 addition & 1 deletion crates/tabby/src/routes/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub async fn chat_completions(
request.user.replace(user);
}

let s = match state.get().create_stream(request).await {
let s = match state.chat_stream(request).await {
Ok(s) => s,
Err(err) => {
warn!("Error happens during chat completion: {}", err);
Expand Down
5 changes: 2 additions & 3 deletions crates/tabby/src/services/answer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ impl AnswerService {
builder.build().expect("Failed to create ChatCompletionRequest")
};

let s = match self.chat.get().create_stream(request).await {
let s = match self.chat.chat_stream(request).await {
Ok(s) => s,
Err(err) => {
warn!("Failed to create chat completion stream: {:?}", err);
Expand Down Expand Up @@ -267,8 +267,7 @@ Remember, based on the original question and related contexts, suggest three suc

let chat = self.chat.clone();
let s = chat
.get()
.create(request)
.chat(request)
.await
.expect("Failed to create chat completion stream");
let content = s.choices[0]
Expand Down

0 comments on commit ef8542c

Please sign in to comment.