forked from TabbyML/tabby
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add mistral/chat support to talk to mistral api platform throug…
…h chat api (TabbyML#2568) * feat: support ExtendedOpenAIConfig * update * support mistral/chat use case
- Loading branch information
Showing
10 changed files
with
156 additions
and
55 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,114 @@ | ||
use async_openai::config::OpenAIConfig; | ||
use async_openai::{ | ||
config::OpenAIConfig, | ||
error::OpenAIError, | ||
types::{ | ||
ChatCompletionResponseStream, CreateChatCompletionRequest, CreateChatCompletionResponse, | ||
}, | ||
}; | ||
use async_trait::async_trait; | ||
use derive_builder::Builder; | ||
|
||
#[async_trait] | ||
pub trait ChatCompletionStream: Sync + Send { | ||
fn get(&self) -> async_openai::Chat<'_, OpenAIConfig>; | ||
async fn chat( | ||
&self, | ||
request: CreateChatCompletionRequest, | ||
) -> Result<CreateChatCompletionResponse, OpenAIError>; | ||
|
||
async fn chat_stream( | ||
&self, | ||
request: CreateChatCompletionRequest, | ||
) -> Result<ChatCompletionResponseStream, OpenAIError>; | ||
} | ||
|
||
#[derive(Clone)] | ||
pub enum OpenAIRequestFieldEnum { | ||
PresencePenalty, | ||
User, | ||
} | ||
|
||
#[derive(Builder, Clone)] | ||
pub struct ExtendedOpenAIConfig { | ||
base: OpenAIConfig, | ||
|
||
#[builder(setter(into))] | ||
model_name: String, | ||
|
||
#[builder(default)] | ||
fields_to_remove: Vec<OpenAIRequestFieldEnum>, | ||
} | ||
|
||
impl ChatCompletionStream for async_openai::Client<OpenAIConfig> { | ||
fn get(&self) -> async_openai::Chat<'_, OpenAIConfig> { | ||
self.chat() | ||
impl ExtendedOpenAIConfig { | ||
pub fn builder() -> ExtendedOpenAIConfigBuilder { | ||
ExtendedOpenAIConfigBuilder::default() | ||
} | ||
|
||
pub fn mistral_fields_to_remove() -> Vec<OpenAIRequestFieldEnum> { | ||
vec![ | ||
OpenAIRequestFieldEnum::PresencePenalty, | ||
OpenAIRequestFieldEnum::User, | ||
] | ||
} | ||
|
||
fn process_request( | ||
&self, | ||
mut request: CreateChatCompletionRequest, | ||
) -> CreateChatCompletionRequest { | ||
request.model = self.model_name.clone(); | ||
|
||
for field in &self.fields_to_remove { | ||
match field { | ||
OpenAIRequestFieldEnum::PresencePenalty => { | ||
request.presence_penalty = None; | ||
} | ||
OpenAIRequestFieldEnum::User => { | ||
request.user = None; | ||
} | ||
} | ||
} | ||
|
||
request | ||
} | ||
} | ||
|
||
impl async_openai::config::Config for ExtendedOpenAIConfig { | ||
fn headers(&self) -> reqwest::header::HeaderMap { | ||
self.base.headers() | ||
} | ||
|
||
fn url(&self, path: &str) -> String { | ||
self.base.url(path) | ||
} | ||
|
||
fn query(&self) -> Vec<(&str, &str)> { | ||
self.base.query() | ||
} | ||
|
||
fn api_base(&self) -> &str { | ||
self.base.api_base() | ||
} | ||
|
||
fn api_key(&self) -> &secrecy::Secret<String> { | ||
self.base.api_key() | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl ChatCompletionStream for async_openai::Client<ExtendedOpenAIConfig> { | ||
async fn chat( | ||
&self, | ||
request: CreateChatCompletionRequest, | ||
) -> Result<CreateChatCompletionResponse, OpenAIError> { | ||
let request = self.config().process_request(request); | ||
self.chat().create(request).await | ||
} | ||
|
||
async fn chat_stream( | ||
&self, | ||
request: CreateChatCompletionRequest, | ||
) -> Result<ChatCompletionResponseStream, OpenAIError> { | ||
let request = self.config().process_request(request); | ||
eprintln!("Creating chat stream: {:?}", request); | ||
self.chat().create_stream(request).await | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters