Skip to content

Commit

Permalink
feat: add support for OpenAI completion API (TabbyML#2604)
Browse files Browse the repository at this point in the history
* feat: add support for OpenAI completion endpoint

* add openai/completion example in model configuration documentation
  • Loading branch information
Syst3m1cAn0maly authored Jul 10, 2024
1 parent dcc91d1 commit 510a63c
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 0 deletions.
10 changes: 10 additions & 0 deletions crates/http-api-bindings/src/completion/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
mod llama;
mod mistral;
mod openai;

use std::sync::Arc;

use llama::LlamaCppEngine;
use mistral::MistralFIMEngine;
use openai::OpenAICompletionEngine;
use tabby_common::config::HttpModelConfig;
use tabby_inference::CompletionStream;

Expand All @@ -24,6 +26,14 @@ pub async fn create(model: &HttpModelConfig) -> Arc<dyn CompletionStream> {
);
Arc::new(engine)
}
"openai/completion" => {
let engine = OpenAICompletionEngine::create(
model.model_name.clone(),
&model.api_endpoint,
model.api_key.clone(),
);
Arc::new(engine)
}

unsupported_kind => panic!(
"Unsupported model kind for http completion: {}",
Expand Down
92 changes: 92 additions & 0 deletions crates/http-api-bindings/src/completion/openai.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use async_stream::stream;
use async_trait::async_trait;
use futures::{stream::BoxStream, StreamExt};
use reqwest_eventsource::{Event, EventSource};
use serde::{Deserialize, Serialize};
use tabby_inference::{CompletionOptions, CompletionStream};

pub struct OpenAICompletionEngine {
client: reqwest::Client,
model_name: String,
api_endpoint: String,
api_key: Option<String>,
}

impl OpenAICompletionEngine {
pub fn create(model_name: Option<String>, api_endpoint: &str, api_key: Option<String>) -> Self {
let model_name = model_name.unwrap();
let client = reqwest::Client::new();

Self {
client,
model_name,
api_endpoint: format!("{}/completions", api_endpoint),
api_key,
}
}
}

#[derive(Serialize)]
struct CompletionRequest {
model: String,
prompt: String,
max_tokens: i32,
temperature: f32,
stream: bool,
presence_penalty: f32,
}

#[derive(Deserialize)]
struct CompletionResponseChunk {
choices: Vec<CompletionResponseChoice>,
}

#[derive(Deserialize)]
struct CompletionResponseChoice {
text: String,
finish_reason: Option<String>,
}

#[async_trait]
impl CompletionStream for OpenAICompletionEngine {
async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream<String> {
let request = CompletionRequest {
model: self.model_name.clone(),
prompt: prompt.to_owned(),
max_tokens: options.max_decoding_tokens,
temperature: options.sampling_temperature,
stream: true,
presence_penalty: options.presence_penalty,
};

let mut request = self.client.post(&self.api_endpoint).json(&request);
if let Some(api_key) = &self.api_key {
request = request.bearer_auth(api_key);
}

let s = stream! {
let mut es = EventSource::new(request).expect("Failed to create event source");
while let Some(event) = es.next().await {
match event {
Ok(Event::Open) => {}
Ok(Event::Message(message)) => {
let x: CompletionResponseChunk = serde_json::from_str(&message.data).expect("Failed to parse response");
if let Some(choice) = x.choices.first() {
yield choice.text.clone();

if choice.finish_reason.is_some() {
break;
}
}
}
Err(_) => {
// StreamEnd
break;
}
}
}
};

Box::pin(s)
}
}
12 changes: 12 additions & 0 deletions website/docs/administration/model.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@ api_endpoint = "https://api.mistral.ai"
api_key = "secret-api-key"
```

#### [openai completion](https://platform.openai.com/docs/api-reference/completions)

Configure Tabby with an OpenAI-compatible completion model (`/v1/completions`) using an online service or a self-hosted backend (vLLM, Nvidia NIM, LocalAI, ...) as follows:

```toml
[model.completion.http]
kind = "openai/completion"
model_name = "your_model"
api_endpoint = "https://url_to_your_backend_or_service"
api_key = "secret-api-key"
```

### Chat Model

Chat models adhere to the standard interface specified by OpenAI's `/chat/completions` API.
Expand Down

0 comments on commit 510a63c

Please sign in to comment.